[
  {
    "path": ".gitignore",
    "content": ".idea\n*.egg-info/\neggs/\n.eggs/\n*.exe\n*.pyc\n/.vscode/\n*.code-workspace\n__pycache__\n# Sphinx documentation\ndocs/_build/\ndocs/build/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# event data\n*.bin\n*.dat\n*.pt\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n"
  },
  {
    "path": "LICENSE",
    "content": "\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# BrainCog\n\n---\n\nBrainCog is an open source spiking neural network based brain-inspired \ncognitive intelligence engine for Brain-inspired Artificial Intelligence, Brain-inspired Embodied AI, and brain simulation. More information on BrainCog can be found on its homepage http://www.brain-cog.network/\n\nThe current version of BrainCog contains at least 50 functional spiking neural network algorithms (including but not limited to perception and learning, decision making, knowledge representation and reasoning, motor control, social cognition, etc.) built based on BrainCog infrastructures, and BrainCog also provide brain simulations to drosophila, rodent, monkey, and human brains at multiple scales based on spiking neural networks at multiple scales. More detail in http://www.brain-cog.network/docs/\n\nBrainCog is a community based effort for spiking neural network based artificial intelligence, and we welcome any forms of contributions, from contributing to the development of core components, to contributing for applications.\n\n<img src=\"http://braincog.ai/static_index/image/github_readme/logo.jpg\" alt=\"./figures/logo.jpg\" width=\"70%\" />\n\nBrainCog provides essential and fundamental components to model biological and artificial intelligence.\n\n![image]( http://braincog.ai/static_index/image/github_readme/braincog.png)\n\nOur paper is published in [Patterns](https://www.cell.com/patterns/fulltext/S2666-3899(23)00144-7?_returnURL=https%3A%2F%2Flinkinghub.elsevier.com%2Fretrieve%2Fpii%2FS2666389923001447%3Fshowall%3Dtrue). If you use BrainCog in your research, the following paper can be cited as the source for BrainCog.\n```bib\n@article{Zeng2023,\n  doi = {10.1016/j.patter.2023.100789},\n  url = {https://doi.org/10.1016/j.patter.2023.100789},\n  year = {2023},\n  month = jul,\n  publisher = {Cell Press},\n  pages = {100789},\n  author = {Yi Zeng and Dongcheng Zhao and Feifei Zhao and Guobin Shen and Yiting Dong and Enmeng Lu and Qian Zhang and Yinqian Sun and Qian Liang and Yuxuan Zhao and Zhuoya Zhao and Hongjian Fang and Yuwei Wang and Yang Li and Xin Liu and Chengcheng Du and Qingqun Kong and Zizhe Ruan and Weida Bi},\n  title = {{BrainCog}: A spiking neural network based,  brain-inspired cognitive intelligence engine for brain-inspired {AI} and brain simulation},\n  journal = {Patterns}\n}\n```\n\n## Brain-Inspired AI\nBrainCog currently provides cognitive functions components that can be classified \ninto five categories: \n* Perception and Learning\n* Knowledge Representation and Reasoning\n* Decision Making\n* Motor Control\n* Social Cognition\n* Development and Evolution\n* Safety and Security\n\n<img src=\"https://raw.githubusercontent.com/Brain-Cog-Lab/Brain-Cog/main/figures/mirror-test.gif\" alt=\"mt\" width=\"55%\" />\n<img src=\"https://raw.githubusercontent.com/Brain-Cog-Lab/Brain-Cog/main/figures/joy.gif\" alt=\"mt\" width=\"55%\" />\n\n## Brain Simulation\nBrainCog currently include two parts for brain simulation:\n* Brain Cognitive Function Simulation\n* Multi-scale Brain Structure Simulation\n\n<img src=\"https://raw.githubusercontent.com/Brain-Cog-Lab/Brain-Cog/main/figures/braincog-mouse-brain-model-10s.gif\" alt=\"bmbm10s\" width=\"55%\" /> \n<img src=\"https://raw.githubusercontent.com/Brain-Cog-Lab/Brain-Cog/main/figures/braincog-macaque-10s.gif\" alt=\"bm10s\" width=\"55%\" />\n<img src=\"https://raw.githubusercontent.com/Brain-Cog-Lab/Brain-Cog/main/figures/braincog-humanbrain-10s.gif\" alt=\"bh10s\" width=\"55%\" />\n\nThe anatomical and imaging data is used to support our simulation from various aspects. \n\n## Software-Hardware Codesign (BrainCog Firefly)\n<img src=\"http://www.brain-cog.network/static/image/github_readme/firefly_logo.jpg\" alt=\"bh10s\" width=\"25%\" />\n\nBrainCog currently provides `hardware acceleration` for spiking neural network based brain-inspired AI.\n\n<img src=\"http://braincog.ai/static_index/image/github_readme/firefly.jpg\" alt=\"bh10s\" width=\"55%\" />\n\nThe following papers are most recent advancement of BrainCog Firefly series for Software-Hardware Codesign for Brain-inspired AI.\n* Tenglong Li, Jindong Li, Guobin Shen, Dongcheng Zhao, Qian Zhang, Yi Zeng. FireFly-S: Exploiting Dual-Side Sparsity for Spiking Neural Networks Acceleration With Reconfigurable Spatial Architecture. IEEE Transactions on Circuits and Systems I (TCAS-I), 2024.(https://doi.org/10.1109/TCSI.2024.3496554)\n* Jindong Li, Guobin Shen, Dongcheng Zhao, Qian Zhang, Yi Zeng. Firefly v2: Advancing hardware support for high-performance spiking neural network with a spatiotemporal fpga accelerator. IEEE Transactions on Computer-Aided Design of Integrated Circuits and Systems, 2024. (https://ieeexplore.ieee.org/abstract/document/10478105/)\n* Jindong Li, Guobin Shen, Dongcheng Zhao, Qian Zhang, Yi Zeng. FireFly: A High-Throughput Hardware Accelerator for Spiking Neural Networks With Efficient DSP and Memory Optimization. IEEE Transactions on Very Large Scale Integration (VLSI) Systems, 2023. (https://ieeexplore.ieee.org/document/10143752)\n\n## Embodied AI and Robotics (BrainCog Embot)\n<img src=\"http://www.brain-cog.network/static/image/github_readme/Embot_logo/%E5%B8%A6%E4%B8%8A%E6%A0%87%E9%A2%98-Embot%20logo%20%E9%80%8F%E6%98%8E%E8%83%8C%E6%99%AF.png\" alt=\"bh10s\" width=\"25%\" />\n\n<img src=\"https://raw.githubusercontent.com/Brain-Cog-Lab/Brain-Cog/main/figures/PushT.gif\" alt=\"bm10s\" width=\"10%\" /><img src=\"https://raw.githubusercontent.com/Brain-Cog-Lab/Brain-Cog/main/figures/Can.gif\" alt=\"bh10s\" width=\"10%\" /> <img src=\"https://raw.githubusercontent.com/Brain-Cog-Lab/Brain-Cog/main/figures/左数第三个.gif\" alt=\"bh10s\" width=\"10%\" /> <img src=\"https://raw.githubusercontent.com/Brain-Cog-Lab/Brain-Cog/main/figures/Square.gif\" alt=\"bm10s\" width=\"10%\" /><img src=\"https://raw.githubusercontent.com/Brain-Cog-Lab/Brain-Cog/main/figures/ToolHang.gif\" alt=\"bh10s\" width=\"10%\" /> <img src=\"https://raw.githubusercontent.com/Brain-Cog-Lab/Brain-Cog/main/figures/左数第六个.gif\" alt=\"bh10s\" width=\"10%\" /> \n\nBrainCog Embot is an Embodied AI platform under the Brain-inspired Cognitive Intelligence Engine (BrainCog) framework, which is an open-source Brain-inspired AI platform based on Spiking Neural Network.\nThe following papers are most recent advancement of BrainCog Embot:\n\n* Qianhao Wang, Yinqian Sun, Enmeng Lu, Qian Zhang, Yi Zeng. Brain-Inspired Action Generation with Spiking Transformer Diffusion Policy Model. Advances in Brain Inspired Cognitive Systems (BICS), 2024.(https://link.springer.com/chapter/10.1007/978-981-96-2882-7_23)\n* Yinqian Sun, Feifei Zhao, Mingyang Lv, Yi Zeng. Implementing Spiking World Model with Multi-Compartment Neurons for Model-based Reinforcement Learning, 2025. (https://arxiv.org/abs/2503.00713)\n* Qianhao Wang, Yinqian Sun, Enmeng Lu, Qian Zhang, Yi Zeng. MTDP: Modulated Transformer Diffusion Policy Model, 2025. (https://arxiv.org/abs/2502.09029)\n\n## Resources\n### [[Lectures]](https://github.com/BrainCog-X/Brain-Cog/blob/main/documents/Lectures.md)  |   [[Tutorial]](https://github.com/BrainCog-X/Brain-Cog/blob/main/documents/Tutorial.md)\n\n\n## Publications using BrainCog \n### [[Brain Inspired AI]](https://github.com/BrainCog-X/Brain-Cog/blob/main/documents/Publication.md) | [[Brain Simulation]](https://github.com/BrainCog-X/Brain-Cog/blob/main/documents/Pub_brain_simulation.md) | [[Software-Hardware Co-design]](https://github.com/BrainCog-X/Brain-Cog/blob/main/documents/Pub_sh_codesign.md)\n\n## BrainCog Data Engine\n###  [BrainCog Data Engine](https://github.com/BrainCog-X/Brain-Cog/blob/main/documents/Data_engine.md)\n\n\n## Requirements:\n* numpy\n* scipy\n* h5py\n* torch\n* torchvision\n* torchaudio\n* timm == 0.6.13\n* scikit-learn\n* einops\n* thop\n* pyyaml\n* matplotlib\n* seaborn\n* pygame\n* dv\n* tensorboard\n* tonic\n\n\n\n## Install \n\n\n\n### Install Online\n\n1. You can install braincog by running:\n\n    > `pip install braincog`\n\n2. Also, install from github by running:\n\n    > `pip install git+https://github.com/braincog-X/Brain-Cog.git`\n\n\n### Install locally\n\n1.  If you are a developer, it is recommanded to download or clone\n    braincog from github.\n\n    > `git clone https://github.com/braincog-X/Brain-Cog.git`\n\n2.  Enter the folder of braincog\n\n    > `cd Brain-Cog`\n\n3.  Install braincog locally\n\n    > `pip install -e .`\n \n\n## Example \n\n1. Examples for Image Classification\n```shell \ncd ./examples/Perception_and_Learning/img_cls/bp \npython main.py --model cifar_convnet --dataset cifar10 --node-type LIFNode --step 8 --device 0\n```\n\n2. Examples for Event Classification \n\n```shell\ncd ./examples/Perception_and_Learning/img_cls/bp \npython main.py --model dvs_convnet --node-type LIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGateGrad --device 0 \n```\n      \nOther BrainCog features and tutorials can be found at http://www.brain-cog.network/docs/\n\n## BrainCog Assistant \nPlease add our BrainCog Assitant via wechat and we will invite you to our wechat developer group.\n![image](https://github.com/Brain-Cog-Lab/Brain-Cog/blob/main/figures/wechat_ass.jpg)\n\n\n## Maintenance\nThis project is led by \n\n**1.Brain-inspired Cognitive Intelligence Lab, Institute of Automation, Chinese Academy of Sciences http://www.braincog.ai/**\n\n**2.Center for Long-term Artificial Intelligence (CLAI) http://long-term-ai.center/**\n"
  },
  {
    "path": "braincog/__init__.py",
    "content": "# __all__ = ['base', 'datasets', 'model_zoo', 'utils']\n#\n# from . import (\n#     base,\n#     datasets,\n#     model_zoo,\n#     utils\n# )\n"
  },
  {
    "path": "braincog/base/__init__.py",
    "content": "__all__ = ['node', 'connection', 'learningrule', 'brainarea', 'encoder', 'utils', 'conversion']\n\nfrom . import (\n    node,\n    strategy,\n    connection,\n    conversion,\n    learningrule,\n    brainarea,\n    utils,\n    encoder\n)\n"
  },
  {
    "path": "braincog/base/brainarea/BrainArea.py",
    "content": "import numpy as np\r\nimport torch, os, sys\r\nfrom torch import nn\r\nfrom torch.nn import Parameter\r\n\r\nimport abc\r\nimport math\r\nfrom abc import ABC\r\nimport numpy as np\r\nimport torch\r\nfrom torch import nn\r\nfrom torch.nn import Parameter\r\nimport torch.nn.functional as F\r\nfrom braincog.base.node.node import *\r\nfrom braincog.base.learningrule.STDP import *\r\nfrom braincog.base.connection.CustomLinear import *\r\n\r\n\r\nclass BrainArea(nn.Module, abc.ABC):\r\n    \"\"\"\r\n    脑区基类\r\n    \"\"\"\r\n\r\n    @abc.abstractmethod\r\n    def __init__(self):\r\n        \"\"\"\r\n        \"\"\"\r\n        super().__init__()\r\n\r\n    @abc.abstractmethod\r\n    def forward(self, x):\r\n        \"\"\"\r\n        计算前向传播过程\r\n        :return:x是脉冲\r\n        \"\"\"\r\n\r\n        return x\r\n\r\n    def reset(self):\r\n        \"\"\"\r\n        计算前向传播过程\r\n        :return:x是脉冲\r\n        \"\"\"\r\n\r\n        pass\r\n\r\n\r\nclass ThreePointForward(BrainArea):\r\n    \"\"\"\r\n    三点前馈脑区\r\n    \"\"\"\r\n\r\n    def __init__(self, w1, w2, w3):\r\n        \"\"\"\r\n        \"\"\"\r\n        super().__init__()\r\n\r\n        self.node = [IFNode(), IFNode(), IFNode()]\r\n        self.connection = [CustomLinear(w1), CustomLinear(w2), CustomLinear(w3)]\r\n        self.stdp = []\r\n\r\n        self.stdp.append(STDP(self.node[0], self.connection[0]))\r\n        self.stdp.append(STDP(self.node[1], self.connection[1]))\r\n        self.stdp.append(STDP(self.node[2], self.connection[2]))\r\n\r\n    def forward(self, x):\r\n        \"\"\"\r\n        计算前向传播过程\r\n        :return:x是脉冲\r\n        \"\"\"\r\n        x, dw1 = self.stdp[0](x)\r\n        x, dw2 = self.stdp[1](x)\r\n        x, dw3 = self.stdp[2](x)\r\n\r\n        return x, (*dw1, *dw2, *dw3)\r\n\r\n\r\nclass Feedback(BrainArea):\r\n    \"\"\"\r\n    反馈网络\r\n    \"\"\"\r\n\r\n    def __init__(self, w1, w2, w3):\r\n        \"\"\"\r\n        \"\"\"\r\n        super().__init__()\r\n\r\n        self.node = [IFNode(), IFNode()]\r\n        self.connection = [CustomLinear(w1), CustomLinear(w2), CustomLinear(w3)]\r\n        self.stdp = []\r\n\r\n        self.stdp.append(MutliInputSTDP(self.node[0], [self.connection[0], self.connection[2]]))\r\n        self.stdp.append(STDP(self.node[1], self.connection[1]))\r\n        self.x1 = torch.zeros(1, w3.shape[0])\r\n\r\n    def forward(self, x):\r\n        \"\"\"\r\n        计算前向传播过程\r\n        :return:x是脉冲\r\n        \"\"\"\r\n        x, dw1 = self.stdp[0](x, self.x1)\r\n        self.x1, dw2 = self.stdp[1](x)\r\n\r\n        return self.x1, (*dw1, *dw2)\r\n\r\n    def reset(self):\r\n        self.x1 *= 0\r\n\r\n\r\nclass TwoInOneOut(BrainArea):\r\n    \"\"\"\r\n    反馈网络\r\n    \"\"\"\r\n\r\n    def __init__(self, w1, w2):\r\n        \"\"\"\r\n        \"\"\"\r\n        super().__init__()\r\n\r\n        self.node = [IFNode()]\r\n        self.connection = [CustomLinear(w1), CustomLinear(w2)]\r\n        self.stdp = []\r\n\r\n        self.stdp.append(MutliInputSTDP(self.node[0], [self.connection[0], self.connection[1]]))\r\n\r\n    def forward(self, x1, x2):\r\n        \"\"\"\r\n        计算前向传播过程\r\n        :return:x是脉冲\r\n        \"\"\"\r\n        x, dw1 = self.stdp[0](x1, x2)\r\n\r\n        return x, dw1\r\n\r\n\r\nclass SelfConnectionArea(BrainArea):\r\n    \"\"\"\r\n    反馈网络\r\n    \"\"\"\r\n\r\n    def __init__(self, w1, w2 ):\r\n        \"\"\"\r\n        \"\"\"\r\n        super().__init__()\r\n\r\n        self.node = [IFNode() ]\r\n        self.connection = [CustomLinear(w1), CustomLinear(w2) ]\r\n        self.stdp = []\r\n\r\n        self.stdp.append(MutliInputSTDP(self.node[0], [self.connection[0], self.connection[1]]))\r\n        self.x1 = torch.zeros(1, w2.shape[0])\r\n\r\n    def forward(self, x):\r\n        \"\"\"\r\n        计算前向传播过程\r\n        :return:x是脉冲\r\n        \"\"\"\r\n        self.x1, dw1 = self.stdp[0](x, self.x1)\r\n\r\n\r\n        return self.x1, dw1\r\n\r\n    def reset(self):\r\n\r\n        self.x1 *= 0\r\n\r\nif __name__ == \"__main__\":\r\n    T = 20\r\n    w1 = torch.tensor([[1., 1], [1, 1]])\r\n    w2 = torch.tensor([[1., 1], [1, 1]])\r\n    w3 = torch.tensor([[0.4, 0.4], [0.4, 0.4]])\r\n    ba = TwoInOneOut(w1, w2)\r\n    for i in range(T):\r\n        x = ba(torch.tensor([[0.1, 0.1]]), torch.tensor([[0.1, 0.1]]))\r\n        print(x[0])\r\n"
  },
  {
    "path": "braincog/base/brainarea/IPL.py",
    "content": "\r\nfrom braincog.base.learningrule.STDP import *\r\nfrom braincog.base.node.node import *\r\nfrom braincog.base.connection.CustomLinear import *\r\nimport random\r\nimport numpy as np\r\nimport torch\r\nimport os\r\nimport sys\r\nfrom torch import nn\r\nfrom torch.nn import Parameter\r\n\r\nimport abc\r\nimport math\r\nfrom abc import ABC\r\n\r\nimport numpy as np\r\nimport torch\r\nfrom torch import nn\r\nfrom torch.nn import Parameter\r\nimport torch.nn.functional as F\r\nimport matplotlib.pyplot as plt\r\nfrom braincog.base.strategy.surrogate import *\r\n\r\nimport os\r\nos.environ[\"KMP_DUPLICATE_LIB_OK\"] = \"TRUE\"\r\n\r\n\r\nclass IPLNet(nn.Module):\r\n    \"\"\"\r\n    inferior parietal lobule (IPL)\r\n    \"\"\"\r\n\r\n    def __init__(self, connection):\r\n        \"\"\"\r\n        Setting the network structure of IPL\r\n        \"\"\"\r\n        super().__init__()\r\n        # IPLM, IPLV\r\n        self.num_subMB = 2\r\n        self.node = [IzhNodeMU(threshold=30., a=0.02, b=0.2, c=-65., d=6., mem=-70.) for i in range(self.num_subMB)]\r\n\r\n        self.connection = connection\r\n        self.learning_rule = []\r\n\r\n        self.learning_rule.append(STDP(self.node[0], self.connection[0]))  # vPMC_input-IPLM\r\n        self.learning_rule.append(MutliInputSTDP(self.node[1], [self.connection[1], self.connection[2]]))  # STS_input-IPLV, IPLM-IPLV\r\n\r\n        self.out_IPLM = torch.zeros((self.connection[0].weight.shape[1]), dtype=torch.float)\r\n        self.out_IPLV = torch.zeros((self.connection[1].weight.shape[1]), dtype=torch.float)\r\n\r\n    def forward(self, input1, input2):  # input from vPMC and STS\r\n        \"\"\"\r\n        Calculate the output of IPLv and the weight update between IPLm and IPLv\r\n        :param input1: input from vPMC\r\n        :param input2: input from STS\r\n        :return: output of IPLv, weight update between IPLm and IPLv\r\n        \"\"\"\r\n        self.out_IPLM = self.node[0](self.connection[0](input1))\r\n        self.out_IPLV, dw_IPLv = self.learning_rule[1](input2, self.out_IPLM)\r\n        if sum(sum(self.out_IPLV)) == 1:\r\n            dw_IPLv = dw_IPLv[0][torch.nonzero(dw_IPLv[1])[0][1]][torch.nonzero(dw_IPLv[1])[0][1]] * dw_IPLv[1]\r\n        else:\r\n            dw_IPLv = dw_IPLv[0]\r\n        return self.out_IPLV, dw_IPLv\r\n\r\n    def UpdateWeight(self, i, dw):\r\n        \"\"\"\r\n        Update the weight\r\n        :param i: index of the connection to update\r\n        :param dw: weight update\r\n        :return: None\r\n        \"\"\"\r\n        self.connection[i].update(dw)\r\n\r\n    def reset(self):\r\n        \"\"\"\r\n        reset the network\r\n        :return: None\r\n        \"\"\"\r\n        for i in range(self.num_subMB):\r\n            self.node[i].n_reset()\r\n        for i in range(len(self.learning_rule)):\r\n            self.learning_rule[i].reset()\r\n\r\n    def getweight(self):\r\n        \"\"\"\r\n        Get the connection and weight in IPL\r\n        :return: connection\r\n        \"\"\"\r\n        return self.connection\r\n"
  },
  {
    "path": "braincog/base/brainarea/Insula.py",
    "content": "import numpy as np\r\nimport torch,os,sys\r\nfrom torch import nn\r\nfrom torch.nn import Parameter \r\n\r\nimport abc\r\nimport math\r\nfrom abc import ABC\r\n\r\nimport numpy as np\r\nimport torch\r\nfrom torch import nn\r\nfrom torch.nn import Parameter\r\nimport torch.nn.functional as F\r\nimport matplotlib.pyplot as plt\r\nfrom braincog.base.strategy.surrogate import *\r\n\r\nimport os\r\nos.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\"\r\nimport random\r\n\r\nfrom braincog.base.connection.CustomLinear import *\r\nfrom braincog.base.node.node import *\r\nfrom braincog.base.learningrule.STDP import *\r\n\r\n\r\nclass InsulaNet(nn.Module):\r\n    \"\"\"\r\n    Insula\r\n    \"\"\"\r\n    def __init__(self,connection):\r\n        \"\"\"\r\n        Setting the network structure of Insula\r\n        \"\"\"\r\n        super().__init__()\r\n        # Insula\r\n        self.num_subMB = 1\r\n        self.node = [IzhNodeMU(threshold=30., a=0.02, b=0.2, c=-65., d=6., mem=-70.) for i in range(self.num_subMB)]\r\n        self.connection = connection\r\n        self.learning_rule = []        \r\n        self.learning_rule.append(MutliInputSTDP(self.node[0], [self.connection[0],self.connection[1]]))# IPLv-Insula, STS-Insula\r\n        self.Insula=torch.zeros((self.connection[1].weight.shape[1]), dtype=torch.float)\r\n\r\n    def forward(self, input1, input2): # input from IPLv and STS\r\n        \"\"\"\r\n        Calculate the output of Insula \r\n        :param input1: input from IPLv\r\n        :param input2: input from STS\r\n        :return: output of Insula, weight update (unused)\r\n        \"\"\"\r\n        self.out_Insula, dw_Insula = self.learning_rule[0](input1, input2)\r\n        return self.out_Insula\r\n\r\n    def UpdateWeight(self,i,dw):\r\n        \"\"\"\r\n        Update the weight\r\n        :param i: index of the connection to update\r\n        :param dw: weight update\r\n        :return: None\r\n        \"\"\"\r\n        self.connection[i].update(dw)\r\n   \r\n    def reset(self):\r\n        \"\"\"\r\n        reset the network\r\n        :return: None\r\n        \"\"\"\r\n        for i in range(self.num_subMB):\r\n            self.node[i].n_reset()\r\n        for i in range(len(self.learning_rule)):\r\n            self.learning_rule[i].reset()\r\n    \r\n    def getweight(self):\r\n        \"\"\"\r\n        Get the connection and weight in Insula\r\n        :return: connection\r\n        \"\"\"\r\n        return self.connection"
  },
  {
    "path": "braincog/base/brainarea/PFC.py",
    "content": "import torch\nfrom torch import nn\nfrom braincog.base.brainarea import BrainArea\nfrom braincog.model_zoo.base_module import BaseLinearModule, BaseModule\n\n\nclass PFC:\n    \"\"\"\n    PFC\n    \"\"\"\n    def __init__(self):\n        \"\"\"\n        \"\"\"\n        super().__init__()\n\n    def forward(self, x):\n        \"\"\"\n\n        :return:x\n        \"\"\"\n\n        return x\n\n    def reset(self):\n        \"\"\"\n\n        :return:x\n        \"\"\"\n\n        pass\n\n\nclass dlPFC(BaseModule, PFC):\n    \"\"\"\n    SNNLinear\n    \"\"\"\n    def __init__(self,\n                 step,\n                 encode_type,\n                 in_features:int,\n                 out_features:int,\n                 bias,\n                 *args,\n                 **kwargs):\n        super().__init__(step, encode_type, *args, **kwargs)\n        self.bias = bias\n        self.in_features = in_features\n        self.out_features = out_features\n        self.fc = self._create_fc()\n        self.c = self._rest_c()\n\n    def _rest_c(self):\n        c = torch.rand((self.out_features, self.in_features)) # eligibility trace\n        return c\n\n    def _create_fc(self):\n        \"\"\"\n        the connection of the SNN linear\n        @return: nn.Linear\n        \"\"\"\n        fc = nn.Linear(in_features=self.in_features,\n                  out_features=self.out_features, bias=self.bias)\n        return fc\n\n\n\n"
  },
  {
    "path": "braincog/base/brainarea/__init__.py",
    "content": "from .basalganglia import basalganglia\nfrom .BrainArea import BrainArea, ThreePointForward, Feedback, TwoInOneOut, SelfConnectionArea\nfrom .Insula import InsulaNet\nfrom .IPL import IPLNet\nfrom .PFC import PFC, dlPFC\n\n\n__all__ = [\n    'basalganglia',\n    'BrainArea', 'ThreePointForward', 'Feedback', 'TwoInOneOut', 'SelfConnectionArea',\n    'InsulaNet',\n    'IPLNet',\n    'PFC', 'dlPFC'\n]\n"
  },
  {
    "path": "braincog/base/brainarea/basalganglia.py",
    "content": "import numpy as np\r\nimport torch\r\nimport os\r\nimport sys\r\nfrom torch import nn\r\nfrom torch.nn import Parameter\r\n\r\nimport abc\r\nimport math\r\nfrom abc import ABC\r\n\r\nimport numpy as np\r\nimport torch\r\nimport torch.nn.functional as F\r\nfrom braincog.base.strategy.surrogate import *\r\nfrom braincog.base.node.node import IFNode, SimHHNode\r\nfrom braincog.base.learningrule.STDP import STDP, MutliInputSTDP\r\nfrom braincog.base.connection.CustomLinear import CustomLinear\r\n\r\n\r\nclass basalganglia(nn.Module):\r\n    \"\"\"\r\n    Basal Ganglia\r\n    \"\"\"\r\n\r\n    def __init__(self, ns, na, we, wi, node_type):\r\n        super().__init__()\r\n        \"\"\"\r\n        :param ns: 状态个数\r\n        :param na:动作个数\r\n        :param we:兴奋性连接权重\r\n        :param wi:抑制性连接权重\r\n        \"\"\"\r\n        num_state = ns\r\n        num_action = na\r\n        num_STN = 2\r\n        weight_exc = we\r\n        weight_inh = wi\r\n        # connetions: 0DLPFC-StrD1 1DLPFC-StrD2 2DLPFC-STN 3StrD1-GPi 4StrD2-GPe 5Gpe-Gpi 6STN-Gpi 7STN-Gpe 8Gpe-STN\r\n        bg_connection = []\r\n        bg_con_mask = []\r\n        # DLPFC-StrD1\r\n        con_matrix1 = torch.zeros((num_state, num_state * num_action), dtype=torch.float)\r\n        for i in range(num_state):\r\n            for j in range(num_action):\r\n                con_matrix1[i, i * num_action + j] = 1\r\n        bg_con_mask.append(con_matrix1)\r\n        bg_connection.append(CustomLinear(weight_exc * con_matrix1, con_matrix1))\r\n        # DLPFC-StrD2\r\n        bg_connection.append(CustomLinear(weight_exc * con_matrix1, con_matrix1))\r\n        bg_con_mask.append(con_matrix1)\r\n        # DLPFC-STN\r\n        con_matrix3 = torch.ones((num_state, num_STN), dtype=torch.float)\r\n        bg_con_mask.append(con_matrix3)\r\n        bg_connection.append(CustomLinear(weight_exc * con_matrix3, con_matrix3))\r\n        # StrD1-GPi\r\n        con_matrix4 = torch.zeros((num_state * num_action, num_action), dtype=torch.float)\r\n        for i in range(num_state):\r\n            for j in range(num_action):\r\n                con_matrix4[i * num_action + j, j] = 1\r\n        bg_con_mask.append(con_matrix4)\r\n        bg_connection.append(CustomLinear(weight_inh * con_matrix4, con_matrix4))\r\n        # StrD2-GPe\r\n        bg_con_mask.append(con_matrix4)\r\n        bg_connection.append(CustomLinear(weight_inh * con_matrix4, con_matrix4))\r\n        # Gpe-Gpi\r\n        con_matrix5 = torch.eye((num_action), dtype=torch.float)\r\n        bg_con_mask.append(con_matrix5)\r\n        bg_connection.append(CustomLinear(weight_inh * con_matrix5, con_matrix5))\r\n        # STN-Gpi\r\n        con_matrix6 = torch.ones((num_STN, num_action), dtype=torch.float)\r\n        bg_con_mask.append(con_matrix6)\r\n        bg_connection.append(CustomLinear(0.5 * weight_exc * con_matrix6, con_matrix6))\r\n        # STN-Gpe\r\n        bg_con_mask.append(con_matrix6)\r\n        bg_connection.append(CustomLinear(0.5 * weight_exc * con_matrix6, con_matrix6))\r\n        # Gpe-STN\r\n        con_matrix7 = torch.ones((num_action, num_STN), dtype=torch.float)\r\n        bg_con_mask.append(con_matrix7)\r\n        bg_connection.append(CustomLinear(0.5 * weight_inh * con_matrix7, con_matrix7))\r\n\r\n        self.num_subBG = 5\r\n        self.node_type = node_type\r\n        if self.node_type == \"hh\":\r\n            self.node = [SimHHNode() for i in range(self.num_subBG)]\r\n        if self.node_type == \"lif\":\r\n            self.node = [IFNode() for i in range(self.num_subBG)]\r\n        self.connection = bg_connection\r\n        self.mask = bg_con_mask\r\n        self.learning_rule = []\r\n\r\n        trace_stdp = 0.99\r\n        self.learning_rule.append(STDP(self.node[0], self.connection[0], trace_stdp))  # DLPFC-StrD1\r\n        self.learning_rule.append(STDP(self.node[1], self.connection[1], trace_stdp))  # DLPFC-StrD2\r\n        self.learning_rule.append(MutliInputSTDP(self.node[2], [self.connection[2], self.connection[8]]))  # DLPFC-STN\r\n        self.learning_rule.append(MutliInputSTDP(self.node[3], [self.connection[4], self.connection[7]]))  # StrD2-GPe STN-Gpe\r\n        self.learning_rule.append(MutliInputSTDP(self.node[4], [self.connection[3], self.connection[5], self.connection[6]]))  # StrD1-GPi Gpe-Gpi STN-Gpi\r\n        self.out_StrD1 = torch.zeros((self.connection[0].weight.shape[1]), dtype=torch.float)\r\n        self.out_StrD2 = torch.zeros((self.connection[1].weight.shape[1]), dtype=torch.float)\r\n        self.out_STN = torch.zeros((self.connection[2].weight.shape[1]), dtype=torch.float)\r\n        self.out_Gpi = torch.zeros((self.connection[3].weight.shape[1]), dtype=torch.float)\r\n        self.out_Gpe = torch.zeros((self.connection[4].weight.shape[1]), dtype=torch.float)\r\n\r\n    def forward(self, input):\r\n        \"\"\"\r\n        计算由当前输入基底节网络的输出\r\n        :param input: 输入电流\r\n        :return: 输出脉冲\r\n        \"\"\"\r\n        self.out_StrD1, dw_strd1 = self.learning_rule[0](input)\r\n        self.out_StrD2, dw_strd2 = self.learning_rule[1](input)\r\n        self.out_STN, dw_stn = self.learning_rule[2](input, self.out_Gpe)\r\n        self.out_Gpe, dw_gpe = self.learning_rule[3](self.out_StrD2, self.out_STN)\r\n        self.out_Gpi, dw_gpi = self.learning_rule[4](self.out_StrD1, self.out_Gpe, self.out_STN)\r\n        return self.out_Gpi\r\n\r\n    def UpdateWeight(self, i, dw):\r\n        \"\"\"\r\n        更新基底节内第i组连接的权重 根据传入的dw值\r\n        :param i: 要更新的连接的索引\r\n        :param dw: 更新的量\r\n        :return: None\r\n        \"\"\"\r\n        self.connection[i].update(dw)\r\n        self.connection[i].weight.data = F.normalize(self.connection[i].weight.data.float(), p=1, dim=1)\r\n\r\n    def reset(self):\r\n        \"\"\"\r\n        reset神经元或学习法则的中间量\r\n        :return: None\r\n        \"\"\"\r\n        for i in range(self.num_subMB):\r\n            self.node[i].n_reset()\r\n        for i in range(len(self.learning_rule)):\r\n            self.learning_rule[i].reset()\r\n\r\n    def getweight(self):\r\n        \"\"\"\r\n        获取基底节网络的连接(包括权值等)\r\n        :return: 基底节网络的连接\r\n        \"\"\"\r\n        return self.connection\r\n\r\n    def getmask(self):\r\n        \"\"\"\r\n        获取基底节网络的连接（仅连接矩阵）\r\n        :return: 基底节网络的连接矩阵\r\n        \"\"\"\r\n        return self.mask\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    BG = basalganglia(4, 2, 0.2, -4)\r\n    con = BG.getweight()\r\n    print(con)\r\n"
  },
  {
    "path": "braincog/base/brainarea/dACC.py",
    "content": "import torch\nimport matplotlib.pyplot as plt\nimport numpy as np\nnp.set_printoptions(threshold=np.inf)\nfrom utils.one_hot import *\nimport os\nimport time\nimport sys\nfrom tqdm import tqdm\n\nfrom braincog.base.encoder.population_coding import *\nfrom braincog.model_zoo.base_module import BaseLinearModule, BaseModule\nfrom braincog.base.learningrule.STDP import *\nimport sys\nsys.path.append(\"..\")\n\nclass dACC(BaseModule):\n    \"\"\"\n    SNNLinear\n    \"\"\"\n    def __init__(self,\n                 step,\n                 encode_type,\n                 in_features:int,\n                 out_features:int,\n                 bias,\n                 node,\n                 *args,\n                 **kwargs):\n        super().__init__(step, encode_type, *args, **kwargs)\n        self.bias = bias\n        self.in_features = in_features\n        self.out_features = out_features\n        self.node1 = node(threshold=0.5, tau=2.)\n        self.node_name1 = node\n        self.node2 = node(threshold=0.1, tau=2.)\n        self.node_name2 = node\n        self.fc = self._create_fc()\n        self.c = self._rest_c()\n\n\n    def _rest_c(self):\n        c = torch.rand((self.out_features, self.in_features)) # eligibility trace\n        return c\n\n    def _create_fc(self):\n        \"\"\"\n        the connection of the SNN linear\n        @return: nn.Linear\n        \"\"\"\n        fc = nn.Linear(in_features=self.in_features,\n                  out_features=self.out_features, bias=self.bias)\n        return fc\n\n    def update_c(self, c, STDP, tau_c=0.2):\n        \"\"\"\n        update the trace of eligibility\n        @param c: a tensor to record eligibility\n        @param STDP: the results of STDP\n        @param tau_c: the parameter of trace decay\n        @return: a update tensor to record eligibility\n        Equation:\n        delta_c = (-(c / tau_c) + STDP) * dela_t\n        c = c + delta_c\n        reference:<Solving the Distal Reward Problem through ...>\n        \"\"\"\n        c = c + tau_c * STDP\n        return c\n\n    def forward(self, inputs, epoch):\n        \"\"\"\n        decision\n        @param inputs: state\n        @return: action\n        \"\"\"\n        output = []\n        stdp = STDP(self.node2, self.fc, decay=0.80)\n        self.c = self._rest_c()\n        # stdp.connection.weight.data = torch.rand((self.out_features, self.in_features))\n\n        for i in range(inputs.shape[0]):\n            for t in range(self.step):\n                l1_in = torch.tensor(inputs[i, :])\n                l1_out = self.node1(l1_in).unsqueeze(0)  #pre  : l1_out\n                l2_out, dw = stdp(l1_out)   #dw -- STDP\n                self.c = self.update_c(self.c, dw[0])\n            output.append(torch.min(l2_out))\n            # output.append((l2_out.any() == 0).cpu().detach().numpy().tolist())\n\n        return output\n\n\n# if __name__ == '__main__':\n#     np.random.seed(6)\n#     T = 5\n#     num_popneurons = 2\n#     safety = 2\n#     epoch = 50\n#     file_name = \"/home/zhaozhuoya/braincog/examples/ToM/data/injury_value.txt\"\n#     state = []\n#     with open(file_name) as f:\n#         data = []\n#         data_split = f.readlines()  #\n#         for i in data_split:\n#             state.append(one_hot(int(i[0])))\n#\n#     output = np.array(state)\n#     train_y = output\n#     test_y = output[79:82]#output[12].reshape(1,2)\n#\n#     file_name = \"/home/zhaozhuoya/braincog/examples/ToM/data/injury_memory.txt\"\n#     state = []\n#     with open(file_name) as f:\n#         data_split = f.readlines()\n#         for i in data_split:\n#             data = []\n#             data.append(int(bool(abs(int(i[2]) - int(i[18]))))*10)\n#             data.append(int(bool(abs(int(i[5]) - int(i[21]))))*10)\n#             state.append(data)\n#     input = np.array(state)\n#     train_x = input\n#     test_x = input[79:82]\n#     dACC_net = dACC(step=T, encode_type='rate', bias=True,\n#                         in_features=num_popneurons, out_features=safety,\n#                         node=node.LIFNode)\n#     dACC_net.fc.weight.data = torch.rand((safety, num_popneurons))\n#     dACC_net.load_state_dict(torch.load('./checkpoint/dACC_net.pth')['dacc'])\n#     output = dACC_net(inputs=train_x, epoch=50)\n#     for i in range(len(output)):\n#         print(output[i], train_x[i])\n    # torch.save({'dacc': dACC_net.state_dict()}, os.path.join('./checkpoint', 'dACC_net.pth'))\n    # dACC_net.load_state_dict(torch.load('./checkpoint/dACC_net.pth')['dacc'])\n    # output = dACC_net(inputs=test_x, epoch=50)\n    # for i in range(len(test_x)):\n    #\n    #     print(output[i],test_x[i])\n\n\n"
  },
  {
    "path": "braincog/base/connection/CustomLinear.py",
    "content": "import os\r\nimport sys\r\n\r\nimport numpy as np\r\nimport torch\r\nfrom torch import nn\r\nfrom torch import einsum\r\nimport torch.nn.functional as F\r\n\r\n\r\nclass CustomLinear(nn.Module):\r\n    \"\"\"\r\n    用户自定义连接 通常stdp的计算\r\n    \"\"\"\r\n\r\n    def __init__(self, weight, mask=None):\r\n        super().__init__()\r\n\r\n        self.weight = nn.Parameter(weight, requires_grad=True)\r\n        self.mask = mask\r\n\r\n    def forward(self, x: torch.Tensor):\r\n        \"\"\"\r\n        :param x:输入 x.shape = [N ]\r\n        \"\"\"\r\n        #\r\n        # ret.shape = [C]\r\n\r\n        return x.matmul(self.weight)\r\n\r\n    def update(self, dw):\r\n        \"\"\"\r\n        :param dw:权重更新量\r\n        \"\"\"\r\n        with torch.no_grad():\r\n            if self.mask is not None:\r\n                dw *= self.mask\r\n            self.weight.data += dw\r\n"
  },
  {
    "path": "braincog/base/connection/__init__.py",
    "content": "from .CustomLinear import CustomLinear\nfrom .layer import VotingLayer, WTALayer, NDropout, ThresholdDependentBatchNorm2d, LayerNorm, SMaxPool, LIPool\n\n\n__all__ = [\n    'CustomLinear',\n    'VotingLayer', 'WTALayer', 'NDropout', 'ThresholdDependentBatchNorm2d', 'LayerNorm', 'SMaxPool', 'LIPool'\n]"
  },
  {
    "path": "braincog/base/connection/layer.py",
    "content": "import warnings\nimport math\nimport numpy as np\nimport torch\nfrom torch import nn\nfrom torch import einsum\nfrom torch.nn.modules.batchnorm import _BatchNorm\nimport torch.nn.functional as F\nfrom torch.nn import Parameter\nfrom einops import rearrange\n\n\nclass VotingLayer(nn.Module):\n    \"\"\"\n    用于SNNs的输出层, 几个神经元投票选出最终的类\n    :param voter_num: 投票的神经元的数量, 例如 ``voter_num = 10``, 则表明会对这10个神经元取平均\n    \"\"\"\n\n    def __init__(self, voter_num: int):\n        super().__init__()\n        self.voting = nn.AvgPool1d(voter_num, voter_num)\n\n    def forward(self, x: torch.Tensor):\n        # x.shape = [N, voter_num * C]\n        # ret.shape = [N, C]\n        return self.voting(x.unsqueeze(1)).squeeze(1)\n\n\nclass WTALayer(nn.Module):\n    \"\"\"\n    winner take all用于SNNs的每层后，将随机选取一个或者多个输出\n    :param k: X选取的输出数目 k默认等于1\n    \"\"\"\n    def __init__(self, k=1):\n        super().__init__()\n        self.k = k\n\n    def forward(self, x: torch.Tensor):\n        # x.shape = [N, C,W,H]\n        # ret.shape = [N, C,W,H]\n        pos = x * torch.rand(x.shape, device=x.device)\n        if self.k > 1:\n            x = x * (pos >= pos.topk(self.k, dim=1)[0][:, -1:]).float()\n        else:\n            x = x * (pos >= pos.max(1, True)[0]).float()\n\n        return x\n\n\nclass NDropout(nn.Module):\n    \"\"\"\n    与Drop功能相同, 但是会保证同一个样本不同时刻的mask相同.\n    \"\"\"\n\n    def __init__(self, p):\n        super(NDropout, self).__init__()\n        self.p = p\n        self.mask = None\n\n    def n_reset(self):\n        \"\"\"\n        重置, 能够生成新的mask\n        :return:\n        \"\"\"\n        self.mask = None\n\n    def create_mask(self, x):\n        \"\"\"\n        生成新的mask\n        :param x: 输入Tensor, 生成与之形状相同的mask\n        :return:\n        \"\"\"\n        self.mask = F.dropout(torch.ones_like(x.data), self.p, training=True)\n\n    def forward(self, x):\n        if self.training:\n            if self.mask is None:\n                self.create_mask(x)\n\n            return self.mask * x\n        else:\n            return x\n\n\nclass WSConv2d(nn.Conv2d):\n\n    def __init__(self, in_channels, out_channels, kernel_size, stride=1,\n                 padding=0, dilation=1, groups=1, bias=True, gain=True):\n        super(WSConv2d, self).__init__(in_channels, out_channels, kernel_size, stride,\n                                       padding, dilation, groups, bias)\n\n        if gain:\n            self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1))\n        else:\n            self.gain = 1.\n\n    def forward(self, x):\n        weight = self.weight\n        weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,\n                                                            keepdim=True).mean(dim=3, keepdim=True)\n        weight = weight - weight_mean\n        std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5\n        weight = self.gain * weight / std.expand_as(weight)\n        return F.conv2d(x, weight, self.bias, self.stride,\n                        self.padding, self.dilation, self.groups)\n\n\nclass ThresholdDependentBatchNorm2d(_BatchNorm):\n    \"\"\"\n    tdBN\n    https://ojs.aaai.org/index.php/AAAI/article/view/17320\n    \"\"\"\n\n    def __init__(self, num_features, alpha: float, threshold: float = .5, layer_by_layer: bool = True, affine: bool = True,**kwargs):\n        self.alpha = alpha\n        self.threshold = threshold\n\n        super().__init__(num_features=num_features, affine=affine)\n\n        assert layer_by_layer, \\\n            'tdBN may works in step-by-step mode, which will not take temporal dimension into batch norm'\n        assert self.affine, 'ThresholdDependentBatchNorm needs to set `affine = True`!'\n\n        torch.nn.init.constant_(self.weight, alpha * threshold)\n\n    def _check_input_dim(self, input):\n        if input.dim() != 4:\n            raise ValueError(\"expected 4D input (got {}D input)\".format(input.dim()))\n\n    def forward(self, input):\n        # input = rearrange(input, '(t b) c w h -> b (t c) w h', t=self.step)\n        output = super().forward(input)\n        return output\n        # return rearrange(output, 'b (t c) w h -> (t b) c w h', t=self.step)\n\nclass TEBN(nn.Module):\n    def __init__(self, num_features,step, eps=1e-5, momentum=0.1,**kwargs):\n        super(TEBN, self).__init__()\n        self.bn = nn.BatchNorm3d(num_features)\n        self.p = nn.Parameter(torch.ones(4, 1, 1, 1, 1))\n        self.step=step\n    def forward(self, input):\n        #y = input.transpose(1, 2).contiguous()  # N T C H W ,  N C T H W\n        y = rearrange(input,\"(t b) c w h -> t c b w h\",t=self.step)\n        y = self.bn(y)\n        # y = y.contiguous().transpose(1, 2)\n        # y = y.transpose(0, 1).contiguous()  # NTCHW  TNCHW\n        y = rearrange(y,\"t c b w h -> t b c w h\")\n        y = y * self.p\n        #y = y.contiguous().transpose(0, 1)  # TNCHW  NTCHW\n        y = rearrange(y, \"t b c w h -> (t b) c w h\")\n        return y\nclass LayerNorm(nn.Module):\n    \"\"\" LayerNorm that supports two data formats: channels_last (default) or channels_first.\n    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with\n    shape (batch_size, height, width, channels) while channels_first corresponds to inputs\n    with shape (batch_size, channels, height, width).\n    \"\"\"\n\n    def __init__(self, normalized_shape, eps=1e-6, data_format=\"channels_last\"):\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(normalized_shape))\n        self.bias = nn.Parameter(torch.zeros(normalized_shape))\n        self.eps = eps\n        self.data_format = data_format\n        if self.data_format not in [\"channels_last\", \"channels_first\"]:\n            raise NotImplementedError\n        self.normalized_shape = (normalized_shape,)\n\n    def forward(self, x):\n        if self.data_format == \"channels_last\":\n            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)\n        elif self.data_format == \"channels_first\":\n            u = x.mean(1, keepdim=True)\n            s = (x - u).pow(2).mean(1, keepdim=True)\n            x = (x - u) / torch.sqrt(s + self.eps)\n            x = self.weight[:, None, None] * x + self.bias[:, None, None]\n            return x\n\n\nclass SMaxPool(nn.Module):\n    \"\"\"用于转换方法的最大池化层的常规替换\n    选用具有最大脉冲发放率的神经元的脉冲通过，能够满足一般性最大池化层的需要\n\n    Reference:\n    https://arxiv.org/abs/1612.04052\n    \"\"\"\n\n    def __init__(self, child):\n        super(SMaxPool, self).__init__()\n        self.opration = child\n        self.sumspike = 0\n\n    def forward(self, x):\n        self.sumspike += x\n        single = self.opration(self.sumspike * 1000)\n        sum_plus_spike = self.opration(x + self.sumspike * 1000)\n\n        return sum_plus_spike - single\n\n    def reset(self):\n        self.sumspike = 0\n\n\nclass LIPool(nn.Module):\n    r\"\"\"用于转换方法的最大池化层的精准替换\n    LIPooling通过引入侧向抑制机制保证在转换后的SNN中输出的最大值与期望值相同。\n\n    Reference:\n    https://arxiv.org/abs/2204.13271\n    \"\"\"\n\n    def __init__(self, child=None):\n        super(LIPool, self).__init__()\n        if child is None:\n            raise NotImplementedError(\"child should be Pooling operation with torch.\")\n\n        self.opration = child\n        self.sumspike = 0\n\n    def forward(self, x):\n        self.sumspike += x\n        out = self.opration(self.sumspike)\n        self.sumspike -= F.interpolate(out, scale_factor=2, mode='nearest')\n        return out\n\n    def reset(self):\n        self.sumspike = 0\n\n\nclass CustomLinear(nn.Module):\n\n    def __init__(self, in_channels, out_channels, bias=True):\n        super(CustomLinear, self).__init__()\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        # self.weight = Parameter(torch.tensor([\n        #     [1., .5, .25, .125],\n        #     [0., 1., .5, .25],\n        #     [0., 0., 1., .5],\n        #     [0., 0., 0., 1.]\n        # ]), requires_grad=True)\n        self.weight = Parameter(torch.diag(torch.ones(self.in_channels)), requires_grad=True)\n        # self.weight = Parameter(torch.randn(self.in_channels, self.in_channels))\n        mask = torch.tril(torch.ones(self.in_channels, self.in_channels), diagonal=0)\n        self.register_buffer('mask', mask)\n\n        if bias:\n            self.bias = Parameter(torch.zeros(out_channels), requires_grad=True)\n        else:\n            self.register_parameter('bias', None)\n\n    def forward(self, inputs):\n        weight = self.mask * self.weight\n        return F.linear(inputs, weight, self.bias)\n"
  },
  {
    "path": "braincog/base/conversion/__init__.py",
    "content": "from .convertor import HookScale, Hookoutput, Scale, Convertor, SNode\nfrom .merge import mergeConvBN, merge\n\n\n__all__ = [\n    'Hookoutput', 'HookScale', 'Scale', 'Convertor', 'SNode',\n    'merge', 'mergeConvBN'\n]"
  },
  {
    "path": "braincog/base/conversion/convertor.py",
    "content": "import torch\nimport torch.nn as nn\nfrom braincog.base.connection.layer import SMaxPool, LIPool\nfrom .merge import mergeConvBN\nfrom .spicalib import SpiCalib\nimport types\n\n\nclass HookScale(nn.Module):\n    \"\"\" 在每个ReLU层后记录该层的百分位最大值\n\n    For channelnorm: 获取最大值时使用了torch.quantile\n    For layernorm：  使用sort，然后手动取百分比，因为quantile在计算单个通道时有上限，batch较大时易出错\n    \"\"\"\n\n    def __init__(self,\n                 p: float = 0.9995,\n                 channelnorm: bool = False,\n                 gamma: float = 0.999,\n                 ):\n        super().__init__()\n        if channelnorm:\n            self.register_buffer('scale', torch.tensor(0.0))\n        else:\n            self.register_buffer('scale', torch.tensor(0.0))\n\n        self.p = p\n        self.channelnorm = channelnorm\n        self.gamma = gamma\n\n    def forward(self, x):\n        x = torch.where(x.detach() < self.gamma, x.detach(),\n                        torch.tensor(self.gamma, dtype=x.dtype, device=x.device))\n        if len(x.shape) == 4 and self.channelnorm:\n            num_channel = x.shape[1]\n            tmp = torch.quantile(x.permute(1, 0, 2, 3).reshape(num_channel, -1), self.p, dim=1,\n                                 interpolation='lower') + 1e-10\n            self.scale = torch.max(tmp, self.scale)\n        else:\n            sort, _ = torch.sort(x.view(-1))\n            self.scale = torch.max(sort[int(sort.shape[0] * self.p) - 1], self.scale)\n        return x\n\n\nclass Hookoutput(nn.Module):\n    \"\"\"\n    在伪转换中为ReLU和ClipQuan提供包装，用于监控其输出\n    \"\"\"\n\n    def __init__(self, module):\n        super(Hookoutput, self).__init__()\n        self.activation = 0.\n        self.operation = module\n\n    def forward(self, x):\n        output = self.operation(x)\n        self.activation = output.detach()\n        return output\n\n\nclass Scale(nn.Module):\n    \"\"\"\n    对前向过程的值进行缩放\n    \"\"\"\n\n    def __init__(self, scale: float = 1.0):\n        super().__init__()\n        self.register_buffer('scale', scale)\n\n    def forward(self, x):\n        if len(self.scale.shape) == 1:\n            return self.scale.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(x) * x\n        else:\n            return self.scale * x\n\n\ndef reset(self):\n    \"\"\"\n    转换的网络来自ANN，需要将新附加上的脉冲module进行reset\n    判断module名称并调用各自节点的reset方法\n    \"\"\"\n    children = list(self.named_children())\n    for i, (name, child) in enumerate(children):\n        if isinstance(child, (SNode, LIPool, SMaxPool)):\n            child.reset()\n        else:\n            reset(child)\n\n\nclass Convertor(nn.Module):\n    \"\"\"ANN2SNN转换器\n\n    用于转换完整的pytorch模型，使用dataloader中部分数据进行最大值计算，通过p控制获取第p百分比最大值\n\n    channlenorm: https://arxiv.org/abs/1903.06530\n    channelnorm可以对每个通道获取最大值并进行权重归一化\n\n    gamma: https://arxiv.org/abs/2204.13271\n    gamma可以控制burst spikes的脉冲数，burst spike可以提高神经元的脉冲发放能力，减小信息残留\n\n    lipool: https://arxiv.org/abs/2204.13271\n    lipool用于使用侧向抑制机制进行最大池化，LIPooling能够对SNN中的最大池化进行有效的转换\n\n    soft_mode: https://arxiv.org/abs/1612.04052\n    soft_mode被称为软重置，可以减小重置过程神经元的信息损失，有效提高转换的性能\n\n    merge用于是否对网络中相邻的卷积和BN层进行融合\n    batch_norm控制对dataloader的数据集的用量\n    \"\"\"\n\n    def __init__(self,\n                 dataloader,\n                 device=None,\n                 p=0.9995,\n                 channelnorm=False,\n                 lipool=True,\n                 gamma=1,\n                 soft_mode=True,\n                 merge=True,\n                 batch_num=1,\n                 spicalib=0\n                 ):\n        super(Convertor, self).__init__()\n        self.dataloader = dataloader\n        self.device = device\n        self.p = p\n        self.channelnorm = channelnorm\n        self.lipool = lipool\n        self.gamma = gamma\n        self.soft_mode = soft_mode\n        self.merge = merge\n        self.batch_num = batch_num\n        self.spicalib = spicalib\n\n    def forward(self, model):\n        model.eval()\n        model = Convertor.register_hook(model, self.p, self.channelnorm, self.gamma)\n        model = Convertor.get_percentile(model, self.dataloader, self.device, batch_num=self.batch_num)\n        model = mergeConvBN(model) if self.merge else model\n        model = Convertor.replace_for_spike(model, self.lipool, self.soft_mode, self.gamma, self.spicalib)\n        model.reset = types.MethodType(reset, model)\n        return model\n\n    @staticmethod\n    def register_hook(model, p=0.99, channelnorm=False, gamma=0.999):\n        \"\"\" Reference: https://github.com/fangwei123456/spikingjelly\n\n        将网络的每一层后注册一个HookScale类\n        该方法在仿真上等效于与对权重进行归一化操作，且易扩展到任意结构的网络中\n        \"\"\"\n        children = list(model.named_children())\n        for _, (name, child) in enumerate(children):\n            if isinstance(child, nn.ReLU):\n                model._modules[name] = nn.Sequential(nn.ReLU(), HookScale(p, channelnorm, gamma))\n            else:\n                Convertor.register_hook(child, p, channelnorm, gamma)\n        return model\n\n    @staticmethod\n    def get_percentile(model, dataloader, device, batch_num=1):\n        \"\"\"\n        该函数需与具有HookScale层的网络配合使用\n        \"\"\"\n        for idx, (data, _) in enumerate(dataloader):\n            data = data.to(device)\n            if idx >= batch_num:\n                break\n            model(data)\n        return model\n\n    @staticmethod\n    def replace_for_spike(model, lipool=True, soft_mode=True, gamma=1, spicalib=0):\n        \"\"\"\n        该函数用于将定义好的ANN模型转换为SNN模型\n        ReLU单元将被替换为脉冲神经元，\n        如果模型中使用了最大池化，lipool参数将定义使用常规模型还是LIPooling方法\n        \"\"\"\n        children = list(model.named_children())\n        for _, (name, child) in enumerate(children):\n            if isinstance(child, nn.Sequential) and len(child) == 2 and isinstance(child[0], nn.ReLU) and isinstance(child[1], HookScale):\n                model._modules[name] = nn.Sequential(\n                    Scale(1.0 / child[1].scale),\n                    SNode(soft_mode, gamma),\n                    SpiCalib(spicalib),\n                    Scale(child[1].scale)\n                )\n            if isinstance(child, nn.MaxPool2d):\n                model._modules[name] = LIPool(child) if lipool else SMaxPool(child)\n            else:\n                Convertor.replace_for_spike(child, lipool, soft_mode, gamma)\n        return model\n\n\nclass SNode(nn.Module):\n    \"\"\"\n    用于转换后的SNN的神经元模型\n    IF神经元模型由gamma=1确定，当gamma为其他大于1的值时，即为使用burst神经元模型\n    soft_mode用于定义神经元的重置方法，soft重置能够极大地减少神经元在重置过程的信息损失\n    \"\"\"\n\n    def __init__(self, soft_mode=False, gamma=5):\n        super(SNode, self).__init__()\n        self.threshold = 1.0\n        self.soft_mode = soft_mode\n        self.gamma = gamma\n\n        self.mem = 0\n        self.spike = 0\n\n    def forward(self, x):\n        self.mem = self.mem + x\n        self.spike = (self.mem / self.threshold).floor().clamp(min=0, max=self.gamma)\n        self.soft_reset() if self.soft_mode else self.hard_reset\n\n        out = self.spike\n        return out\n\n    def hard_reset(self):\n        \"\"\"\n        硬重置后神经元的膜电势被重置为0\n        \"\"\"\n        self.mem = self.mem * (1 - self.spike.detach())\n\n    def soft_reset(self):\n        \"\"\"\n        软重置后神经元的膜电势为神经元当前膜电势减去阈值\n        \"\"\"\n        self.mem = self.mem - self.threshold * self.spike.detach()\n\n    def reset(self):\n        self.mem = 0\n        self.spike = 0\n"
  },
  {
    "path": "braincog/base/conversion/merge.py",
    "content": "import torch\nimport torch.nn as nn\n\n\ndef mergeConvBN(m):\n    \"\"\"\n    合并网络模块中的卷积与BN层\n    \"\"\"\n    children = list(m.named_children())\n    c, cn = None, None\n\n    for i, (name, child) in enumerate(children):\n        if isinstance(child, nn.BatchNorm2d):\n            bc = merge(c, child)\n            m._modules[cn] = bc\n            m._modules[name] = torch.nn.Identity()\n            c = None\n        elif isinstance(child, nn.Conv2d):\n            c = child\n            cn = name\n        else:\n            mergeConvBN(child)\n    return m\n\n\ndef merge(conv, bn):\n    \"\"\"\n    conv: 卷积层实例\n    bn: BN层实例\n    \"\"\"\n    w = conv.weight\n    mean, var_sqrt, beta, gamma = bn.running_mean, torch.sqrt(bn.running_var + bn.eps), bn.weight, bn.bias\n    b = conv.bias if conv.bias is not None else mean.new_zeros(mean.shape)\n\n    w = w * (beta / var_sqrt).reshape([conv.out_channels, 1, 1, 1])\n    b = (b - mean) / var_sqrt * beta + gamma\n    fused_conv = nn.Conv2d(conv.in_channels, conv.out_channels, conv.kernel_size, conv.stride, conv.padding, bias=True)\n    fused_conv.weight = nn.Parameter(w)\n    fused_conv.bias = nn.Parameter(b)\n    return fused_conv\n"
  },
  {
    "path": "braincog/base/conversion/spicalib.py",
    "content": "import torch\nimport torch.nn as nn\n\n\nclass SpiCalib(nn.Module):\n    def __init__(self, allowance):\n        super(SpiCalib, self).__init__()\n        self.allowance = allowance\n        self.sumspike = 0\n        self.t = 0\n\n    def forward(self, x):\n        if self.allowance == 0:\n            return x\n\n        if self.t == 0:\n            self.last_spike = torch.zeros_like(x)\n            self.avg_time = torch.zeros_like(x)\n            self.num_spike = torch.zeros_like(x)\n\n        SPIKE_MASK = x > 0\n        self.num_spike[SPIKE_MASK] += 1\n        self.avg_time[SPIKE_MASK] = (self.t - self.last_spike + self.avg_time * (self.num_spike - 1))[SPIKE_MASK] / \\\n                                    self.num_spike[SPIKE_MASK]\n        self.last_spike[SPIKE_MASK] = self.t\n        SIN_MASK = self.t - self.last_spike > self.avg_time + self.allowance\n        x[SIN_MASK] -= 1.0\n        self.sumspike += x\n        x[self.sumspike <= -1] = 0\n        self.t += 1\n        return x\n\n    def reset(self):\n        self.sumspike = 0\n        self.t = 0"
  },
  {
    "path": "braincog/base/encoder/__init__.py",
    "content": "from .encoder import Encoder\nfrom .population_coding import PEncoder\nfrom.qs_coding import QSEncoder\n\n\n__all__ = [\n    'Encoder',\n    'PEncoder',\n    'QSEncoder'\n]"
  },
  {
    "path": "braincog/base/encoder/encoder.py",
    "content": "import torch\nimport torch.nn as nn\nfrom einops import rearrange, repeat\nfrom braincog.base.strategy.surrogate import GateGrad\n\n\nclass AutoEncoder(nn.Module):\n    def __init__(self, step, spike_output=True):\n        super(AutoEncoder, self).__init__()\n        self.step = step\n        self.spike_output = spike_output\n\n        # self.gru = nn.GRU(input_size=1, hidden_size=1, num_layers=3)\n        self.sigmoid = nn.Sigmoid()\n        self.fc1 = nn.Linear(1, self.step)\n        self.fc2 = nn.Linear(self.step, self.step)\n        self.relu = nn.ReLU()\n        #\n        self.act_fun = GateGrad()\n\n    def forward(self, x):\n        shape = x.shape\n\n        x = self.fc1(x.view(-1, 1))\n        x = self.relu(x)\n        x = self.fc2(x).transpose_(1, 0)\n\n        # x = x.view(1, -1, 1).repeat(self.step, 1, 1)\n        # x, _ = self.gru(x)\n\n        x = self.sigmoid(x)\n        if not self.spike_output:\n            return x.view(self.step, *shape)\n        else:\n            return self.act_fun(x).view(self.step, *shape)\n\n\n# class TransEncoder(nn.Module):\n#     def __init__(self, step):\n#         super(TransEncoder, self).__init__()\n#         self.step = step\n#         self.trans = Transformer(dim=128, depth=3, heads=8, dim_head=, mlp_dim, dropout=0.)\n\n\nclass Encoder(nn.Module):\n    '''\n    将static image编码\n    :param step: 仿真步长\n    :param encode_type: 编码方式, 可选 ``direct``, ``ttfs``, ``rate``, ``phase``\n    :param temporal_flatten: 直接将temporal维度concat到channel维度\n    :param layer_by_layer: 是否使用计算每一层的所有的输出的方式进行推理\n    :param\n    (step, batch_size, )\n    '''\n\n    def __init__(self, step, encode_type='ttfs', *args, **kwargs):\n        super(Encoder, self).__init__()\n        self.step = step\n        self.fun = getattr(self, encode_type)\n        self.encode_type = encode_type\n        self.temporal_flatten = kwargs['temporal_flatten'] if 'temporal_flatten' in kwargs else False\n        self.layer_by_layer = kwargs['layer_by_layer'] if 'layer_by_layer' in kwargs else False\n        self.no_encode = kwargs['adaptive_node'] if 'adaptive_node' in kwargs else False\n        self.groups = kwargs['n_groups'] if 'n_groups' in kwargs else 1\n        # if encode_type == 'auto':\n        #     self.fun = AutoEncoder(self.step, spike_output=False)\n\n    def forward(self, inputs, deletion_prob=None, shift_var=None):\n        if len(inputs.shape) == 5:  # DVS data\n            outputs = inputs.permute(1, 0, 2, 3, 4).contiguous()  # t, b, c, w, h\n        elif len(inputs.shape) == 3:  # DAS data\n            outputs = inputs.permute(1, 0, 2).contiguous()  # t, b, c\n        else:\n            if self.encode_type == 'auto':\n                if self.fun.device != inputs.device:\n                    self.fun.to(inputs.device)\n            outputs = self.fun(inputs)\n\n        if deletion_prob:\n            outputs = self.delete(outputs, deletion_prob)\n        if shift_var:\n            outputs = self.shift(outputs, shift_var)\n\n        if self.temporal_flatten or self.no_encode:\n            outputs = rearrange(outputs, 't b c w h -> 1 b (t c) w h')\n        elif self.groups != 1:\n            outputs = rearrange(outputs, 't b c w h -> b (c t) w h')\n        elif self.layer_by_layer:\n            if len(inputs.shape) == 3:\n                outputs = rearrange(outputs, 't b c-> (t b) c')\n            else:\n                outputs = rearrange(outputs, 't b c w h -> (t b) c w h')\n\n        return outputs\n\n    @torch.no_grad()\n    def direct(self, inputs):\n        \"\"\"\n        直接编码\n        :param inputs: 形状(b, c, w, h)\n        :return: (t, b, c, w, h)\n        \"\"\"\n        outputs = repeat(inputs, 'b c w h -> t b c w h', t=self.step)\n        # outputs = inputs.unsqueeze(0).repeat(self.step, *([1] * len(shape)))\n        return outputs\n\n    def auto(self, inputs):\n        # TODO: Calc loss for firing-rate\n        shape = inputs.shape\n        outputs = self.fun(inputs)\n        print(outputs.shape)\n        return outputs\n\n    @torch.no_grad()\n    def ttfs(self, inputs):\n        \"\"\"\n        Time-to-First-Spike Encoder\n        :param inputs: static data\n        :return: Encoded data\n        \"\"\"\n        # print(\"ttfs\")\n        shape = (self.step,) + inputs.shape\n        outputs = torch.zeros(shape, device=self.device)\n        for i in range(self.step):\n            mask = (inputs * self.step <= (self.step - i)\n                    ) & (inputs * self.step > (self.step - i - 1))\n            outputs[i, mask] = 1 / (i + 1)\n        return outputs\n\n    @torch.no_grad()\n    def rate(self, inputs):\n        \"\"\"\n        Rate Coding\n        :param inputs:\n        :return:\n        \"\"\"\n        shape = (self.step,) + inputs.shape\n        return (inputs > torch.rand(shape, device=inputs.device)).float()\n\n    @torch.no_grad()\n    def phase(self, inputs):\n        \"\"\"\n        Phase Coding\n        相位编码\n        :param inputs: static data\n        :return: encoded data\n        \"\"\"\n        shape = (self.step,) + inputs.shape\n        outputs = torch.zeros(shape, device=self.device)\n        inputs = (inputs * 256).long()\n        val = 1.\n        for i in range(self.step):\n            if i < 8:\n                mask = (inputs >> (8 - i - 1)) & 1 != 0\n                outputs[i, mask] = val\n                val /= 2.\n            else:\n                outputs[i] = outputs[i % 8]\n        return outputs\n\n    @torch.no_grad()\n    def delete(self, inputs, prob):\n        \"\"\"\n        在Coding 过程中随机删除脉冲\n        :param inputs: encoded data\n        :param prob: 删除脉冲的概率\n        :return: 随机删除脉冲之后的数据\n        \"\"\"\n        mask = (inputs >= 0) & (torch.randn_like(\n            inputs, device=self.device) < prob)\n        inputs[mask] = 0.\n        return inputs\n\n    @torch.no_grad()\n    def shift(self, inputs, var):\n        \"\"\"\n        对数据进行随机平移, 添加噪声\n        :param inputs: encoded data\n        :param var: 随机平移的方差\n        :return: shifted data\n        \"\"\"\n        # TODO: Real-time shift\n        outputs = torch.zeros_like(inputs)\n        for step in range(self.step):\n            shift = (var * torch.randn(1)).round_() + step\n            shift.clamp_(min=0, max=self.step - 1)\n            outputs[step] += inputs[int(shift)]\n        return outputs\n"
  },
  {
    "path": "braincog/base/encoder/population_coding.py",
    "content": "import torch\nimport torch.nn as nn\nimport torchvision.utils\n\nclass PEncoder(nn.Module):\n    \"\"\"\n    Population coding\n    :param step: time steps\n    :param encode_type: encoder type (str)\n    \"\"\"\n    def __init__(self, step, encode_type):\n        super().__init__()\n        self.step = step\n        self.fun = getattr(self, encode_type)\n\n    def forward(self, inputs, num_popneurons, *args, **kwargs):\n        outputs = self.fun(inputs, num_popneurons, *args, **kwargs)\n        return outputs\n\n    @torch.no_grad()\n    def population_time(self, inputs, m):\n        \"\"\"\n        one feature will be encoded into gauss_neurons\n        the center of i-th neuron is:  gauss --\n\n        .. math::\n            \\\\mu  u_i = I_min + (2i-3)/2(I_max-I_min)/(m -2)\n        the width of i-th neuron is :  gauss --\n\n        .. math::\n            \\\\sigma sigma_i = \\\\frac{1}{1.5}\\\\frac{(I_max-I_min)}{m - 2}\n\n        :param inputs:   (N_num, N_feature) array\n        :param m: the number of the gaussian neurons\n        i : the i_th gauss_neuron\n        1.5: experience value\n        popneurons_spike_t: gauss -- function\n        I_min = min(inputs)\n        I_max = max(inputs)\n        :return: (step, num_gauss_neuron) \n        \"\"\"\n        # m = self.step\n        I_min, I_max = torch.min(inputs), torch.max(inputs)\n        mu = [i for i in range(0, m)]\n        mu = torch.ones((1, m)) * I_min + ((2 * torch.tensor(mu) - 3) / 2) * ((I_max-I_min) / (m -2))\n        sigma = (1 / 1.5) * ((I_max-I_min) / (m -2))\n        # shape = (self.step,) + inputs.shape\n        shape = (self.step,m)\n        popneurons_spike_t = torch.zeros(((m,) + inputs.shape))\n        for i in range(m):\n            popneurons_spike_t[i, :] = torch.exp(-(inputs - mu[0, i]) ** 2 / (2 * sigma * sigma))\n\n        spike_time = (self.step * popneurons_spike_t).type(torch.int)\n        spikes = torch.zeros(shape)\n        for spike_time_k in range(self.step):\n            if torch.where(spike_time == spike_time_k)[1].numel() != 0:\n                spikes[spike_time_k][torch.where(spike_time == spike_time_k)[0]] = 1\n\n        return spikes\n\n    @torch.no_grad()\n    def population_voltage(self, inputs, m, VTH):\n        '''\n        The more similar the input is to the mean,\n        the more sensitive the neuron corresponding to the mean is to the input.\n        You can change the maen.\n        :param inputs:   (N_num, N_feature) array\n        :param m : the number of the gaussian neurons\n        :param VTH : threshold voltage\n        i : the i_th gauss_neuron\n        one feature will be encoded into gauss_neurons\n        the center of i-th neuron is:  gauss -- \\mu  u_i = I_min + (2i-3)/2(I_max-I_min)/(m -2)\n        the width of i-th neuron is :  gauss -- \\sigma sigma_i = 1/1.5(I_max-I_min)/(m -2) 1.5: experience value\n        popneuron_v: gauss -- function\n        I_min = min(inputs)\n        I_max = max(inputs)\n        :return: (step, num_gauss_neuron, dim_inputs) \n        '''\n        ENCODER_REGULAR_VTH = VTH\n        I_min, I_max = torch.min(inputs), torch.max(inputs)\n        mu = [i for i in range(0, m)]\n        mu = torch.ones((1, m)) * I_min + ((2 * torch.tensor(mu) - 3) / 2) * ((I_max-I_min) / (m -2))\n        sigma = (1 / 1.5) * ((I_max-I_min) / (m -2))\n        popneuron_v = torch.zeros(((m,) + inputs.shape))\n        delta_v = torch.zeros(((m,) + inputs.shape))\n        for i in range(m):\n            delta_v[i] = torch.exp(-(inputs - mu[0, i]) ** 2 / (2 * sigma * sigma))\n        spikes = torch.zeros((self.step,) + ((m,) + inputs.shape))\n        for spike_time_k in range(self.step):\n            popneuron_v = popneuron_v + delta_v\n            spikes[spike_time_k][torch.where(popneuron_v.ge(ENCODER_REGULAR_VTH))] = 1\n            popneuron_v = popneuron_v - spikes[spike_time_k] * ENCODER_REGULAR_VTH\n\n        popneuron_rate = torch.sum(spikes, dim=0)/self.step\n\n        return spikes, popneuron_rate\n\n\n## test\n# if __name__ == '__main__':\n#     a = (torch.rand((2,4))*10).type(torch.int)\n#     print(a)\n#     pencoder = PEncoder(10, 'population_time')\n#     spikes=pencoder(inputs=a, num_popneurons=3)\n#     print(spikes, spikes.shape)\n\n#     pencoder = PEncoder(10, 'population_voltage')\n#     spikes, popneuron_rate = pencoder(inputs=a, num_popneurons=5, VTH=0.99)\n#     print(spikes, spikes.shape)\n"
  },
  {
    "path": "braincog/base/encoder/qs_coding.py",
    "content": "from signal import signal\nfrom subprocess import call\nimport numpy as np\nimport random\nimport copy\n\n\nclass QSEncoder:\n    \"\"\"\n    QS Encoding.\n    :param lambda_max: 最大发放率\n    :param steps: 脉冲发放周期长度 T\n    :param sig_len: 脉冲发放窗口\n    :param shift: 是否反转背景\n    :param noise: 是否增加噪声\n    :param noise_rate: 噪声比例\n    :param eps: 防止溢出参数\n    \"\"\"\n    def __init__(self,\n        lambda_max,\n        steps,\n        sig_len,\n        shift=False,\n        noise=None,\n        noise_rate=None,\n        eps=1e-6\n    ) -> None:\n        self._lambda_max = lambda_max\n        self._steps = steps\n        self._sig_len = sig_len\n        self._shift = shift\n        self._noise = noise\n        self._noise_rate = noise_rate\n        self._eps = eps\n\n\n    def __call__(self, image, image_delta, image_ori, image_ori_delta):\n        \"\"\"\n        将图片转换为脉冲。\n        :param image: 背景反转图片\n        :param image_delta: 扰动图片，用于计算相位\n        :param image_ori: 原始图片\n        :param image_ori_delta: 原始扰动图片\n        \"\"\"\n        if self._noise:\n            signals = self.noise_trans(image, image_ori, image_ori_delta)\n        elif self._shift:\n            signals = self.shift_trans(image, image_delta, image_ori, image_ori_delta)\n        else:\n            signals = np.zeros((self.steps, image.shape[0]))\n            signal_possion = np.random.poisson(image, (self._sig_len, image.shape[0]))\n            signals[:self._sig_len] = signal_possion[:]\n        return signal.T\n\n\n    def shift_trans(self, image, image_delta, image_ori, image_ori_delta):\n        \"\"\"\n        背景翻转图片转脉冲序列。\n        :param image: 背景反转图片\n        :param image_delta: 扰动图片，用于计算相位\n        :param image_ori: 原始图片\n        :param image_ori_delta: 原始扰动图片\n        \"\"\"\n        signal = np.zeros((self._steps, image.shape[0]))\n        assert image_ori is not None\n        assert self.noise is False\n        assert image_delta is not None\n        assert image_ori_delta is not None\n        image_ori_reverse = self._lambda_max - image_ori\n        image_ori_delta_reverse = self._lambda_max - image_ori_delta\n        zeta = image / (image_ori**2 + image_ori_reverse**2) ** 0.5\n        zeta_delta = image_delta / (image_ori_delta**2 + image_ori_delta_reverse**2)**0.5\n        idx_left = zeta < zeta_delta\n        phi = np.arctan(image_ori / (image_ori_reverse + self._eps))\n        zeta = np.clip(zeta, -1, 1)\n        zeta = np.arcsin(zeta)\n        theta1 = zeta - phi\n        theta2 = np.pi - zeta - phi\n        theta = np.zeros(theta1.shape)\n        theta[idx_left] = theta1[idx_left]\n        theta[~idx_left] = theta2[~idx_left]\n        theta = np.mean(theta)\n        cos_theta = np.cos(theta)\n        sin_theta = np.sin(theta)\n        spike_rate = np.abs((self._lambda_max * sin_theta - image) / (sin_theta - cos_theta + self._eps))\n        signal_possion = np.random.poisson(spike_rate, (self._sig_len, spike_rate.shape[0]))\n        shift_step = np.rint(np.clip(2 * theta / np.pi, a_min=0, a_max=1.0) * (self._steps - self._sig_len))\n        shift_step = shift_step.astype(np.int)\n        signal[shift_step:shift_step + self._sig_len] = signal_possion[:]\n\n\n\n    def noise_trans(self, image, image_ori, image_ori_delta):\n        \"\"\"\n        噪声图片转脉冲序列\n        :param image: 背景反转图片\n        :param image_ori: 原始图片\n        :param image_ori_delta: 原始扰动图片\n        \"\"\"\n        signal = np.zeros((self._steps, image.shape[0]))\n        assert image_ori is not None\n        assert self._shift is False\n        assert self._noise_rate is not None\n        image_ori_delta = copy.deepcopy(image_ori)\n        idx = image_ori_delta < (self._lambda_max - 0.001)\n        image_ori_delta[idx] += 0.001\n        image_ori_reverse = self._lambda_max - image_ori\n        image_ori_delta_reverse = self._lambda_max - image_ori_delta\n        image_noise, image_delta_noise = self.reverse_pixels(image_ori, image_ori_delta, noise_rate=self._noise_rate)\n        zeta = image_noise / (image_ori**2 + image_ori_reverse**2)**0.5\n        zeta_delta = image_delta_noise / (image_ori_delta**2 + image_ori_delta_reverse**2)**0.5\n        idx_left = zeta < zeta_delta\n        phi = np.arctan(image_ori / (image_ori_reverse + self._eps))\n        zeta = np.clip(zeta, -1, 1)\n        zeta = np.arcsin(zeta)\n        theta1 = zeta - phi\n        theta2 = np.pi - zeta - phi\n        theta = np.zeros(theta1.shape)\n        theta[idx_left] = theta1[idx_left]\n        theta[~idx_left] = theta2[~idx_left]\n        theta = np.mean(theta)\n        cos_theta = np.cos(theta)\n        sin_theta = np.sin(theta)\n        spike_rate = np.abs((self._lambda_max * sin_theta - image_noise) / (sin_theta - cos_theta + self._eps))\n        signal_possion = np.random.poisson(spike_rate, (self._sig_len, spike_rate.shape[0]))\n        shift_step = np.rint(np.clip(2 * theta / np.pi, a_min=0, a_max=1.0) * (self._steps - self._sig_len))\n        shift_step = shift_step.astype(np.int)\n        signal[shift_step:shift_step + self._sig_len] = signal_possion[:]\n        return signal\n\n    def reverse_pixels(self, image, image_delta, noise_rate, flip_bits=None):\n        \"\"\"\n        反转图片像素\n        \"\"\"\n        if flip_bits is None:\n            N = int(noise_rate * image.shape[0])\n            flip_bits = random.sample(range(image.shape[0]), N)\n            img = copy.copy(image)\n            img_delta = copy.copy(image_delta)\n\n            img[flip_bits] = self._lambda_max - img[flip_bits]\n            img_delta[flip_bits] = self._lambda_max - img_delta[flip_bits]\n        return img, img_delta"
  },
  {
    "path": "braincog/base/learningrule/BCM.py",
    "content": "import numpy as np\nimport torch\nimport os\nimport sys\nfrom torch import nn\nfrom torch.nn import Parameter\n\nimport abc\nimport math\nfrom abc import ABC\n\nimport numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import Parameter\nimport torch.nn.functional as F\nfrom braincog.base.node import *\n\n\nclass BCM(nn.Module):\n    \"\"\"\n    BCM learning rule 多组神经元输入到该节点\n    \"\"\"\n\n    def __init__(self, node, connection, cfunc=None, weightdecay=0.99, tau=10):\n        \"\"\"\n        :param node:node神经元类型实例如IFNode LIFNode\n        :param connection:连接 类的实例列表 里面只能有一个操作\n        :param cfunc:BCM的频率函数 默认y(y-th)\n        :param weightdecay:权重衰减系数 默认0.99\n        :param tau: 频率更新时间常数\n        \"\"\"\n        super().__init__()\n\n        self.node = node\n        self.connection = connection\n        if not isinstance(connection, list):\n            self.connection = [self.connection]\n        self.weightdecay = weightdecay\n        self.tau = tau\n        self.threshold = 0\n\n    def forward(self, *x):\n        \"\"\"\n        计算前向传播过程\n        :return:s是脉冲 dw更新量\n        \"\"\"\n        i = 0\n        x = [xi.clone().detach() for xi in x]\n        for xi, coni in zip(x, self.connection):\n            i += coni(xi)\n        with torch.no_grad():\n            s = self.node(i)\n\n            i.data += self.cfunc(s) - i.data\n\n        dw = torch.autograd.grad(outputs=i, inputs=[i.weight for i in self.connection], grad_outputs=i)\n        for dwi, i in zip(dw, self.connection):\n            dwi -= (1 - self.weightdecay) * i.weight\n        return s, dw\n\n    def cfunc(self, s):\n        self.threshold = ((self.tau - 1) * self.threshold + s) / self.tau\n\n        return (s * (s - self.threshold)).detach()\n\n    def reset(self):\n        \"\"\"\n        重置\n        \"\"\"\n        self.threshold = 0\n        pass\n"
  },
  {
    "path": "braincog/base/learningrule/Hebb.py",
    "content": "import numpy as np\nimport torch\nimport os\nimport sys\nfrom torch import nn\nfrom torch.nn import Parameter\n\nimport abc\nimport math\nfrom abc import ABC\n\nimport numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import Parameter\nimport torch.nn.functional as F\nfrom braincog.base.node.node import *\n\n\nclass Hebb(nn.Module):\n    \"\"\"\n    Hebb learning rule 多组神经元输入到该节点\n    \"\"\"\n\n    def __init__(self, node, connection):\n        \"\"\"\n        :param node:node神经元类型实例如IFNode LIFNode\n        :param connection:连接 类的实例列表 里面只能有一个操作\n        \"\"\"\n        super().__init__()\n\n        self.node = node\n        self.connection = connection\n        self.trace = [None for i in self.connection]\n\n    def forward(self, *x):\n        \"\"\"\n        计算前向传播过程\n        :return:s是脉冲 dw更新量\n        \"\"\"\n        i = 0\n        x = [xi.clone().detach() for xi in x]\n        for xi, coni in zip(x, self.connection):\n            i += coni(xi)\n        with torch.no_grad():\n            s = self.node(i)\n\n            i.data += s - i.data\n\n        dw = torch.autograd.grad(outputs=i, inputs=[i.weight for i in self.connection], grad_outputs=i)\n\n        return s, dw\n\n    def reset(self):\n        \"\"\"\n        重置\n        \"\"\"\n        self.trace = [None for i in self.connection]\n\n\nif __name__ == \"__main__\":\n    node = IFNode()\n    linear1 = nn.Linear(2, 2, bias=False)\n    linear2 = nn.Linear(2, 2, bias=False)\n    linear1.weight.data = torch.tensor([[1., 1], [1, 1]], requires_grad=True)\n    linear2.weight.data = torch.tensor([[1., 1], [1, 1]], requires_grad=True)\n\n    hebb = Hebb(node, [linear1, linear2])\n    for i in range(10):\n        x, dw1 = hebb(torch.tensor([1.1, 1.1]), torch.tensor([1.1, 1.1]))\n        print(dw1)\n"
  },
  {
    "path": "braincog/base/learningrule/RSTDP.py",
    "content": "import numpy as np\nimport torch\nimport os\nimport sys\nfrom torch import nn\nfrom torch.nn import Parameter\n\nimport abc\nimport math\nfrom abc import ABC\n\nimport numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import Parameter\nimport torch.nn.functional as F\nfrom braincog.base.node import *\n\n\nclass RSTDP(nn.Module):\n    \"\"\"\n    RSTDP算法\n    \"\"\"\n    def __init__(self, node, connection, decay=0.99, reward_decay=0.5):\n        \"\"\"\n        :param node:node神经元类型实例如IFNode LIFNode\n        :param connection:连接 类的实例列表 里面只能有一个操作\n        \"\"\"\n        super().__init__()\n\n        self.node = node\n        self.connection = connection\n        if not isinstance(connection, list):\n            self.connection = [self.connection]\n        self.trace = [None for i in self.connection]\n        self.decay = decay\n        self.reward_decay = reward_decay\n        self.stdp = STDP(self.node, self.node, self.decay)\n\n    def forward(self, *x, r):\n        \"\"\"\n        计算前向传播过程\n        :return:s是脉冲 dw更新量\n        \"\"\"\n        s, dw = self.stdp(x)\n        trace = self.cal_trace(r)\n        return s, dw * trace\n\n    def cal_trace(self, x):\n        \"\"\"\n        计算trace\n        \"\"\"\n        for i in range(len(x)):\n            if self.trace[i] is None:\n                self.trace[i] = Parameter(x[i].clone().detach(), requires_grad=False)\n            else:\n                self.trace[i] *= self.decay\n                self.trace[i] += x[i].detach()\n        return self.trace\n\n    def reset(self):\n        self.trace = [None for i in self.connection]\n"
  },
  {
    "path": "braincog/base/learningrule/STDP.py",
    "content": "import numpy as np\r\nimport torch\r\nimport os\r\nimport sys\r\nfrom torch import nn\r\nfrom torch.nn import Parameter\r\n\r\nimport abc\r\nimport math\r\nfrom abc import ABC\r\n\r\nimport numpy as np\r\nimport torch\r\nfrom torch import nn\r\nfrom torch.nn import Parameter\r\nimport torch.nn.functional as F\r\nfrom braincog.base.node.node import *\r\n\r\n\r\nclass STDP(nn.Module):\r\n    \"\"\"\r\n    STDP learning rule\r\n    \"\"\"\r\n\r\n    def __init__(self, node, connection, decay=0.99):\r\n        \"\"\"\r\n        :param node:node神经元类型实例如IFNode LIFNode\r\n        :param connection:连接 类的实例 里面只能有一个操作\r\n        \"\"\"\r\n        super().__init__()\r\n\r\n        self.node = node\r\n        self.connection = connection\r\n        self.trace = None\r\n        self.decay = decay\r\n\r\n    def forward(self, x):\r\n        \"\"\"\r\n        计算前向传播过程\r\n        :return:s是脉冲 dw更新量\r\n        \"\"\"\r\n        x = x.clone().detach()\r\n        i = self.connection(x)\r\n        with torch.no_grad():\r\n            s = self.node(i)\r\n\r\n            i.data += s - i.data\r\n            trace = self.cal_trace(x)\r\n            x.data += trace - x.data\r\n\r\n        dw = torch.autograd.grad(outputs=i, inputs=self.connection.weight, grad_outputs=i)\r\n\r\n        return s, dw\r\n\r\n    def cal_trace(self, x):\r\n        \"\"\"\r\n        计算trace\r\n        \"\"\"\r\n        if self.trace is None:\r\n            self.trace = Parameter(x.clone().detach(), requires_grad=False)\r\n        else:\r\n            self.trace *= self.decay\r\n            self.trace += x\r\n        return self.trace.detach()\r\n\r\n    def reset(self):\r\n        \"\"\"\r\n        重置\r\n        \"\"\"\r\n        self.trace = None\r\n\r\n\r\nclass MutliInputSTDP(nn.Module):\r\n    \"\"\"\r\n    STDP learning rule 多组神经元输入到该节点\r\n    \"\"\"\r\n\r\n    def __init__(self, node, connection, decay=0.99):\r\n        \"\"\"\r\n        :param node:node神经元类型实例如IFNode LIFNode\r\n        :param connection:连接 类的实例列表 里面只能有一个操作\r\n        \"\"\"\r\n        super().__init__()\r\n\r\n        self.node = node\r\n        self.connection = connection\r\n        self.trace = [None for i in self.connection]\r\n        self.decay = decay\r\n\r\n    def forward(self, *x):\r\n        \"\"\"\r\n        计算前向传播过程\r\n        :return:s是脉冲 dw更新量\r\n        \"\"\"\r\n        i = 0\r\n        x = [xi.clone().detach() for xi in x]\r\n        for xi, coni in zip(x, self.connection):\r\n            i += coni(xi)\r\n        with torch.no_grad():\r\n            s = self.node(i)\r\n\r\n            i.data += s - i.data\r\n\r\n            trace = self.cal_trace(x)\r\n            for xi, ti in zip(x, trace):\r\n                xi.data += ti - xi.data\r\n\r\n        dw = torch.autograd.grad(outputs=i, inputs=[i.weight for i in self.connection], grad_outputs=i)\r\n\r\n        return s, dw\r\n\r\n    def cal_trace(self, x):\r\n        \"\"\"\r\n        计算trace\r\n        \"\"\"\r\n        for i in range(len(x)):\r\n            if self.trace[i] is None:\r\n                self.trace[i] = Parameter(x[i].clone().detach(), requires_grad=False)\r\n            else:\r\n                self.trace[i] *= self.decay\r\n                self.trace[i] += x[i].detach()\r\n        return self.trace\r\n\r\n    def reset(self):\r\n        \"\"\"\r\n        重置\r\n        \"\"\"\r\n        self.trace = [None for i in self.connection]\r\n\r\n\r\nclass LTP(MutliInputSTDP):\r\n    \"\"\"\r\n    STDP learning rule 多组神经元输入到该节点\r\n    \"\"\"\r\n    pass\r\n\r\n\r\nclass LTD(nn.Module):\r\n    \"\"\"\r\n    STDP learning rule 多组神经元输入到该节点\r\n    \"\"\"\r\n\r\n    def __init__(self, node, connection, decay=0.99):\r\n        \"\"\"\r\n        :param node:node神经元类型实例如IFNode LIFNode\r\n        :param connection:连接 类的实例列表 里面只能有一个操作\r\n        \"\"\"\r\n        super().__init__()\r\n\r\n        self.node = node\r\n        self.connection = connection\r\n        self.trace = None\r\n        self.decay = decay\r\n\r\n    def forward(self, *x):\r\n        \"\"\"\r\n        计算前向传播过程\r\n        :return:s是脉冲 dw更新量\r\n        \"\"\"\r\n        i = 0\r\n        x = [xi.clone().detach() for xi in x]\r\n        for xi, coni in zip(x, self.connection):\r\n            i += coni(xi)\r\n        with torch.no_grad():\r\n            s = self.node(i)\r\n\r\n            trace = self.cal_trace(s)\r\n            i.data += trace - i.data\r\n\r\n        dw = torch.autograd.grad(outputs=i, inputs=[i.weight for i in self.connection], grad_outputs=i)\r\n\r\n        return s, dw\r\n\r\n    def cal_trace(self, x):\r\n        \"\"\"\r\n        计算trace\r\n        \"\"\"\r\n        if self.trace is None:\r\n            self.trace = Parameter(torch.zeros_like(x), requires_grad=False)\r\n        else:\r\n            self.trace *= self.decay\r\n        trace = self.trace.clone().detach()\r\n        self.trace += x\r\n        return trace\r\n\r\n    def reset(self):\r\n        \"\"\"\r\n        重置\r\n        \"\"\"\r\n        self.trace = None\r\n\r\n\r\nclass FullSTDP(nn.Module):\r\n    \"\"\"\r\n    STDP learning rule 多组神经元输入到该节点\r\n    \"\"\"\r\n\r\n    def __init__(self, node, connection, decay=0.99, decay2=0.99):\r\n        \"\"\"\r\n        :param node:node神经元类型实例如IFNode LIFNode\r\n        :param connection:连接 类的实例列表 里面只能有一个操作\r\n        \"\"\"\r\n        super().__init__()\r\n\r\n        self.node = node\r\n        self.connection = connection\r\n        self.tracein = [None for i in self.connection]\r\n        self.traceout = None\r\n        self.decay = decay\r\n        self.decay2 = decay2\r\n\r\n    def forward(self, *x):\r\n        \"\"\"\r\n        计算前向传播过程\r\n        :return:s是脉冲 dw更新量\r\n        \"\"\"\r\n        i = 0\r\n        x = [xi.clone().detach() for xi in x]\r\n        for xi, coni in zip(x, self.connection):\r\n            i += coni(xi)\r\n        with torch.no_grad():\r\n            s = self.node(i)\r\n            traceout = self.cal_traceout(s)\r\n            i.data += traceout - i.data\r\n        dw1 = torch.autograd.grad(outputs=i, inputs=[i.weight for i in self.connection], retain_graph=True,\r\n                                  grad_outputs=i)\r\n\r\n        with torch.no_grad():\r\n            i.data += s - i.data\r\n\r\n            tracein = self.cal_tracein(x)\r\n            for xi, ti in zip(x, tracein):\r\n                xi.data += ti - xi.data\r\n\r\n        dw2 = torch.autograd.grad(outputs=i, inputs=[i.weight for i in self.connection], grad_outputs=i)\r\n\r\n        return s, dw2, dw1\r\n\r\n    def cal_tracein(self, x):\r\n        \"\"\"\r\n        计算trace\r\n        \"\"\"\r\n        for i in range(len(x)):\r\n            if self.tracein[i] is None:\r\n                self.tracein[i] = Parameter(x[i].clone().detach(), requires_grad=False)\r\n            else:\r\n                self.tracein[i] *= self.decay\r\n                self.tracein[i] += x[i].detach()\r\n        return self.tracein\r\n\r\n    def cal_traceout(self, x):\r\n        \"\"\"\r\n        计算trace\r\n        \"\"\"\r\n        if self.traceout is None:\r\n            self.traceout = Parameter(torch.zeros_like(x), requires_grad=False)\r\n        else:\r\n            self.traceout *= self.decay2\r\n        trace = self.traceout.clone().detach()\r\n        self.traceout += x\r\n        return trace\r\n\r\n    def reset(self):\r\n        \"\"\"\r\n        重置\r\n        \"\"\"\r\n        self.traceout = [None for i in self.connection]\r\n        self.tracein = None\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    node = IFNode()\r\n    linear1 = nn.Linear(2, 2, bias=False)\r\n    linear2 = nn.Linear(2, 2, bias=False)\r\n    linear1.weight.data = torch.tensor([[1., 1], [1, 1]], requires_grad=True)\r\n    linear2.weight.data = torch.tensor([[1., 1], [1, 1]], requires_grad=True)\r\n\r\n    stdp = LTD(node, [linear1, linear2])\r\n    for i in range(10):\r\n        x, dw1 = stdp(torch.tensor([1.1, 1.1]), torch.tensor([1.1, 1.1]))\r\n        print(dw1)\r\n"
  },
  {
    "path": "braincog/base/learningrule/STP.py",
    "content": "import math\n\n\nclass short_time():\n    \"\"\"\n        计算短期突触可塑性的变量详见Tsodyks和Markram 1997\n        :param Syn:突出可塑性结构体\n        :param ISI:棘突间期\n        :param Nsp:突触前棘波\n    \"\"\"\n    def __init__(self, SizeHistOutput):\n        super().__init__()   \n        self.SizeHistOutput = SizeHistOutput\n\n\n    def syndepr(self, Syn=None, ISI=None, Nsp=None):\n        \"\"\"\n            短期突触可塑性计算\n        \"\"\"\n        SizeHistOutput = self.SizeHistOutput\n        qu = Syn.uprev[Nsp] * math.exp(-ISI / Syn.tc_fac)\n        qR = math.exp(-ISI / Syn.tc_rec)\n        u = qu + Syn.use * (1.0 - qu)\n        R = Syn.Rprev[Nsp] * (1.0 - Syn.uprev[Nsp]) * qR + 1.0 - qR\n        Syn.uprev[(Nsp + 1) % SizeHistOutput] = u\n        Syn.Rprev[(Nsp + 1) % SizeHistOutput] = R\n        return R * u\n\n    \n\n    def set_gsyn(self, np=None, dt=None, v=None, NoiseSyn=None):\n        \"\"\"\n            突触电流参数计算\n        \"\"\"\n        Isyn = 0\n        gsyn_AN = 0\n        gsyn_G = 0\n\n        for j in range(np.NumSynType):\n            syn = np.STList[j]\n            sgate = 1.0\n            if (syn.Mg_gate > 0.0):\n                sgate = syn.Mg_gate / (1.0 + syn.Mg_fac * math.exp(syn.Mg_slope * (syn.Mg_half - v[0])))\n            Isyn += sgate * (\n                np.gfOFFsyn[j] * math.exp(-dt / syn.tc_off) - np.gfONsyn[j] * math.exp(-dt / syn.tc_on)) * (\n                syn.Erev - v[0])\n            if (syn.Erev == 0.0):\n                gsyn_AN = gsyn_AN + sgate * (\n                    np.gfOFFsyn[j] * math.exp(-dt / syn.tc_off) - np.gfONsyn[j] * math.exp(-dt / syn.tc_on))\n            else:\n                gsyn_G = gsyn_G + sgate * (\n                    np.gfOFFsyn[j] * math.exp(-dt / syn.tc_off) - np.gfONsyn[j] * math.exp(-dt / syn.tc_on))\n\n        \n        for j in range(NoiseSyn.NumSyn):\n            syn = NoiseSyn.Syn[j].STPtr\n            sgate = 1.0\n            if (syn.Mg_gate > 0.0):  \n                sgate = syn.Mg_gate / (1.0 + syn.Mg_fac * math.exp(syn.Mg_slope * (syn.Mg_half - v)))\n            Isyn += sgate * (\n                np.gfOFFnoise[j] * math.exp(-dt / syn.tc_off) - np.gfONnoise[j] * math.exp(-dt / syn.tc_on)) * (\n                syn.Erev - v)\n            if (syn.Erev == 0.0):\n                gsyn_AN = gsyn_AN + sgate * (\n                    np.gfOFFnoise[j] * math.exp(-dt / syn.tc_off) - np.gfONnoise[j] * math.exp(-dt / syn.tc_on))\n            else:\n                gsyn_G = gsyn_G + sgate * (\n                    np.gfOFFnoise[j] * math.exp(-dt / syn.tc_off) - np.gfONnoise[j] * math.exp(-dt / syn.tc_on))\n\n        I_tot = Isyn + np.Iinj\n        return gsyn_AN, I_tot, gsyn_G\n    \n    \n    def IDderiv(self, np=None, v=None, dt=None, dv=None, NoiseSyn=None, flag_dv=None):\n        \"\"\"\n         定义模型的常微分方程计算单个神经元常微分方程\n         :param np:神经元参数\n         :param v:当前变量\n         :param dt:时间步长\n        \"\"\"\n        Isyn = 0\n        gsyn_G = 0\n        gsyn_AN = 0\n        for j in range(np.NumSynType):\n            syn = np.STList[j]\n            sgate = 1.0\n            if (syn.Mg_gate > 0.0):  \n                sgate = syn.Mg_gate / (1.0 + syn.Mg_fac * math.exp(syn.Mg_slope * (syn.Mg_half - v[0])))\n            Isyn += sgate * (\n                np.gfOFFsyn[j] * math.exp(-dt / syn.tc_off) - np.gfONsyn[j] * math.exp(-dt / syn.tc_on)) * (\n                syn.Erev - v[0])\n            if (syn.Erev == 0.0):\n                gsyn_AN = gsyn_AN + sgate * (\n                    np.gfOFFsyn[j] * math.exp(-dt / syn.tc_off) - np.gfONsyn[j] * math.exp(-dt / syn.tc_on))\n            else:\n                gsyn_G = gsyn_G + sgate * (\n                    np.gfOFFsyn[j] * math.exp(-dt / syn.tc_off) - np.gfONsyn[j] * math.exp(-dt / syn.tc_on))\n\n       \n        for j in range(NoiseSyn.NumSyn):\n            syn = NoiseSyn.Syn[j].STPtr\n            sgate = 1.0\n            if (syn.Mg_gate > 0.0):  \n                sgate = syn.Mg_gate / (1.0 + syn.Mg_fac * math.exp(syn.Mg_slope * (syn.Mg_half - v[0])))\n            Isyn += sgate * (\n                np.gfOFFnoise[j] * math.exp(-dt / syn.tc_off) - np.gfONnoise[j] * math.exp(-dt / syn.tc_on)) * (\n                syn.Erev - v[0])\n            if (syn.Erev == 0.0):\n                gsyn_AN = gsyn_AN + sgate * (\n                    np.gfOFFnoise[j] * math.exp(-dt / syn.tc_off) - np.gfONnoise[j] * math.exp(-dt / syn.tc_on))\n            else:\n                gsyn_G = gsyn_G + sgate * (\n                    np.gfOFFnoise[j] * math.exp(-dt / syn.tc_off) - np.gfONnoise[j] * math.exp(-dt / syn.tc_on))\n\n       \n        I_ex = np.gL * np.sf * math.exp((v[0] - np.Vth) / np.sf)\n        \n        wV = np.Iinj + Isyn - np.gL * (v[0] - np.EL) + I_ex\n        \n        D0 = (np.Cm / np.gL) * wV\n\n        \n        if ((\n                np.Iinj + Isyn) >= np.I_ref and flag_dv == 0):  \n            dv[0] = -(np.gL / np.Cm) * (v[0] - np.v_dep)\n            flag_regime_osc = 0\n        else:\n            dv[0] = (np.Iinj - np.gL * (v[0] - np.EL) - v[1] + I_ex + Isyn) / np.Cm\n            flag_regime_osc = 1\n\n        \n        dD0 = np.Cm * (math.exp((v[0] - np.Vth) / np.sf) - 1)\n\n        \n        if ((v[1] > wV - D0 / np.tcw) and (v[1] < wV + D0 / np.tcw) and v[0] <= np.Vth and (\n                np.Iinj + Isyn) < np.I_ref):\n            dv[1] = -(np.gL * (1 - math.exp((v[0] - np.Vth) / np.sf)) + dD0 / np.tcw) * dv[0]\n        else:\n            dv[1] = 0\n        I_tot = Isyn + np.Iinj\n\n        return wV, D0, gsyn_AN, gsyn_G, I_tot, dv\n\n   \n\n    def update(self, np=None, dt=None, NoiseSyn=None, flag_dv=None):\n        \"\"\"\n         用二阶显式龙格-库塔法积分常微分方程\n         :param np:神经元参数\n         :param dt:时间步长\n        \"\"\"\n        nvar = 2\n        v = [0] * 2\n        dv1 = [0] * 2\n        dv2 = [0] * 2\n        for i in range(nvar):\n            v[i] = np.v[i]\n        wV, D0, gsyn_AN, gsyn_G, I_tot, dv1 = short_time(self.SizeHistOutput).IDderiv(np, v, 0.0, dv1, NoiseSyn, flag_dv)\n        for i in range(nvar):\n            v[i] += dt * dv1[i]\n        wV, D0, gsyn_AN, gsyn_G, I_tot, dv2 = short_time(self.SizeHistOutput).IDderiv(np, v, 0.0, dv2, NoiseSyn, flag_dv)\n        for i in range(nvar):\n            np.v[i] += dt / 2.0 * (dv1[i] + dv2[i])\n            np.dv[i] = dt / 2.0 * (dv1[i] + dv2[i])\n\n       \n        if ((np.v[1] > wV - D0 / np.tcw) and (np.v[1] < wV + D0 / np.tcw) and np.v[0] <= np.Vth):\n            np.v[1] = wV - (D0 / np.tcw)\n\n        return np, gsyn_AN, gsyn_G, I_tot\n"
  },
  {
    "path": "braincog/base/learningrule/__init__.py",
    "content": "from .BCM import BCM\nfrom .Hebb import Hebb\nfrom .RSTDP import RSTDP\nfrom .STDP import STDP, MutliInputSTDP, LTP, LTD, FullSTDP\nfrom .STP import short_time\n\n\n__all__ = [\n    'BCM',\n    \"Hebb\",\n    'RSTDP',\n    'STDP', 'MutliInputSTDP', 'LTP', 'LTD', 'FullSTDP',\n    'short_time'\n]\n"
  },
  {
    "path": "braincog/base/node/__init__.py",
    "content": "from .node import *"
  },
  {
    "path": "braincog/base/node/node.py",
    "content": "# encoding: utf-8\n# Author    : Floyed<Floyed_Shen@outlook.com>\n# Datetime  : 2022/4/10 18:46\n# User      : Floyed\n# Product   : PyCharm\n# Project   : braincog\n# File      : node.py\n# explain   : 神经元节点类型\n\nimport abc\nimport math\nfrom abc import ABC\nimport numpy as np\nimport random\nimport torch\nfrom torch import nn\nfrom torch.nn import Parameter\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\n\nfrom braincog.base.connection.layer import CustomLinear\nfrom braincog.base.strategy.surrogate import *\n\n\nclass BaseNode(nn.Module, abc.ABC):\n    \"\"\"\n    神经元模型的基类\n    :param threshold: 神经元发放脉冲需要达到的阈值\n    :param v_reset: 静息电位\n    :param dt: 时间步长\n    :param step: 仿真步\n    :param requires_thres_grad: 是否需要计算对于threshold的梯度, 默认为 ``False``\n    :param sigmoid_thres: 是否使用sigmoid约束threshold的范围搭到 [0, 1], 默认为 ``False``\n    :param requires_fp: 是否需要在推理过程中保存feature map, 需要消耗额外的内存和时间, 默认为 ``False``\n    :param layer_by_layer: 是否以一次性计算所有step的输出, 在网络模型较大的情况下, 一般会缩短单次推理的时间, 默认为 ``False``\n    :param n_groups: 在不同的时间步, 是否使用不同的权重, 默认为 ``1``, 即不分组\n    :param mem_detach: 是否将上一时刻的膜电位在计算图中截断\n    :param args: 其他的参数\n    :param kwargs: 其他的参数\n    \"\"\"\n\n    def __init__(self,\n                 threshold=.5,\n                 v_reset=0.,\n                 dt=1.,\n                 step=8,\n                 requires_thres_grad=False,\n                 sigmoid_thres=False,\n                 requires_fp=False,\n                 layer_by_layer=False,\n                 n_groups=1,\n                 *args,\n                 **kwargs):\n\n        super(BaseNode, self).__init__()\n        self.threshold = Parameter(torch.tensor(threshold), requires_grad=requires_thres_grad)\n        self.sigmoid_thres = sigmoid_thres\n        self.mem = 0.\n        self.spike = 0.\n        self.dt = dt\n        self.feature_map = []\n        self.mem_collect = []\n        self.requires_fp = requires_fp\n        self.v_reset = v_reset\n        self.step = step\n        self.layer_by_layer = layer_by_layer\n        self.groups = n_groups\n        self.mem_detach = kwargs['mem_detach'] if 'mem_detach' in kwargs else False\n        self.requires_mem = kwargs['requires_mem'] if 'requires_mem' in kwargs else False\n\n    @abc.abstractmethod\n    def calc_spike(self):\n        \"\"\"\n        通过当前的mem计算是否发放脉冲，并reset\n        :return: None\n        \"\"\"\n\n        pass\n\n    def integral(self, inputs):\n        \"\"\"\n        计算由当前inputs对于膜电势的累积\n        :param inputs: 当前突触输入电流\n        :type inputs: torch.tensor\n        :return: None\n        \"\"\"\n\n        pass\n\n    def get_thres(self):\n        return self.threshold if not self.sigmoid_thres else self.threshold.sigmoid()\n\n    def rearrange2node(self, inputs):\n        if self.groups != 1:\n            if len(inputs.shape) == 4:\n                outputs = rearrange(inputs, 'b (c t) w h -> t b c w h', t=self.step)\n            elif len(inputs.shape) == 2:\n                outputs = rearrange(inputs, 'b (c t) -> t b c', t=self.step)\n            else:\n                raise NotImplementedError\n\n        elif self.layer_by_layer:\n            if len(inputs.shape) == 4:\n                outputs = rearrange(inputs, '(t b) c w h -> t b c w h', t=self.step)\n            elif len(inputs.shape) == 3:\n                outputs = rearrange(inputs, '(t b) n c -> t b n c', t=self.step)\n            elif len(inputs.shape) == 2:\n                outputs = rearrange(inputs, '(t b) c -> t b c', t=self.step)\n            else:\n                raise NotImplementedError\n\n\n        else:\n            outputs = inputs\n\n        return outputs\n\n    def rearrange2op(self, inputs):\n        if self.groups != 1:\n            if len(inputs.shape) == 5:\n                outputs = rearrange(inputs, 't b c w h -> b (c t) w h')\n            elif len(inputs.shape) == 3:\n                outputs = rearrange(inputs, ' t b c -> b (c t)')\n            else:\n                raise NotImplementedError\n        elif self.layer_by_layer:\n            if len(inputs.shape) == 5:\n                outputs = rearrange(inputs, 't b c w h -> (t b) c w h')\n            elif len(inputs.shape) == 4:\n                outputs = rearrange(inputs, ' t b n c -> (t b) n c')\n            elif len(inputs.shape) == 3:\n                outputs = rearrange(inputs, ' t b c -> (t b) c')\n            else:\n                raise NotImplementedError\n\n        else:\n            outputs = inputs\n\n        return outputs\n\n    def forward(self, inputs):\n        \"\"\"\n        torch.nn.Module 默认调用的函数，用于计算膜电位的输入和脉冲的输出\n        在```self.requires_fp is True``` 的情况下，可以使得```self.feature_map```用于记录trace\n        :param inputs: 当前输入的膜电位\n        :return: 输出的脉冲\n        \"\"\"\n\n        if hasattr(self, 'parallel') and self.parallel is True:\n            inputs = self.rearrange2node(inputs)\n            if self.mem_detach and hasattr(self.mem, 'detach'):\n                self.mem = self.mem.detach()\n                self.spike = self.spike.detach()\n            self.integral(inputs)\n\n            self.calc_spike()\n\n            if self.requires_fp is True:\n                self.feature_map.append(self.spike)\n            if self.requires_mem is True:\n                self.mem_collect.append(self.mem)\n\n            return self.rearrange2op(self.spike)\n\n        elif self.layer_by_layer or self.groups != 1:\n            inputs = self.rearrange2node(inputs)\n\n            outputs = []\n            for i in range(self.step):\n                \n                if self.mem_detach and hasattr(self.mem, 'detach'):\n                    self.mem = self.mem.detach()\n                    self.spike = self.spike.detach()\n                self.integral(inputs[i])\n                \n                self.calc_spike()\n                \n                if self.requires_fp is True:\n                    self.feature_map.append(self.spike)\n                if self.requires_mem is True:\n                    self.mem_collect.append(self.mem)\n                outputs.append(self.spike)\n            outputs = torch.stack(outputs)\n\n            outputs = self.rearrange2op(outputs)\n            return outputs\n        else:\n            if self.mem_detach and hasattr(self.mem, 'detach'):\n                self.mem = self.mem.detach()\n                self.spike = self.spike.detach()\n            self.integral(inputs)\n            self.calc_spike()\n            if self.requires_fp is True:\n                self.feature_map.append(self.spike)\n            if self.requires_mem is True:\n                self.mem_collect.append(self.mem)   \n            return self.spike\n\n    def n_reset(self):\n        \"\"\"\n        神经元重置，用于模型接受两个不相关输入之间，重置神经元所有的状态\n        :return: None\n        \"\"\"\n        self.mem = self.v_reset\n        self.spike = 0.\n        self.feature_map = []\n        self.mem_collect = []\n    def get_n_attr(self, attr):\n\n        if hasattr(self, attr):\n            return getattr(self, attr)\n        else:\n            return None\n\n    def set_n_warm_up(self, flag):\n        \"\"\"\n        一些训练策略会在初始的一些epoch，将神经元视作ANN的激活函数训练，此为设置是否使用该方法训练\n        :param flag: True：神经元变为激活函数， False：不变\n        :return: None\n        \"\"\"\n        self.warm_up = flag\n\n    def set_n_threshold(self, thresh):\n        \"\"\"\n        动态设置神经元的阈值\n        :param thresh: 阈值\n        :return:\n        \"\"\"\n        self.threshold = Parameter(torch.tensor(thresh, dtype=torch.float), requires_grad=False)\n\n    def set_n_tau(self, tau):\n        \"\"\"\n        动态设置神经元的衰减系数，用于带Leaky的神经元\n        :param tau: 衰减系数\n        :return:\n        \"\"\"\n        if hasattr(self, 'tau'):\n            self.tau = Parameter(torch.tensor(tau, dtype=torch.float), requires_grad=False)\n        else:\n            raise NotImplementedError\n\n#============================================================================\n# node的基类\nclass BaseMCNode(nn.Module, abc.ABC):\n    \"\"\"\n    多房室神经元模型的基类\n    :param threshold: 神经元发放脉冲需要达到的阈值\n    :param v_reset: 静息电位\n    :param comps: 神经元不同房室, 例如[\"apical\", \"basal\", \"soma\"]\n    \"\"\"\n    def __init__(self,\n                 threshold=1.0,\n                 v_reset=0.,\n                 comps=[]):\n        super().__init__()\n        self.threshold = Parameter(torch.tensor(threshold), requires_grad=False)\n        # self.decay = Parameter(torch.tensor(decay), requires_grad=False)\n        self.v_reset = v_reset\n        assert len(comps) != 0\n        self.mems = dict()\n        for c in comps:\n            self.mems[c] = None \n        self.spike = None\n        self.warm_up = False\n\n    @abc.abstractmethod\n    def calc_spike(self):\n        pass\n    @abc.abstractmethod\n    def integral(self, inputs):\n        pass        \n    \n    def forward(self, inputs: dict):\n        '''\n        Params:\n            inputs dict: Inputs for every compartments of neuron \n        '''\n        if self.warm_up:\n            return inputs\n        else:\n            self.integral(**inputs)\n            self.calc_spike()\n            return self.spike\n\n    def n_reset(self):\n        for c in self.mems.keys():\n            self.mems[c] = self.v_reset\n        self.spike = 0.0\n\n    def get_n_fire_rate(self):\n        if self.spike is None:\n            return 0.\n        return float((self.spike.detach() >= self.threshold).sum()) / float(np.product(self.spike.shape))\n\n    def set_n_warm_up(self, flag):\n        self.warm_up = flag\n\n    def set_n_threshold(self, thresh):\n        self.threshold = Parameter(torch.tensor(thresh, dtype=torch.float), requires_grad=False)\n\n\nclass ThreeCompNode(BaseMCNode):\n    \"\"\"\n    三房室神经元模型\n    :param threshold: 神经元发放脉冲需要达到的阈值\n    :param v_reset: 静息电位\n    :param tau: 胞体膜电位时间常数, 用于控制胞体膜电位衰减\n    :param tau_basal: 基底树突膜电位时间常数, 用于控制基地树突胞体膜电位衰减\n    :param tau_apical: 远端树突膜电位时间常数, 用于控制远端树突胞体膜电位衰减\n    :param comps: 神经元不同房室, 例如[\"apical\", \"basal\", \"soma\"]\n    :param act_fun: 脉冲梯度代理函数\n    \"\"\"\n    def __init__(self,\n                 threshold=1.0,\n                 tau=2.0,\n                 tau_basal=2.0,\n                 tau_apical=2.0,\n                 v_reset=0.0,\n                 comps=['basal', 'apical', 'soma'],\n                 act_fun=AtanGrad):\n        g_B = 0.6\n        g_L = 0.05\n        super().__init__(threshold, v_reset, comps)\n        self.tau = tau\n        self.tau_basal = tau_basal\n        self.tau_apical = tau_apical\n        self.act_fun = act_fun(alpha=tau, requires_grad=False)\n    \n    def integral(self, basal_inputs, apical_inputs):\n        '''\n        Params:\n            inputs torch.Tensor: Inputs for basal dendrite  \n        '''\n\n        self.mems['basal'] =  (self.mems['basal'] + basal_inputs) / self.tau_basal\n        self.mems['apical'] =  (self.mems['apical'] + apical_inputs) / self.tau_apical\n\n        self.mems['soma'] = self.mems['soma'] + (self.mems['apical'] + self.mems['basal'] - self.mems['soma']) / self.tau\n\n\n    def calc_spike(self):\n        self.spike = self.act_fun(self.mems['soma'] - self.threshold)\n        self.mems['soma'] = self.mems['soma']  * (1. - self.spike.detach())\n        self.mems['basal'] = self.mems['basal'] * (1. - self.spike.detach())\n        self.mems['apical'] = self.mems['apical']  * (1. - self.spike.detach())\n\n\n#============================================================================\n\n# 用于静态测试 使用ANN的情况 不累积电位 \nclass ReLUNode(BaseNode):\n    \"\"\"\n    用于相同连接的ANN的测试\n    \"\"\"\n\n    def __init__(self,\n                 *args,\n                 **kwargs):\n        super().__init__(requires_fp=False, *args, **kwargs)\n        self.act_fun = nn.ReLU()\n\n    def forward(self, x):\n        \"\"\"\n        参考```BaseNode```\n        :param x:\n        :return:\n        \"\"\"\n        self.spike = self.act_fun(x)\n        if self.requires_fp is True:\n            self.feature_map.append(self.spike)\n        if self.requires_mem is True:\n            self.mem_collect.append(self.mem)\n        return self.spike\n\n    def calc_spike(self):\n        pass\n\n\nclass BiasReLUNode(BaseNode):\n    \"\"\"\n    用于相同连接的ANN的测试, 会在每个时刻注入恒定电流, 使得神经元更容易激发\n    \"\"\"\n\n    def __init__(self,\n                 *args,\n                 **kwargs):\n        super().__init__(*args, **kwargs)\n        self.act_fun = nn.ReLU()\n\n    def forward(self, x):\n        self.spike = self.act_fun(x + 0.1)\n        if self.requires_fp is True:\n            self.feature_map += self.spike\n        return self.spike\n\n    def calc_spike(self):\n        pass\n\n\n# ============================================================================\n# 用于SNN的node\nclass IFNode(BaseNode):\n    \"\"\"\n    Integrate and Fire Neuron\n    :param threshold: 神经元发放脉冲需要达到的阈值\n    :param v_reset: 静息电位\n    :param dt: 时间步长\n    :param step: 仿真步\n    :param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``\n    :param requires_thres_grad: 是否需要计算对于threshold的梯度, 默认为 ``False``\n    :param sigmoid_thres: 是否使用sigmoid约束threshold的范围搭到 [0, 1], 默认为 ``False``\n    :param requires_fp: 是否需要在推理过程中保存feature map, 需要消耗额外的内存和时间, 默认为 ``False``\n    :param layer_by_layer: 是否以一次性计算所有step的输出, 在网络模型较大的情况下, 一般会缩短单次推理的时间, 默认为 ``False``\n    :param n_groups: 在不同的时间步, 是否使用不同的权重, 默认为 ``1``, 即不分组\n    :param args: 其他的参数\n    :param kwargs: 其他的参数\n    \"\"\"\n\n    def __init__(self, threshold=.5, act_fun=AtanGrad, *args, **kwargs):\n        \"\"\"\n        :param threshold:\n        :param act_fun:\n        :param args:\n        :param kwargs:\n        \"\"\"\n        super().__init__(threshold, *args, **kwargs)\n        if isinstance(act_fun, str):\n            act_fun = eval(act_fun)\n        self.act_fun = act_fun(alpha=2., requires_grad=False)\n\n    def integral(self, inputs):\n        self.mem = self.mem + inputs * self.dt\n\n    def calc_spike(self):\n        self.spike = self.act_fun(self.mem - self.get_thres())\n        self.mem = self.mem * (1 - self.spike.detach())\n\n\nclass LIFNode(BaseNode):\n    \"\"\"\n    Leaky Integrate and Fire\n    :param threshold: 神经元发放脉冲需要达到的阈值\n    :param v_reset: 静息电位\n    :param dt: 时间步长\n    :param step: 仿真步\n    :param tau: 膜电位时间常数, 用于控制膜电位衰减\n    :param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``\n    :param requires_thres_grad: 是否需要计算对于threshold的梯度, 默认为 ``False``\n    :param sigmoid_thres: 是否使用sigmoid约束threshold的范围搭到 [0, 1], 默认为 ``False``\n    :param requires_fp: 是否需要在推理过程中保存feature map, 需要消耗额外的内存和时间, 默认为 ``False``\n    :param layer_by_layer: 是否以一次性计算所有step的输出, 在网络模型较大的情况下, 一般会缩短单次推理的时间, 默认为 ``False``\n    :param n_groups: 在不同的时间步, 是否使用不同的权重, 默认为 ``1``, 即不分组\n    :param args: 其他的参数\n    :param kwargs: 其他的参数\n    \"\"\"\n\n    def __init__(self, threshold=0.5, tau=2., act_fun=QGateGrad, *args, **kwargs):\n        super().__init__(threshold, *args, **kwargs)\n        self.tau = tau\n        if isinstance(act_fun, str):\n            act_fun = eval(act_fun)\n        self.act_fun = act_fun(alpha=2., requires_grad=False)\n        # self.threshold = threshold\n        # print(threshold)\n        # print(tau)\n\n    def integral(self, inputs):\n        self.mem = self.mem + (inputs - self.mem) / self.tau\n\n    def calc_spike(self):\n        self.spike = self.act_fun(self.mem - self.threshold)\n        self.mem = self.mem * (1 - self.spike.detach())\n\n\nclass BurstLIFNode(LIFNode):\n    def __init__(self, threshold=.5, tau=2., act_fun=RoundGrad, *args, **kwargs):\n        super().__init__(threshold=threshold, tau=tau, act_fun=act_fun, *args, **kwargs)\n        self.burst_factor = 1.5\n\n    def calc_spike(self):\n        LIFNode.calc_spike(self)\n        self.spike = torch.where(self.spike > 1., self.burst_factor * self.spike, self.spike)\n\n\n\nclass BackEINode(BaseNode):\n    \"\"\"\n    BackEINode with self feedback connection and excitatory and inhibitory neurons\n    Reference：https://www.sciencedirect.com/science/article/pii/S0893608022002520\n    :param threshold: 神经元发放脉冲需要达到的阈值\n    :param if_back whether to use self feedback\n    :param if_ei whether to use excitotory and inhibitory neurons\n    :param args: 其他的参数\n    :param kwargs: 其他的参数\n    \"\"\"\n    def __init__(self, threshold=0.5, decay=0.2, act_fun=BackEIGateGrad, th_fun=EIGrad, channel=40, if_back=True,\n                 if_ei=True, cfg_backei=2, *args, **kwargs):\n        super().__init__(threshold, *args, **kwargs)\n        self.decay = decay\n        if isinstance(act_fun, str):\n            act_fun = eval(act_fun)\n        if isinstance(th_fun, str):\n            th_fun = eval(th_fun)\n        self.act_fun = act_fun()\n        self.th_fun = th_fun()\n        self.channel = channel\n        self.if_back = if_back\n\n        if self.if_back:\n            self.back = nn.Conv2d(channel, channel, kernel_size=2 * cfg_backei+1, stride=1, padding=cfg_backei)\n        self.if_ei = if_ei\n        if self.if_ei:\n            self.ei = nn.Conv2d(channel, channel, kernel_size=2 * cfg_backei+1, stride=1, padding=cfg_backei)\n\n    def integral(self, inputs):\n        if self.mem is None:\n            self.mem = torch.zeros_like(inputs)\n            self.spike = torch.zeros_like(inputs)\n        self.mem = self.decay * self.mem\n        if self.if_back:\n            self.mem += F.sigmoid(self.back(self.spike)) * inputs\n        else:\n            self.mem += inputs\n\n    def calc_spike(self):\n        if self.if_ei:\n            ei_gate = self.th_fun(self.ei(self.mem))\n            self.spike = self.act_fun(self.mem-self.threshold)\n            self.mem = self.mem * (1 - self.spike)\n            self.spike = ei_gate * self.spike\n        else:\n            self.spike = self.act_fun(self.mem-self.threshold)\n            self.mem = self.mem * (1 - self.spike)\n\n    def n_reset(self):\n        self.mem = None\n        self.spike = None\n        self.feature_map = []\n        self.mem_collect = []\n\n\nclass NoiseLIFNode(LIFNode):\n    \"\"\"\n    Noisy Leaky Integrate and Fire\n    在神经元中注入噪声, 默认的噪声分布为 ``Beta(log(2), log(6))``\n    :param threshold: 神经元发放脉冲需要达到的阈值\n    :param v_reset: 静息电位\n    :param dt: 时间步长\n    :param step: 仿真步\n    :param tau: 膜电位时间常数, 用于控制膜电位衰减\n    :param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``\n    :param requires_thres_grad: 是否需要计算对于threshold的梯度, 默认为 ``False``\n    :param sigmoid_thres: 是否使用sigmoid约束threshold的范围搭到 [0, 1], 默认为 ``False``\n    :param requires_fp: 是否需要在推理过程中保存feature map, 需要消耗额外的内存和时间, 默认为 ``False``\n    :param layer_by_layer: 是否以一次性计算所有step的输出, 在网络模型较大的情况下, 一般会缩短单次推理的时间, 默认为 ``False``\n    :param n_groups: 在不同的时间步, 是否使用不同的权重, 默认为 ``1``, 即不分组\n    :param log_alpha: 控制 beta 分布的参数 ``a``\n    :param log_beta: 控制 beta 分布的参数 ``b``\n    :param args: 其他的参数\n    :param kwargs: 其他的参数\n    \"\"\"\n\n    def __init__(self,\n                 threshold=1,\n                 tau=2.,\n                 act_fun=GateGrad,\n                 log_alpha=np.log(2),\n                 log_beta=np.log(6),\n                 *args,\n                 **kwargs):\n        super().__init__(threshold=threshold, tau=tau, act_fun=act_fun, *args, **kwargs)\n        self.log_alpha = Parameter(torch.as_tensor(log_alpha), requires_grad=True)\n        self.log_beta = Parameter(torch.as_tensor(log_beta), requires_grad=True)\n\n        # self.fc = nn.Sequential(\n        #     nn.Linear(1, 5),\n        #     nn.ReLU(),\n        #     nn.Linear(5, 5),\n        #     nn.ReLU(),\n        #     nn.Linear(5, 2)\n        # )\n\n    def integral(self, inputs):  # b, c, w, h / b, c\n        # self.mu, self.log_var = self.fc(inputs.mean().unsqueeze(0)).split(1)\n        alpha, beta = torch.exp(self.log_alpha), torch.exp(self.log_beta)\n        mu = alpha / (alpha + beta)\n        var = ((alpha + 1) * alpha) / ((alpha + beta + 1) * (alpha + beta))\n        noise = torch.distributions.beta.Beta(alpha, beta).sample(inputs.shape) * self.get_thres()\n        noise = noise * var / var.detach() + mu - mu.detach()\n\n        self.mem = self.mem + ((inputs - self.mem) / self.tau + noise) * self.dt\n\n\nclass BiasLIFNode(BaseNode):\n    \"\"\"\n    带有恒定电流输入Bias的LIF神经元，用于带有抑制性/反馈链接的网络的测试\n    Noisy Leaky Integrate and Fire\n    在神经元中注入噪声, 默认的噪声分布为 ``Beta(log(2), log(6))``\n    :param threshold: 神经元发放脉冲需要达到的阈值\n    :param v_reset: 静息电位\n    :param dt: 时间步长\n    :param step: 仿真步\n    :param tau: 膜电位时间常数, 用于控制膜电位衰减\n    :param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``\n    :param requires_thres_grad: 是否需要计算对于threshold的梯度, 默认为 ``False``\n    :param sigmoid_thres: 是否使用sigmoid约束threshold的范围搭到 [0, 1], 默认为 ``False``\n    :param requires_fp: 是否需要在推理过程中保存feature map, 需要消耗额外的内存和时间, 默认为 ``False``\n    :param layer_by_layer: 是否以一次性计算所有step的输出, 在网络模型较大的情况下, 一般会缩短单次推理的时间, 默认为 ``False``\n    :param n_groups: 在不同的时间步, 是否使用不同的权重, 默认为 ``1``, 即不分组\n    :param args: 其他的参数\n    :param kwargs: 其他的参数\n    \"\"\"\n\n    def __init__(self, threshold=1., tau=2., act_fun=AtanGrad, *args, **kwargs):\n        super().__init__(threshold, *args, **kwargs)\n        self.tau = tau\n        if isinstance(act_fun, str):\n            act_fun = eval(act_fun)\n        self.act_fun = act_fun(alpha=2., requires_grad=False)\n\n    def integral(self, inputs):\n        self.mem = self.mem + ((inputs - self.mem) / self.tau) * self.dt + 0.1\n\n    def calc_spike(self):\n        self.spike = self.act_fun(self.mem - self.get_thres())\n        self.mem = self.mem * (1 - self.spike.detach())\n\n\nclass LIFSTDPNode(BaseNode):\n    \"\"\"\n    用于执行STDP运算时使用的节点 decay的方式是膜电位乘以decay并直接加上输入电流\n    \"\"\"\n\n    def __init__(self, threshold=1., tau=2., act_fun=AtanGrad, *args, **kwargs):\n        super().__init__(threshold, *args, **kwargs)\n        self.tau = tau\n        if isinstance(act_fun, str):\n            act_fun = eval(act_fun)\n        self.act_fun = act_fun(alpha=2., requires_grad=False)\n\n    def integral(self, inputs):\n        self.mem = self.mem * self.tau + inputs\n\n    def calc_spike(self):\n        self.spike = self.act_fun(self.mem - self.threshold)\n        # print(( self.threshold).max())\n        self.mem = self.mem * (1 - self.spike.detach())\n\n    def requires_activation(self):\n        return False\n\n\nclass PLIFNode(BaseNode):\n    \"\"\"\n    Parametric LIF， 其中的 ```tau``` 会被backward过程影响\n    Reference：https://arxiv.org/abs/2007.05785\n    :param threshold: 神经元发放脉冲需要达到的阈值\n    :param v_reset: 静息电位\n    :param dt: 时间步长\n    :param step: 仿真步\n    :param tau: 膜电位时间常数, 用于控制膜电位衰减\n    :param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``\n    :param requires_thres_grad: 是否需要计算对于threshold的梯度, 默认为 ``False``\n    :param sigmoid_thres: 是否使用sigmoid约束threshold的范围搭到 [0, 1], 默认为 ``False``\n    :param requires_fp: 是否需要在推理过程中保存feature map, 需要消耗额外的内存和时间, 默认为 ``False``\n    :param layer_by_layer: 是否以一次性计算所有step的输出, 在网络模型较大的情况下, 一般会缩短单次推理的时间, 默认为 ``False``\n    :param n_groups: 在不同的时间步, 是否使用不同的权重, 默认为 ``1``, 即不分组\n    :param args: 其他的参数\n    :param kwargs: 其他的参数\n    \"\"\"\n\n    def __init__(self, threshold=1., tau=2., act_fun=AtanGrad, *args, **kwargs):\n        super().__init__(threshold, *args, **kwargs)\n        init_w = -math.log(tau - 1.)\n        if isinstance(act_fun, str):\n            act_fun = eval(act_fun)\n        self.act_fun = act_fun(alpha=2., requires_grad=True)\n        self.w = nn.Parameter(torch.as_tensor(init_w))\n\n    def integral(self, inputs):\n        self.mem = self.mem + ((inputs - self.mem) * self.w.sigmoid()) * self.dt\n\n    def calc_spike(self):\n        self.spike = self.act_fun(self.mem - self.get_thres())\n        self.mem = self.mem * (1 - self.spike.detach())\n\n\nclass PSU(BaseNode):\n    def __init__(self, threshold=1., tau=2., act_fun=AtanGrad, *args, **kwargs):\n        super().__init__(threshold, *args, **kwargs)\n        init_w = -math.log(tau - 1.)\n        if isinstance(act_fun, str):\n            act_fun = eval(act_fun)\n        self.parallel = True\n        self.act_fun = act_fun(alpha=2., requires_grad=True)\n\n        T = self.step\n        m1, m2 = generate_matrix(T, tau)\n        self.register_buffer('m1', m1)\n        self.register_buffer('m2', m2)\n        self.m2 *= self.threshold\n\n    def integral(self, inputs):\n        d1 = self.m1 @ inputs.flatten(1)\n        self.mem = (d1 + self.m2 @ d1.sigmoid()).view(inputs.shape)\n\n    def calc_spike(self):\n        self.spike = self.act_fun(self.mem - self.threshold)\n\n\nclass IPSU(BaseNode):\n    def masked_weight(self):\n        return self.fc.weight * self.mask0\n\n    def __init__(self, threshold=1., tau=2., act_fun=AtanGrad, *args, **kwargs):\n        super().__init__(threshold, *args, **kwargs)\n        init_w = -math.log(tau - 1.)\n        if isinstance(act_fun, str):\n            act_fun = eval(act_fun)\n        self.parallel = True\n        self.act_fun = act_fun(alpha=2., requires_grad=True)\n\n        T = self.step\n        matrix, matrix2 = generate_matrix(T, tau)\n        self.register_buffer('m1', matrix)\n        self.register_buffer('m2', matrix2)\n        # self.m2 *= self.threshold\n\n        self.fc = nn.Linear(T, T)\n        nn.init.constant_(self.fc.bias, 0.)\n        nn.init.kaiming_normal_(self.fc.weight, mode='fan_out', nonlinearity='relu')\n\n        mask0 = torch.tril(torch.ones([T, T]))\n        self.register_buffer('mask0', mask0)\n\n    def integral(self, inputs):\n        d1 = torch.addmm(self.fc.bias.unsqueeze(1), self.masked_weight(), inputs.flatten((1)))\n        self.mem = (d1 + self.m2 @ inputs.flatten(1)).view(inputs.shape)\n\n    def calc_spike(self):\n        self.spike = self.act_fun(self.mem - self.threshold)\n\n\nclass RPSU(BaseNode):\n    def masked_weight(self):\n        return self.fc.weight * self.mask0\n\n    def __init__(self, threshold=1., tau=2., act_fun=AtanGrad, *args, **kwargs):\n        super().__init__(threshold, *args, **kwargs)\n        init_w = -math.log(tau - 1.)\n        if isinstance(act_fun, str):\n            act_fun = eval(act_fun)\n        self.parallel = True\n        self.act_fun = act_fun(alpha=2., requires_grad=True)\n\n        T = self.step\n        matrix, matrix2 = generate_matrix(T, tau)\n        self.register_buffer('m1', matrix)\n        self.register_buffer('m2', matrix2)\n        # self.m2 *= self.threshold\n\n        self.fc = nn.Linear(T, T)\n        nn.init.constant_(self.fc.bias, 0.)\n        nn.init.kaiming_normal_(self.fc.weight, mode='fan_out', nonlinearity='relu')\n\n        mask0 = torch.tril(torch.ones([T, T]))\n        self.register_buffer('mask0', mask0)\n\n    def integral(self, inputs):\n        d1 = self.m1 @ inputs.flatten(1)\n        d2 = torch.addmm(self.fc.bias.unsqueeze(1), self.masked_weight(), inputs.flatten((1)))\n        self.mem = (d1 + self.m2 @ d2.sigmoid()).view(inputs.shape)\n\n    def calc_spike(self):\n        self.spike = self.act_fun(self.mem - self.threshold)\n\n\nclass SPSN(BaseNode):\n    def __init__(self, threshold=1., tau=2., act_fun=AtanGrad, *args, **kwargs):\n        super().__init__(threshold, *args, **kwargs)\n        init_w = -math.log(tau - 1.)\n        if isinstance(act_fun, str):\n            act_fun = eval(act_fun)\n        self.parallel = True\n        self.act_fun = act_fun(alpha=2., requires_grad=True)\n\n        m1, m2 = generate_matrix(self.step, tau)\n        self.register_buffer('m1', m1)\n\n    def integral(self, inputs):\n        self.mem = (self.m1 @ inputs.flatten(1)).sigmoid().view(inputs.shape)\n\n    def calc_spike(self):\n        self.spike = torch.bernoulli(self.mem)\n\n\nclass NoisePLIFNode(PLIFNode):\n    \"\"\"\n    Noisy Parametric Leaky Integrate and Fire\n    :param threshold: 神经元发放脉冲需要达到的阈值\n    :param v_reset: 静息电位\n    :param dt: 时间步长\n    :param step: 仿真步\n    :param tau: 膜电位时间常数, 用于控制膜电位衰减\n    :param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``\n    :param requires_thres_grad: 是否需要计算对于threshold的梯度, 默认为 ``False``\n    :param sigmoid_thres: 是否使用sigmoid约束threshold的范围搭到 [0, 1], 默认为 ``False``\n    :param requires_fp: 是否需要在推理过程中保存feature map, 需要消耗额外的内存和时间, 默认为 ``False``\n    :param layer_by_layer: 是否以一次性计算所有step的输出, 在网络模型较大的情况下, 一般会缩短单次推理的时间, 默认为 ``False``\n    :param n_groups: 在不同的时间步, 是否使用不同的权重, 默认为 ``1``, 即不分组\n    :param args: 其他的参数\n    :param kwargs: 其他的参数\n    \"\"\"\n\n    def __init__(self,\n                 threshold=1,\n                 tau=2.,\n                 act_fun=GateGrad,\n                 *args,\n                 **kwargs):\n        super().__init__(threshold=threshold, tau=tau, act_fun=act_fun, *args, **kwargs)\n        log_alpha = kwargs['log_alpha'] if 'log_alpha' in kwargs else np.log(2)\n        log_beta = kwargs['log_beta'] if 'log_beta' in kwargs else np.log(6)\n        self.log_alpha = Parameter(torch.as_tensor(log_alpha), requires_grad=True)\n        self.log_beta = Parameter(torch.as_tensor(log_beta), requires_grad=True)\n\n        # self.fc = nn.Sequential(\n        #     nn.Linear(1, 5),\n        #     nn.ReLU(),\n        #     nn.Linear(5, 5),\n        #     nn.ReLU(),\n        #     nn.Linear(5, 2)\n        # )\n\n    def integral(self, inputs):  # b, c, w, h / b, c\n        # self.mu, self.log_var = self.fc(inputs.mean().unsqueeze(0)).split(1)\n        alpha, beta = torch.exp(self.log_alpha), torch.exp(self.log_beta)\n        mu = alpha / (alpha + beta)\n        var = ((alpha + 1) * alpha) / ((alpha + beta + 1) * (alpha + beta))\n        noise = torch.distributions.beta.Beta(alpha, beta).sample(inputs.shape) * self.get_thres()\n        noise = noise * var / var.detach() + mu - mu.detach()\n        self.mem = self.mem + ((inputs - self.mem) * self.w.sigmoid() + noise) * self.dt\n\n\nclass BiasPLIFNode(BaseNode):\n    \"\"\"\n    Parametric LIF with bias\n    :param threshold: 神经元发放脉冲需要达到的阈值\n    :param v_reset: 静息电位\n    :param dt: 时间步长\n    :param step: 仿真步\n    :param tau: 膜电位时间常数, 用于控制膜电位衰减\n    :param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``\n    :param requires_thres_grad: 是否需要计算对于threshold的梯度, 默认为 ``False``\n    :param sigmoid_thres: 是否使用sigmoid约束threshold的范围搭到 [0, 1], 默认为 ``False``\n    :param requires_fp: 是否需要在推理过程中保存feature map, 需要消耗额外的内存和时间, 默认为 ``False``\n    :param layer_by_layer: 是否以一次性计算所有step的输出, 在网络模型较大的情况下, 一般会缩短单次推理的时间, 默认为 ``False``\n    :param n_groups: 在不同的时间步, 是否使用不同的权重, 默认为 ``1``, 即不分组\n    :param args: 其他的参数\n    :param kwargs: 其他的参数\n    \"\"\"\n\n    def __init__(self, threshold=1., tau=2., act_fun=AtanGrad, *args, **kwargs):\n        super().__init__(threshold, *args, **kwargs)\n        init_w = -math.log(tau - 1.)\n        if isinstance(act_fun, str):\n            act_fun = eval(act_fun)\n        self.act_fun = act_fun(alpha=2., requires_grad=True)\n        self.w = nn.Parameter(torch.as_tensor(init_w))\n\n    def integral(self, inputs):\n        self.mem = self.mem + ((inputs - self.mem) * self.w.sigmoid() + 0.1) * self.dt\n\n    def calc_spike(self):\n        self.spike = self.act_fun(self.mem - self.get_thres())\n        self.mem = self.mem * (1 - self.spike.detach())\n\n\nclass DoubleSidePLIFNode(LIFNode):\n    \"\"\"\n    能够输入正负脉冲的 PLIF\n    :param threshold: 神经元发放脉冲需要达到的阈值\n    :param v_reset: 静息电位\n    :param dt: 时间步长\n    :param step: 仿真步\n    :param tau: 膜电位时间常数, 用于控制膜电位衰减\n    :param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``\n    :param requires_thres_grad: 是否需要计算对于threshold的梯度, 默认为 ``False``\n    :param sigmoid_thres: 是否使用sigmoid约束threshold的范围搭到 [0, 1], 默认为 ``False``\n    :param requires_fp: 是否需要在推理过程中保存feature map, 需要消耗额外的内存和时间, 默认为 ``False``\n    :param layer_by_layer: 是否以一次性计算所有step的输出, 在网络模型较大的情况下, 一般会缩短单次推理的时间, 默认为 ``False``\n    :param n_groups: 在不同的时间步, 是否使用不同的权重, 默认为 ``1``, 即不分组\n    :param args: 其他的参数\n    :param kwargs: 其他的参数\n    \"\"\"\n\n    def __init__(self,\n                 threshold=.5,\n                 tau=2.,\n                 act_fun=AtanGrad,\n                 *args,\n                 **kwargs):\n        super().__init__(threshold, tau, act_fun, *args, **kwargs)\n        if isinstance(act_fun, str):\n            act_fun = eval(act_fun)\n        self.act_fun = act_fun(alpha=2., requires_grad=True)\n\n    def calc_spike(self):\n        self.spike = self.act_fun(self.mem - self.get_thres()) - self.act_fun(self.get_thres - self.mem)\n        self.mem = self.mem * (1. - torch.abs(self.spike.detach()))\n\n\nclass IzhNode(BaseNode):\n    \"\"\"\n    Izhikevich 脉冲神经元\n    :param threshold: 神经元发放脉冲需要达到的阈值\n    :param v_reset: 静息电位\n    :param dt: 时间步长\n    :param step: 仿真步\n    :param tau: 膜电位时间常数, 用于控制膜电位衰减\n    :param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``\n    :param args: 其他的参数\n    :param kwargs: 其他的参数\n    \"\"\"\n\n    def __init__(self, threshold=1., tau=2., act_fun=AtanGrad, *args, **kwargs):\n        super().__init__(threshold, *args, **kwargs)\n        self.tau = tau\n        if isinstance(act_fun, str):\n            act_fun = eval(act_fun)\n        self.act_fun = act_fun(alpha=2., requires_grad=False)\n        self.a = kwargs['a'] if 'a' in kwargs else 0.02\n        self.b = kwargs['b'] if 'b' in kwargs else 0.2\n        self.c = kwargs['c'] if 'c' in kwargs else -55.\n        self.d = kwargs['d'] if 'd' in kwargs else -2.\n        '''\n        v' = 0.04v^2 + 5v + 140 -u + I\n        u' = a(bv-u)\n        下面是将Izh离散化的写法\n        if v>= thresh:\n            v = c\n            u = u + d\n        '''\n        # 初始化膜电势 以及 对应的U\n        self.mem = 0.\n        self.u = 0.\n        self.dt = kwargs['dt'] if 'dt' in kwargs else 1.\n\n    def integral(self, inputs):\n        self.mem = self.mem + self.dt * (0.04 * self.mem * self.mem + 5 * self.mem - self.u + 140 + inputs)\n        self.u = self.u + self.dt * (self.a * self.b * self.mem - self.a * self.u)\n\n    def calc_spike(self):\n        self.spike = self.act_fun(self.mem - self.get_thres())  # 大于阈值释放脉冲\n        self.mem = self.mem * (1 - self.spike.detach()) + self.spike.detach() * self.c\n        self.u = self.u + self.spike.detach() * self.d\n\n    def n_reset(self):\n        self.mem = 0.\n        self.u = 0.\n        self.spike = 0.\n\n\nclass IzhNodeMU(BaseNode):\n    \"\"\"\n    Izhikevich 脉冲神经元多参数版\n    :param threshold: 神经元发放脉冲需要达到的阈值\n    :param v_reset: 静息电位\n    :param dt: 时间步长\n    :param step: 仿真步\n    :param tau: 膜电位时间常数, 用于控制膜电位衰减\n    :param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``\n    :param args: 其他的参数\n    :param kwargs: 其他的参数\n    \"\"\"\n\n    def __init__(self, threshold=1., tau=2., act_fun=AtanGrad, *args, **kwargs):\n        super().__init__(threshold, *args, **kwargs)\n        self.tau = tau\n        if isinstance(act_fun, str):\n            act_fun = eval(act_fun)\n        self.act_fun = act_fun(alpha=2., requires_grad=False)\n        self.a = kwargs['a'] if 'a' in kwargs else 0.02\n        self.b = kwargs['b'] if 'b' in kwargs else 0.2\n        self.c = kwargs['c'] if 'c' in kwargs else -55.\n        self.d = kwargs['d'] if 'd' in kwargs else -2.\n        self.mem = kwargs['mem'] if 'mem' in kwargs else 0.\n        self.u = kwargs['u'] if 'u' in kwargs else 0.\n        self.dt = kwargs['dt'] if 'dt' in kwargs else 1.\n\n    def integral(self, inputs):\n        self.mem = self.mem + self.dt * (0.04 * self.mem * self.mem + 5 * self.mem - self.u + 140 + inputs)\n        self.u = self.u + self.dt * (self.a * self.b * self.mem - self.a * self.u)\n\n    def calc_spike(self):\n        self.spike = self.act_fun(self.mem - self.threshold)\n        self.mem = self.mem * (1 - self.spike.detach()) + self.spike.detach() * self.c\n        self.u = self.u + self.spike.detach() * self.d\n\n    def n_reset(self):\n        self.mem = -70.\n        self.u = 0.\n        self.spike = 0.\n\n    def requires_activation(self):\n        return False\n\n\nclass DGLIFNode(BaseNode):\n    \"\"\"\n    Reference: https://arxiv.org/abs/2110.08858\n    :param threshold: 神经元的脉冲发放阈值\n    :param tau: 神经元的膜常数, 控制膜电位衰减\n    \"\"\"\n\n    def __init__(self, threshold=.5, tau=2., *args, **kwargs):\n        super().__init__(threshold, tau, *args, **kwargs)\n        self.act = nn.ReLU()\n        self.tau = tau\n\n    def integral(self, inputs):\n        inputs = self.act(inputs)\n        self.mem = self.mem + ((inputs - self.mem) / self.tau) * self.dt\n\n    def calc_spike(self):\n        spike = self.mem.clone()\n        spike[(spike < self.get_thres())] = 0.\n        # self.spike = spike / (self.mem.detach().clone() + 1e-12)\n        self.spike = spike - spike.detach() + \\\n                     torch.where(spike.detach() > self.get_thres(), torch.ones_like(spike), torch.zeros_like(spike))\n        self.spike = spike\n        self.mem = torch.where(self.mem >= self.get_thres(), torch.zeros_like(self.mem), self.mem)\n\n\nclass HTDGLIFNode(IFNode):\n    \"\"\"\n    Reference: https://arxiv.org/abs/2110.08858\n    :param threshold: 神经元的脉冲发放阈值\n    :param tau: 神经元的膜常数, 控制膜电位衰减\n    \"\"\"\n\n    def __init__(self, threshold=.5, tau=2., *args, **kwargs):\n        super().__init__(threshold, *args, **kwargs)\n        self.warm_up = False\n\n    def calc_spike(self):\n        spike = self.mem.clone()\n        spike[(spike < self.get_thres())] = 0.\n        # self.spike = spike / (self.mem.detach().clone() + 1e-12)\n        self.spike = spike - spike.detach() + \\\n                     torch.where(spike.detach() > self.get_thres(), torch.ones_like(spike), torch.zeros_like(spike))\n        self.spike = spike\n        self.mem = torch.where(self.mem >= self.get_thres(), torch.zeros_like(self.mem), self.mem)\n        # self.mem[[(spike > self.get_thres())]] = self.mem[[(spike > self.get_thres())]] - self.get_thres()\n\n        self.mem = (self.mem + 0.2 * self.spike - 0.2 * self.spike.detach()) * self.dt\n\n    def forward(self, inputs):\n        if self.warm_up:\n            return F.relu(inputs)\n        else:\n            return super(IFNode, self).forward(F.relu(inputs))\n\n\nclass SimHHNode(BaseNode):\n    \"\"\"\n    简单版本的HH模型\n    :param threshold: 神经元发放脉冲需要达到的阈值\n    :param v_reset: 静息电位\n    :param dt: 时间步长\n    :param step: 仿真步\n    :param tau: 膜电位时间常数, 用于控制膜电位衰减\n    :param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``\n    :param args: 其他的参数\n    :param kwargs: 其他的参数\n    \"\"\"\n\n    def __init__(self, threshold=50., tau=2., act_fun=AtanGrad, *args, **kwargs):\n        super().__init__(threshold, *args, **kwargs)\n        self.tau = tau\n        if isinstance(act_fun, str):\n            act_fun = eval(act_fun)\n        '''\n        I = Cm dV/dt + g_k*n^4*(V_m-V_k) + g_Na*m^3*h*(V_m-V_Na) + g_l*(V_m - V_L)\n        '''\n        self.act_fun = act_fun(alpha=2., requires_grad=False)\n        self.g_Na, self.g_K, self.g_l = torch.tensor(120.), torch.tensor(120), torch.tensor(0.3)  # k 36\n        self.V_Na, self.V_K, self.V_l = torch.tensor(120.), torch.tensor(-120.), torch.tensor(10.6)  # k -12\n        self.m, self.n, self.h = torch.tensor(0), torch.tensor(0), torch.tensor(0)\n        self.mem = 0\n        self.dt = 0.01\n\n    def integral(self, inputs):\n        self.I_Na = torch.pow(self.m, 3) * self.g_Na * self.h * (self.mem - self.V_Na)\n        self.I_K = torch.pow(self.n, 4) * self.g_K * (self.mem - self.V_K)\n        self.I_L = self.g_l * (self.mem - self.V_l)\n        self.mem = self.mem + self.dt * (inputs - self.I_Na - self.I_K - self.I_L) / 0.02\n        # non Na\n        # self.mem = self.mem + 0.01 * (inputs -  self.I_K - self.I_L) / 0.02  #decayed\n        # NON k\n        # self.mem = self.mem + 0.01 * (inputs - self.I_Na - self.I_L) / 0.02  #increase\n\n        self.alpha_n = 0.01 * (self.mem + 10.0) / (1 - torch.exp(-(self.mem + 10.0) / 10))\n        self.beta_n = 0.125 * torch.exp(-(self.mem) / 80)\n\n        self.alpha_m = 0.1 * (self.mem + 25) / (1 - torch.exp(-(self.mem + 25) / 10))\n        self.beta_m = 4 * torch.exp(-(self.mem) / 18)\n\n        self.alpha_h = 0.07 * torch.exp(-(self.mem) / 20)\n        self.beta_h = 1 / (1 + torch.exp(-(self.mem + 30) / 10))\n\n        self.n = self.n + self.dt * (self.alpha_n * (1 - self.n) - self.beta_n * self.n)\n        self.m = self.m + self.dt * (self.alpha_m * (1 - self.m) - self.beta_m * self.m)\n        self.h = self.h + self.dt * (self.alpha_h * (1 - self.h) - self.beta_h * self.h)\n\n    def calc_spike(self):\n        self.spike = self.act_fun(self.mem - self.threshold)\n        self.mem = self.mem * (1 - self.spike.detach())\n\n    def forward(self, inputs):\n        self.integral(inputs)\n        self.calc_spike()\n        return self.spike\n\n    def n_reset(self):\n        self.mem = 0.\n        self.spike = 0.\n        self.m, self.n, self.h = torch.tensor(0), torch.tensor(0), torch.tensor(0)\n\n    def requires_activation(self):\n        return False\n\n\nclass CTIzhNode(IzhNode):\n    def __init__(self, threshold=1., tau=2., act_fun=AtanGrad, *args, **kwargs):\n        super().__init__(threshold, tau, act_fun, *args, **kwargs)\n\n        self.name = kwargs['name'] if 'name' in kwargs else ''\n        self.excitability = kwargs['excitability'] if 'excitability' in kwargs else 'TRUE'\n        self.spikepattern = kwargs['spikepattern'] if 'spikepattern' in kwargs else 'RS'\n        self.synnum = kwargs['synnum'] if 'synnum' in kwargs else 0\n        self.locationlayer = kwargs['locationlayer'] if 'locationlayer' in kwargs else ''\n        self.adjneuronlist = {}\n        self.proximal_dendrites = []\n        self.distal_dendrites = []\n        self.totalindex = kwargs['totalindex'] if 'totalindex' in kwargs else 0\n        self.colindex = 0\n        self.state = 'inactive'\n\n        self.Gup = kwargs['Gup'] if 'Gup' in kwargs else 0.0\n        self.Gdown = kwargs['Gdown'] if 'Gdown' in kwargs else 0.0\n        self.Vr = kwargs['Vr'] if 'Vr' in kwargs else 0.0\n        self.Vt = kwargs['Vt'] if 'Vt' in kwargs else 0.0\n        self.Vpeak = kwargs['Vpeak'] if 'Vpeak' in kwargs else 0.0\n        self.capicitance = kwargs['capacitance'] if 'capacitance' in kwargs else 0.0\n        self.k = kwargs['k'] if 'k' in kwargs else 0.0\n        self.mem = -65\n        self.vtmp = -65\n        self.u = -13.0\n        self.spike = 0\n        self.dc = 0\n\n    def integral(self, inputs):\n        self.mem += self.dt * (\n                self.k * (self.mem - self.Vr) * (self.mem - self.Vt) - self.u + inputs) / self.capicitance\n        self.u += self.dt * (self.a * (self.b * (self.mem - self.Vr) - self.u))\n\n    def calc_spike(self):\n        if self.mem >= self.Vpeak:\n            self.mem = self.c\n            self.u = self.u + self.d\n            self.spike = 1\n            self.spreadMarkPostNeurons()\n\n    def spreadMarkPostNeurons(self):\n        for post, list in self.adjneuronlist.items():\n            if self.excitability == \"TRUE\":\n                post.dc = random.randint(140, 160)\n            else:\n                post.dc = random.randint(-160, -140)\n\n\nclass adth(BaseNode):\n    \"\"\"\n        The adaptive Exponential Integrate-and-Fire model (aEIF)\n        :param args: Other parameters\n        :param kwargs: Other parameters\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(requires_fp=False, *args, **kwargs)\n\n    def adthNode(self, v, dt, c_m, g_m, alpha_w, ad, Ieff, Ichem, Igap, tau_ad, beta_ad, vt, vm1):\n        \"\"\"\n                Calculate the neurons that discharge after the current threshold is reached\n                :param v: Current neuron voltage\n                :param dt: time step\n                :param ad:Adaptive variable\n                :param vv:Spike, if the voltage exceeds the threshold from below\n        \"\"\"\n        v = v + dt / c_m * (-g_m * v + alpha_w * ad + Ieff + Ichem + Igap)\n        ad = ad + dt / tau_ad * (-ad + beta_ad * v)\n        vv = (v >= vt).astype(int) * (vm1 < vt).astype(int)\n        vm1 = v\n        return v, ad, vv, vm1\n\n    def calc_spike(self):\n        pass\n    \n    \nclass HHNode(BaseNode):\n    \"\"\"\n    用于脑模拟的HH模型\n    p: [threshold, g_Na, g_K, g_l, V_Na, V_K, V_l, C]\n\n    \"\"\"\n\n    def __init__(self, p, dt, device, act_fun=AtanGrad, *args, **kwargs):\n        super().__init__(threshold=p[0], *args, **kwargs)\n        if isinstance(act_fun, str):\n            act_fun = eval(act_fun)\n        '''\n        I = Cm dV/dt + g_k*n^4*(V_m-V_k) + g_Na*m^3*h*(V_m-V_Na) + g_l*(V_m - V_L)\n        '''\n        self.neuron_num = len(p[0])\n        self.act_fun = act_fun(alpha=2., requires_grad=False)\n        self.tau_I = 3\n        self.g_Na = torch.tensor(p[1])\n        self.g_K = torch.tensor(p[2])\n        self.g_l = torch.tensor(p[3])\n        self.V_Na = torch.tensor(p[4])\n        self.V_K = torch.tensor(p[5])\n        self.V_l = torch.tensor(p[6])\n        self.C = torch.tensor(p[7])\n        self.m = 0.05 * torch.ones(self.neuron_num, device=device, requires_grad=False)\n        self.n = 0.31 * torch.ones(self.neuron_num, device=device, requires_grad=False)\n        self.h = 0.59 * torch.ones(self.neuron_num, device=device, requires_grad=False)\n        self.v_reset = 0\n        self.dt = dt\n        self.dt_over_tau = self.dt / self.tau_I\n        self.sqrt_coeff = math.sqrt(1 / (2 * (1 / self.dt_over_tau)))\n        self.mu = 10\n        self.sig = 12\n\n        self.mem = torch.tensor(self.v_reset, device=device, requires_grad=False)\n        self.mem_p = self.mem\n        self.spike = torch.zeros(self.neuron_num, device=device, requires_grad=False)\n        self.Iback = torch.zeros(self.neuron_num, device=device, requires_grad=False)\n        self.Ieff = torch.zeros(self.neuron_num, device=device, requires_grad=False)\n\n    def integral(self, inputs):\n        self.alpha_n = (0.1 - 0.01 * self.mem) / (torch.exp(1 - 0.1 * self.mem) - 1)\n        self.alpha_m = (2.5 - 0.1 * self.mem) / (torch.exp(2.5 - 0.1 * self.mem) - 1)\n        self.alpha_h = 0.07 * torch.exp(-self.mem / 20.0)\n\n        self.beta_n = 0.125 * torch.exp(-self.mem / 80.0)\n        self.beta_m = 4.0 * torch.exp(-self.mem / 18.0)\n        self.beta_h = 1 / (torch.exp(3 - 0.1 * self.mem) + 1)\n\n        self.n = self.n + self.dt * (self.alpha_n * (1 - self.n) - self.beta_n * self.n)\n        self.m = self.m + self.dt * (self.alpha_m * (1 - self.m) - self.beta_m * self.m)\n        self.h = self.h + self.dt * (self.alpha_h * (1 - self.h) - self.beta_h * self.h)\n\n        self.I_Na = torch.pow(self.m, 3) * self.g_Na * self.h * (self.mem - self.V_Na)\n        self.I_K = torch.pow(self.n, 4) * self.g_K * (self.mem - self.V_K)\n        self.I_L = self.g_l * (self.mem - self.V_l)\n\n        self.mem_p = self.mem\n        self.mem = self.mem + self.dt * (inputs - self.I_Na - self.I_K - self.I_L) / self.C\n\n    def calc_spike(self):\n        self.spike = (self.threshold > self.mem_p).float() * (self.mem > self.threshold).float()\n\n    def forward(self, inputs):\n        self.integral(inputs)\n        self.calc_spike()\n        return self.spike, self.mem\n\n    def requires_activation(self):\n        return False\n\n    \nclass aEIF(BaseNode):\n    \"\"\"\n        The adaptive Exponential Integrate-and-Fire model (aEIF)\n        This class define the membrane, spike, current and parameters of a neuron group of a specific type\n        :param args: Other parameters\n        :param kwargs: Other parameters\n    \"\"\"\n\n    def __init__(self, p, dt, device, *args, **kwargs):\n        \"\"\"\n            p:[threshold, v_reset, c_m, tao_w, alpha_ad, beta_ad]\n\n        \"\"\"\n        super().__init__(threshold=p[0], requires_fp=False, *args, **kwargs)\n        self.neuron_num = len(p[0])\n        self.g_m = 0.1  # neuron conduction\n        self.dt = dt\n        self.tau_I = 3  # Time constant to filter the synaptic inputs\n        self.Delta_T = 0.5  # parameter\n        self.v_reset = p[1]  # membrane potential reset to v_reset after fire spike\n        self.c_m = p[2]\n        self.tau_w = p[3]  # Time constant of adaption coupling\n        self.alpha_ad = p[4]\n        self.beta_ad = p[5]\n        self.refrac = 5 / self.dt  # refractory period\n        self.dt_over_tau = self.dt / self.tau_I\n        self.sqrt_coeff = math.sqrt(1 / (2 * (1 / self.dt_over_tau)))\n        self.mem = self.v_reset\n        self.spike = torch.zeros(self.neuron_num, device=device, requires_grad=False)\n        self.ad = torch.zeros(self.neuron_num, device=device, requires_grad=False)\n        self.ref = torch.randint(0, int(self.refrac + 1), (1, self.neuron_num), device=device, requires_grad=False).squeeze(\n            0)  # refractory counter\n        self.ref = self.ref.float()\n        self.mu = 10\n        self.sig = 12\n        self.Iback = torch.zeros(self.neuron_num, device=device, requires_grad=False)\n        self.Ieff = torch.zeros(self.neuron_num, device=device, requires_grad=False)\n\n    def integral(self, inputs):\n\n        self.mem = self.mem + (self.ref > self.refrac) * self.dt / self.c_m * \\\n                   (-self.g_m * (self.mem - self.v_reset) + self.g_m * self.Delta_T *\n                    torch.exp((self.mem - self.threshold) / self.Delta_T) +\n                    self.alpha_ad * self.ad + inputs)\n\n        self.ad = self.ad + (self.ref > self.refrac) * self.dt / self.tau_w * \\\n                  (-self.ad + self.beta_ad * (self.mem - self.v_reset))\n\n    def calc_spike(self):\n        self.spike = (self.mem > self.threshold).float()\n        self.ref = self.ref * (1 - self.spike) + 1\n        self.ad = self.ad + self.spike * 30\n        self.mem = self.spike * self.v_reset + (1 - self.spike.detach()) * self.mem\n\n    def forward(self, inputs):\n\n        # aeifnode_cuda.forward(self.threshold, self.c_m, self.alpha_w, self.beta_ad, inputs, self.ref, self.ad, self.mem, self.spike)\n        self.integral(inputs)\n        self.calc_spike()\n\n        return self.spike, self.mem\n    \nclass LIAFNode(BaseNode):\n    \"\"\"\n    Leaky Integrate and Analog Fire (LIAF), Reference: https://ieeexplore.ieee.org/abstract/document/9429228\n    与LIF相同, 但前传的是膜电势, 更新沿用阈值和膜电势\n    :param act_fun: 前传使用的激活函数 [ReLU, SeLU, LeakyReLU]\n    :param threshold_related: 阈值依赖模式，若为\"True\"则 self.spike = act_fun(mem-threshold)\n    :note that BaseNode return self.spike, and here self.spike is analog value.\n    \"\"\"\n    def __init__(self, spike_act=BackEIGateGrad(), act_fun=\"SELU\", threshold=0.5, tau=2., threshold_related=True, *args, **kwargs):\n        super().__init__(threshold, *args, **kwargs)\n        if isinstance(act_fun, str):\n            act_fun = eval(\"nn.\" + act_fun + \"()\")\n        self.tau = tau\n        self.act_fun = act_fun\n        self.spike_act = spike_act\n        self.threshold_related = threshold_related\n\n    def integral(self, inputs):\n        self.mem = self.mem + (inputs - self.mem) / self.tau\n\n    def calc_spike(self):\n        if self.threshold_related:\n            spike_tmp = self.act_fun(self.mem - self.threshold)\n        else:\n            spike_tmp = self.act_fun(self.mem)\n        self.spike = self.spike_act(self.mem - self.threshold)\n        self.mem = self.mem * (1 - self.spike)\n        self.spike = spike_tmp\n\n\n\nclass OnlineLIFNode(BaseNode):\n    \"\"\"\n    Online-update Leaky Integrate and Fire\n    与LIF模型相同，但是时序信息在反传时从计算图剥离，因此可以实现在线的更新；模型占用显存固定，不随仿真步step线性提升。\n    使用此神经元需要修改:  1. 将模型中t次forward从model_zoo写到main.py中\n                       2. 在Conv层与OnelineLIFNode层中加入Replace函数，即时序前传都是detach的，但仍计算该层空间梯度信息。\n                       3. 网络结构不适用BN层，使用weight standardization\n    注意该神经元不同于OTTT，而是将时序信息全部扔弃。对应这篇文章：https://arxiv.org/abs/2302.14311\n    若需保留时序，需要对self.rate_tracking进行计算。实现可参考https://github.com/pkuxmq/OTTT-SNN\n    \"\"\"\n\n    def __init__(self, threshold=0.5, tau=2., act_fun=QGateGrad, init=False, *args, **kwargs):\n        super().__init__(threshold, *args, **kwargs)\n        self.tau = tau\n        if isinstance(act_fun, str):\n            act_fun = eval(act_fun)\n        self.act_fun = act_fun(alpha=2., requires_grad=False)\n        self.rate_tracking = None\n        self.init = True\n\n\n    def integral(self, inputs):\n        if self.init is True:\n            self.mem = torch.zeros_like(inputs)\n            self.init = False\n        self.mem = self.mem.detach() + (inputs - self.mem.detach()) / self.tau\n\n    def calc_spike(self):\n        self.spike = self.act_fun(self.mem - self.threshold)\n        self.mem = self.mem * (1 - self.spike.detach())\n        with torch.no_grad():\n            if self.rate_tracking == None:\n                self.rate_tracking = self.spike.clone().detach()\n        self.spike = torch.cat((self.spike, self.rate_tracking), dim=0)\n\n\nclass AdaptiveNode(LIFNode):\n\n    def __init__(self, threshold=1., act_fun=QGateGrad, step=10, spike_output=True, *args, **kwargs):\n        super().__init__(threshold=threshold, step=step, **kwargs)\n        self.n_encode_type = kwargs['n_encode_type'] if 'n_encode_type' in kwargs else 'linear'\n        if isinstance(act_fun, str):\n            act_fun = eval(act_fun)\n        self.act_fun = act_fun(alpha=2., requires_grad=False)\n        # self.act_fun = BinaryActivation()\n        print(self.n_encode_type)\n        if self.n_encode_type == 'linear':\n            self.encoder = nn.Sequential(\n                CustomLinear(self.step, self.step)\n            )\n        elif self.n_encode_type == 'mlp':\n            # Direct\n            self.encoder = nn.Sequential(\n                CustomLinear(self.step, self.step),\n                nn.ReLU(),\n                CustomLinear(self.step, self.step),\n                nn.ReLU(),\n                CustomLinear(self.step, self.step),\n                nn.ReLU(),\n                CustomLinear(self.step, self.step),\n            )\n        elif self.n_encode_type == 'att':\n            # -> SE block\n            self.encoder = nn.Sequential(\n                nn.Linear(self.step, self.step),\n                nn.ReLU(),\n                nn.Linear(self.step, self.step),\n                nn.ReLU(),\n                nn.Linear(self.step, self.step),\n                nn.Sigmoid()\n            )\n        elif self.n_encode_type == 'conv':\n            self.encoder = nn.Sequential(\n                nn.Linear(self.step, self.step),\n                nn.ReLU(),\n                nn.Linear(self.step, self.step),\n            )\n            # self.init_weight()\n        else:\n            raise NotImplementedError('Unrecognizable categories {}.'.format(self.n_encode_type))\n\n        self.saved_mem = 0.\n\n    def init_weight(self):\n        for mod in self.encoder.modules():\n            if isinstance(mod, nn.Conv1d):\n                mod.weight.data[:, :, 4] = 1. / mod.weight.shape[0]\n                mod.weight.data[:, :, [0, 1, 2, 3, 5, 6, 7, 8]] = 0.\n                mod.bias.data[:] = 0.\n\n    def forward(self, inputs):  # (t b) c w h\n        if self.n_encode_type != 'conv':\n            x = rearrange(inputs, '(t b) ... -> b ... t', t=self.step)\n        else:\n            c, w, h = inputs.shape[1:]\n            x = rearrange(inputs, '(t b) c w h -> (b c w h) 1 t', t=self.step)\n\n        if self.n_encode_type != 'att':\n            x = self.encoder(x)  # Direct\n        else:\n            x = x * self.encoder(x)  # SE Block\n\n        if self.n_encode_type != 'conv':\n            x = rearrange(x, 'b ... t -> (t b) ...')\n        else:\n            x = rearrange(x, '(b c w h) 1 t -> (t b) c w h', c=c, w=w, h=h)\n\n        # self.spike = self.act_fun(x - 0.5)\n        # # print(self.spike.mean())\n        # # print(self.requires_fp)\n        # if self.requires_fp:\n        #     spike = rearrange(self.spike, '(t b) c w h -> t b c w h', t=self.step)\n        #     for t in range(self.step):\n        #         # print(t, float(spike[t].mean()), float(spike[t].std()))\n        #         self.feature_map.append(spike[t])\n        #     self.saved_mem = x\n        # return self.spike\n\n        return super().forward(x)\n\n    # def get_thres(self):\n    #     mem_relu = F.relu(self.mem.detach())\n    #     return mem_relu[mem_relu > 0.].median()\n\n    def n_reset(self):\n        super().n_reset()\n        self.saved_mem = 0."
  },
  {
    "path": "braincog/base/strategy/LateralInhibition.py",
    "content": "import warnings\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\n\n\nclass LateralInhibition(nn.Module):\n    \"\"\"\n    侧抑制 用于发放脉冲的神经元抑制其他同层神经元 在膜电位上作用\n    \"\"\"\n    def __init__(self, node, inh, mode=\"constant\"):\n        super().__init__()\n        self.inh = inh\n        self.node = node\n        self.mode = mode\n\n    def forward(self, x: torch.Tensor, xori=None):\n        # x.shape = [N, C,W,H]\n        # ret.shape = [N, C,W,H]\n        if self.mode == \"constant\":\n\n            self.node.mem = self.node.mem - self.inh * (x.max(1, True)[0] - x)\n\n        elif self.mode == \"max\":\n            self.node.mem = self.node.mem - self.inh * xori.max(1, True)[0] .detach() * (x.max(1, True)[0] - x)\n        elif self.mode == \"threshold\":\n            self.node.mem = self.node.mem - self.inh * self.node.threshold * (x.max(1, True)[0] - x)\n        else:\n            pass\n        return x\n"
  },
  {
    "path": "braincog/base/strategy/__init__.py",
    "content": "__all__ = ['surrogate', 'LateralInhibition']\n\nfrom . import (\n    surrogate,\n    LateralInhibition\n)\n"
  },
  {
    "path": "braincog/base/strategy/surrogate.py",
    "content": "import math\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\ndef heaviside(x):\n    return (x >= 0.).to(x.dtype)\n\n\nclass SurrogateFunctionBase(nn.Module):\n    \"\"\"\n    Surrogate Function 的基类\n    :param alpha: 为一些能够调控函数形状的代理函数提供参数.\n    :param requires_grad: 参数 ``alpha`` 是否需要计算梯度, 默认为 ``False``\n    \"\"\"\n\n    def __init__(self, alpha, requires_grad=True):\n        super().__init__()\n        self.alpha = nn.Parameter(\n            torch.tensor(alpha, dtype=torch.float),\n            requires_grad=requires_grad)\n\n    @staticmethod\n    def act_fun(x, alpha):\n        \"\"\"\n        :param x: 膜电位的输入\n        :param alpha: 控制代理梯度形状的变量, 可以为 ``NoneType``\n        :return: 激发之后的spike, 取值为 ``[0, 1]``\n        \"\"\"\n        raise NotImplementedError\n\n    def forward(self, x):\n        \"\"\"\n        :param x: 膜电位输入\n        :return: 激发之后的spike\n        \"\"\"\n        return self.act_fun(x, self.alpha)\n\n\n'''\n    sigmoid surrogate function.\n'''\n\n\nclass sigmoid(torch.autograd.Function):\n    \"\"\"\n    使用 sigmoid 作为代理梯度函数\n    对应的原函数为:\n\n    .. math::\n            g(x) = \\\\mathrm{sigmoid}(\\\\alpha x) = \\\\frac{1}{1+e^{-\\\\alpha x}}\n    反向传播的函数为:\n\n    .. math::\n            g'(x) = \\\\alpha * (1 - \\\\mathrm{sigmoid} (\\\\alpha x)) \\\\mathrm{sigmoid} (\\\\alpha x)\n\n    \"\"\"\n\n    @staticmethod\n    def forward(ctx, x, alpha):\n        if x.requires_grad:\n            ctx.save_for_backward(x)\n            ctx.alpha = alpha\n        return heaviside(x)\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        grad_x = None\n        if ctx.needs_input_grad[0]:\n            s_x = torch.sigmoid(ctx.alpha * ctx.saved_tensors[0])\n            grad_x = grad_output * s_x * (1 - s_x) * ctx.alpha\n        return grad_x, None\n\n\nclass SigmoidGrad(SurrogateFunctionBase):\n    def __init__(self, alpha=1., requires_grad=False):\n        super().__init__(alpha, requires_grad)\n\n    @staticmethod\n    def act_fun(x, alpha):\n        return sigmoid.apply(x, alpha)\n\n\n'''\n    atan surrogate function.\n'''\n\n\nclass atan(torch.autograd.Function):\n    \"\"\"\n    使用 Atan 作为代理梯度函数\n    对应的原函数为:\n\n    .. math::\n            g(x) = \\\\frac{1}{\\\\pi} \\\\arctan(\\\\frac{\\\\pi}{2}\\\\alpha x) + \\\\frac{1}{2}\n    反向传播的函数为:\n\n    .. math::\n            g'(x) = \\\\frac{\\\\alpha}{2(1 + (\\\\frac{\\\\pi}{2}\\\\alpha x)^2)}\n\n    \"\"\"\n\n    @staticmethod\n    def forward(ctx, inputs, alpha):\n        ctx.save_for_backward(inputs, alpha)\n        return inputs.gt(0.).float()\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        grad_x = None\n        grad_alpha = None\n\n        shared_c = grad_output / \\\n                   (1 + (ctx.saved_tensors[1] * math.pi /\n                         2 * ctx.saved_tensors[0]).square())\n        if ctx.needs_input_grad[0]:\n            grad_x = ctx.saved_tensors[1] / 2 * shared_c\n        if ctx.needs_input_grad[1]:\n            grad_alpha = (ctx.saved_tensors[0] / 2 * shared_c).sum()\n\n        return grad_x, grad_alpha\n\n\nclass AtanGrad(SurrogateFunctionBase):\n    def __init__(self, alpha=2., requires_grad=True):\n        super().__init__(alpha, requires_grad)\n\n    @staticmethod\n    def act_fun(x, alpha):\n        return atan.apply(x, alpha)\n\n\n'''\n    gate surrogate fucntion. \n'''\n\n\nclass gate(torch.autograd.Function):\n    \"\"\"\n    使用 gate 作为代理梯度函数\n    对应的原函数为:\n\n    .. math::\n            g(x) = \\\\mathrm{NonzeroSign}(x) \\\\log (|\\\\alpha x| + 1)\n    反向传播的函数为:\n\n    .. math::\n            g'(x) = \\\\frac{\\\\alpha}{1 + |\\\\alpha x|} = \\\\frac{1}{\\\\frac{1}{\\\\alpha} + |x|}\n\n    \"\"\"\n\n    @staticmethod\n    def forward(ctx, x, alpha):\n        if x.requires_grad:\n            grad_x = torch.where(x.abs() < 1. / alpha, torch.ones_like(x), torch.zeros_like(x))\n            ctx.save_for_backward(grad_x)\n        return x.gt(0).float()\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        grad_x = None\n        if ctx.needs_input_grad[0]:\n            grad_x = grad_output * ctx.saved_tensors[0]\n        return grad_x, None\n\n\nclass GateGrad(SurrogateFunctionBase):\n    def __init__(self, alpha=2., requires_grad=False):\n        super().__init__(alpha, requires_grad)\n\n    @staticmethod\n    def act_fun(x, alpha):\n        return gate.apply(x, alpha)\n\n\n'''\n    gatquadratic_gate surrogate function.\n'''\n\n\nclass quadratic_gate(torch.autograd.Function):\n    \"\"\"\n    使用 quadratic_gate 作为代理梯度函数\n    对应的原函数为:\n\n    .. math::\n        g(x) =\n        \\\\begin{cases}\n        0, & x < -\\\\frac{1}{\\\\alpha} \\\\\\\\\n        -\\\\frac{1}{2}\\\\alpha^2|x|x + \\\\alpha x + \\\\frac{1}{2}, & |x| \\\\leq \\\\frac{1}{\\\\alpha}  \\\\\\\\\n        1, & x > \\\\frac{1}{\\\\alpha} \\\\\\\\\n        \\\\end{cases}\n\n    反向传播的函数为:\n\n    .. math::\n        g'(x) =\n        \\\\begin{cases}\n        0, & |x| > \\\\frac{1}{\\\\alpha} \\\\\\\\\n        -\\\\alpha^2|x|+\\\\alpha, & |x| \\\\leq \\\\frac{1}{\\\\alpha}\n        \\\\end{cases}\n\n    \"\"\"\n\n    @staticmethod\n    def forward(ctx, x, alpha):\n        if x.requires_grad:\n            mask_zero = (x.abs() > 1 / alpha)\n            grad_x = -alpha * alpha * x.abs() + alpha\n            grad_x.masked_fill_(mask_zero, 0)\n            ctx.save_for_backward(grad_x)\n        return x.gt(0.).float()\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        grad_x = None\n        if ctx.needs_input_grad[0]:\n            grad_x = grad_output * ctx.saved_tensors[0]\n        return grad_x, None\n\n\nclass QGateGrad(SurrogateFunctionBase):\n    def __init__(self, alpha=2., requires_grad=False):\n        super().__init__(alpha, requires_grad)\n\n    @staticmethod\n    def act_fun(x, alpha):\n        return quadratic_gate.apply(x, alpha)\n\n\nclass relu_like(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, x, alpha):\n        if x.requires_grad:\n            ctx.save_for_backward(x, alpha)\n        return heaviside(x)\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        grad_x, grad_alpha = None, None\n        x, alpha = ctx.saved_tensors\n        if ctx.needs_input_grad[0]:\n            grad_x = grad_output * x.gt(0.).float() * alpha\n        if ctx.needs_input_grad[1]:\n            grad_alpha = (grad_output * F.relu(x)).sum()\n        return grad_x, grad_alpha\n\nclass RoundGrad(nn.Module):\n    def __init__(self, **kwargs):\n        super(RoundGrad, self).__init__()\n        self.act = nn.Hardtanh(-.5, 4.5)\n\n    def forward(self, x):\n        x = self.act(x)\n        return x.ceil() + x - x.detach()\n\nclass ReLUGrad(SurrogateFunctionBase):\n    \"\"\"\n    使用ReLU作为代替梯度函数, 主要用为相同结构的ANN的测试\n    \"\"\"\n\n    def __init__(self, alpha=2., requires_grad=False):\n        super().__init__(alpha, requires_grad)\n\n    @staticmethod\n    def act_fun(x, alpha):\n        return relu_like.apply(x, alpha)\n\n\n'''\n    Straight-Through (ST) Estimator\n'''\n\n\nclass straight_through_estimator(torch.autograd.Function):\n    \"\"\"\n    使用直通估计器作为代理梯度函数\n    http://arxiv.org/abs/1308.3432\n    \"\"\"\n\n    @staticmethod\n    def forward(ctx, inputs):\n        outputs = heaviside(inputs)\n        ctx.save_for_backward(outputs)\n        return outputs\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        grad_x = None\n        if ctx.needs_input_grad[0]:\n            grad_x = grad_output\n        return grad_x\n\n\nclass stdp(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, inputs):\n        outputs = inputs.gt(0.).float()\n        ctx.save_for_backward(outputs)\n        return outputs\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        inputs, = ctx.saved_tensors\n        return inputs * grad_output\n\n\nclass STDPGrad(SurrogateFunctionBase):\n    def __init__(self, alpha=2., requires_grad=False):\n        super().__init__(alpha, requires_grad)\n\n    @staticmethod\n    def act_fun(x, alpha):\n        return stdp.apply(x)\n\n\n\n\n\nclass backeigate(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, input):\n        ctx.save_for_backward(input)\n        return input.gt(0.).float()\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        input, = ctx.saved_tensors\n        grad_input = grad_output.clone()\n        temp = abs(input) < 0.5\n        return grad_input * temp.float()\n\n\nclass BackEIGateGrad(SurrogateFunctionBase):\n    def __init__(self, alpha=2., requires_grad=False):\n        super().__init__(alpha, requires_grad)\n\n    @staticmethod\n    def act_fun(x, alpha):\n        return backeigate.apply(x)\n\nclass ei(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, input):\n        ctx.save_for_backward(input)\n        return torch.sign(input).float()\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        input, = ctx.saved_tensors\n        grad_input = grad_output.clone()\n        temp = abs(input) < 0.5\n        return grad_input * temp.float()\n\n\nclass EIGrad(SurrogateFunctionBase):\n    def __init__(self, alpha=2., requires_grad=False):\n        super().__init__(alpha, requires_grad)\n\n    @staticmethod\n    def act_fun(x, alpha):\n        return ei.apply(x)\n\n\n"
  },
  {
    "path": "braincog/base/utils/__init__.py",
    "content": "from .criterions import UnilateralMse, MixLoss\nfrom .visualization import plot_tsne, plot_tsne_3d, plot_confusion_matrix\nfrom torch.autograd import Variable\nimport torch\n\n__all__ = [\n    'UnilateralMse', 'MixLoss',\n    'plot_tsne', 'plot_tsne_3d', 'plot_confusion_matrix', 'drop_path'\n]\n\n\ndef drop_path(x, drop_prob):\n    if drop_prob > 0.:\n        keep_prob = 1. - drop_prob\n        mask = Variable(torch.cuda.FloatTensor(\n            x.size(0), 1, 1, 1).bernoulli_(keep_prob))\n        x.div_(keep_prob)\n        x.mul_(mask)\n    return x\n"
  },
  {
    "path": "braincog/base/utils/criterions.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn.functional as F\n\nclass UnilateralMse(torch.nn.Module):\n    \"\"\"\n    扩展单边的MSE损失, 用于控制输出层的期望fire-rate 高于 thresh\n    :param thresh: 输出层的期望输出频率\n    \"\"\"\n    def __init__(self, thresh=1.):\n        super(UnilateralMse, self).__init__()\n        self.thresh = thresh\n        self.loss = torch.nn.MSELoss()\n\n    def forward(self, x, target):\n        # x = nn.functional.softmax(x, dim=1)\n        torch.clip(x, max=self.thresh)\n        if x.shape == target.shape:\n            return self.loss(x, target)\n        return self.loss(x, torch.zeros_like(x).scatter_(1, target.view(-1, 1), self.thresh))\n\n\nclass MixLoss(torch.nn.Module):\n    \"\"\"\n    混合损失函数, 可以将任意的损失函数与UnilateralMse损失混合\n    :param ce_loss: 任意的损失函数\n    \"\"\"\n    def __init__(self, ce_loss):\n        super(MixLoss, self).__init__()\n        self.ce = ce_loss\n        self.mse = UnilateralMse(1.)\n\n    def forward(self, x, target):\n        return 0.1 * self.ce(x, target) + self.mse(x, target)\n\n\nclass TetLoss(torch.nn.Module):\n    def __init__(self, loss_fn):\n        super(TetLoss, self).__init__()\n        self.loss_fn = loss_fn\n\n    def forward(self, x, target):\n        loss = 0.\n        for logit in x:\n            loss += self.loss_fn(logit, target)\n\n        return loss / x.shape[0]\n\n\nclass OnehotMse(torch.nn.Module):\n    \"\"\"\n    将类别转换为onehot进行mse损失计算, 用于带vote的SNN中\n    \"\"\"\n    def __init__(self, num_class):\n        super(OnehotMse, self).__init__()\n        self.num_class = num_class\n        self.loss_fn = torch.nn.MSELoss()\n\n    def forward(self, x, target):\n        target = F.one_hot(target.to(torch.int64), self.num_class).float()\n        loss = self.loss_fn(x, target)\n        return loss\n"
  },
  {
    "path": "braincog/base/utils/visualization.py",
    "content": "# encoding: utf-8\n# Author    : Floyed<Floyed_Shen@outlook.com>\n# Datetime  : 2022/7/1 11:10\n# User      : Floyed\n# Product   : PyCharm\n# Project   : braincog\n# File      : visualization.py\n# explain   : add t-SNE\n\nimport os\nimport numpy as np\nimport sklearn\nfrom sklearn.manifold import TSNE\nfrom sklearn.metrics import confusion_matrix\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange\nimport matplotlib.pyplot as plt\nimport matplotlib.patheffects as PathEffects\nimport matplotlib\nfrom mpl_toolkits.mplot3d import Axes3D\nfrom mpl_toolkits.mplot3d import proj3d\n\nimport seaborn as sns\n\n# Random state.\nRS = 20150101\n\n\n\ndef spike_rate_vis_1d(data, output_dir=''):\n    assert len(data.shape) == 2, 'Shape should be (t, c).'\n\n    data = rearrange(data, 'i j -> j i')\n    if isinstance(data, torch.Tensor):\n        data = data.to('cpu').numpy()\n\n    plt.figure(figsize=(8, 8))\n    sns.heatmap(data, annot=None, cmap='YlGnBu')\n    # plt.ylim(0, _max + 1)\n    plt.xlabel('Predicted labels')\n    plt.ylabel('True labels')\n    plt.show()\n\n\ndef spike_rate_vis(data, output_dir=''):\n    assert len(data.shape) == 3, 'Shape should be (t, r, c).'\n    data = data.mean(axis=0)\n\n    if isinstance(data, torch.Tensor):\n        data = data.to('cpu').numpy()\n\n    plt.figure(figsize=(8, 8))\n    sns.heatmap(data, annot=None, cmap='YlGnBu')\n    # plt.ylim(0, _max + 1)\n    plt.xlabel('Predicted labels')\n    plt.ylabel('True labels')\n    plt.show()\n\n\ndef plot_mem_distribution(data,\n                          output_dir='',\n                          legend='',\n                          xlabel='Membrane Potential',\n                          ylabel='Density',\n                          **kwargs):\n    # print(type(data), len(data))\n    if isinstance(data, torch.Tensor):\n        data = data.reshape(-1).to('cpu').numpy()\n\n    mean = data.mean()\n    std = data.std()\n    idx = np.argwhere(data < mean - 3 * std)\n    data = np.delete(data, idx)\n    idx = np.argwhere(data > mean + 3 * std)\n    data = np.delete(data, idx)\n    \n    sns.set_style('darkgrid')\n    # sns.set_palette('deep', desat=.6)\n    sns.set_context(\"notebook\", font_scale=1.5,\n                    rc={\"lines.linewidth\": 2.5})\n \n    # fig = plt.figure(figsize=(8, 8))\n    # ax = fig.add_subplot(111, aspect='equal')\n         \n    # sns.distplot(data, bins=int(np.sqrt(data.shape[0])),\n    #              hist=True, kde=False, hist_kws={'histtype': 'stepfilled'}, **kwargs)\n\n    # print('hist begin')\n    print(len(data))\n    n, bins, patches = plt.hist(data,\n                                density=True,\n                                histtype='stepfilled',\n                                alpha=0.618,\n                                bins=int(np.sqrt(data.shape[0])),\n                                **kwargs)\n    # print('hist finished')\n    # sns.kdeplot(data, color='#5294c3')\n    # print('kde finished')\n\n    plt.xlabel(xlabel)\n    plt.ylabel(ylabel)\n    # if legend != '':\n    #     plt.legend(legend)\n    # ax.axis('tight')\n\n    if output_dir != '':\n        plt.savefig(output_dir, bbox_inches='tight')\n        print('{} saved'.format(output_dir))\n    # plt.show()\n\n\ndef plot_tsne(x, colors,output_dir=\"\", num_classes=None):\n    if isinstance(x, torch.Tensor):\n        x = x.to('cpu').numpy()\n    if isinstance(colors, torch.Tensor):\n        colors = colors.to('cpu').numpy()\n\n    if num_classes is None:\n        num_classes=colors.max()+1\n    x = TSNE(random_state=RS, n_components=2).fit_transform(x)\n    sns.set_style('darkgrid')\n    sns.set_palette('muted')\n    sns.set_context(\"notebook\", font_scale=1.5,\n                    rc={\"lines.linewidth\": 2.5})\n    palette = np.array(sns.color_palette(\"hls\", num_classes))\n    fig = plt.figure(figsize=(8, 8))\n    ax = fig.add_subplot(111, aspect='equal')\n    sc = ax.scatter(x[:, 0], x[:, 1], lw=0, s=25,\n                    c=palette[colors.astype(np.int)])\n    # plt.xlim(-25, 25)\n    # plt.ylim(-25, 25)\n    # ax.axis('off')\n    ax.axis('tight')\n    # plt.grid('off')\n\n    plt.savefig(output_dir, facecolor=fig.get_facecolor(), bbox_inches='tight')\n    #plt.show()\n\n\ndef plot_tsne_3d(x, colors,output_dir=\"\", num_classes=None):\n    \"\"\"\n    绘制3D t-SNE聚类图, 直接将图片保存到输出路径\n    :param x: 输入的feature map / spike\n    :param colors: predicted labels 作为不同类别的颜色\n    :param output_dir: 图片输出的路径(包括图片名及后缀)\n    :return: None\n    \"\"\"\n    if isinstance(x, torch.Tensor):\n        x = x.to('cpu').numpy()\n    if isinstance(colors, torch.Tensor):\n        colors = colors.to('cpu').numpy()\n\n    if num_classes is None:\n        num_classes=colors.max()+1\n    x = TSNE(random_state=RS, n_components=3, perplexity=30).fit_transform(x)\n    # sns.set_style('darkgrid')\n    sns.set_palette('muted')\n    sns.set_context(\"notebook\", font_scale=1.5,\n                    rc={\"lines.linewidth\": 2.5})\n    fig = plt.figure(figsize=(8, 8))\n\n    palette = np.array(sns.color_palette(\"hls\", num_classes))\n    ax = fig.add_subplot(111, projection='3d')\n\n    sc = ax.scatter(x[:, 0], x[:, 1], x[:, 2], lw=0, s=20, alpha=0.8,\n                    c=palette[colors.astype(np.int)])\n\n    # ax.set_xlabel('X')\n    # ax.set_ylabel('Y')\n    # ax.set_zlabel('Z')\n    # ax.view_init(20, -120)\n    ax.axis('tight')\n    plt.savefig(output_dir, facecolor=fig.get_facecolor(), bbox_inches='tight')\n    #plt.show()\n\n\ndef plot_confusion_matrix(logits, labels, output_dir):\n    \"\"\"\n    绘制混淆矩阵图\n    :param logits: predicted labels\n    :param labels: true labels\n    :param output_dir: 输出路径, 需要包括文件名以及后缀\n    :return: None\n    \"\"\"\n    sns.set_style('darkgrid')\n    sns.set_palette('Blues_r')\n    sns.set_context(\"notebook\", font_scale=1.,\n                    rc={\"lines.linewidth\": 2.})\n\n    logits = logits.argmax(dim=1).cpu()\n    labels = labels.cpu()\n    _max = labels.max()\n    if _max > 10:\n        annot = False\n    else:\n        annot = True\n    # print(labels.shape, logits.shape)\n    conf_matrix = confusion_matrix(labels, logits)\n    con_mat_norm = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:, np.newaxis]  # 归一化\n    con_mat_norm = np.around(con_mat_norm, decimals=2)\n    plt.figure(figsize=(8, 8))\n    sns.heatmap(con_mat_norm, annot=annot, cmap='Blues')\n    plt.ylim(0, _max + 1)\n    plt.xlabel('Predicted labels')\n    plt.ylabel('True labels')\n\n    plt.savefig(output_dir, bbox_inches='tight')\n    #plt.show()\n\n\nif __name__ == '__main__':\n    # Test for T-SNE\n    # x = torch.randn((100, 100))\n    # y = torch.randint(low=0, high=10, size=[100])\n    # plot_tsne_3d(x, y, output_dir='./t-sne.eps')\n\n    # Test for confusion matrix\n    # x = torch.rand(5012, 100)\n    # y = torch.randint(0, 100, (5012,))\n    # plot_confusion_matrix(x, y, '')\n\n    # Test for Mem Distribution\n    x = torch.randn(100000)\n    plot_mem_distribution(x, legend=['test'])\n\n"
  },
  {
    "path": "braincog/datasets/CUB2002011.py",
    "content": "import os\n\nimport pandas as pd\nfrom torchvision.datasets import VisionDataset\nfrom torchvision.datasets.folder import default_loader\nfrom torchvision.datasets.utils import download_file_from_google_drive\n\n\nclass CUB2002011(VisionDataset):\n    \"\"\"`CUB-200-2011 <http://www.vision.caltech.edu/visipedia/CUB-200-2011.html>`_ Dataset.\n        Args:\n            root (string): Root directory of the dataset.\n            train (bool, optional): If True, creates dataset from training set, otherwise\n               creates from test set.\n            transform (callable, optional): A function/transform that  takes in an PIL image\n               and returns a transformed version. E.g, ``transforms.RandomCrop``\n            target_transform (callable, optional): A function/transform that takes in the\n               target and transforms it.\n            download (bool, optional): If true, downloads the dataset from the internet and\n               puts it in root directory. If dataset is already downloaded, it is not\n               downloaded again.\n    \"\"\"\n    base_folder = 'CUB_200_2011/images'\n    # url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'\n    file_id = '1hbzc_P1FuxMkcabkgn9ZKinBwW683j45'\n    filename = 'CUB_200_2011.tgz'\n    tgz_md5 = '97eceeb196236b17998738112f37df78'\n\n    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):\n        super(CUB2002011, self).__init__(root, transform=transform, target_transform=target_transform)\n\n        self.loader = default_loader\n        self.train = train\n        if download:\n            self._download()\n\n        if not self._check_integrity():\n            raise RuntimeError('Dataset not found or corrupted. You can use download=True to download it')\n\n    def _load_metadata(self):\n        images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ',\n                             names=['img_id', 'filepath'])\n        image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'),\n                                         sep=' ', names=['img_id', 'target'])\n        train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),\n                                       sep=' ', names=['img_id', 'is_training_img'])\n\n        data = images.merge(image_class_labels, on='img_id')\n        self.data = data.merge(train_test_split, on='img_id')\n\n        class_names = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'classes.txt'),\n                                  sep=' ', names=['class_name'], usecols=[1])\n        self.class_names = class_names['class_name'].to_list()\n        if self.train:\n            self.data = self.data[self.data.is_training_img == 1]\n        else:\n            self.data = self.data[self.data.is_training_img == 0]\n\n    def _check_integrity(self):\n        try:\n            self._load_metadata()\n        except Exception:\n            return False\n\n        for index, row in self.data.iterrows():\n            filepath = os.path.join(self.root, self.base_folder, row.filepath)\n            if not os.path.isfile(filepath):\n                print(filepath)\n                return False\n        return True\n\n    def _download(self):\n        import tarfile\n\n        if self._check_integrity():\n            print('Files already downloaded and verified')\n            return\n\n        download_file_from_google_drive(self.file_id, self.root, self.filename, self.tgz_md5)\n\n        with tarfile.open(os.path.join(self.root, self.filename), \"r:gz\") as tar:\n            tar.extractall(path=self.root)\n\n    def __len__(self):\n        return len(self.data)\n\n    def __getitem__(self, idx):\n        sample = self.data.iloc[idx]\n        path = os.path.join(self.root, self.base_folder, sample.filepath)\n        target = sample.target - 1  # Targets start at 1 by default, so shift to 0\n        img = self.loader(path)\n\n        if self.transform is not None:\n            img = self.transform(img)\n        if self.target_transform is not None:\n            target = self.target_transform(target)\n        return img, target\n\n\nif __name__ == '__main__':\n    train_dataset = CUB2002011('./cub2011', train=True, download=False)\n    test_dataset = CUB2002011('./cub2011', train=False, download=False)\n "
  },
  {
    "path": "braincog/datasets/ESimagenet/ES_imagenet.py",
    "content": "# -*- coding: utf-8 -*-            \n# Time : 2022/11/1 11:06\n# Author : Regulus\n# FileName: ES_imagenet.py\n# Explain: \n# Software: PyCharm\nimport numpy as np\nimport torch\nimport linecache\nimport torch.utils.data as data\n\n\nclass ESImagenet_Dataset(data.Dataset):\n    def __init__(self, mode, data_set_path='/data/dvsimagenet/', transform=None):\n        super().__init__()\n        self.mode = mode\n        self.filenames = []\n        self.trainpath = data_set_path + 'train'\n        self.testpath = data_set_path + 'val'\n        self.traininfotxt = data_set_path + 'trainlabel.txt'\n        self.testinfotxt = data_set_path + 'vallabel.txt'\n        self.formats = '.npz'\n        self.transform = transform\n        if mode == 'train':\n            self.path = self.trainpath\n            trainfile = open(self.traininfotxt, 'r')\n            for line in trainfile:\n                filename, classnum, a, b = line.split()\n                realname, sub = filename.split('.')\n                self.filenames.append(realname + self.formats)\n        else:\n            self.path = self.testpath\n            testfile = open(self.testinfotxt, 'r')\n            for line in testfile:\n                filename, classnum, a, b = line.split()\n                realname, sub = filename.split('.')\n                self.filenames.append(realname + self.formats)\n\n    def __getitem__(self, index):\n        if self.mode == 'train':\n            info = linecache.getline(self.traininfotxt, index + 1)\n        else:\n            info = linecache.getline(self.testinfotxt, index + 1)\n        filename, classnum, a, b = info.split()\n        realname, sub = filename.split('.')\n        filename = realname + self.formats\n        filename = self.path + r'/' + filename\n        classnum = int(classnum)\n        a = int(a)\n        b = int(b)\n        datapos = np.load(filename)['pos'].astype(np.float64)\n        dataneg = np.load(filename)['neg'].astype(np.float64)\n\n        dy = (254 - b) // 2\n        dx = (254 - a) // 2\n        input = torch.zeros([2, 8, 256, 256])\n\n        x = datapos[:, 0] + dx\n        y = datapos[:, 1] + dy\n        t = datapos[:, 2] - 1\n        input[0, t, x, y] = 1\n\n        x = dataneg[:, 0] + dx\n        y = dataneg[:, 1] + dy\n        t = dataneg[:, 2] - 1\n        input[1, t, x, y] = 1\n\n        reshape = input[:, :, 16:240, 16:240].permute(0, 1, 2, 3).contiguous()\n        if self.transform is not None:\n            reshape = self.transform(reshape)\n        label = torch.tensor([classnum])\n        return reshape, label\n\n    def __len__(self):\n        return len(self.filenames)"
  },
  {
    "path": "braincog/datasets/ESimagenet/__init__.py",
    "content": "# -*- coding: utf-8 -*-            \n# Time : 2022/11/1 11:05\n# Author : Regulus\n# FileName: __init__.py.py\n# Explain: \n# Software: PyCharm\n\n\"\"\"\nfrom: https://github.com/lyh983012/ES-imagenet-master\n\"\"\"\n\n__all__ = ['ES_imagenet', 'reconstructed_ES_imagenet']\nfrom . import (\n    ES_imagenet,\n    reconstructed_ES_imagenet\n)"
  },
  {
    "path": "braincog/datasets/ESimagenet/reconstructed_ES_imagenet.py",
    "content": "# -*- coding: utf-8 -*-            \n# Time : 2022/11/1 11:06\n# Author : Regulus\n# FileName: reconstructed_ES_imagenet.py\n# Explain: \n# Software: PyCharm\n\nimport numpy as np\nimport torch\nimport linecache\nimport torch.utils.data as data\nfrom tqdm import tqdm\n\nclass ESImagenet2D_Dataset(data.Dataset):\n    def __init__(self, mode, data_set_path='/data/ESimagenet-0.18/', transform=None):\n        super().__init__()\n        self.mode = mode\n        self.filenames = []\n        self.trainpath = data_set_path + 'train'\n        self.testpath = data_set_path + 'val'\n        self.traininfotxt = data_set_path + 'trainlabel.txt'\n        self.testinfotxt = data_set_path + 'vallabel.txt'\n        self.formats = '.npz'\n        self.transform = transform\n        if mode == 'train':\n            self.path = self.trainpath\n            trainfile = open(self.traininfotxt, 'r')\n            for line in trainfile:\n                filename, classnum, a, b = line.split()\n                realname, sub = filename.split('.')\n                self.filenames.append(realname + self.formats)\n            trainfile = open(self.traininfotxt, 'r')\n            self.infolist = trainfile.readlines()\n        else:\n            self.path = self.testpath\n            testfile = open(self.testinfotxt, 'r')\n            for line in testfile:\n                filename, classnum, a, b = line.split()\n                realname, sub = filename.split('.')\n                self.filenames.append(realname + self.formats)\n            testfile = open(self.testinfotxt, 'r')\n            self.infolist = testfile.readlines()\n\n    def __getitem__(self, index):\n        info = self.infolist[index]\n        filename, classnum, a, b = info.split()\n        realname, sub = filename.split('.')\n        filename = realname + self.formats\n        filename = self.path + r'/' + filename\n        classnum = int(classnum)\n        a = int(a)\n        b = int(b)\n        with open(filename, \"rb\") as f:\n            data = np.load(f)\n            datapos = data['pos'].astype(np.float64)\n            dataneg = data['neg'].astype(np.float64)\n        tracex = [0, 2, 1, 0, 2, 1, 1, 2]\n        tracey = [2, 1, 0, 1, 2, 0, 1, 1]\n\n        dy = (254 - b) // 2\n        dx = (254 - a) // 2\n        input = torch.zeros([2, 8, 256, 256])\n\n        x = datapos[:, 0] + dx\n        y = datapos[:, 1] + dy\n        t = datapos[:, 2] - 1\n        input[0, t, x, y] += 1\n\n        x = dataneg[:, 0] + dx\n        y = dataneg[:, 1] + dy\n        t = dataneg[:, 2] - 1\n        input[1, t, x, y] += 1\n\n        sum_gary_data = torch.zeros([1, 1, 256, 256])\n        reshape = input[:, :, 16:240, 16:240]\n        H = 224\n        W = 224\n        for t in range(8):\n            dx = tracex[t]\n            dy = tracey[t]\n            sum_gary_data[0, 0, 2 - dx:2 - dx + H, 2 - dy:2 - dy + W] += reshape[0, t, :, :]\n            sum_gary_data[0, 0, 2 - dx:2 - dx + H, 2 - dy:2 - dy + W] -= reshape[1, t, :, :]\n\n        sum_gary_data = sum_gary_data[:, :, 1:225, 1:225]\n        # if self.transform is not None:\n        #     sum_gary_data = self.transform(sum_gary_data)\n        label = classnum\n        return sum_gary_data, label\n\n    def __len__(self):\n        return len(self.filenames)\n"
  },
  {
    "path": "braincog/datasets/NOmniglot/NOmniglot.py",
    "content": "from torch.utils.data import Dataset\nfrom braincog.datasets.NOmniglot.utils import *\n\n\nclass NOmniglot(Dataset):\n    def __init__(self, root='data/', frames_num=12, train=True, data_type='event',\n                 transform=None, target_transform=None, use_npz=False, crop=True, create=True, thread_num=16):\n        super().__init__()\n        self.crop = crop\n        self.data_type = data_type\n        self.use_npz = use_npz\n        self.transform = transform\n        self.target_transform = target_transform\n        events_npy_root = os.path.join(root, 'events_npy', 'background' if train else \"evaluation\")\n\n        frames_root = os.path.join(root, f'fnum_{frames_num}_dtype_{data_type}_npz_{use_npz}',\n                                   'background' if train else \"evaluation\")\n\n        if not os.path.exists(frames_root) and create:\n            if not os.path.exists(events_npy_root) and create:\n                os.makedirs(events_npy_root)\n                print('creating event data..')\n                convert_aedat4_dir_to_events_dir(root, train)\n            else:\n                print(f'npy format events data root {events_npy_root}, already exists')\n\n            os.makedirs(frames_root)\n            print('creating frames data..')\n            convert_events_dir_to_frames_dir(events_npy_root, frames_root, '.npy', frames_num, data_type,\n                                             thread_num=thread_num, compress=use_npz)\n        else:\n            print(f'frames data root {frames_root} already exists.')\n\n        self.datadict, self.num_classes = list_class_files(events_npy_root, frames_root, True, use_npz=use_npz)\n\n        self.datalist = []\n        for i in self.datadict:\n            self.datalist.extend([(j, i) for j in self.datadict[i]])\n\n    def __len__(self):\n        return len(self.datalist)\n\n    def __getitem__(self, index):\n        image, label = self.datalist[index]\n        image, label = self.readimage(image, label)\n        return image, label\n\n    def readimage(self, image, label):\n        if self.use_npz:\n            image = torch.tensor(np.load(image)['arr_0']).float()\n        else:\n            image = torch.tensor(np.load(image)).float()\n        if self.crop:\n            image = image[:, :, 4:254, 54:304]\n        if self.transform is not None: image = self.transform(image)\n        if self.target_transform is not None: label = self.target_transform(label)\n        return image, label\n\n\n\n"
  },
  {
    "path": "braincog/datasets/NOmniglot/__init__.py",
    "content": "__all__ = ['NOmniglot', 'nomniglot_full', 'nomniglot_nw_ks','nomniglot_pair','utils']\r\nfrom . import (\r\n    NOmniglot,\r\n    nomniglot_full,\r\n    nomniglot_nw_ks,\r\n    nomniglot_pair,\r\n    utils\r\n)"
  },
  {
    "path": "braincog/datasets/NOmniglot/nomniglot_full.py",
    "content": "import torch\nfrom torch.utils.data import Dataset, DataLoader\nfrom braincog.datasets.NOmniglot.NOmniglot import NOmniglot\n\n\nclass NOmniglotfull(Dataset):\n    '''\n    solve few-shot learning as general classification problem,\n    We combine the original training set with the test set and take 3/4 as the training set\n    '''\n\n    def __init__(self, root='data/', train=True, frames_num=4, data_type='event',\n                 transform=None, target_transform=None, use_npz=False, crop=True, create=True):\n        super().__init__()\n\n        trainSet = NOmniglot(root=root, train=True, frames_num=frames_num, data_type=data_type,\n                             transform=transform, target_transform=target_transform,\n                             use_npz=use_npz, crop=crop, create=create)\n        testSet = NOmniglot(root=root, train=False, frames_num=frames_num, data_type=data_type,\n                            transform=transform, target_transform=lambda x: x + 964,\n                            use_npz=use_npz, crop=crop, create=create)\n        self.data = torch.utils.data.ConcatDataset([trainSet, testSet])\n        if train:\n            self.id = [j for j in range(len(self.data)) if j % 20 in [i for i in range(15)]]\n\n        else:\n            self.id = [j for j in range(len(self.data)) if j % 20 in [i for i in range(15, 20)]]\n\n    def __len__(self):\n        return len(self.id)\n\n    def __getitem__(self, index):\n        image, label = self.data[self.id[index]]\n        return image, label\n\n\nif __name__ == '__main__':\n    db_train = NOmniglotfull('../../data/', train=True, frames_num=4, data_type='event')\n    dataloadertrain = DataLoader(db_train, batch_size=16, shuffle=True, num_workers=16, pin_memory=True)\n    for x_spt, y_spt, x_qry, y_qry in dataloadertrain:\n        print(x_spt.shape)\n"
  },
  {
    "path": "braincog/datasets/NOmniglot/nomniglot_nw_ks.py",
    "content": "import torch\nimport torchvision\nimport numpy as np\nfrom torch.utils.data import Dataset, DataLoader\nfrom braincog.datasets.NOmniglot.NOmniglot import NOmniglot\n\n\nclass NOmniglotNWayKShot(Dataset):\n    '''\n    get n-wway k-shot data as meta learning\n    We set the sampling times of each epoch as \"len(self.dataSet) // (self.n_way * (self.k_shot + self.k_query))\"\n    you can increase or decrease the number of epochs to determine the total training times\n    '''\n\n    def __init__(self, root, n_way, k_shot, k_query, train=True, frames_num=12, data_type='event',\n                 transform=torchvision.transforms.Resize((28, 28))):\n        self.dataSet = NOmniglot(root=root, train=train,\n                                 frames_num=frames_num, data_type=data_type, transform=transform)\n        self.n_way = n_way  # n way\n        self.k_shot = k_shot  # k shot\n        self.k_query = k_query  # k query\n        assert (k_shot + k_query) <= 20\n        self.length = 256\n        self.data_cache = self.load_data_cache(self.dataSet.datadict, self.length)\n\n    def load_data_cache(self, data_dict, length):\n        '''\n        The dataset is sampled randomly length times, and the address is saved to obtain\n        '''\n        data_cache = []\n        for i in range(length):\n            selected_cls = np.random.choice(len(data_dict), self.n_way, False)\n\n            x_spts, y_spts, x_qrys, y_qrys = [], [], [], []\n            for j, cur_class in enumerate(selected_cls):\n                selected_img = np.random.choice(20, self.k_shot + self.k_query, False)\n\n                x_spts.append(np.array(data_dict[cur_class])[selected_img[:self.k_shot]])\n                x_qrys.append(np.array(data_dict[cur_class])[selected_img[self.k_shot:]])\n                y_spts.append([j for _ in range(self.k_shot)])\n                y_qrys.append([j for _ in range(self.k_query)])\n\n            shufflespt = np.random.choice(self.n_way * self.k_shot, self.n_way * self.k_shot, False)\n            shuffleqry = np.random.choice(self.n_way * self.k_query, self.n_way * self.k_query, False)\n\n            temp = [np.array(x_spts).reshape(-1)[shufflespt], np.array(y_spts).reshape(-1)[shufflespt],\n                    np.array(x_qrys).reshape(-1)[shuffleqry], np.array(y_qrys).reshape(-1)[shuffleqry]]\n            data_cache.append(temp)\n        return data_cache\n\n    def __getitem__(self, index):\n        x_spts, y_spts, x_qrys, y_qrys = self.data_cache[index]\n        x_sptst, y_sptst, x_qryst, y_qryst = [], [], [], []\n\n        for i, j in zip(x_spts, y_spts):\n            i, j = self.dataSet.readimage(i, j)\n            x_sptst.append(i.unsqueeze(0))\n            y_sptst.append(j)\n        for i, j in zip(x_qrys, y_qrys):\n            i, j = self.dataSet.readimage(i, j)\n            x_qryst.append(i.unsqueeze(0))\n            y_qryst.append(j)\n        return torch.cat(x_sptst, dim=0), np.array(y_sptst), torch.cat(x_qryst, dim=0), np.array(y_qryst)\n\n    def reset(self):\n        self.data_cache = self.load_data_cache(self.dataSet.datadict, self.length)\n\n    def __len__(self):\n        return len(self.data_cache)\n\n\nif __name__ == \"__main__\":\n    db_train = NOmniglotNWayKShot('./data/', n_way=5, k_shot=1, k_query=15,\n                                  frames_num=4, data_type='frequency', train=True)\n    dataloadertrain = DataLoader(db_train, batch_size=16, shuffle=True, num_workers=16, pin_memory=True)\n    for x_spt, y_spt, x_qry, y_qry in dataloadertrain:\n        print(x_spt.shape)\n    db_train.resampling()\n"
  },
  {
    "path": "braincog/datasets/NOmniglot/nomniglot_pair.py",
    "content": "import torch\nfrom torch.utils.data import Dataset, DataLoader\nimport numpy as np\nfrom numpy.random import choice as npc\nimport random\nimport torch.nn.functional as F\nfrom braincog.datasets.NOmniglot import NOmniglot\n\n\nclass NOmniglotTrainSet(Dataset):\n    '''\n    Dataloader for Siamese Net\n    The pairs of similar samples are labeled as 1, and those of different samples are labeled as 0\n    '''\n\n    def __init__(self, root='data/', use_frame=True, frames_num=10, data_type='event', use_npz=False, resize=None):\n        super(NOmniglotTrainSet, self).__init__()\n        self.resize = resize\n        self.data_type = data_type\n        self.use_frame = use_frame\n        self.dataSet = NOmniglot(root=root, train=True, frames_num=frames_num, data_type=data_type, use_npz=use_npz)\n        self.datas, self.num_classes = self.dataSet.datadict, self.dataSet.num_classes\n\n        np.random.seed(0)\n\n    def __len__(self):\n        '''\n        Sampling upper limit, you can set the maximum sampling times when using to terminate\n        '''\n        return 21000000\n\n    def __getitem__(self, index):\n        # get image from same class\n        if index % 2 == 1:\n            label = 1.0\n            idx1 = random.randint(0, self.num_classes - 1)\n            image1 = random.choice(self.datas[idx1])\n            image2 = random.choice(self.datas[idx1])\n        # get image from different class\n        else:\n            label = 0.0\n            idx1 = random.randint(0, self.num_classes - 1)\n            idx2 = random.randint(0, self.num_classes - 1)\n            while idx1 == idx2:\n                idx2 = random.randint(0, self.num_classes - 1)\n            image1 = random.choice(self.datas[idx1])\n            image2 = random.choice(self.datas[idx2])\n\n        if self.use_frame:\n            if self.data_type == 'event':\n                image1 = torch.tensor(np.load(image1)['arr_0']).float()\n                image2 = torch.tensor(np.load(image2)['arr_0']).float()\n            elif self.data_type == 'frequency':\n                image1 = torch.tensor(np.load(image1)['arr_0']).float()\n                image2 = torch.tensor(np.load(image2)['arr_0']).float()\n            else:\n                raise NotImplementedError\n\n        if self.resize is not None:\n            image1 = image1[:, :, 4:254, 54:304]\n            image1 = F.interpolate(image1, size=(self.resize, self.resize))\n            image2 = image2[:, :, 4:254, 54:304]\n            image2 = F.interpolate(image2, size=(self.resize, self.resize))\n\n        return image1, image2, torch.from_numpy(np.array([label], dtype=np.float32))\n\n\nclass NOmniglotTestSet(Dataset):\n    '''\n        Dataloader for Siamese Net\n\n        '''\n\n    def __init__(self, root='data/', time=1000, way=20, shot=1, query=1, use_frame=True, frames_num=10, data_type='event', use_npz=True, resize=None):\n        super(NOmniglotTestSet, self).__init__()\n        self.resize = resize\n        self.use_frame = use_frame\n        self.time = time         # Sampling times\n        self.way = way\n        self.shot = shot\n        self.query = query\n        self.img1 = None         # Fix test sample while sampling support set\n        self.c1 = None           # Fixed categories when sampling multiple samples\n        self.c2 = None\n        self.select_class = []   # selected classes\n        self.select_sample = []  # selected samples\n\n        self.data_type = data_type\n        np.random.seed(0)\n        self.dataSet = NOmniglot(root=root, train=False, frames_num=frames_num, data_type=data_type, use_npz=use_npz)\n        self.datas, self.num_classes = self.dataSet.datadict, self.dataSet.num_classes\n\n    def __len__(self):\n        '''\n        In general, the total number of test tasks is 1000.\n        Since one test sample is collected at a time, way * shot support samples are used for each test\n        '''\n        return self.time * self.way * self.shot\n\n    def __getitem__(self, index):\n        '''\n        The 0th sample of each way*shot is used for query and recorded in the selected sample\n        to achieve the effect of selecting K +1\n        '''\n        idx = index % (self.way * self.shot)\n        # generate image pair from same class\n        if idx == 0:  #\n            self.select_class = []\n            self.c1 = random.randint(0, self.num_classes - 1)\n            self.c2 = self.c1\n            sind = random.randint(0, len(self.datas[self.c1]) - 1)\n            self.select_sample.append(sind)\n            self.img1 = self.datas[self.c1][sind]\n\n            sind = random.randint(0, len(self.datas[self.c2]) - 1)\n            while sind in self.select_sample:\n                sind = random.randint(0, len(self.datas[self.c2]) - 1)\n            img2 = self.datas[self.c1][sind]\n            self.select_sample.append(sind)\n            self.select_class.append(self.c1)\n        # generate image pair from different class\n        else:\n            if index % self.shot == 0:\n                self.c2 = random.randint(0, self.num_classes - 1)\n                while self.c2 in self.select_class:  # self.c1 == c2:\n                    self.c2 = random.randint(0, self.num_classes - 1)\n                self.select_class.append(self.c2)\n                self.select_sample = []\n            sind = random.randint(0, len(self.datas[self.c2]) - 1)\n            while sind in self.select_sample:\n                sind = random.randint(0, len(self.datas[self.c2]) - 1)\n            img2 = self.datas[self.c2][sind]\n            self.select_sample.append(sind)\n\n        if self.use_frame:\n            if self.data_type == 'event':\n                img1 = torch.tensor(np.load(self.img1)['arr_0']).float()\n                img2 = torch.tensor(np.load(img2)['arr_0']).float()\n            elif self.data_type == 'frequency':\n                img1 = torch.tensor(np.load(self.img1)['arr_0']).float()\n                img2 = torch.tensor(np.load(img2)['arr_0']).float()\n            else:\n                raise NotImplementedError\n\n        if self.resize is not None:\n            img1 = img1[:, :, 4:254, 54:304]\n            img1 = F.interpolate(img1, size=(self.resize, self.resize))\n            img2 = img2[:, :, 4:254, 54:304]\n            img2 = F.interpolate(img2, size=(self.resize, self.resize))\n        return img1, img2\n\n\nif __name__ == '__main__':\n    data_type = 'frequency'\n    T = 4\n    trainSet = NOmniglotTrainSet(root='data/', use_frame=True, frames_num=T, data_type=data_type, use_npz=True, resize=105)\n    testSet = NOmniglotTestSet(root='data/', time=1000, way=5, shot=1, use_frame=True, frames_num=T,\n                               data_type=data_type, use_npz=True, resize=105)\n    trainLoader = DataLoader(trainSet, batch_size=48, shuffle=False, num_workers=4)\n    testLoader = DataLoader(testSet, batch_size=5 * 1, shuffle=False, num_workers=4)\n    for batch_id, (img1, img2) in enumerate(testLoader, 1):\n        # img1.shape [batch, T, 2, H, W]\n        print(batch_id)\n        break\n\n    for batch_id, (img1, img2, label) in enumerate(trainLoader, 1):\n        # img1.shape [batch, T, 2, H, W]\n        print(batch_id)\n        break\n"
  },
  {
    "path": "braincog/datasets/NOmniglot/utils.py",
    "content": "import torch\nimport threading\nimport numpy as np\nimport pandas\nimport os\nfrom dv import AedatFile\n\n\nclass FunctionThread(threading.Thread):\n    def __init__(self, f, *args, **kwargs):\n        super().__init__()\n        self.f = f\n        self.args = args\n        self.kwargs = kwargs\n\n    def run(self):\n        self.f(*self.args, **self.kwargs)\n\n\ndef integrate_events_to_frames(events, height, width, frames_num=10, data_type='event'):\n    frames = np.zeros(shape=[frames_num, 2, height * width])\n\n    # create j_{l}和j_{r}\n    j_l = np.zeros(shape=[frames_num], dtype=int)\n    j_r = np.zeros(shape=[frames_num], dtype=int)\n\n    # split by time\n    events['t'] -= events['t'][0]       # start with 0 timestamp\n    assert events['t'][-1] > frames_num\n    dt = events['t'][-1] // frames_num  # get length of each frame\n    idx = np.arange(events['t'].size)\n    for i in range(frames_num):\n        t_l = dt * i\n        t_r = t_l + dt\n        mask = np.logical_and(events['t'] >= t_l, events['t'] < t_r)\n        idx_masked = idx[mask]\n        if len(idx_masked) == 0:\n            j_l[i] = -1\n            j_r[i] = -1\n        else:\n            j_l[i] = idx_masked[0]\n            j_r[i] = idx_masked[-1] + 1 if i < frames_num - 1 else events['t'].size\n\n    for i in range(frames_num):\n        if j_l[i] >= 0:\n            x = events['x'][j_l[i]:j_r[i]]\n            y = events['y'][j_l[i]:j_r[i]]\n            p = events['p'][j_l[i]:j_r[i]]\n            mask = []\n            mask.append(p == 0)\n            mask.append(np.logical_not(mask[0]))\n            for j in range(2):\n                position = y[mask[j]] * width + x[mask[j]]\n                events_number_per_pos = np.bincount(position)\n                frames[i][j][np.arange(events_number_per_pos.size)] += events_number_per_pos\n\n        if data_type == 'frequency':\n            if i < frames_num - 1:\n                frames[i] /= dt\n            else:\n                frames[i] /= (dt + events['t'][-1] % frames_num)\n        frames = frames.astype(np.float16)\n\n    if data_type == 'event':\n        frames = (frames > 0).astype(np.bool)\n    else:\n        frames = normalize_frame(frames, 'max')\n    return frames.reshape((frames_num, 2, height, width))\n\n\ndef normalize_frame(frames: np.ndarray or torch.Tensor, normalization: str):\n    eps = 1e-5\n    for i in range(frames.shape[0]):\n        if normalization == 'max':\n            frames[i][0] = frames[i][0] / max(frames[i][0].max(), eps)\n            frames[i][1] = frames[i][1] / max(frames[i][1].max(), eps)\n\n        elif normalization == 'norm':\n            frames[i][0] = (frames[i][0] - frames[i][0].mean()) / np.sqrt(max(frames[i][0].var(), eps))\n            frames[i][1] = (frames[i][1] - frames[i][1].mean()) / np.sqrt(max(frames[i][1].var(), eps))\n\n        elif normalization == 'sum':\n            frames[i][0] = frames[i][0] / max(frames[i][0].sum(), eps)\n            frames[i][1] = frames[i][1] / max(frames[i][1].sum(), eps)\n\n        else:\n            raise NotImplementedError\n    return frames\n\n\ndef convert_events_dir_to_frames_dir(events_data_dir, frames_data_dir, suffix,\n                                     frames_num=12, result_type='event', thread_num=1,\n                                     compress=True):\n    \"\"\"\n    Iterate through all event data in eventS_date_DIR and generate frame data files in frames_data_DIR\n    \"\"\"\n    def read_function(file_name):\n        return np.load(file_name, allow_pickle=True).item()\n\n    def cvt_fun(events_file_list):\n        for events_file in events_file_list:\n            print(events_file)\n            frames = integrate_events_to_frames(read_function(events_file), 260, 346, frames_num, result_type  )\n            if compress:\n                frames_file = os.path.join(frames_data_dir,\n                                           os.path.basename(events_file)[0: -suffix.__len__()] + '.npz')\n                np.savez_compressed(frames_file, frames)\n            else:\n                frames_file = os.path.join(frames_data_dir,\n                                           os.path.basename(events_file)[0: -suffix.__len__()] + '.npy')\n                np.save(frames_file, frames)\n\n    # Obtain the path of the all files\n    events_file_list = list_all_files(events_data_dir, '.npy')\n\n    if thread_num == 1:\n        cvt_fun(events_file_list)\n    else:\n        # Multithreading acceleration\n        thread_list = []\n        block = events_file_list.__len__() // thread_num\n        for i in range(thread_num - 1):\n            thread_list.append(FunctionThread(cvt_fun, events_file_list[i * block: (i + 1) * block]))\n            thread_list[-1].start()\n            print(f'thread {i} start, processing files index: {i * block} : {(i + 1) * block}.')\n        thread_list.append(FunctionThread(cvt_fun, events_file_list[(thread_num - 1) * block:]))\n        thread_list[-1].start()\n        print(\n            f'thread {thread_num} start, processing files index: {(thread_num - 1) * block} : {events_file_list.__len__()}.')\n        for i in range(thread_num):\n            thread_list[i].join()\n            print(f'thread {i} finished.')\n\n\ndef convert_aedat4_dir_to_events_dir(root, train):\n    kind = 'background' if train else \"evaluation\"\n    originroot = root\n    root = root + '/dvs_' + kind + '/'\n    alphabet_names = [a for a in os.listdir(root) if a[0] != '.']  # get folder names\n\n    for a in range(len(alphabet_names)):\n        alpha_name = alphabet_names[a]\n\n        for b in range(len(os.listdir(os.path.join(root, alpha_name)))):\n            character_id = b + 1\n            character_path = alpha_name + '/character' + num2str(character_id)\n            print('Parsing %s \\\\ character%s ...' % (alpha_name, num2str(character_id)))\n\n            file_path = os.path.join(root, character_path)\n            aedat4_name = [a for a in os.listdir(file_path) if a[-4:] == 'dat4' and len(a) == 11][0]\n            csv_name = [a for a in os.listdir(file_path) if a[-4:] == '.csv' and len(a) == 8][0]\n            number = csv_name[:4]\n            new_path = originroot + '/events_npy/' + kind + '/' + alpha_name + '/character' + num2str(character_id)\n            if not os.path.exists(new_path):\n                os.makedirs(new_path)\n\n            start_end_timestamp = pandas.read_csv(os.path.join(file_path, csv_name)).values\n\n            a_timestamp, a_polarity, a_x, a_y = [], [], [], []\n            with AedatFile(os.path.join(file_path, aedat4_name)) as f:  # read aedat4\n                for e in f['events']:\n                    a_timestamp.append(e.timestamp)\n                    a_polarity.append(e.polarity)\n                    a_x.append(e.x)\n                    a_y.append(e.y)\n\n            for ii in range(20):  # each file has 20 samples\n                name = str(number) + '_' + num2str(ii + 1) + '.npy'\n                start_index = a_timestamp.index(start_end_timestamp[ii][1])\n                end_index = a_timestamp.index(start_end_timestamp[ii][2])\n                tmp = {'t': np.array(a_timestamp[start_index:end_index]),\n                       'x': np.array(a_x[start_index:end_index]),\n                       'y': np.array(a_y[start_index:end_index]),\n                       'p': np.array(a_polarity[start_index:end_index])}\n                np.save(os.path.join(new_path, name), tmp)\n\n\ndef num2str(idx):\n    if idx < 10:\n        return '0' + str(idx)\n    return str(idx)\n\n\ndef list_all_files(root, suffix, getlen=False):\n    '''\n    List the path of all files under root, output a list\n    '''\n    file_list = []\n    alphabet_names = [a for a in os.listdir(root) if a[0] != '.']  # get folder names\n    idx = 0\n    for a in range(len(alphabet_names)):\n        alpha_name = alphabet_names[a]\n        for b in range(len(os.listdir(os.path.join(root, alpha_name)))):\n            character_id = b + 1\n            character_path = os.path.join(root, alpha_name, 'character' + num2str(character_id))\n            idx += 1\n            for c in range(len(os.listdir(character_path))):\n                fn_example = os.listdir(character_path)[c]\n                if fn_example[-4:] == suffix:\n                    file_list.append(os.path.join(character_path, fn_example))\n    if getlen:\n        return file_list, idx\n    else:\n        return file_list\n\n\ndef list_class_files(root, frames_kind_root, getlen=False, use_npz=False):\n    '''\n    index the generated samples,\n    get dictionaries according to categories, each corresponding to a list,\n    the list contain the address of the new file in fnum_x_dtype_x_npz_True\n    '''\n    file_list = {}\n    alphabet_names = [a for a in os.listdir(root) if a[0] != '.']  # get folder names\n    idx = 0\n    for a in range(len(alphabet_names)):\n        alpha_name = alphabet_names[a]\n        for b in range(len(os.listdir(os.path.join(root, alpha_name)))):\n            character_id = b + 1\n            character_path = os.path.join(root, alpha_name, 'character' + num2str(character_id))\n            file_list[idx] = []\n            for c in range(len(os.listdir(character_path))):\n                fn_example = os.listdir(character_path)[c]\n                if use_npz:\n                    fn_example = fn_example[:-1] + 'z'\n                file_list[idx].append(os.path.join(frames_kind_root, fn_example))\n            idx += 1\n    if getlen:\n        return file_list, idx\n    else:\n        return file_list\n"
  },
  {
    "path": "braincog/datasets/StanfordDogs.py",
    "content": "import os\nimport scipy.io\nfrom os.path import join\nfrom torchvision.datasets import VisionDataset\nfrom torchvision.datasets.folder import default_loader\nfrom torchvision.datasets.utils import download_url, list_dir\n\n\nclass StanfordDogs(VisionDataset):\n    \"\"\"`Stanford Dogs <http://vision.stanford.edu/aditya86/ImageNetDogs/>`_ Dataset.\n        Args:\n            root (string): Root directory of the dataset.\n            train (bool, optional): If True, creates dataset from training set, otherwise\n               creates from test set.\n            transform (callable, optional): A function/transform that  takes in an PIL image\n               and returns a transformed version. E.g, ``transforms.RandomCrop``\n            target_transform (callable, optional): A function/transform that takes in the\n               target and transforms it.\n            download (bool, optional): If true, downloads the dataset from the internet and\n               puts it in root directory. If dataset is already downloaded, it is not\n               downloaded again.\n    \"\"\"\n    download_url_prefix = 'http://vision.stanford.edu/aditya86/ImageNetDogs'\n\n    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):\n        super(StanfordDogs, self).__init__(root, transform=transform, target_transform=target_transform)\n\n        self.loader = default_loader\n        self.train = train\n\n        if download:\n            self.download()\n\n        split = self.load_split()\n\n        self.images_folder = join(self.root, 'Images')\n        self.annotations_folder = join(self.root, 'Annotation')\n        self._breeds = list_dir(self.images_folder)\n\n        self._breed_images = [(annotation + '.jpg', idx) for annotation, idx in split]\n\n        self._flat_breed_images = self._breed_images\n\n    def __len__(self):\n        return len(self._flat_breed_images)\n\n    def __getitem__(self, index):\n        image_name, target = self._flat_breed_images[index]\n        image_path = join(self.images_folder, image_name)\n        image = self.loader(image_path)\n\n        if self.transform is not None:\n            image = self.transform(image)\n        if self.target_transform is not None:\n            target = self.target_transform(target)\n        return image, target\n\n    def download(self):\n        import tarfile\n\n        if os.path.exists(join(self.root, 'Images')) and os.path.exists(join(self.root, 'Annotation')):\n            if len(os.listdir(join(self.root, 'Images'))) == len(os.listdir(join(self.root, 'Annotation'))) == 120:\n                print('Files already downloaded and verified')\n                return\n\n        for filename in ['images', 'annotation', 'lists']:\n            tar_filename = filename + '.tar'\n            url = self.download_url_prefix + '/' + tar_filename\n            download_url(url, self.root, tar_filename, None)\n            print('Extracting downloaded file: ' + join(self.root, tar_filename))\n            with tarfile.open(join(self.root, tar_filename), 'r') as tar_file:\n                tar_file.extractall(self.root)\n            os.remove(join(self.root, tar_filename))\n\n    def load_split(self):\n        if self.train:\n            split = scipy.io.loadmat(join(self.root, 'train_list.mat'))['annotation_list']\n            labels = scipy.io.loadmat(join(self.root, 'train_list.mat'))['labels']\n        else:\n            split = scipy.io.loadmat(join(self.root, 'test_list.mat'))['annotation_list']\n            labels = scipy.io.loadmat(join(self.root, 'test_list.mat'))['labels']\n\n        split = [item[0][0] for item in split]\n        labels = [item[0] - 1 for item in labels]\n        return list(zip(split, labels))\n\n    def stats(self):\n        counts = {}\n        for index in range(len(self._flat_breed_images)):\n            image_name, target_class = self._flat_breed_images[index]\n            if target_class not in counts.keys():\n                counts[target_class] = 1\n            else:\n                counts[target_class] += 1\n\n        print(\"%d samples spanning %d classes (avg %f per class)\" % (len(self._flat_breed_images), len(counts.keys()),\n                                                                     float(len(self._flat_breed_images)) / float(\n                                                                         len(counts.keys()))))\n\n        return counts\n\n\nif __name__ == '__main__':\n    train_dataset = Dogs('./dogs', train=True, download=False)\n    test_dataset = Dogs('./dogs', train=False, download=False)"
  },
  {
    "path": "braincog/datasets/TinyImageNet.py",
    "content": "import os\nimport os\nimport pandas as pd\nimport warnings\nfrom torchvision.datasets import ImageFolder\nfrom torchvision.datasets import VisionDataset\nfrom torchvision.datasets.folder import default_loader\nfrom torchvision.datasets.folder import default_loader\nfrom torchvision.datasets.utils import extract_archive, check_integrity, download_url, verify_str_arg\n\n\nclass TinyImageNet(VisionDataset):\n    \"\"\"`tiny-imageNet <http://cs231n.stanford.edu/tiny-imagenet-200.zip>`_ Dataset.\n        Args:\n            root (string): Root directory of the dataset.\n            split (string, optional): The dataset split, supports ``train``, or ``val``.\n            transform (callable, optional): A function/transform that  takes in an PIL image\n               and returns a transformed version. E.g, ``transforms.RandomCrop``\n            target_transform (callable, optional): A function/transform that takes in the\n               target and transforms it.\n            download (bool, optional): If true, downloads the dataset from the internet and\n               puts it in root directory. If dataset is already downloaded, it is not\n               downloaded again.\n    \"\"\"\n    base_folder = 'tiny-imagenet-200/'\n    url = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip'\n    filename = 'tiny-imagenet-200.zip'\n    md5 = '90528d7ca1a48142e341f4ef8d21d0de'\n\n    def __init__(self, root, split='train', transform=None, target_transform=None, download=False):\n        super(TinyImageNet, self).__init__(root, transform=transform, target_transform=target_transform)\n\n        self.dataset_path = os.path.join(root, self.base_folder)\n        self.loader = default_loader\n        self.split = verify_str_arg(split, \"split\", (\"train\", \"val\",))\n\n        if self._check_integrity():\n            print('Files already downloaded and verified.')\n        elif download:\n            self._download()\n        else:\n            raise RuntimeError(\n                'Dataset not found. You can use download=True to download it.')\n        if not os.path.isdir(self.dataset_path):\n            print('Extracting...')\n            extract_archive(os.path.join(root, self.filename))\n\n        _, class_to_idx = find_classes(os.path.join(self.dataset_path, 'wnids.txt'))\n\n        self.data = make_dataset(self.root, self.base_folder, self.split, class_to_idx)\n\n    def _download(self):\n        print('Downloading...')\n        download_url(self.url, root=self.root, filename=self.filename)\n        print('Extracting...')\n        extract_archive(os.path.join(self.root, self.filename))\n\n    def _check_integrity(self):\n        return check_integrity(os.path.join(self.root, self.filename), self.md5)\n\n    def __getitem__(self, index):\n        img_path, target = self.data[index]\n        image = self.loader(img_path)\n\n        if self.transform is not None:\n            image = self.transform(image)\n        if self.target_transform is not None:\n            target = self.target_transform(target)\n\n        return image, target\n\n    def __len__(self):\n        return len(self.data)\n\n\ndef find_classes(class_file):\n    with open(class_file) as r:\n        classes = list(map(lambda s: s.strip(), r.readlines()))\n\n    classes.sort()\n    class_to_idx = {classes[i]: i for i in range(len(classes))}\n\n    return classes, class_to_idx\n\n\ndef make_dataset(root, base_folder, dirname, class_to_idx):\n    images = []\n    dir_path = os.path.join(root, base_folder, dirname)\n\n    if dirname == 'train':\n        for fname in sorted(os.listdir(dir_path)):\n            cls_fpath = os.path.join(dir_path, fname)\n            if os.path.isdir(cls_fpath):\n                cls_imgs_path = os.path.join(cls_fpath, 'images')\n                for imgname in sorted(os.listdir(cls_imgs_path)):\n                    path = os.path.join(cls_imgs_path, imgname)\n                    item = (path, class_to_idx[fname])\n                    images.append(item)\n    else:\n        imgs_path = os.path.join(dir_path, 'images')\n        imgs_annotations = os.path.join(dir_path, 'val_annotations.txt')\n\n        with open(imgs_annotations) as r:\n            data_info = map(lambda s: s.split('\\t'), r.readlines())\n\n        cls_map = {line_data[0]: line_data[1] for line_data in data_info}\n\n        for imgname in sorted(os.listdir(imgs_path)):\n            path = os.path.join(imgs_path, imgname)\n            item = (path, class_to_idx[cls_map[imgname]])\n            images.append(item)\n\n    return images\n\n\nif __name__ == '__main__':\n    train_dataset = TinyImageNet('./tiny-imagenet', split='train', download=False)\n    test_dataset = TinyImageNet('./tiny-imagenet', split='val', download=False)\n \n"
  },
  {
    "path": "braincog/datasets/__init__.py",
    "content": "from .datasets import build_transform, build_dataset, get_mnist_data, get_fashion_data, \\\n    get_cifar10_data, get_cifar100_data, get_imnet_data, get_dvsg_data, get_dvsc10_data, \\\n    get_NCALTECH101_data, get_NCARS_data, get_nomni_data, get_bullyingdvs_data\nfrom .utils import rescale, dvs_channel_check_expend\n\nfrom .hmdb_dvs import HMDBDVS\nfrom .ucf101_dvs import ucf101_dvs\nfrom .ncaltech101 import NCALTECH101\nfrom .bullying10k import BULLYINGDVS\n\n__all__ = [\n    'build_transform', 'build_dataset',\n    'get_mnist_data', 'get_fashion_data', 'get_cifar10_data', 'get_cifar100_data', 'get_imnet_data',\n    'get_dvsg_data', 'get_dvsc10_data', 'get_NCALTECH101_data', 'get_NCARS_data', 'get_nomni_data',\n    'rescale', 'dvs_channel_check_expend', 'get_bullyingdvs_data'\n]\n\n\ndvs_data = [\n    'dvsg',\n    'dvsc10',\n    'ncaltech101',\n    'ncars',\n    'dvsg',\n    'ucf101dvs',\n    'hmdbdvs',\n    'shd',\n    'ntidigits',\n    'nmnist'\n]\n\n\ndef is_dvs_data(dataset):\n    if dataset.lower() in dvs_data:\n        return True\n    else:\n        return False\n"
  },
  {
    "path": "braincog/datasets/bullying10k/__init__.py",
    "content": "from .bullying10k import BULLYINGDVS"
  },
  {
    "path": "braincog/datasets/bullying10k/bullying10k.py",
    "content": "import os\nimport numpy as np\nfrom numpy.lib import recfunctions\nimport scipy.io as scio\nfrom typing import Tuple, Any, Optional\nfrom tonic.dataset import Dataset\nfrom tonic.download_utils import extract_archive\nimport dv\n\n\nclass BULLYINGDVS(Dataset):\n    classes = [\"fingerguess\", \"greeting\", \"hairgrabs\", \"handshake\", \"kicking\",\n               \"punching\", \"pushing\", \"slapping\", \"strangling\", \"walking\"]\n    class_dict = {cls: idx for idx, cls in enumerate(classes)}\n\n    sensor_size = (346, 260, 2)\n    dtype = np.dtype([(\"t\", int), (\"x\", int), (\"y\", int), (\"p\", int)])\n    ordering = dtype.names\n\n    def __init__(self, save_to, transform=None, target_transform=None):\n        super(BULLYINGDVS, self).__init__(\n            save_to, transform=transform, target_transform=target_transform\n        )\n        self.aedat4 = True\n\n        for path, dirs, files in os.walk(self.location_on_system):\n            dirs.sort()\n            files.sort()\n            for file in files:\n                if file.endswith(\"aedat4\"):\n                    self.data.append(path + \"/\" + file)\n                    self.targets.append(self.class_dict[path.split('/')[-2]])\n\n                if file.endswith(\"npy\"):\n                    self.aedat4 = False\n                    self.data.append(path + \"/\" + file)\n                    self.targets.append(self.class_dict[path.split('/')[-2]])\n\n\n    def __getitem__(self, index: int) -> Tuple[Any, Any]:\n        \"\"\"\n        Returns:\n            (events, target) where target is index of the target class.\n        \"\"\"\n        if self.aedat4:\n            events, target = dv.AedatFile(self.data[index])['events'], self.targets[index]\n            events = np.concatenate([event for event in events.numpy()])\n        else:\n            events = np.concatenate(np.load(self.data[index], allow_pickle=True))\n\n        events = np.column_stack(\n            [\n                events['timestamp'] - events['timestamp'][0],\n                events['x'],\n                events['y'],\n                events['polarity']\n            ]\n        )\n\n        events = np.lib.recfunctions.unstructured_to_structured(events, self.dtype)\n        if self.transform is not None:\n            events = self.transform(events)\n        if self.target_transform is not None:\n            target = self.target_transform(target)\n        return events, target\n\n    def __len__(self):\n        return len(self.data)\n\n    def _check_exists(self):\n        return True\n"
  },
  {
    "path": "braincog/datasets/cut_mix.py",
    "content": "import math\nimport numpy as np\nimport random\nfrom torch.utils.data.dataset import Dataset\nfrom braincog.datasets.rand_aug import SaltAndPepperNoise\nimport numpy as np\nimport torch\nfrom torch.nn import functional as F\n\n\ndef event_difference(x1, x2, kernel_size=3):\n    padding = kernel_size // 2\n    x1 = F.avg_pool2d(x1, kernel_size=kernel_size, stride=1, padding=padding)\n    x2 = F.avg_pool2d(x2, kernel_size=kernel_size, stride=1, padding=padding)\n    return F.mse_loss(x1, x2)\n\n\ndef onehot(size, target):\n    vec = torch.zeros(size, dtype=torch.float32)\n    vec[target] = 1.\n    return vec\n\n\ndef rand_bbox_time(size, rat):\n    if len(size) == 4:  # step, channel, height, width\n        step = size[0]\n    else:\n        raise Exception\n\n    cut_t = np.int(step * rat)\n    ct = np.random.randint(step)\n    bbt1 = np.clip(ct - cut_t // 2, 0, step)\n    bbt2 = np.clip(ct + cut_t // 2, 0, step)\n\n    return bbt1, bbt2\n\n\ndef rand_bbox(size, rat):\n    if len(size) == 4:\n        W = size[2]\n        H = size[3]\n    else:\n        raise Exception\n\n    cut_rat = np.sqrt(rat)\n    cut_w = np.int(W * cut_rat)\n    cut_h = np.int(H * cut_rat)\n\n    # uniform\n    cx = np.random.randint(W)\n    cy = np.random.randint(H)\n\n    bbx1 = np.clip(cx - cut_w // 2, 0, W)\n    bby1 = np.clip(cy - cut_h // 2, 0, H)\n    bbx2 = np.clip(cx + cut_w // 2, 0, W)\n    bby2 = np.clip(cy + cut_h // 2, 0, H)\n\n    return bbx1, bby1, bbx2, bby2\n\n\ndef calc_lam(x1, x2, bbt1, bbt2, bbx1, bbx2, bby1, bby2):\n    tot_x1 = x1.sum()\n    tot_x2 = x2.sum()\n    tot_bb1 = x1[bbt1:bbt2, :, bbx1:bbx2, bby1:bby2].sum()\n    tot_bb2 = x2[bbt1:bbt2, :, bbx1:bbx2, bby1:bby2].sum()\n    x1_rat = tot_bb1 / tot_x1\n    x2_rat = tot_bb2 / tot_x2\n    lam = 1. - (x2_rat / (1. - x1_rat + x2_rat))\n    return lam\n\n\ndef rand_bbox_st(size, rat):\n    temporal_rat = np.random.uniform(rat, 1.)\n    wh_rat = rat / temporal_rat\n    bbt1, bbt2 = rand_bbox_time(size, temporal_rat)\n    bbx1, bby1, bbx2, bby2 = rand_bbox(size, wh_rat)\n    return bbt1, bbt2, bbx1, bby1, bbx2, bby2\n\n\ndef spatio_mask(size, rat):\n    t = size[0]\n    x = torch.rand(2, 2)\n    y = torch.rand(2, 2)\n    f = torch.zeros(*size[-2:], dtype=torch.complex64)\n    # f[0:2, 0:2] = x + y * 0.j\n    f[[[0, -1], [-1, -1]], [[0, -1], [0, -1]]] = x + y * 1.j\n    mask = torch.fft.ifftn(f).real\n    idx = int(np.prod(size[-2:]) * rat)\n    val = mask.flatten().sort()[0][idx]\n\n    return (mask < val).unsqueeze(0).unsqueeze(0).repeat(t, 2, 1, 1)\n\n\ndef temporal_mask(size, rat):\n    bbt1, bbt2 = rand_bbox_time(size, rat)\n    mask = torch.zeros(*size, dtype=torch.bool)\n    mask[bbt1:bbt2] = True\n    return mask\n\n\ndef st_mask(size, rat):\n    t = size[0]\n    temporal_rat = np.random.uniform(rat, 1.)\n    wh_rat = rat / temporal_rat\n    bbt1, bbt2 = rand_bbox_time(size, temporal_rat)\n    mask = spatio_mask(size, wh_rat)\n    mask[0:bbt1] = False\n    mask[bbt2:t] = False\n    return mask\n\n\ndef GMM_mask_clip(size, rat):\n    t = size[0]\n    temporal_rat = np.random.uniform(rat, 1.)\n    wh_rat = rat / temporal_rat\n    bbt1, bbt2 = rand_bbox_time(size, temporal_rat)\n    mask = GMM_mask(size, wh_rat)\n    mask[0:bbt1] = False\n    mask[bbt2:t] = False\n    return mask\n\n\ndef GMM_mask(size, rat, n=None):\n    if n is None:\n        n = np.random.randint(2, 5)\n    pi = torch.tensor(np.random.rand(n))\n    # pi = torch.ones(n) / n\n\n    mask = torch.zeros((size[0], size[2], size[3]))\n    t = torch.tensor(list(range(size[0])))\n    x = torch.tensor(list(range(size[2])))\n    y = torch.tensor(list(range(size[3])))\n    t, x, y = torch.meshgrid(t, x, y, indexing='ij')\n\n    for p in pi:\n        mt = np.random.randint(0, size[0])\n        mx = np.random.randint(0, size[2])\n        my = np.random.randint(0, size[3])\n        # print(mt, mx, my)\n        st = max(np.random.rand(), 0.1) * size[0] * 0.5\n        sx = max(np.random.rand(), 0.1) * size[2] * .5\n        sy = max(np.random.rand(), 0.1) * size[3] * .5\n        # st, sx, sy = size[0], 0000.5 * size[2], 0000.5 * size[3]\n        # print(st, sx, sy)\n        tt = t - mt\n        xx = x - mx\n        yy = y - my\n        tmp = -((tt ** 2) / (st ** 2) + (xx ** 2) / (sx ** 2) + (yy ** 2) / (sy ** 2)) / 2\n        mask += p * tmp.exp()\n\n    idx = int(np.prod(mask.shape) * rat)\n    val = mask.flatten().sort()[0][idx - 1]\n    return (mask > val).unsqueeze(1).repeat(1, 2, 1, 1)\n    # return mask.unsqueeze(1).repeat(1, 2, 1, 1)\n\n# FOR EVENT VIS\n# def spatio_mask(size, rat):\n#     t = size[0]\n#     x = torch.rand(2, 2)\n#     y = torch.rand(2, 2)\n#     f = torch.zeros(*size[-2:], dtype=torch.complex64)\n#     # f[0:2, 0:2] = x + y * 0.j\n#     f[[[0, -1], [-1, -1]], [[0, -1], [0, -1]]] = x + y * 1.j\n#\n#     f = f.unsqueeze(0).repeat(t, 1, 1)\n#     f[1:-2, :, :] = 0\n#\n#     mask = torch.fft.ifftn(f).real\n#     # print(mask.shape)\n#     idx = int(np.prod(mask.shape) * 0.6)\n#     # print(idx)\n#     val = mask.flatten().sort()[0][idx]\n#     print(mask.unsqueeze(1).repeat(1, 2, 1, 1).shape)\n#     return (mask < val).unsqueeze(1).repeat(1, 2, 1, 1)\n#\n# def st_mask(size, rat):\n#     # t = size[0]\n#     # temporal_rat = np.random.uniform(rat, 1.)\n#     # wh_rat = rat / temporal_rat\n#     wh_rat = rat\n#     # bbt1, bbt2 = rand_bbox_time(size, temporal_rat)\n#     mask = spatio_mask(size, wh_rat)\n#     # mask[0:bbt1] = False\n#     # mask[bbt2:t] = False\n#     return mask\n\n\ndef calc_masked_lam(x1, x2, mask):\n    tot_x1 = x1.sum()\n    tot_x2 = x2.sum()\n    tot_mask1 = x1[mask].sum()\n    tot_mask2 = x2[mask].sum()\n    x1_rat = tot_mask1 / tot_x1\n    x2_rat = tot_mask2 / tot_x2\n    lam = 1. - (x2_rat / (1. - x1_rat + x2_rat))\n    # print(tot_x1, tot_x2, tot_mask1, tot_mask2)\n    return lam\n\n\ndef calc_masked_lam_with_difference(x1, x2, mix, kernel_size=3):\n    s1 = event_difference(x1, mix, kernel_size=kernel_size)\n    s2 = event_difference(x2, mix, kernel_size=kernel_size)\n    return (s2 * s2) / (s1 * s1 + s2 * s2)\n\n\nclass MixUp(Dataset):\n    def __init__(self, dataset, num_class, num_mix=1, beta=1., prob=1.0, indices=None, noise=0.0, vis=False, **kwargs):\n        self.dataset = dataset\n        self.num_class = num_class\n        self.num_mix = num_mix\n        self.beta = beta\n        self.prob = prob\n        self.indices = indices\n        self.noise = noise\n        self.vis = vis\n\n    def __getitem__(self, index):\n        img, lb = self.dataset[index]\n        lb_onehot = onehot(self.num_class, lb)\n\n        if self.vis:\n            origin = img.clone()\n\n        for _ in range(self.num_mix):\n            r = np.random.rand(1)\n            if self.beta <= 0 or r > self.prob:\n                continue\n\n            # generate mixed sample\n            lam = np.random.beta(self.beta, self.beta)\n\n            if self.indices is None:\n                rand_index = random.choice(range(len(self)))\n            else:\n                rand_index = random.choice(self.indices)\n\n            img2, lb2 = self.dataset[rand_index]\n            lb2_onehot = onehot(self.num_class, lb2)\n\n            img = img * lam + img2 * (1. - lam)\n            lb_onehot = lb_onehot * lam + lb2_onehot * (1. - lam)\n\n            if self.noise != 0.:\n                img = SaltAndPepperNoise(img, self.noise)\n\n        if self.vis:\n            return origin, img, img2\n        else:\n            return img, lb_onehot\n\n    def __len__(self):\n        return len(self.dataset)\n\n\nclass CutMix(Dataset):\n    #   81.45161290322581 (epoch 584) /data/floyed/BrainCog/train/20220413-050658-resnet34-dvsc10-10-cut_mix before lam\n    def __init__(self, dataset, num_class, num_mix=1, beta=1., prob=1.0, indices=None, noise=0.0, vis=False, **kwargs):\n        self.dataset = dataset\n        self.num_class = num_class\n        self.num_mix = num_mix\n        self.beta = beta\n        self.prob = prob\n        self.indices = indices\n        self.noise = noise\n        self.vis = vis\n\n    def __getitem__(self, index):\n        img, lb = self.dataset[index]\n        lb_onehot = onehot(self.num_class, lb)\n\n        if self.vis:\n            origin = img.clone()\n\n        for _ in range(self.num_mix):\n            r = np.random.rand(1)\n            if self.beta <= 0 or r > self.prob:\n                continue\n\n            # generate mixed sample\n            lam = np.random.beta(self.beta, self.beta)\n\n            if self.indices is None:\n                rand_index = random.choice(range(len(self)))\n            else:\n                rand_index = random.choice(self.indices)\n\n            img2, lb2 = self.dataset[rand_index]\n            lb2_onehot = onehot(self.num_class, lb2)\n            # shape: step, channel, height, width\n            # alpha = np.random.rand()\n\n            # if alpha < 0.333:\n            bbx1, bby1, bbx2, bby2 = rand_bbox(img.shape, 1. - lam)\n            # bbx1, bby1, bbx2, bby2 = 32, 0, 48, 16\n\n            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (img.shape[-1] * img.shape[-2]))  # area\n            # lam = calc_lam(img, img2, 0, shape[0], bbx1, bbx2, bby1, bby2)  # count\n\n            #  distance\n            # mask = torch.zeros_like(img, dtype=torch.bool)\n            # mask[:, :, bbx1:bbx2, bby1:bby2] = True\n            # mix = img.clone()\n            # mix[mask] = img2[mask]\n            # lam = calc_masked_lam_with_difference(img, img2, mix, kernel_size=3)\n            if self.vis:\n                img[:, :, bbx1:bbx2, bby1:bby2] = -img2[:, :, bbx1:bbx2, bby1:bby2]\n                img2 = -img2\n            else:\n                img[:, :, bbx1:bbx2, bby1:bby2] = img2[:, :, bbx1:bbx2, bby1:bby2]\n\n            # elif alpha > 0.667:\n            #     bbt1, bbt2 = rand_bbox_time(img.shape, 1. - lam)\n            #     lam = calc_lam(img, img2, bbt1, bbt2, 0, shape[2], 0, shape[3])\n            #     img[:, bbt1:bbt2, :, :] = img2[:, bbt1:bbt2, :, :]\n            #     # lam = 1 - (bbt2 - bbt1) / (img.shape[-4])\n            # else:\n            #     bbt1, bbt2, bbx1, bby1, bbx2, bby2 = rand_bbox_st(img.shape, 1. - lam)\n            #     lam = calc_lam(img, img2, bbt1, bbt2, bbx1, bbx2, bby1, bby2)\n            #     img[:, bbt1:bbt2, bbx1:bbx2, bby1:bby2] = img2[:, bbt1:bbt2, bbx1:bbx2, bby1:bby2]\n            #     # lam = 1 - ((bbt2 - bbt1) * (bbx2 - bbx1) * (bby2 - bby1) / (img.shape[-1] * img.shape[-2] * img.shape[-4]))\n\n            if self.noise != 0.:\n                img = SaltAndPepperNoise(img, self.noise)\n            lb_onehot = lb_onehot * lam + lb2_onehot * (1. - lam)\n\n        if self.vis:\n            mask = torch.zeros_like(img)\n            mask[:, :, bbx1:bbx2, bby1:bby2] = 1.\n            return origin, img, img2, mask\n        else:\n            return img, lb_onehot\n\n    def __len__(self):\n        return len(self.dataset)\n\n\nclass EventMix(Dataset):\n    #  82.15725806451613 (epoch 554) /data/floyed/BrainCog/train/20220413-014843-resnet34-dvsc10-10-masked\n    def __init__(self,\n                 dataset,\n                 num_class,\n                 num_mix=1,\n                 beta=1.,\n                 prob=1.0,\n                 indices=None,\n                 noise=0.1,\n                 vis=False,\n                 gaussian_n=None,\n                 **kwargs):\n        self.dataset = dataset\n        self.num_class = num_class\n        self.num_mix = num_mix\n        self.beta = beta\n        self.prob = prob\n        self.indices = indices\n        self.noise = noise\n        self.vis = vis\n        self.gaussian_n = gaussian_n\n        print(self.prob, self.gaussian_n, self.beta)\n\n    def __getitem__(self, index):\n        img, lb = self.dataset[index]\n        lb_onehot = onehot(self.num_class, lb)\n\n        shape = img.shape\n        if self.vis:\n            origin = img.clone()\n\n        for _ in range(self.num_mix):\n            r = np.random.rand(1)\n            if self.beta <= 0 or r > self.prob:\n                continue\n\n            # generate mixed sample\n            lam = np.random.beta(self.beta, self.beta)  # lam -> remain ratio\n\n            if self.indices is None:\n                rand_index = random.choice(range(len(self)))\n            else:\n                rand_index = random.choice(self.indices)\n\n            img2, lb2 = self.dataset[rand_index]\n            lb2_onehot = onehot(self.num_class, lb2)\n            # shape: step, channel, height, width\n            # alpha = np.random.rand()\n            # if alpha < 0.333:\n            # mask = spatio_mask(shape, 1. - lam)\n            # elif alpha > 0.667:\n            # mask = temporal_mask(shape, 1. - lam)\n            # else:\n            # mask = st_mask(shape, 1. - lam)\n            mask = GMM_mask(shape, 1. - lam, self.gaussian_n)\n            # mask = GMM_mask_clip(shape, 1. - lam)\n            # mask = torch.logical_not(mask)\n\n            # lam = 1 - (mask.sum() / np.prod(img.shape))  #  area\n            lam = calc_masked_lam(img, img2, mask)  # count\n            img[mask] = img2[mask]  # count && mask required\n\n            # distance\n            # mix = torch.clone(img)\n            # if self.vis:\n            #     mix[mask] = -img2[mask]\n            #     img2 = -img2\n            # else:\n            #     mix[mask] = img2[mask]\n            # lam = calc_masked_lam_with_difference(img, img2, mix, kernel_size=3)\n            # img = mix\n\n            if self.noise != 0.:\n                img = SaltAndPepperNoise(img, self.noise)\n\n            lb_onehot = lb_onehot * lam + lb2_onehot * (1. - lam)\n\n        if self.vis:\n            return origin, img, img2, mask\n        else:\n            return img, lb_onehot\n\n    def __len__(self):\n        return len(self.dataset)\n\n\nif __name__ == '__main__':\n    import matplotlib.pyplot as plt\n    from mpl_toolkits.mplot3d import Axes3D\n    from mpl_toolkits.mplot3d import proj3d\n\n\n    def get_proj(self):\n        \"\"\"\n         Create the projection matrix from the current viewing position.\n\n         elev stores the elevation angle in the z plane\n         azim stores the azimuth angle in the (x, y) plane\n\n         dist is the distance of the eye viewing point from the object point.\n        \"\"\"\n        # chosen for similarity with the initial view before gh-8896\n\n        relev, razim = np.pi * self.elev / 180, np.pi * self.azim / 180\n\n        # EDITED TO HAVE SCALED AXIS\n        xmin, xmax = np.divide(self.get_xlim3d(), self.pbaspect[0])\n        ymin, ymax = np.divide(self.get_ylim3d(), self.pbaspect[1])\n        zmin, zmax = np.divide(self.get_zlim3d(), self.pbaspect[2])\n\n        # transform to uniform world coordinates 0-1, 0-1, 0-1\n        worldM = proj3d.world_transformation(xmin, xmax,\n                                             ymin, ymax,\n                                             zmin, zmax)\n\n        # look into the middle of the new coordinates\n        R = self.pbaspect / 2\n\n        xp = R[0] + np.cos(razim) * np.cos(relev) * self.dist\n        yp = R[1] + np.sin(razim) * np.cos(relev) * self.dist\n        zp = R[2] + np.sin(relev) * self.dist\n        E = np.array((xp, yp, zp))\n\n        self.eye = E\n        self.vvec = R - E\n        self.vvec = self.vvec / np.linalg.norm(self.vvec)\n\n        if abs(relev) > np.pi / 2:\n            # upside down\n            V = np.array((0, 0, -1))\n        else:\n            V = np.array((0, 0, 1))\n        zfront, zback = -self.dist, self.dist\n\n        viewM = proj3d.view_transformation(E, R, V)\n        projM = self._projection(zfront, zback)\n        M0 = np.dot(viewM, worldM)\n        M = np.dot(projM, M0)\n        return M\n\n\n    Axes3D.get_proj = get_proj\n\n    size = (100, 2, 48, 48)\n    mask = GMM_mask(size, 0.3)\n    print(mask.shape)\n    # for i in range(100):\n    #     plt.figure()\n    #     plt.imshow(mask[i, 0])\n    # plt.show()\n\n    pos_idx1 = []\n    neg_idx1 = []\n    for t in range(100):\n        for r in range(48):\n            for c in range(48):\n                if mask[t, 0, r, c] > 0:\n                    pos_idx1.append((t, r, c))\n                if mask[t, 1, r, c] > 0:\n                    neg_idx1.append((t, r, c))\n    pos_t1, pos_x1, pos_y1 = np.split(np.array(pos_idx1), 3, axis=1)\n    neg_t1, neg_x1, neg_y1 = np.split(np.array(neg_idx1), 3, axis=1)\n\n    fig = plt.figure(figsize=plt.figaspect(0.5) * 1.5)\n    ax = Axes3D(fig)\n    ax.pbaspect = np.array([1, 1, 1])  # np.array([2.0, 1.0, 0.5])\n    ax.view_init(elev=10, azim=-75)\n    # ax.axis('off')\n    ax.scatter(pos_t1[:, 0], pos_y1[:, 0], 48 - pos_x1[:, 0], color='red', alpha=0.1, s=2.)\n    ax.scatter(neg_t1[:, 0], neg_y1[:, 0], 48 - neg_x1[:, 0], color='blue', alpha=0.1, s=2.)\n    plt.show()\n"
  },
  {
    "path": "braincog/datasets/datasets.py",
    "content": "import os, warnings\n\nimport tonic\nfrom tonic import DiskCachedDataset\n \nimport torch\nimport torch.nn.functional as F\nimport torch.utils\nimport torchvision.datasets as datasets\nfrom timm.data import ImageDataset, create_loader, Mixup, FastCollateMixup, AugMixDataset\nfrom timm.data import create_transform\n\nfrom torchvision import transforms\nfrom typing import Any, Dict, Optional, Sequence, Tuple, Union\n\nimport braincog\nfrom braincog.datasets.NOmniglot.nomniglot_full import NOmniglotfull\nfrom braincog.datasets.NOmniglot.nomniglot_nw_ks import NOmniglotNWayKShot\nfrom braincog.datasets.NOmniglot.nomniglot_pair import NOmniglotTrainSet, NOmniglotTestSet\nfrom braincog.datasets.ESimagenet.ES_imagenet import ESImagenet_Dataset\nfrom braincog.datasets.ESimagenet.reconstructed_ES_imagenet import ESImagenet2D_Dataset\nfrom braincog.datasets.CUB2002011 import CUB2002011\nfrom braincog.datasets.TinyImageNet import TinyImageNet\nfrom braincog.datasets.StanfordDogs import StanfordDogs\nfrom braincog.datasets.bullying10k import BULLYINGDVS\n\nfrom .cut_mix import CutMix, EventMix, MixUp\nfrom .rand_aug import *\nfrom .utils import dvs_channel_check_expend, rescale\n\nDVSCIFAR10_MEAN_16 = [0.3290, 0.4507]\nDVSCIFAR10_STD_16 = [1.8398, 1.6549]\n\nDATA_DIR = '/data/datasets'\n\nDEFAULT_CROP_PCT = 0.875\nIMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)\nIMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)\nIMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)\nIMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)\nIMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)\nIMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)\n\nCIFAR10_DEFAULT_MEAN = (0.4914, 0.4822, 0.4465)\nCIFAR10_DEFAULT_STD = (0.2023, 0.1994, 0.2010)\n\n\ndef unpack_mix_param(args):\n    mix_up = args['mix_up'] if 'mix_up' in args else False\n    cut_mix = args['cut_mix'] if 'cut_mix' in args else False\n    event_mix = args['event_mix'] if 'event_mix' in args else False\n    beta = args['beta'] if 'beta' in args else 1.\n    prob = args['prob'] if 'prob' in args else .5\n    num = args['num'] if 'num' in args else 1\n    num_classes = args['num_classes'] if 'num_classes' in args else 10\n    noise = args['noise'] if 'noise' in args else 0.\n    gaussian_n = args['gaussian_n'] if 'gaussian_n' in args else None\n    return mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n\n\n\ndef build_transform(is_train, img_size):\n    \"\"\"\n    构建数据增强, 适用于static data\n    :param is_train: 是否训练集\n    :param img_size: 输出的图像尺寸\n    :return: 数据增强策略\n    \"\"\"\n    resize_im = img_size > 32\n    if is_train:\n        # this should always dispatch to transforms_imagenet_train\n        transform = create_transform(\n            input_size=img_size,\n            is_training=True,\n            color_jitter=0.4,\n            auto_augment='rand-m9-mstd0.5-inc1',\n            interpolation='bicubic',\n            re_prob=0.25,\n            re_mode='pixel',\n            re_count=1,\n        )\n        if not resize_im:\n            # replace RandomResizedCropAndInterpolation with\n            # RandomCrop\n            transform.transforms[0] = transforms.RandomCrop(\n                img_size, padding=4)\n        return transform\n\n    t = []\n    if resize_im:\n        size = int((256 / 224) * img_size)\n        t.append(\n            # to maintain same ratio w.r.t. 224 images\n            transforms.Resize(size, interpolation=3),\n        )\n        t.append(transforms.CenterCrop(img_size))\n\n    t.append(transforms.ToTensor())\n    if img_size > 32:\n        t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))\n    else:\n        t.append(transforms.Normalize(CIFAR10_DEFAULT_MEAN, CIFAR10_DEFAULT_STD))\n    return transforms.Compose(t)\n\n\ndef build_dataset(is_train, img_size, dataset, path, same_da=False):\n    \"\"\"\n    构建带有增强策略的数据集\n    :param is_train: 是否训练集\n    :param img_size: 输出图像尺寸\n    :param dataset: 数据集名称\n    :param path: 数据集路径\n    :param same_da: 为训练集使用测试集的增广方法\n    :return: 增强后的数据集\n    \"\"\"\n    transform = build_transform(False, img_size) if same_da else build_transform(is_train, img_size)\n\n    if dataset == 'CIFAR10':\n        dataset = datasets.CIFAR10(\n            path, train=is_train, transform=transform, download=True)\n        nb_classes = 10\n    elif dataset == 'CIFAR100':\n        dataset = datasets.CIFAR100(\n            path, train=is_train, transform=transform, download=True)\n        nb_classes = 100\n    else:\n        raise NotImplementedError\n\n    return dataset, nb_classes\n\n\nclass MNISTData(object):\n    \"\"\"\n    Load MNIST datesets.\n    \"\"\"\n\n    def __init__(self,\n                 data_path: str,\n                 batch_size: int,\n                 train_trans: Sequence[torch.nn.Module] = None,\n                 test_trans: Sequence[torch.nn.Module] = None,\n                 pin_memory: bool = True,\n                 drop_last: bool = True,\n                 shuffle: bool = True,\n                 ) -> None:\n        self._data_path = data_path\n        self._batch_size = batch_size\n        self._pin_memory = pin_memory\n        self._drop_last = drop_last\n        self._shuffle = shuffle\n        self._train_transform = transforms.Compose(train_trans) if train_trans else None\n        self._test_transform = transforms.Compose(test_trans) if test_trans else None\n\n    def get_data_loaders(self):\n        print('Batch size: ', self._batch_size)\n        train_datasets = datasets.MNIST(root=self._data_path, train=True, transform=self._train_transform, download=True)\n        test_datasets = datasets.MNIST(root=self._data_path, train=False, transform=self._test_transform, download=True)\n        train_loader = torch.utils.data.DataLoader(\n            train_datasets, batch_size=self._batch_size,\n            pin_memory=self._pin_memory, drop_last=self._drop_last, shuffle=self._shuffle\n        )\n        test_loader = torch.utils.data.DataLoader(\n            test_datasets, batch_size=self._batch_size,\n            pin_memory=self._pin_memory, drop_last=False\n        )\n        return train_loader, test_loader\n\n    def get_standard_data(self):\n        MNIST_MEAN = 0.1307\n        MNIST_STD = 0.3081\n        self._train_transform = transforms.Compose([transforms.RandomCrop(28, padding=4),\n                                                    transforms.ToTensor(),\n                                                    transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])\n        self._test_transform = transforms.Compose([transforms.ToTensor(),\n                                                   transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])\n        return self.get_data_loaders()\n\n\ndef get_mnist_data(batch_size, num_workers=8, same_da=False,root=DATA_DIR, **kwargs):\n    \"\"\"\n    获取MNIST数据\n    http://data.pymvpa.org/datasets/mnist/\n    :param batch_size: batch size\n    :param same_da: 为训练集使用测试集的增广方法\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    MNIST_MEAN = 0.1307\n    MNIST_STD = 0.3081\n    if 'root' in kwargs:root=kwargs[\"root\"]\n    if 'skip_norm' in kwargs and kwargs['skip_norm'] is True:\n        train_transform = transforms.Compose([\n            transforms.ToTensor(),\n            transforms.Lambda(rescale)\n        ])\n        test_transform = transforms.Compose([\n            transforms.ToTensor(),\n            transforms.Lambda(rescale)\n        ])\n    else:\n        train_transform = transforms.Compose([transforms.RandomCrop(28, padding=4),\n                                              # transforms.RandomRotation(10),\n                                              transforms.ToTensor(),\n                                              transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])\n        test_transform = transforms.Compose([transforms.ToTensor(),\n                                             transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])\n\n    train_datasets = datasets.MNIST(\n        root=root, train=True, transform=test_transform if same_da else train_transform, download=True)\n    test_datasets = datasets.MNIST(\n        root=root, train=False, transform=test_transform, download=True)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=num_workers\n    )\n\n    return train_loader, test_loader, False, None\n\n\ndef get_fashion_data(batch_size, num_workers=8, same_da=False,root=DATA_DIR, **kwargs):\n    \"\"\"\n    获取fashion MNIST数据\n    http://arxiv.org/abs/1708.07747\n    :param batch_size: batch size\n    :param same_da: 为训练集使用测试集的增广方法\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    train_transform = transforms.Compose([transforms.RandomCrop(28, padding=4),\n                                          transforms.RandomHorizontalFlip(),\n                                          transforms.RandomRotation(10),\n                                          transforms.ToTensor()])\n    test_transform = transforms.Compose([transforms.ToTensor()])\n\n    train_datasets = datasets.FashionMNIST(\n        root=root, train=True, transform=test_transform if same_da else train_transform, download=True)\n    test_datasets = datasets.FashionMNIST(\n        root=root, train=False, transform=test_transform, download=True)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=num_workers\n    )\n\n    return train_loader, test_loader, False, None\n\n\ndef get_cifar10_data(batch_size, num_workers=8, same_da=False,root=DATA_DIR, **kwargs):\n    \"\"\"\n    获取CIFAR10数据\n     https://www.cs.toronto.edu/~kriz/cifar.html\n    :param batch_size: batch size\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    train_datasets, _ = build_dataset(True, 32, 'CIFAR10', root, same_da)\n    test_datasets, _ = build_dataset(False, 32, 'CIFAR10', root, same_da)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=True, shuffle=True,\n        num_workers=num_workers\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=False,\n        num_workers=num_workers\n    )\n    return train_loader, test_loader, None, None\n\n\ndef get_cifar100_data(batch_size, num_workers=8, same_data=False,root=DATA_DIR, *args, **kwargs):\n    \"\"\"\n    获取CIFAR100数据\n    https://www.cs.toronto.edu/~kriz/cifar.html\n    :param batch_size: batch size\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    train_datasets, _ = build_dataset(True, 32, 'CIFAR100', root, same_data)\n    test_datasets, _ = build_dataset(False, 32, 'CIFAR100', root, same_data)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=num_workers\n    )\n    return train_loader, test_loader, False, None\n\ndef get_TinyImageNet_data(batch_size, num_workers=8, same_da=False,root=DATA_DIR, *args, **kwargs):\n    size=kwargs[\"size\"] if \"size\" in kwargs else 224\n    train_transform = transforms.Compose([\n        transforms.RandomResizedCrop(size),\n        transforms.RandomHorizontalFlip(), \n        transforms.ToTensor(),\n        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n    ])\n    test_transform = transforms.Compose([\n        transforms.Resize(size*8//7),\n        transforms.CenterCrop(size),\n        transforms.ToTensor(),\n        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n    ])\n    root=os.path.join(root, 'TinyImageNet')\n    train_datasets = TinyImageNet(\n        root=root, split=\"train\", transform=test_transform if same_da else train_transform, download=True)\n    test_datasets = TinyImageNet(\n        root=root, split=\"val\", transform=test_transform, download=True)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=num_workers\n    )\n\n    return train_loader, test_loader, False, None\n\n\ndef get_imnet_data(args, _logger, data_config, num_aug_splits,root=DATA_DIR, **kwargs):\n    \"\"\"\n    获取ImageNet数据集\n    http://arxiv.org/abs/1409.0575\n    :param args: 其他的参数\n    :param _logger: 日志路径\n    :param data_config: 增强策略\n    :param num_aug_splits: 不同增强策略的数量\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    train_dir = os.path.join(root, 'ILSVRC2012/train')\n    if not os.path.exists(train_dir):\n        _logger.error(\n            'Training folder does not exist at: {}'.format(train_dir))\n        exit(1)\n    dataset_train = ImageDataset(train_dir)\n    # collate_fn = None\n    # mixup_fn = None\n    # mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None\n    # if mixup_active:\n    #     mixup_args = dict(\n    #         mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,\n    #         prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,\n    #         label_smoothing=args.smoothing, num_classes=args.num_classes)\n    #     if args.prefetcher:\n    #         # collate conflict (need to support deinterleaving in collate mixup)\n    #         assert not num_aug_splits\n    #         collate_fn = FastCollateMixup(**mixup_args)\n    #     else:\n    #         mixup_fn = Mixup(**mixup_args)\n\n    # if num_aug_splits > 1:\n    #     dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)\n\n    train_interpolation = args.train_interpolation\n    if args.no_aug or not train_interpolation:\n        train_interpolation = data_config['interpolation']\n    loader_train = create_loader(\n        dataset_train,\n        input_size=data_config['input_size'],\n        batch_size=args.batch_size,\n        is_training=True,\n        use_prefetcher=args.prefetcher,\n        no_aug=args.no_aug,\n        # re_prob=args.reprob,\n        # re_mode=args.remode,\n        # re_count=args.recount,\n        # re_split=args.resplit,\n        scale=args.scale,\n        ratio=args.ratio,\n        hflip=args.hflip,\n        # vflip=arg,\n        color_jitter=args.color_jitter,\n        #auto_augment=args.aa,\n        num_aug_splits=num_aug_splits,\n        interpolation=train_interpolation,\n        mean=data_config['mean'],\n        std=data_config['std'],\n        num_workers=args.workers,\n        distributed=args.distributed,\n        #collate_fn=collate_fn,\n        pin_memory=args.pin_mem,\n        use_multi_epochs_loader=args.use_multi_epochs_loader)\n    eval_dir = os.path.join(root, 'ILSVRC2012/val')\n    if not os.path.isdir(eval_dir):\n        eval_dir = os.path.join(root, 'ILSVRC2012/validation')\n        if not os.path.isdir(eval_dir):\n            _logger.error(\n                'Validation folder does not exist at: {}'.format(eval_dir))\n            exit(1)\n    dataset_eval = ImageDataset(eval_dir)\n\n    loader_eval = create_loader(\n        dataset_eval,\n        input_size=data_config['input_size'],\n        batch_size=args.validation_batch_size_multiplier * args.batch_size,\n        is_training=False,\n        use_prefetcher=args.prefetcher,\n        interpolation=data_config['interpolation'],\n        mean=data_config['mean'],\n        std=data_config['std'],\n        num_workers=args.workers,\n        distributed=args.distributed,\n        crop_pct=data_config['crop_pct'],\n        pin_memory=args.pin_mem,\n    )\n    return loader_train, loader_eval, False, None\n\n\ndef get_dvsg_data(batch_size, step,root=DATA_DIR, **kwargs):\n    \"\"\"\n    获取DVS Gesture数据\n    DOI: 10.1109/CVPR.2017.781\n    :param batch_size: batch size\n    :param step: 仿真步长\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    sensor_size = tonic.datasets.DVSGesture.sensor_size\n    size = kwargs['size'] if 'size' in kwargs else 48\n\n    train_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        # tonic.transforms.DropEvent(p=0.1),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),\n    ])\n    test_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),\n    ])\n\n    train_dataset = tonic.datasets.DVSGesture(os.path.join(root, 'DVS/DVSGesture'),\n                                              transform=train_transform, train=True)\n    test_dataset = tonic.datasets.DVSGesture(os.path.join(root, 'DVS/DVSGesture'),\n                                             transform=test_transform, train=False)\n\n    train_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        lambda x: dvs_channel_check_expend(x),\n        transforms.RandomCrop(size, padding=size // 12),\n        # transforms.RandomHorizontalFlip(),\n        # transforms.RandomRotation(15)\n    ])\n    test_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        lambda x: dvs_channel_check_expend(x),\n    ])\n    if 'rand_aug' in kwargs.keys():\n        if kwargs['rand_aug'] is True:\n            n = kwargs['randaug_n']\n            m = kwargs['randaug_m']\n            train_transform.transforms.insert(2, RandAugment(m=m, n=n))\n\n    # if 'temporal_flatten' in kwargs.keys():\n    #     if kwargs['temporal_flatten'] is True:\n    #         train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n    #         test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n\n    train_dataset = DiskCachedDataset(train_dataset,\n                                      cache_path=os.path.join(root, 'DVS/DVSGesture/train_cache_{}'.format(step)),\n                                      transform=train_transform, num_copies=3)\n    test_dataset = DiskCachedDataset(test_dataset,\n                                     cache_path=os.path.join(root, 'DVS/DVSGesture/test_cache_{}'.format(step)),\n                                     transform=test_transform, num_copies=3)\n\n    mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n    mixup_active = cut_mix | event_mix | mix_up\n\n    if cut_mix:\n        train_dataset = CutMix(train_dataset,\n                               beta=beta,\n                               prob=prob,\n                               num_mix=num,\n                               num_class=num_classes,\n                               noise=noise)\n\n    if event_mix:\n        train_dataset = EventMix(train_dataset,\n                                 beta=beta,\n                                 prob=prob,\n                                 num_mix=num,\n                                 num_class=num_classes,\n                                 noise=noise,\n                                 gaussian_n=gaussian_n)\n    if mix_up:\n        train_dataset = MixUp(train_dataset,\n                              beta=beta,\n                              prob=prob,\n                              num_mix=num,\n                              num_class=num_classes,\n                              noise=noise)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size,\n        pin_memory=True, drop_last=True, num_workers=8,\n        shuffle=True,\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=2,\n        shuffle=False,\n    )\n\n    return train_loader, test_loader, mixup_active, None\n\n\ndef get_bullyingdvs_data(batch_size, step, root=DATA_DIR, **kwargs):\n    \"\"\"\n    获取Bullying10K数据\n    NeurIPS 2023\n    :param batch_size: batch size\n    :param step: 仿真步长\n    :param kwargs:\n    :return:\n    \"\"\"\n    size = kwargs['size'] if 'size' in kwargs else 48\n    sensor_size = BULLYINGDVS.sensor_size\n    train_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        # tonic.transforms.DropEvent(p=0.1),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    test_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    train_dataset = BULLYINGDVS('/data/datasets/Bullying10k_processed', transform=train_transform)\n    # train_dataset = BULLYINGDVS(os.path.join(root, 'DVS/BULLYINGDVS'), transform=train_transform)\n    test_dataset = BULLYINGDVS(os.path.join(root, 'DVS/BULLYINGDVS'), transform=test_transform)\n\n    train_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        transforms.RandomCrop(size, padding=size // 12),\n        transforms.RandomHorizontalFlip(),\n        transforms.RandomRotation(15)\n    ])\n    test_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n    ])\n\n    if 'rand_aug' in kwargs.keys():\n        if kwargs['rand_aug'] is True:\n            n = kwargs['randaug_n']\n            m = kwargs['randaug_m']\n            # print('randaug', m, n)\n            train_transform.transforms.insert(2, RandAugment(m=m, n=n))\n\n    train_dataset = DiskCachedDataset(train_dataset,\n                                      cache_path=os.path.join(root, 'DVS/BULLYINGDVS/train_cache_{}'.format(step)),\n                                      transform=train_transform)\n    test_dataset = DiskCachedDataset(test_dataset,\n                                     cache_path=os.path.join(root, 'DVS/BULLYINGDVS/test_cache_{}'.format(step)),\n                                     transform=test_transform)\n\n    num_train = len(train_dataset)\n    num_per_cls = num_train // 10\n    indices_train, indices_test = [], []\n    portion = kwargs['portion'] if 'portion' in kwargs else .9\n    for i in range(10):\n        indices_train.extend(\n            list(range(i * num_per_cls, round(i * num_per_cls + num_per_cls * portion))))\n        indices_test.extend(\n            list(range(round(i * num_per_cls + num_per_cls * portion), (i + 1) * num_per_cls)))\n\n    mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n    mixup_active = cut_mix | event_mix | mix_up\n\n    if cut_mix:\n        # print('cut_mix', beta, prob, num, num_classes)\n        train_dataset = CutMix(train_dataset,\n                               beta=beta,\n                               prob=prob,\n                               num_mix=num,\n                               num_class=num_classes,\n                               indices=indices_train,\n                               noise=noise)\n\n    if event_mix:\n        train_dataset = EventMix(train_dataset,\n                                 beta=beta,\n                                 prob=prob,\n                                 num_mix=num,\n                                 num_class=num_classes,\n                                 indices=indices_train,\n                                 noise=noise,\n                                 gaussian_n=gaussian_n)\n\n    if mix_up:\n        train_dataset = MixUp(train_dataset,\n                              beta=beta,\n                              prob=prob,\n                              num_mix=num,\n                              num_class=num_classes,\n                              indices=indices_train,\n                              noise=noise)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size,\n        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_train),\n        pin_memory=True, drop_last=True, num_workers=8\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_test),\n        pin_memory=True, drop_last=False, num_workers=2\n    )\n\n    return train_loader, test_loader, mixup_active, None\n\n\n\ndef get_dvsc10_data(batch_size, step, root=DATA_DIR, **kwargs):\n    \"\"\"\n    获取DVS CIFAR10数据\n    http://journal.frontiersin.org/article/10.3389/fnins.2017.00309/full\n    :param batch_size: batch size\n    :param step: 仿真步长\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    size = kwargs['size'] if 'size' in kwargs else 48\n    sensor_size = tonic.datasets.CIFAR10DVS.sensor_size\n    train_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        # tonic.transforms.DropEvent(p=0.1),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    test_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    train_dataset = tonic.datasets.CIFAR10DVS(os.path.join(root, 'DVS/DVS_Cifar10'), transform=train_transform)\n    test_dataset = tonic.datasets.CIFAR10DVS(os.path.join(root, 'DVS/DVS_Cifar10'), transform=test_transform)\n\n    train_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        # lambda x: TemporalShift(x, .01),\n        # lambda x: drop(x, 0.15),\n        # lambda x: ShearX(x, 15),\n        # lambda x: ShearY(x, 15),\n        # lambda x: TranslateX(x, 0.225),\n        # lambda x: TranslateY(x, 0.225),\n        # lambda x: Rotate(x, 15),\n        # lambda x: CutoutAbs(x, 0.25),\n        # lambda x: CutoutTemporal(x, 0.25),\n        # lambda x: GaussianBlur(x, 0.5),\n        # lambda x: SaltAndPepperNoise(x, 0.1),\n        # transforms.Normalize(DVSCIFAR10_MEAN_16, DVSCIFAR10_STD_16),\n        transforms.RandomCrop(size, padding=size // 12),\n        transforms.RandomHorizontalFlip(),\n        transforms.RandomRotation(15)\n    ])\n    test_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n    ])\n\n    if 'rand_aug' in kwargs.keys():\n        if kwargs['rand_aug'] is True:\n            n = kwargs['randaug_n']\n            m = kwargs['randaug_m']\n            # print('randaug', m, n)\n            train_transform.transforms.insert(2, RandAugment(m=m, n=n))\n\n    # if 'temporal_flatten' in kwargs.keys():\n    #     if kwargs['temporal_flatten'] is True:\n    #         train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n    #         test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n\n    train_dataset = DiskCachedDataset(train_dataset,\n                                      cache_path=os.path.join(root, 'DVS/DVS_Cifar10/train_cache_{}'.format(step)),\n                                      transform=train_transform)\n    test_dataset = DiskCachedDataset(test_dataset,\n                                     cache_path=os.path.join(root, 'DVS/DVS_Cifar10/test_cache_{}'.format(step)),\n                                     transform=test_transform)\n\n    num_train = len(train_dataset)\n    num_per_cls = num_train // 10\n    indices_train, indices_test = [], []\n    portion = kwargs['portion'] if 'portion' in kwargs else .9\n    for i in range(10):\n        indices_train.extend(\n            list(range(i * num_per_cls, round(i * num_per_cls + num_per_cls * portion))))\n        indices_test.extend(\n            list(range(round(i * num_per_cls + num_per_cls * portion), (i + 1) * num_per_cls)))\n\n    mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n    mixup_active = cut_mix | event_mix | mix_up\n\n    if cut_mix:\n        # print('cut_mix', beta, prob, num, num_classes)\n        train_dataset = CutMix(train_dataset,\n                               beta=beta,\n                               prob=prob,\n                               num_mix=num,\n                               num_class=num_classes,\n                               indices=indices_train,\n                               noise=noise)\n\n    if event_mix:\n        train_dataset = EventMix(train_dataset,\n                                 beta=beta,\n                                 prob=prob,\n                                 num_mix=num,\n                                 num_class=num_classes,\n                                 indices=indices_train,\n                                 noise=noise,\n                                 gaussian_n=gaussian_n)\n\n    if mix_up:\n        train_dataset = MixUp(train_dataset,\n                              beta=beta,\n                              prob=prob,\n                              num_mix=num,\n                              num_class=num_classes,\n                              indices=indices_train,\n                              noise=noise)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size,\n        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_train),\n        pin_memory=True, drop_last=True, num_workers=8\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_test),\n        pin_memory=True, drop_last=False, num_workers=2\n    )\n\n    return train_loader, test_loader, mixup_active, None\n\n\ndef get_NCALTECH101_data(batch_size, step,root=DATA_DIR, **kwargs):\n    \"\"\"\n    获取NCaltech101数据\n    http://journal.frontiersin.org/Article/10.3389/fnins.2015.00437/abstract\n    :param batch_size: batch size\n    :param step: 仿真步长\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    sensor_size = tonic.datasets.NCALTECH101.sensor_size\n    cls_count = tonic.datasets.NCALTECH101.cls_count\n    dataset_length = tonic.datasets.NCALTECH101.length\n    portion = kwargs['portion'] if 'portion' in kwargs else .9\n    size = kwargs['size'] if 'size' in kwargs else 48\n    # print('portion', portion)\n    train_sample_weight = []\n    train_sample_index = []\n    train_count = 0\n    test_sample_index = []\n    idx_begin = 0\n    for count in cls_count:\n        sample_weight = dataset_length / count\n        train_sample = round(portion * count)\n        test_sample = count - train_sample\n        train_count += train_sample\n        train_sample_weight.extend(\n            [sample_weight] * train_sample\n        )\n        train_sample_weight.extend(\n            [0.] * test_sample\n        )\n        train_sample_index.extend(\n            list((range(idx_begin, idx_begin + train_sample)))\n        )\n        test_sample_index.extend(\n            list(range(idx_begin + train_sample, idx_begin + train_sample + test_sample))\n        )\n        idx_begin += count\n\n    train_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_sample_weight, train_count)\n    test_sampler = torch.utils.data.sampler.SubsetRandomSampler(test_sample_index)\n\n    train_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        # tonic.transforms.DropEvent(p=0.1),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    test_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n\n    train_dataset = tonic.datasets.NCALTECH101(os.path.join(root, 'DVS/NCALTECH101'), transform=train_transform)\n    test_dataset = tonic.datasets.NCALTECH101(os.path.join(root, 'DVS/NCALTECH101'), transform=test_transform)\n\n    train_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        # lambda x: print(x.shape),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        # transforms.RandomCrop(size, padding=size // 12),\n        # transforms.RandomHorizontalFlip(),\n        #transforms.RandomRotation(15)\n    ])\n    test_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        # lambda x: temporal_flatten(x),\n    ])\n    if 'rand_aug' in kwargs.keys():\n        if kwargs['rand_aug'] is True:\n            n = kwargs['randaug_n']\n            m = kwargs['randaug_m']\n            train_transform.transforms.insert(2, RandAugment(m=m, n=n))\n\n    # if 'temporal_flatten' in kwargs.keys():\n    #     if kwargs['temporal_flatten'] is True:\n    #         train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n    #         test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n\n    train_dataset = DiskCachedDataset(train_dataset,\n                                      cache_path=os.path.join(root, 'DVS/NCALTECH101/train_cache_{}'.format(step)),\n                                      transform=train_transform, num_copies=3)\n    test_dataset = DiskCachedDataset(test_dataset,\n                                     cache_path=os.path.join(root, 'DVS/NCALTECH101/test_cache_{}'.format(step)),\n                                     transform=test_transform, num_copies=3)\n\n    mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n    mixup_active = cut_mix | event_mix | mix_up\n\n    if cut_mix:\n        train_dataset = CutMix(train_dataset,\n                               beta=beta,\n                               prob=prob,\n                               num_mix=num,\n                               num_class=num_classes,\n                               indices=train_sample_index,\n                               noise=noise)\n\n    if event_mix:\n        train_dataset = EventMix(train_dataset,\n                                 beta=beta,\n                                 prob=prob,\n                                 num_mix=num,\n                                 num_class=num_classes,\n                                 indices=train_sample_index,\n                                 noise=noise,\n                                 gaussian_n=gaussian_n)\n    if mix_up:\n        train_dataset = MixUp(train_dataset,\n                              beta=beta,\n                              prob=prob,\n                              num_mix=num,\n                              num_class=num_classes,\n                              indices=train_sample_index,\n                              noise=noise)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size,\n        sampler=train_sampler,\n        pin_memory=True, drop_last=True, num_workers=8\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        sampler=test_sampler,\n        pin_memory=True, drop_last=False, num_workers=2\n    )\n\n    return train_loader, test_loader, mixup_active, None\n\n\ndef get_NCARS_data(batch_size, step,root=DATA_DIR, **kwargs):\n    \"\"\"\n    获取N-Cars数据\n    https://ieeexplore.ieee.org/document/8578284/\n    :param batch_size: batch size\n    :param step: 仿真步长\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    sensor_size = tonic.datasets.NCARS.sensor_size\n    size = kwargs['size'] if 'size' in kwargs else 48\n\n    train_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        # tonic.transforms.DropEvent(p=0.1),\n        tonic.transforms.ToFrame(sensor_size=None, n_time_bins=step),\n    ])\n    test_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        tonic.transforms.ToFrame(sensor_size=None, n_time_bins=step),\n    ])\n\n    train_dataset = tonic.datasets.NCARS(os.path.join(root, 'DVS/NCARS'), transform=train_transform, train=True)\n    test_dataset = tonic.datasets.NCARS(os.path.join(root, 'DVS/NCARS'), transform=test_transform, train=False)\n\n    train_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        lambda x: dvs_channel_check_expend(x),\n        transforms.RandomCrop(size, padding=size // 12),\n        transforms.RandomHorizontalFlip(),\n        transforms.RandomRotation(15)\n    ])\n    test_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        lambda x: dvs_channel_check_expend(x),\n    ])\n    if 'rand_aug' in kwargs.keys():\n        if kwargs['rand_aug'] is True:\n            n = kwargs['randaug_n']\n            m = kwargs['randaug_m']\n            train_transform.transforms.insert(2, RandAugment(m=m, n=n))\n\n    # if 'temporal_flatten' in kwargs.keys():\n    #     if kwargs['temporal_flatten'] is True:\n    #         train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n    #         test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n\n    train_dataset = DiskCachedDataset(train_dataset,\n                                      cache_path=os.path.join(root, 'DVS/NCARS/train_cache_{}'.format(step)),\n                                      transform=train_transform, num_copies=3)\n    test_dataset = DiskCachedDataset(test_dataset,\n                                     cache_path=os.path.join(root, 'DVS/NCARS/test_cache_{}'.format(step)),\n                                     transform=test_transform, num_copies=3)\n\n    mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n    mixup_active = cut_mix | event_mix | mix_up\n\n    if cut_mix:\n        train_dataset = CutMix(train_dataset,\n                               beta=beta,\n                               prob=prob,\n                               num_mix=num,\n                               num_class=num_classes,\n                               noise=noise)\n\n    if event_mix:\n        train_dataset = EventMix(train_dataset,\n                                 beta=beta,\n                                 prob=prob,\n                                 num_mix=num,\n                                 num_class=num_classes,\n                                 noise=noise,\n                                 gaussian_n=gaussian_n)\n    if mix_up:\n        train_dataset = MixUp(train_dataset,\n                              beta=beta,\n                              prob=prob,\n                              num_mix=num,\n                              num_class=num_classes,\n                              noise=noise)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size,\n        pin_memory=True, drop_last=True, num_workers=8,\n        shuffle=True,\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=2,\n        shuffle=False,\n    )\n\n    return train_loader, test_loader, mixup_active, None\n\n\ndef get_nomni_data(batch_size, train_portion=1.,root=DATA_DIR, **kwargs):\n    \"\"\"\n    获取N-Omniglot数据\n    :param batch_size:batch的大小\n    :param data_mode:一共full nkks pair三种模式\n    :param frames_num:一个样本帧的个数\n    :param data_type:event frequency两种模式\n    \"\"\"\n    data_mode = kwargs[\"data_mode\"] if \"data_mode\" in kwargs else \"full\"\n    frames_num = kwargs[\"frames_num\"] if \"frames_num\" in kwargs else 4\n    data_type = kwargs[\"data_type\"] if \"data_type\" in kwargs else \"event\"\n\n    train_transform = transforms.Compose([\n        transforms.Resize((28, 28))])\n    test_transform = transforms.Compose([\n        transforms.Resize((28, 28))])\n    if data_mode == \"full\":\n        train_datasets = NOmniglotfull(root=os.path.join(root, 'DVS/NOmniglot'), train=True, frames_num=frames_num,\n                                       data_type=data_type,\n                                       transform=train_transform)\n        test_datasets = NOmniglotfull(root=os.path.join(root, 'DVS/NOmniglot'), train=False, frames_num=frames_num,\n                                      data_type=data_type,\n                                      transform=test_transform)\n\n    elif data_mode == \"nkks\":\n        train_datasets = NOmniglotNWayKShot(os.path.join(root, 'DVS/NOmniglot'),\n                                            n_way=kwargs[\"n_way\"],\n                                            k_shot=kwargs[\"k_shot\"],\n                                            k_query=kwargs[\"k_query\"],\n                                            train=True,\n                                            frames_num=frames_num,\n                                            data_type=data_type,\n                                            transform=train_transform)\n        test_datasets = NOmniglotNWayKShot(os.path.join(root, 'DVS/NOmniglot'),\n                                           n_way=kwargs[\"n_way\"],\n                                           k_shot=kwargs[\"k_shot\"],\n                                           k_query=kwargs[\"k_query\"],\n                                           train=False,\n                                           frames_num=frames_num,\n                                           data_type=data_type,\n                                           transform=test_transform)\n    elif data_mode == \"pair\":\n        train_datasets = NOmniglotTrainSet(root=os.path.join(root, 'DVS/NOmniglot'), use_frame=True,\n                                           frames_num=frames_num, data_type=data_type,\n                                           use_npz=False, resize=105)\n        test_datasets = NOmniglotTestSet(root=os.path.join(root, 'DVS/NOmniglot'), time=2000, way=kwargs[\"n_way\"],\n                                         shot=kwargs[\"k_shot\"], use_frame=True,\n                                         frames_num=frames_num, data_type=data_type, use_npz=False, resize=105)\n\n    else:\n        pass\n\n    train_loader = torch.utils.data.DataLoader(\n        train_datasets, batch_size=batch_size, num_workers=12,\n        pin_memory=True, drop_last=True, shuffle=True\n    )\n    test_loader = torch.utils.data.DataLoader(\n        test_datasets, batch_size=batch_size, num_workers=12,\n        pin_memory=True, drop_last=False\n    )\n    return train_loader, test_loader, None, None\n\n\ndef get_esimnet_data(batch_size, step,root=DATA_DIR, **kwargs):\n    \"\"\"\n    获取ES imagenet数据\n    DOI: 10.3389/fnins.2021.726582\n    :param batch_size: batch size\n    :param step: 仿真步长，固定为8\n    :param reconstruct: 重构则时间步为1, 否则为8\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    :note: 没有自动下载, 下载及md5请参考spikingjelly, sampler默认为DistributedSampler\n    \"\"\"\n\n    reconstruct = kwargs[\"reconstruct\"] if \"reconstruct\" in kwargs else False\n\n    train_transform = transforms.Compose([\n        transforms.RandomHorizontalFlip(),\n        transforms.RandomRotation(15)\n    ])\n    test_transform = transforms.Compose([\n        lambda x: dvs_channel_check_expend(x),\n    ])\n\n    if reconstruct:\n        assert step == 1\n        train_dataset = ESImagenet2D_Dataset(mode='train',\n                                            data_set_path=os.path.join(root, 'DVS/ES-imagenet-0.18/extract/ES-imagenet-0.18/'),\n                                            transform=train_transform)\n\n        test_dataset = ESImagenet2D_Dataset(mode='test',\n                                            data_set_path=os.path.join(root, 'DVS/ES-imagenet-0.18/extract/ES-imagenet-0.18/'),\n                                            transform=test_transform)\n    else:\n        assert step == 8\n        train_dataset = ESImagenet_Dataset(mode='train',\n                                             data_set_path=os.path.join(root,\n                                                                        'DVS/ES-imagenet-0.18/extract/ES-imagenet-0.18/'),\n                                             transform=train_transform)\n\n        test_dataset = ESImagenet_Dataset(mode='test',\n                                            data_set_path=os.path.join(root,\n                                                                       'DVS/ES-imagenet-0.18/extract/ES-imagenet-0.18/'),\n                                            transform=test_transform)\n\n\n    mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n    mixup_active = cut_mix | event_mix | mix_up\n\n    if cut_mix:\n        train_dataset = CutMix(train_dataset,\n                               beta=beta,\n                               prob=prob,\n                               num_mix=num,\n                               num_class=num_classes,\n                               noise=noise)\n\n    if event_mix:\n        train_dataset = EventMix(train_dataset,\n                                 beta=beta,\n                                 prob=prob,\n                                 num_mix=num,\n                                 num_class=num_classes,\n                                 noise=noise,\n                                 gaussian_n=gaussian_n)\n    if mix_up:\n        train_dataset = MixUp(train_dataset,\n                              beta=beta,\n                              prob=prob,\n                              num_mix=num,\n                              num_class=num_classes,\n                              noise=noise)\n\n    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)\n    test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size,\n        pin_memory=True, drop_last=True, num_workers=8,\n        shuffle=False, sampler=train_sampler\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=1,\n        shuffle=False, sampler=test_sampler\n    )\n\n    return train_loader, test_loader, mixup_active, None\n\n\ndef get_nmnist_data(batch_size, step, **kwargs):\n    \"\"\"\n    获取N-MNIST数据\n    http://journal.frontiersin.org/Article/10.3389/fnins.2015.00437/abstract\n    :param batch_size: batch size\n    :param step: 仿真步长\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    sensor_size = tonic.datasets.NMNIST.sensor_size\n    size = kwargs['size'] if 'size' in kwargs else 34\n\n    train_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        # tonic.transforms.DropEvent(p=0.1),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),\n    ])\n    test_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),\n    ])\n\n    train_dataset = tonic.datasets.NMNIST(os.path.join(DATA_DIR, 'DVS/N-MNIST'),\n                                              transform=train_transform, train=True)\n    test_dataset = tonic.datasets.NMNIST(os.path.join(DATA_DIR, 'DVS/N-MNIST'),\n                                             transform=test_transform, train=False)\n\n    train_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        lambda x: dvs_channel_check_expend(x),\n        # transforms.RandomCrop(size, padding=size // 12),\n        # transforms.RandomHorizontalFlip(),\n        # transforms.RandomRotation(15)\n    ])\n    test_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        lambda x: dvs_channel_check_expend(x),\n    ])\n    if 'rand_aug' in kwargs.keys():\n        if kwargs['rand_aug'] is True:\n            n = kwargs['randaug_n']\n            m = kwargs['randaug_m']\n            train_transform.transforms.insert(2, RandAugment(m=m, n=n))\n\n    # if 'temporal_flatten' in kwargs.keys():\n    #     if kwargs['temporal_flatten'] is True:\n    #         train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n    #         test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n\n    train_dataset = DiskCachedDataset(train_dataset,\n                                      cache_path=os.path.join(DATA_DIR, 'DVS/N-MNIST/train_cache_{}'.format(step)),\n                                      transform=train_transform, num_copies=3)\n    test_dataset = DiskCachedDataset(test_dataset,\n                                     cache_path=os.path.join(DATA_DIR, 'DVS/N-MNIST/test_cache_{}'.format(step)),\n                                     transform=test_transform, num_copies=3)\n\n    mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n    mixup_active = cut_mix | event_mix | mix_up\n\n    if cut_mix:\n        train_dataset = CutMix(train_dataset,\n                               beta=beta,\n                               prob=prob,\n                               num_mix=num,\n                               num_class=num_classes,\n                               noise=noise)\n\n    if event_mix:\n        train_dataset = EventMix(train_dataset,\n                                 beta=beta,\n                                 prob=prob,\n                                 num_mix=num,\n                                 num_class=num_classes,\n                                 noise=noise,\n                                 gaussian_n=gaussian_n)\n    if mix_up:\n        train_dataset = MixUp(train_dataset,\n                              beta=beta,\n                              prob=prob,\n                              num_mix=num,\n                              num_class=num_classes,\n                              noise=noise)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size,\n        pin_memory=True, drop_last=True, num_workers=8,\n        shuffle=True,\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=2,\n        shuffle=False,\n    )\n\n    return train_loader, test_loader, mixup_active, None\n\n\ndef get_ntidigits_data(batch_size, step, **kwargs):\n    \"\"\"\n    获取N-TIDIGITS数据 (tonic 新版本中的下载链接可能挂了，可以参考0.4.0的版本)\n    https://www.frontiersin.org/articles/10.3389/fnins.2018.00023/full\n    :param batch_size: batch size\n    :param step: 仿真步长\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    :format: (b,t,c,len) 不同于vision, audio中c为1, 并且没有h,w; 只有len=64\n    \"\"\"\n    sensor_size = tonic.datasets.NTIDIGITS.sensor_size\n    train_transform = transforms.Compose([\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: x.squeeze(1)\n    ])\n    test_transform = transforms.Compose([\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: x.squeeze(1)\n    ])\n\n    train_dataset = tonic.datasets.NTIDIGITS(os.path.join(DATA_DIR, 'DVS/NTIDIGITS'),\n                                              transform=train_transform, train=True)\n\n    test_dataset = tonic.datasets.NTIDIGITS(os.path.join(DATA_DIR, 'DVS/NTIDIGITS'),\n                                             transform=test_transform, train=False)\n\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size,\n        pin_memory=True, drop_last=True, num_workers=8,\n        shuffle=True,\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=2,\n        shuffle=False,\n    )\n\n    return train_loader, test_loader, None, None\n\n\ndef get_shd_data(batch_size, step, **kwargs):\n    \"\"\"\n    获取SHD数据\n    https://ieeexplore.ieee.org/abstract/document/9311226\n    :param batch_size: batch size\n    :param step: 仿真步长\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    :format: (b,t,c,len) 不同于vision, audio中c为1, 并且没有h,w; 只有len=700. Transform后变为(b, t, len)\n    \"\"\"\n    sensor_size = tonic.datasets.SHD.sensor_size\n    train_transform = transforms.Compose([\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step)\n    ])\n    test_transform = transforms.Compose([\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step)\n    ])\n\n    train_dataset = tonic.datasets.SHD(os.path.join(DATA_DIR, 'DVS/SHD'),\n                                              transform=train_transform, train=True)\n\n    test_dataset = tonic.datasets.SHD(os.path.join(DATA_DIR, 'DVS/SHD'),\n                                             transform=test_transform, train=False)\n\n    train_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: x.squeeze(1)\n    ])\n\n    test_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: x.squeeze(1)\n    ])\n\n    train_dataset = DiskCachedDataset(train_dataset,\n                                      cache_path=os.path.join(DATA_DIR, 'DVS/SHD/train_cache_{}'.format(step)),\n                                      transform=train_transform, num_copies=3)\n    test_dataset = DiskCachedDataset(test_dataset,\n                                     cache_path=os.path.join(DATA_DIR, 'DVS/SHD/test_cache_{}'.format(step)),\n                                     transform=test_transform, num_copies=3)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=8,\n        shuffle=True,\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=2,\n        shuffle=False,\n    )\n\n    return train_loader, test_loader, None, None\n\n\ndef get_CUB2002011_data(batch_size, num_workers=8, same_da=False,root=DATA_DIR, *args, **kwargs):\n    train_transform = transforms.Compose([\n        transforms.RandomResizedCrop(224),\n        transforms.RandomHorizontalFlip(), \n        transforms.ToTensor(),\n        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n    ])\n    test_transform = transforms.Compose([\n        transforms.Resize(256),\n        transforms.CenterCrop(224),\n        transforms.ToTensor(),\n        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n    ])\n    root=os.path.join(root, 'CUB2002011')\n    train_datasets = CUB2002011(\n        root=root, train=True, transform=test_transform if same_da else train_transform, download=True)\n    test_datasets = CUB2002011(\n        root=root, train=False, transform=test_transform, download=True)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=num_workers\n    )\n\n    return train_loader, test_loader, False, None\n\ndef get_StanfordCars_data(batch_size, num_workers=8, same_da=False,root=DATA_DIR, *args, **kwargs):\n    train_transform = transforms.Compose([\n        transforms.RandomResizedCrop(224),\n        transforms.RandomHorizontalFlip(), \n        transforms.ToTensor(),\n        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n    ])\n    test_transform = transforms.Compose([\n        transforms.Resize(256),\n        transforms.CenterCrop(224),\n        transforms.ToTensor(),\n        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n    ])\n    root=os.path.join(root, 'StanfordCars')\n    train_datasets = datasets.StanfordCars(\n        root=root, split =\"train\", transform=test_transform if same_da else train_transform, download=True)\n    test_datasets = datasets.StanfordCars(\n        root=root, split =\"test\", transform=test_transform, download=True)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=num_workers\n    )\n\n    return train_loader, test_loader, False, None\n\ndef get_StanfordDogs_data(batch_size, num_workers=8, same_da=False,root=DATA_DIR, *args, **kwargs):\n    train_transform = transforms.Compose([\n        transforms.RandomResizedCrop(224),\n        transforms.RandomHorizontalFlip(), \n        transforms.ToTensor(),\n        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n    ])\n    test_transform = transforms.Compose([\n        transforms.Resize(256),\n        transforms.CenterCrop(224),\n        transforms.ToTensor(),\n        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n    ])\n    root=os.path.join(root, 'StanfordDogs')\n    train_datasets = StanfordDogs(\n        root=root, train=True, transform=test_transform if same_da else train_transform, download=True)\n    test_datasets = StanfordDogs(\n        root=root, train=False, transform=test_transform, download=True)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=num_workers\n    )\n\n    return train_loader, test_loader, False, None\n\n\ndef get_FGVCAircraft_data(batch_size, num_workers=8, same_da=False,root=DATA_DIR, *args, **kwargs):\n    train_transform = transforms.Compose([\n        transforms.RandomResizedCrop(224),\n        transforms.RandomHorizontalFlip(), \n        transforms.ToTensor(),\n        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n    ])\n    test_transform = transforms.Compose([\n        transforms.Resize(256),\n        transforms.CenterCrop(224),\n        transforms.ToTensor(),\n        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n    ])\n    root=os.path.join(root, 'FGVCAircraft')\n    train_datasets = datasets.FGVCAircraft(\n        root=root, split=\"train\", transform=test_transform if same_da else train_transform, download=True)\n    test_datasets = datasets.FGVCAircraft(\n        root=root, split=\"test\", transform=test_transform, download=True)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=num_workers\n    )\n\n    return train_loader, test_loader, False, None\n\n\ndef get_Flowers102_data(batch_size, num_workers=8, same_da=False,root=DATA_DIR, *args, **kwargs):\n    train_transform = transforms.Compose([\n        transforms.RandomResizedCrop(224),\n        transforms.RandomHorizontalFlip(), \n        transforms.ToTensor(),\n        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n    ])\n    test_transform = transforms.Compose([\n        transforms.Resize(256),\n        transforms.CenterCrop(224),\n        transforms.ToTensor(),\n        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n    ])\n    root=os.path.join(root, 'Flowers102')\n    train_datasets = datasets.Flowers102(\n        root=root, split=\"train\", transform=test_transform if same_da else train_transform, download=True)\n    test_datasets = datasets.Flowers102(\n        root=root, split=\"test\", transform=test_transform, download=True)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=num_workers\n    )\n\n    return train_loader, test_loader, False, None\n\n\ndef get_UCF101DVS_data(batch_size, step, **kwargs):\n    \"\"\"\n    获取DVS CIFAR10数据\n    http://journal.frontiersin.org/article/10.3389/fnins.2017.00309/full\n    :param batch_size: batch size\n    :param step: 仿真步长\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    size = kwargs['size'] if 'size' in kwargs else 48\n    sensor_size = braincog.datasets.ucf101_dvs.UCF101DVS.sensor_size\n    train_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        # tonic.transforms.DropEvent(p=0.1),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    test_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    train_dataset = braincog.datasets.ucf101_dvs.UCF101DVS(os.path.join(DATA_DIR, 'UCF101DVS'), train=True, transform=train_transform)\n    test_dataset = braincog.datasets.ucf101_dvs.UCF101DVS(os.path.join(DATA_DIR, 'UCF101DVS'), train=False, transform=test_transform)\n\n    train_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        # lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        # lambda x: TemporalShift(x, .01),\n        # lambda x: drop(x, 0.15),\n        # lambda x: ShearX(x, 15),\n        # lambda x: ShearY(x, 15),\n        # lambda x: TranslateX(x, 0.225),\n        # lambda x: TranslateY(x, 0.225),\n        # lambda x: Rotate(x, 15),\n        # lambda x: CutoutAbs(x, 0.25),\n        # lambda x: CutoutTemporal(x, 0.25),\n        # lambda x: GaussianBlur(x, 0.5),\n        # lambda x: SaltAndPepperNoise(x, 0.1),\n        # transforms.Normalize(DVSCIFAR10_MEAN_16, DVSCIFAR10_STD_16),\n        # transforms.RandomCrop(size, padding=size // 12),\n        transforms.RandomHorizontalFlip(),\n        # transforms.RandomRotation(15)\n    ])\n    test_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        # lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n    ])\n\n    if 'rand_aug' in kwargs.keys():\n        if kwargs['rand_aug'] is True:\n            n = kwargs['randaug_n']\n            m = kwargs['randaug_m']\n            # print('randaug', m, n)\n            train_transform.transforms.insert(2, RandAugment(m=m, n=n))\n\n    # if 'temporal_flatten' in kwargs.keys():\n    #     if kwargs['temporal_flatten'] is True:\n    #         train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n    #         test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n\n    train_dataset = DiskCachedDataset(train_dataset,\n                                      cache_path=os.path.join(DATA_DIR, 'UCF101DVS/train_cache_{}'.format(step)),\n                                      transform=train_transform)\n    test_dataset = DiskCachedDataset(test_dataset,\n                                     cache_path=os.path.join(DATA_DIR, 'UCF101DVS/test_cache_{}'.format(step)),\n                                     transform=test_transform)\n\n    mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n    mixup_active = cut_mix | event_mix | mix_up\n\n    if cut_mix:\n        # print('cut_mix', beta, prob, num, num_classes)\n        train_dataset = CutMix(train_dataset,\n                               beta=beta,\n                               prob=prob,\n                               num_mix=num,\n                               num_class=num_classes,\n                               noise=noise)\n\n    if event_mix:\n        train_dataset = EventMix(train_dataset,\n                                 beta=beta,\n                                 prob=prob,\n                                 num_mix=num,\n                                 num_class=num_classes,\n                                 noise=noise,\n                                 gaussian_n=gaussian_n)\n\n    if mix_up:\n        train_dataset = MixUp(train_dataset,\n                              beta=beta,\n                              prob=prob,\n                              num_mix=num,\n                              num_class=num_classes,\n                              noise=noise)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size, shuffle=True,\n        pin_memory=True, drop_last=True, num_workers=8\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size, shuffle=False,\n        pin_memory=True, drop_last=False, num_workers=2\n    )\n\n    return train_loader, test_loader, mixup_active, None\n\n\ndef get_HMDBDVS_data(batch_size, step, **kwargs):\n    sensor_size = braincog.datasets.hmdb_dvs.HMDBDVS.sensor_size\n\n    train_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        # tonic.transforms.DropEvent(p=0.1),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    test_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n\n    train_dataset = braincog.datasets.hmdb_dvs.HMDBDVS(os.path.join(DATA_DIR, 'HMDBDVS'), transform=train_transform)\n    test_dataset = braincog.datasets.hmdb_dvs.HMDBDVS(os.path.join(DATA_DIR, 'HMDBDVS'), transform=test_transform)\n\n    cls_count = train_dataset.cls_count\n    dataset_length = train_dataset.length\n\n    portion = .5\n    # portion = kwargs['portion'] if 'portion' in kwargs else .9\n    size = kwargs['size'] if 'size' in kwargs else 48\n    # print('portion', portion)\n    train_sample_weight = []\n    train_sample_index = []\n    train_count = 0\n    test_sample_index = []\n    idx_begin = 0\n    for count in cls_count:\n        sample_weight = dataset_length / count\n        train_sample = round(portion * count)\n        test_sample = count - train_sample\n        train_count += train_sample\n        train_sample_weight.extend(\n            [sample_weight] * train_sample\n        )\n        train_sample_weight.extend(\n            [0.] * test_sample\n        )\n        lst = list(range(idx_begin, idx_begin + train_sample + test_sample))\n        random.seed(0)\n        random.shuffle(lst)\n        train_sample_index.extend(\n            lst[:train_sample]\n            # list((range(idx_begin, idx_begin + train_sample)))\n        )\n        test_sample_index.extend(\n            lst[train_sample:train_sample + test_sample]\n            # list(range(idx_begin + train_sample, idx_begin + train_sample + test_sample))\n        )\n        idx_begin += count\n\n    train_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_sample_weight, train_count)\n    test_sampler = torch.utils.data.sampler.SubsetRandomSampler(test_sample_index)\n\n    train_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        # lambda x: print(x.shape),\n        # lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        # transforms.RandomCrop(size, padding=size // 12),\n        # transforms.RandomHorizontalFlip(),\n        # transforms.RandomRotation(15)\n    ])\n    test_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        # lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        # lambda x: temporal_flatten(x),\n    ])\n    if 'rand_aug' in kwargs.keys():\n        if kwargs['rand_aug'] is True:\n            n = kwargs['randaug_n']\n            m = kwargs['randaug_m']\n            train_transform.transforms.insert(2, RandAugment(m=m, n=n))\n\n    # if 'temporal_flatten' in kwargs.keys():\n    #     if kwargs['temporal_flatten'] is True:\n    #         train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n    #         test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n\n    train_dataset = DiskCachedDataset(train_dataset,\n                                      cache_path=os.path.join(DATA_DIR, 'HMDBDVS/train_cache_{}'.format(step)),\n                                      transform=train_transform, num_copies=3)\n    test_dataset = DiskCachedDataset(test_dataset,\n                                     cache_path=os.path.join(DATA_DIR, 'HMDBDVS/test_cache_{}'.format(step)),\n                                     transform=test_transform, num_copies=3)\n\n    mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n    mixup_active = cut_mix | event_mix | mix_up\n\n    if cut_mix:\n        train_dataset = CutMix(train_dataset,\n                               beta=beta,\n                               prob=prob,\n                               num_mix=num,\n                               num_class=num_classes,\n                               indices=train_sample_index,\n                               noise=noise)\n\n    if event_mix:\n        train_dataset = EventMix(train_dataset,\n                                 beta=beta,\n                                 prob=prob,\n                                 num_mix=num,\n                                 num_class=num_classes,\n                                 indices=train_sample_index,\n                                 noise=noise,\n                                 gaussian_n=gaussian_n)\n    if mix_up:\n        train_dataset = MixUp(train_dataset,\n                              beta=beta,\n                              prob=prob,\n                              num_mix=num,\n                              num_class=num_classes,\n                              indices=train_sample_index,\n                              noise=noise)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size,\n        sampler=train_sampler,\n        pin_memory=True, drop_last=True, num_workers=8\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        sampler=test_sampler,\n        pin_memory=True, drop_last=False, num_workers=2\n    )\n\n    return train_loader, test_loader, mixup_active, None\n"
  },
  {
    "path": "braincog/datasets/gen_input_signal.py",
    "content": "import numpy as np\nimport random\nimport copy\n\ndt = 1.0                  # ms\nlambda_max = 0.25 * dt  # maximum spike rate (spikes per time step)\n\n\n"
  },
  {
    "path": "braincog/datasets/hmdb_dvs/__init__.py",
    "content": "# encoding: utf-8\n# Author    : Floyed<Floyed_Shen@outlook.com>\n# Datetime  : 2023/1/30 20:54\n# User      : yu\n# Product   : PyCharm\n# Project   : BrainCog\n# File      : __init__.py\n# explain   :\n\nfrom .hmdb_dvs import HMDBDVS\n\n__all__ = [\n    'HMDBDVS'\n]"
  },
  {
    "path": "braincog/datasets/hmdb_dvs/hmdb_dvs.py",
    "content": "# encoding: utf-8\n# Author    : Floyed<Floyed_Shen@outlook.com>\n# Datetime  : 2023/1/30 20:54\n# User      : yu\n# Product   : PyCharm\n# Project   : BrainCog\n# File      : hmdb_dvs.py\n# explain   :\n\nimport os\nimport numpy as np\nfrom numpy.lib import recfunctions\nimport scipy.io as scio\nfrom typing import Tuple, Any, Optional\nfrom tonic.dataset import Dataset\nfrom tonic.download_utils import extract_archive\n\nclass HMDBDVS(Dataset):\n    \"\"\"ASL-DVS dataset <https://github.com/PIX2NVS/NVS2Graph>. Events have (txyp) ordering.\n    ::\n\n        @inproceedings{bi2019graph,\n            title={Graph-based Object Classification for Neuromorphic Vision Sensing},\n            author={Bi, Y and Chadha, A and Abbas, A and and Bourtsoulatze, E and Andreopoulos, Y},\n            booktitle={2019 IEEE International Conference on Computer Vision (ICCV)},\n            year={2019},\n            organization={IEEE}\n        }\n\n    Parameters:\n        save_to (string): Location to save files to on disk.\n        transform (callable, optional): A callable of transforms to apply to the data.\n        target_transform (callable, optional): A callable of transforms to apply to the targets/labels.\n    \"\"\"\n\n    sensor_size = (240, 180, 2)\n    dtype = np.dtype([(\"t\", int), (\"x\", int), (\"y\", int), (\"p\", int)])\n    ordering = dtype.names\n\n    def __init__(self, save_to, transform=None, target_transform=None):\n        super(HMDBDVS, self).__init__(\n            save_to, transform=transform, target_transform=target_transform\n        )\n\n        if not self._check_exists():\n            raise NotImplementedError(\n                'Please manually download the dataset from'\n                ' https://www.dropbox.com/sh/ie75dn246cacf6n/AACoU-_zkGOAwj51lSCM0JhGa?dl=0 '\n                'and extract it to {}'.format(self.location_on_system))\n\n        classes = os.listdir(self.location_on_system)\n        self.int_classes = dict(zip(classes, range(len(classes))))\n\n        for path, dirs, files in os.walk(self.location_on_system):\n            dirs.sort()\n            files.sort()\n            for file in files:\n                if file.endswith(\"mat\"):\n                    fsize = os.path.getsize(path + '/' + file) / float(1024)\n                    if fsize < 1:\n                        # print('{} size {} K'.format(file, fsize))\n                        continue\n                    self.data.append(path + \"/\" + file)\n                    self.targets.append(self.int_classes[path.split('/')[-1]])\n\n        self.length = self.__len__()\n        self.cls_count = np.bincount(self.targets)\n\n    def __getitem__(self, index: int) -> Tuple[Any, Any]:\n        \"\"\"\n        Returns:\n            (events, target) where target is index of the target class.\n        \"\"\"\n        events, target = scio.loadmat(self.data[index]), self.targets[index]\n        events = np.column_stack(\n            [\n                events[\"ts\"],\n                events[\"x\"],\n                self.sensor_size[1] - 1 - events[\"y\"],\n                events[\"pol\"],\n            ]\n        )\n        events = np.lib.recfunctions.unstructured_to_structured(events, self.dtype)\n        if self.transform is not None:\n            events = self.transform(events)\n        if self.target_transform is not None:\n            target = self.target_transform(target)\n        return events, target\n\n    def __len__(self):\n        return len(self.data)\n\n    def _check_exists(self):\n        return self._folder_contains_at_least_n_files_of_type(\n            6765, \".mat\"\n        )\n"
  },
  {
    "path": "braincog/datasets/ncaltech101/__init__.py",
    "content": "# encoding: utf-8\n# Author    : Floyed<Floyed_Shen@outlook.com>\n# Datetime  : 2023/1/30 21:26\n# User      : yu\n# Product   : PyCharm\n# Project   : BrainCog\n# File      : __init__.py.py\n# explain   :\n\nfrom .ncaltech101 import NCALTECH101\n\n__all__ = [\n    'NCALTECH101'\n]"
  },
  {
    "path": "braincog/datasets/ncaltech101/ncaltech101.py",
    "content": "# encoding: utf-8\n# Author    : Floyed<Floyed_Shen@outlook.com>\n# Datetime  : 2023/1/30 21:28\n# User      : yu\n# Product   : PyCharm\n# Project   : BrainCog\n# File      : ncaltech101.py\n# explain   :\nimport os\nimport numpy as np\n\nfrom tonic.io import read_mnist_file\nfrom tonic.dataset import Dataset\nfrom tonic.download_utils import extract_archive\n\n\nclass NCALTECH101(Dataset):\n    \"\"\"N-CALTECH101 dataset <https://www.garrickorchard.com/datasets/n-caltech101>. Events have (xytp) ordering.\n    ::\n\n        @article{orchard2015converting,\n          title={Converting static image datasets to spiking neuromorphic datasets using saccades},\n          author={Orchard, Garrick and Jayawant, Ajinkya and Cohen, Gregory K and Thakor, Nitish},\n          journal={Frontiers in neuroscience},\n          volume={9},\n          pages={437},\n          year={2015},\n          publisher={Frontiers}\n        }\n\n    Parameters:\n        save_to (string): Location to save files to on disk.\n        transform (callable, optional): A callable of transforms to apply to the data.\n        target_transform (callable, optional): A callable of transforms to apply to the targets/labels.\n    \"\"\"\n\n    url = \"https://data.mendeley.com/public-files/datasets/cy6cvx3ryv/files/36b5c52a-b49d-4853-addb-a836a8883e49/file_downloaded\"\n    filename = \"N-Caltech101-archive.zip\"\n    file_md5 = \"66201824eabb0239c7ab992480b50ba3\"\n    data_filename = \"N-Caltech101-archive.zip\"\n    folder_name = \"Caltech101\"\n    cls_count = [467,\n                 435, 200, 798, 55, 800, 42, 42, 47, 54, 46,\n                 33, 128, 98, 43, 85, 91, 50, 43, 123, 47,\n                 59, 62, 107, 47, 69, 73, 70, 50, 51, 57,\n                 67, 52, 65, 68, 75, 64, 53, 64, 85, 67,\n                 67, 45, 34, 34, 51, 99, 100, 42, 54, 88,\n                 80, 31, 64, 86, 114, 61, 81, 78, 41, 66,\n                 43, 40, 87, 32, 76, 55, 35, 39, 47, 38,\n                 45, 53, 34, 57, 82, 59, 49, 40, 63, 39,\n                 84, 57, 35, 64, 45, 86, 59, 64, 35, 85,\n                 49, 86, 75, 239, 37, 59, 34, 56, 39, 60]\n    # length = 8242\n    length = 8709\n\n    sensor_size = None  # all recordings are of different size\n    dtype = np.dtype([(\"x\", int), (\"y\", int), (\"t\", int), (\"p\", int)])\n    ordering = dtype.names\n\n    def __init__(self, save_to, transform=None, target_transform=None):\n        super(NCALTECH101, self).__init__(\n            save_to, transform=transform, target_transform=target_transform\n        )\n\n        classes = {\n            'BACKGROUND_Google': 0,\n            'Faces_easy': 1,\n            'Leopards': 2,\n            'Motorbikes': 3,\n            'accordion': 4,\n            'airplanes': 5,\n            'anchor': 6,\n            'ant': 7,\n            'barrel': 8,\n            'bass': 9,\n            'beaver': 10,\n            'binocular': 11,\n            'bonsai': 12,\n            'brain': 13,\n            'brontosaurus': 14,\n            'buddha': 15,\n            'butterfly': 16,\n            'camera': 17,\n            'cannon': 18,\n            'car_side': 19,\n            'ceiling_fan': 20,\n            'cellphone': 21,\n            'chair': 22,\n            'chandelier': 23,\n            'cougar_body': 24,\n            'cougar_face': 25,\n            'crab': 26,\n            'crayfish': 27,\n            'crocodile': 28,\n            'crocodile_head': 29,\n            'cup': 30,\n            'dalmatian': 31,\n            'dollar_bill': 32,\n            'dolphin': 33,\n            'dragonfly': 34,\n            'electric_guitar': 35,\n            'elephant': 36,\n            'emu': 37,\n            'euphonium': 38,\n            'ewer': 39,\n            'ferry': 40,\n            'flamingo': 41,\n            'flamingo_head': 42,\n            'garfield': 43,\n            'gerenuk': 44,\n            'gramophone': 45,\n            'grand_piano': 46,\n            'hawksbill': 47,\n            'headphone': 48,\n            'hedgehog': 49,\n            'helicopter': 50,\n            'ibis': 51,\n            'inline_skate': 52,\n            'joshua_tree': 53,\n            'kangaroo': 54,\n            'ketch': 55,\n            'lamp': 56,\n            'laptop': 57,\n            'llama': 58,\n            'lobster': 59,\n            'lotus': 60,\n            'mandolin': 61,\n            'mayfly': 62,\n            'menorah': 63,\n            'metronome': 64,\n            'minaret': 65,\n            'nautilus': 66,\n            'octopus': 67,\n            'okapi': 68,\n            'pagoda': 69,\n            'panda': 70,\n            'pigeon': 71,\n            'pizza': 72,\n            'platypus': 73,\n            'pyramid': 74,\n            'revolver': 75,\n            'rhino': 76,\n            'rooster': 77,\n            'saxophone': 78,\n            'schooner': 79,\n            'scissors': 80,\n            'scorpion': 81,\n            'sea_horse': 82,\n            'snoopy': 83,\n            'soccer_ball': 84,\n            'stapler': 85,\n            'starfish': 86,\n            'stegosaurus': 87,\n            'stop_sign': 88,\n            'strawberry': 89,\n            'sunflower': 90,\n            'tick': 91,\n            'trilobite': 92,\n            'umbrella': 93,\n            'watch': 94,\n            'water_lilly': 95,\n            'wheelchair': 96,\n            'wild_cat': 97,\n            'windsor_chair': 98,\n            'wrench': 99,\n            'yin_yang': 100,\n        }\n\n        # if not self._check_exists():\n            # self.download()\n            # extract_archive(os.path.join(self.location_on_system, self.data_filename))\n\n        file_path = os.path.join(self.location_on_system, self.folder_name)\n        for path, dirs, files in os.walk(file_path):\n            dirs.sort()\n            # if 'BACKGROUND_Google' in path:\n            #     continue\n            for file in files:\n                if file.endswith(\"bin\"):\n                    self.data.append(path + \"/\" + file)\n                    label_name = os.path.basename(path)\n\n                    if isinstance(label_name, bytes):\n                        label_name = label_name.decode()\n                    self.targets.append(classes[label_name])\n\n    def __getitem__(self, index):\n        \"\"\"\n        Returns:\n            a tuple of (events, target) where target is the index of the target class.\n        \"\"\"\n        events = read_mnist_file(self.data[index], dtype=self.dtype)\n        target = self.targets[index]\n        events[\"x\"] -= events[\"x\"].min()\n        events[\"y\"] -= events[\"y\"].min()\n        if self.transform is not None:\n            events = self.transform(events)\n        if self.target_transform is not None:\n            target = self.target_transform(target)\n        return events, target\n\n    def __len__(self):\n        return len(self.data)\n\n    def _check_exists(self):\n        return self._is_file_present() and self._folder_contains_at_least_n_files_of_type(\n            8709, \".bin\"\n        )\n"
  },
  {
    "path": "braincog/datasets/rand_aug.py",
    "content": "import random\nimport numpy as np\nimport torch\nfrom torchvision import transforms\nfrom torchvision.transforms import functional\nfrom torchvision.transforms import InterpolationMode\n\n\ndef ShearX(x, v):  # [-0.3, 0.3]\n    assert 0 <= v <= 30\n    v = np.random.uniform(0, v)\n    if random.random() > 0.5:\n        v = -v\n    return functional.affine(x, angle=0, translate=[0, 0], scale=1., shear=[v, 0])\n\n\ndef ShearY(x, v):  # [-0.3, 0.3]\n    assert 0 <= v <= 30\n    v = np.random.uniform(0, v)\n    if random.random() > 0.5:\n        v = -v\n    return functional.affine(x, angle=0, translate=[0, 0], scale=1., shear=[0, v])\n\n\ndef TranslateX(x, v):\n    assert 0 <= v <= 0.45\n    v = np.random.uniform(0, v)\n    w, h = x.shape[-2::]\n    v = round(w * v)\n    if random.random() > 0.5:\n        v = -v\n    return functional.affine(x, angle=0, translate=[0, v], scale=1., shear=[0, 0])\n\n\ndef TranslateY(x, v):\n    assert 0 <= v <= 0.45\n    v = np.random.uniform(0, v)\n    w, h = x.shape[-2::]\n    v = round(w * v)\n    if random.random() > 0.5:\n        v = -v\n    return functional.affine(x, angle=0, translate=[v, 0], scale=1., shear=[0, 0])\n\n\ndef Rotate(x, v):  # [-30, 30]\n    assert 0 <= v <= 30\n    v = np.random.uniform(0, v)\n    if random.random() > 0.5:\n        v = -v\n    return functional.affine(x, angle=v, translate=[0, 0], scale=1., shear=[0, 0])\n\n\ndef CutoutAbs(x, v):  # [0, 60] => percentage: [0, 0.2]\n    assert 0 <= v <= 0.5\n    w, h = x.shape[-2::]\n    v = round(v * w)\n\n    x0 = np.random.uniform(w)\n    y0 = np.random.uniform(h)\n\n    x0 = round(max(0, x0 - v / 2.))\n    y0 = round(max(0, y0 - v / 2.))\n    x1 = min(w, x0 + v)\n    y1 = min(h, y0 + v)\n\n    x[:, :, y0:y1, x0:x1] = 0.\n    return x\n\n\ndef CutoutTemporal(x, v):\n    assert 0 <= v <= 0.5\n    v = np.random.uniform(0, v)\n    step = x.shape[0]\n    v = round(v * step)\n    t0 = np.random.randint(step)\n    t1 = min(step, t0 + v)\n    x[t0:t1, :, :, :] = 0.\n    return x\n\n\ndef TemporalShift(x, v):\n    # TODO: Maybe shift too mach than origin has\n    assert 0 <= v <= 0.2\n    v = v / 2.\n    shape = x.shape\n    # p = torch.zeros(2 * (shape[0] - 1), *shape[-3:], device=x.device)\n    shift = []\n    for i in range(x.shape[0] - 1):\n        spike = x[i].clone()\n        _max = int(spike.max())\n        sft = torch.zeros(shape[-3:], device=x.device)\n        for j in range(_max):\n            p = torch.rand_like(sft)\n            sft[torch.logical_and(p < v, spike > 0.)] += 1.\n            spike -= 1\n        shift.append(sft)\n\n        spike = x[i + 1].clone()\n        _max = int(spike.max())\n        sft = torch.zeros(shape[-3:], device=x.device)\n        for j in range(_max):\n            p = torch.rand_like(sft)\n            sft[torch.logical_and(p < v, spike > 0.)] += 1.\n        shift.append(sft)\n\n    for i in range(shape[0] - 1):\n        sft_next = shift[i * 2]\n        sft_pre = shift[i * 2 + 1]\n        x[i + 1] = torch.clip(x[i + 1] + sft_next - sft_pre, 0.)\n        x[i] = torch.clip(x[i] - sft_next + sft_pre, 0.)\n\n    return x\n\n\ndef SpatioShift(x, v):\n    # assert 0 <= v <= 0.1\n    w, h = x.shape[-2::]\n    shift_x = round(random.uniform(-v, v) * w)\n    shift_y = round(random.uniform(-v, v) * h)\n    output = []\n    step = x.shape[0]\n    for t in range(step):\n        output.append(functional.affine(x[t],\n                                        angle=0,\n                                        translate=[\n                                            round(shift_x * t / step),\n                                            round(shift_y * t / step)],\n                                        scale=1.,\n                                        shear=[0, 0]))\n    return torch.stack(output, dim=0)\n\n\ndef drop(x, v):\n    assert 0 <= v <= 0.5\n    v = np.random.uniform(0, v)\n    _max = int(torch.max(x))\n    p = torch.rand((_max, *x.shape), device=x.device)\n\n    for i in range(_max):\n        p[i, x > 0] += 1.\n        x -= 1.\n\n    p = torch.where(p > 1. + v, 1., 0.)\n    return torch.sum(p, dim=0)\n\n\ndef GaussianBlur(x, v):\n    assert 0.1 <= v <= 1.\n    v = np.random.uniform(0.1, v)\n    return functional.gaussian_blur(x, kernel_size=[5, 5], sigma=v)\n\n\ndef SaltAndPepperNoise(x, v):\n    assert 0 <= v <= 0.3\n    v = np.random.uniform(0, v)\n    p = torch.rand_like(x)\n    p = torch.where(p > v, 0., 1.)\n    return x + p\n\n\ndef Identity(x, v):\n    return x\n\n#                           DVSC10 NCAL\naugment_list = [  # normal: 79.44  77.57\n    # (ShearX, 0, 20),  # 75.71\n    # (ShearY, 0, 20),\n    # (TranslateX, 0, 0.25),  # 77.52\n    # (TranslateY, 0, 0.25),\n    # (Rotate, 0, 30),  # 77.02\n    (CutoutAbs, 0, 0.5),  # 79.13\n    (CutoutTemporal, 0, 0.5),  # 80.65\n    # (TemporalShift, 0, 0.2), # 75.30\n    # (SpatioShift, 0, 0.1),   # 78.43\n    (GaussianBlur, 0, 1.),  # 79.83\n    # (drop, 0, 0.5),  # 74.00\n    (SaltAndPepperNoise, 0, 0.3),  # 79.64\n    # cutmix_normal_aug: 90.02  86.52\n]\n\n\nclass RandAugment:\n    def __init__(self, n, m):\n        self.n = n\n        self.m = m      # [0, 30]\n        self.augment_list = augment_list\n\n    def __call__(self, x):\n        ops = random.choices(self.augment_list, k=self.n)\n        for op, minvalue, maxvalue in ops:\n            val = (float(self.m) / 30) * float(maxvalue - minvalue) + minvalue\n            x = op(x, val)\n\n        return x\n"
  },
  {
    "path": "braincog/datasets/scripts/testlist01.txt",
    "content": "ApplyEyeMakeup/v_ApplyEyeMakeup_g01_c01.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g01_c02.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g01_c03.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g01_c04.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g01_c05.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g01_c06.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g02_c01.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g02_c02.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g02_c03.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g02_c04.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g03_c01.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g03_c02.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g03_c03.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g03_c04.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g03_c05.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g03_c06.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g04_c01.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g04_c02.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g04_c03.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g04_c04.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g04_c05.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g04_c06.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g04_c07.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g05_c01.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g05_c02.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g05_c03.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g05_c04.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g05_c05.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g05_c06.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g05_c07.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g06_c01.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g06_c02.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g06_c03.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g06_c04.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g06_c05.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g06_c06.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g06_c07.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g07_c01.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g07_c02.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g07_c03.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g07_c04.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g07_c05.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g07_c06.avi\nApplyEyeMakeup/v_ApplyEyeMakeup_g07_c07.avi\nApplyLipstick/v_ApplyLipstick_g01_c01.avi\nApplyLipstick/v_ApplyLipstick_g01_c02.avi\nApplyLipstick/v_ApplyLipstick_g01_c03.avi\nApplyLipstick/v_ApplyLipstick_g01_c04.avi\nApplyLipstick/v_ApplyLipstick_g01_c05.avi\nApplyLipstick/v_ApplyLipstick_g02_c01.avi\nApplyLipstick/v_ApplyLipstick_g02_c02.avi\nApplyLipstick/v_ApplyLipstick_g02_c03.avi\nApplyLipstick/v_ApplyLipstick_g02_c04.avi\nApplyLipstick/v_ApplyLipstick_g03_c01.avi\nApplyLipstick/v_ApplyLipstick_g03_c02.avi\nApplyLipstick/v_ApplyLipstick_g03_c03.avi\nApplyLipstick/v_ApplyLipstick_g03_c04.avi\nApplyLipstick/v_ApplyLipstick_g04_c01.avi\nApplyLipstick/v_ApplyLipstick_g04_c02.avi\nApplyLipstick/v_ApplyLipstick_g04_c03.avi\nApplyLipstick/v_ApplyLipstick_g04_c04.avi\nApplyLipstick/v_ApplyLipstick_g04_c05.avi\nApplyLipstick/v_ApplyLipstick_g05_c01.avi\nApplyLipstick/v_ApplyLipstick_g05_c02.avi\nApplyLipstick/v_ApplyLipstick_g05_c03.avi\nApplyLipstick/v_ApplyLipstick_g05_c04.avi\nApplyLipstick/v_ApplyLipstick_g05_c05.avi\nApplyLipstick/v_ApplyLipstick_g06_c01.avi\nApplyLipstick/v_ApplyLipstick_g06_c02.avi\nApplyLipstick/v_ApplyLipstick_g06_c03.avi\nApplyLipstick/v_ApplyLipstick_g06_c04.avi\nApplyLipstick/v_ApplyLipstick_g06_c05.avi\nApplyLipstick/v_ApplyLipstick_g07_c01.avi\nApplyLipstick/v_ApplyLipstick_g07_c02.avi\nApplyLipstick/v_ApplyLipstick_g07_c03.avi\nApplyLipstick/v_ApplyLipstick_g07_c04.avi\nArchery/v_Archery_g01_c01.avi\nArchery/v_Archery_g01_c02.avi\nArchery/v_Archery_g01_c03.avi\nArchery/v_Archery_g01_c04.avi\nArchery/v_Archery_g01_c05.avi\nArchery/v_Archery_g01_c06.avi\nArchery/v_Archery_g01_c07.avi\nArchery/v_Archery_g02_c01.avi\nArchery/v_Archery_g02_c02.avi\nArchery/v_Archery_g02_c03.avi\nArchery/v_Archery_g02_c04.avi\nArchery/v_Archery_g02_c05.avi\nArchery/v_Archery_g02_c06.avi\nArchery/v_Archery_g02_c07.avi\nArchery/v_Archery_g03_c01.avi\nArchery/v_Archery_g03_c02.avi\nArchery/v_Archery_g03_c03.avi\nArchery/v_Archery_g03_c04.avi\nArchery/v_Archery_g03_c05.avi\nArchery/v_Archery_g04_c01.avi\nArchery/v_Archery_g04_c02.avi\nArchery/v_Archery_g04_c03.avi\nArchery/v_Archery_g04_c04.avi\nArchery/v_Archery_g04_c05.avi\nArchery/v_Archery_g05_c01.avi\nArchery/v_Archery_g05_c02.avi\nArchery/v_Archery_g05_c03.avi\nArchery/v_Archery_g05_c04.avi\nArchery/v_Archery_g05_c05.avi\nArchery/v_Archery_g06_c01.avi\nArchery/v_Archery_g06_c02.avi\nArchery/v_Archery_g06_c03.avi\nArchery/v_Archery_g06_c04.avi\nArchery/v_Archery_g06_c05.avi\nArchery/v_Archery_g06_c06.avi\nArchery/v_Archery_g07_c01.avi\nArchery/v_Archery_g07_c02.avi\nArchery/v_Archery_g07_c03.avi\nArchery/v_Archery_g07_c04.avi\nArchery/v_Archery_g07_c05.avi\nArchery/v_Archery_g07_c06.avi\nBabyCrawling/v_BabyCrawling_g01_c01.avi\nBabyCrawling/v_BabyCrawling_g01_c02.avi\nBabyCrawling/v_BabyCrawling_g01_c03.avi\nBabyCrawling/v_BabyCrawling_g01_c04.avi\nBabyCrawling/v_BabyCrawling_g02_c01.avi\nBabyCrawling/v_BabyCrawling_g02_c02.avi\nBabyCrawling/v_BabyCrawling_g02_c03.avi\nBabyCrawling/v_BabyCrawling_g02_c04.avi\nBabyCrawling/v_BabyCrawling_g02_c05.avi\nBabyCrawling/v_BabyCrawling_g02_c06.avi\nBabyCrawling/v_BabyCrawling_g03_c01.avi\nBabyCrawling/v_BabyCrawling_g03_c02.avi\nBabyCrawling/v_BabyCrawling_g03_c03.avi\nBabyCrawling/v_BabyCrawling_g03_c04.avi\nBabyCrawling/v_BabyCrawling_g04_c01.avi\nBabyCrawling/v_BabyCrawling_g04_c02.avi\nBabyCrawling/v_BabyCrawling_g04_c03.avi\nBabyCrawling/v_BabyCrawling_g04_c04.avi\nBabyCrawling/v_BabyCrawling_g05_c01.avi\nBabyCrawling/v_BabyCrawling_g05_c02.avi\nBabyCrawling/v_BabyCrawling_g05_c03.avi\nBabyCrawling/v_BabyCrawling_g05_c04.avi\nBabyCrawling/v_BabyCrawling_g05_c05.avi\nBabyCrawling/v_BabyCrawling_g06_c01.avi\nBabyCrawling/v_BabyCrawling_g06_c02.avi\nBabyCrawling/v_BabyCrawling_g06_c03.avi\nBabyCrawling/v_BabyCrawling_g06_c04.avi\nBabyCrawling/v_BabyCrawling_g06_c05.avi\nBabyCrawling/v_BabyCrawling_g06_c06.avi\nBabyCrawling/v_BabyCrawling_g07_c01.avi\nBabyCrawling/v_BabyCrawling_g07_c02.avi\nBabyCrawling/v_BabyCrawling_g07_c03.avi\nBabyCrawling/v_BabyCrawling_g07_c04.avi\nBabyCrawling/v_BabyCrawling_g07_c05.avi\nBabyCrawling/v_BabyCrawling_g07_c06.avi\nBalanceBeam/v_BalanceBeam_g01_c01.avi\nBalanceBeam/v_BalanceBeam_g01_c02.avi\nBalanceBeam/v_BalanceBeam_g01_c03.avi\nBalanceBeam/v_BalanceBeam_g01_c04.avi\nBalanceBeam/v_BalanceBeam_g02_c01.avi\nBalanceBeam/v_BalanceBeam_g02_c02.avi\nBalanceBeam/v_BalanceBeam_g02_c03.avi\nBalanceBeam/v_BalanceBeam_g02_c04.avi\nBalanceBeam/v_BalanceBeam_g03_c01.avi\nBalanceBeam/v_BalanceBeam_g03_c02.avi\nBalanceBeam/v_BalanceBeam_g03_c03.avi\nBalanceBeam/v_BalanceBeam_g03_c04.avi\nBalanceBeam/v_BalanceBeam_g04_c01.avi\nBalanceBeam/v_BalanceBeam_g04_c02.avi\nBalanceBeam/v_BalanceBeam_g04_c03.avi\nBalanceBeam/v_BalanceBeam_g04_c04.avi\nBalanceBeam/v_BalanceBeam_g05_c01.avi\nBalanceBeam/v_BalanceBeam_g05_c02.avi\nBalanceBeam/v_BalanceBeam_g05_c03.avi\nBalanceBeam/v_BalanceBeam_g05_c04.avi\nBalanceBeam/v_BalanceBeam_g06_c01.avi\nBalanceBeam/v_BalanceBeam_g06_c02.avi\nBalanceBeam/v_BalanceBeam_g06_c03.avi\nBalanceBeam/v_BalanceBeam_g06_c04.avi\nBalanceBeam/v_BalanceBeam_g06_c05.avi\nBalanceBeam/v_BalanceBeam_g06_c06.avi\nBalanceBeam/v_BalanceBeam_g06_c07.avi\nBalanceBeam/v_BalanceBeam_g07_c01.avi\nBalanceBeam/v_BalanceBeam_g07_c02.avi\nBalanceBeam/v_BalanceBeam_g07_c03.avi\nBalanceBeam/v_BalanceBeam_g07_c04.avi\nBandMarching/v_BandMarching_g01_c01.avi\nBandMarching/v_BandMarching_g01_c02.avi\nBandMarching/v_BandMarching_g01_c03.avi\nBandMarching/v_BandMarching_g01_c04.avi\nBandMarching/v_BandMarching_g01_c05.avi\nBandMarching/v_BandMarching_g01_c06.avi\nBandMarching/v_BandMarching_g01_c07.avi\nBandMarching/v_BandMarching_g02_c01.avi\nBandMarching/v_BandMarching_g02_c02.avi\nBandMarching/v_BandMarching_g02_c03.avi\nBandMarching/v_BandMarching_g02_c04.avi\nBandMarching/v_BandMarching_g02_c05.avi\nBandMarching/v_BandMarching_g02_c06.avi\nBandMarching/v_BandMarching_g02_c07.avi\nBandMarching/v_BandMarching_g03_c01.avi\nBandMarching/v_BandMarching_g03_c02.avi\nBandMarching/v_BandMarching_g03_c03.avi\nBandMarching/v_BandMarching_g03_c04.avi\nBandMarching/v_BandMarching_g03_c05.avi\nBandMarching/v_BandMarching_g03_c06.avi\nBandMarching/v_BandMarching_g03_c07.avi\nBandMarching/v_BandMarching_g04_c01.avi\nBandMarching/v_BandMarching_g04_c02.avi\nBandMarching/v_BandMarching_g04_c03.avi\nBandMarching/v_BandMarching_g04_c04.avi\nBandMarching/v_BandMarching_g05_c01.avi\nBandMarching/v_BandMarching_g05_c02.avi\nBandMarching/v_BandMarching_g05_c03.avi\nBandMarching/v_BandMarching_g05_c04.avi\nBandMarching/v_BandMarching_g05_c05.avi\nBandMarching/v_BandMarching_g05_c06.avi\nBandMarching/v_BandMarching_g05_c07.avi\nBandMarching/v_BandMarching_g06_c01.avi\nBandMarching/v_BandMarching_g06_c02.avi\nBandMarching/v_BandMarching_g06_c03.avi\nBandMarching/v_BandMarching_g06_c04.avi\nBandMarching/v_BandMarching_g07_c01.avi\nBandMarching/v_BandMarching_g07_c02.avi\nBandMarching/v_BandMarching_g07_c03.avi\nBandMarching/v_BandMarching_g07_c04.avi\nBandMarching/v_BandMarching_g07_c05.avi\nBandMarching/v_BandMarching_g07_c06.avi\nBandMarching/v_BandMarching_g07_c07.avi\nBaseballPitch/v_BaseballPitch_g01_c01.avi\nBaseballPitch/v_BaseballPitch_g01_c02.avi\nBaseballPitch/v_BaseballPitch_g01_c03.avi\nBaseballPitch/v_BaseballPitch_g01_c04.avi\nBaseballPitch/v_BaseballPitch_g01_c05.avi\nBaseballPitch/v_BaseballPitch_g01_c06.avi\nBaseballPitch/v_BaseballPitch_g02_c01.avi\nBaseballPitch/v_BaseballPitch_g02_c02.avi\nBaseballPitch/v_BaseballPitch_g02_c03.avi\nBaseballPitch/v_BaseballPitch_g02_c04.avi\nBaseballPitch/v_BaseballPitch_g03_c01.avi\nBaseballPitch/v_BaseballPitch_g03_c02.avi\nBaseballPitch/v_BaseballPitch_g03_c03.avi\nBaseballPitch/v_BaseballPitch_g03_c04.avi\nBaseballPitch/v_BaseballPitch_g03_c05.avi\nBaseballPitch/v_BaseballPitch_g03_c06.avi\nBaseballPitch/v_BaseballPitch_g03_c07.avi\nBaseballPitch/v_BaseballPitch_g04_c01.avi\nBaseballPitch/v_BaseballPitch_g04_c02.avi\nBaseballPitch/v_BaseballPitch_g04_c03.avi\nBaseballPitch/v_BaseballPitch_g04_c04.avi\nBaseballPitch/v_BaseballPitch_g04_c05.avi\nBaseballPitch/v_BaseballPitch_g05_c01.avi\nBaseballPitch/v_BaseballPitch_g05_c02.avi\nBaseballPitch/v_BaseballPitch_g05_c03.avi\nBaseballPitch/v_BaseballPitch_g05_c04.avi\nBaseballPitch/v_BaseballPitch_g05_c05.avi\nBaseballPitch/v_BaseballPitch_g05_c06.avi\nBaseballPitch/v_BaseballPitch_g05_c07.avi\nBaseballPitch/v_BaseballPitch_g06_c01.avi\nBaseballPitch/v_BaseballPitch_g06_c02.avi\nBaseballPitch/v_BaseballPitch_g06_c03.avi\nBaseballPitch/v_BaseballPitch_g06_c04.avi\nBaseballPitch/v_BaseballPitch_g06_c05.avi\nBaseballPitch/v_BaseballPitch_g06_c06.avi\nBaseballPitch/v_BaseballPitch_g06_c07.avi\nBaseballPitch/v_BaseballPitch_g07_c01.avi\nBaseballPitch/v_BaseballPitch_g07_c02.avi\nBaseballPitch/v_BaseballPitch_g07_c03.avi\nBaseballPitch/v_BaseballPitch_g07_c04.avi\nBaseballPitch/v_BaseballPitch_g07_c05.avi\nBaseballPitch/v_BaseballPitch_g07_c06.avi\nBaseballPitch/v_BaseballPitch_g07_c07.avi\nBasketball/v_Basketball_g01_c01.avi\nBasketball/v_Basketball_g01_c02.avi\nBasketball/v_Basketball_g01_c03.avi\nBasketball/v_Basketball_g01_c04.avi\nBasketball/v_Basketball_g01_c05.avi\nBasketball/v_Basketball_g01_c06.avi\nBasketball/v_Basketball_g01_c07.avi\nBasketball/v_Basketball_g02_c01.avi\nBasketball/v_Basketball_g02_c02.avi\nBasketball/v_Basketball_g02_c03.avi\nBasketball/v_Basketball_g02_c04.avi\nBasketball/v_Basketball_g02_c05.avi\nBasketball/v_Basketball_g02_c06.avi\nBasketball/v_Basketball_g03_c01.avi\nBasketball/v_Basketball_g03_c02.avi\nBasketball/v_Basketball_g03_c03.avi\nBasketball/v_Basketball_g03_c04.avi\nBasketball/v_Basketball_g03_c05.avi\nBasketball/v_Basketball_g03_c06.avi\nBasketball/v_Basketball_g04_c01.avi\nBasketball/v_Basketball_g04_c02.avi\nBasketball/v_Basketball_g04_c03.avi\nBasketball/v_Basketball_g04_c04.avi\nBasketball/v_Basketball_g05_c01.avi\nBasketball/v_Basketball_g05_c02.avi\nBasketball/v_Basketball_g05_c03.avi\nBasketball/v_Basketball_g05_c04.avi\nBasketball/v_Basketball_g06_c01.avi\nBasketball/v_Basketball_g06_c02.avi\nBasketball/v_Basketball_g06_c03.avi\nBasketball/v_Basketball_g06_c04.avi\nBasketball/v_Basketball_g07_c01.avi\nBasketball/v_Basketball_g07_c02.avi\nBasketball/v_Basketball_g07_c03.avi\nBasketball/v_Basketball_g07_c04.avi\nBasketballDunk/v_BasketballDunk_g01_c01.avi\nBasketballDunk/v_BasketballDunk_g01_c02.avi\nBasketballDunk/v_BasketballDunk_g01_c03.avi\nBasketballDunk/v_BasketballDunk_g01_c04.avi\nBasketballDunk/v_BasketballDunk_g01_c05.avi\nBasketballDunk/v_BasketballDunk_g01_c06.avi\nBasketballDunk/v_BasketballDunk_g01_c07.avi\nBasketballDunk/v_BasketballDunk_g02_c01.avi\nBasketballDunk/v_BasketballDunk_g02_c02.avi\nBasketballDunk/v_BasketballDunk_g02_c03.avi\nBasketballDunk/v_BasketballDunk_g02_c04.avi\nBasketballDunk/v_BasketballDunk_g03_c01.avi\nBasketballDunk/v_BasketballDunk_g03_c02.avi\nBasketballDunk/v_BasketballDunk_g03_c03.avi\nBasketballDunk/v_BasketballDunk_g03_c04.avi\nBasketballDunk/v_BasketballDunk_g03_c05.avi\nBasketballDunk/v_BasketballDunk_g03_c06.avi\nBasketballDunk/v_BasketballDunk_g04_c01.avi\nBasketballDunk/v_BasketballDunk_g04_c02.avi\nBasketballDunk/v_BasketballDunk_g04_c03.avi\nBasketballDunk/v_BasketballDunk_g04_c04.avi\nBasketballDunk/v_BasketballDunk_g05_c01.avi\nBasketballDunk/v_BasketballDunk_g05_c02.avi\nBasketballDunk/v_BasketballDunk_g05_c03.avi\nBasketballDunk/v_BasketballDunk_g05_c04.avi\nBasketballDunk/v_BasketballDunk_g05_c05.avi\nBasketballDunk/v_BasketballDunk_g05_c06.avi\nBasketballDunk/v_BasketballDunk_g06_c01.avi\nBasketballDunk/v_BasketballDunk_g06_c02.avi\nBasketballDunk/v_BasketballDunk_g06_c03.avi\nBasketballDunk/v_BasketballDunk_g06_c04.avi\nBasketballDunk/v_BasketballDunk_g07_c01.avi\nBasketballDunk/v_BasketballDunk_g07_c02.avi\nBasketballDunk/v_BasketballDunk_g07_c03.avi\nBasketballDunk/v_BasketballDunk_g07_c04.avi\nBasketballDunk/v_BasketballDunk_g07_c05.avi\nBasketballDunk/v_BasketballDunk_g07_c06.avi\nBenchPress/v_BenchPress_g01_c01.avi\nBenchPress/v_BenchPress_g01_c02.avi\nBenchPress/v_BenchPress_g01_c03.avi\nBenchPress/v_BenchPress_g01_c04.avi\nBenchPress/v_BenchPress_g01_c05.avi\nBenchPress/v_BenchPress_g01_c06.avi\nBenchPress/v_BenchPress_g02_c01.avi\nBenchPress/v_BenchPress_g02_c02.avi\nBenchPress/v_BenchPress_g02_c03.avi\nBenchPress/v_BenchPress_g02_c04.avi\nBenchPress/v_BenchPress_g02_c05.avi\nBenchPress/v_BenchPress_g02_c06.avi\nBenchPress/v_BenchPress_g02_c07.avi\nBenchPress/v_BenchPress_g03_c01.avi\nBenchPress/v_BenchPress_g03_c02.avi\nBenchPress/v_BenchPress_g03_c03.avi\nBenchPress/v_BenchPress_g03_c04.avi\nBenchPress/v_BenchPress_g03_c05.avi\nBenchPress/v_BenchPress_g03_c06.avi\nBenchPress/v_BenchPress_g03_c07.avi\nBenchPress/v_BenchPress_g04_c01.avi\nBenchPress/v_BenchPress_g04_c02.avi\nBenchPress/v_BenchPress_g04_c03.avi\nBenchPress/v_BenchPress_g04_c04.avi\nBenchPress/v_BenchPress_g04_c05.avi\nBenchPress/v_BenchPress_g04_c06.avi\nBenchPress/v_BenchPress_g04_c07.avi\nBenchPress/v_BenchPress_g05_c01.avi\nBenchPress/v_BenchPress_g05_c02.avi\nBenchPress/v_BenchPress_g05_c03.avi\nBenchPress/v_BenchPress_g05_c04.avi\nBenchPress/v_BenchPress_g05_c05.avi\nBenchPress/v_BenchPress_g05_c06.avi\nBenchPress/v_BenchPress_g05_c07.avi\nBenchPress/v_BenchPress_g06_c01.avi\nBenchPress/v_BenchPress_g06_c02.avi\nBenchPress/v_BenchPress_g06_c03.avi\nBenchPress/v_BenchPress_g06_c04.avi\nBenchPress/v_BenchPress_g06_c05.avi\nBenchPress/v_BenchPress_g06_c06.avi\nBenchPress/v_BenchPress_g06_c07.avi\nBenchPress/v_BenchPress_g07_c01.avi\nBenchPress/v_BenchPress_g07_c02.avi\nBenchPress/v_BenchPress_g07_c03.avi\nBenchPress/v_BenchPress_g07_c04.avi\nBenchPress/v_BenchPress_g07_c05.avi\nBenchPress/v_BenchPress_g07_c06.avi\nBenchPress/v_BenchPress_g07_c07.avi\nBiking/v_Biking_g01_c01.avi\nBiking/v_Biking_g01_c02.avi\nBiking/v_Biking_g01_c03.avi\nBiking/v_Biking_g01_c04.avi\nBiking/v_Biking_g02_c01.avi\nBiking/v_Biking_g02_c02.avi\nBiking/v_Biking_g02_c03.avi\nBiking/v_Biking_g02_c04.avi\nBiking/v_Biking_g02_c05.avi\nBiking/v_Biking_g02_c06.avi\nBiking/v_Biking_g02_c07.avi\nBiking/v_Biking_g03_c01.avi\nBiking/v_Biking_g03_c02.avi\nBiking/v_Biking_g03_c03.avi\nBiking/v_Biking_g03_c04.avi\nBiking/v_Biking_g04_c01.avi\nBiking/v_Biking_g04_c02.avi\nBiking/v_Biking_g04_c03.avi\nBiking/v_Biking_g04_c04.avi\nBiking/v_Biking_g04_c05.avi\nBiking/v_Biking_g05_c01.avi\nBiking/v_Biking_g05_c02.avi\nBiking/v_Biking_g05_c03.avi\nBiking/v_Biking_g05_c04.avi\nBiking/v_Biking_g05_c05.avi\nBiking/v_Biking_g05_c06.avi\nBiking/v_Biking_g05_c07.avi\nBiking/v_Biking_g06_c01.avi\nBiking/v_Biking_g06_c02.avi\nBiking/v_Biking_g06_c03.avi\nBiking/v_Biking_g06_c04.avi\nBiking/v_Biking_g06_c05.avi\nBiking/v_Biking_g07_c01.avi\nBiking/v_Biking_g07_c02.avi\nBiking/v_Biking_g07_c03.avi\nBiking/v_Biking_g07_c04.avi\nBiking/v_Biking_g07_c05.avi\nBiking/v_Biking_g07_c06.avi\nBilliards/v_Billiards_g01_c01.avi\nBilliards/v_Billiards_g01_c02.avi\nBilliards/v_Billiards_g01_c03.avi\nBilliards/v_Billiards_g01_c04.avi\nBilliards/v_Billiards_g01_c05.avi\nBilliards/v_Billiards_g01_c06.avi\nBilliards/v_Billiards_g02_c01.avi\nBilliards/v_Billiards_g02_c02.avi\nBilliards/v_Billiards_g02_c03.avi\nBilliards/v_Billiards_g02_c04.avi\nBilliards/v_Billiards_g02_c05.avi\nBilliards/v_Billiards_g02_c06.avi\nBilliards/v_Billiards_g02_c07.avi\nBilliards/v_Billiards_g03_c01.avi\nBilliards/v_Billiards_g03_c02.avi\nBilliards/v_Billiards_g03_c03.avi\nBilliards/v_Billiards_g03_c04.avi\nBilliards/v_Billiards_g03_c05.avi\nBilliards/v_Billiards_g04_c01.avi\nBilliards/v_Billiards_g04_c02.avi\nBilliards/v_Billiards_g04_c03.avi\nBilliards/v_Billiards_g04_c04.avi\nBilliards/v_Billiards_g04_c05.avi\nBilliards/v_Billiards_g04_c06.avi\nBilliards/v_Billiards_g04_c07.avi\nBilliards/v_Billiards_g05_c01.avi\nBilliards/v_Billiards_g05_c02.avi\nBilliards/v_Billiards_g05_c03.avi\nBilliards/v_Billiards_g05_c04.avi\nBilliards/v_Billiards_g05_c05.avi\nBilliards/v_Billiards_g05_c06.avi\nBilliards/v_Billiards_g06_c01.avi\nBilliards/v_Billiards_g06_c02.avi\nBilliards/v_Billiards_g06_c03.avi\nBilliards/v_Billiards_g06_c04.avi\nBilliards/v_Billiards_g06_c05.avi\nBilliards/v_Billiards_g07_c01.avi\nBilliards/v_Billiards_g07_c02.avi\nBilliards/v_Billiards_g07_c03.avi\nBilliards/v_Billiards_g07_c04.avi\nBlowDryHair/v_BlowDryHair_g01_c01.avi\nBlowDryHair/v_BlowDryHair_g01_c02.avi\nBlowDryHair/v_BlowDryHair_g01_c03.avi\nBlowDryHair/v_BlowDryHair_g01_c04.avi\nBlowDryHair/v_BlowDryHair_g02_c01.avi\nBlowDryHair/v_BlowDryHair_g02_c02.avi\nBlowDryHair/v_BlowDryHair_g02_c03.avi\nBlowDryHair/v_BlowDryHair_g02_c04.avi\nBlowDryHair/v_BlowDryHair_g02_c05.avi\nBlowDryHair/v_BlowDryHair_g03_c01.avi\nBlowDryHair/v_BlowDryHair_g03_c02.avi\nBlowDryHair/v_BlowDryHair_g03_c03.avi\nBlowDryHair/v_BlowDryHair_g03_c04.avi\nBlowDryHair/v_BlowDryHair_g03_c05.avi\nBlowDryHair/v_BlowDryHair_g04_c01.avi\nBlowDryHair/v_BlowDryHair_g04_c02.avi\nBlowDryHair/v_BlowDryHair_g04_c03.avi\nBlowDryHair/v_BlowDryHair_g04_c04.avi\nBlowDryHair/v_BlowDryHair_g04_c05.avi\nBlowDryHair/v_BlowDryHair_g05_c01.avi\nBlowDryHair/v_BlowDryHair_g05_c02.avi\nBlowDryHair/v_BlowDryHair_g05_c03.avi\nBlowDryHair/v_BlowDryHair_g05_c04.avi\nBlowDryHair/v_BlowDryHair_g05_c05.avi\nBlowDryHair/v_BlowDryHair_g06_c01.avi\nBlowDryHair/v_BlowDryHair_g06_c02.avi\nBlowDryHair/v_BlowDryHair_g06_c03.avi\nBlowDryHair/v_BlowDryHair_g06_c04.avi\nBlowDryHair/v_BlowDryHair_g06_c05.avi\nBlowDryHair/v_BlowDryHair_g06_c06.avi\nBlowDryHair/v_BlowDryHair_g06_c07.avi\nBlowDryHair/v_BlowDryHair_g07_c01.avi\nBlowDryHair/v_BlowDryHair_g07_c02.avi\nBlowDryHair/v_BlowDryHair_g07_c03.avi\nBlowDryHair/v_BlowDryHair_g07_c04.avi\nBlowDryHair/v_BlowDryHair_g07_c05.avi\nBlowDryHair/v_BlowDryHair_g07_c06.avi\nBlowDryHair/v_BlowDryHair_g07_c07.avi\nBlowingCandles/v_BlowingCandles_g01_c01.avi\nBlowingCandles/v_BlowingCandles_g01_c02.avi\nBlowingCandles/v_BlowingCandles_g01_c03.avi\nBlowingCandles/v_BlowingCandles_g01_c04.avi\nBlowingCandles/v_BlowingCandles_g02_c01.avi\nBlowingCandles/v_BlowingCandles_g02_c02.avi\nBlowingCandles/v_BlowingCandles_g02_c03.avi\nBlowingCandles/v_BlowingCandles_g02_c04.avi\nBlowingCandles/v_BlowingCandles_g03_c01.avi\nBlowingCandles/v_BlowingCandles_g03_c02.avi\nBlowingCandles/v_BlowingCandles_g03_c03.avi\nBlowingCandles/v_BlowingCandles_g03_c04.avi\nBlowingCandles/v_BlowingCandles_g04_c01.avi\nBlowingCandles/v_BlowingCandles_g04_c02.avi\nBlowingCandles/v_BlowingCandles_g04_c03.avi\nBlowingCandles/v_BlowingCandles_g04_c04.avi\nBlowingCandles/v_BlowingCandles_g04_c05.avi\nBlowingCandles/v_BlowingCandles_g05_c01.avi\nBlowingCandles/v_BlowingCandles_g05_c02.avi\nBlowingCandles/v_BlowingCandles_g05_c03.avi\nBlowingCandles/v_BlowingCandles_g05_c04.avi\nBlowingCandles/v_BlowingCandles_g05_c05.avi\nBlowingCandles/v_BlowingCandles_g06_c01.avi\nBlowingCandles/v_BlowingCandles_g06_c02.avi\nBlowingCandles/v_BlowingCandles_g06_c03.avi\nBlowingCandles/v_BlowingCandles_g06_c04.avi\nBlowingCandles/v_BlowingCandles_g06_c05.avi\nBlowingCandles/v_BlowingCandles_g06_c06.avi\nBlowingCandles/v_BlowingCandles_g06_c07.avi\nBlowingCandles/v_BlowingCandles_g07_c01.avi\nBlowingCandles/v_BlowingCandles_g07_c02.avi\nBlowingCandles/v_BlowingCandles_g07_c03.avi\nBlowingCandles/v_BlowingCandles_g07_c04.avi\nBodyWeightSquats/v_BodyWeightSquats_g01_c01.avi\nBodyWeightSquats/v_BodyWeightSquats_g01_c02.avi\nBodyWeightSquats/v_BodyWeightSquats_g01_c03.avi\nBodyWeightSquats/v_BodyWeightSquats_g01_c04.avi\nBodyWeightSquats/v_BodyWeightSquats_g02_c01.avi\nBodyWeightSquats/v_BodyWeightSquats_g02_c02.avi\nBodyWeightSquats/v_BodyWeightSquats_g02_c03.avi\nBodyWeightSquats/v_BodyWeightSquats_g02_c04.avi\nBodyWeightSquats/v_BodyWeightSquats_g03_c01.avi\nBodyWeightSquats/v_BodyWeightSquats_g03_c02.avi\nBodyWeightSquats/v_BodyWeightSquats_g03_c03.avi\nBodyWeightSquats/v_BodyWeightSquats_g03_c04.avi\nBodyWeightSquats/v_BodyWeightSquats_g03_c05.avi\nBodyWeightSquats/v_BodyWeightSquats_g04_c01.avi\nBodyWeightSquats/v_BodyWeightSquats_g04_c02.avi\nBodyWeightSquats/v_BodyWeightSquats_g04_c03.avi\nBodyWeightSquats/v_BodyWeightSquats_g04_c04.avi\nBodyWeightSquats/v_BodyWeightSquats_g05_c01.avi\nBodyWeightSquats/v_BodyWeightSquats_g05_c02.avi\nBodyWeightSquats/v_BodyWeightSquats_g05_c03.avi\nBodyWeightSquats/v_BodyWeightSquats_g05_c04.avi\nBodyWeightSquats/v_BodyWeightSquats_g06_c01.avi\nBodyWeightSquats/v_BodyWeightSquats_g06_c02.avi\nBodyWeightSquats/v_BodyWeightSquats_g06_c03.avi\nBodyWeightSquats/v_BodyWeightSquats_g06_c04.avi\nBodyWeightSquats/v_BodyWeightSquats_g06_c05.avi\nBodyWeightSquats/v_BodyWeightSquats_g07_c01.avi\nBodyWeightSquats/v_BodyWeightSquats_g07_c02.avi\nBodyWeightSquats/v_BodyWeightSquats_g07_c03.avi\nBodyWeightSquats/v_BodyWeightSquats_g07_c04.avi\nBowling/v_Bowling_g01_c01.avi\nBowling/v_Bowling_g01_c02.avi\nBowling/v_Bowling_g01_c03.avi\nBowling/v_Bowling_g01_c04.avi\nBowling/v_Bowling_g01_c05.avi\nBowling/v_Bowling_g01_c06.avi\nBowling/v_Bowling_g01_c07.avi\nBowling/v_Bowling_g02_c01.avi\nBowling/v_Bowling_g02_c02.avi\nBowling/v_Bowling_g02_c03.avi\nBowling/v_Bowling_g02_c04.avi\nBowling/v_Bowling_g03_c01.avi\nBowling/v_Bowling_g03_c02.avi\nBowling/v_Bowling_g03_c03.avi\nBowling/v_Bowling_g03_c04.avi\nBowling/v_Bowling_g03_c05.avi\nBowling/v_Bowling_g03_c06.avi\nBowling/v_Bowling_g03_c07.avi\nBowling/v_Bowling_g04_c01.avi\nBowling/v_Bowling_g04_c02.avi\nBowling/v_Bowling_g04_c03.avi\nBowling/v_Bowling_g04_c04.avi\nBowling/v_Bowling_g05_c01.avi\nBowling/v_Bowling_g05_c02.avi\nBowling/v_Bowling_g05_c03.avi\nBowling/v_Bowling_g05_c04.avi\nBowling/v_Bowling_g05_c05.avi\nBowling/v_Bowling_g05_c06.avi\nBowling/v_Bowling_g05_c07.avi\nBowling/v_Bowling_g06_c01.avi\nBowling/v_Bowling_g06_c02.avi\nBowling/v_Bowling_g06_c03.avi\nBowling/v_Bowling_g06_c04.avi\nBowling/v_Bowling_g06_c05.avi\nBowling/v_Bowling_g06_c06.avi\nBowling/v_Bowling_g06_c07.avi\nBowling/v_Bowling_g07_c01.avi\nBowling/v_Bowling_g07_c02.avi\nBowling/v_Bowling_g07_c03.avi\nBowling/v_Bowling_g07_c04.avi\nBowling/v_Bowling_g07_c05.avi\nBowling/v_Bowling_g07_c06.avi\nBowling/v_Bowling_g07_c07.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g01_c01.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g01_c02.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g01_c03.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g01_c04.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g01_c05.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g01_c06.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g01_c07.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g02_c01.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g02_c02.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g02_c03.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g02_c04.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g02_c05.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g02_c06.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g02_c07.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g03_c01.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g03_c02.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g03_c03.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g03_c04.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g03_c05.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g03_c06.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g03_c07.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g04_c01.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g04_c02.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g04_c03.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g04_c04.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g04_c05.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g04_c06.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g04_c07.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g05_c01.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g05_c02.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g05_c03.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g05_c04.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g05_c05.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g05_c06.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g05_c07.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g06_c01.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g06_c02.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g06_c03.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g06_c04.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g06_c05.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g06_c06.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g06_c07.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g07_c01.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g07_c02.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g07_c03.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g07_c04.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g07_c05.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g07_c06.avi\nBoxingPunchingBag/v_BoxingPunchingBag_g07_c07.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g01_c01.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g01_c02.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g01_c03.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g01_c04.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g02_c01.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g02_c02.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g02_c03.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g02_c04.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g03_c01.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g03_c02.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g03_c03.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g03_c04.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g03_c05.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g04_c01.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g04_c02.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g04_c03.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g04_c04.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g04_c05.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g04_c06.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g04_c07.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g05_c01.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g05_c02.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g05_c03.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g05_c04.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g05_c05.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g06_c01.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g06_c02.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g06_c03.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g06_c04.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g06_c05.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g07_c01.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g07_c02.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g07_c03.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g07_c04.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g07_c05.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g07_c06.avi\nBoxingSpeedBag/v_BoxingSpeedBag_g07_c07.avi\nBreastStroke/v_BreastStroke_g01_c01.avi\nBreastStroke/v_BreastStroke_g01_c02.avi\nBreastStroke/v_BreastStroke_g01_c03.avi\nBreastStroke/v_BreastStroke_g01_c04.avi\nBreastStroke/v_BreastStroke_g02_c01.avi\nBreastStroke/v_BreastStroke_g02_c02.avi\nBreastStroke/v_BreastStroke_g02_c03.avi\nBreastStroke/v_BreastStroke_g02_c04.avi\nBreastStroke/v_BreastStroke_g03_c01.avi\nBreastStroke/v_BreastStroke_g03_c02.avi\nBreastStroke/v_BreastStroke_g03_c03.avi\nBreastStroke/v_BreastStroke_g03_c04.avi\nBreastStroke/v_BreastStroke_g04_c01.avi\nBreastStroke/v_BreastStroke_g04_c02.avi\nBreastStroke/v_BreastStroke_g04_c03.avi\nBreastStroke/v_BreastStroke_g04_c04.avi\nBreastStroke/v_BreastStroke_g05_c01.avi\nBreastStroke/v_BreastStroke_g05_c02.avi\nBreastStroke/v_BreastStroke_g05_c03.avi\nBreastStroke/v_BreastStroke_g05_c04.avi\nBreastStroke/v_BreastStroke_g06_c01.avi\nBreastStroke/v_BreastStroke_g06_c02.avi\nBreastStroke/v_BreastStroke_g06_c03.avi\nBreastStroke/v_BreastStroke_g06_c04.avi\nBreastStroke/v_BreastStroke_g07_c01.avi\nBreastStroke/v_BreastStroke_g07_c02.avi\nBreastStroke/v_BreastStroke_g07_c03.avi\nBreastStroke/v_BreastStroke_g07_c04.avi\nBrushingTeeth/v_BrushingTeeth_g01_c01.avi\nBrushingTeeth/v_BrushingTeeth_g01_c02.avi\nBrushingTeeth/v_BrushingTeeth_g01_c03.avi\nBrushingTeeth/v_BrushingTeeth_g01_c04.avi\nBrushingTeeth/v_BrushingTeeth_g02_c01.avi\nBrushingTeeth/v_BrushingTeeth_g02_c02.avi\nBrushingTeeth/v_BrushingTeeth_g02_c03.avi\nBrushingTeeth/v_BrushingTeeth_g02_c04.avi\nBrushingTeeth/v_BrushingTeeth_g02_c05.avi\nBrushingTeeth/v_BrushingTeeth_g02_c06.avi\nBrushingTeeth/v_BrushingTeeth_g02_c07.avi\nBrushingTeeth/v_BrushingTeeth_g03_c01.avi\nBrushingTeeth/v_BrushingTeeth_g03_c02.avi\nBrushingTeeth/v_BrushingTeeth_g03_c03.avi\nBrushingTeeth/v_BrushingTeeth_g03_c04.avi\nBrushingTeeth/v_BrushingTeeth_g03_c05.avi\nBrushingTeeth/v_BrushingTeeth_g04_c01.avi\nBrushingTeeth/v_BrushingTeeth_g04_c02.avi\nBrushingTeeth/v_BrushingTeeth_g04_c03.avi\nBrushingTeeth/v_BrushingTeeth_g04_c04.avi\nBrushingTeeth/v_BrushingTeeth_g05_c01.avi\nBrushingTeeth/v_BrushingTeeth_g05_c02.avi\nBrushingTeeth/v_BrushingTeeth_g05_c03.avi\nBrushingTeeth/v_BrushingTeeth_g05_c04.avi\nBrushingTeeth/v_BrushingTeeth_g05_c05.avi\nBrushingTeeth/v_BrushingTeeth_g06_c01.avi\nBrushingTeeth/v_BrushingTeeth_g06_c02.avi\nBrushingTeeth/v_BrushingTeeth_g06_c03.avi\nBrushingTeeth/v_BrushingTeeth_g06_c04.avi\nBrushingTeeth/v_BrushingTeeth_g06_c05.avi\nBrushingTeeth/v_BrushingTeeth_g07_c01.avi\nBrushingTeeth/v_BrushingTeeth_g07_c02.avi\nBrushingTeeth/v_BrushingTeeth_g07_c03.avi\nBrushingTeeth/v_BrushingTeeth_g07_c04.avi\nBrushingTeeth/v_BrushingTeeth_g07_c05.avi\nBrushingTeeth/v_BrushingTeeth_g07_c06.avi\nCleanAndJerk/v_CleanAndJerk_g01_c01.avi\nCleanAndJerk/v_CleanAndJerk_g01_c02.avi\nCleanAndJerk/v_CleanAndJerk_g01_c03.avi\nCleanAndJerk/v_CleanAndJerk_g01_c04.avi\nCleanAndJerk/v_CleanAndJerk_g01_c05.avi\nCleanAndJerk/v_CleanAndJerk_g02_c01.avi\nCleanAndJerk/v_CleanAndJerk_g02_c02.avi\nCleanAndJerk/v_CleanAndJerk_g02_c03.avi\nCleanAndJerk/v_CleanAndJerk_g02_c04.avi\nCleanAndJerk/v_CleanAndJerk_g03_c01.avi\nCleanAndJerk/v_CleanAndJerk_g03_c02.avi\nCleanAndJerk/v_CleanAndJerk_g03_c03.avi\nCleanAndJerk/v_CleanAndJerk_g03_c04.avi\nCleanAndJerk/v_CleanAndJerk_g03_c05.avi\nCleanAndJerk/v_CleanAndJerk_g03_c06.avi\nCleanAndJerk/v_CleanAndJerk_g04_c01.avi\nCleanAndJerk/v_CleanAndJerk_g04_c02.avi\nCleanAndJerk/v_CleanAndJerk_g04_c03.avi\nCleanAndJerk/v_CleanAndJerk_g04_c04.avi\nCleanAndJerk/v_CleanAndJerk_g04_c05.avi\nCleanAndJerk/v_CleanAndJerk_g05_c01.avi\nCleanAndJerk/v_CleanAndJerk_g05_c02.avi\nCleanAndJerk/v_CleanAndJerk_g05_c03.avi\nCleanAndJerk/v_CleanAndJerk_g05_c04.avi\nCleanAndJerk/v_CleanAndJerk_g06_c01.avi\nCleanAndJerk/v_CleanAndJerk_g06_c02.avi\nCleanAndJerk/v_CleanAndJerk_g06_c03.avi\nCleanAndJerk/v_CleanAndJerk_g06_c04.avi\nCleanAndJerk/v_CleanAndJerk_g07_c01.avi\nCleanAndJerk/v_CleanAndJerk_g07_c02.avi\nCleanAndJerk/v_CleanAndJerk_g07_c03.avi\nCleanAndJerk/v_CleanAndJerk_g07_c04.avi\nCleanAndJerk/v_CleanAndJerk_g07_c05.avi\nCliffDiving/v_CliffDiving_g01_c01.avi\nCliffDiving/v_CliffDiving_g01_c02.avi\nCliffDiving/v_CliffDiving_g01_c03.avi\nCliffDiving/v_CliffDiving_g01_c04.avi\nCliffDiving/v_CliffDiving_g01_c05.avi\nCliffDiving/v_CliffDiving_g01_c06.avi\nCliffDiving/v_CliffDiving_g02_c01.avi\nCliffDiving/v_CliffDiving_g02_c02.avi\nCliffDiving/v_CliffDiving_g02_c03.avi\nCliffDiving/v_CliffDiving_g02_c04.avi\nCliffDiving/v_CliffDiving_g03_c01.avi\nCliffDiving/v_CliffDiving_g03_c02.avi\nCliffDiving/v_CliffDiving_g03_c03.avi\nCliffDiving/v_CliffDiving_g03_c04.avi\nCliffDiving/v_CliffDiving_g03_c05.avi\nCliffDiving/v_CliffDiving_g04_c01.avi\nCliffDiving/v_CliffDiving_g04_c02.avi\nCliffDiving/v_CliffDiving_g04_c03.avi\nCliffDiving/v_CliffDiving_g04_c04.avi\nCliffDiving/v_CliffDiving_g05_c01.avi\nCliffDiving/v_CliffDiving_g05_c02.avi\nCliffDiving/v_CliffDiving_g05_c03.avi\nCliffDiving/v_CliffDiving_g05_c04.avi\nCliffDiving/v_CliffDiving_g05_c05.avi\nCliffDiving/v_CliffDiving_g05_c06.avi\nCliffDiving/v_CliffDiving_g05_c07.avi\nCliffDiving/v_CliffDiving_g06_c01.avi\nCliffDiving/v_CliffDiving_g06_c02.avi\nCliffDiving/v_CliffDiving_g06_c03.avi\nCliffDiving/v_CliffDiving_g06_c04.avi\nCliffDiving/v_CliffDiving_g06_c05.avi\nCliffDiving/v_CliffDiving_g06_c06.avi\nCliffDiving/v_CliffDiving_g06_c07.avi\nCliffDiving/v_CliffDiving_g07_c01.avi\nCliffDiving/v_CliffDiving_g07_c02.avi\nCliffDiving/v_CliffDiving_g07_c03.avi\nCliffDiving/v_CliffDiving_g07_c04.avi\nCliffDiving/v_CliffDiving_g07_c05.avi\nCliffDiving/v_CliffDiving_g07_c06.avi\nCricketBowling/v_CricketBowling_g01_c01.avi\nCricketBowling/v_CricketBowling_g01_c02.avi\nCricketBowling/v_CricketBowling_g01_c03.avi\nCricketBowling/v_CricketBowling_g01_c04.avi\nCricketBowling/v_CricketBowling_g01_c05.avi\nCricketBowling/v_CricketBowling_g01_c06.avi\nCricketBowling/v_CricketBowling_g01_c07.avi\nCricketBowling/v_CricketBowling_g02_c01.avi\nCricketBowling/v_CricketBowling_g02_c02.avi\nCricketBowling/v_CricketBowling_g02_c03.avi\nCricketBowling/v_CricketBowling_g02_c04.avi\nCricketBowling/v_CricketBowling_g02_c05.avi\nCricketBowling/v_CricketBowling_g02_c06.avi\nCricketBowling/v_CricketBowling_g02_c07.avi\nCricketBowling/v_CricketBowling_g03_c01.avi\nCricketBowling/v_CricketBowling_g03_c02.avi\nCricketBowling/v_CricketBowling_g03_c03.avi\nCricketBowling/v_CricketBowling_g03_c04.avi\nCricketBowling/v_CricketBowling_g04_c01.avi\nCricketBowling/v_CricketBowling_g04_c02.avi\nCricketBowling/v_CricketBowling_g04_c03.avi\nCricketBowling/v_CricketBowling_g04_c04.avi\nCricketBowling/v_CricketBowling_g04_c05.avi\nCricketBowling/v_CricketBowling_g05_c01.avi\nCricketBowling/v_CricketBowling_g05_c02.avi\nCricketBowling/v_CricketBowling_g05_c03.avi\nCricketBowling/v_CricketBowling_g05_c04.avi\nCricketBowling/v_CricketBowling_g06_c01.avi\nCricketBowling/v_CricketBowling_g06_c02.avi\nCricketBowling/v_CricketBowling_g06_c03.avi\nCricketBowling/v_CricketBowling_g06_c04.avi\nCricketBowling/v_CricketBowling_g06_c05.avi\nCricketBowling/v_CricketBowling_g07_c01.avi\nCricketBowling/v_CricketBowling_g07_c02.avi\nCricketBowling/v_CricketBowling_g07_c03.avi\nCricketBowling/v_CricketBowling_g07_c04.avi\nCricketShot/v_CricketShot_g01_c01.avi\nCricketShot/v_CricketShot_g01_c02.avi\nCricketShot/v_CricketShot_g01_c03.avi\nCricketShot/v_CricketShot_g01_c04.avi\nCricketShot/v_CricketShot_g01_c05.avi\nCricketShot/v_CricketShot_g01_c06.avi\nCricketShot/v_CricketShot_g01_c07.avi\nCricketShot/v_CricketShot_g02_c01.avi\nCricketShot/v_CricketShot_g02_c02.avi\nCricketShot/v_CricketShot_g02_c03.avi\nCricketShot/v_CricketShot_g02_c04.avi\nCricketShot/v_CricketShot_g02_c05.avi\nCricketShot/v_CricketShot_g02_c06.avi\nCricketShot/v_CricketShot_g02_c07.avi\nCricketShot/v_CricketShot_g03_c01.avi\nCricketShot/v_CricketShot_g03_c02.avi\nCricketShot/v_CricketShot_g03_c03.avi\nCricketShot/v_CricketShot_g03_c04.avi\nCricketShot/v_CricketShot_g03_c05.avi\nCricketShot/v_CricketShot_g03_c06.avi\nCricketShot/v_CricketShot_g03_c07.avi\nCricketShot/v_CricketShot_g04_c01.avi\nCricketShot/v_CricketShot_g04_c02.avi\nCricketShot/v_CricketShot_g04_c03.avi\nCricketShot/v_CricketShot_g04_c04.avi\nCricketShot/v_CricketShot_g04_c05.avi\nCricketShot/v_CricketShot_g04_c06.avi\nCricketShot/v_CricketShot_g04_c07.avi\nCricketShot/v_CricketShot_g05_c01.avi\nCricketShot/v_CricketShot_g05_c02.avi\nCricketShot/v_CricketShot_g05_c03.avi\nCricketShot/v_CricketShot_g05_c04.avi\nCricketShot/v_CricketShot_g05_c05.avi\nCricketShot/v_CricketShot_g05_c06.avi\nCricketShot/v_CricketShot_g05_c07.avi\nCricketShot/v_CricketShot_g06_c01.avi\nCricketShot/v_CricketShot_g06_c02.avi\nCricketShot/v_CricketShot_g06_c03.avi\nCricketShot/v_CricketShot_g06_c04.avi\nCricketShot/v_CricketShot_g06_c05.avi\nCricketShot/v_CricketShot_g06_c06.avi\nCricketShot/v_CricketShot_g06_c07.avi\nCricketShot/v_CricketShot_g07_c01.avi\nCricketShot/v_CricketShot_g07_c02.avi\nCricketShot/v_CricketShot_g07_c03.avi\nCricketShot/v_CricketShot_g07_c04.avi\nCricketShot/v_CricketShot_g07_c05.avi\nCricketShot/v_CricketShot_g07_c06.avi\nCricketShot/v_CricketShot_g07_c07.avi\nCuttingInKitchen/v_CuttingInKitchen_g01_c01.avi\nCuttingInKitchen/v_CuttingInKitchen_g01_c02.avi\nCuttingInKitchen/v_CuttingInKitchen_g01_c03.avi\nCuttingInKitchen/v_CuttingInKitchen_g01_c04.avi\nCuttingInKitchen/v_CuttingInKitchen_g01_c05.avi\nCuttingInKitchen/v_CuttingInKitchen_g02_c01.avi\nCuttingInKitchen/v_CuttingInKitchen_g02_c02.avi\nCuttingInKitchen/v_CuttingInKitchen_g02_c03.avi\nCuttingInKitchen/v_CuttingInKitchen_g02_c04.avi\nCuttingInKitchen/v_CuttingInKitchen_g03_c01.avi\nCuttingInKitchen/v_CuttingInKitchen_g03_c02.avi\nCuttingInKitchen/v_CuttingInKitchen_g03_c03.avi\nCuttingInKitchen/v_CuttingInKitchen_g03_c04.avi\nCuttingInKitchen/v_CuttingInKitchen_g04_c01.avi\nCuttingInKitchen/v_CuttingInKitchen_g04_c02.avi\nCuttingInKitchen/v_CuttingInKitchen_g04_c03.avi\nCuttingInKitchen/v_CuttingInKitchen_g04_c04.avi\nCuttingInKitchen/v_CuttingInKitchen_g04_c05.avi\nCuttingInKitchen/v_CuttingInKitchen_g05_c01.avi\nCuttingInKitchen/v_CuttingInKitchen_g05_c02.avi\nCuttingInKitchen/v_CuttingInKitchen_g05_c03.avi\nCuttingInKitchen/v_CuttingInKitchen_g05_c04.avi\nCuttingInKitchen/v_CuttingInKitchen_g05_c05.avi\nCuttingInKitchen/v_CuttingInKitchen_g05_c06.avi\nCuttingInKitchen/v_CuttingInKitchen_g06_c01.avi\nCuttingInKitchen/v_CuttingInKitchen_g06_c02.avi\nCuttingInKitchen/v_CuttingInKitchen_g06_c03.avi\nCuttingInKitchen/v_CuttingInKitchen_g06_c04.avi\nCuttingInKitchen/v_CuttingInKitchen_g06_c05.avi\nCuttingInKitchen/v_CuttingInKitchen_g07_c01.avi\nCuttingInKitchen/v_CuttingInKitchen_g07_c02.avi\nCuttingInKitchen/v_CuttingInKitchen_g07_c03.avi\nCuttingInKitchen/v_CuttingInKitchen_g07_c04.avi\nDiving/v_Diving_g01_c01.avi\nDiving/v_Diving_g01_c02.avi\nDiving/v_Diving_g01_c03.avi\nDiving/v_Diving_g01_c04.avi\nDiving/v_Diving_g01_c05.avi\nDiving/v_Diving_g01_c06.avi\nDiving/v_Diving_g01_c07.avi\nDiving/v_Diving_g02_c01.avi\nDiving/v_Diving_g02_c02.avi\nDiving/v_Diving_g02_c03.avi\nDiving/v_Diving_g02_c04.avi\nDiving/v_Diving_g02_c05.avi\nDiving/v_Diving_g02_c06.avi\nDiving/v_Diving_g02_c07.avi\nDiving/v_Diving_g03_c01.avi\nDiving/v_Diving_g03_c02.avi\nDiving/v_Diving_g03_c03.avi\nDiving/v_Diving_g03_c04.avi\nDiving/v_Diving_g03_c05.avi\nDiving/v_Diving_g03_c06.avi\nDiving/v_Diving_g03_c07.avi\nDiving/v_Diving_g04_c01.avi\nDiving/v_Diving_g04_c02.avi\nDiving/v_Diving_g04_c03.avi\nDiving/v_Diving_g04_c04.avi\nDiving/v_Diving_g04_c05.avi\nDiving/v_Diving_g04_c06.avi\nDiving/v_Diving_g04_c07.avi\nDiving/v_Diving_g05_c01.avi\nDiving/v_Diving_g05_c02.avi\nDiving/v_Diving_g05_c03.avi\nDiving/v_Diving_g05_c04.avi\nDiving/v_Diving_g05_c05.avi\nDiving/v_Diving_g05_c06.avi\nDiving/v_Diving_g06_c01.avi\nDiving/v_Diving_g06_c02.avi\nDiving/v_Diving_g06_c03.avi\nDiving/v_Diving_g06_c04.avi\nDiving/v_Diving_g06_c05.avi\nDiving/v_Diving_g06_c06.avi\nDiving/v_Diving_g06_c07.avi\nDiving/v_Diving_g07_c01.avi\nDiving/v_Diving_g07_c02.avi\nDiving/v_Diving_g07_c03.avi\nDiving/v_Diving_g07_c04.avi\nDrumming/v_Drumming_g01_c01.avi\nDrumming/v_Drumming_g01_c02.avi\nDrumming/v_Drumming_g01_c03.avi\nDrumming/v_Drumming_g01_c04.avi\nDrumming/v_Drumming_g01_c05.avi\nDrumming/v_Drumming_g01_c06.avi\nDrumming/v_Drumming_g01_c07.avi\nDrumming/v_Drumming_g02_c01.avi\nDrumming/v_Drumming_g02_c02.avi\nDrumming/v_Drumming_g02_c03.avi\nDrumming/v_Drumming_g02_c04.avi\nDrumming/v_Drumming_g02_c05.avi\nDrumming/v_Drumming_g02_c06.avi\nDrumming/v_Drumming_g02_c07.avi\nDrumming/v_Drumming_g03_c01.avi\nDrumming/v_Drumming_g03_c02.avi\nDrumming/v_Drumming_g03_c03.avi\nDrumming/v_Drumming_g03_c04.avi\nDrumming/v_Drumming_g03_c05.avi\nDrumming/v_Drumming_g04_c01.avi\nDrumming/v_Drumming_g04_c02.avi\nDrumming/v_Drumming_g04_c03.avi\nDrumming/v_Drumming_g04_c04.avi\nDrumming/v_Drumming_g04_c05.avi\nDrumming/v_Drumming_g04_c06.avi\nDrumming/v_Drumming_g04_c07.avi\nDrumming/v_Drumming_g05_c01.avi\nDrumming/v_Drumming_g05_c02.avi\nDrumming/v_Drumming_g05_c03.avi\nDrumming/v_Drumming_g05_c04.avi\nDrumming/v_Drumming_g05_c05.avi\nDrumming/v_Drumming_g05_c06.avi\nDrumming/v_Drumming_g06_c01.avi\nDrumming/v_Drumming_g06_c02.avi\nDrumming/v_Drumming_g06_c03.avi\nDrumming/v_Drumming_g06_c04.avi\nDrumming/v_Drumming_g06_c05.avi\nDrumming/v_Drumming_g06_c06.avi\nDrumming/v_Drumming_g07_c01.avi\nDrumming/v_Drumming_g07_c02.avi\nDrumming/v_Drumming_g07_c03.avi\nDrumming/v_Drumming_g07_c04.avi\nDrumming/v_Drumming_g07_c05.avi\nDrumming/v_Drumming_g07_c06.avi\nDrumming/v_Drumming_g07_c07.avi\nFencing/v_Fencing_g01_c01.avi\nFencing/v_Fencing_g01_c02.avi\nFencing/v_Fencing_g01_c03.avi\nFencing/v_Fencing_g01_c04.avi\nFencing/v_Fencing_g01_c05.avi\nFencing/v_Fencing_g01_c06.avi\nFencing/v_Fencing_g02_c01.avi\nFencing/v_Fencing_g02_c02.avi\nFencing/v_Fencing_g02_c03.avi\nFencing/v_Fencing_g02_c04.avi\nFencing/v_Fencing_g02_c05.avi\nFencing/v_Fencing_g03_c01.avi\nFencing/v_Fencing_g03_c02.avi\nFencing/v_Fencing_g03_c03.avi\nFencing/v_Fencing_g03_c04.avi\nFencing/v_Fencing_g03_c05.avi\nFencing/v_Fencing_g04_c01.avi\nFencing/v_Fencing_g04_c02.avi\nFencing/v_Fencing_g04_c03.avi\nFencing/v_Fencing_g04_c04.avi\nFencing/v_Fencing_g04_c05.avi\nFencing/v_Fencing_g05_c01.avi\nFencing/v_Fencing_g05_c02.avi\nFencing/v_Fencing_g05_c03.avi\nFencing/v_Fencing_g05_c04.avi\nFencing/v_Fencing_g05_c05.avi\nFencing/v_Fencing_g06_c01.avi\nFencing/v_Fencing_g06_c02.avi\nFencing/v_Fencing_g06_c03.avi\nFencing/v_Fencing_g06_c04.avi\nFencing/v_Fencing_g07_c01.avi\nFencing/v_Fencing_g07_c02.avi\nFencing/v_Fencing_g07_c03.avi\nFencing/v_Fencing_g07_c04.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g01_c01.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g01_c02.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g01_c03.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g01_c04.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g01_c05.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g02_c01.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g02_c02.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g02_c03.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g02_c04.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g02_c05.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g02_c06.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g03_c01.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g03_c02.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g03_c03.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g03_c04.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g04_c01.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g04_c02.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g04_c03.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g04_c04.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g04_c05.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g04_c06.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g04_c07.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g05_c01.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g05_c02.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g05_c03.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g05_c04.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g05_c05.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g05_c06.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g05_c07.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g06_c01.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g06_c02.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g06_c03.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g06_c04.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g06_c05.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g06_c06.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g06_c07.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g07_c01.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g07_c02.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g07_c03.avi\nFieldHockeyPenalty/v_FieldHockeyPenalty_g07_c04.avi\nFloorGymnastics/v_FloorGymnastics_g01_c01.avi\nFloorGymnastics/v_FloorGymnastics_g01_c02.avi\nFloorGymnastics/v_FloorGymnastics_g01_c03.avi\nFloorGymnastics/v_FloorGymnastics_g01_c04.avi\nFloorGymnastics/v_FloorGymnastics_g01_c05.avi\nFloorGymnastics/v_FloorGymnastics_g02_c01.avi\nFloorGymnastics/v_FloorGymnastics_g02_c02.avi\nFloorGymnastics/v_FloorGymnastics_g02_c03.avi\nFloorGymnastics/v_FloorGymnastics_g02_c04.avi\nFloorGymnastics/v_FloorGymnastics_g03_c01.avi\nFloorGymnastics/v_FloorGymnastics_g03_c02.avi\nFloorGymnastics/v_FloorGymnastics_g03_c03.avi\nFloorGymnastics/v_FloorGymnastics_g03_c04.avi\nFloorGymnastics/v_FloorGymnastics_g04_c01.avi\nFloorGymnastics/v_FloorGymnastics_g04_c02.avi\nFloorGymnastics/v_FloorGymnastics_g04_c03.avi\nFloorGymnastics/v_FloorGymnastics_g04_c04.avi\nFloorGymnastics/v_FloorGymnastics_g04_c05.avi\nFloorGymnastics/v_FloorGymnastics_g05_c01.avi\nFloorGymnastics/v_FloorGymnastics_g05_c02.avi\nFloorGymnastics/v_FloorGymnastics_g05_c03.avi\nFloorGymnastics/v_FloorGymnastics_g05_c04.avi\nFloorGymnastics/v_FloorGymnastics_g06_c01.avi\nFloorGymnastics/v_FloorGymnastics_g06_c02.avi\nFloorGymnastics/v_FloorGymnastics_g06_c03.avi\nFloorGymnastics/v_FloorGymnastics_g06_c04.avi\nFloorGymnastics/v_FloorGymnastics_g06_c05.avi\nFloorGymnastics/v_FloorGymnastics_g06_c06.avi\nFloorGymnastics/v_FloorGymnastics_g06_c07.avi\nFloorGymnastics/v_FloorGymnastics_g07_c01.avi\nFloorGymnastics/v_FloorGymnastics_g07_c02.avi\nFloorGymnastics/v_FloorGymnastics_g07_c03.avi\nFloorGymnastics/v_FloorGymnastics_g07_c04.avi\nFloorGymnastics/v_FloorGymnastics_g07_c05.avi\nFloorGymnastics/v_FloorGymnastics_g07_c06.avi\nFloorGymnastics/v_FloorGymnastics_g07_c07.avi\nFrisbeeCatch/v_FrisbeeCatch_g01_c01.avi\nFrisbeeCatch/v_FrisbeeCatch_g01_c02.avi\nFrisbeeCatch/v_FrisbeeCatch_g01_c03.avi\nFrisbeeCatch/v_FrisbeeCatch_g01_c04.avi\nFrisbeeCatch/v_FrisbeeCatch_g01_c05.avi\nFrisbeeCatch/v_FrisbeeCatch_g01_c06.avi\nFrisbeeCatch/v_FrisbeeCatch_g02_c01.avi\nFrisbeeCatch/v_FrisbeeCatch_g02_c02.avi\nFrisbeeCatch/v_FrisbeeCatch_g02_c03.avi\nFrisbeeCatch/v_FrisbeeCatch_g02_c04.avi\nFrisbeeCatch/v_FrisbeeCatch_g02_c05.avi\nFrisbeeCatch/v_FrisbeeCatch_g03_c01.avi\nFrisbeeCatch/v_FrisbeeCatch_g03_c02.avi\nFrisbeeCatch/v_FrisbeeCatch_g03_c03.avi\nFrisbeeCatch/v_FrisbeeCatch_g03_c04.avi\nFrisbeeCatch/v_FrisbeeCatch_g03_c05.avi\nFrisbeeCatch/v_FrisbeeCatch_g04_c01.avi\nFrisbeeCatch/v_FrisbeeCatch_g04_c02.avi\nFrisbeeCatch/v_FrisbeeCatch_g04_c03.avi\nFrisbeeCatch/v_FrisbeeCatch_g04_c04.avi\nFrisbeeCatch/v_FrisbeeCatch_g04_c05.avi\nFrisbeeCatch/v_FrisbeeCatch_g05_c01.avi\nFrisbeeCatch/v_FrisbeeCatch_g05_c02.avi\nFrisbeeCatch/v_FrisbeeCatch_g05_c03.avi\nFrisbeeCatch/v_FrisbeeCatch_g05_c04.avi\nFrisbeeCatch/v_FrisbeeCatch_g05_c05.avi\nFrisbeeCatch/v_FrisbeeCatch_g06_c01.avi\nFrisbeeCatch/v_FrisbeeCatch_g06_c02.avi\nFrisbeeCatch/v_FrisbeeCatch_g06_c03.avi\nFrisbeeCatch/v_FrisbeeCatch_g06_c04.avi\nFrisbeeCatch/v_FrisbeeCatch_g06_c05.avi\nFrisbeeCatch/v_FrisbeeCatch_g07_c01.avi\nFrisbeeCatch/v_FrisbeeCatch_g07_c02.avi\nFrisbeeCatch/v_FrisbeeCatch_g07_c03.avi\nFrisbeeCatch/v_FrisbeeCatch_g07_c04.avi\nFrisbeeCatch/v_FrisbeeCatch_g07_c05.avi\nFrisbeeCatch/v_FrisbeeCatch_g07_c06.avi\nFrontCrawl/v_FrontCrawl_g01_c01.avi\nFrontCrawl/v_FrontCrawl_g01_c02.avi\nFrontCrawl/v_FrontCrawl_g01_c03.avi\nFrontCrawl/v_FrontCrawl_g01_c04.avi\nFrontCrawl/v_FrontCrawl_g02_c01.avi\nFrontCrawl/v_FrontCrawl_g02_c02.avi\nFrontCrawl/v_FrontCrawl_g02_c03.avi\nFrontCrawl/v_FrontCrawl_g02_c04.avi\nFrontCrawl/v_FrontCrawl_g03_c01.avi\nFrontCrawl/v_FrontCrawl_g03_c02.avi\nFrontCrawl/v_FrontCrawl_g03_c03.avi\nFrontCrawl/v_FrontCrawl_g03_c04.avi\nFrontCrawl/v_FrontCrawl_g03_c05.avi\nFrontCrawl/v_FrontCrawl_g03_c06.avi\nFrontCrawl/v_FrontCrawl_g04_c01.avi\nFrontCrawl/v_FrontCrawl_g04_c02.avi\nFrontCrawl/v_FrontCrawl_g04_c03.avi\nFrontCrawl/v_FrontCrawl_g04_c04.avi\nFrontCrawl/v_FrontCrawl_g04_c05.avi\nFrontCrawl/v_FrontCrawl_g04_c06.avi\nFrontCrawl/v_FrontCrawl_g04_c07.avi\nFrontCrawl/v_FrontCrawl_g05_c01.avi\nFrontCrawl/v_FrontCrawl_g05_c02.avi\nFrontCrawl/v_FrontCrawl_g05_c03.avi\nFrontCrawl/v_FrontCrawl_g05_c04.avi\nFrontCrawl/v_FrontCrawl_g06_c01.avi\nFrontCrawl/v_FrontCrawl_g06_c02.avi\nFrontCrawl/v_FrontCrawl_g06_c03.avi\nFrontCrawl/v_FrontCrawl_g06_c04.avi\nFrontCrawl/v_FrontCrawl_g06_c05.avi\nFrontCrawl/v_FrontCrawl_g07_c01.avi\nFrontCrawl/v_FrontCrawl_g07_c02.avi\nFrontCrawl/v_FrontCrawl_g07_c03.avi\nFrontCrawl/v_FrontCrawl_g07_c04.avi\nFrontCrawl/v_FrontCrawl_g07_c05.avi\nFrontCrawl/v_FrontCrawl_g07_c06.avi\nFrontCrawl/v_FrontCrawl_g07_c07.avi\nGolfSwing/v_GolfSwing_g01_c01.avi\nGolfSwing/v_GolfSwing_g01_c02.avi\nGolfSwing/v_GolfSwing_g01_c03.avi\nGolfSwing/v_GolfSwing_g01_c04.avi\nGolfSwing/v_GolfSwing_g01_c05.avi\nGolfSwing/v_GolfSwing_g01_c06.avi\nGolfSwing/v_GolfSwing_g02_c01.avi\nGolfSwing/v_GolfSwing_g02_c02.avi\nGolfSwing/v_GolfSwing_g02_c03.avi\nGolfSwing/v_GolfSwing_g02_c04.avi\nGolfSwing/v_GolfSwing_g03_c01.avi\nGolfSwing/v_GolfSwing_g03_c02.avi\nGolfSwing/v_GolfSwing_g03_c03.avi\nGolfSwing/v_GolfSwing_g03_c04.avi\nGolfSwing/v_GolfSwing_g03_c05.avi\nGolfSwing/v_GolfSwing_g03_c06.avi\nGolfSwing/v_GolfSwing_g03_c07.avi\nGolfSwing/v_GolfSwing_g04_c01.avi\nGolfSwing/v_GolfSwing_g04_c02.avi\nGolfSwing/v_GolfSwing_g04_c03.avi\nGolfSwing/v_GolfSwing_g04_c04.avi\nGolfSwing/v_GolfSwing_g04_c05.avi\nGolfSwing/v_GolfSwing_g04_c06.avi\nGolfSwing/v_GolfSwing_g05_c01.avi\nGolfSwing/v_GolfSwing_g05_c02.avi\nGolfSwing/v_GolfSwing_g05_c03.avi\nGolfSwing/v_GolfSwing_g05_c04.avi\nGolfSwing/v_GolfSwing_g05_c05.avi\nGolfSwing/v_GolfSwing_g05_c06.avi\nGolfSwing/v_GolfSwing_g05_c07.avi\nGolfSwing/v_GolfSwing_g06_c01.avi\nGolfSwing/v_GolfSwing_g06_c02.avi\nGolfSwing/v_GolfSwing_g06_c03.avi\nGolfSwing/v_GolfSwing_g06_c04.avi\nGolfSwing/v_GolfSwing_g07_c01.avi\nGolfSwing/v_GolfSwing_g07_c02.avi\nGolfSwing/v_GolfSwing_g07_c03.avi\nGolfSwing/v_GolfSwing_g07_c04.avi\nGolfSwing/v_GolfSwing_g07_c05.avi\nHaircut/v_Haircut_g01_c01.avi\nHaircut/v_Haircut_g01_c02.avi\nHaircut/v_Haircut_g01_c03.avi\nHaircut/v_Haircut_g01_c04.avi\nHaircut/v_Haircut_g02_c01.avi\nHaircut/v_Haircut_g02_c02.avi\nHaircut/v_Haircut_g02_c03.avi\nHaircut/v_Haircut_g02_c04.avi\nHaircut/v_Haircut_g03_c01.avi\nHaircut/v_Haircut_g03_c02.avi\nHaircut/v_Haircut_g03_c03.avi\nHaircut/v_Haircut_g03_c04.avi\nHaircut/v_Haircut_g03_c05.avi\nHaircut/v_Haircut_g03_c06.avi\nHaircut/v_Haircut_g04_c01.avi\nHaircut/v_Haircut_g04_c02.avi\nHaircut/v_Haircut_g04_c03.avi\nHaircut/v_Haircut_g04_c04.avi\nHaircut/v_Haircut_g04_c05.avi\nHaircut/v_Haircut_g05_c01.avi\nHaircut/v_Haircut_g05_c02.avi\nHaircut/v_Haircut_g05_c03.avi\nHaircut/v_Haircut_g05_c04.avi\nHaircut/v_Haircut_g06_c01.avi\nHaircut/v_Haircut_g06_c02.avi\nHaircut/v_Haircut_g06_c03.avi\nHaircut/v_Haircut_g06_c04.avi\nHaircut/v_Haircut_g07_c01.avi\nHaircut/v_Haircut_g07_c02.avi\nHaircut/v_Haircut_g07_c03.avi\nHaircut/v_Haircut_g07_c04.avi\nHaircut/v_Haircut_g07_c05.avi\nHaircut/v_Haircut_g07_c06.avi\nHammering/v_Hammering_g01_c01.avi\nHammering/v_Hammering_g01_c02.avi\nHammering/v_Hammering_g01_c03.avi\nHammering/v_Hammering_g01_c04.avi\nHammering/v_Hammering_g02_c01.avi\nHammering/v_Hammering_g02_c02.avi\nHammering/v_Hammering_g02_c03.avi\nHammering/v_Hammering_g02_c04.avi\nHammering/v_Hammering_g03_c01.avi\nHammering/v_Hammering_g03_c02.avi\nHammering/v_Hammering_g03_c03.avi\nHammering/v_Hammering_g03_c04.avi\nHammering/v_Hammering_g03_c05.avi\nHammering/v_Hammering_g04_c01.avi\nHammering/v_Hammering_g04_c02.avi\nHammering/v_Hammering_g04_c03.avi\nHammering/v_Hammering_g04_c04.avi\nHammering/v_Hammering_g04_c05.avi\nHammering/v_Hammering_g05_c01.avi\nHammering/v_Hammering_g05_c02.avi\nHammering/v_Hammering_g05_c03.avi\nHammering/v_Hammering_g05_c04.avi\nHammering/v_Hammering_g06_c01.avi\nHammering/v_Hammering_g06_c02.avi\nHammering/v_Hammering_g06_c03.avi\nHammering/v_Hammering_g06_c04.avi\nHammering/v_Hammering_g06_c05.avi\nHammering/v_Hammering_g06_c06.avi\nHammering/v_Hammering_g07_c01.avi\nHammering/v_Hammering_g07_c02.avi\nHammering/v_Hammering_g07_c03.avi\nHammering/v_Hammering_g07_c04.avi\nHammering/v_Hammering_g07_c05.avi\nHammerThrow/v_HammerThrow_g01_c01.avi\nHammerThrow/v_HammerThrow_g01_c02.avi\nHammerThrow/v_HammerThrow_g01_c03.avi\nHammerThrow/v_HammerThrow_g01_c04.avi\nHammerThrow/v_HammerThrow_g01_c05.avi\nHammerThrow/v_HammerThrow_g01_c06.avi\nHammerThrow/v_HammerThrow_g02_c01.avi\nHammerThrow/v_HammerThrow_g02_c02.avi\nHammerThrow/v_HammerThrow_g02_c03.avi\nHammerThrow/v_HammerThrow_g02_c04.avi\nHammerThrow/v_HammerThrow_g02_c05.avi\nHammerThrow/v_HammerThrow_g02_c06.avi\nHammerThrow/v_HammerThrow_g02_c07.avi\nHammerThrow/v_HammerThrow_g03_c01.avi\nHammerThrow/v_HammerThrow_g03_c02.avi\nHammerThrow/v_HammerThrow_g03_c03.avi\nHammerThrow/v_HammerThrow_g03_c04.avi\nHammerThrow/v_HammerThrow_g03_c05.avi\nHammerThrow/v_HammerThrow_g03_c06.avi\nHammerThrow/v_HammerThrow_g03_c07.avi\nHammerThrow/v_HammerThrow_g04_c01.avi\nHammerThrow/v_HammerThrow_g04_c02.avi\nHammerThrow/v_HammerThrow_g04_c03.avi\nHammerThrow/v_HammerThrow_g04_c04.avi\nHammerThrow/v_HammerThrow_g04_c05.avi\nHammerThrow/v_HammerThrow_g04_c06.avi\nHammerThrow/v_HammerThrow_g04_c07.avi\nHammerThrow/v_HammerThrow_g05_c01.avi\nHammerThrow/v_HammerThrow_g05_c02.avi\nHammerThrow/v_HammerThrow_g05_c03.avi\nHammerThrow/v_HammerThrow_g05_c04.avi\nHammerThrow/v_HammerThrow_g05_c05.avi\nHammerThrow/v_HammerThrow_g05_c06.avi\nHammerThrow/v_HammerThrow_g06_c01.avi\nHammerThrow/v_HammerThrow_g06_c02.avi\nHammerThrow/v_HammerThrow_g06_c03.avi\nHammerThrow/v_HammerThrow_g06_c04.avi\nHammerThrow/v_HammerThrow_g06_c05.avi\nHammerThrow/v_HammerThrow_g06_c06.avi\nHammerThrow/v_HammerThrow_g06_c07.avi\nHammerThrow/v_HammerThrow_g07_c01.avi\nHammerThrow/v_HammerThrow_g07_c02.avi\nHammerThrow/v_HammerThrow_g07_c03.avi\nHammerThrow/v_HammerThrow_g07_c04.avi\nHammerThrow/v_HammerThrow_g07_c05.avi\nHandstandPushups/v_HandStandPushups_g01_c01.avi\nHandstandPushups/v_HandStandPushups_g01_c02.avi\nHandstandPushups/v_HandStandPushups_g01_c03.avi\nHandstandPushups/v_HandStandPushups_g01_c04.avi\nHandstandPushups/v_HandStandPushups_g02_c01.avi\nHandstandPushups/v_HandStandPushups_g02_c02.avi\nHandstandPushups/v_HandStandPushups_g02_c03.avi\nHandstandPushups/v_HandStandPushups_g02_c04.avi\nHandstandPushups/v_HandStandPushups_g03_c01.avi\nHandstandPushups/v_HandStandPushups_g03_c02.avi\nHandstandPushups/v_HandStandPushups_g03_c03.avi\nHandstandPushups/v_HandStandPushups_g03_c04.avi\nHandstandPushups/v_HandStandPushups_g04_c01.avi\nHandstandPushups/v_HandStandPushups_g04_c02.avi\nHandstandPushups/v_HandStandPushups_g04_c03.avi\nHandstandPushups/v_HandStandPushups_g04_c04.avi\nHandstandPushups/v_HandStandPushups_g05_c01.avi\nHandstandPushups/v_HandStandPushups_g05_c02.avi\nHandstandPushups/v_HandStandPushups_g05_c03.avi\nHandstandPushups/v_HandStandPushups_g05_c04.avi\nHandstandPushups/v_HandStandPushups_g06_c01.avi\nHandstandPushups/v_HandStandPushups_g06_c02.avi\nHandstandPushups/v_HandStandPushups_g06_c03.avi\nHandstandPushups/v_HandStandPushups_g06_c04.avi\nHandstandPushups/v_HandStandPushups_g07_c01.avi\nHandstandPushups/v_HandStandPushups_g07_c02.avi\nHandstandPushups/v_HandStandPushups_g07_c03.avi\nHandstandPushups/v_HandStandPushups_g07_c04.avi\nHandstandWalking/v_HandstandWalking_g01_c01.avi\nHandstandWalking/v_HandstandWalking_g01_c02.avi\nHandstandWalking/v_HandstandWalking_g01_c03.avi\nHandstandWalking/v_HandstandWalking_g01_c04.avi\nHandstandWalking/v_HandstandWalking_g02_c01.avi\nHandstandWalking/v_HandstandWalking_g02_c02.avi\nHandstandWalking/v_HandstandWalking_g02_c03.avi\nHandstandWalking/v_HandstandWalking_g02_c04.avi\nHandstandWalking/v_HandstandWalking_g03_c01.avi\nHandstandWalking/v_HandstandWalking_g03_c02.avi\nHandstandWalking/v_HandstandWalking_g03_c03.avi\nHandstandWalking/v_HandstandWalking_g03_c04.avi\nHandstandWalking/v_HandstandWalking_g04_c01.avi\nHandstandWalking/v_HandstandWalking_g04_c02.avi\nHandstandWalking/v_HandstandWalking_g04_c03.avi\nHandstandWalking/v_HandstandWalking_g04_c04.avi\nHandstandWalking/v_HandstandWalking_g04_c05.avi\nHandstandWalking/v_HandstandWalking_g05_c01.avi\nHandstandWalking/v_HandstandWalking_g05_c02.avi\nHandstandWalking/v_HandstandWalking_g05_c03.avi\nHandstandWalking/v_HandstandWalking_g05_c04.avi\nHandstandWalking/v_HandstandWalking_g05_c05.avi\nHandstandWalking/v_HandstandWalking_g05_c06.avi\nHandstandWalking/v_HandstandWalking_g05_c07.avi\nHandstandWalking/v_HandstandWalking_g06_c01.avi\nHandstandWalking/v_HandstandWalking_g06_c02.avi\nHandstandWalking/v_HandstandWalking_g06_c03.avi\nHandstandWalking/v_HandstandWalking_g06_c04.avi\nHandstandWalking/v_HandstandWalking_g07_c01.avi\nHandstandWalking/v_HandstandWalking_g07_c02.avi\nHandstandWalking/v_HandstandWalking_g07_c03.avi\nHandstandWalking/v_HandstandWalking_g07_c04.avi\nHandstandWalking/v_HandstandWalking_g07_c05.avi\nHandstandWalking/v_HandstandWalking_g07_c06.avi\nHeadMassage/v_HeadMassage_g01_c01.avi\nHeadMassage/v_HeadMassage_g01_c02.avi\nHeadMassage/v_HeadMassage_g01_c03.avi\nHeadMassage/v_HeadMassage_g01_c04.avi\nHeadMassage/v_HeadMassage_g01_c05.avi\nHeadMassage/v_HeadMassage_g02_c01.avi\nHeadMassage/v_HeadMassage_g02_c02.avi\nHeadMassage/v_HeadMassage_g02_c03.avi\nHeadMassage/v_HeadMassage_g02_c04.avi\nHeadMassage/v_HeadMassage_g02_c05.avi\nHeadMassage/v_HeadMassage_g02_c06.avi\nHeadMassage/v_HeadMassage_g02_c07.avi\nHeadMassage/v_HeadMassage_g03_c01.avi\nHeadMassage/v_HeadMassage_g03_c02.avi\nHeadMassage/v_HeadMassage_g03_c03.avi\nHeadMassage/v_HeadMassage_g03_c04.avi\nHeadMassage/v_HeadMassage_g03_c05.avi\nHeadMassage/v_HeadMassage_g03_c06.avi\nHeadMassage/v_HeadMassage_g03_c07.avi\nHeadMassage/v_HeadMassage_g04_c01.avi\nHeadMassage/v_HeadMassage_g04_c02.avi\nHeadMassage/v_HeadMassage_g04_c03.avi\nHeadMassage/v_HeadMassage_g04_c04.avi\nHeadMassage/v_HeadMassage_g05_c01.avi\nHeadMassage/v_HeadMassage_g05_c02.avi\nHeadMassage/v_HeadMassage_g05_c03.avi\nHeadMassage/v_HeadMassage_g05_c04.avi\nHeadMassage/v_HeadMassage_g05_c05.avi\nHeadMassage/v_HeadMassage_g05_c06.avi\nHeadMassage/v_HeadMassage_g06_c01.avi\nHeadMassage/v_HeadMassage_g06_c02.avi\nHeadMassage/v_HeadMassage_g06_c03.avi\nHeadMassage/v_HeadMassage_g06_c04.avi\nHeadMassage/v_HeadMassage_g06_c05.avi\nHeadMassage/v_HeadMassage_g06_c06.avi\nHeadMassage/v_HeadMassage_g06_c07.avi\nHeadMassage/v_HeadMassage_g07_c01.avi\nHeadMassage/v_HeadMassage_g07_c02.avi\nHeadMassage/v_HeadMassage_g07_c03.avi\nHeadMassage/v_HeadMassage_g07_c04.avi\nHeadMassage/v_HeadMassage_g07_c05.avi\nHighJump/v_HighJump_g01_c01.avi\nHighJump/v_HighJump_g01_c02.avi\nHighJump/v_HighJump_g01_c03.avi\nHighJump/v_HighJump_g01_c04.avi\nHighJump/v_HighJump_g01_c05.avi\nHighJump/v_HighJump_g02_c01.avi\nHighJump/v_HighJump_g02_c02.avi\nHighJump/v_HighJump_g02_c03.avi\nHighJump/v_HighJump_g02_c04.avi\nHighJump/v_HighJump_g02_c05.avi\nHighJump/v_HighJump_g02_c06.avi\nHighJump/v_HighJump_g02_c07.avi\nHighJump/v_HighJump_g03_c01.avi\nHighJump/v_HighJump_g03_c02.avi\nHighJump/v_HighJump_g03_c03.avi\nHighJump/v_HighJump_g03_c04.avi\nHighJump/v_HighJump_g04_c01.avi\nHighJump/v_HighJump_g04_c02.avi\nHighJump/v_HighJump_g04_c03.avi\nHighJump/v_HighJump_g04_c04.avi\nHighJump/v_HighJump_g04_c05.avi\nHighJump/v_HighJump_g04_c06.avi\nHighJump/v_HighJump_g05_c01.avi\nHighJump/v_HighJump_g05_c02.avi\nHighJump/v_HighJump_g05_c03.avi\nHighJump/v_HighJump_g05_c04.avi\nHighJump/v_HighJump_g05_c05.avi\nHighJump/v_HighJump_g06_c01.avi\nHighJump/v_HighJump_g06_c02.avi\nHighJump/v_HighJump_g06_c03.avi\nHighJump/v_HighJump_g06_c04.avi\nHighJump/v_HighJump_g07_c01.avi\nHighJump/v_HighJump_g07_c02.avi\nHighJump/v_HighJump_g07_c03.avi\nHighJump/v_HighJump_g07_c04.avi\nHighJump/v_HighJump_g07_c05.avi\nHighJump/v_HighJump_g07_c06.avi\nHorseRace/v_HorseRace_g01_c01.avi\nHorseRace/v_HorseRace_g01_c02.avi\nHorseRace/v_HorseRace_g01_c03.avi\nHorseRace/v_HorseRace_g01_c04.avi\nHorseRace/v_HorseRace_g02_c01.avi\nHorseRace/v_HorseRace_g02_c02.avi\nHorseRace/v_HorseRace_g02_c03.avi\nHorseRace/v_HorseRace_g02_c04.avi\nHorseRace/v_HorseRace_g03_c01.avi\nHorseRace/v_HorseRace_g03_c02.avi\nHorseRace/v_HorseRace_g03_c03.avi\nHorseRace/v_HorseRace_g03_c04.avi\nHorseRace/v_HorseRace_g03_c05.avi\nHorseRace/v_HorseRace_g04_c01.avi\nHorseRace/v_HorseRace_g04_c02.avi\nHorseRace/v_HorseRace_g04_c03.avi\nHorseRace/v_HorseRace_g04_c04.avi\nHorseRace/v_HorseRace_g04_c05.avi\nHorseRace/v_HorseRace_g04_c06.avi\nHorseRace/v_HorseRace_g05_c01.avi\nHorseRace/v_HorseRace_g05_c02.avi\nHorseRace/v_HorseRace_g05_c03.avi\nHorseRace/v_HorseRace_g05_c04.avi\nHorseRace/v_HorseRace_g06_c01.avi\nHorseRace/v_HorseRace_g06_c02.avi\nHorseRace/v_HorseRace_g06_c03.avi\nHorseRace/v_HorseRace_g06_c04.avi\nHorseRace/v_HorseRace_g06_c05.avi\nHorseRace/v_HorseRace_g06_c06.avi\nHorseRace/v_HorseRace_g07_c01.avi\nHorseRace/v_HorseRace_g07_c02.avi\nHorseRace/v_HorseRace_g07_c03.avi\nHorseRace/v_HorseRace_g07_c04.avi\nHorseRace/v_HorseRace_g07_c05.avi\nHorseRace/v_HorseRace_g07_c06.avi\nHorseRiding/v_HorseRiding_g01_c01.avi\nHorseRiding/v_HorseRiding_g01_c02.avi\nHorseRiding/v_HorseRiding_g01_c03.avi\nHorseRiding/v_HorseRiding_g01_c04.avi\nHorseRiding/v_HorseRiding_g01_c05.avi\nHorseRiding/v_HorseRiding_g01_c06.avi\nHorseRiding/v_HorseRiding_g01_c07.avi\nHorseRiding/v_HorseRiding_g02_c01.avi\nHorseRiding/v_HorseRiding_g02_c02.avi\nHorseRiding/v_HorseRiding_g02_c03.avi\nHorseRiding/v_HorseRiding_g02_c04.avi\nHorseRiding/v_HorseRiding_g02_c05.avi\nHorseRiding/v_HorseRiding_g02_c06.avi\nHorseRiding/v_HorseRiding_g02_c07.avi\nHorseRiding/v_HorseRiding_g03_c01.avi\nHorseRiding/v_HorseRiding_g03_c02.avi\nHorseRiding/v_HorseRiding_g03_c03.avi\nHorseRiding/v_HorseRiding_g03_c04.avi\nHorseRiding/v_HorseRiding_g03_c05.avi\nHorseRiding/v_HorseRiding_g03_c06.avi\nHorseRiding/v_HorseRiding_g03_c07.avi\nHorseRiding/v_HorseRiding_g04_c01.avi\nHorseRiding/v_HorseRiding_g04_c02.avi\nHorseRiding/v_HorseRiding_g04_c03.avi\nHorseRiding/v_HorseRiding_g04_c04.avi\nHorseRiding/v_HorseRiding_g04_c05.avi\nHorseRiding/v_HorseRiding_g04_c06.avi\nHorseRiding/v_HorseRiding_g04_c07.avi\nHorseRiding/v_HorseRiding_g05_c01.avi\nHorseRiding/v_HorseRiding_g05_c02.avi\nHorseRiding/v_HorseRiding_g05_c03.avi\nHorseRiding/v_HorseRiding_g05_c04.avi\nHorseRiding/v_HorseRiding_g05_c05.avi\nHorseRiding/v_HorseRiding_g05_c06.avi\nHorseRiding/v_HorseRiding_g05_c07.avi\nHorseRiding/v_HorseRiding_g06_c01.avi\nHorseRiding/v_HorseRiding_g06_c02.avi\nHorseRiding/v_HorseRiding_g06_c03.avi\nHorseRiding/v_HorseRiding_g06_c04.avi\nHorseRiding/v_HorseRiding_g06_c05.avi\nHorseRiding/v_HorseRiding_g06_c06.avi\nHorseRiding/v_HorseRiding_g06_c07.avi\nHorseRiding/v_HorseRiding_g07_c01.avi\nHorseRiding/v_HorseRiding_g07_c02.avi\nHorseRiding/v_HorseRiding_g07_c03.avi\nHorseRiding/v_HorseRiding_g07_c04.avi\nHorseRiding/v_HorseRiding_g07_c05.avi\nHorseRiding/v_HorseRiding_g07_c06.avi\nHorseRiding/v_HorseRiding_g07_c07.avi\nHulaHoop/v_HulaHoop_g01_c01.avi\nHulaHoop/v_HulaHoop_g01_c02.avi\nHulaHoop/v_HulaHoop_g01_c03.avi\nHulaHoop/v_HulaHoop_g01_c04.avi\nHulaHoop/v_HulaHoop_g01_c05.avi\nHulaHoop/v_HulaHoop_g01_c06.avi\nHulaHoop/v_HulaHoop_g01_c07.avi\nHulaHoop/v_HulaHoop_g02_c01.avi\nHulaHoop/v_HulaHoop_g02_c02.avi\nHulaHoop/v_HulaHoop_g02_c03.avi\nHulaHoop/v_HulaHoop_g02_c04.avi\nHulaHoop/v_HulaHoop_g03_c01.avi\nHulaHoop/v_HulaHoop_g03_c02.avi\nHulaHoop/v_HulaHoop_g03_c03.avi\nHulaHoop/v_HulaHoop_g03_c04.avi\nHulaHoop/v_HulaHoop_g03_c05.avi\nHulaHoop/v_HulaHoop_g04_c01.avi\nHulaHoop/v_HulaHoop_g04_c02.avi\nHulaHoop/v_HulaHoop_g04_c03.avi\nHulaHoop/v_HulaHoop_g04_c04.avi\nHulaHoop/v_HulaHoop_g04_c05.avi\nHulaHoop/v_HulaHoop_g05_c01.avi\nHulaHoop/v_HulaHoop_g05_c02.avi\nHulaHoop/v_HulaHoop_g05_c03.avi\nHulaHoop/v_HulaHoop_g05_c04.avi\nHulaHoop/v_HulaHoop_g06_c01.avi\nHulaHoop/v_HulaHoop_g06_c02.avi\nHulaHoop/v_HulaHoop_g06_c03.avi\nHulaHoop/v_HulaHoop_g06_c04.avi\nHulaHoop/v_HulaHoop_g07_c01.avi\nHulaHoop/v_HulaHoop_g07_c02.avi\nHulaHoop/v_HulaHoop_g07_c03.avi\nHulaHoop/v_HulaHoop_g07_c04.avi\nHulaHoop/v_HulaHoop_g07_c05.avi\nIceDancing/v_IceDancing_g01_c01.avi\nIceDancing/v_IceDancing_g01_c02.avi\nIceDancing/v_IceDancing_g01_c03.avi\nIceDancing/v_IceDancing_g01_c04.avi\nIceDancing/v_IceDancing_g01_c05.avi\nIceDancing/v_IceDancing_g01_c06.avi\nIceDancing/v_IceDancing_g01_c07.avi\nIceDancing/v_IceDancing_g02_c01.avi\nIceDancing/v_IceDancing_g02_c02.avi\nIceDancing/v_IceDancing_g02_c03.avi\nIceDancing/v_IceDancing_g02_c04.avi\nIceDancing/v_IceDancing_g02_c05.avi\nIceDancing/v_IceDancing_g02_c06.avi\nIceDancing/v_IceDancing_g02_c07.avi\nIceDancing/v_IceDancing_g03_c01.avi\nIceDancing/v_IceDancing_g03_c02.avi\nIceDancing/v_IceDancing_g03_c03.avi\nIceDancing/v_IceDancing_g03_c04.avi\nIceDancing/v_IceDancing_g03_c05.avi\nIceDancing/v_IceDancing_g03_c06.avi\nIceDancing/v_IceDancing_g04_c01.avi\nIceDancing/v_IceDancing_g04_c02.avi\nIceDancing/v_IceDancing_g04_c03.avi\nIceDancing/v_IceDancing_g04_c04.avi\nIceDancing/v_IceDancing_g04_c05.avi\nIceDancing/v_IceDancing_g04_c06.avi\nIceDancing/v_IceDancing_g04_c07.avi\nIceDancing/v_IceDancing_g05_c01.avi\nIceDancing/v_IceDancing_g05_c02.avi\nIceDancing/v_IceDancing_g05_c03.avi\nIceDancing/v_IceDancing_g05_c04.avi\nIceDancing/v_IceDancing_g05_c05.avi\nIceDancing/v_IceDancing_g05_c06.avi\nIceDancing/v_IceDancing_g06_c01.avi\nIceDancing/v_IceDancing_g06_c02.avi\nIceDancing/v_IceDancing_g06_c03.avi\nIceDancing/v_IceDancing_g06_c04.avi\nIceDancing/v_IceDancing_g06_c05.avi\nIceDancing/v_IceDancing_g06_c06.avi\nIceDancing/v_IceDancing_g07_c01.avi\nIceDancing/v_IceDancing_g07_c02.avi\nIceDancing/v_IceDancing_g07_c03.avi\nIceDancing/v_IceDancing_g07_c04.avi\nIceDancing/v_IceDancing_g07_c05.avi\nIceDancing/v_IceDancing_g07_c06.avi\nIceDancing/v_IceDancing_g07_c07.avi\nJavelinThrow/v_JavelinThrow_g01_c01.avi\nJavelinThrow/v_JavelinThrow_g01_c02.avi\nJavelinThrow/v_JavelinThrow_g01_c03.avi\nJavelinThrow/v_JavelinThrow_g01_c04.avi\nJavelinThrow/v_JavelinThrow_g02_c01.avi\nJavelinThrow/v_JavelinThrow_g02_c02.avi\nJavelinThrow/v_JavelinThrow_g02_c03.avi\nJavelinThrow/v_JavelinThrow_g02_c04.avi\nJavelinThrow/v_JavelinThrow_g03_c01.avi\nJavelinThrow/v_JavelinThrow_g03_c02.avi\nJavelinThrow/v_JavelinThrow_g03_c03.avi\nJavelinThrow/v_JavelinThrow_g03_c04.avi\nJavelinThrow/v_JavelinThrow_g04_c01.avi\nJavelinThrow/v_JavelinThrow_g04_c02.avi\nJavelinThrow/v_JavelinThrow_g04_c03.avi\nJavelinThrow/v_JavelinThrow_g04_c04.avi\nJavelinThrow/v_JavelinThrow_g05_c01.avi\nJavelinThrow/v_JavelinThrow_g05_c02.avi\nJavelinThrow/v_JavelinThrow_g05_c03.avi\nJavelinThrow/v_JavelinThrow_g05_c04.avi\nJavelinThrow/v_JavelinThrow_g05_c05.avi\nJavelinThrow/v_JavelinThrow_g05_c06.avi\nJavelinThrow/v_JavelinThrow_g06_c01.avi\nJavelinThrow/v_JavelinThrow_g06_c02.avi\nJavelinThrow/v_JavelinThrow_g06_c03.avi\nJavelinThrow/v_JavelinThrow_g06_c04.avi\nJavelinThrow/v_JavelinThrow_g07_c01.avi\nJavelinThrow/v_JavelinThrow_g07_c02.avi\nJavelinThrow/v_JavelinThrow_g07_c03.avi\nJavelinThrow/v_JavelinThrow_g07_c04.avi\nJavelinThrow/v_JavelinThrow_g07_c05.avi\nJugglingBalls/v_JugglingBalls_g01_c01.avi\nJugglingBalls/v_JugglingBalls_g01_c02.avi\nJugglingBalls/v_JugglingBalls_g01_c03.avi\nJugglingBalls/v_JugglingBalls_g01_c04.avi\nJugglingBalls/v_JugglingBalls_g02_c01.avi\nJugglingBalls/v_JugglingBalls_g02_c02.avi\nJugglingBalls/v_JugglingBalls_g02_c03.avi\nJugglingBalls/v_JugglingBalls_g02_c04.avi\nJugglingBalls/v_JugglingBalls_g02_c05.avi\nJugglingBalls/v_JugglingBalls_g02_c06.avi\nJugglingBalls/v_JugglingBalls_g03_c01.avi\nJugglingBalls/v_JugglingBalls_g03_c02.avi\nJugglingBalls/v_JugglingBalls_g03_c03.avi\nJugglingBalls/v_JugglingBalls_g03_c04.avi\nJugglingBalls/v_JugglingBalls_g03_c05.avi\nJugglingBalls/v_JugglingBalls_g03_c06.avi\nJugglingBalls/v_JugglingBalls_g03_c07.avi\nJugglingBalls/v_JugglingBalls_g04_c01.avi\nJugglingBalls/v_JugglingBalls_g04_c02.avi\nJugglingBalls/v_JugglingBalls_g04_c03.avi\nJugglingBalls/v_JugglingBalls_g04_c04.avi\nJugglingBalls/v_JugglingBalls_g04_c05.avi\nJugglingBalls/v_JugglingBalls_g05_c01.avi\nJugglingBalls/v_JugglingBalls_g05_c02.avi\nJugglingBalls/v_JugglingBalls_g05_c03.avi\nJugglingBalls/v_JugglingBalls_g05_c04.avi\nJugglingBalls/v_JugglingBalls_g05_c05.avi\nJugglingBalls/v_JugglingBalls_g06_c01.avi\nJugglingBalls/v_JugglingBalls_g06_c02.avi\nJugglingBalls/v_JugglingBalls_g06_c03.avi\nJugglingBalls/v_JugglingBalls_g06_c04.avi\nJugglingBalls/v_JugglingBalls_g06_c05.avi\nJugglingBalls/v_JugglingBalls_g06_c06.avi\nJugglingBalls/v_JugglingBalls_g07_c01.avi\nJugglingBalls/v_JugglingBalls_g07_c02.avi\nJugglingBalls/v_JugglingBalls_g07_c03.avi\nJugglingBalls/v_JugglingBalls_g07_c04.avi\nJugglingBalls/v_JugglingBalls_g07_c05.avi\nJugglingBalls/v_JugglingBalls_g07_c06.avi\nJugglingBalls/v_JugglingBalls_g07_c07.avi\nJumpingJack/v_JumpingJack_g01_c01.avi\nJumpingJack/v_JumpingJack_g01_c02.avi\nJumpingJack/v_JumpingJack_g01_c03.avi\nJumpingJack/v_JumpingJack_g01_c04.avi\nJumpingJack/v_JumpingJack_g01_c05.avi\nJumpingJack/v_JumpingJack_g01_c06.avi\nJumpingJack/v_JumpingJack_g01_c07.avi\nJumpingJack/v_JumpingJack_g02_c01.avi\nJumpingJack/v_JumpingJack_g02_c02.avi\nJumpingJack/v_JumpingJack_g02_c03.avi\nJumpingJack/v_JumpingJack_g02_c04.avi\nJumpingJack/v_JumpingJack_g03_c01.avi\nJumpingJack/v_JumpingJack_g03_c02.avi\nJumpingJack/v_JumpingJack_g03_c03.avi\nJumpingJack/v_JumpingJack_g03_c04.avi\nJumpingJack/v_JumpingJack_g04_c01.avi\nJumpingJack/v_JumpingJack_g04_c02.avi\nJumpingJack/v_JumpingJack_g04_c03.avi\nJumpingJack/v_JumpingJack_g04_c04.avi\nJumpingJack/v_JumpingJack_g05_c01.avi\nJumpingJack/v_JumpingJack_g05_c02.avi\nJumpingJack/v_JumpingJack_g05_c03.avi\nJumpingJack/v_JumpingJack_g05_c04.avi\nJumpingJack/v_JumpingJack_g05_c05.avi\nJumpingJack/v_JumpingJack_g05_c06.avi\nJumpingJack/v_JumpingJack_g06_c01.avi\nJumpingJack/v_JumpingJack_g06_c02.avi\nJumpingJack/v_JumpingJack_g06_c03.avi\nJumpingJack/v_JumpingJack_g06_c04.avi\nJumpingJack/v_JumpingJack_g06_c05.avi\nJumpingJack/v_JumpingJack_g06_c06.avi\nJumpingJack/v_JumpingJack_g06_c07.avi\nJumpingJack/v_JumpingJack_g07_c01.avi\nJumpingJack/v_JumpingJack_g07_c02.avi\nJumpingJack/v_JumpingJack_g07_c03.avi\nJumpingJack/v_JumpingJack_g07_c04.avi\nJumpingJack/v_JumpingJack_g07_c05.avi\nJumpRope/v_JumpRope_g01_c01.avi\nJumpRope/v_JumpRope_g01_c02.avi\nJumpRope/v_JumpRope_g01_c03.avi\nJumpRope/v_JumpRope_g01_c04.avi\nJumpRope/v_JumpRope_g02_c01.avi\nJumpRope/v_JumpRope_g02_c02.avi\nJumpRope/v_JumpRope_g02_c03.avi\nJumpRope/v_JumpRope_g02_c04.avi\nJumpRope/v_JumpRope_g02_c05.avi\nJumpRope/v_JumpRope_g02_c06.avi\nJumpRope/v_JumpRope_g02_c07.avi\nJumpRope/v_JumpRope_g03_c01.avi\nJumpRope/v_JumpRope_g03_c02.avi\nJumpRope/v_JumpRope_g03_c03.avi\nJumpRope/v_JumpRope_g03_c04.avi\nJumpRope/v_JumpRope_g04_c01.avi\nJumpRope/v_JumpRope_g04_c02.avi\nJumpRope/v_JumpRope_g04_c03.avi\nJumpRope/v_JumpRope_g04_c04.avi\nJumpRope/v_JumpRope_g04_c05.avi\nJumpRope/v_JumpRope_g04_c06.avi\nJumpRope/v_JumpRope_g04_c07.avi\nJumpRope/v_JumpRope_g05_c01.avi\nJumpRope/v_JumpRope_g05_c02.avi\nJumpRope/v_JumpRope_g05_c03.avi\nJumpRope/v_JumpRope_g05_c04.avi\nJumpRope/v_JumpRope_g05_c05.avi\nJumpRope/v_JumpRope_g06_c01.avi\nJumpRope/v_JumpRope_g06_c02.avi\nJumpRope/v_JumpRope_g06_c03.avi\nJumpRope/v_JumpRope_g06_c04.avi\nJumpRope/v_JumpRope_g06_c05.avi\nJumpRope/v_JumpRope_g07_c01.avi\nJumpRope/v_JumpRope_g07_c02.avi\nJumpRope/v_JumpRope_g07_c03.avi\nJumpRope/v_JumpRope_g07_c04.avi\nJumpRope/v_JumpRope_g07_c05.avi\nJumpRope/v_JumpRope_g07_c06.avi\nKayaking/v_Kayaking_g01_c01.avi\nKayaking/v_Kayaking_g01_c02.avi\nKayaking/v_Kayaking_g01_c03.avi\nKayaking/v_Kayaking_g01_c04.avi\nKayaking/v_Kayaking_g01_c05.avi\nKayaking/v_Kayaking_g01_c06.avi\nKayaking/v_Kayaking_g02_c01.avi\nKayaking/v_Kayaking_g02_c02.avi\nKayaking/v_Kayaking_g02_c03.avi\nKayaking/v_Kayaking_g02_c04.avi\nKayaking/v_Kayaking_g03_c01.avi\nKayaking/v_Kayaking_g03_c02.avi\nKayaking/v_Kayaking_g03_c03.avi\nKayaking/v_Kayaking_g03_c04.avi\nKayaking/v_Kayaking_g04_c01.avi\nKayaking/v_Kayaking_g04_c02.avi\nKayaking/v_Kayaking_g04_c03.avi\nKayaking/v_Kayaking_g04_c04.avi\nKayaking/v_Kayaking_g04_c05.avi\nKayaking/v_Kayaking_g04_c06.avi\nKayaking/v_Kayaking_g04_c07.avi\nKayaking/v_Kayaking_g05_c01.avi\nKayaking/v_Kayaking_g05_c02.avi\nKayaking/v_Kayaking_g05_c03.avi\nKayaking/v_Kayaking_g05_c04.avi\nKayaking/v_Kayaking_g06_c01.avi\nKayaking/v_Kayaking_g06_c02.avi\nKayaking/v_Kayaking_g06_c03.avi\nKayaking/v_Kayaking_g06_c04.avi\nKayaking/v_Kayaking_g06_c05.avi\nKayaking/v_Kayaking_g06_c06.avi\nKayaking/v_Kayaking_g06_c07.avi\nKayaking/v_Kayaking_g07_c01.avi\nKayaking/v_Kayaking_g07_c02.avi\nKayaking/v_Kayaking_g07_c03.avi\nKayaking/v_Kayaking_g07_c04.avi\nKnitting/v_Knitting_g01_c01.avi\nKnitting/v_Knitting_g01_c02.avi\nKnitting/v_Knitting_g01_c03.avi\nKnitting/v_Knitting_g01_c04.avi\nKnitting/v_Knitting_g02_c01.avi\nKnitting/v_Knitting_g02_c02.avi\nKnitting/v_Knitting_g02_c03.avi\nKnitting/v_Knitting_g02_c04.avi\nKnitting/v_Knitting_g02_c05.avi\nKnitting/v_Knitting_g03_c01.avi\nKnitting/v_Knitting_g03_c02.avi\nKnitting/v_Knitting_g03_c03.avi\nKnitting/v_Knitting_g03_c04.avi\nKnitting/v_Knitting_g03_c05.avi\nKnitting/v_Knitting_g04_c01.avi\nKnitting/v_Knitting_g04_c02.avi\nKnitting/v_Knitting_g04_c03.avi\nKnitting/v_Knitting_g04_c04.avi\nKnitting/v_Knitting_g04_c05.avi\nKnitting/v_Knitting_g04_c06.avi\nKnitting/v_Knitting_g05_c01.avi\nKnitting/v_Knitting_g05_c02.avi\nKnitting/v_Knitting_g05_c03.avi\nKnitting/v_Knitting_g05_c04.avi\nKnitting/v_Knitting_g05_c05.avi\nKnitting/v_Knitting_g06_c01.avi\nKnitting/v_Knitting_g06_c02.avi\nKnitting/v_Knitting_g06_c03.avi\nKnitting/v_Knitting_g06_c04.avi\nKnitting/v_Knitting_g07_c01.avi\nKnitting/v_Knitting_g07_c02.avi\nKnitting/v_Knitting_g07_c03.avi\nKnitting/v_Knitting_g07_c04.avi\nKnitting/v_Knitting_g07_c05.avi\nLongJump/v_LongJump_g01_c01.avi\nLongJump/v_LongJump_g01_c02.avi\nLongJump/v_LongJump_g01_c03.avi\nLongJump/v_LongJump_g01_c04.avi\nLongJump/v_LongJump_g01_c05.avi\nLongJump/v_LongJump_g01_c06.avi\nLongJump/v_LongJump_g01_c07.avi\nLongJump/v_LongJump_g02_c01.avi\nLongJump/v_LongJump_g02_c02.avi\nLongJump/v_LongJump_g02_c03.avi\nLongJump/v_LongJump_g02_c04.avi\nLongJump/v_LongJump_g02_c05.avi\nLongJump/v_LongJump_g03_c01.avi\nLongJump/v_LongJump_g03_c02.avi\nLongJump/v_LongJump_g03_c03.avi\nLongJump/v_LongJump_g03_c04.avi\nLongJump/v_LongJump_g03_c05.avi\nLongJump/v_LongJump_g03_c06.avi\nLongJump/v_LongJump_g04_c01.avi\nLongJump/v_LongJump_g04_c02.avi\nLongJump/v_LongJump_g04_c03.avi\nLongJump/v_LongJump_g04_c04.avi\nLongJump/v_LongJump_g04_c05.avi\nLongJump/v_LongJump_g04_c06.avi\nLongJump/v_LongJump_g04_c07.avi\nLongJump/v_LongJump_g05_c01.avi\nLongJump/v_LongJump_g05_c02.avi\nLongJump/v_LongJump_g05_c03.avi\nLongJump/v_LongJump_g05_c04.avi\nLongJump/v_LongJump_g05_c05.avi\nLongJump/v_LongJump_g06_c01.avi\nLongJump/v_LongJump_g06_c02.avi\nLongJump/v_LongJump_g06_c03.avi\nLongJump/v_LongJump_g06_c04.avi\nLongJump/v_LongJump_g07_c01.avi\nLongJump/v_LongJump_g07_c02.avi\nLongJump/v_LongJump_g07_c03.avi\nLongJump/v_LongJump_g07_c04.avi\nLongJump/v_LongJump_g07_c05.avi\nLunges/v_Lunges_g01_c01.avi\nLunges/v_Lunges_g01_c02.avi\nLunges/v_Lunges_g01_c03.avi\nLunges/v_Lunges_g01_c04.avi\nLunges/v_Lunges_g01_c05.avi\nLunges/v_Lunges_g01_c06.avi\nLunges/v_Lunges_g01_c07.avi\nLunges/v_Lunges_g02_c01.avi\nLunges/v_Lunges_g02_c02.avi\nLunges/v_Lunges_g02_c03.avi\nLunges/v_Lunges_g02_c04.avi\nLunges/v_Lunges_g03_c01.avi\nLunges/v_Lunges_g03_c02.avi\nLunges/v_Lunges_g03_c03.avi\nLunges/v_Lunges_g03_c04.avi\nLunges/v_Lunges_g04_c01.avi\nLunges/v_Lunges_g04_c02.avi\nLunges/v_Lunges_g04_c03.avi\nLunges/v_Lunges_g04_c04.avi\nLunges/v_Lunges_g05_c01.avi\nLunges/v_Lunges_g05_c02.avi\nLunges/v_Lunges_g05_c03.avi\nLunges/v_Lunges_g05_c04.avi\nLunges/v_Lunges_g06_c01.avi\nLunges/v_Lunges_g06_c02.avi\nLunges/v_Lunges_g06_c03.avi\nLunges/v_Lunges_g06_c04.avi\nLunges/v_Lunges_g06_c05.avi\nLunges/v_Lunges_g06_c06.avi\nLunges/v_Lunges_g06_c07.avi\nLunges/v_Lunges_g07_c01.avi\nLunges/v_Lunges_g07_c02.avi\nLunges/v_Lunges_g07_c03.avi\nLunges/v_Lunges_g07_c04.avi\nLunges/v_Lunges_g07_c05.avi\nLunges/v_Lunges_g07_c06.avi\nLunges/v_Lunges_g07_c07.avi\nMilitaryParade/v_MilitaryParade_g01_c01.avi\nMilitaryParade/v_MilitaryParade_g01_c02.avi\nMilitaryParade/v_MilitaryParade_g01_c03.avi\nMilitaryParade/v_MilitaryParade_g01_c04.avi\nMilitaryParade/v_MilitaryParade_g01_c05.avi\nMilitaryParade/v_MilitaryParade_g01_c06.avi\nMilitaryParade/v_MilitaryParade_g01_c07.avi\nMilitaryParade/v_MilitaryParade_g02_c01.avi\nMilitaryParade/v_MilitaryParade_g02_c02.avi\nMilitaryParade/v_MilitaryParade_g02_c03.avi\nMilitaryParade/v_MilitaryParade_g02_c04.avi\nMilitaryParade/v_MilitaryParade_g03_c01.avi\nMilitaryParade/v_MilitaryParade_g03_c02.avi\nMilitaryParade/v_MilitaryParade_g03_c03.avi\nMilitaryParade/v_MilitaryParade_g03_c04.avi\nMilitaryParade/v_MilitaryParade_g04_c01.avi\nMilitaryParade/v_MilitaryParade_g04_c02.avi\nMilitaryParade/v_MilitaryParade_g04_c03.avi\nMilitaryParade/v_MilitaryParade_g04_c04.avi\nMilitaryParade/v_MilitaryParade_g05_c01.avi\nMilitaryParade/v_MilitaryParade_g05_c02.avi\nMilitaryParade/v_MilitaryParade_g05_c03.avi\nMilitaryParade/v_MilitaryParade_g05_c04.avi\nMilitaryParade/v_MilitaryParade_g06_c01.avi\nMilitaryParade/v_MilitaryParade_g06_c02.avi\nMilitaryParade/v_MilitaryParade_g06_c03.avi\nMilitaryParade/v_MilitaryParade_g06_c04.avi\nMilitaryParade/v_MilitaryParade_g07_c01.avi\nMilitaryParade/v_MilitaryParade_g07_c02.avi\nMilitaryParade/v_MilitaryParade_g07_c03.avi\nMilitaryParade/v_MilitaryParade_g07_c04.avi\nMilitaryParade/v_MilitaryParade_g07_c05.avi\nMilitaryParade/v_MilitaryParade_g07_c06.avi\nMixing/v_Mixing_g01_c01.avi\nMixing/v_Mixing_g01_c02.avi\nMixing/v_Mixing_g01_c03.avi\nMixing/v_Mixing_g01_c04.avi\nMixing/v_Mixing_g01_c05.avi\nMixing/v_Mixing_g01_c06.avi\nMixing/v_Mixing_g01_c07.avi\nMixing/v_Mixing_g02_c01.avi\nMixing/v_Mixing_g02_c02.avi\nMixing/v_Mixing_g02_c03.avi\nMixing/v_Mixing_g02_c04.avi\nMixing/v_Mixing_g02_c05.avi\nMixing/v_Mixing_g02_c06.avi\nMixing/v_Mixing_g03_c01.avi\nMixing/v_Mixing_g03_c02.avi\nMixing/v_Mixing_g03_c03.avi\nMixing/v_Mixing_g03_c04.avi\nMixing/v_Mixing_g03_c05.avi\nMixing/v_Mixing_g03_c06.avi\nMixing/v_Mixing_g03_c07.avi\nMixing/v_Mixing_g04_c01.avi\nMixing/v_Mixing_g04_c02.avi\nMixing/v_Mixing_g04_c03.avi\nMixing/v_Mixing_g04_c04.avi\nMixing/v_Mixing_g04_c05.avi\nMixing/v_Mixing_g04_c06.avi\nMixing/v_Mixing_g04_c07.avi\nMixing/v_Mixing_g05_c01.avi\nMixing/v_Mixing_g05_c02.avi\nMixing/v_Mixing_g05_c03.avi\nMixing/v_Mixing_g05_c04.avi\nMixing/v_Mixing_g05_c05.avi\nMixing/v_Mixing_g05_c06.avi\nMixing/v_Mixing_g05_c07.avi\nMixing/v_Mixing_g06_c01.avi\nMixing/v_Mixing_g06_c02.avi\nMixing/v_Mixing_g06_c03.avi\nMixing/v_Mixing_g06_c04.avi\nMixing/v_Mixing_g06_c05.avi\nMixing/v_Mixing_g06_c06.avi\nMixing/v_Mixing_g07_c01.avi\nMixing/v_Mixing_g07_c02.avi\nMixing/v_Mixing_g07_c03.avi\nMixing/v_Mixing_g07_c04.avi\nMixing/v_Mixing_g07_c05.avi\nMoppingFloor/v_MoppingFloor_g01_c01.avi\nMoppingFloor/v_MoppingFloor_g01_c02.avi\nMoppingFloor/v_MoppingFloor_g01_c03.avi\nMoppingFloor/v_MoppingFloor_g01_c04.avi\nMoppingFloor/v_MoppingFloor_g02_c01.avi\nMoppingFloor/v_MoppingFloor_g02_c02.avi\nMoppingFloor/v_MoppingFloor_g02_c03.avi\nMoppingFloor/v_MoppingFloor_g02_c04.avi\nMoppingFloor/v_MoppingFloor_g02_c05.avi\nMoppingFloor/v_MoppingFloor_g02_c06.avi\nMoppingFloor/v_MoppingFloor_g03_c01.avi\nMoppingFloor/v_MoppingFloor_g03_c02.avi\nMoppingFloor/v_MoppingFloor_g03_c03.avi\nMoppingFloor/v_MoppingFloor_g03_c04.avi\nMoppingFloor/v_MoppingFloor_g04_c01.avi\nMoppingFloor/v_MoppingFloor_g04_c02.avi\nMoppingFloor/v_MoppingFloor_g04_c03.avi\nMoppingFloor/v_MoppingFloor_g04_c04.avi\nMoppingFloor/v_MoppingFloor_g04_c05.avi\nMoppingFloor/v_MoppingFloor_g04_c06.avi\nMoppingFloor/v_MoppingFloor_g05_c01.avi\nMoppingFloor/v_MoppingFloor_g05_c02.avi\nMoppingFloor/v_MoppingFloor_g05_c03.avi\nMoppingFloor/v_MoppingFloor_g05_c04.avi\nMoppingFloor/v_MoppingFloor_g05_c05.avi\nMoppingFloor/v_MoppingFloor_g06_c01.avi\nMoppingFloor/v_MoppingFloor_g06_c02.avi\nMoppingFloor/v_MoppingFloor_g06_c03.avi\nMoppingFloor/v_MoppingFloor_g06_c04.avi\nMoppingFloor/v_MoppingFloor_g07_c01.avi\nMoppingFloor/v_MoppingFloor_g07_c02.avi\nMoppingFloor/v_MoppingFloor_g07_c03.avi\nMoppingFloor/v_MoppingFloor_g07_c04.avi\nMoppingFloor/v_MoppingFloor_g07_c05.avi\nNunchucks/v_Nunchucks_g01_c01.avi\nNunchucks/v_Nunchucks_g01_c02.avi\nNunchucks/v_Nunchucks_g01_c03.avi\nNunchucks/v_Nunchucks_g01_c04.avi\nNunchucks/v_Nunchucks_g02_c01.avi\nNunchucks/v_Nunchucks_g02_c02.avi\nNunchucks/v_Nunchucks_g02_c03.avi\nNunchucks/v_Nunchucks_g02_c04.avi\nNunchucks/v_Nunchucks_g02_c05.avi\nNunchucks/v_Nunchucks_g02_c06.avi\nNunchucks/v_Nunchucks_g03_c01.avi\nNunchucks/v_Nunchucks_g03_c02.avi\nNunchucks/v_Nunchucks_g03_c03.avi\nNunchucks/v_Nunchucks_g03_c04.avi\nNunchucks/v_Nunchucks_g03_c05.avi\nNunchucks/v_Nunchucks_g03_c06.avi\nNunchucks/v_Nunchucks_g03_c07.avi\nNunchucks/v_Nunchucks_g04_c01.avi\nNunchucks/v_Nunchucks_g04_c02.avi\nNunchucks/v_Nunchucks_g04_c03.avi\nNunchucks/v_Nunchucks_g04_c04.avi\nNunchucks/v_Nunchucks_g04_c05.avi\nNunchucks/v_Nunchucks_g04_c06.avi\nNunchucks/v_Nunchucks_g05_c01.avi\nNunchucks/v_Nunchucks_g05_c02.avi\nNunchucks/v_Nunchucks_g05_c03.avi\nNunchucks/v_Nunchucks_g05_c04.avi\nNunchucks/v_Nunchucks_g06_c01.avi\nNunchucks/v_Nunchucks_g06_c02.avi\nNunchucks/v_Nunchucks_g06_c03.avi\nNunchucks/v_Nunchucks_g06_c04.avi\nNunchucks/v_Nunchucks_g07_c01.avi\nNunchucks/v_Nunchucks_g07_c02.avi\nNunchucks/v_Nunchucks_g07_c03.avi\nNunchucks/v_Nunchucks_g07_c04.avi\nParallelBars/v_ParallelBars_g01_c01.avi\nParallelBars/v_ParallelBars_g01_c02.avi\nParallelBars/v_ParallelBars_g01_c03.avi\nParallelBars/v_ParallelBars_g01_c04.avi\nParallelBars/v_ParallelBars_g02_c01.avi\nParallelBars/v_ParallelBars_g02_c02.avi\nParallelBars/v_ParallelBars_g02_c03.avi\nParallelBars/v_ParallelBars_g02_c04.avi\nParallelBars/v_ParallelBars_g03_c01.avi\nParallelBars/v_ParallelBars_g03_c02.avi\nParallelBars/v_ParallelBars_g03_c03.avi\nParallelBars/v_ParallelBars_g03_c04.avi\nParallelBars/v_ParallelBars_g04_c01.avi\nParallelBars/v_ParallelBars_g04_c02.avi\nParallelBars/v_ParallelBars_g04_c03.avi\nParallelBars/v_ParallelBars_g04_c04.avi\nParallelBars/v_ParallelBars_g04_c05.avi\nParallelBars/v_ParallelBars_g04_c06.avi\nParallelBars/v_ParallelBars_g04_c07.avi\nParallelBars/v_ParallelBars_g05_c01.avi\nParallelBars/v_ParallelBars_g05_c02.avi\nParallelBars/v_ParallelBars_g05_c03.avi\nParallelBars/v_ParallelBars_g05_c04.avi\nParallelBars/v_ParallelBars_g05_c05.avi\nParallelBars/v_ParallelBars_g06_c01.avi\nParallelBars/v_ParallelBars_g06_c02.avi\nParallelBars/v_ParallelBars_g06_c03.avi\nParallelBars/v_ParallelBars_g06_c04.avi\nParallelBars/v_ParallelBars_g06_c05.avi\nParallelBars/v_ParallelBars_g06_c06.avi\nParallelBars/v_ParallelBars_g06_c07.avi\nParallelBars/v_ParallelBars_g07_c01.avi\nParallelBars/v_ParallelBars_g07_c02.avi\nParallelBars/v_ParallelBars_g07_c03.avi\nParallelBars/v_ParallelBars_g07_c04.avi\nParallelBars/v_ParallelBars_g07_c05.avi\nParallelBars/v_ParallelBars_g07_c06.avi\nPizzaTossing/v_PizzaTossing_g01_c01.avi\nPizzaTossing/v_PizzaTossing_g01_c02.avi\nPizzaTossing/v_PizzaTossing_g01_c03.avi\nPizzaTossing/v_PizzaTossing_g01_c04.avi\nPizzaTossing/v_PizzaTossing_g02_c01.avi\nPizzaTossing/v_PizzaTossing_g02_c02.avi\nPizzaTossing/v_PizzaTossing_g02_c03.avi\nPizzaTossing/v_PizzaTossing_g02_c04.avi\nPizzaTossing/v_PizzaTossing_g02_c05.avi\nPizzaTossing/v_PizzaTossing_g03_c01.avi\nPizzaTossing/v_PizzaTossing_g03_c02.avi\nPizzaTossing/v_PizzaTossing_g03_c03.avi\nPizzaTossing/v_PizzaTossing_g03_c04.avi\nPizzaTossing/v_PizzaTossing_g04_c01.avi\nPizzaTossing/v_PizzaTossing_g04_c02.avi\nPizzaTossing/v_PizzaTossing_g04_c03.avi\nPizzaTossing/v_PizzaTossing_g04_c04.avi\nPizzaTossing/v_PizzaTossing_g04_c05.avi\nPizzaTossing/v_PizzaTossing_g04_c06.avi\nPizzaTossing/v_PizzaTossing_g04_c07.avi\nPizzaTossing/v_PizzaTossing_g05_c01.avi\nPizzaTossing/v_PizzaTossing_g05_c02.avi\nPizzaTossing/v_PizzaTossing_g05_c03.avi\nPizzaTossing/v_PizzaTossing_g05_c04.avi\nPizzaTossing/v_PizzaTossing_g06_c01.avi\nPizzaTossing/v_PizzaTossing_g06_c02.avi\nPizzaTossing/v_PizzaTossing_g06_c03.avi\nPizzaTossing/v_PizzaTossing_g06_c04.avi\nPizzaTossing/v_PizzaTossing_g06_c05.avi\nPizzaTossing/v_PizzaTossing_g07_c01.avi\nPizzaTossing/v_PizzaTossing_g07_c02.avi\nPizzaTossing/v_PizzaTossing_g07_c03.avi\nPizzaTossing/v_PizzaTossing_g07_c04.avi\nPlayingCello/v_PlayingCello_g01_c01.avi\nPlayingCello/v_PlayingCello_g01_c02.avi\nPlayingCello/v_PlayingCello_g01_c03.avi\nPlayingCello/v_PlayingCello_g01_c04.avi\nPlayingCello/v_PlayingCello_g01_c05.avi\nPlayingCello/v_PlayingCello_g01_c06.avi\nPlayingCello/v_PlayingCello_g01_c07.avi\nPlayingCello/v_PlayingCello_g02_c01.avi\nPlayingCello/v_PlayingCello_g02_c02.avi\nPlayingCello/v_PlayingCello_g02_c03.avi\nPlayingCello/v_PlayingCello_g02_c04.avi\nPlayingCello/v_PlayingCello_g02_c05.avi\nPlayingCello/v_PlayingCello_g02_c06.avi\nPlayingCello/v_PlayingCello_g02_c07.avi\nPlayingCello/v_PlayingCello_g03_c01.avi\nPlayingCello/v_PlayingCello_g03_c02.avi\nPlayingCello/v_PlayingCello_g03_c03.avi\nPlayingCello/v_PlayingCello_g03_c04.avi\nPlayingCello/v_PlayingCello_g04_c01.avi\nPlayingCello/v_PlayingCello_g04_c02.avi\nPlayingCello/v_PlayingCello_g04_c03.avi\nPlayingCello/v_PlayingCello_g04_c04.avi\nPlayingCello/v_PlayingCello_g04_c05.avi\nPlayingCello/v_PlayingCello_g04_c06.avi\nPlayingCello/v_PlayingCello_g04_c07.avi\nPlayingCello/v_PlayingCello_g05_c01.avi\nPlayingCello/v_PlayingCello_g05_c02.avi\nPlayingCello/v_PlayingCello_g05_c03.avi\nPlayingCello/v_PlayingCello_g05_c04.avi\nPlayingCello/v_PlayingCello_g05_c05.avi\nPlayingCello/v_PlayingCello_g05_c06.avi\nPlayingCello/v_PlayingCello_g05_c07.avi\nPlayingCello/v_PlayingCello_g06_c01.avi\nPlayingCello/v_PlayingCello_g06_c02.avi\nPlayingCello/v_PlayingCello_g06_c03.avi\nPlayingCello/v_PlayingCello_g06_c04.avi\nPlayingCello/v_PlayingCello_g06_c05.avi\nPlayingCello/v_PlayingCello_g06_c06.avi\nPlayingCello/v_PlayingCello_g06_c07.avi\nPlayingCello/v_PlayingCello_g07_c01.avi\nPlayingCello/v_PlayingCello_g07_c02.avi\nPlayingCello/v_PlayingCello_g07_c03.avi\nPlayingCello/v_PlayingCello_g07_c04.avi\nPlayingCello/v_PlayingCello_g07_c05.avi\nPlayingDaf/v_PlayingDaf_g01_c01.avi\nPlayingDaf/v_PlayingDaf_g01_c02.avi\nPlayingDaf/v_PlayingDaf_g01_c03.avi\nPlayingDaf/v_PlayingDaf_g01_c04.avi\nPlayingDaf/v_PlayingDaf_g02_c01.avi\nPlayingDaf/v_PlayingDaf_g02_c02.avi\nPlayingDaf/v_PlayingDaf_g02_c03.avi\nPlayingDaf/v_PlayingDaf_g02_c04.avi\nPlayingDaf/v_PlayingDaf_g02_c05.avi\nPlayingDaf/v_PlayingDaf_g02_c06.avi\nPlayingDaf/v_PlayingDaf_g02_c07.avi\nPlayingDaf/v_PlayingDaf_g03_c01.avi\nPlayingDaf/v_PlayingDaf_g03_c02.avi\nPlayingDaf/v_PlayingDaf_g03_c03.avi\nPlayingDaf/v_PlayingDaf_g03_c04.avi\nPlayingDaf/v_PlayingDaf_g04_c01.avi\nPlayingDaf/v_PlayingDaf_g04_c02.avi\nPlayingDaf/v_PlayingDaf_g04_c03.avi\nPlayingDaf/v_PlayingDaf_g04_c04.avi\nPlayingDaf/v_PlayingDaf_g04_c05.avi\nPlayingDaf/v_PlayingDaf_g04_c06.avi\nPlayingDaf/v_PlayingDaf_g04_c07.avi\nPlayingDaf/v_PlayingDaf_g05_c01.avi\nPlayingDaf/v_PlayingDaf_g05_c02.avi\nPlayingDaf/v_PlayingDaf_g05_c03.avi\nPlayingDaf/v_PlayingDaf_g05_c04.avi\nPlayingDaf/v_PlayingDaf_g05_c05.avi\nPlayingDaf/v_PlayingDaf_g05_c06.avi\nPlayingDaf/v_PlayingDaf_g05_c07.avi\nPlayingDaf/v_PlayingDaf_g06_c01.avi\nPlayingDaf/v_PlayingDaf_g06_c02.avi\nPlayingDaf/v_PlayingDaf_g06_c03.avi\nPlayingDaf/v_PlayingDaf_g06_c04.avi\nPlayingDaf/v_PlayingDaf_g06_c05.avi\nPlayingDaf/v_PlayingDaf_g06_c06.avi\nPlayingDaf/v_PlayingDaf_g06_c07.avi\nPlayingDaf/v_PlayingDaf_g07_c01.avi\nPlayingDaf/v_PlayingDaf_g07_c02.avi\nPlayingDaf/v_PlayingDaf_g07_c03.avi\nPlayingDaf/v_PlayingDaf_g07_c04.avi\nPlayingDaf/v_PlayingDaf_g07_c05.avi\nPlayingDhol/v_PlayingDhol_g01_c01.avi\nPlayingDhol/v_PlayingDhol_g01_c02.avi\nPlayingDhol/v_PlayingDhol_g01_c03.avi\nPlayingDhol/v_PlayingDhol_g01_c04.avi\nPlayingDhol/v_PlayingDhol_g01_c05.avi\nPlayingDhol/v_PlayingDhol_g01_c06.avi\nPlayingDhol/v_PlayingDhol_g01_c07.avi\nPlayingDhol/v_PlayingDhol_g02_c01.avi\nPlayingDhol/v_PlayingDhol_g02_c02.avi\nPlayingDhol/v_PlayingDhol_g02_c03.avi\nPlayingDhol/v_PlayingDhol_g02_c04.avi\nPlayingDhol/v_PlayingDhol_g02_c05.avi\nPlayingDhol/v_PlayingDhol_g02_c06.avi\nPlayingDhol/v_PlayingDhol_g02_c07.avi\nPlayingDhol/v_PlayingDhol_g03_c01.avi\nPlayingDhol/v_PlayingDhol_g03_c02.avi\nPlayingDhol/v_PlayingDhol_g03_c03.avi\nPlayingDhol/v_PlayingDhol_g03_c04.avi\nPlayingDhol/v_PlayingDhol_g03_c05.avi\nPlayingDhol/v_PlayingDhol_g03_c06.avi\nPlayingDhol/v_PlayingDhol_g03_c07.avi\nPlayingDhol/v_PlayingDhol_g04_c01.avi\nPlayingDhol/v_PlayingDhol_g04_c02.avi\nPlayingDhol/v_PlayingDhol_g04_c03.avi\nPlayingDhol/v_PlayingDhol_g04_c04.avi\nPlayingDhol/v_PlayingDhol_g04_c05.avi\nPlayingDhol/v_PlayingDhol_g04_c06.avi\nPlayingDhol/v_PlayingDhol_g04_c07.avi\nPlayingDhol/v_PlayingDhol_g05_c01.avi\nPlayingDhol/v_PlayingDhol_g05_c02.avi\nPlayingDhol/v_PlayingDhol_g05_c03.avi\nPlayingDhol/v_PlayingDhol_g05_c04.avi\nPlayingDhol/v_PlayingDhol_g05_c05.avi\nPlayingDhol/v_PlayingDhol_g05_c06.avi\nPlayingDhol/v_PlayingDhol_g05_c07.avi\nPlayingDhol/v_PlayingDhol_g06_c01.avi\nPlayingDhol/v_PlayingDhol_g06_c02.avi\nPlayingDhol/v_PlayingDhol_g06_c03.avi\nPlayingDhol/v_PlayingDhol_g06_c04.avi\nPlayingDhol/v_PlayingDhol_g06_c05.avi\nPlayingDhol/v_PlayingDhol_g06_c06.avi\nPlayingDhol/v_PlayingDhol_g06_c07.avi\nPlayingDhol/v_PlayingDhol_g07_c01.avi\nPlayingDhol/v_PlayingDhol_g07_c02.avi\nPlayingDhol/v_PlayingDhol_g07_c03.avi\nPlayingDhol/v_PlayingDhol_g07_c04.avi\nPlayingDhol/v_PlayingDhol_g07_c05.avi\nPlayingDhol/v_PlayingDhol_g07_c06.avi\nPlayingDhol/v_PlayingDhol_g07_c07.avi\nPlayingFlute/v_PlayingFlute_g01_c01.avi\nPlayingFlute/v_PlayingFlute_g01_c02.avi\nPlayingFlute/v_PlayingFlute_g01_c03.avi\nPlayingFlute/v_PlayingFlute_g01_c04.avi\nPlayingFlute/v_PlayingFlute_g01_c05.avi\nPlayingFlute/v_PlayingFlute_g01_c06.avi\nPlayingFlute/v_PlayingFlute_g01_c07.avi\nPlayingFlute/v_PlayingFlute_g02_c01.avi\nPlayingFlute/v_PlayingFlute_g02_c02.avi\nPlayingFlute/v_PlayingFlute_g02_c03.avi\nPlayingFlute/v_PlayingFlute_g02_c04.avi\nPlayingFlute/v_PlayingFlute_g02_c05.avi\nPlayingFlute/v_PlayingFlute_g02_c06.avi\nPlayingFlute/v_PlayingFlute_g02_c07.avi\nPlayingFlute/v_PlayingFlute_g03_c01.avi\nPlayingFlute/v_PlayingFlute_g03_c02.avi\nPlayingFlute/v_PlayingFlute_g03_c03.avi\nPlayingFlute/v_PlayingFlute_g03_c04.avi\nPlayingFlute/v_PlayingFlute_g03_c05.avi\nPlayingFlute/v_PlayingFlute_g03_c06.avi\nPlayingFlute/v_PlayingFlute_g03_c07.avi\nPlayingFlute/v_PlayingFlute_g04_c01.avi\nPlayingFlute/v_PlayingFlute_g04_c02.avi\nPlayingFlute/v_PlayingFlute_g04_c03.avi\nPlayingFlute/v_PlayingFlute_g04_c04.avi\nPlayingFlute/v_PlayingFlute_g04_c05.avi\nPlayingFlute/v_PlayingFlute_g04_c06.avi\nPlayingFlute/v_PlayingFlute_g04_c07.avi\nPlayingFlute/v_PlayingFlute_g05_c01.avi\nPlayingFlute/v_PlayingFlute_g05_c02.avi\nPlayingFlute/v_PlayingFlute_g05_c03.avi\nPlayingFlute/v_PlayingFlute_g05_c04.avi\nPlayingFlute/v_PlayingFlute_g05_c05.avi\nPlayingFlute/v_PlayingFlute_g05_c06.avi\nPlayingFlute/v_PlayingFlute_g05_c07.avi\nPlayingFlute/v_PlayingFlute_g06_c01.avi\nPlayingFlute/v_PlayingFlute_g06_c02.avi\nPlayingFlute/v_PlayingFlute_g06_c03.avi\nPlayingFlute/v_PlayingFlute_g06_c04.avi\nPlayingFlute/v_PlayingFlute_g06_c05.avi\nPlayingFlute/v_PlayingFlute_g06_c06.avi\nPlayingFlute/v_PlayingFlute_g07_c01.avi\nPlayingFlute/v_PlayingFlute_g07_c02.avi\nPlayingFlute/v_PlayingFlute_g07_c03.avi\nPlayingFlute/v_PlayingFlute_g07_c04.avi\nPlayingFlute/v_PlayingFlute_g07_c05.avi\nPlayingFlute/v_PlayingFlute_g07_c06.avi\nPlayingFlute/v_PlayingFlute_g07_c07.avi\nPlayingGuitar/v_PlayingGuitar_g01_c01.avi\nPlayingGuitar/v_PlayingGuitar_g01_c02.avi\nPlayingGuitar/v_PlayingGuitar_g01_c03.avi\nPlayingGuitar/v_PlayingGuitar_g01_c04.avi\nPlayingGuitar/v_PlayingGuitar_g01_c05.avi\nPlayingGuitar/v_PlayingGuitar_g01_c06.avi\nPlayingGuitar/v_PlayingGuitar_g02_c01.avi\nPlayingGuitar/v_PlayingGuitar_g02_c02.avi\nPlayingGuitar/v_PlayingGuitar_g02_c03.avi\nPlayingGuitar/v_PlayingGuitar_g02_c04.avi\nPlayingGuitar/v_PlayingGuitar_g03_c01.avi\nPlayingGuitar/v_PlayingGuitar_g03_c02.avi\nPlayingGuitar/v_PlayingGuitar_g03_c03.avi\nPlayingGuitar/v_PlayingGuitar_g03_c04.avi\nPlayingGuitar/v_PlayingGuitar_g03_c05.avi\nPlayingGuitar/v_PlayingGuitar_g03_c06.avi\nPlayingGuitar/v_PlayingGuitar_g03_c07.avi\nPlayingGuitar/v_PlayingGuitar_g04_c01.avi\nPlayingGuitar/v_PlayingGuitar_g04_c02.avi\nPlayingGuitar/v_PlayingGuitar_g04_c03.avi\nPlayingGuitar/v_PlayingGuitar_g04_c04.avi\nPlayingGuitar/v_PlayingGuitar_g04_c05.avi\nPlayingGuitar/v_PlayingGuitar_g04_c06.avi\nPlayingGuitar/v_PlayingGuitar_g04_c07.avi\nPlayingGuitar/v_PlayingGuitar_g05_c01.avi\nPlayingGuitar/v_PlayingGuitar_g05_c02.avi\nPlayingGuitar/v_PlayingGuitar_g05_c03.avi\nPlayingGuitar/v_PlayingGuitar_g05_c04.avi\nPlayingGuitar/v_PlayingGuitar_g05_c05.avi\nPlayingGuitar/v_PlayingGuitar_g06_c01.avi\nPlayingGuitar/v_PlayingGuitar_g06_c02.avi\nPlayingGuitar/v_PlayingGuitar_g06_c03.avi\nPlayingGuitar/v_PlayingGuitar_g06_c04.avi\nPlayingGuitar/v_PlayingGuitar_g06_c05.avi\nPlayingGuitar/v_PlayingGuitar_g06_c06.avi\nPlayingGuitar/v_PlayingGuitar_g06_c07.avi\nPlayingGuitar/v_PlayingGuitar_g07_c01.avi\nPlayingGuitar/v_PlayingGuitar_g07_c02.avi\nPlayingGuitar/v_PlayingGuitar_g07_c03.avi\nPlayingGuitar/v_PlayingGuitar_g07_c04.avi\nPlayingGuitar/v_PlayingGuitar_g07_c05.avi\nPlayingGuitar/v_PlayingGuitar_g07_c06.avi\nPlayingGuitar/v_PlayingGuitar_g07_c07.avi\nPlayingPiano/v_PlayingPiano_g01_c01.avi\nPlayingPiano/v_PlayingPiano_g01_c02.avi\nPlayingPiano/v_PlayingPiano_g01_c03.avi\nPlayingPiano/v_PlayingPiano_g01_c04.avi\nPlayingPiano/v_PlayingPiano_g02_c01.avi\nPlayingPiano/v_PlayingPiano_g02_c02.avi\nPlayingPiano/v_PlayingPiano_g02_c03.avi\nPlayingPiano/v_PlayingPiano_g02_c04.avi\nPlayingPiano/v_PlayingPiano_g03_c01.avi\nPlayingPiano/v_PlayingPiano_g03_c02.avi\nPlayingPiano/v_PlayingPiano_g03_c03.avi\nPlayingPiano/v_PlayingPiano_g03_c04.avi\nPlayingPiano/v_PlayingPiano_g04_c01.avi\nPlayingPiano/v_PlayingPiano_g04_c02.avi\nPlayingPiano/v_PlayingPiano_g04_c03.avi\nPlayingPiano/v_PlayingPiano_g04_c04.avi\nPlayingPiano/v_PlayingPiano_g05_c01.avi\nPlayingPiano/v_PlayingPiano_g05_c02.avi\nPlayingPiano/v_PlayingPiano_g05_c03.avi\nPlayingPiano/v_PlayingPiano_g05_c04.avi\nPlayingPiano/v_PlayingPiano_g06_c01.avi\nPlayingPiano/v_PlayingPiano_g06_c02.avi\nPlayingPiano/v_PlayingPiano_g06_c03.avi\nPlayingPiano/v_PlayingPiano_g06_c04.avi\nPlayingPiano/v_PlayingPiano_g07_c01.avi\nPlayingPiano/v_PlayingPiano_g07_c02.avi\nPlayingPiano/v_PlayingPiano_g07_c03.avi\nPlayingPiano/v_PlayingPiano_g07_c04.avi\nPlayingSitar/v_PlayingSitar_g01_c01.avi\nPlayingSitar/v_PlayingSitar_g01_c02.avi\nPlayingSitar/v_PlayingSitar_g01_c03.avi\nPlayingSitar/v_PlayingSitar_g01_c04.avi\nPlayingSitar/v_PlayingSitar_g02_c01.avi\nPlayingSitar/v_PlayingSitar_g02_c02.avi\nPlayingSitar/v_PlayingSitar_g02_c03.avi\nPlayingSitar/v_PlayingSitar_g02_c04.avi\nPlayingSitar/v_PlayingSitar_g02_c05.avi\nPlayingSitar/v_PlayingSitar_g02_c06.avi\nPlayingSitar/v_PlayingSitar_g03_c01.avi\nPlayingSitar/v_PlayingSitar_g03_c02.avi\nPlayingSitar/v_PlayingSitar_g03_c03.avi\nPlayingSitar/v_PlayingSitar_g03_c04.avi\nPlayingSitar/v_PlayingSitar_g03_c05.avi\nPlayingSitar/v_PlayingSitar_g03_c06.avi\nPlayingSitar/v_PlayingSitar_g03_c07.avi\nPlayingSitar/v_PlayingSitar_g04_c01.avi\nPlayingSitar/v_PlayingSitar_g04_c02.avi\nPlayingSitar/v_PlayingSitar_g04_c03.avi\nPlayingSitar/v_PlayingSitar_g04_c04.avi\nPlayingSitar/v_PlayingSitar_g04_c05.avi\nPlayingSitar/v_PlayingSitar_g04_c06.avi\nPlayingSitar/v_PlayingSitar_g04_c07.avi\nPlayingSitar/v_PlayingSitar_g05_c01.avi\nPlayingSitar/v_PlayingSitar_g05_c02.avi\nPlayingSitar/v_PlayingSitar_g05_c03.avi\nPlayingSitar/v_PlayingSitar_g05_c04.avi\nPlayingSitar/v_PlayingSitar_g05_c05.avi\nPlayingSitar/v_PlayingSitar_g05_c06.avi\nPlayingSitar/v_PlayingSitar_g05_c07.avi\nPlayingSitar/v_PlayingSitar_g06_c01.avi\nPlayingSitar/v_PlayingSitar_g06_c02.avi\nPlayingSitar/v_PlayingSitar_g06_c03.avi\nPlayingSitar/v_PlayingSitar_g06_c04.avi\nPlayingSitar/v_PlayingSitar_g06_c05.avi\nPlayingSitar/v_PlayingSitar_g06_c06.avi\nPlayingSitar/v_PlayingSitar_g07_c01.avi\nPlayingSitar/v_PlayingSitar_g07_c02.avi\nPlayingSitar/v_PlayingSitar_g07_c03.avi\nPlayingSitar/v_PlayingSitar_g07_c04.avi\nPlayingSitar/v_PlayingSitar_g07_c05.avi\nPlayingSitar/v_PlayingSitar_g07_c06.avi\nPlayingSitar/v_PlayingSitar_g07_c07.avi\nPlayingTabla/v_PlayingTabla_g01_c01.avi\nPlayingTabla/v_PlayingTabla_g01_c02.avi\nPlayingTabla/v_PlayingTabla_g01_c03.avi\nPlayingTabla/v_PlayingTabla_g01_c04.avi\nPlayingTabla/v_PlayingTabla_g02_c01.avi\nPlayingTabla/v_PlayingTabla_g02_c02.avi\nPlayingTabla/v_PlayingTabla_g02_c03.avi\nPlayingTabla/v_PlayingTabla_g02_c04.avi\nPlayingTabla/v_PlayingTabla_g03_c01.avi\nPlayingTabla/v_PlayingTabla_g03_c02.avi\nPlayingTabla/v_PlayingTabla_g03_c03.avi\nPlayingTabla/v_PlayingTabla_g03_c04.avi\nPlayingTabla/v_PlayingTabla_g03_c05.avi\nPlayingTabla/v_PlayingTabla_g04_c01.avi\nPlayingTabla/v_PlayingTabla_g04_c02.avi\nPlayingTabla/v_PlayingTabla_g04_c03.avi\nPlayingTabla/v_PlayingTabla_g04_c04.avi\nPlayingTabla/v_PlayingTabla_g04_c05.avi\nPlayingTabla/v_PlayingTabla_g04_c06.avi\nPlayingTabla/v_PlayingTabla_g05_c01.avi\nPlayingTabla/v_PlayingTabla_g05_c02.avi\nPlayingTabla/v_PlayingTabla_g05_c03.avi\nPlayingTabla/v_PlayingTabla_g05_c04.avi\nPlayingTabla/v_PlayingTabla_g06_c01.avi\nPlayingTabla/v_PlayingTabla_g06_c02.avi\nPlayingTabla/v_PlayingTabla_g06_c03.avi\nPlayingTabla/v_PlayingTabla_g06_c04.avi\nPlayingTabla/v_PlayingTabla_g07_c01.avi\nPlayingTabla/v_PlayingTabla_g07_c02.avi\nPlayingTabla/v_PlayingTabla_g07_c03.avi\nPlayingTabla/v_PlayingTabla_g07_c04.avi\nPlayingViolin/v_PlayingViolin_g01_c01.avi\nPlayingViolin/v_PlayingViolin_g01_c02.avi\nPlayingViolin/v_PlayingViolin_g01_c03.avi\nPlayingViolin/v_PlayingViolin_g01_c04.avi\nPlayingViolin/v_PlayingViolin_g02_c01.avi\nPlayingViolin/v_PlayingViolin_g02_c02.avi\nPlayingViolin/v_PlayingViolin_g02_c03.avi\nPlayingViolin/v_PlayingViolin_g02_c04.avi\nPlayingViolin/v_PlayingViolin_g03_c01.avi\nPlayingViolin/v_PlayingViolin_g03_c02.avi\nPlayingViolin/v_PlayingViolin_g03_c03.avi\nPlayingViolin/v_PlayingViolin_g03_c04.avi\nPlayingViolin/v_PlayingViolin_g04_c01.avi\nPlayingViolin/v_PlayingViolin_g04_c02.avi\nPlayingViolin/v_PlayingViolin_g04_c03.avi\nPlayingViolin/v_PlayingViolin_g04_c04.avi\nPlayingViolin/v_PlayingViolin_g05_c01.avi\nPlayingViolin/v_PlayingViolin_g05_c02.avi\nPlayingViolin/v_PlayingViolin_g05_c03.avi\nPlayingViolin/v_PlayingViolin_g05_c04.avi\nPlayingViolin/v_PlayingViolin_g06_c01.avi\nPlayingViolin/v_PlayingViolin_g06_c02.avi\nPlayingViolin/v_PlayingViolin_g06_c03.avi\nPlayingViolin/v_PlayingViolin_g06_c04.avi\nPlayingViolin/v_PlayingViolin_g07_c01.avi\nPlayingViolin/v_PlayingViolin_g07_c02.avi\nPlayingViolin/v_PlayingViolin_g07_c03.avi\nPlayingViolin/v_PlayingViolin_g07_c04.avi\nPoleVault/v_PoleVault_g01_c01.avi\nPoleVault/v_PoleVault_g01_c02.avi\nPoleVault/v_PoleVault_g01_c03.avi\nPoleVault/v_PoleVault_g01_c04.avi\nPoleVault/v_PoleVault_g01_c05.avi\nPoleVault/v_PoleVault_g02_c01.avi\nPoleVault/v_PoleVault_g02_c02.avi\nPoleVault/v_PoleVault_g02_c03.avi\nPoleVault/v_PoleVault_g02_c04.avi\nPoleVault/v_PoleVault_g02_c05.avi\nPoleVault/v_PoleVault_g02_c06.avi\nPoleVault/v_PoleVault_g02_c07.avi\nPoleVault/v_PoleVault_g03_c01.avi\nPoleVault/v_PoleVault_g03_c02.avi\nPoleVault/v_PoleVault_g03_c03.avi\nPoleVault/v_PoleVault_g03_c04.avi\nPoleVault/v_PoleVault_g03_c05.avi\nPoleVault/v_PoleVault_g03_c06.avi\nPoleVault/v_PoleVault_g03_c07.avi\nPoleVault/v_PoleVault_g04_c01.avi\nPoleVault/v_PoleVault_g04_c02.avi\nPoleVault/v_PoleVault_g04_c03.avi\nPoleVault/v_PoleVault_g04_c04.avi\nPoleVault/v_PoleVault_g04_c05.avi\nPoleVault/v_PoleVault_g04_c06.avi\nPoleVault/v_PoleVault_g04_c07.avi\nPoleVault/v_PoleVault_g05_c01.avi\nPoleVault/v_PoleVault_g05_c02.avi\nPoleVault/v_PoleVault_g05_c03.avi\nPoleVault/v_PoleVault_g05_c04.avi\nPoleVault/v_PoleVault_g05_c05.avi\nPoleVault/v_PoleVault_g06_c01.avi\nPoleVault/v_PoleVault_g06_c02.avi\nPoleVault/v_PoleVault_g06_c03.avi\nPoleVault/v_PoleVault_g06_c04.avi\nPoleVault/v_PoleVault_g06_c05.avi\nPoleVault/v_PoleVault_g07_c01.avi\nPoleVault/v_PoleVault_g07_c02.avi\nPoleVault/v_PoleVault_g07_c03.avi\nPoleVault/v_PoleVault_g07_c04.avi\nPommelHorse/v_PommelHorse_g01_c01.avi\nPommelHorse/v_PommelHorse_g01_c02.avi\nPommelHorse/v_PommelHorse_g01_c03.avi\nPommelHorse/v_PommelHorse_g01_c04.avi\nPommelHorse/v_PommelHorse_g01_c05.avi\nPommelHorse/v_PommelHorse_g01_c06.avi\nPommelHorse/v_PommelHorse_g01_c07.avi\nPommelHorse/v_PommelHorse_g02_c01.avi\nPommelHorse/v_PommelHorse_g02_c02.avi\nPommelHorse/v_PommelHorse_g02_c03.avi\nPommelHorse/v_PommelHorse_g02_c04.avi\nPommelHorse/v_PommelHorse_g03_c01.avi\nPommelHorse/v_PommelHorse_g03_c02.avi\nPommelHorse/v_PommelHorse_g03_c03.avi\nPommelHorse/v_PommelHorse_g03_c04.avi\nPommelHorse/v_PommelHorse_g04_c01.avi\nPommelHorse/v_PommelHorse_g04_c02.avi\nPommelHorse/v_PommelHorse_g04_c03.avi\nPommelHorse/v_PommelHorse_g04_c04.avi\nPommelHorse/v_PommelHorse_g04_c05.avi\nPommelHorse/v_PommelHorse_g05_c01.avi\nPommelHorse/v_PommelHorse_g05_c02.avi\nPommelHorse/v_PommelHorse_g05_c03.avi\nPommelHorse/v_PommelHorse_g05_c04.avi\nPommelHorse/v_PommelHorse_g06_c01.avi\nPommelHorse/v_PommelHorse_g06_c02.avi\nPommelHorse/v_PommelHorse_g06_c03.avi\nPommelHorse/v_PommelHorse_g06_c04.avi\nPommelHorse/v_PommelHorse_g07_c01.avi\nPommelHorse/v_PommelHorse_g07_c02.avi\nPommelHorse/v_PommelHorse_g07_c03.avi\nPommelHorse/v_PommelHorse_g07_c04.avi\nPommelHorse/v_PommelHorse_g07_c05.avi\nPommelHorse/v_PommelHorse_g07_c06.avi\nPommelHorse/v_PommelHorse_g07_c07.avi\nPullUps/v_PullUps_g01_c01.avi\nPullUps/v_PullUps_g01_c02.avi\nPullUps/v_PullUps_g01_c03.avi\nPullUps/v_PullUps_g01_c04.avi\nPullUps/v_PullUps_g02_c01.avi\nPullUps/v_PullUps_g02_c02.avi\nPullUps/v_PullUps_g02_c03.avi\nPullUps/v_PullUps_g02_c04.avi\nPullUps/v_PullUps_g03_c01.avi\nPullUps/v_PullUps_g03_c02.avi\nPullUps/v_PullUps_g03_c03.avi\nPullUps/v_PullUps_g03_c04.avi\nPullUps/v_PullUps_g04_c01.avi\nPullUps/v_PullUps_g04_c02.avi\nPullUps/v_PullUps_g04_c03.avi\nPullUps/v_PullUps_g04_c04.avi\nPullUps/v_PullUps_g05_c01.avi\nPullUps/v_PullUps_g05_c02.avi\nPullUps/v_PullUps_g05_c03.avi\nPullUps/v_PullUps_g05_c04.avi\nPullUps/v_PullUps_g06_c01.avi\nPullUps/v_PullUps_g06_c02.avi\nPullUps/v_PullUps_g06_c03.avi\nPullUps/v_PullUps_g06_c04.avi\nPullUps/v_PullUps_g07_c01.avi\nPullUps/v_PullUps_g07_c02.avi\nPullUps/v_PullUps_g07_c03.avi\nPullUps/v_PullUps_g07_c04.avi\nPunch/v_Punch_g01_c01.avi\nPunch/v_Punch_g01_c02.avi\nPunch/v_Punch_g01_c03.avi\nPunch/v_Punch_g01_c04.avi\nPunch/v_Punch_g01_c05.avi\nPunch/v_Punch_g02_c01.avi\nPunch/v_Punch_g02_c02.avi\nPunch/v_Punch_g02_c03.avi\nPunch/v_Punch_g02_c04.avi\nPunch/v_Punch_g03_c01.avi\nPunch/v_Punch_g03_c02.avi\nPunch/v_Punch_g03_c03.avi\nPunch/v_Punch_g03_c04.avi\nPunch/v_Punch_g04_c01.avi\nPunch/v_Punch_g04_c02.avi\nPunch/v_Punch_g04_c03.avi\nPunch/v_Punch_g04_c04.avi\nPunch/v_Punch_g04_c05.avi\nPunch/v_Punch_g05_c01.avi\nPunch/v_Punch_g05_c02.avi\nPunch/v_Punch_g05_c03.avi\nPunch/v_Punch_g05_c04.avi\nPunch/v_Punch_g05_c05.avi\nPunch/v_Punch_g05_c06.avi\nPunch/v_Punch_g05_c07.avi\nPunch/v_Punch_g06_c01.avi\nPunch/v_Punch_g06_c02.avi\nPunch/v_Punch_g06_c03.avi\nPunch/v_Punch_g06_c04.avi\nPunch/v_Punch_g06_c05.avi\nPunch/v_Punch_g06_c06.avi\nPunch/v_Punch_g06_c07.avi\nPunch/v_Punch_g07_c01.avi\nPunch/v_Punch_g07_c02.avi\nPunch/v_Punch_g07_c03.avi\nPunch/v_Punch_g07_c04.avi\nPunch/v_Punch_g07_c05.avi\nPunch/v_Punch_g07_c06.avi\nPunch/v_Punch_g07_c07.avi\nPushUps/v_PushUps_g01_c01.avi\nPushUps/v_PushUps_g01_c02.avi\nPushUps/v_PushUps_g01_c03.avi\nPushUps/v_PushUps_g01_c04.avi\nPushUps/v_PushUps_g01_c05.avi\nPushUps/v_PushUps_g02_c01.avi\nPushUps/v_PushUps_g02_c02.avi\nPushUps/v_PushUps_g02_c03.avi\nPushUps/v_PushUps_g02_c04.avi\nPushUps/v_PushUps_g03_c01.avi\nPushUps/v_PushUps_g03_c02.avi\nPushUps/v_PushUps_g03_c03.avi\nPushUps/v_PushUps_g03_c04.avi\nPushUps/v_PushUps_g04_c01.avi\nPushUps/v_PushUps_g04_c02.avi\nPushUps/v_PushUps_g04_c03.avi\nPushUps/v_PushUps_g04_c04.avi\nPushUps/v_PushUps_g04_c05.avi\nPushUps/v_PushUps_g05_c01.avi\nPushUps/v_PushUps_g05_c02.avi\nPushUps/v_PushUps_g05_c03.avi\nPushUps/v_PushUps_g05_c04.avi\nPushUps/v_PushUps_g06_c01.avi\nPushUps/v_PushUps_g06_c02.avi\nPushUps/v_PushUps_g06_c03.avi\nPushUps/v_PushUps_g06_c04.avi\nPushUps/v_PushUps_g07_c01.avi\nPushUps/v_PushUps_g07_c02.avi\nPushUps/v_PushUps_g07_c03.avi\nPushUps/v_PushUps_g07_c04.avi\nRafting/v_Rafting_g01_c01.avi\nRafting/v_Rafting_g01_c02.avi\nRafting/v_Rafting_g01_c03.avi\nRafting/v_Rafting_g01_c04.avi\nRafting/v_Rafting_g02_c01.avi\nRafting/v_Rafting_g02_c02.avi\nRafting/v_Rafting_g02_c03.avi\nRafting/v_Rafting_g02_c04.avi\nRafting/v_Rafting_g03_c01.avi\nRafting/v_Rafting_g03_c02.avi\nRafting/v_Rafting_g03_c03.avi\nRafting/v_Rafting_g03_c04.avi\nRafting/v_Rafting_g04_c01.avi\nRafting/v_Rafting_g04_c02.avi\nRafting/v_Rafting_g04_c03.avi\nRafting/v_Rafting_g04_c04.avi\nRafting/v_Rafting_g05_c01.avi\nRafting/v_Rafting_g05_c02.avi\nRafting/v_Rafting_g05_c03.avi\nRafting/v_Rafting_g05_c04.avi\nRafting/v_Rafting_g06_c01.avi\nRafting/v_Rafting_g06_c02.avi\nRafting/v_Rafting_g06_c03.avi\nRafting/v_Rafting_g06_c04.avi\nRafting/v_Rafting_g07_c01.avi\nRafting/v_Rafting_g07_c02.avi\nRafting/v_Rafting_g07_c03.avi\nRafting/v_Rafting_g07_c04.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g01_c01.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g01_c02.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g01_c03.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g01_c04.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g01_c05.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g02_c01.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g02_c02.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g02_c03.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g02_c04.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g02_c05.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g03_c01.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g03_c02.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g03_c03.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g03_c04.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g03_c05.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g03_c06.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g03_c07.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g04_c01.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g04_c02.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g04_c03.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g04_c04.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g05_c01.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g05_c02.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g05_c03.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g05_c04.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g05_c05.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g05_c06.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g06_c01.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g06_c02.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g06_c03.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g06_c04.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g06_c05.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g06_c06.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g06_c07.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g07_c01.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g07_c02.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g07_c03.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g07_c04.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g07_c05.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g07_c06.avi\nRockClimbingIndoor/v_RockClimbingIndoor_g07_c07.avi\nRopeClimbing/v_RopeClimbing_g01_c01.avi\nRopeClimbing/v_RopeClimbing_g01_c02.avi\nRopeClimbing/v_RopeClimbing_g01_c03.avi\nRopeClimbing/v_RopeClimbing_g01_c04.avi\nRopeClimbing/v_RopeClimbing_g02_c01.avi\nRopeClimbing/v_RopeClimbing_g02_c02.avi\nRopeClimbing/v_RopeClimbing_g02_c03.avi\nRopeClimbing/v_RopeClimbing_g02_c04.avi\nRopeClimbing/v_RopeClimbing_g02_c05.avi\nRopeClimbing/v_RopeClimbing_g02_c06.avi\nRopeClimbing/v_RopeClimbing_g03_c01.avi\nRopeClimbing/v_RopeClimbing_g03_c02.avi\nRopeClimbing/v_RopeClimbing_g03_c03.avi\nRopeClimbing/v_RopeClimbing_g03_c04.avi\nRopeClimbing/v_RopeClimbing_g04_c01.avi\nRopeClimbing/v_RopeClimbing_g04_c02.avi\nRopeClimbing/v_RopeClimbing_g04_c03.avi\nRopeClimbing/v_RopeClimbing_g04_c04.avi\nRopeClimbing/v_RopeClimbing_g05_c01.avi\nRopeClimbing/v_RopeClimbing_g05_c02.avi\nRopeClimbing/v_RopeClimbing_g05_c03.avi\nRopeClimbing/v_RopeClimbing_g05_c04.avi\nRopeClimbing/v_RopeClimbing_g05_c05.avi\nRopeClimbing/v_RopeClimbing_g05_c06.avi\nRopeClimbing/v_RopeClimbing_g05_c07.avi\nRopeClimbing/v_RopeClimbing_g06_c01.avi\nRopeClimbing/v_RopeClimbing_g06_c02.avi\nRopeClimbing/v_RopeClimbing_g06_c03.avi\nRopeClimbing/v_RopeClimbing_g06_c04.avi\nRopeClimbing/v_RopeClimbing_g07_c01.avi\nRopeClimbing/v_RopeClimbing_g07_c02.avi\nRopeClimbing/v_RopeClimbing_g07_c03.avi\nRopeClimbing/v_RopeClimbing_g07_c04.avi\nRopeClimbing/v_RopeClimbing_g07_c05.avi\nRowing/v_Rowing_g01_c01.avi\nRowing/v_Rowing_g01_c02.avi\nRowing/v_Rowing_g01_c03.avi\nRowing/v_Rowing_g01_c04.avi\nRowing/v_Rowing_g02_c01.avi\nRowing/v_Rowing_g02_c02.avi\nRowing/v_Rowing_g02_c03.avi\nRowing/v_Rowing_g02_c04.avi\nRowing/v_Rowing_g02_c05.avi\nRowing/v_Rowing_g02_c06.avi\nRowing/v_Rowing_g03_c01.avi\nRowing/v_Rowing_g03_c02.avi\nRowing/v_Rowing_g03_c03.avi\nRowing/v_Rowing_g03_c04.avi\nRowing/v_Rowing_g03_c05.avi\nRowing/v_Rowing_g03_c06.avi\nRowing/v_Rowing_g03_c07.avi\nRowing/v_Rowing_g04_c01.avi\nRowing/v_Rowing_g04_c02.avi\nRowing/v_Rowing_g04_c03.avi\nRowing/v_Rowing_g04_c04.avi\nRowing/v_Rowing_g04_c05.avi\nRowing/v_Rowing_g04_c06.avi\nRowing/v_Rowing_g05_c01.avi\nRowing/v_Rowing_g05_c02.avi\nRowing/v_Rowing_g05_c03.avi\nRowing/v_Rowing_g05_c04.avi\nRowing/v_Rowing_g06_c01.avi\nRowing/v_Rowing_g06_c02.avi\nRowing/v_Rowing_g06_c03.avi\nRowing/v_Rowing_g06_c04.avi\nRowing/v_Rowing_g07_c01.avi\nRowing/v_Rowing_g07_c02.avi\nRowing/v_Rowing_g07_c03.avi\nRowing/v_Rowing_g07_c04.avi\nRowing/v_Rowing_g07_c05.avi\nSalsaSpin/v_SalsaSpin_g01_c01.avi\nSalsaSpin/v_SalsaSpin_g01_c02.avi\nSalsaSpin/v_SalsaSpin_g01_c03.avi\nSalsaSpin/v_SalsaSpin_g01_c04.avi\nSalsaSpin/v_SalsaSpin_g01_c05.avi\nSalsaSpin/v_SalsaSpin_g01_c06.avi\nSalsaSpin/v_SalsaSpin_g01_c07.avi\nSalsaSpin/v_SalsaSpin_g02_c01.avi\nSalsaSpin/v_SalsaSpin_g02_c02.avi\nSalsaSpin/v_SalsaSpin_g02_c03.avi\nSalsaSpin/v_SalsaSpin_g02_c04.avi\nSalsaSpin/v_SalsaSpin_g02_c05.avi\nSalsaSpin/v_SalsaSpin_g02_c06.avi\nSalsaSpin/v_SalsaSpin_g02_c07.avi\nSalsaSpin/v_SalsaSpin_g03_c01.avi\nSalsaSpin/v_SalsaSpin_g03_c02.avi\nSalsaSpin/v_SalsaSpin_g03_c03.avi\nSalsaSpin/v_SalsaSpin_g03_c04.avi\nSalsaSpin/v_SalsaSpin_g03_c05.avi\nSalsaSpin/v_SalsaSpin_g03_c06.avi\nSalsaSpin/v_SalsaSpin_g04_c01.avi\nSalsaSpin/v_SalsaSpin_g04_c02.avi\nSalsaSpin/v_SalsaSpin_g04_c03.avi\nSalsaSpin/v_SalsaSpin_g04_c04.avi\nSalsaSpin/v_SalsaSpin_g04_c05.avi\nSalsaSpin/v_SalsaSpin_g04_c06.avi\nSalsaSpin/v_SalsaSpin_g05_c01.avi\nSalsaSpin/v_SalsaSpin_g05_c02.avi\nSalsaSpin/v_SalsaSpin_g05_c03.avi\nSalsaSpin/v_SalsaSpin_g05_c04.avi\nSalsaSpin/v_SalsaSpin_g05_c05.avi\nSalsaSpin/v_SalsaSpin_g05_c06.avi\nSalsaSpin/v_SalsaSpin_g06_c01.avi\nSalsaSpin/v_SalsaSpin_g06_c02.avi\nSalsaSpin/v_SalsaSpin_g06_c03.avi\nSalsaSpin/v_SalsaSpin_g06_c04.avi\nSalsaSpin/v_SalsaSpin_g06_c05.avi\nSalsaSpin/v_SalsaSpin_g07_c01.avi\nSalsaSpin/v_SalsaSpin_g07_c02.avi\nSalsaSpin/v_SalsaSpin_g07_c03.avi\nSalsaSpin/v_SalsaSpin_g07_c04.avi\nSalsaSpin/v_SalsaSpin_g07_c05.avi\nSalsaSpin/v_SalsaSpin_g07_c06.avi\nShavingBeard/v_ShavingBeard_g01_c01.avi\nShavingBeard/v_ShavingBeard_g01_c02.avi\nShavingBeard/v_ShavingBeard_g01_c03.avi\nShavingBeard/v_ShavingBeard_g01_c04.avi\nShavingBeard/v_ShavingBeard_g02_c01.avi\nShavingBeard/v_ShavingBeard_g02_c02.avi\nShavingBeard/v_ShavingBeard_g02_c03.avi\nShavingBeard/v_ShavingBeard_g02_c04.avi\nShavingBeard/v_ShavingBeard_g02_c05.avi\nShavingBeard/v_ShavingBeard_g02_c06.avi\nShavingBeard/v_ShavingBeard_g02_c07.avi\nShavingBeard/v_ShavingBeard_g03_c01.avi\nShavingBeard/v_ShavingBeard_g03_c02.avi\nShavingBeard/v_ShavingBeard_g03_c03.avi\nShavingBeard/v_ShavingBeard_g03_c04.avi\nShavingBeard/v_ShavingBeard_g03_c05.avi\nShavingBeard/v_ShavingBeard_g03_c06.avi\nShavingBeard/v_ShavingBeard_g03_c07.avi\nShavingBeard/v_ShavingBeard_g04_c01.avi\nShavingBeard/v_ShavingBeard_g04_c02.avi\nShavingBeard/v_ShavingBeard_g04_c03.avi\nShavingBeard/v_ShavingBeard_g04_c04.avi\nShavingBeard/v_ShavingBeard_g05_c01.avi\nShavingBeard/v_ShavingBeard_g05_c02.avi\nShavingBeard/v_ShavingBeard_g05_c03.avi\nShavingBeard/v_ShavingBeard_g05_c04.avi\nShavingBeard/v_ShavingBeard_g05_c05.avi\nShavingBeard/v_ShavingBeard_g05_c06.avi\nShavingBeard/v_ShavingBeard_g05_c07.avi\nShavingBeard/v_ShavingBeard_g06_c01.avi\nShavingBeard/v_ShavingBeard_g06_c02.avi\nShavingBeard/v_ShavingBeard_g06_c03.avi\nShavingBeard/v_ShavingBeard_g06_c04.avi\nShavingBeard/v_ShavingBeard_g06_c05.avi\nShavingBeard/v_ShavingBeard_g06_c06.avi\nShavingBeard/v_ShavingBeard_g06_c07.avi\nShavingBeard/v_ShavingBeard_g07_c01.avi\nShavingBeard/v_ShavingBeard_g07_c02.avi\nShavingBeard/v_ShavingBeard_g07_c03.avi\nShavingBeard/v_ShavingBeard_g07_c04.avi\nShavingBeard/v_ShavingBeard_g07_c05.avi\nShavingBeard/v_ShavingBeard_g07_c06.avi\nShavingBeard/v_ShavingBeard_g07_c07.avi\nShotput/v_Shotput_g01_c01.avi\nShotput/v_Shotput_g01_c02.avi\nShotput/v_Shotput_g01_c03.avi\nShotput/v_Shotput_g01_c04.avi\nShotput/v_Shotput_g01_c05.avi\nShotput/v_Shotput_g01_c06.avi\nShotput/v_Shotput_g01_c07.avi\nShotput/v_Shotput_g02_c01.avi\nShotput/v_Shotput_g02_c02.avi\nShotput/v_Shotput_g02_c03.avi\nShotput/v_Shotput_g02_c04.avi\nShotput/v_Shotput_g02_c05.avi\nShotput/v_Shotput_g02_c06.avi\nShotput/v_Shotput_g02_c07.avi\nShotput/v_Shotput_g03_c01.avi\nShotput/v_Shotput_g03_c02.avi\nShotput/v_Shotput_g03_c03.avi\nShotput/v_Shotput_g03_c04.avi\nShotput/v_Shotput_g03_c05.avi\nShotput/v_Shotput_g03_c06.avi\nShotput/v_Shotput_g04_c01.avi\nShotput/v_Shotput_g04_c02.avi\nShotput/v_Shotput_g04_c03.avi\nShotput/v_Shotput_g04_c04.avi\nShotput/v_Shotput_g04_c05.avi\nShotput/v_Shotput_g05_c01.avi\nShotput/v_Shotput_g05_c02.avi\nShotput/v_Shotput_g05_c03.avi\nShotput/v_Shotput_g05_c04.avi\nShotput/v_Shotput_g05_c05.avi\nShotput/v_Shotput_g05_c06.avi\nShotput/v_Shotput_g05_c07.avi\nShotput/v_Shotput_g06_c01.avi\nShotput/v_Shotput_g06_c02.avi\nShotput/v_Shotput_g06_c03.avi\nShotput/v_Shotput_g06_c04.avi\nShotput/v_Shotput_g06_c05.avi\nShotput/v_Shotput_g06_c06.avi\nShotput/v_Shotput_g06_c07.avi\nShotput/v_Shotput_g07_c01.avi\nShotput/v_Shotput_g07_c02.avi\nShotput/v_Shotput_g07_c03.avi\nShotput/v_Shotput_g07_c04.avi\nShotput/v_Shotput_g07_c05.avi\nShotput/v_Shotput_g07_c06.avi\nShotput/v_Shotput_g07_c07.avi\nSkateBoarding/v_SkateBoarding_g01_c01.avi\nSkateBoarding/v_SkateBoarding_g01_c02.avi\nSkateBoarding/v_SkateBoarding_g01_c03.avi\nSkateBoarding/v_SkateBoarding_g01_c04.avi\nSkateBoarding/v_SkateBoarding_g02_c01.avi\nSkateBoarding/v_SkateBoarding_g02_c02.avi\nSkateBoarding/v_SkateBoarding_g02_c03.avi\nSkateBoarding/v_SkateBoarding_g02_c04.avi\nSkateBoarding/v_SkateBoarding_g02_c05.avi\nSkateBoarding/v_SkateBoarding_g02_c06.avi\nSkateBoarding/v_SkateBoarding_g03_c01.avi\nSkateBoarding/v_SkateBoarding_g03_c02.avi\nSkateBoarding/v_SkateBoarding_g03_c03.avi\nSkateBoarding/v_SkateBoarding_g03_c04.avi\nSkateBoarding/v_SkateBoarding_g04_c01.avi\nSkateBoarding/v_SkateBoarding_g04_c02.avi\nSkateBoarding/v_SkateBoarding_g04_c03.avi\nSkateBoarding/v_SkateBoarding_g04_c04.avi\nSkateBoarding/v_SkateBoarding_g04_c05.avi\nSkateBoarding/v_SkateBoarding_g05_c01.avi\nSkateBoarding/v_SkateBoarding_g05_c02.avi\nSkateBoarding/v_SkateBoarding_g05_c03.avi\nSkateBoarding/v_SkateBoarding_g05_c04.avi\nSkateBoarding/v_SkateBoarding_g06_c01.avi\nSkateBoarding/v_SkateBoarding_g06_c02.avi\nSkateBoarding/v_SkateBoarding_g06_c03.avi\nSkateBoarding/v_SkateBoarding_g06_c04.avi\nSkateBoarding/v_SkateBoarding_g07_c01.avi\nSkateBoarding/v_SkateBoarding_g07_c02.avi\nSkateBoarding/v_SkateBoarding_g07_c03.avi\nSkateBoarding/v_SkateBoarding_g07_c04.avi\nSkateBoarding/v_SkateBoarding_g07_c05.avi\nSkiing/v_Skiing_g01_c01.avi\nSkiing/v_Skiing_g01_c02.avi\nSkiing/v_Skiing_g01_c03.avi\nSkiing/v_Skiing_g01_c04.avi\nSkiing/v_Skiing_g01_c05.avi\nSkiing/v_Skiing_g01_c06.avi\nSkiing/v_Skiing_g02_c01.avi\nSkiing/v_Skiing_g02_c02.avi\nSkiing/v_Skiing_g02_c03.avi\nSkiing/v_Skiing_g02_c04.avi\nSkiing/v_Skiing_g02_c05.avi\nSkiing/v_Skiing_g03_c01.avi\nSkiing/v_Skiing_g03_c02.avi\nSkiing/v_Skiing_g03_c03.avi\nSkiing/v_Skiing_g03_c04.avi\nSkiing/v_Skiing_g03_c05.avi\nSkiing/v_Skiing_g03_c06.avi\nSkiing/v_Skiing_g03_c07.avi\nSkiing/v_Skiing_g04_c01.avi\nSkiing/v_Skiing_g04_c02.avi\nSkiing/v_Skiing_g04_c03.avi\nSkiing/v_Skiing_g04_c04.avi\nSkiing/v_Skiing_g04_c05.avi\nSkiing/v_Skiing_g04_c06.avi\nSkiing/v_Skiing_g04_c07.avi\nSkiing/v_Skiing_g05_c01.avi\nSkiing/v_Skiing_g05_c02.avi\nSkiing/v_Skiing_g05_c03.avi\nSkiing/v_Skiing_g05_c04.avi\nSkiing/v_Skiing_g06_c01.avi\nSkiing/v_Skiing_g06_c02.avi\nSkiing/v_Skiing_g06_c03.avi\nSkiing/v_Skiing_g06_c04.avi\nSkiing/v_Skiing_g06_c05.avi\nSkiing/v_Skiing_g06_c06.avi\nSkiing/v_Skiing_g06_c07.avi\nSkiing/v_Skiing_g07_c01.avi\nSkiing/v_Skiing_g07_c02.avi\nSkiing/v_Skiing_g07_c03.avi\nSkiing/v_Skiing_g07_c04.avi\nSkijet/v_Skijet_g01_c01.avi\nSkijet/v_Skijet_g01_c02.avi\nSkijet/v_Skijet_g01_c03.avi\nSkijet/v_Skijet_g01_c04.avi\nSkijet/v_Skijet_g02_c01.avi\nSkijet/v_Skijet_g02_c02.avi\nSkijet/v_Skijet_g02_c03.avi\nSkijet/v_Skijet_g02_c04.avi\nSkijet/v_Skijet_g03_c01.avi\nSkijet/v_Skijet_g03_c02.avi\nSkijet/v_Skijet_g03_c03.avi\nSkijet/v_Skijet_g03_c04.avi\nSkijet/v_Skijet_g04_c01.avi\nSkijet/v_Skijet_g04_c02.avi\nSkijet/v_Skijet_g04_c03.avi\nSkijet/v_Skijet_g04_c04.avi\nSkijet/v_Skijet_g05_c01.avi\nSkijet/v_Skijet_g05_c02.avi\nSkijet/v_Skijet_g05_c03.avi\nSkijet/v_Skijet_g05_c04.avi\nSkijet/v_Skijet_g06_c01.avi\nSkijet/v_Skijet_g06_c02.avi\nSkijet/v_Skijet_g06_c03.avi\nSkijet/v_Skijet_g06_c04.avi\nSkijet/v_Skijet_g07_c01.avi\nSkijet/v_Skijet_g07_c02.avi\nSkijet/v_Skijet_g07_c03.avi\nSkijet/v_Skijet_g07_c04.avi\nSkyDiving/v_SkyDiving_g01_c01.avi\nSkyDiving/v_SkyDiving_g01_c02.avi\nSkyDiving/v_SkyDiving_g01_c03.avi\nSkyDiving/v_SkyDiving_g01_c04.avi\nSkyDiving/v_SkyDiving_g02_c01.avi\nSkyDiving/v_SkyDiving_g02_c02.avi\nSkyDiving/v_SkyDiving_g02_c03.avi\nSkyDiving/v_SkyDiving_g02_c04.avi\nSkyDiving/v_SkyDiving_g03_c01.avi\nSkyDiving/v_SkyDiving_g03_c02.avi\nSkyDiving/v_SkyDiving_g03_c03.avi\nSkyDiving/v_SkyDiving_g03_c04.avi\nSkyDiving/v_SkyDiving_g03_c05.avi\nSkyDiving/v_SkyDiving_g04_c01.avi\nSkyDiving/v_SkyDiving_g04_c02.avi\nSkyDiving/v_SkyDiving_g04_c03.avi\nSkyDiving/v_SkyDiving_g04_c04.avi\nSkyDiving/v_SkyDiving_g05_c01.avi\nSkyDiving/v_SkyDiving_g05_c02.avi\nSkyDiving/v_SkyDiving_g05_c03.avi\nSkyDiving/v_SkyDiving_g05_c04.avi\nSkyDiving/v_SkyDiving_g05_c05.avi\nSkyDiving/v_SkyDiving_g06_c01.avi\nSkyDiving/v_SkyDiving_g06_c02.avi\nSkyDiving/v_SkyDiving_g06_c03.avi\nSkyDiving/v_SkyDiving_g06_c04.avi\nSkyDiving/v_SkyDiving_g07_c01.avi\nSkyDiving/v_SkyDiving_g07_c02.avi\nSkyDiving/v_SkyDiving_g07_c03.avi\nSkyDiving/v_SkyDiving_g07_c04.avi\nSkyDiving/v_SkyDiving_g07_c05.avi\nSoccerJuggling/v_SoccerJuggling_g01_c01.avi\nSoccerJuggling/v_SoccerJuggling_g01_c02.avi\nSoccerJuggling/v_SoccerJuggling_g01_c03.avi\nSoccerJuggling/v_SoccerJuggling_g01_c04.avi\nSoccerJuggling/v_SoccerJuggling_g01_c05.avi\nSoccerJuggling/v_SoccerJuggling_g02_c01.avi\nSoccerJuggling/v_SoccerJuggling_g02_c02.avi\nSoccerJuggling/v_SoccerJuggling_g02_c03.avi\nSoccerJuggling/v_SoccerJuggling_g02_c04.avi\nSoccerJuggling/v_SoccerJuggling_g02_c05.avi\nSoccerJuggling/v_SoccerJuggling_g02_c06.avi\nSoccerJuggling/v_SoccerJuggling_g03_c01.avi\nSoccerJuggling/v_SoccerJuggling_g03_c02.avi\nSoccerJuggling/v_SoccerJuggling_g03_c03.avi\nSoccerJuggling/v_SoccerJuggling_g03_c04.avi\nSoccerJuggling/v_SoccerJuggling_g04_c01.avi\nSoccerJuggling/v_SoccerJuggling_g04_c02.avi\nSoccerJuggling/v_SoccerJuggling_g04_c03.avi\nSoccerJuggling/v_SoccerJuggling_g04_c04.avi\nSoccerJuggling/v_SoccerJuggling_g04_c05.avi\nSoccerJuggling/v_SoccerJuggling_g04_c06.avi\nSoccerJuggling/v_SoccerJuggling_g05_c01.avi\nSoccerJuggling/v_SoccerJuggling_g05_c02.avi\nSoccerJuggling/v_SoccerJuggling_g05_c03.avi\nSoccerJuggling/v_SoccerJuggling_g05_c04.avi\nSoccerJuggling/v_SoccerJuggling_g05_c05.avi\nSoccerJuggling/v_SoccerJuggling_g05_c06.avi\nSoccerJuggling/v_SoccerJuggling_g06_c01.avi\nSoccerJuggling/v_SoccerJuggling_g06_c02.avi\nSoccerJuggling/v_SoccerJuggling_g06_c03.avi\nSoccerJuggling/v_SoccerJuggling_g06_c04.avi\nSoccerJuggling/v_SoccerJuggling_g06_c05.avi\nSoccerJuggling/v_SoccerJuggling_g07_c01.avi\nSoccerJuggling/v_SoccerJuggling_g07_c02.avi\nSoccerJuggling/v_SoccerJuggling_g07_c03.avi\nSoccerJuggling/v_SoccerJuggling_g07_c04.avi\nSoccerJuggling/v_SoccerJuggling_g07_c05.avi\nSoccerJuggling/v_SoccerJuggling_g07_c06.avi\nSoccerJuggling/v_SoccerJuggling_g07_c07.avi\nSoccerPenalty/v_SoccerPenalty_g01_c01.avi\nSoccerPenalty/v_SoccerPenalty_g01_c02.avi\nSoccerPenalty/v_SoccerPenalty_g01_c03.avi\nSoccerPenalty/v_SoccerPenalty_g01_c04.avi\nSoccerPenalty/v_SoccerPenalty_g01_c05.avi\nSoccerPenalty/v_SoccerPenalty_g01_c06.avi\nSoccerPenalty/v_SoccerPenalty_g02_c01.avi\nSoccerPenalty/v_SoccerPenalty_g02_c02.avi\nSoccerPenalty/v_SoccerPenalty_g02_c03.avi\nSoccerPenalty/v_SoccerPenalty_g02_c04.avi\nSoccerPenalty/v_SoccerPenalty_g02_c05.avi\nSoccerPenalty/v_SoccerPenalty_g03_c01.avi\nSoccerPenalty/v_SoccerPenalty_g03_c02.avi\nSoccerPenalty/v_SoccerPenalty_g03_c03.avi\nSoccerPenalty/v_SoccerPenalty_g03_c04.avi\nSoccerPenalty/v_SoccerPenalty_g03_c05.avi\nSoccerPenalty/v_SoccerPenalty_g04_c01.avi\nSoccerPenalty/v_SoccerPenalty_g04_c02.avi\nSoccerPenalty/v_SoccerPenalty_g04_c03.avi\nSoccerPenalty/v_SoccerPenalty_g04_c04.avi\nSoccerPenalty/v_SoccerPenalty_g04_c05.avi\nSoccerPenalty/v_SoccerPenalty_g05_c01.avi\nSoccerPenalty/v_SoccerPenalty_g05_c02.avi\nSoccerPenalty/v_SoccerPenalty_g05_c03.avi\nSoccerPenalty/v_SoccerPenalty_g05_c04.avi\nSoccerPenalty/v_SoccerPenalty_g05_c05.avi\nSoccerPenalty/v_SoccerPenalty_g05_c06.avi\nSoccerPenalty/v_SoccerPenalty_g05_c07.avi\nSoccerPenalty/v_SoccerPenalty_g06_c01.avi\nSoccerPenalty/v_SoccerPenalty_g06_c02.avi\nSoccerPenalty/v_SoccerPenalty_g06_c03.avi\nSoccerPenalty/v_SoccerPenalty_g06_c04.avi\nSoccerPenalty/v_SoccerPenalty_g06_c05.avi\nSoccerPenalty/v_SoccerPenalty_g06_c06.avi\nSoccerPenalty/v_SoccerPenalty_g06_c07.avi\nSoccerPenalty/v_SoccerPenalty_g07_c01.avi\nSoccerPenalty/v_SoccerPenalty_g07_c02.avi\nSoccerPenalty/v_SoccerPenalty_g07_c03.avi\nSoccerPenalty/v_SoccerPenalty_g07_c04.avi\nSoccerPenalty/v_SoccerPenalty_g07_c05.avi\nSoccerPenalty/v_SoccerPenalty_g07_c06.avi\nStillRings/v_StillRings_g01_c01.avi\nStillRings/v_StillRings_g01_c02.avi\nStillRings/v_StillRings_g01_c03.avi\nStillRings/v_StillRings_g01_c04.avi\nStillRings/v_StillRings_g01_c05.avi\nStillRings/v_StillRings_g02_c01.avi\nStillRings/v_StillRings_g02_c02.avi\nStillRings/v_StillRings_g02_c03.avi\nStillRings/v_StillRings_g02_c04.avi\nStillRings/v_StillRings_g03_c01.avi\nStillRings/v_StillRings_g03_c02.avi\nStillRings/v_StillRings_g03_c03.avi\nStillRings/v_StillRings_g03_c04.avi\nStillRings/v_StillRings_g03_c05.avi\nStillRings/v_StillRings_g03_c06.avi\nStillRings/v_StillRings_g03_c07.avi\nStillRings/v_StillRings_g04_c01.avi\nStillRings/v_StillRings_g04_c02.avi\nStillRings/v_StillRings_g04_c03.avi\nStillRings/v_StillRings_g04_c04.avi\nStillRings/v_StillRings_g05_c01.avi\nStillRings/v_StillRings_g05_c02.avi\nStillRings/v_StillRings_g05_c03.avi\nStillRings/v_StillRings_g05_c04.avi\nStillRings/v_StillRings_g06_c01.avi\nStillRings/v_StillRings_g06_c02.avi\nStillRings/v_StillRings_g06_c03.avi\nStillRings/v_StillRings_g06_c04.avi\nStillRings/v_StillRings_g07_c01.avi\nStillRings/v_StillRings_g07_c02.avi\nStillRings/v_StillRings_g07_c03.avi\nStillRings/v_StillRings_g07_c04.avi\nSumoWrestling/v_SumoWrestling_g01_c01.avi\nSumoWrestling/v_SumoWrestling_g01_c02.avi\nSumoWrestling/v_SumoWrestling_g01_c03.avi\nSumoWrestling/v_SumoWrestling_g01_c04.avi\nSumoWrestling/v_SumoWrestling_g02_c01.avi\nSumoWrestling/v_SumoWrestling_g02_c02.avi\nSumoWrestling/v_SumoWrestling_g02_c03.avi\nSumoWrestling/v_SumoWrestling_g02_c04.avi\nSumoWrestling/v_SumoWrestling_g03_c01.avi\nSumoWrestling/v_SumoWrestling_g03_c02.avi\nSumoWrestling/v_SumoWrestling_g03_c03.avi\nSumoWrestling/v_SumoWrestling_g03_c04.avi\nSumoWrestling/v_SumoWrestling_g04_c01.avi\nSumoWrestling/v_SumoWrestling_g04_c02.avi\nSumoWrestling/v_SumoWrestling_g04_c03.avi\nSumoWrestling/v_SumoWrestling_g04_c04.avi\nSumoWrestling/v_SumoWrestling_g05_c01.avi\nSumoWrestling/v_SumoWrestling_g05_c02.avi\nSumoWrestling/v_SumoWrestling_g05_c03.avi\nSumoWrestling/v_SumoWrestling_g05_c04.avi\nSumoWrestling/v_SumoWrestling_g06_c01.avi\nSumoWrestling/v_SumoWrestling_g06_c02.avi\nSumoWrestling/v_SumoWrestling_g06_c03.avi\nSumoWrestling/v_SumoWrestling_g06_c04.avi\nSumoWrestling/v_SumoWrestling_g06_c05.avi\nSumoWrestling/v_SumoWrestling_g06_c06.avi\nSumoWrestling/v_SumoWrestling_g06_c07.avi\nSumoWrestling/v_SumoWrestling_g07_c01.avi\nSumoWrestling/v_SumoWrestling_g07_c02.avi\nSumoWrestling/v_SumoWrestling_g07_c03.avi\nSumoWrestling/v_SumoWrestling_g07_c04.avi\nSumoWrestling/v_SumoWrestling_g07_c05.avi\nSumoWrestling/v_SumoWrestling_g07_c06.avi\nSumoWrestling/v_SumoWrestling_g07_c07.avi\nSurfing/v_Surfing_g01_c01.avi\nSurfing/v_Surfing_g01_c02.avi\nSurfing/v_Surfing_g01_c03.avi\nSurfing/v_Surfing_g01_c04.avi\nSurfing/v_Surfing_g01_c05.avi\nSurfing/v_Surfing_g01_c06.avi\nSurfing/v_Surfing_g01_c07.avi\nSurfing/v_Surfing_g02_c01.avi\nSurfing/v_Surfing_g02_c02.avi\nSurfing/v_Surfing_g02_c03.avi\nSurfing/v_Surfing_g02_c04.avi\nSurfing/v_Surfing_g02_c05.avi\nSurfing/v_Surfing_g02_c06.avi\nSurfing/v_Surfing_g03_c01.avi\nSurfing/v_Surfing_g03_c02.avi\nSurfing/v_Surfing_g03_c03.avi\nSurfing/v_Surfing_g03_c04.avi\nSurfing/v_Surfing_g04_c01.avi\nSurfing/v_Surfing_g04_c02.avi\nSurfing/v_Surfing_g04_c03.avi\nSurfing/v_Surfing_g04_c04.avi\nSurfing/v_Surfing_g05_c01.avi\nSurfing/v_Surfing_g05_c02.avi\nSurfing/v_Surfing_g05_c03.avi\nSurfing/v_Surfing_g05_c04.avi\nSurfing/v_Surfing_g06_c01.avi\nSurfing/v_Surfing_g06_c02.avi\nSurfing/v_Surfing_g06_c03.avi\nSurfing/v_Surfing_g06_c04.avi\nSurfing/v_Surfing_g07_c01.avi\nSurfing/v_Surfing_g07_c02.avi\nSurfing/v_Surfing_g07_c03.avi\nSurfing/v_Surfing_g07_c04.avi\nSwing/v_Swing_g01_c01.avi\nSwing/v_Swing_g01_c02.avi\nSwing/v_Swing_g01_c03.avi\nSwing/v_Swing_g01_c04.avi\nSwing/v_Swing_g01_c05.avi\nSwing/v_Swing_g02_c01.avi\nSwing/v_Swing_g02_c02.avi\nSwing/v_Swing_g02_c03.avi\nSwing/v_Swing_g02_c04.avi\nSwing/v_Swing_g02_c05.avi\nSwing/v_Swing_g03_c01.avi\nSwing/v_Swing_g03_c02.avi\nSwing/v_Swing_g03_c03.avi\nSwing/v_Swing_g03_c04.avi\nSwing/v_Swing_g04_c01.avi\nSwing/v_Swing_g04_c02.avi\nSwing/v_Swing_g04_c03.avi\nSwing/v_Swing_g04_c04.avi\nSwing/v_Swing_g04_c05.avi\nSwing/v_Swing_g04_c06.avi\nSwing/v_Swing_g04_c07.avi\nSwing/v_Swing_g05_c01.avi\nSwing/v_Swing_g05_c02.avi\nSwing/v_Swing_g05_c03.avi\nSwing/v_Swing_g05_c04.avi\nSwing/v_Swing_g05_c05.avi\nSwing/v_Swing_g05_c06.avi\nSwing/v_Swing_g05_c07.avi\nSwing/v_Swing_g06_c01.avi\nSwing/v_Swing_g06_c02.avi\nSwing/v_Swing_g06_c03.avi\nSwing/v_Swing_g06_c04.avi\nSwing/v_Swing_g06_c05.avi\nSwing/v_Swing_g06_c06.avi\nSwing/v_Swing_g06_c07.avi\nSwing/v_Swing_g07_c01.avi\nSwing/v_Swing_g07_c02.avi\nSwing/v_Swing_g07_c03.avi\nSwing/v_Swing_g07_c04.avi\nSwing/v_Swing_g07_c05.avi\nSwing/v_Swing_g07_c06.avi\nSwing/v_Swing_g07_c07.avi\nTableTennisShot/v_TableTennisShot_g01_c01.avi\nTableTennisShot/v_TableTennisShot_g01_c02.avi\nTableTennisShot/v_TableTennisShot_g01_c03.avi\nTableTennisShot/v_TableTennisShot_g01_c04.avi\nTableTennisShot/v_TableTennisShot_g01_c05.avi\nTableTennisShot/v_TableTennisShot_g01_c06.avi\nTableTennisShot/v_TableTennisShot_g02_c01.avi\nTableTennisShot/v_TableTennisShot_g02_c02.avi\nTableTennisShot/v_TableTennisShot_g02_c03.avi\nTableTennisShot/v_TableTennisShot_g02_c04.avi\nTableTennisShot/v_TableTennisShot_g03_c01.avi\nTableTennisShot/v_TableTennisShot_g03_c02.avi\nTableTennisShot/v_TableTennisShot_g03_c03.avi\nTableTennisShot/v_TableTennisShot_g03_c04.avi\nTableTennisShot/v_TableTennisShot_g03_c05.avi\nTableTennisShot/v_TableTennisShot_g04_c01.avi\nTableTennisShot/v_TableTennisShot_g04_c02.avi\nTableTennisShot/v_TableTennisShot_g04_c03.avi\nTableTennisShot/v_TableTennisShot_g04_c04.avi\nTableTennisShot/v_TableTennisShot_g04_c05.avi\nTableTennisShot/v_TableTennisShot_g04_c06.avi\nTableTennisShot/v_TableTennisShot_g04_c07.avi\nTableTennisShot/v_TableTennisShot_g05_c01.avi\nTableTennisShot/v_TableTennisShot_g05_c02.avi\nTableTennisShot/v_TableTennisShot_g05_c03.avi\nTableTennisShot/v_TableTennisShot_g05_c04.avi\nTableTennisShot/v_TableTennisShot_g05_c05.avi\nTableTennisShot/v_TableTennisShot_g05_c06.avi\nTableTennisShot/v_TableTennisShot_g05_c07.avi\nTableTennisShot/v_TableTennisShot_g06_c01.avi\nTableTennisShot/v_TableTennisShot_g06_c02.avi\nTableTennisShot/v_TableTennisShot_g06_c03.avi\nTableTennisShot/v_TableTennisShot_g06_c04.avi\nTableTennisShot/v_TableTennisShot_g06_c05.avi\nTableTennisShot/v_TableTennisShot_g06_c06.avi\nTableTennisShot/v_TableTennisShot_g07_c01.avi\nTableTennisShot/v_TableTennisShot_g07_c02.avi\nTableTennisShot/v_TableTennisShot_g07_c03.avi\nTableTennisShot/v_TableTennisShot_g07_c04.avi\nTaiChi/v_TaiChi_g01_c01.avi\nTaiChi/v_TaiChi_g01_c02.avi\nTaiChi/v_TaiChi_g01_c03.avi\nTaiChi/v_TaiChi_g01_c04.avi\nTaiChi/v_TaiChi_g02_c01.avi\nTaiChi/v_TaiChi_g02_c02.avi\nTaiChi/v_TaiChi_g02_c03.avi\nTaiChi/v_TaiChi_g02_c04.avi\nTaiChi/v_TaiChi_g03_c01.avi\nTaiChi/v_TaiChi_g03_c02.avi\nTaiChi/v_TaiChi_g03_c03.avi\nTaiChi/v_TaiChi_g03_c04.avi\nTaiChi/v_TaiChi_g04_c01.avi\nTaiChi/v_TaiChi_g04_c02.avi\nTaiChi/v_TaiChi_g04_c03.avi\nTaiChi/v_TaiChi_g04_c04.avi\nTaiChi/v_TaiChi_g05_c01.avi\nTaiChi/v_TaiChi_g05_c02.avi\nTaiChi/v_TaiChi_g05_c03.avi\nTaiChi/v_TaiChi_g05_c04.avi\nTaiChi/v_TaiChi_g06_c01.avi\nTaiChi/v_TaiChi_g06_c02.avi\nTaiChi/v_TaiChi_g06_c03.avi\nTaiChi/v_TaiChi_g06_c04.avi\nTaiChi/v_TaiChi_g07_c01.avi\nTaiChi/v_TaiChi_g07_c02.avi\nTaiChi/v_TaiChi_g07_c03.avi\nTaiChi/v_TaiChi_g07_c04.avi\nTennisSwing/v_TennisSwing_g01_c01.avi\nTennisSwing/v_TennisSwing_g01_c02.avi\nTennisSwing/v_TennisSwing_g01_c03.avi\nTennisSwing/v_TennisSwing_g01_c04.avi\nTennisSwing/v_TennisSwing_g01_c05.avi\nTennisSwing/v_TennisSwing_g01_c06.avi\nTennisSwing/v_TennisSwing_g01_c07.avi\nTennisSwing/v_TennisSwing_g02_c01.avi\nTennisSwing/v_TennisSwing_g02_c02.avi\nTennisSwing/v_TennisSwing_g02_c03.avi\nTennisSwing/v_TennisSwing_g02_c04.avi\nTennisSwing/v_TennisSwing_g02_c05.avi\nTennisSwing/v_TennisSwing_g02_c06.avi\nTennisSwing/v_TennisSwing_g02_c07.avi\nTennisSwing/v_TennisSwing_g03_c01.avi\nTennisSwing/v_TennisSwing_g03_c02.avi\nTennisSwing/v_TennisSwing_g03_c03.avi\nTennisSwing/v_TennisSwing_g03_c04.avi\nTennisSwing/v_TennisSwing_g03_c05.avi\nTennisSwing/v_TennisSwing_g03_c06.avi\nTennisSwing/v_TennisSwing_g03_c07.avi\nTennisSwing/v_TennisSwing_g04_c01.avi\nTennisSwing/v_TennisSwing_g04_c02.avi\nTennisSwing/v_TennisSwing_g04_c03.avi\nTennisSwing/v_TennisSwing_g04_c04.avi\nTennisSwing/v_TennisSwing_g04_c05.avi\nTennisSwing/v_TennisSwing_g04_c06.avi\nTennisSwing/v_TennisSwing_g04_c07.avi\nTennisSwing/v_TennisSwing_g05_c01.avi\nTennisSwing/v_TennisSwing_g05_c02.avi\nTennisSwing/v_TennisSwing_g05_c03.avi\nTennisSwing/v_TennisSwing_g05_c04.avi\nTennisSwing/v_TennisSwing_g05_c05.avi\nTennisSwing/v_TennisSwing_g05_c06.avi\nTennisSwing/v_TennisSwing_g05_c07.avi\nTennisSwing/v_TennisSwing_g06_c01.avi\nTennisSwing/v_TennisSwing_g06_c02.avi\nTennisSwing/v_TennisSwing_g06_c03.avi\nTennisSwing/v_TennisSwing_g06_c04.avi\nTennisSwing/v_TennisSwing_g06_c05.avi\nTennisSwing/v_TennisSwing_g06_c06.avi\nTennisSwing/v_TennisSwing_g06_c07.avi\nTennisSwing/v_TennisSwing_g07_c01.avi\nTennisSwing/v_TennisSwing_g07_c02.avi\nTennisSwing/v_TennisSwing_g07_c03.avi\nTennisSwing/v_TennisSwing_g07_c04.avi\nTennisSwing/v_TennisSwing_g07_c05.avi\nTennisSwing/v_TennisSwing_g07_c06.avi\nTennisSwing/v_TennisSwing_g07_c07.avi\nThrowDiscus/v_ThrowDiscus_g01_c01.avi\nThrowDiscus/v_ThrowDiscus_g01_c02.avi\nThrowDiscus/v_ThrowDiscus_g01_c03.avi\nThrowDiscus/v_ThrowDiscus_g01_c04.avi\nThrowDiscus/v_ThrowDiscus_g02_c01.avi\nThrowDiscus/v_ThrowDiscus_g02_c02.avi\nThrowDiscus/v_ThrowDiscus_g02_c03.avi\nThrowDiscus/v_ThrowDiscus_g02_c04.avi\nThrowDiscus/v_ThrowDiscus_g02_c05.avi\nThrowDiscus/v_ThrowDiscus_g02_c06.avi\nThrowDiscus/v_ThrowDiscus_g02_c07.avi\nThrowDiscus/v_ThrowDiscus_g03_c01.avi\nThrowDiscus/v_ThrowDiscus_g03_c02.avi\nThrowDiscus/v_ThrowDiscus_g03_c03.avi\nThrowDiscus/v_ThrowDiscus_g03_c04.avi\nThrowDiscus/v_ThrowDiscus_g04_c01.avi\nThrowDiscus/v_ThrowDiscus_g04_c02.avi\nThrowDiscus/v_ThrowDiscus_g04_c03.avi\nThrowDiscus/v_ThrowDiscus_g04_c04.avi\nThrowDiscus/v_ThrowDiscus_g05_c01.avi\nThrowDiscus/v_ThrowDiscus_g05_c02.avi\nThrowDiscus/v_ThrowDiscus_g05_c03.avi\nThrowDiscus/v_ThrowDiscus_g05_c04.avi\nThrowDiscus/v_ThrowDiscus_g05_c05.avi\nThrowDiscus/v_ThrowDiscus_g06_c01.avi\nThrowDiscus/v_ThrowDiscus_g06_c02.avi\nThrowDiscus/v_ThrowDiscus_g06_c03.avi\nThrowDiscus/v_ThrowDiscus_g06_c04.avi\nThrowDiscus/v_ThrowDiscus_g06_c05.avi\nThrowDiscus/v_ThrowDiscus_g06_c06.avi\nThrowDiscus/v_ThrowDiscus_g06_c07.avi\nThrowDiscus/v_ThrowDiscus_g07_c01.avi\nThrowDiscus/v_ThrowDiscus_g07_c02.avi\nThrowDiscus/v_ThrowDiscus_g07_c03.avi\nThrowDiscus/v_ThrowDiscus_g07_c04.avi\nThrowDiscus/v_ThrowDiscus_g07_c05.avi\nThrowDiscus/v_ThrowDiscus_g07_c06.avi\nThrowDiscus/v_ThrowDiscus_g07_c07.avi\nTrampolineJumping/v_TrampolineJumping_g01_c01.avi\nTrampolineJumping/v_TrampolineJumping_g01_c02.avi\nTrampolineJumping/v_TrampolineJumping_g01_c03.avi\nTrampolineJumping/v_TrampolineJumping_g01_c04.avi\nTrampolineJumping/v_TrampolineJumping_g02_c01.avi\nTrampolineJumping/v_TrampolineJumping_g02_c02.avi\nTrampolineJumping/v_TrampolineJumping_g02_c03.avi\nTrampolineJumping/v_TrampolineJumping_g02_c04.avi\nTrampolineJumping/v_TrampolineJumping_g02_c05.avi\nTrampolineJumping/v_TrampolineJumping_g02_c06.avi\nTrampolineJumping/v_TrampolineJumping_g03_c01.avi\nTrampolineJumping/v_TrampolineJumping_g03_c02.avi\nTrampolineJumping/v_TrampolineJumping_g03_c03.avi\nTrampolineJumping/v_TrampolineJumping_g03_c04.avi\nTrampolineJumping/v_TrampolineJumping_g04_c01.avi\nTrampolineJumping/v_TrampolineJumping_g04_c02.avi\nTrampolineJumping/v_TrampolineJumping_g04_c03.avi\nTrampolineJumping/v_TrampolineJumping_g04_c04.avi\nTrampolineJumping/v_TrampolineJumping_g04_c05.avi\nTrampolineJumping/v_TrampolineJumping_g05_c01.avi\nTrampolineJumping/v_TrampolineJumping_g05_c02.avi\nTrampolineJumping/v_TrampolineJumping_g05_c03.avi\nTrampolineJumping/v_TrampolineJumping_g05_c04.avi\nTrampolineJumping/v_TrampolineJumping_g06_c01.avi\nTrampolineJumping/v_TrampolineJumping_g06_c02.avi\nTrampolineJumping/v_TrampolineJumping_g06_c03.avi\nTrampolineJumping/v_TrampolineJumping_g06_c04.avi\nTrampolineJumping/v_TrampolineJumping_g07_c01.avi\nTrampolineJumping/v_TrampolineJumping_g07_c02.avi\nTrampolineJumping/v_TrampolineJumping_g07_c03.avi\nTrampolineJumping/v_TrampolineJumping_g07_c04.avi\nTrampolineJumping/v_TrampolineJumping_g07_c05.avi\nTyping/v_Typing_g01_c01.avi\nTyping/v_Typing_g01_c02.avi\nTyping/v_Typing_g01_c03.avi\nTyping/v_Typing_g01_c04.avi\nTyping/v_Typing_g01_c05.avi\nTyping/v_Typing_g01_c06.avi\nTyping/v_Typing_g01_c07.avi\nTyping/v_Typing_g02_c01.avi\nTyping/v_Typing_g02_c02.avi\nTyping/v_Typing_g02_c03.avi\nTyping/v_Typing_g02_c04.avi\nTyping/v_Typing_g02_c05.avi\nTyping/v_Typing_g02_c06.avi\nTyping/v_Typing_g03_c01.avi\nTyping/v_Typing_g03_c02.avi\nTyping/v_Typing_g03_c03.avi\nTyping/v_Typing_g03_c04.avi\nTyping/v_Typing_g03_c05.avi\nTyping/v_Typing_g03_c06.avi\nTyping/v_Typing_g03_c07.avi\nTyping/v_Typing_g04_c01.avi\nTyping/v_Typing_g04_c02.avi\nTyping/v_Typing_g04_c03.avi\nTyping/v_Typing_g04_c04.avi\nTyping/v_Typing_g05_c01.avi\nTyping/v_Typing_g05_c02.avi\nTyping/v_Typing_g05_c03.avi\nTyping/v_Typing_g05_c04.avi\nTyping/v_Typing_g05_c05.avi\nTyping/v_Typing_g05_c06.avi\nTyping/v_Typing_g06_c01.avi\nTyping/v_Typing_g06_c02.avi\nTyping/v_Typing_g06_c03.avi\nTyping/v_Typing_g06_c04.avi\nTyping/v_Typing_g06_c05.avi\nTyping/v_Typing_g06_c06.avi\nTyping/v_Typing_g06_c07.avi\nTyping/v_Typing_g07_c01.avi\nTyping/v_Typing_g07_c02.avi\nTyping/v_Typing_g07_c03.avi\nTyping/v_Typing_g07_c04.avi\nTyping/v_Typing_g07_c05.avi\nTyping/v_Typing_g07_c06.avi\nUnevenBars/v_UnevenBars_g01_c01.avi\nUnevenBars/v_UnevenBars_g01_c02.avi\nUnevenBars/v_UnevenBars_g01_c03.avi\nUnevenBars/v_UnevenBars_g01_c04.avi\nUnevenBars/v_UnevenBars_g02_c01.avi\nUnevenBars/v_UnevenBars_g02_c02.avi\nUnevenBars/v_UnevenBars_g02_c03.avi\nUnevenBars/v_UnevenBars_g02_c04.avi\nUnevenBars/v_UnevenBars_g03_c01.avi\nUnevenBars/v_UnevenBars_g03_c02.avi\nUnevenBars/v_UnevenBars_g03_c03.avi\nUnevenBars/v_UnevenBars_g03_c04.avi\nUnevenBars/v_UnevenBars_g04_c01.avi\nUnevenBars/v_UnevenBars_g04_c02.avi\nUnevenBars/v_UnevenBars_g04_c03.avi\nUnevenBars/v_UnevenBars_g04_c04.avi\nUnevenBars/v_UnevenBars_g05_c01.avi\nUnevenBars/v_UnevenBars_g05_c02.avi\nUnevenBars/v_UnevenBars_g05_c03.avi\nUnevenBars/v_UnevenBars_g05_c04.avi\nUnevenBars/v_UnevenBars_g06_c01.avi\nUnevenBars/v_UnevenBars_g06_c02.avi\nUnevenBars/v_UnevenBars_g06_c03.avi\nUnevenBars/v_UnevenBars_g06_c04.avi\nUnevenBars/v_UnevenBars_g07_c01.avi\nUnevenBars/v_UnevenBars_g07_c02.avi\nUnevenBars/v_UnevenBars_g07_c03.avi\nUnevenBars/v_UnevenBars_g07_c04.avi\nVolleyballSpiking/v_VolleyballSpiking_g01_c01.avi\nVolleyballSpiking/v_VolleyballSpiking_g01_c02.avi\nVolleyballSpiking/v_VolleyballSpiking_g01_c03.avi\nVolleyballSpiking/v_VolleyballSpiking_g01_c04.avi\nVolleyballSpiking/v_VolleyballSpiking_g02_c01.avi\nVolleyballSpiking/v_VolleyballSpiking_g02_c02.avi\nVolleyballSpiking/v_VolleyballSpiking_g02_c03.avi\nVolleyballSpiking/v_VolleyballSpiking_g02_c04.avi\nVolleyballSpiking/v_VolleyballSpiking_g03_c01.avi\nVolleyballSpiking/v_VolleyballSpiking_g03_c02.avi\nVolleyballSpiking/v_VolleyballSpiking_g03_c03.avi\nVolleyballSpiking/v_VolleyballSpiking_g03_c04.avi\nVolleyballSpiking/v_VolleyballSpiking_g04_c01.avi\nVolleyballSpiking/v_VolleyballSpiking_g04_c02.avi\nVolleyballSpiking/v_VolleyballSpiking_g04_c03.avi\nVolleyballSpiking/v_VolleyballSpiking_g04_c04.avi\nVolleyballSpiking/v_VolleyballSpiking_g04_c05.avi\nVolleyballSpiking/v_VolleyballSpiking_g04_c06.avi\nVolleyballSpiking/v_VolleyballSpiking_g04_c07.avi\nVolleyballSpiking/v_VolleyballSpiking_g05_c01.avi\nVolleyballSpiking/v_VolleyballSpiking_g05_c02.avi\nVolleyballSpiking/v_VolleyballSpiking_g05_c03.avi\nVolleyballSpiking/v_VolleyballSpiking_g05_c04.avi\nVolleyballSpiking/v_VolleyballSpiking_g05_c05.avi\nVolleyballSpiking/v_VolleyballSpiking_g06_c01.avi\nVolleyballSpiking/v_VolleyballSpiking_g06_c02.avi\nVolleyballSpiking/v_VolleyballSpiking_g06_c03.avi\nVolleyballSpiking/v_VolleyballSpiking_g06_c04.avi\nVolleyballSpiking/v_VolleyballSpiking_g07_c01.avi\nVolleyballSpiking/v_VolleyballSpiking_g07_c02.avi\nVolleyballSpiking/v_VolleyballSpiking_g07_c03.avi\nVolleyballSpiking/v_VolleyballSpiking_g07_c04.avi\nVolleyballSpiking/v_VolleyballSpiking_g07_c05.avi\nVolleyballSpiking/v_VolleyballSpiking_g07_c06.avi\nVolleyballSpiking/v_VolleyballSpiking_g07_c07.avi\nWalkingWithDog/v_WalkingWithDog_g01_c01.avi\nWalkingWithDog/v_WalkingWithDog_g01_c02.avi\nWalkingWithDog/v_WalkingWithDog_g01_c03.avi\nWalkingWithDog/v_WalkingWithDog_g01_c04.avi\nWalkingWithDog/v_WalkingWithDog_g02_c01.avi\nWalkingWithDog/v_WalkingWithDog_g02_c02.avi\nWalkingWithDog/v_WalkingWithDog_g02_c03.avi\nWalkingWithDog/v_WalkingWithDog_g02_c04.avi\nWalkingWithDog/v_WalkingWithDog_g02_c05.avi\nWalkingWithDog/v_WalkingWithDog_g02_c06.avi\nWalkingWithDog/v_WalkingWithDog_g03_c01.avi\nWalkingWithDog/v_WalkingWithDog_g03_c02.avi\nWalkingWithDog/v_WalkingWithDog_g03_c03.avi\nWalkingWithDog/v_WalkingWithDog_g03_c04.avi\nWalkingWithDog/v_WalkingWithDog_g03_c05.avi\nWalkingWithDog/v_WalkingWithDog_g04_c01.avi\nWalkingWithDog/v_WalkingWithDog_g04_c02.avi\nWalkingWithDog/v_WalkingWithDog_g04_c03.avi\nWalkingWithDog/v_WalkingWithDog_g04_c04.avi\nWalkingWithDog/v_WalkingWithDog_g04_c05.avi\nWalkingWithDog/v_WalkingWithDog_g05_c01.avi\nWalkingWithDog/v_WalkingWithDog_g05_c02.avi\nWalkingWithDog/v_WalkingWithDog_g05_c03.avi\nWalkingWithDog/v_WalkingWithDog_g05_c04.avi\nWalkingWithDog/v_WalkingWithDog_g05_c05.avi\nWalkingWithDog/v_WalkingWithDog_g06_c01.avi\nWalkingWithDog/v_WalkingWithDog_g06_c02.avi\nWalkingWithDog/v_WalkingWithDog_g06_c03.avi\nWalkingWithDog/v_WalkingWithDog_g06_c04.avi\nWalkingWithDog/v_WalkingWithDog_g06_c05.avi\nWalkingWithDog/v_WalkingWithDog_g07_c01.avi\nWalkingWithDog/v_WalkingWithDog_g07_c02.avi\nWalkingWithDog/v_WalkingWithDog_g07_c03.avi\nWalkingWithDog/v_WalkingWithDog_g07_c04.avi\nWalkingWithDog/v_WalkingWithDog_g07_c05.avi\nWalkingWithDog/v_WalkingWithDog_g07_c06.avi\nWallPushups/v_WallPushups_g01_c01.avi\nWallPushups/v_WallPushups_g01_c02.avi\nWallPushups/v_WallPushups_g01_c03.avi\nWallPushups/v_WallPushups_g01_c04.avi\nWallPushups/v_WallPushups_g02_c01.avi\nWallPushups/v_WallPushups_g02_c02.avi\nWallPushups/v_WallPushups_g02_c03.avi\nWallPushups/v_WallPushups_g02_c04.avi\nWallPushups/v_WallPushups_g03_c01.avi\nWallPushups/v_WallPushups_g03_c02.avi\nWallPushups/v_WallPushups_g03_c03.avi\nWallPushups/v_WallPushups_g03_c04.avi\nWallPushups/v_WallPushups_g03_c05.avi\nWallPushups/v_WallPushups_g04_c01.avi\nWallPushups/v_WallPushups_g04_c02.avi\nWallPushups/v_WallPushups_g04_c03.avi\nWallPushups/v_WallPushups_g04_c04.avi\nWallPushups/v_WallPushups_g05_c01.avi\nWallPushups/v_WallPushups_g05_c02.avi\nWallPushups/v_WallPushups_g05_c03.avi\nWallPushups/v_WallPushups_g05_c04.avi\nWallPushups/v_WallPushups_g05_c05.avi\nWallPushups/v_WallPushups_g06_c01.avi\nWallPushups/v_WallPushups_g06_c02.avi\nWallPushups/v_WallPushups_g06_c03.avi\nWallPushups/v_WallPushups_g06_c04.avi\nWallPushups/v_WallPushups_g06_c05.avi\nWallPushups/v_WallPushups_g06_c06.avi\nWallPushups/v_WallPushups_g06_c07.avi\nWallPushups/v_WallPushups_g07_c01.avi\nWallPushups/v_WallPushups_g07_c02.avi\nWallPushups/v_WallPushups_g07_c03.avi\nWallPushups/v_WallPushups_g07_c04.avi\nWallPushups/v_WallPushups_g07_c05.avi\nWallPushups/v_WallPushups_g07_c06.avi\nWritingOnBoard/v_WritingOnBoard_g01_c01.avi\nWritingOnBoard/v_WritingOnBoard_g01_c02.avi\nWritingOnBoard/v_WritingOnBoard_g01_c03.avi\nWritingOnBoard/v_WritingOnBoard_g01_c04.avi\nWritingOnBoard/v_WritingOnBoard_g01_c05.avi\nWritingOnBoard/v_WritingOnBoard_g01_c06.avi\nWritingOnBoard/v_WritingOnBoard_g01_c07.avi\nWritingOnBoard/v_WritingOnBoard_g02_c01.avi\nWritingOnBoard/v_WritingOnBoard_g02_c02.avi\nWritingOnBoard/v_WritingOnBoard_g02_c03.avi\nWritingOnBoard/v_WritingOnBoard_g02_c04.avi\nWritingOnBoard/v_WritingOnBoard_g02_c05.avi\nWritingOnBoard/v_WritingOnBoard_g02_c06.avi\nWritingOnBoard/v_WritingOnBoard_g02_c07.avi\nWritingOnBoard/v_WritingOnBoard_g03_c01.avi\nWritingOnBoard/v_WritingOnBoard_g03_c02.avi\nWritingOnBoard/v_WritingOnBoard_g03_c03.avi\nWritingOnBoard/v_WritingOnBoard_g03_c04.avi\nWritingOnBoard/v_WritingOnBoard_g03_c05.avi\nWritingOnBoard/v_WritingOnBoard_g03_c06.avi\nWritingOnBoard/v_WritingOnBoard_g03_c07.avi\nWritingOnBoard/v_WritingOnBoard_g04_c01.avi\nWritingOnBoard/v_WritingOnBoard_g04_c02.avi\nWritingOnBoard/v_WritingOnBoard_g04_c03.avi\nWritingOnBoard/v_WritingOnBoard_g04_c04.avi\nWritingOnBoard/v_WritingOnBoard_g05_c01.avi\nWritingOnBoard/v_WritingOnBoard_g05_c02.avi\nWritingOnBoard/v_WritingOnBoard_g05_c03.avi\nWritingOnBoard/v_WritingOnBoard_g05_c04.avi\nWritingOnBoard/v_WritingOnBoard_g05_c05.avi\nWritingOnBoard/v_WritingOnBoard_g05_c06.avi\nWritingOnBoard/v_WritingOnBoard_g06_c01.avi\nWritingOnBoard/v_WritingOnBoard_g06_c02.avi\nWritingOnBoard/v_WritingOnBoard_g06_c03.avi\nWritingOnBoard/v_WritingOnBoard_g06_c04.avi\nWritingOnBoard/v_WritingOnBoard_g06_c05.avi\nWritingOnBoard/v_WritingOnBoard_g06_c06.avi\nWritingOnBoard/v_WritingOnBoard_g06_c07.avi\nWritingOnBoard/v_WritingOnBoard_g07_c01.avi\nWritingOnBoard/v_WritingOnBoard_g07_c02.avi\nWritingOnBoard/v_WritingOnBoard_g07_c03.avi\nWritingOnBoard/v_WritingOnBoard_g07_c04.avi\nWritingOnBoard/v_WritingOnBoard_g07_c05.avi\nWritingOnBoard/v_WritingOnBoard_g07_c06.avi\nWritingOnBoard/v_WritingOnBoard_g07_c07.avi\nYoYo/v_YoYo_g01_c01.avi\nYoYo/v_YoYo_g01_c02.avi\nYoYo/v_YoYo_g01_c03.avi\nYoYo/v_YoYo_g01_c04.avi\nYoYo/v_YoYo_g01_c05.avi\nYoYo/v_YoYo_g01_c06.avi\nYoYo/v_YoYo_g01_c07.avi\nYoYo/v_YoYo_g02_c01.avi\nYoYo/v_YoYo_g02_c02.avi\nYoYo/v_YoYo_g02_c03.avi\nYoYo/v_YoYo_g02_c04.avi\nYoYo/v_YoYo_g02_c05.avi\nYoYo/v_YoYo_g03_c01.avi\nYoYo/v_YoYo_g03_c02.avi\nYoYo/v_YoYo_g03_c03.avi\nYoYo/v_YoYo_g03_c04.avi\nYoYo/v_YoYo_g03_c05.avi\nYoYo/v_YoYo_g03_c06.avi\nYoYo/v_YoYo_g04_c01.avi\nYoYo/v_YoYo_g04_c02.avi\nYoYo/v_YoYo_g04_c03.avi\nYoYo/v_YoYo_g04_c04.avi\nYoYo/v_YoYo_g04_c05.avi\nYoYo/v_YoYo_g05_c01.avi\nYoYo/v_YoYo_g05_c02.avi\nYoYo/v_YoYo_g05_c03.avi\nYoYo/v_YoYo_g05_c04.avi\nYoYo/v_YoYo_g05_c05.avi\nYoYo/v_YoYo_g06_c01.avi\nYoYo/v_YoYo_g06_c02.avi\nYoYo/v_YoYo_g06_c03.avi\nYoYo/v_YoYo_g06_c04.avi\nYoYo/v_YoYo_g07_c01.avi\nYoYo/v_YoYo_g07_c02.avi\nYoYo/v_YoYo_g07_c03.avi\nYoYo/v_YoYo_g07_c04.avi\n"
  },
  {
    "path": "braincog/datasets/scripts/ucf101_dvs_preprocessing.py",
    "content": "# encoding: utf-8\n# Author    : Floyed<Floyed_Shen@outlook.com>\n# Datetime  : 2022/12/20 20:16\n# User      : Floyed\n# Product   : PyCharm\n# Project   : BrainCog\n# File      : ucf101_dvs_preprocessing.py\n# explain   :\n\nimport os\nimport shutil\n\n\nROOT_DIR = '/data/datasets/UCF101_DVS/UCF101_DVS'\ntrain_path = os.path.join(ROOT_DIR, 'train')\nval_path = os.path.join(ROOT_DIR, 'val')\nval_fname = 'testlist01.txt'\n\ncls_path = os.listdir(train_path)\n\nif not os.path.exists(val_path):\n    os.mkdir(val_path)\n    for cls_name in cls_path:\n        os.mkdir(os.path.join(val_path, cls_name))\n\nf = open(val_fname, 'r')\n\nfor fname in f.readlines():\n    fname = fname[:-4] + 'mat'\n    fname.replace('Billards', 'Billiards')\n    src = os.path.join(train_path, fname)\n    dst = os.path.join(val_path, fname)\n    try:\n        shutil.move(src, dst)\n    except:\n        print('[Warning] Cannot find {}.'.format(src))\n    print('[Moving] {} -> {}.'.format(src, dst))\n"
  },
  {
    "path": "braincog/datasets/ucf101_dvs/__init__.py",
    "content": "# encoding: utf-8\n# Author    : Floyed<Floyed_Shen@outlook.com>\n# Datetime  : 2023/1/30 21:04\n# User      : yu\n# Product   : PyCharm\n# Project   : BrainCog\n# File      : __init__.py.py\n# explain   :\n\nfrom .ucf101_dvs import UCF101DVS\n\n__all__ = [\n    'UCF101DVS'\n]"
  },
  {
    "path": "braincog/datasets/ucf101_dvs/ucf101_dvs.py",
    "content": "# encoding: utf-8\n# Author    : Floyed<Floyed_Shen@outlook.com>\n# Datetime  : 2023/1/30 21:05\n# User      : yu\n# Product   : PyCharm\n# Project   : BrainCog\n# File      : ucf51_dvs.py\n# explain   :\n# encoding: utf-8\n# Author    : Floyed<Floyed_Shen@outlook.com>\n# Datetime  : 2022/12/20 20:47\n# User      : Floyed\n# Product   : PyCharm\n# Project   : tonic\n# File      : ucf101dvs.py\n# explain   :\n\nimport os\nimport numpy as np\nfrom numpy.lib import recfunctions\nimport scipy.io as scio\nfrom typing import Tuple, Any, Optional\nfrom tonic.dataset import Dataset\nfrom tonic.download_utils import extract_archive\n\n\nclass UCF101DVS(Dataset):\n    \"\"\"ASL-DVS dataset <https://github.com/PIX2NVS/NVS2Graph>. Events have (txyp) ordering.\n    ::\n\n        @inproceedings{bi2019graph,\n            title={Graph-based Object Classification for Neuromorphic Vision Sensing},\n            author={Bi, Y and Chadha, A and Abbas, A and and Bourtsoulatze, E and Andreopoulos, Y},\n            booktitle={2019 IEEE International Conference on Computer Vision (ICCV)},\n            year={2019},\n            organization={IEEE}\n        }\n\n    Parameters:\n        save_to (string): Location to save files to on disk.\n        transform (callable, optional): A callable of transforms to apply to the data.\n        target_transform (callable, optional): A callable of transforms to apply to the targets/labels.\n    \"\"\"\n\n    sensor_size = (240, 180, 2)\n    dtype = np.dtype([(\"t\", int), (\"x\", int), (\"y\", int), (\"p\", int)])\n    ordering = dtype.names\n    folder_name = 'UCF101DVS'\n    def __init__(self, save_to, train=False, transform=None, target_transform=None):\n        super(UCF101DVS, self).__init__(\n            save_to, transform=transform, target_transform=target_transform\n        )\n\n        if not self._check_exists():\n            raise NotImplementedError(\n                'Please manually download the dataset from'\n                ' https://www.dropbox.com/sh/ie75dn246cacf6n/AACoU-_zkGOAwj51lSCM0JhGa?dl=0 '\n                'and extract it to {}'.format(self.location_on_system))\n\n        if train:\n            self.location_on_system = os.path.join(self.location_on_system, 'train')\n        else:\n            self.location_on_system = os.path.join(self.location_on_system, 'val')\n\n        classes = os.listdir(self.location_on_system)\n        self.int_classes = dict(zip(classes, range(len(classes))))\n\n        for path, dirs, files in os.walk(self.location_on_system):\n            dirs.sort()\n            files.sort()\n            for file in files:\n                if file.endswith(\"mat\"):\n                    fsize = os.path.getsize(path + '/' + file) / float(1024)\n                    if fsize < 1:\n                        # print('{} size {} K'.format(file, fsize))\n                        continue\n                    self.data.append(path + \"/\" + file)\n                    self.targets.append(self.int_classes[path.split('/')[-1]])\n\n    def __getitem__(self, index: int) -> Tuple[Any, Any]:\n        \"\"\"\n        Returns:\n            (events, target) where target is index of the target class.\n        \"\"\"\n        events, target = scio.loadmat(self.data[index]), self.targets[index]\n        events = np.column_stack(\n            [\n                events[\"ts\"],\n                events[\"x\"],\n                self.sensor_size[1] - 1 - events[\"y\"],\n                events[\"pol\"],\n            ]\n        )\n        events = np.lib.recfunctions.unstructured_to_structured(events, self.dtype)\n        if self.transform is not None:\n            events = self.transform(events)\n        if self.target_transform is not None:\n            target = self.target_transform(target)\n        return events, target\n\n    def __len__(self):\n        return len(self.data)\n\n    def _check_exists(self):\n        print(self.folder_name)\n        return self._folder_contains_at_least_n_files_of_type(\n            13523, \".mat\"\n        )\n"
  },
  {
    "path": "braincog/datasets/utils.py",
    "content": "import torch\nfrom einops import repeat\nfrom braincog.datasets.gen_input_signal import lambda_max\n\n\ndef rescale(x, factor=None):\n    \"\"\"\n    数据放缩函数\n    :param x: 输入的tensor\n    :param factor: 缩放因子\n    :return: 缩放后的数据\n    \"\"\"\n    if factor:\n        x *= factor\n    else:\n        x *= lambda_max\n    return x\n\n\ndef dvs_channel_check_expend(x):\n    \"\"\"\n    检查是否存在DVS数据缺失, N-Car中有的数据会缺少一个通道\n    :param x: 输入的tensor\n    :return: 补全之后的数据\n    \"\"\"\n    if x.shape[1] == 1:\n        return repeat(x, 'b c w h -> b (r c) w h', r=2)\n    else:\n        return x\n"
  },
  {
    "path": "braincog/model_zoo/NeuEvo/__init__.py",
    "content": "# encoding: utf-8\n# Author    : Floyed<Floyed_Shen@outlook.com>\n# Datetime  : 2022/9/1 16:43\n# User      : Floyed\n# Product   : PyCharm\n# Project   : BrainCog\n# File      : __init__.py.py\n# explain   :\n\nimport os\nimport numpy as np\nfrom .genotypes import PRIMITIVES, Genotype\n\nforward_edge_num = sum(1 for i in range(3) for n in range(2 + i))\nbackward_edge_num = sum(1 for i in range(3) for n in range(i))\nnum_ops = len(PRIMITIVES)\ntype_num = len(PRIMITIVES) // 2\n# edge_num = [2, 3, 4]\n\n# node_id: (forward) 2, 3, 4\n# node_id: (backward) 3, 2\nedge_num = [2, 3, 4, 1, 2]\n\n\ndef parse(weights, operation_set,\n          op_threshold, parse_method,\n          steps, reduction=False,\n          back_connection=False):\n    global k_best\n    gene = []\n    if parse_method == 'darts':\n        n = 2\n        start = 0\n        for i in range(steps):  # step = 4\n            end = start + n\n            W = weights[start:end].copy()\n            edges = sorted(range(i + 2), key=lambda x: -\n                           max(W[x][k] for k in range(len(W[x]))))[:2]\n            for j in edges:\n\n                for k in range(len(W[j])):\n                    if k_best is None or W[j][k] > W[j][k_best]:\n                        k_best = k\n                # geno item : (operation, node idx)\n                gene.append((operation_set[k_best], j))\n            start = end\n            n += 1\n\n    elif parse_method == 'bio_darts':\n        weights_backward = weights[forward_edge_num:]\n        weights_forward = weights[:forward_edge_num]\n\n        # forward\n        n = 2\n        start = 0\n\n        # idx = np.argsort(weights_forward[:, 0]).tolist()\n        # if reduction:\n        #     idx.remove(0)\n        #     idx.remove(1)\n        # weights_forward[:, 0] = 0.\n        # weights_forward[idx[-2:], 0] = 1.\n\n        for i in range(steps):  # step = 4\n            end = start + n\n            W = weights_forward[start:end].copy()\n            edges = sorted(range(i + 2), key=lambda x: -\n                           max(W[x][k] for k in range(len(W[x]))))[:2]\n            k_best = None\n            idx = np.argsort(W[edges[0]])\n            gene.append((operation_set[idx[-1]], edges[0]))\n            idx = np.argsort(W[edges[1]])\n            gene.append((operation_set[idx[-1]], edges[1]))\n            #\n            # op_name = operation_set[idx[-1]]\n            # idx = np.argsort(W[edges[1]])\n            # if 'skip' in op_name:\n            #     gene.append((operation_set[idx[-1]], edges[1]))\n            # elif '_n' in op_name:\n            #     for k in reversed(idx):\n            #         if '_n' not in operation_set[k]:\n            #             gene.append((operation_set[k], edges[1]))\n            #             break\n            # else:\n            #     for k in reversed(idx):\n            #         if '_n' in operation_set[k]:\n            #             gene.append((operation_set[k], edges[1]))\n            #             break\n\n            start = end\n            n += 1\n\n        if back_connection:\n            # backward\n            n = 1\n            start = 0\n            for i in range(1, steps):\n                end = start + n\n                W = weights_backward[start:end].copy()\n                edges = sorted(range(i), key=lambda x: -\n                               max(W[x][k] for k in range(len(W[x]))))[0]\n                idx = np.argsort(W[edges])\n                gene.append((operation_set[idx[-1]] + '_back', edges + 2))\n\n                start = end\n                n += 1\n\n    elif 'threshold' in parse_method:\n        n = 2\n        start = 0\n        for i in range(steps):  # step = 4\n            end = start + n\n            W = weights[start:end].copy()\n            if 'edge' in parse_method:\n                edges = list(range(i + 2))\n            else:  # select edges using darts methods\n                edges = sorted(range(i + 2), key=lambda x: -\n                               max(W[x][k] for k in range(len(W[x]))))[:2]\n\n            for j in edges:\n                if 'edge' in parse_method:  # OP_{prob > T} AND |Edge| <= 2\n                    topM = sorted(enumerate(W[j]), key=lambda x: x[1])[-2:]\n                    for k, v in topM:  # Get top M = 2 operations for one edge\n                        if W[j][k] >= op_threshold:\n                            gene.append((operation_set[k], i + 2, j))\n                # max( OP_{prob > T} ) and |Edge| <= 2\n                elif 'sparse' in parse_method:\n                    k_best = None\n                    for k in range(len(W[j])):\n                        if k_best is None or W[j][k] > W[j][k_best]:\n                            k_best = k\n                    if W[j][k_best] >= op_threshold:\n                        gene.append((operation_set[k_best], i + 2, j))\n                else:\n                    raise NotImplementedError(\n                        \"Not support parse method: {}\".format(parse_method))\n            start = end\n            n += 1\n    return gene\n\n\ndef parse_genotype(alphas, steps, multiplier, path=None,\n                   parse_method='threshold_sparse', op_threshold=0.85):\n    alphas_normal, alphas_reduce = alphas\n    gene_normal = parse(alphas_normal, PRIMITIVES,\n                        op_threshold, parse_method, steps)\n    gene_reduce = parse(alphas_reduce, PRIMITIVES,\n                        op_threshold, parse_method, steps)\n    concat = range(2 + steps - multiplier, steps + 2)\n    genotype = Genotype(\n        normal=gene_normal, normal_concat=concat,\n        reduce=gene_reduce, reduce_concat=concat\n    )\n\n    if path is not None:\n        if not os.path.exists(path):\n            os.makedirs(path)\n        print('Architecture parsing....\\n', genotype)\n        save_path = os.path.join(\n            path, parse_method + '_' + str(op_threshold) + '.txt')\n        with open(save_path, \"w+\") as f:\n            f.write(str(genotype))\n            print('Save in :', save_path)\n\n"
  },
  {
    "path": "braincog/model_zoo/NeuEvo/architect.py",
    "content": "import torch\nfrom torch.autograd import Variable\nimport torch.nn.functional as F\nimport numpy as np\nfrom numpy.linalg import eigvals\nfrom braincog.model_zoo.NeuEvo.model_search import calc_weight, calc_loss\n\n\ndef normalize(x):\n    mu = np.average(x)\n    sigma = np.std(x)\n    return (x - mu) / sigma\n\n\ndef _concat(xs):\n    return torch.cat([x.view(-1) for x in xs])\n\n\nclass Architect(object):\n    def __init__(self, model, args):\n\n        self.network_momentum = args.momentum\n        self.network_weight_decay = args.weight_decay\n        self.model = model\n        self.optimizer = torch.optim.AdamW(self.model.arch_parameters(),\n                                           lr=args.arch_learning_rate,\n                                           betas=(args.arch_lr_gamma, 0.999),\n                                           weight_decay=args.arch_weight_decay)\n        # self.optimizer = torch.optim.SGD(self.model.arch_parameters(), lr=args.arch_learning_rate)\n        self.hessian = None\n        self.grads = None\n\n    def step(self, input_valid, target_valid):\n        self.optimizer.zero_grad()\n        aux_input = torch.cat([calc_loss(self.model.alphas_normal)], dim=0)\n        loss, loss1, loss2 = self.model._loss(\n            input_valid, target_valid, aux_input)\n        # loss = self.model._loss(input_valid, target_valid)\n        loss.backward()\n        self.optimizer.step()\n        return loss1, loss2\n\n    def compute_Hw(self, input_valid, target_valid):\n        self.zero_grads(self.model.parameters())\n        self.zero_grads(self.model.arch_parameters())\n        aux_input = torch.cat(\n            [F.softmax(self.model.alphas_normal, dim=-1)], dim=0)\n        loss = self.model._loss(input_valid, target_valid, aux_input)\n        self.hessian = self._hessian(loss, self.model.arch_parameters())\n        return self.hessian\n\n    def zero_grads(self, parameters):\n        for p in parameters:\n            if p.grad is not None:\n                p.grad.detach_()\n                p.grad.zero_()\n\n    def compute_eigenvalues(self):\n        self.compute_Hw()\n        return eigvals(self.hessian.cpu().data.numpy())\n\n    def _hessian(self, outputs, inputs, out=None, allow_unused=False):\n        if torch.is_tensor(inputs):\n            inputs = [inputs]\n        else:\n            inputs = list(inputs)\n\n        n = sum(p.numel() for p in inputs)\n        if out is None:\n            out = torch.tensor(torch.zeros(n, n)).type_as(outputs)\n\n        ai = 0\n        for i, inp in enumerate(inputs):\n            [grad] = torch.autograd.grad(outputs, inp, create_graph=True,\n                                         allow_unused=allow_unused)\n            grad = grad.contiguous().view(-1) + self.weight_decay * inp.view(-1)\n            for j in range(inp.numel()):\n                if grad[j].requires_grad:\n                    row = self.gradient(\n                        grad[j], inputs[i:], retain_graph=True)[j:]\n                else:\n                    n = sum(x.numel() for x in inputs[i:]) - j\n                    row = Variable(torch.zeros(n)).type_as(grad[j])\n\n                out.data[ai, ai:].add_(row.clone().type_as(out).data)\n                if ai + 1 < n:\n                    out.data[ai + 1:,\n                             ai].add_(row.clone().type_as(out).data[1:])\n                del row\n                ai += 1\n            del grad\n        return out\n"
  },
  {
    "path": "braincog/model_zoo/NeuEvo/genotypes.py",
    "content": "from collections import namedtuple\n\nimport torch\n\nGenotype = namedtuple('Genotype', 'normal normal_concat')\n\n\"\"\"\nOperation sets\n\"\"\"\n\nPRIMITIVES = [\n    'conv_3x3_p',\n    # 'max_pool_3x3',\n    # 'avg_pool_3x3',\n    # 'def_conv_3x3',\n    # 'def_conv_5x5',\n    # 'sep_conv_3x3',\n    # 'sep_conv_5x5',\n    # 'dil_conv_3x3',\n    # 'dil_conv_5x5',\n\n    # 'max_pool_3x3_p',\n    # 'avg_pool_3x3_p',\n    'conv_3x3_p',\n    'conv_5x5_p',\n    # 'conv_3x3_p_p',\n    # 'sep_conv_3x3_p',\n    # 'sep_conv_5x5_p',\n    # 'dil_conv_3x3_p',\n    # 'dil_conv_5x5_p',\n    # 'def_conv_3x3_p',\n    # 'def_conv_5x5_p',n\n\n    # 'max_pool_3x3_n',\n    # 'avg_pool_3x3_n',\n    'conv_3x3_n',\n    'conv_5x5_n',\n    # 'conv_3x3_p_n',\n    # 'sep_conv_3x3_n',\n    # 'sep_conv_5x5_n',\n    # 'dil_conv_3x3_n',\n    # 'dil_conv_5x5_n',\n    # 'def_conv_3x3_n',\n    # 'def_conv_5x5_n',\n\n    # 'transformer',\n]\n\"\"\"====== SnnMlp Archirtecture By Other Methods\"\"\"\n\nmlp1 = Genotype(\n    normal=[\n        ('mlp', 0), ('conv_3x3_p', 1),  # 2\n        ('mlp', 1), ('mlp', 0),  # 3\n        ('conv_3x3_p', 2), ('mlp', 3),  # 4\n        ('mlp_back', 2),\n        ('conv_3x3_p_back', 2)\n    ],\n    normal_concat=range(2, 5)\n)\n\nmlp2 = Genotype(\n    normal=[\n        ('mlp', 0), ('conv_3x3_p', 1),\n        ('conv_3x3_p', 2), ('mlp_p', 1),\n        # ('mlp_n', 1), ('conv_3x3_p', 2),\n        ('mlp_back', 2)\n    ],\n    normal_concat=range(2, 4)\n)\n\n\n\"\"\"====== SNN Archirtecture By Other Methods\"\"\"\n\ndvsc10_new_skip22 = Genotype(\n    normal=[\n        ('conv_3x3_p', 1), ('conv_3x3_p', 0),  # 2\n        ('conv_5x5_p', 1), ('conv_3x3_p', 2),  # 3\n        ('conv_3x3_p', 0), ('conv_3x3_p', 3),  # 4\n        ('conv_3x3_n_back', 2), ('conv_3x3_p_back', 3)  # 3, 4\n    ],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_new_skip22 = Genotype(\n    normal=[\n        ('conv_3x3_p', 1), ('conv_3x3_p', 0),\n        ('conv_5x5_n', 1), ('conv_3x3_p', 2),\n        ('conv_5x5_n', 0), ('conv_3x3_p', 3),\n        ('conv_3x3_n_back', 0), ('conv_3x3_p_back', 1)\n    ],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_new_skip21 = Genotype(\n    normal=[\n        ('conv_3x3_n', 0), ('conv_5x5_p', 1),  # 2\n        ('conv_3x3_p', 1), ('conv_5x5_p', 2),  # 3\n        ('conv_5x5_n', 2), ('conv_3x3_p', 1),  # 4\n        # ('conv_3x3_p_back', 2), ('conv_5x5_p_back', 2)\n    ],\n    normal_concat=range(2, 5)\n)\n\n\ndvsc10_new_skip20 = Genotype(\n    normal=[\n        ('conv_5x5_p', 0), ('conv_5x5_n', 1),\n        ('conv_3x3_n', 2), ('conv_5x5_p', 0),\n        ('conv_3x3_p', 2), ('conv_3x3_n', 3),\n        ('conv_3x3_p_back', 2),\n        ('conv_5x5_p_back', 3)\n    ],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_new_skip19 = Genotype(\n    normal=[\n        ('conv_5x5_n', 0), ('conv_3x3_p', 1),\n        ('conv_5x5_n', 2), ('conv_5x5_n', 0),\n        ('conv_3x3_p', 2), ('conv_5x5_p', 3),\n        ('conv_3x3_p_back', 2),\n        ('conv_5x5_p_back', 2)\n    ],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_new_skip18 = Genotype(\n    normal=[\n        ('conv_5x5_p', 0), ('conv_3x3_p', 1),\n        ('conv_5x5_p', 2), ('conv_5x5_n', 0),\n        ('conv_3x3_p', 2), ('conv_5x5_p', 3),\n        ('conv_5x5_n_back', 2),\n        ('conv_3x3_p_back', 2)],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_new_skip17 = Genotype(\n    normal=[\n        ('conv_3x3_p', 1), ('conv_5x5_n', 0),\n        ('conv_5x5_n', 2), ('conv_5x5_p', 1),\n        ('conv_3x3_p', 2), ('avg_pool_3x3_p', 3),\n        ('avg_pool_3x3_p_back', 2), ('conv_3x3_p_back', 2)\n    ],\n    normal_concat=range(2, 5)\n)\ndvsc10_new_skip16 = Genotype(\n    normal=[\n        ('conv_5x5_p', 0), ('conv_5x5_n', 1),\n        ('conv_3x3_n', 2), ('avg_pool_3x3_p', 0),\n        ('conv_3x3_p', 2), ('avg_pool_3x3_n', 3),\n        ('conv_3x3_p_back', 2),\n        ('conv_3x3_p_back', 3)\n    ],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_new_skip15 = Genotype(\n    normal=[\n        ('conv_5x5_n', 0), ('conv_5x5_p', 1),\n        ('conv_5x5_p', 2), ('conv_5x5_n', 1),\n        ('conv_3x3_p', 2), ('conv_5x5_p', 3),\n        ('conv_5x5_p_back', 2),\n        ('conv_3x3_p_back', 3)\n    ],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_new_skip14 = Genotype(\n    normal=[\n        ('conv_5x5_n', 1), ('conv_3x3_p', 0),\n        ('conv_5x5_p', 1), ('conv_3x3_p', 2),\n        ('conv_3x3_p', 2), ('conv_5x5_n', 1),\n        ('conv_3x3_n_back', 2),\n        ('conv_3x3_p_back', 3)],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_new_skip13 = Genotype(\n    normal=[\n        ('conv_5x5_n', 1), ('conv_3x3_p', 0),\n        ('conv_5x5_n', 1), ('conv_3x3_p', 2),\n        ('conv_3x3_p', 2), ('conv_5x5_n', 1),\n        ('conv_3x3_n_back', 2),\n        ('conv_3x3_p_back', 3)],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_new_skip12 = Genotype(\n    normal=[\n        ('conv_5x5_n', 0), ('conv_3x3_p', 1),\n        ('conv_5x5_n', 2), ('conv_5x5_n', 0),\n        ('conv_3x3_p', 2), ('conv_5x5_p', 3),\n        ('conv_3x3_n_back', 2),\n        ('conv_3x3_p_back', 2)\n    ],\n    normal_concat=range(2, 5)\n)\n# dvsc10_new_skip12 = Genotype(\n#     normal=[\n#         ('conv_3x3_p', 0), ('conv_3x3_n', 1),\n#         ('conv_3x3_p', 1), ('conv_5x5_n', 2),\n#         ('conv_3x3_n', 3), ('conv_3x3_p', 0),\n#         ('conv_5x5_p_back', 2), ('conv_3x3_p_back', 3)\n#     ],\n#     normal_concat=range(2, 5)\n# )\n\ndvsc10_new_skip11 = Genotype(normal=[\n    ('conv_3x3_n', 0), ('conv_5x5_n', 1),\n    ('conv_5x5_p', 0), ('conv_3x3_n', 2),\n    ('conv_3x3_p', 2), ('conv_5x5_n', 0),\n    ('conv_3x3_n_back', 2),\n    ('conv_3x3_p_back', 3)],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_new_skip10 = Genotype(\n    normal=[\n        ('conv_5x5_n', 1), ('conv_3x3_p', 0),\n        ('conv_5x5_p', 2), ('conv_5x5_p', 1),\n        ('conv_3x3_p', 2), ('conv_5x5_n', 1),\n        ('conv_3x3_n_back', 2),\n        ('conv_3x3_n_back', 2)],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_new_skip9 = Genotype(\n    normal=[\n        ('conv_5x5_p', 1), ('conv_5x5_n', 0),\n        ('conv_5x5_p', 2), ('conv_5x5_n', 0),\n        ('conv_3x3_p', 2), ('conv_5x5_p', 3),\n        ('conv_3x3_p_back', 2),\n        ('conv_5x5_n_back', 3)\n    ],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_new_skip8 = Genotype(\n    normal=[\n        ('conv_5x5_n', 0), ('conv_5x5_n', 1),\n        ('conv_3x3_n', 2), ('conv_5x5_p', 0),\n        ('conv_3x3_p', 2), ('conv_5x5_n', 1),\n        ('conv_5x5_n_back', 2),\n        ('conv_3x3_p_back', 3)],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_new_skip7 = Genotype(\n    normal=[\n        ('conv_5x5_p', 0), ('conv_5x5_p', 1),\n        ('conv_3x3_n', 2), ('conv_5x5_n', 0),\n        ('conv_3x3_p', 2), ('conv_5x5_p', 3),\n        ('conv_5x5_n_back', 2),\n        ('conv_3x3_p_back', 3)],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_new_skip6 = Genotype(\n    normal=[\n        ('conv_3x3_p', 0), ('conv_5x5_n', 1),\n        ('conv_5x5_p', 2), ('conv_5x5_p', 1),\n        ('conv_3x3_p', 2), ('conv_5x5_n', 0),\n        ('conv_3x3_n_back', 2), ('conv_3x3_n_back', 2)],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_new_skip5 = Genotype(\n    normal=[\n        ('conv_3x3_p', 0), ('conv_3x3_p', 1),\n        ('conv_3x3_n', 2), ('conv_3x3_n', 0),\n        ('conv_3x3_p', 2), ('conv_3x3_p', 3),\n        ('conv_5x5_n_back', 2),\n        ('conv_3x3_p_back', 2)],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_new_skip4 = Genotype(\n    normal=[\n        ('conv_5x5_n', 1), ('conv_5x5_p', 0),\n        ('conv_3x3_p', 2), ('conv_5x5_p', 1),\n        ('conv_3x3_p', 2), ('conv_5x5_n', 0),\n        ('conv_3x3_p_back', 2),\n        ('conv_3x3_p_back', 3)],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_new_skip3 = Genotype(\n    normal=[\n        ('conv_5x5_p', 0), ('conv_3x3_p', 1),\n        ('conv_3x3_n', 2), ('conv_3x3_n', 0),\n        ('conv_3x3_p', 2), ('conv_5x5_p', 3),\n        ('conv_5x5_n_back', 2),\n        ('conv_3x3_p_back', 3)\n    ],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_new_skip2 = Genotype(\n    normal=[\n        ('avg_pool_3x3_p', 0), ('avg_pool_3x3_p', 1),\n        ('avg_pool_3x3_p', 0), ('avg_pool_3x3_p', 1),\n        ('conv_3x3_p', 2), ('avg_pool_3x3_n', 0),\n        ('avg_pool_3x3_n_back', 2),\n        ('conv_3x3_p_back', 2)],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_new_skip1 = Genotype(\n    normal=[\n        ('conv_5x5_p', 0), ('conv_3x3_n', 1),\n        ('conv_3x3_n', 2), ('conv_3x3_p', 1),\n        ('conv_5x5_p', 1), ('conv_3x3_p', 2),\n        ('conv_3x3_p_back', 2),\n        ('conv_3x3_p_back', 2)],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_new_skip = Genotype(\n    normal=[\n        ('conv_3x3_n', 1), ('conv_3x3_p', 0),\n        ('conv_3x3_p', 0), ('avg_pool_3x3_p', 1),\n        ('conv_3x3_p', 2), ('conv_3x3_n', 0),\n        ('conv_3x3_p_back', 2),\n        ('conv_3x3_p_back', 2)],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_new_base0 = Genotype(\n    normal=[\n        ('avg_pool_3x3_p', 1), ('avg_pool_3x3_p', 0),\n        ('avg_pool_3x3_n', 2), ('avg_pool_3x3_p', 1),\n        ('avg_pool_3x3_n', 2), ('avg_pool_3x3_n', 3),\n        ('avg_pool_3x3_n_back', 2),\n        ('avg_pool_3x3_n_back', 3)],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_new_base1 = Genotype(\n    normal=[\n        ('conv_3x3_p', 1), ('conv_5x5_n', 0),\n        ('conv_5x5_p', 1), ('conv_3x3_p', 0),\n        ('conv_5x5_n', 1), ('conv_3x3_p', 0),\n        ('avg_pool_3x3_p_back', 2),\n        ('conv_3x3_p_back', 3)\n    ],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_new_base2 = Genotype(\n    normal=[\n        ('conv_5x5_p', 0), ('conv_3x3_p', 1),\n        ('conv_5x5_n', 1), ('avg_pool_3x3_p', 0),\n        ('avg_pool_3x3_n', 3), ('conv_5x5_n', 1),\n        ('avg_pool_3x3_n_back', 2),\n        ('avg_pool_3x3_n_back', 2)],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_new_base3 = Genotype(\n    normal=[\n        ('avg_pool_3x3_p', 0), ('conv_5x5_p', 1),\n        ('conv_3x3_p', 1), ('conv_3x3_n', 0),\n        ('conv_5x5_p', 1), ('conv_3x3_n', 0),\n        ('conv_3x3_p_back', 2),\n        ('avg_pool_3x3_n_back', 3)],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_grad2 = Genotype(\n    normal=[\n        ('avg_pool_3x3_n', 1), ('conv_5x5_p', 0),\n        ('conv_5x5_n', 1), ('conv_5x5_n', 0),\n        ('conv_3x3_p', 3), ('conv_5x5_n', 1),\n        ('conv_5x5_p_back', 2),\n        ('conv_3x3_p_back', 2)],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_grad1 = Genotype(\n    normal=[\n        ('avg_pool_3x3_p', 1), ('conv_5x5_p', 0),\n        ('avg_pool_3x3_n', 2), ('avg_pool_3x3_n', 1),\n        ('avg_pool_3x3_p', 2), ('conv_5x5_n', 1),\n        ('conv_5x5_p_back', 2),\n        ('conv_3x3_p_back', 3)],\n    normal_concat=range(2, 5))\n\ndvsg_new2 = Genotype(\n    normal=[\n        ('avg_pool_3x3_p', 1), ('conv_5x5_p', 0),\n        ('conv_3x3_p', 1), ('conv_3x3_p', 0),\n        ('conv_3x3_p', 1), ('avg_pool_3x3_p', 0),\n        ('avg_pool_3x3_n_back', 2),\n        ('avg_pool_3x3_n_back', 3)],\n    normal_concat=range(2, 5))\n\ndvsg_new1 = Genotype(\n    normal=[\n        ('avg_pool_3x3_p', 1), ('conv_5x5_p', 0),\n        ('conv_3x3_p', 1), ('conv_3x3_p', 0),\n        ('conv_3x3_p', 1),  ('avg_pool_3x3_p', 0),\n        ('avg_pool_3x3_n_back', 2),\n        ('conv_5x5_n_back', 3)],\n    normal_concat=range(2, 5))\n\ndvscal_new1 = Genotype(\n    normal=[\n        ('conv_5x5_n', 0), ('conv_5x5_n', 1),\n        ('conv_5x5_n', 1), ('conv_5x5_p', 0),\n        ('avg_pool_3x3_p', 1), ('conv_5x5_p', 0),\n        ('avg_pool_3x3_n_back', 2),\n        ('avg_pool_3x3_n_back', 2)],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_new8 = Genotype(\n    normal=[('conv_5x5_p', 0), ('conv_5x5_p', 1),\n            ('conv_3x3_p', 0), ('conv_5x5_n', 1),\n            ('conv_5x5_p', 0), ('conv_5x5_n', 1),\n            ('avg_pool_3x3_n_back', 2),\n            ('avg_pool_3x3_n_back', 3)],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_new7 = Genotype(\n    normal=[\n        ('conv_5x5_p', 0), ('conv_5x5_p', 1),\n        ('conv_3x3_p', 0), ('conv_5x5_n', 1),\n        ('conv_5x5_p', 0), ('conv_5x5_n', 1),\n        ('conv_3x3_n_back', 2),\n        ('avg_pool_3x3_n_back', 2)],\n    normal_concat=range(2, 5))\n\ndvsc10_new6 = Genotype(\n    normal=[\n        ('conv_3x3_p', 1), ('conv_3x3_p', 0),\n        ('conv_3x3_p', 0), ('conv_3x3_p', 1),\n        ('conv_3x3_p', 0), ('avg_pool_3x3_p', 1),\n        ('avg_pool_3x3_n_back', 2),\n        ('avg_pool_3x3_n_back', 2)],\n    normal_concat=range(2, 5))\n\ndvsc10_new5 = Genotype(\n    normal=[\n        ('conv_5x5_p', 1), ('conv_3x3_p', 0),\n        ('conv_3x3_p', 0), ('conv_5x5_p', 1),\n        ('conv_3x3_p', 0), ('avg_pool_3x3_p', 1),\n        ('avg_pool_3x3_n_back', 2),\n        ('avg_pool_3x3_n_back', 2)],\n    normal_concat=range(2, 5))\n\ndvsc10_new4 = Genotype(\n    normal=[\n        ('conv_3x3_n', 1), ('conv_3x3_p', 0),\n        ('conv_5x5_p', 1), ('conv_5x5_p', 0),\n        ('conv_5x5_p', 1), ('conv_5x5_p', 0),\n        ('avg_pool_3x3_p_back', 2),\n        ('avg_pool_3x3_n_back', 2)],\n    normal_concat=range(2, 5),\n)\n\ndvsc10_new3 = Genotype(\n    normal=[\n        ('avg_pool_3x3_p', 0), ('conv_3x3_n', 1),\n        ('conv_3x3_n', 1), ('conv_3x3_n', 0),\n        ('avg_pool_3x3_p', 2), ('conv_3x3_n', 1),\n        ('avg_pool_3x3_n_back', 2),\n        ('avg_pool_3x3_p', 2)],\n    normal_concat=range(2, 5),\n)\n\ndvsc10_new2 = Genotype(normal=[\n    ('conv_3x3_p', 0), ('conv_3x3_n', 1),\n    ('conv_3x3_n', 1), ('avg_pool_3x3_p', 0),\n    ('avg_pool_3x3_p', 2), ('conv_3x3_n', 1),\n    ('avg_pool_3x3_n_back', 2),\n    ('conv_3x3_p_back', 2)],\n    normal_concat=range(2, 5),\n)\n\ndvsc10_new1 = Genotype(\n    normal=[\n        ('conv_3x3_p', 1), ('avg_pool_3x3_p', 0),\n        ('avg_pool_3x3_p', 0), ('conv_3x3_n', 1),\n        ('conv_3x3_p', 0), ('conv_3x3_p', 1),\n        ('conv_3x3_p_back', 2),\n        ('conv_3x3_n_back', 2)],\n    normal_concat=range(2, 5)\n)\n\ndvsc10_new0 = Genotype(\n    normal=[\n        ('conv_3x3_p', 1), ('avg_pool_3x3_p', 0),\n        ('avg_pool_3x3_p', 2), ('conv_3x3_n', 1),\n        ('conv_3x3_p', 0), ('conv_3x3_p', 3),\n        ('conv_3x3_p_back', 2),\n        ('conv_3x3_n_back', 3)],\n    normal_concat=range(2, 5)\n)\ncifar_new_skip1 = Genotype(\n    normal=[\n        ('conv_5x5_n', 0), ('conv_5x5_p', 1),\n        ('avg_pool_3x3_p', 0), ('avg_pool_3x3_n', 2),\n        ('avg_pool_3x3_p', 2), ('conv_5x5_p', 0),\n        ('avg_pool_3x3_n_back', 2),\n        ('avg_pool_3x3_p_back', 3)\n    ],\n    normal_concat=range(2, 5))\n\ncifar_new1 = Genotype(\n    normal=[\n        ('avg_pool_3x3_p', 1), ('avg_pool_3x3_p', 0),\n        ('conv_3x3_n', 0), ('avg_pool_3x3_p', 1),\n        ('avg_pool_3x3_p', 2), ('conv_3x3_p', 0),\n        ('avg_pool_3x3_n_back', 2),\n        ('conv_3x3_p_back', 2)],\n    normal_concat=range(2, 5)\n)\n\ncifar_new2 = Genotype(\n    normal=[\n        ('conv_3x3_n', 0), ('avg_pool_3x3_p', 1),\n        ('conv_3x3_p', 0), ('avg_pool_3x3_p', 1),\n        ('conv_3x3_p', 2), ('conv_3x3_n', 0),\n        ('conv_3x3_n_back', 2),\n        ('conv_3x3_p_back', 2)],\n    normal_concat=range(2, 5),\n)\n\ncifar_new0 = Genotype(\n    normal=[\n        ('avg_pool_3x3_p', 1), ('avg_pool_3x3_n', 0),  # 2, 3\n        ('conv_3x3_n', 0), ('avg_pool_3x3_p', 1),  # 4, 5\n        ('conv_3x3_p', 2), ('conv_3x3_n', 3),  # 6 , 7\n        ('avg_pool_3x3_n_back', 2),\n        ('conv_3x3_p_back', 1)],\n    normal_concat=range(2, 5)\n)\n"
  },
  {
    "path": "braincog/model_zoo/NeuEvo/model.py",
    "content": "from functools import partial\nfrom typing import List, Type\n\nfrom braincog.model_zoo.NeuEvo.operations import *\nfrom braincog.model_zoo.NeuEvo.genotypes import Genotype\nfrom braincog.base.utils import drop_path\nfrom timm.models import register_model\nfrom braincog.base.node.node import *\nfrom braincog.base.connection.layer import *\nfrom braincog.model_zoo.base_module import BaseModule\n\n\nclass MlpCell(BaseModule):\n    def __init__(\n        self,\n        genotype: Genotype,\n        C: int,\n        input_dim: int,\n        output_dim: int,\n        encode_type: str = 'direct',\n        activation_fn: Type[nn.Module] = LIFNode,\n        squash_output: bool = False,\n        back_connection: bool = True,\n        step: int = 10,\n        **kwargs\n    ):\n        super(MlpCell, self).__init__(\n            step=step,\n            encode_type=encode_type,\n            layer_by_layer=True\n        )\n        # print(activation_fn, step)\n        self.act_fun = partial(activation_fn, step=step, layer_by_layer=self.layer_by_layer, **kwargs)\n        self.back_connection = back_connection\n        op_names, indices = zip(*genotype.normal)\n        concat = genotype.normal_concat\n        self._compile(C, op_names, indices, concat)\n\n        self.feature = nn.Sequential(\n            nn.Linear(input_dim, C),\n            self.act_fun(),\n        )\n        if output_dim > 0:\n            self.output_fn = nn.Linear(self.multiplier * C, output_dim)\n        elif squash_output:\n            self.output_fn = nn.Tanh()\n        else:\n            self.output_fn = nn.Identity()\n\n    def _compile(self, C, op_names, indices, concat):\n        assert len(op_names) == len(indices)\n        # self._steps = len(op_names) // 2\n        self._concat = concat\n        self.multiplier = len(concat)\n\n        self._ops = nn.ModuleList()\n        self._ops_back = nn.ModuleList()\n        back_begin_index = 0\n        for i, (name, index) in enumerate(zip(op_names, indices)):\n            # print(name, index)\n            if '_back' in name:\n                back_begin_index = i\n                break\n            op = OPS_Mlp[name](C, act_fun=self.act_fun)\n            self._ops += [op]\n\n        if self.back_connection:\n            for name, index in zip(op_names[back_begin_index:], indices[back_begin_index:]):\n                op = OPS_Mlp[name.replace('_back', '')](\n                    C, act_fun=self.act_fun)\n                self._ops_back += [op]\n\n        if self.back_connection:\n            self._indices_forward = indices[:back_begin_index]\n            self._indices_backward = indices[back_begin_index:]\n        else:\n            self._indices_backward = []\n            self._indices_forward = indices\n        self._steps = len(self._indices_forward) // 2\n\n    def _forward_once(self, s0, s1, drop_prob):\n\n        states = [s0, s1]\n        for i in range(self._steps):\n            h1 = states[self._indices_forward[2 * i]]\n            h2 = states[self._indices_forward[2 * i + 1]]\n            op1 = self._ops[2 * i]\n            op2 = self._ops[2 * i + 1]\n            h1 = op1(h1)\n            h2 = op2(h2)\n            if self.training and drop_prob > 0.:\n                if not isinstance(op1, Identity):\n                    h1 = drop_path(h1, drop_prob)\n                if not isinstance(op2, Identity):\n                    h2 = drop_path(h2, drop_prob)\n            s = h1 + h2\n            if self.back_connection:\n                if i != 0:\n                    s_back = self._ops_back[i - 1](s)\n                    states[self._indices_backward[i - 1]\n                           ] = states[self._indices_backward[i - 1]] + s_back\n            states += [s]\n\n        outputs = []\n        for i in self._concat:\n            outputs.append(rearrange(states[i], '(t b) c -> t b c', t=self.step))\n        outputs = torch.cat(outputs, dim=2)  # T, B, C\n\n        return outputs\n\n    def forward(self, inputs):\n        inputs = self.encoder(inputs)\n        self.reset()\n\n        if self.layer_by_layer:\n            x = self.feature(inputs)\n            x = self._forward_once(x, x, 0.)\n            x = self.output_fn(x)\n            x = x.mean(0)\n        else:\n            raise NotImplementedError\n\n        return x\n\n\nclass Cell(nn.Module):\n    def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev, act_fun, back_connection):\n        # print(C_prev_prev, C_prev, C, reduction)\n        super(Cell, self).__init__()\n        self.act_fun = act_fun\n        self.back_connection = back_connection\n        self.reduction = reduction\n        if reduction:\n            self.fun = FactorizedReduce(\n                C_prev, C * 3, act_fun=act_fun\n            )\n            self.multiplier = 3\n        else:\n            if reduction_prev:\n                self.preprocess0 = FactorizedReduce(\n                    C_prev_prev, C, act_fun=act_fun)\n            else:\n                self.preprocess0 = ReLUConvBN(\n                    C_prev_prev, C, 1, 1, 0, act_fun=act_fun)\n            self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, act_fun=act_fun)\n\n            op_names, indices = zip(*genotype.normal)\n            concat = genotype.normal_concat\n            self._compile(C, op_names, indices, concat, reduction)\n\n    def _compile(self, C, op_names, indices, concat, reduction):\n        assert len(op_names) == len(indices)\n        # self._steps = len(op_names) // 2\n        self._concat = concat\n        self.multiplier = len(concat)\n\n        self._ops = nn.ModuleList()\n        self._ops_back = nn.ModuleList()\n        back_begin_index = 0\n        for i, (name, index) in enumerate(zip(op_names, indices)):\n            # print(name, index)\n            if '_back' in name:\n                back_begin_index = i\n                break\n            stride = 2 if reduction and index < 2 else 1\n            op = OPS[name](C, stride, True, act_fun=self.act_fun)\n            self._ops += [op]\n\n        if self.back_connection:\n            for name, index in zip(op_names[back_begin_index:], indices[back_begin_index:]):\n                op = OPS[name.replace('_back', '')](\n                    C, 1, True, act_fun=self.act_fun)\n                self._ops_back += [op]\n\n        if self.back_connection:\n            self._indices_forward = indices[:back_begin_index]\n            self._indices_backward = indices[back_begin_index:]\n        else:\n            self._indices_backward = []\n            self._indices_forward = indices\n        self._steps = len(self._indices_forward) // 2\n\n    def forward(self, s0, s1, drop_prob):\n        if self.reduction:\n            return self.fun(s1)\n\n        s0 = self.preprocess0(s0)\n        s1 = self.preprocess1(s1)\n\n        states = [s0, s1]\n        for i in range(self._steps):\n            h1 = states[self._indices_forward[2 * i]]\n            h2 = states[self._indices_forward[2 * i + 1]]\n            op1 = self._ops[2 * i]\n            op2 = self._ops[2 * i + 1]\n            h1 = op1(h1)\n            h2 = op2(h2)\n            if self.training and drop_prob > 0.:\n                if not isinstance(op1, Identity):\n                    h1 = drop_path(h1, drop_prob)\n                if not isinstance(op2, Identity):\n                    h2 = drop_path(h2, drop_prob)\n            s = h1 + h2\n            if self.back_connection:\n                if i != 0:\n                    s_back = self._ops_back[i - 1](s)\n                    states[self._indices_backward[i - 1]\n                           ] = states[self._indices_backward[i - 1]] + s_back\n            states += [s]\n        outputs = torch.cat([states[i]\n                            for i in self._concat], dim=1)  # N，C，H, W\n        return outputs\n        # return self.node(outputs)\n\n\nclass DCOCell(nn.Module):\n    def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev, act_fun):\n        super(DCOCell, self).__init__()\n        self.act_fun = act_fun\n\n        if reduction_prev:\n            self.preprocess0 = FactorizedReduce(C_prev_prev, C)\n        else:\n            self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)\n        self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)\n\n        if reduction:\n            op_names, tos, froms = zip(*genotype.reduce)\n        else:\n            op_names, tos, froms = zip(*genotype.normal)\n        self._compile(C, op_names, tos, froms, reduction)\n\n    def _compile(self, C, op_names, tos, froms, reduction):\n        self._ops = nn.ModuleDict()\n        for name_i, to_i, from_i in zip(op_names, tos, froms):\n            stride = 2 if reduction and from_i < 2 else 1\n            op = OPS[name_i](C, stride, True, act_fun=self.act_fun)\n            if str(to_i) in self._ops.keys():\n                if str(from_i) in self._ops[str(to_i)]:\n                    self._ops[str(to_i)][str(from_i)] += [op]\n                else:\n                    self._ops[str(to_i)][str(from_i)] = nn.ModuleList()\n                    self._ops[str(to_i)][str(from_i)] += [op]\n            else:\n                self._ops[str(to_i)] = nn.ModuleDict()\n                self._ops[str(to_i)][str(from_i)] = nn.ModuleList()\n                self._ops[str(to_i)][str(from_i)] += [op]\n\n        # TODO: Some intermediate node maybe no selected during search.\n        self.multiplier = len(self._ops)\n\n    def forward(self, s0, s1, drop_prob):\n        s0 = self.preprocess0(s0)\n        s1 = self.preprocess1(s1)\n\n        states = {}\n        states['0'] = s0\n        states['1'] = s1\n\n        # get all the operations in current intermediate node\n        for to_i, ops in self._ops.items():\n            h = []\n            for from_i, op_i in ops.items():\n                # each edge may no more than one operation\n                if from_i not in states:\n                    # print('Exist the isolate node, which id is {}, we need ignore it!'.format(from_i))\n                    continue\n                h += [sum([op(states[from_i])\n                          for op in op_i if from_i in states])]\n            out = sum(h)\n            if self.training and drop_prob > 0:\n                out = drop_path(out, drop_prob)\n            states[to_i] = out\n\n        outputs = torch.cat([v for v in states.values()][2:], dim=1)\n        # return outputs\n        return outputs\n\n\nclass AuxiliaryHeadCIFAR(nn.Module):\n    def __init__(self, C, num_classes, act_fun):\n        \"\"\"assuming inputs size 8x8\"\"\"\n        super(AuxiliaryHeadCIFAR, self).__init__()\n        self.act_fun = act_fun\n        self.features = nn.Sequential(\n            # nn.ReLU(inplace=True),\n            self.act_fun(),\n            # image size = 2 x 2\n            nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False),\n            nn.Conv2d(C, 128, 1, bias=False),\n            nn.BatchNorm2d(128),\n            # nn.ReLU(inplace=True),\n            self.act_fun(),\n            nn.Conv2d(128, 768, 2, bias=False),\n            nn.BatchNorm2d(768),\n            # nn.ReLU(inplace=True)\n            self.act_fun()\n        )\n        self.classifier = nn.Linear(768, num_classes)\n\n    def forward(self, x):\n        x = self.features(x)\n        x = self.classifier(x.view(x.size(0), -1))\n        return x\n\n\nclass AuxiliaryHeadImageNet(nn.Module):\n\n    def __init__(self, C, num_classes):\n        \"\"\"assuming inputs size 14x14\"\"\"\n        super(AuxiliaryHeadImageNet, self).__init__()\n        self.features = nn.Sequential(\n            nn.ReLU(inplace=True),\n            nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False),\n            nn.Conv2d(C, 128, 1, bias=False),\n            nn.BatchNorm2d(128),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(128, 768, 2, bias=False),\n            # NOTE: This batchnorm was omitted in my earlier implementation due to a typo.\n            # Commenting it out for consistency with the experiments in the paper.\n            # nn.BatchNorm2d(768),\n            nn.ReLU(inplace=True)\n        )\n        self.classifier = nn.Linear(768, num_classes)\n\n    def forward(self, x):\n        x = self.features(x)\n        x = self.classifier(x.view(x.size(0), -1))\n        return x\n\n\n@register_model\nclass NetworkCIFAR(BaseModule):\n\n    def __init__(self,\n                 C,\n                 num_classes,\n                 layers,\n                 auxiliary,\n                 genotype,\n                 parse_method='darts',\n                 step=1,\n                 node_type='ReLUNode',\n                 **kwargs):\n        super(NetworkCIFAR, self).__init__(\n            step=step,\n            num_classes=num_classes,\n            **kwargs\n        )\n        if isinstance(node_type, str):\n            self.act_fun = eval(node_type)\n        else:\n            self.act_fun = node_type\n        self.act_fun = partial(self.act_fun, **kwargs)\n\n        if 'back_connection' in kwargs.keys():\n            self.back_connection = kwargs['back_connection']\n        else:\n            self.back_connection = False\n\n        self.spike_output = kwargs['spike_output'] if 'spike_output' in kwargs else True\n        self.dataset = kwargs['dataset']\n\n        if self.layer_by_layer:\n            self.flatten = nn.Flatten(start_dim=1)\n        else:\n            self.flatten = nn.Flatten()\n\n        self._layers = layers\n        self._auxiliary = auxiliary\n        self.drop_path_prob = 0\n\n        stem_multiplier = 3\n        C_curr = stem_multiplier * C\n        if self.dataset == 'dvsg' or self.dataset == 'dvsc10' or self.dataset == 'NCALTECH101':\n            self.stem = nn.Sequential(\n                nn.Conv2d(2 * self.init_channel_mul, C_curr, 3, padding=1, bias=False),\n                nn.BatchNorm2d(C_curr),\n            )\n            # self.reduce_idx = [\n            #     layers // 4,\n            #     layers // 2,\n            #     3 * layers // 4\n            # ]\n            self.reduce_idx = [1, 3, 5, 7]\n        else:\n            self.stem = nn.Sequential(\n                nn.Conv2d(3 * self.init_channel_mul, C_curr, 3, padding=1, bias=False),\n                nn.BatchNorm2d(C_curr),\n            )\n            self.reduce_idx = [layers // 4,\n                               layers // 2,\n                               3 * layers // 4]\n\n        C_prev_prev, C_prev, C_curr = C_curr, C_curr, C\n\n        self.cells = nn.ModuleList()\n        reduction_prev = False\n        for i in range(layers):\n            if i in self.reduce_idx:\n                C_curr *= 2\n                reduction = True\n            else:\n                reduction = False\n            if parse_method == 'darts':\n                cell = Cell(genotype, C_prev_prev, C_prev, C_curr,\n                            reduction, reduction_prev,\n                            act_fun=self.act_fun, back_connection=self.back_connection)\n            else:\n                cell = DCOCell(genotype, C_prev_prev, C_prev, C_curr,\n                               reduction, reduction_prev, act_fun=self.act_fun)\n            reduction_prev = reduction\n            self.cells += [cell]\n            C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr\n            if i == 2 * layers // 3:\n                C_to_auxiliary = C_prev\n\n        if auxiliary:\n            self.auxiliary_head = AuxiliaryHeadCIFAR(\n                C_to_auxiliary, num_classes, act_fun=self.act_fun)\n        self.global_pooling = nn.Sequential(\n            self.act_fun(), nn.AdaptiveAvgPool2d(1))\n\n        if self.spike_output:\n            self.classifier = nn.Sequential(\n                nn.Linear(C_prev, 10 * num_classes),\n                self.act_fun())\n            self.vote = VotingLayer(10)\n        else:\n            self.classifier = nn.Linear(C_prev, num_classes)\n            self.vote = nn.Identity()\n\n        # self.classifier = nn.Linear(C_prev, num_classes)\n        # self.vote = nn.Identity()\n\n    def forward(self, inputs):\n        logits_aux = None\n        inputs = self.encoder(inputs)\n        if not self.layer_by_layer:\n            outputs = []\n            output_aux = []\n            self.reset()\n            for t in range(self.step):\n                x = inputs[t]\n                s0 = s1 = self.stem(x)\n                for i, cell in enumerate(self.cells):\n                    s0, s1 = s1, cell(s0, s1, self.drop_path_prob)\n                    # print(s0.shape, s1.shape)\n                    # if i == 2 * self._layers // 3:\n                    #     if self._auxiliary and self.training:\n                    #         logits_aux = self.auxiliary_head(s1)\n                out = self.global_pooling(s1)\n                out = self.classifier(self.flatten(out))\n                logits = self.vote(out)\n                outputs.append(logits)\n                output_aux.append(logits_aux)\n            return sum(outputs) / len(outputs)\n            # logits_aux if logits_aux is None else (sum(output_aux) / len(output_aux))\n        else:\n            s0 = s1 = self.stem(inputs)\n            for i, cell in enumerate(self.cells):\n                s0, s1 = s1, cell(s0, s1, self.drop_path_prob)\n                if i == 2 * self._layers // 3:\n                    if self._auxiliary and self.training:\n                        logits_aux = self.auxiliary_head(s1)\n            out = self.global_pooling(s1)\n            out = self.classifier(self.flatten(out))\n            out = rearrange(out, '(t b) c -> t b c', t=self.step).mean(0)\n            logits = self.vote(out)\n            return logits\n\n\n@register_model\nclass NetworkImageNet(BaseModule):\n\n    def __init__(self,\n                 C,\n                 num_classes,\n                 layers,\n                 auxiliary,\n                 genotype,\n                 step=1,\n                 node_type='ReLUNode',\n                 **kwargs):\n        super(NetworkImageNet, self).__init__(\n            step=step,\n            num_classes=num_classes,\n            **kwargs)\n\n        if isinstance(node_type, str):\n            self.act_fun = eval(node_type)\n        else:\n            self.act_fun = node_type\n        self.act_fun = partial(self.act_fun, **kwargs)\n\n        if 'back_connection' in kwargs.keys():\n            self.back_connection = kwargs['back_connection']\n        else:\n            self.back_connection = False\n\n        self.spike_output = kwargs['spike_output'] if 'spike_output' in kwargs else True\n\n        if self.layer_by_layer:\n            self.flatten = nn.Flatten(start_dim=1)\n        else:\n            self.flatten = nn.Flatten()\n\n        self._layers = layers\n        self._auxiliary = auxiliary\n        self.drop_path_prob = 0\n\n        self.stem0 = nn.Sequential(\n            nn.Conv2d(3, C // 2, kernel_size=3,\n                      stride=2, padding=1, bias=False),\n            nn.BatchNorm2d(C // 2),\n            # nn.ReLU(inplace=True),\n            self.act_fun(),\n            nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),\n            nn.BatchNorm2d(C),\n        )\n\n        self.stem1 = nn.Sequential(\n            # nn.ReLU(inplace=True),\n            self.act_fun(),\n            nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),\n            nn.BatchNorm2d(C),\n        )\n\n        C_prev_prev, C_prev, C_curr = C, C, C\n\n        self.cells = nn.ModuleList()\n        reduction_prev = True\n        for i in range(layers):\n            if i in [layers // 3, 2 * layers // 3]:\n                C_curr *= 2\n                reduction = True\n            else:\n                reduction = False\n            cell = Cell(genotype, C_prev_prev, C_prev,\n                        C_curr, reduction, reduction_prev,\n                        act_fun=self.act_fun, back_connection=self.back_connection)\n            reduction_prev = reduction\n            self.cells += [cell]\n            C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr\n\n        self.global_pooling = nn.AvgPool2d(7)\n        self.classifier = nn.Linear(C_prev, num_classes)\n\n    def forward(self, inputs):\n        outputs = []\n        self.reset()\n        for t in range(self.step):\n            s0 = self.stem0(inputs)\n            s1 = self.stem1(s0)\n            for i, cell in enumerate(self.cells):\n                s0, s1 = s1, cell(s0, s1, self.drop_path_prob)\n            out = self.global_pooling(s1)\n            logits = self.classifier(self.flatten(out))\n            outputs.append(logits)\n        return sum(outputs) / len(outputs)\n\n\nif __name__ == '__main__':\n    from braincog.model_zoo.NeuEvo.genotypes import mlp2\n    cell = MlpCell(mlp2, C=128, input_dim=17, output_dim=-1)\n    x = torch.rand(4, 17)\n    out = cell(x)\n    print(out)\n    print(out.shape)\n"
  },
  {
    "path": "braincog/model_zoo/NeuEvo/model_search.py",
    "content": "from functools import partial\nfrom braincog.model_zoo.NeuEvo.operations import *\nfrom torch.autograd import Variable\nfrom braincog.model_zoo.NeuEvo.genotypes import PRIMITIVES\nfrom braincog.model_zoo.NeuEvo.genotypes import Genotype\nfrom . import parse\n\nfrom braincog.base.connection.layer import VotingLayer\nfrom braincog.base.node.node import *\nfrom braincog.model_zoo.base_module import BaseModule\nfrom . import forward_edge_num\nfrom . import edge_num\n\n\ndef calc_weight(x):\n    tmp0 = torch.split(x[0], edge_num, dim=0)\n    tmp1 = torch.split(x[1], edge_num, dim=0)\n    res = []\n    for i in range(len(edge_num)):\n        res.append(\n            torch.softmax(tmp0[i].view(-1), dim=-1).view(tmp0[i].shape)\n            + torch.softmax(tmp1[i].view(-1), dim=-1).view(tmp1[i].shape)\n        )\n    return torch.cat(res, dim=0)\n\n\ndef calc_loss(x):\n    tmp0 = torch.split(x[0], edge_num, dim=0)\n    tmp1 = torch.split(x[1], edge_num, dim=0)\n    res = []\n    for i in range(len(edge_num)):\n        res.append(\n            torch.softmax(tmp0[i].view(-1), dim=-1).view(tmp0[i].shape)\n            - torch.softmax(tmp1[i].view(-1), dim=-1).view(tmp1[i].shape)\n        )\n    return torch.cat(res, dim=0)\n\n\nclass darts_fun(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, inputs, weights):  # feature map / arch weight\n        output = inputs * weights\n        ctx.save_for_backward(inputs, weights)\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):  # error signal\n        grad_inputs, grad_weights = None, None\n\n        inputs, weights = ctx.saved_tensors\n\n        if ctx.needs_input_grad[0]:\n            grad_inputs = grad_output * weights\n        if ctx.needs_input_grad[1]:\n            if torch.min(inputs) < -1e-12 and torch.max(inputs) > 1e-12:\n                inputs = torch.abs(inputs) / 2.\n            else:\n                inputs = torch.abs(inputs)\n            grad_weights = -inputs.mean()\n\n        return grad_inputs, grad_weights\n\n\nclass MixedOp(nn.Module):\n    def __init__(self, C, stride, act_fun):\n        super(MixedOp, self).__init__()\n        self._ops = nn.ModuleList()\n        for primitive in PRIMITIVES:\n            op = OPS[primitive](C, stride, False, act_fun)\n            if 'pool' in primitive:\n                op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False))\n            self._ops.append(op)\n\n        self.multiply = darts_fun.apply\n\n    def forward(self, x, weights):\n        feature_map = []\n        for i, op in enumerate(self._ops):\n            res = op(x)\n            feature_map.append(res)\n        return sum(self.multiply(mp, w) for w, mp in zip(weights, feature_map))\n\n\nclass Cell(nn.Module):\n\n    def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, act_fun, back_connection):\n        super(Cell, self).__init__()\n        self.reduction = reduction\n        self.back_connection = back_connection\n        if reduction:\n            self.fun = FactorizedReduce(\n                C_prev, C * multiplier, affine=True, act_fun=act_fun, positive=1\n            )\n        else:\n            if reduction_prev:\n                self.preprocess0 = FactorizedReduce(\n                    C_prev_prev, C, affine=False, act_fun=act_fun, positive=1)\n            else:\n                self.preprocess0 = ReLUConvBN(\n                    C_prev_prev, C, 1, 1, 0, affine=False, act_fun=act_fun, positive=1)\n            self.preprocess1 = ReLUConvBN(\n                C_prev, C, 1, 1, 0, affine=False, act_fun=act_fun, positive=1)\n            self._steps = steps\n            self._multiplier = multiplier\n\n            self._ops = nn.ModuleList()\n\n            for i in range(self._steps):\n                for j in range(2 + i):\n                    stride = 2 if reduction and j < 2 else 1\n                    op = MixedOp(C, stride, act_fun)\n                    self._ops.append(op)\n\n            if self.back_connection:\n                self._ops_back = nn.ModuleList()\n                for i in range(self._steps):\n                    for j in range(i):\n                        op = MixedOp(C, 1, act_fun)\n                        self._ops_back.append(op)\n\n    def forward(self, s0, s1, weights):\n        if self.reduction:\n            return self.fun(s1)\n\n        s0 = self.preprocess0(s0)\n        s1 = self.preprocess1(s1)\n\n        states = [s0, s1]\n        offset = 0\n        offset_back = 0\n\n        weights_forward = weights[:forward_edge_num]\n        weights_backward = weights[forward_edge_num:]\n        for i in range(self._steps):\n            s = sum(self._ops[offset + j](h, weights_forward[offset + j])\n                    for j, h in enumerate(states))\n            offset += len(states)\n\n            if self.back_connection:\n                for j in range(2, len(states)):\n                    # print(j, len(states), offset_back, len(self._ops_back))\n                    states[j] = states[j] + \\\n                        self._ops_back[offset_back](\n                            s, weights_backward[offset_back])\n                    offset_back += 1\n\n            states.append(s)\n\n        outputs = torch.cat(states[-self._multiplier:], dim=1)\n        return outputs\n\n\nclass Network(BaseModule):\n\n    def __init__(self, C, num_classes, layers, criterion, steps=4, multiplier=4, stem_multiplier=3,\n                 parse_method='bio_darts', op_threshold=None, step=1, node_type='ReLUNode', **kwargs):\n\n        super().__init__(\n            step=step,\n            encode_type='direct',\n            **kwargs\n        )\n\n        self.act_fun = eval(node_type)\n        self.act_fun = partial(self.act_fun, **kwargs)\n\n        self._C = C\n        self._num_classes = num_classes\n        self._layers = layers\n        self._criterion = criterion\n        self._steps = steps\n        self._multiplier = multiplier\n        self.parse_method = parse_method\n        self.op_threshold = op_threshold\n        self.fire_rate_per_step = [0.] * self.step\n        self.forward_step = 0\n        self.record_fire_rate = False\n        if 'back_connection' in kwargs.keys():\n            self.back_connection = kwargs['back_connection']\n        else:\n            self.back_connection = False\n        self.dataset = kwargs['dataset']\n        self.spike_output = kwargs['spike_output'] if 'spike_output' in kwargs else True\n\n        C_curr = stem_multiplier * C\n\n        if self.dataset == 'dvsg' or self.dataset == 'dvsc10' or self.dataset == 'NCALTECH101':\n            self.stem = nn.Sequential(\n                nn.Conv2d(2 * self.init_channel_mul, C_curr, 3, padding=1, bias=False),\n                nn.BatchNorm2d(C_curr),\n            )\n            self.reduce_idx = [layers // 3,\n                               2 * layers // 3]\n        else:\n            self.stem = nn.Sequential(\n                nn.Conv2d(3 * self.init_channel_mul, C_curr, 3, padding=1, bias=False),\n                nn.BatchNorm2d(C_curr),\n            )\n            self.reduce_idx = [1, 3, 5]\n\n        C_prev_prev, C_prev, C_curr = C_curr, C_curr, C\n        self.cells = nn.ModuleList()\n        reduction_prev = False\n        for i in range(layers):\n            if i in self.reduce_idx:\n                C_curr *= 2\n                reduction = True\n            else:\n                reduction = False\n            cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, self.act_fun,\n                        self.back_connection)\n            reduction_prev = reduction\n            self.cells += [cell]\n\n            C_prev_prev, C_prev = C_prev, multiplier * C_curr\n        self.global_pooling = nn.Sequential(\n            self.act_fun(), nn.AdaptiveAvgPool2d(1))\n        if self.spike_output:\n            self.classifier = nn.Sequential(\n                nn.Linear(C_prev, 10 * num_classes),\n                self.act_fun())\n            self.vote = VotingLayer(10)\n        else:\n            self.classifier = nn.Linear(C_prev, num_classes)\n            self.vote = nn.Identity()\n        self._initialize_alphas()\n\n    def new(self):\n        model_new = Network(self._C, self._num_classes,\n                            self._layers, self._criterion).cuda()\n        for x, y in zip(model_new.arch_parameters(), self.arch_parameters()):\n            x.data.copy_(y.data)\n        return model_new\n\n    def forward(self, inputs):\n        inputs = self.encoder(inputs)\n\n        self.reset()\n        if not self.training:\n            self.fire_rate.clear()\n\n        outputs = []\n        for t in range(self.step):\n            x = inputs[t]\n            s0 = s1 = self.stem(x)\n            for i, cell in enumerate(self.cells):\n                if not cell.reduction:\n                    weights = calc_weight(self.alphas_normal)\n                    s0, s1 = s1, cell(s0, s1, weights)\n                else:\n                    s0, s1 = s1, cell(s0, s1, None)\n            out = self.global_pooling(s1)\n            out = self.classifier(out.view(out.size(0), -1))\n            logits = self.vote(out)\n            outputs.append(logits)\n        # print(self.get_fire_rate_avg(), self.fire_rate_per_step, len(self.fire_rate_per_step))\n        if self.record_fire_rate:\n            self.forward_step += 1\n        return sum(outputs) / len(outputs)\n\n    def reset_fire_rate_record(self):\n        self.fire_rate_per_step = [0.] * self.step\n        self.forward_step = 0\n\n    def get_fire_per_step(self):\n        return [x / self.forward_step for x in self.fire_rate_per_step]\n\n    def _loss(self, input1, target1, input2):\n        logits = self(input1)\n        return self._criterion(logits, target1, input2)\n    # def _loss(self, input1, target1):\n    #     logits = self(input1)\n    #     return self._criterion(logits, target1)\n\n    def _initialize_alphas(self):\n        # k = 2 + 3 + 4 + 5 = 14\n        k = sum(1 for i in range(self._steps) for n in range(2 + i))\n        if self.back_connection:\n            k += sum(1 for i in range(self._steps) for n in range(i))\n        num_ops = len(PRIMITIVES)\n\n        self.alphas_normal = Variable(\n            0.5 * torch.randn(2, k, num_ops).cuda(), requires_grad=True)\n\n        # init the history\n        self.alphas_normal_history = {}\n        mm = 0\n        last_id = 1\n        node_id = 0\n        for i in range(k):\n            for j in range(num_ops):\n                self.alphas_normal_history['edge: {}, op: {}'.format(\n                    (node_id, mm), PRIMITIVES[j])] = []\n            if mm == last_id:\n                mm = 0\n                last_id += 1\n                node_id += 1\n            else:\n                mm += 1\n\n    def arch_parameters(self):\n        return [self.alphas_normal]\n\n    def genotype(self):\n\n        # alphas_normal\n        gene_normal = parse(calc_weight(self.alphas_normal).data.cpu().numpy(),\n                            PRIMITIVES, self.op_threshold, self.parse_method,\n                            self._steps, reduction=False, back_connection=self.back_connection)\n\n        concat = range(2 + self._steps - self._multiplier, self._steps + 2)\n        genotype = Genotype(\n            normal=gene_normal, normal_concat=concat,\n        )\n        return genotype\n\n    def states(self):\n        return {\n            'alphas_normal': self.alphas_normal,\n            'alphas_normal_history': self.alphas_normal_history,\n            'criterion': self._criterion\n        }\n\n    def restore(self, states):\n        self.alphas_normal = states['alphas_normal']\n        self.alphas_normal_history = states['alphas_normal_history']\n\n    def update_history(self):\n\n        mm = 0\n        last_id = 1\n        node_id = 0\n        weights1 = calc_weight(self.alphas_normal).data.cpu().numpy()\n\n        k, num_ops = weights1.shape\n        for i in range(k):\n            for j in range(num_ops):\n                self.alphas_normal_history['edge: {}, op: {}'.format((node_id, mm), PRIMITIVES[j])].append(\n                    float(weights1[i][j]))\n            if mm == last_id:\n                mm = 0\n                last_id += 1\n                node_id += 1\n            else:\n                mm += 1\n"
  },
  {
    "path": "braincog/model_zoo/NeuEvo/operations.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.nn import *\nimport torch.nn.functional as F\nfrom torch import einsum\nfrom einops import rearrange\nfrom braincog.model_zoo.base_module import DeformConvPack\nfrom braincog.model_zoo.base_module import BaseLinearModule\n\n\n# from mmcv.ops import ModulatedDeformConv2dPack\n\n\ndef si_relu(x, positive):\n    if positive == 1:\n        return torch.where(x > 0., x, torch.zeros_like(x))\n    elif positive == 0:\n        return x\n    elif positive == -1:\n        return torch.where(x < 0., x, torch.zeros_like(x))\n    else:\n        raise ValueError\n\n\nclass SiReLU(nn.Module):\n    def __init__(self, positive=0):\n        super().__init__()\n        self.positive = positive\n\n    def forward(self, x):\n        return si_relu(x, self.positive)\n\n\ndef weight_init(m):\n    if isinstance(m, nn.Conv2d):\n        torch.nn.init.xavier_normal(m.weight.data, gain=0.1)\n        torch.nn.init.constant(m.bias.data, 0.)\n\nOPS_Mlp = {\n    'mlp': lambda C, act_fun:\n        SiMLP(C, C, act_fun=act_fun, positive=0),\n    'mlp_p': lambda C, act_fun:\n        SiMLP(C, C, act_fun=act_fun, positive=1),\n    'mlp_n': lambda C, act_fun:\n        SiMLP(C, C, act_fun=act_fun, positive=-1),\n\n    'skip_connect': lambda C, act_fun:\n        Identity(positive=0),\n    'skip_connect_p': lambda C, act_fun:\n        Identity(positive=1),\n    'skip_connect_n': lambda C, act_fun:\n        Identity(positive=-1),\n}\n\nOPS = {\n    'avg_pool_3x3': lambda C, stride, affine, act_fun: nn.AvgPool2d(3, stride=stride, padding=1,\n                                                                    count_include_pad=False),\n    'conv_3x3': lambda C, stride, affine, act_fun:\n        ReLUConvBN(C_in=C, C_out=C, kernel_size=3, padding=1, stride=stride, affine=affine, act_fun=act_fun, positive=0),\n    'conv_5x5': lambda C, stride, affine, act_fun:\n        ReLUConvBN(C_in=C, C_out=C, kernel_size=5, padding=2, stride=stride, affine=affine, act_fun=act_fun, positive=0),\n    'max_pool_3x3': lambda C, stride, affine, act_fun: nn.MaxPool2d(3, stride=stride, padding=1),\n    'skip_connect': lambda C, stride, affine, act_fun:\n        Identity(positive=0) if stride == 1 else FactorizedReduce(C, C, affine=affine, act_fun=act_fun),\n    'sep_conv_3x3': lambda C, stride, affine, act_fun:\n        SepConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=0),\n    'sep_conv_5x5': lambda C, stride, affine, act_fun:\n        SepConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=0),\n    'sep_conv_7x7': lambda C, stride, affine, act_fun:\n        SepConv(C, C, 7, stride, 3, affine=affine, act_fun=act_fun, positive=0),\n    'dil_conv_3x3': lambda C, stride, affine, act_fun:\n        DilConv(C, C, 3, stride, 2, 2, affine=affine, act_fun=act_fun, positive=0),\n    'dil_conv_5x5': lambda C, stride, affine, act_fun:\n        DilConv(C, C, 5, stride, 4, 2, affine=affine, act_fun=act_fun, positive=0),\n    'def_conv_3x3': lambda C, stride, affine, act_fun:\n        DeformConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=0),\n    'def_conv_5x5': lambda C, stride, affine, act_fun:\n        DeformConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=0),\n\n    'avg_pool_3x3_p': lambda C, stride, affine, act_fun: nn.Sequential(\n        nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),\n        SiReLU(positive=1)\n    ),\n    'max_pool_3x3_p': lambda C, stride, affine, act_fun: nn.Sequential(\n        nn.MaxPool2d(3, stride=stride, padding=1),\n        SiReLU(positive=1)\n    ),\n    'conv_3x3_p': lambda C, stride, affine, act_fun:\n        ReLUConvBN(C_in=C, C_out=C, kernel_size=3, padding=1, stride=stride, affine=affine, act_fun=act_fun, positive=1),\n    'conv_5x5_p': lambda C, stride, affine, act_fun:\n        ReLUConvBN(C_in=C, C_out=C, kernel_size=5, padding=2, stride=stride, affine=affine, act_fun=act_fun, positive=1),\n    'skip_connect_p': lambda C, stride, affine, act_fun:\n        Identity(positive=1) if stride == 1 else FactorizedReduce(C, C, affine=affine, act_fun=act_fun, positive=1),\n    'sep_conv_3x3_p': lambda C, stride, affine, act_fun:\n        SepConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=1),\n    'sep_conv_5x5_p': lambda C, stride, affine, act_fun:\n        SepConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=1),\n    'sep_conv_7x7_p': lambda C, stride, affine, act_fun:\n        SepConv(C, C, 7, stride, 3, affine=affine, act_fun=act_fun, positive=1),\n    'dil_conv_3x3_p': lambda C, stride, affine, act_fun:\n        DilConv(C, C, 3, stride, 2, 2, affine=affine, act_fun=act_fun, positive=1),\n    'dil_conv_5x5_p': lambda C, stride, affine, act_fun:\n        DilConv(C, C, 5, stride, 4, 2, affine=affine, act_fun=act_fun, positive=1),\n    'def_conv_3x3_p': lambda C, stride, affine, act_fun:\n        DeformConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=1),\n    'def_conv_5x5_p': lambda C, stride, affine, act_fun:\n        DeformConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=1),\n\n    'avg_pool_3x3_n': lambda C, stride, affine, act_fun: nn.Sequential(\n        nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),\n        SiReLU(positive=-1)\n    ),\n    'max_pool_3x3_n': lambda C, stride, affine, act_fun: nn.Sequential(\n            nn.MaxPool2d(3, stride=stride, padding=1),\n            SiReLU(positive=-1)\n    ),\n    'conv_3x3_n': lambda C, stride, affine, act_fun:\n        ReLUConvBN(C_in=C, C_out=C, kernel_size=3, padding=1, stride=stride, affine=affine, act_fun=act_fun, positive=-1),\n    'conv_5x5_n': lambda C, stride, affine, act_fun:\n        ReLUConvBN(C_in=C, C_out=C, kernel_size=5, padding=2, stride=stride, affine=affine, act_fun=act_fun, positive=-1),\n    'skip_connect_n': lambda C, stride, affine, act_fun:\n        Identity(positive=-1) if stride == 1 else FactorizedReduce(C, C, affine=affine, act_fun=act_fun, positive=-1),\n    'sep_conv_3x3_n': lambda C, stride, affine, act_fun:\n        SepConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=-1),\n    'sep_conv_5x5_n': lambda C, stride, affine, act_fun:\n        SepConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=-1),\n    'sep_conv_7x7_n': lambda C, stride, affine, act_fun:\n        SepConv(C, C, 7, stride, 3, affine=affine, act_fun=act_fun, positive=-1),\n    'dil_conv_3x3_n': lambda C, stride, affine, act_fun:\n        DilConv(C, C, 3, stride, 2, 2, affine=affine, act_fun=act_fun, positive=-1),\n    'dil_conv_5x5_n': lambda C, stride, affine, act_fun:\n        DilConv(C, C, 5, stride, 4, 2, affine=affine, act_fun=act_fun, positive=-1),\n    'def_conv_3x3_n': lambda C, stride, affine, act_fun:\n        DeformConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=-1),\n    'def_conv_5x5_n': lambda C, stride, affine, act_fun:\n        DeformConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=-1),\n\n    'conv_7x1_1x7': lambda C, stride, affine, act_fun: nn.Sequential(\n        # nn.ReLU(inplace=False),\n        act_fun(),\n        nn.Conv2d(C, C, (1, 7), stride=(1, stride),\n                  padding=(0, 3), bias=False),\n        nn.Conv2d(C, C, (7, 1), stride=(stride, 1),\n                  padding=(3, 0), bias=False),\n        nn.BatchNorm2d(C, affine=affine)\n    ),\n    'transformer': lambda C, stride, affine, act_fun:\n        FactorizedReduce(\n            C, C, affine=affine, act_fun=act_fun) if stride != 1 else TransformerEncoderLayer(C),\n}\n\n\nclass SiMLP(nn.Module):\n    def __init__(self, c_in, c_out, act_fun=nn.ReLU, positive=0, *args, **kwargs):\n        super(SiMLP, self).__init__()\n        self.op = nn.Sequential(\n            nn.Linear(c_in, c_out, bias=True),\n            act_fun()\n        )\n        self.positive = positive\n\n    def forward(self, x):\n        out = self.op(si_relu(x, self.positive))\n        return out\n\n\nclass ReLUConvBN(nn.Module):\n    \"\"\"\n    ReLu -> Conv2d -> BatchNorm2d\n    \"\"\"\n\n    def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True, act_fun=nn.ReLU, positive=0):\n        super(ReLUConvBN, self).__init__()\n        self.op = nn.Sequential(\n            # nn.ReLU(inplace=False),\n            # act_fun(),\n            nn.Conv2d(C_in, C_out, kernel_size, stride=stride,\n                      padding=padding, bias=False),\n            nn.BatchNorm2d(C_out, affine=affine)\n        )\n        self.positive = positive\n        # if positive == -1:\n        #     weight_init(self.op)\n\n    def forward(self, x):\n        out = self.op(x)\n        return si_relu(out, self.positive)\n\n\nclass DilConv(nn.Module):\n    \"\"\"\n    Dilation Convolution ： ReLU -> DilConv -> Conv2d -> BatchNorm2d\n    \"\"\"\n\n    def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True, act_fun=nn.ReLU, positive=0):\n        super(DilConv, self).__init__()\n        self.op = nn.Sequential(\n            # nn.ReLU(inplace=False),\n            act_fun(),\n            nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,\n                      groups=C_in, bias=False),\n            nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),\n            nn.BatchNorm2d(C_out, affine=affine),\n        )\n        self.positive = positive\n        # if positive == -1:\n        #     weight_init(self.op)\n\n    def forward(self, x):\n        out = self.op(x)\n        return si_relu(out, self.positive)\n\n\nclass SepConv(nn.Module):\n\n    def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True, act_fun=nn.ReLU, positive=0):\n        super(SepConv, self).__init__()\n        self.op = nn.Sequential(\n            # nn.ReLU(inplace=False),\n            act_fun(),\n            nn.Conv2d(C_in, C_in, kernel_size=kernel_size,\n                      stride=stride, padding=padding, groups=C_in, bias=False),\n            nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),\n            nn.BatchNorm2d(C_in, affine=affine),\n            nn.ReLU(inplace=False),\n            nn.Conv2d(C_in, C_in, kernel_size=kernel_size,\n                      stride=1, padding=padding, groups=C_in, bias=False),\n            nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),\n            nn.BatchNorm2d(C_out, affine=affine),\n        )\n        self.positive = positive\n        # if positive == -1:\n        #     weight_init(self.op)\n\n    def forward(self, x):\n        out = self.op(x)\n        return si_relu(out, self.positive)\n\n\nclass Identity(nn.Module):\n\n    def __init__(self, positive=0):\n        super(Identity, self).__init__()\n        self.positive = positive\n\n    def forward(self, x):\n        return si_relu(x, self.positive)\n\n\nclass Zero(nn.Module):\n\n    def __init__(self, stride):\n        super(Zero, self).__init__()\n        self.stride = stride\n\n    def forward(self, x):\n        if self.stride == 1:\n            return x.mul(0.)\n        return x[:, :, ::self.stride, ::self.stride].mul(0.)  # N * C * W * H\n\n\nclass FactorizedReduce(nn.Module):\n\n    def __init__(self, C_in, C_out, affine=True, act_fun=nn.ReLU, positive=0):\n        super(FactorizedReduce, self).__init__()\n        assert C_out % 2 == 0\n        # self.relu = nn.ReLU(inplace=False)\n        self.activation = act_fun()\n        self.conv_1 = nn.Conv2d(C_in, C_out // 2, 3,\n                                stride=2, padding=1, bias=False)\n        self.conv_2 = nn.Conv2d(C_in, C_out // 2, 3,\n                                stride=2, padding=1, bias=False)\n        self.bn = nn.BatchNorm2d(C_out, affine=affine)\n        self.positive = positive\n        # if positive == -1:\n        #     weight_init(self.op)\n\n    def forward(self, x):\n        # x = self.relu(x)\n        x = self.activation(x)\n        out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1)\n        out = self.bn(out)\n        out = si_relu(out, self.positive)\n        return out\n\n\nclass DeformConv(nn.Module):\n    def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True, act_fun=nn.ReLU, positive=0):\n        super(DeformConv, self).__init__()\n        self.op = nn.Sequential(\n            # nn.ReLU(inplace=False),\n            act_fun(),\n            DeformConvPack(C_in, C_out, kernel_size=kernel_size,\n                           stride=stride, padding=padding, bias=True),\n            nn.BatchNorm2d(C_out, affine=affine)\n        )\n        self.positive = positive\n        # if positive == -1:\n        #     weight_init(self.op)\n\n    def forward(self, x):\n        out = self.op(x)\n        return si_relu(out, self.positive)\n\n\nclass Attention(Module):\n    \"\"\"\n    Obtained from: github.com:rwightman/pytorch-image-models\n    \"\"\"\n\n    def __init__(self, dim, num_heads=4, attention_dropout=0.1, projection_dropout=0.1):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // self.num_heads\n        self.scale = head_dim ** -0.5\n\n        self.qkv = Linear(dim, dim * 3, bias=False)\n        self.attn_drop = Dropout(attention_dropout)\n        self.proj = Linear(dim, dim)\n        self.proj_drop = Dropout(projection_dropout)\n\n    def forward(self, x):\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //\n                                  self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass TransformerEncoderLayer(Module):\n    \"\"\"\n    Inspired by torch.nn.TransformerEncoderLayer and\n    rwightman's timm package.\n    \"\"\"\n\n    def __init__(self, d_model, nhead=4, dim_feedforward=256, dropout=0.1,\n                 attention_dropout=0.1, drop_path_rate=0.1):\n        super(TransformerEncoderLayer, self).__init__()\n        self.pre_norm = LayerNorm(d_model)\n        self.self_attn = Attention(dim=d_model, num_heads=nhead,\n                                   attention_dropout=attention_dropout, projection_dropout=dropout)\n        dim_feedforward = d_model\n        self.linear1 = Linear(d_model, dim_feedforward)\n        self.dropout1 = Dropout(dropout)\n        self.norm1 = LayerNorm(d_model)\n        self.linear2 = Linear(dim_feedforward, d_model)\n        self.dropout2 = Dropout(dropout)\n\n        self.drop_path = DropPath(\n            drop_path_rate) if drop_path_rate > 0 else Identity()\n\n        self.activation = F.gelu\n\n    def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor:\n        # print(src.shape)\n        c = src.shape[-1]\n        src = rearrange(src, 'b d r c -> b (r c) d')\n        # print(src.shape)\n        src = src + self.drop_path(self.self_attn(self.pre_norm(src)))\n        src = self.norm1(src)\n        src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))\n        src = src + self.drop_path(self.dropout2(src2))\n        src = rearrange(src, 'b (r c) d -> b d r c', c=c)\n        return src\n\n\ndef drop_path(x, drop_prob: float = 0., training: bool = False):\n    \"\"\"\n    Obtained from: github.com:rwightman/pytorch-image-models\n    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\n    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,\n    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...\n    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for\n    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use\n    'survival rate' as the argument.\n    \"\"\"\n    if drop_prob == 0. or not training:\n        return x\n    keep_prob = 1 - drop_prob\n    # work with diff dim tensors, not just 2D ConvNets\n    shape = (x.shape[0],) + (1,) * (x.ndim - 1)\n    random_tensor = keep_prob + \\\n        torch.rand(shape, dtype=x.dtype, device=x.device)\n    random_tensor.floor_()  # binarize\n    output = x.div(keep_prob) * random_tensor\n    return output\n\n\nclass DropPath(Module):\n    \"\"\"\n    Obtained from: github.com:rwightman/pytorch-image-models\n    Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).\n    \"\"\"\n\n    def __init__(self, drop_prob=None):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n\n    def forward(self, x):\n        return drop_path(x, self.drop_prob, self.training)\n"
  },
  {
    "path": "braincog/model_zoo/NeuEvo/others.py",
    "content": "# encoding: utf-8\n# Author    : Floyed<Floyed_Shen@outlook.com>\n# Datetime  : 2023/5/22 13:32\n# User      : yu\n# Product   : PyCharm\n# Project   : BrainCog\n# File      : others.py\n# explain   :\nfrom functools import partial\nimport torch\nimport torch.nn as nn\nfrom copy import deepcopy\n\nfrom timm.models import register_model\n\nfrom braincog.base.node.node import *\nfrom braincog.base.connection.layer import WSConv2d\nfrom braincog.datasets import is_dvs_data\nfrom braincog.model_zoo.base_module import BaseModule, BaseConvModule\n\n@register_model\nclass CIFARNet_Wu(BaseModule):\n\n    def __init__(\n            self, num_classes=10,\n            node_type=LIFNode,\n            step=4,\n            encode_type='direct',\n            *args,\n            **kwargs,\n    ):\n        super().__init__(step, encode_type, *args, **kwargs)\n        self.dataset = kwargs['dataset']\n        self.node = node_type\n        if issubclass(self.node, BaseNode):\n            self.node = partial(self.node, **kwargs, step=step)\n\n        channels = 32\n        if not is_dvs_data(self.dataset):\n            init_channel = 3\n            out_size = 2 ** 2\n        else:\n            init_channel = 2\n            out_size = 3 ** 2\n\n        self.feature = nn.Sequential(\n            BaseConvModule(init_channel, channels, node=self.node),\n            BaseConvModule(channels, channels * 2, node=self.node),\n            nn.AvgPool2d(2, 2),\n            self.node(),\n            BaseConvModule(channels * 2, channels * 4, node=self.node),\n            nn.AvgPool2d(2, 2),\n            # self.node(),\n            BaseConvModule(channels * 4, channels * 8, node=self.node),\n            BaseConvModule(channels * 8, channels * 4, node=self.node),\n            nn.Flatten(),\n        )\n\n        self.fc = nn.Sequential(\n            nn.Linear(channels * 4 * 8 * 8, channels * 8, bias=False),\n            self.node(),\n            nn.Linear(channels * 8, channels * 4, bias=False),\n            self.node(),\n            nn.Linear(channels * 4, num_classes, bias=False)\n        )\n\n    def forward(self, inputs):\n        inputs = self.encoder(inputs).contiguous()\n        self.reset()\n        outputs = []\n        for t in range(self.step):\n            x = inputs[t]\n            x = self.feature(x)\n            x = self.fc(x)\n            outputs.append(x)\n        return sum(outputs) / len(outputs)\n\n@register_model\nclass CIFARNet_Fang(BaseModule):\n\n    def __init__(\n            self, num_classes=10,\n            node_type=LIFNode,\n            step=4,\n            encode_type='direct',\n            *args,\n            **kwargs,\n    ):\n        super().__init__(step, encode_type, *args, **kwargs)\n        self.dataset = kwargs['dataset']\n        self.node = node_type\n        if issubclass(self.node, BaseNode):\n            self.node = partial(self.node, **kwargs, step=step)\n\n        channels = 32\n        if not is_dvs_data(self.dataset):\n            init_channel = 3\n        else:\n            init_channel = 2\n\n        self.feature = nn.Sequential(\n            BaseConvModule(init_channel, channels, node=self.node),\n            BaseConvModule(channels, channels, node=self.node),\n            BaseConvModule(channels, channels, node=self.node),\n            nn.MaxPool2d(2, 2),\n            BaseConvModule(channels, channels, node=self.node),\n            BaseConvModule(channels, channels, node=self.node),\n            BaseConvModule(channels, channels, node=self.node),\n            nn.MaxPool2d(2, 2),\n            nn.Flatten(),\n        )\n\n        self.fc = nn.Sequential(\n            nn.Linear(channels * 8 * 8, channels * 8, bias=False),\n            self.node(),\n            nn.Linear(channels * 8, channels, bias=False),\n        )\n\n    def forward(self, inputs):\n        inputs = self.encoder(inputs).contiguous()\n        self.reset()\n\n        outputs = []\n        for t in range(self.step):\n            x = inputs[t]\n            x = self.feature(x)\n            x = self.fc(x)\n            outputs.append(x)\n        return sum(outputs) / len(outputs)\n\n@register_model\nclass DVS_CIFARNet_Fang(BaseModule):\n\n    def __init__(\n            self, num_classes=10,\n            node_type=LIFNode,\n            step=10,\n            encode_type='direct',\n            *args,\n            **kwargs,\n    ):\n        super().__init__(step, encode_type, *args, **kwargs)\n        self.dataset = kwargs['dataset']\n        self.node = node_type\n        if issubclass(self.node, BaseNode):\n            self.node = partial(self.node, **kwargs, step=step)\n\n        channels = 128\n        if not is_dvs_data(self.dataset):\n            init_channel = 3\n        else:\n            init_channel = 2\n\n        self.feature = nn.Sequential(\n            BaseConvModule(init_channel, channels, node=self.node),\n            nn.MaxPool2d(2, 2),\n            BaseConvModule(channels, channels, node=self.node),\n            nn.MaxPool2d(2, 2),\n            BaseConvModule(channels, channels, node=self.node),\n            nn.MaxPool2d(2, 2),\n            BaseConvModule(channels, channels, node=self.node),\n            nn.MaxPool2d(2, 2),\n            nn.Flatten(),\n        )\n\n        self.fc = nn.Sequential(\n            nn.Linear(channels * 8 * 8, channels * 4, bias=False),\n            self.node(),\n            nn.Linear(channels * 4, channels, bias=False),\n        )\n\n    def forward(self, inputs):\n        inputs = self.encoder(inputs).contiguous()\n        self.reset()\n\n        outputs = []\n        for t in range(self.step):\n            x = inputs[t]\n            x = self.feature(x)\n            x = self.fc(x)\n            outputs.append(x)\n        return sum(outputs) / len(outputs)"
  },
  {
    "path": "braincog/model_zoo/__init__.py",
    "content": "__all__ = ['convnet', 'resnet', 'base_module', 'glsnn', 'qsnn', 'resnet19_snn']\n\nfrom . import (\n    convnet,\n    resnet,\n    base_module,\n    glsnn,\n    qsnn,\n    resnet19_snn\n)\n"
  },
  {
    "path": "braincog/model_zoo/backeinet.py",
    "content": "import numpy as np\nfrom timm.models import register_model\nfrom braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule\nfrom braincog.base.node.node import *\nfrom braincog.base.connection.layer import *\n\n\nclass MNISTNet(BaseModule):\n    def __init__(self, step=20, encode_type='rate', if_back=True, if_ei=True, data='mnist', *args, **kwargs):\n        super().__init__(step, encode_type, *args, **kwargs)\n        self.if_back = if_back\n        self.if_ei = if_ei\n        if data == 'mnist':\n            self.cfg_conv = ((1, 15, 5, 1, 0), (15, 40, 5, 1, 0))\n            self.cfg_fc = (300, 10)\n            self.cfg_kernel = (24, 8, 4)\n            cfg_backei = 2\n        if data == 'fashion':\n            self.cfg_conv = ((1, 32, 5, 1, 2), (32, 64, 5, 1, 2))\n            self.cfg_fc = (1024, 10)\n            self.cfg_kernel = (28, 14, 7)\n            cfg_backei = 1\n        self.feature = nn.Sequential(\n            nn.Conv2d(self.cfg_conv[0][0], self.cfg_conv[0][1], self.cfg_conv[0][2], self.cfg_conv[0][3],\n                      self.cfg_conv[0][4]),\n            BackEINode(channel=self.cfg_conv[0][1], if_back=self.if_back, if_ei=self.if_ei, cfg_backei=cfg_backei),\n            nn.AvgPool2d(2),\n            nn.Conv2d(self.cfg_conv[1][0], self.cfg_conv[1][1], self.cfg_conv[1][2], self.cfg_conv[1][3],\n                      self.cfg_conv[1][4]),\n            BackEINode(channel=self.cfg_conv[1][1], if_back=self.if_back, if_ei=self.if_ei, cfg_backei=cfg_backei),\n            nn.AvgPool2d(2),\n            nn.Flatten(),\n            nn.Linear(self.cfg_kernel[2] * self.cfg_kernel[2] * self.cfg_conv[1][1], self.cfg_fc[0]),\n            BackEINode(if_back=False, if_ei=False),\n            nn.Linear(self.cfg_fc[0], self.cfg_fc[1]),\n            BackEINode(if_back=False, if_ei=False)\n        )\n\n    def forward(self, inputs):\n        inputs = self.encoder(inputs)\n        self.reset()\n        if not self.training:\n            self.fire_rate.clear()\n        outputs = []\n        step = self.step\n        for t in range(step):\n            x = inputs[t]\n            x = self.feature(x)\n            outputs.append(x)\n\n        return sum(outputs) / len(outputs)\n\n\nclass CIFARNet(BaseModule):\n    def __init__(self, step=20, encode_type='rate', if_back=True, if_ei=True, *args, **kwargs):\n        super().__init__(step, encode_type, *args, **kwargs)\n        self.if_back = if_back\n        self.if_ei = if_ei\n        self.feature = nn.Sequential(\n            nn.Conv2d(3, 128, 3, 1, 1),\n            BackEINode(channel=128, if_back=self.if_back, if_ei=self.if_ei, cfg_backei=1),\n            nn.Dropout(0.5),\n            nn.AvgPool2d(2),\n\n            nn.Conv2d(128, 256, 3, 1, 1),\n            BackEINode(channel=256, if_back=self.if_back, if_ei=self.if_ei, cfg_backei=1),\n            nn.Dropout(0.5),\n            nn.AvgPool2d(2),\n\n            nn.Conv2d(256, 512, 3, 1, 1),\n            BackEINode(channel=512, if_back=self.if_back, if_ei=self.if_ei, cfg_backei=1),\n            nn.Dropout(0.5),\n            nn.AvgPool2d(2),\n\n            nn.Flatten(),\n            nn.Linear(4 * 4 * 512, 1024),\n            BackEINode(if_back=False, if_ei=False),\n            nn.Dropout(0.5),\n\n            nn.Linear(1024, 10),\n            BackEINode(if_back=False, if_ei=False)\n        )\n\n    def forward(self, inputs):\n        inputs = self.encoder(inputs)\n        self.reset()\n        if not self.training:\n            self.fire_rate.clear()\n        outputs = []\n        step = self.step\n        for t in range(step):\n            x = inputs[t]\n            x = self.feature(x)\n            outputs.append(x)\n\n        return sum(outputs) / len(outputs)\n"
  },
  {
    "path": "braincog/model_zoo/base_module.py",
    "content": "from functools import partial\nfrom torchvision.ops import DeformConv2d\nfrom braincog.base.node.node import *\nfrom braincog.base.connection.layer import *\nfrom braincog.base.encoder.encoder import *\n\n\nclass BaseLinearModule(nn.Module):\n    \"\"\"\n    线性模块\n    :param in_features: 输入尺寸\n    :param out_features: 输出尺寸\n    :param bias: 是否有Bias, 默认 ``False``\n    :param node: 神经元类型, 默认 ``LIFNode``\n    :param args:\n    :param kwargs:\n    \"\"\"\n    def __init__(self,\n                 in_features: int,\n                 out_features: int,\n                 bias=True,\n                 node=LIFNode,\n                 *args,\n                 **kwargs):\n        super().__init__()\n        if node is None:\n            raise TypeError\n\n        self.groups = kwargs['groups'] if 'groups' in kwargs else 1\n        if self.groups == 1:\n            self.fc = nn.Linear(in_features=in_features,\n                                out_features=out_features, bias=bias)\n        else:\n            self.fc = nn.ModuleList()\n            for i in range(self.groups):\n                self.fc.append(nn.Linear(\n                    in_features=in_features,\n                    out_features=out_features,\n                    bias=bias\n                ))\n        self.node = partial(node, **kwargs)()\n\n    def forward(self, x):\n        if self.groups == 1:  # (t b) c\n            outputs = self.fc(x)\n\n        else: # b (c t)\n            x = rearrange(x, 'b (c t) -> t b c', t=self.groups)\n            outputs = []\n            for i in range(self.groups):\n                outputs.append(self.fc[i](x[i]))\n            outputs = torch.stack(outputs) # t b c\n            outputs = rearrange(outputs, 't b c -> b (c t)')\n\n        return self.node(outputs)\n\n\nclass BaseConvModule(nn.Module):\n    \"\"\"\n    SNN卷积模块\n    :param in_channels: 输入通道数\n    :param out_channels: 输出通道数\n    :param kernel_size: kernel size\n    :param stride: stride\n    :param padding: padding\n    :param bias: Bias\n    :param node: 神经元类型\n    :param kwargs:\n    \"\"\"\n    def __init__(self,\n                 in_channels: int,\n                 out_channels: int,\n                 kernel_size=(3, 3),\n                 stride=(1, 1),\n                 padding=(1, 1),\n                 bias=False,\n                 node=PLIFNode,\n                 **kwargs):\n\n        super().__init__()\n\n        if node is None:\n            raise TypeError\n\n        self.groups = kwargs['groups'] if 'groups' in kwargs else 1\n        self.conv = nn.Conv2d(in_channels=in_channels * self.groups,\n                              out_channels=out_channels * self.groups,\n                              kernel_size=kernel_size,\n                              padding=padding,\n                              stride=stride,\n                              bias=bias)\n \n\n        self.bn = nn.BatchNorm2d(out_channels * self.groups)\n\n        self.node = partial(node, **kwargs)()\n\n        self.activation = nn.Identity()\n\n    def forward(self, x):\n        # origin_shape = x.shape\n        # if len(origin_shape) > 4:\n        #     x = x.reshape(np.prod(origin_shape[0:-3]), *origin_shape[-3:])\n        x = self.conv(x)\n        x = self.bn(x)\n        # if len(origin_shape) > 4:\n        #     x = x.reshape(*origin_shape[0:-3], *x.shape[-3:])\n\n        x = self.node(x)\n        return x\n\n\nclass BaseModule(nn.Module, abc.ABC):\n    \"\"\"\n    SNN抽象类, 所有的SNN都要继承这个类, 以实现一些基础方法\n    :param step: 仿真步长\n    :param encode_type: 数据编码类型\n    :param layer_by_layer: 是否layer wise地进行前向推理\n    :param temporal_flatten: 是否将时间维度和channel合并\n    :param args:\n    :param kwargs:\n    \"\"\"\n    def __init__(self,\n                 step,\n                 encode_type,\n                 layer_by_layer=False,\n                 temporal_flatten=False,\n                 *args,\n                 **kwargs):\n        super(BaseModule, self).__init__()\n        self.step = step\n        # print(kwargs['layer_by_layer'])\n        self.layer_by_layer = layer_by_layer\n\n        self.temporal_flatten = temporal_flatten\n        encode_step = self.step\n\n        if temporal_flatten is True:\n            self.init_channel_mul = self.step\n            self.step = 1\n        else:  # origin\n            self.init_channel_mul = 1\n\n        self.encoder = Encoder(encode_step, encode_type, temporal_flatten=self.temporal_flatten, layer_by_layer=self.layer_by_layer, **kwargs)\n\n        self.kwargs = kwargs\n        self.warm_up = False\n\n        self.fire_rate = []\n\n    def reset(self):\n        \"\"\"\n        重置所有神经元的膜电位\n        :return:\n        \"\"\"\n        for mod in self.modules():\n            if hasattr(mod, 'n_reset'):\n                mod.n_reset()\n\n    def set_attr(self, attr, val):\n        \"\"\"\n        设置神经元的属性\n        :param attr: 属性名称\n        :param val: 设置的属性值\n        :return:\n        \"\"\"\n        for mod in self.modules():\n            if isinstance(mod, BaseNode):\n                if hasattr(mod, attr):\n                    setattr(mod, attr, val)\n                else:\n                    ValueError('{} do not has {}'.format(self, attr))\n\n    def get_threshold(self):\n        \"\"\"\n        获取所有神经元的阈值\n        :return:\n        \"\"\"\n        outputs = []\n        for mod in self.modules():\n            if isinstance(mod, BaseNode):\n                thresh = (mod.get_thres())\n                outputs.append(thresh)\n        return outputs\n\n    def get_fp(self, temporal_info=False):\n        \"\"\"\n        获取所有神经元的状态\n        :param temporal_info: 是否要读取神经元的时间维度状态, False会把时间维度拍平\n        :return: 所有神经元的状态, List\n        \"\"\"\n        outputs = []\n        for mod in self.modules():\n            if isinstance(mod, BaseNode):\n                if temporal_info:\n                    outputs.append(mod.feature_map)#[l,[t,[b,w,h]]]\n                else:\n                    outputs.append(sum(mod.feature_map) / len(mod.feature_map))\n        return outputs\n\n    def get_mem(self, temporal_info=False):\n        \"\"\"\n        获取所有神经元的模电势\n        :param temporal_info: 是否要读取神经元的时间维度状态, False会把时间维度拍平\n        :return: 所有神经元的状态, List\n        \"\"\"\n        outputs = []\n        for mod in self.modules():\n            if isinstance(mod, BaseNode):\n                if temporal_info:\n                    outputs.append(mod.mem_collect)#[l,[t,[b,w,h]]]\n                else:\n                    outputs.append(sum(mod.mem_collect) / len(mod.mem_collect))\n        return outputs\n\n    def get_fire_rate(self, requires_grad=False):\n        \"\"\"\n        获取神经元的fire-rate\n        :param requires_grad: 是否需要梯度信息, 默认为 ``False`` 会截断梯度\n        :return: 所有神经元的fire-rate\n        \"\"\"\n        outputs = []\n        fp = self.get_attr('feature_map')\n        for f in fp:\n            if requires_grad is False:\n                if len(f) == 0:\n                    return torch.tensor([0.])\n                outputs.append(((sum(f) / len(f)).detach() > 0.).float().mean())\n            else:\n                outputs.append(((sum(f) / len(f)) > 0.).float().mean())\n        if len(outputs) == 0:\n            return torch.tensor([0.])\n        return torch.stack(outputs)\n\n    def get_tot_spike(self):\n        \"\"\"\n        获取神经元总的脉冲数量\n        :return:\n        \"\"\"\n        tot_spike = 0\n        batch_size = 1\n        fp = self.get_attr('feature_map')\n        for f in fp:\n            if len(f) == 0:\n                break\n            tot_spike += sum(f).sum()\n            batch_size = f[0].shape[0]\n        return tot_spike / batch_size\n\n    def get_spike_info(self):\n        \"\"\"\n        获取神经元的脉冲信息, 主要用于绘图\n        :return:\n        \"\"\"\n        spike_feature_list = self.get_fp(temporal_info=True)\n        avg, var, spike = [], [], []\n        avg_per_step = []\n        for spike_feature in spike_feature_list:\n            avg_list = []\n            for spike_t in spike_feature:\n                avg_list.append(float(spike_t.mean()))\n            avg_per_step.append(avg_list)\n\n            spike_feature = sum(spike_feature)\n            num = np.prod(spike_feature.shape)\n            avg.append(float(spike_feature.sum()))\n            var.append(float(spike_feature.std()))\n            lst = []\n            for t in range(self.step + 1):\n                lst.append(float((spike_feature == t).sum() / num))\n\n            spike.append(lst)\n\n        return avg, var, spike, avg_per_step\n\n    def set_requires_fp(self, flag):\n        for mod in self.modules():\n            if hasattr(mod, 'requires_fp'):\n                mod.requires_fp = flag\n\n    def set_requires_mem(self, flag):\n        for mod in self.modules():\n            if hasattr(mod, 'requires_mem'):\n                mod.requires_mem = flag        \n\n    def get_attr(self, attr):\n        \"\"\"\n        获取神经元的某一属性值\n        :param attr: 属性名称\n        :return: 对应属性的值, List\n        \"\"\"\n        outputs = []\n        for mod in self.modules():\n            if hasattr(mod, attr):\n                outputs.append(getattr(mod, attr))\n        return outputs\n\n    @staticmethod\n    def forward(self, inputs):\n        pass\n\n\nclass DeformConvPack(nn.Module):\n    def __init__(self,\n                 in_channels,\n                 out_channels,\n                 kernel_size,\n                 padding,\n                 stride,\n                 bias,\n                 *args,\n                 **kwargs):\n        super(DeformConvPack, self).__init__()\n        self.in_channels = in_channels\n        self.kernel_size = kernel_size\n        self.stride = stride\n        self.padding = padding\n\n        if isinstance(self.kernel_size, tuple) or isinstance(self.kernel_size, list):\n            self.receptive_field = self.kernel_size[0]\n        else:\n            self.receptive_field = self.kernel_size\n            self.kernel_size = (self.kernel_size, self.kernel_size)\n\n        self.receptive_field = 4 * (self.receptive_field // 2)\n\n        self.conv_offset = nn.Conv2d(\n            self.in_channels,\n            3 * self.kernel_size[0] * self.kernel_size[1],\n            kernel_size=self.kernel_size,\n            stride=self.stride,\n            padding=self.padding,\n            bias=True)\n        self.deform_conv = DeformConv2d(in_channels=in_channels,\n                                        out_channels=out_channels,\n                                        kernel_size=kernel_size,\n                                        padding=padding,\n                                        stride=stride,\n                                        bias=bias)\n        self.init_weights()\n\n    def init_weights(self):\n        if hasattr(self, 'conv_offset'):\n            self.conv_offset.weight.data.zero_()\n            self.conv_offset.bias.data.zero_()\n\n    def forward(self, x):\n        out = self.conv_offset(x)\n        o1, o2, mask = torch.chunk(out, 3, dim=1)\n        offset = torch.cat((o1, o2), dim=1)\n        offset = self.receptive_field * (torch.sigmoid(offset) - 0.5)\n        mask = torch.sigmoid(mask)\n        return self.deform_conv(x, offset, mask)\n"
  },
  {
    "path": "braincog/model_zoo/bdmsnn.py",
    "content": "\r\nimport torch\r\nfrom torch import nn\r\n\r\nfrom braincog.base.node.node import IFNode, SimHHNode\r\nfrom braincog.base.learningrule.STDP import STDP, MutliInputSTDP\r\nfrom braincog.base.connection.CustomLinear import CustomLinear\r\nfrom braincog.base.brainarea.basalganglia import basalganglia\r\n\r\nimport pygame\r\nfrom pygame.locals import *\r\nfrom collections import deque\r\nfrom random import randint\r\n#os.environ[\"SDL_VIDEODRIVER\"] = \"dummy\"\r\n\r\n\r\nclass BDMSNN(nn.Module):\r\n    def __init__(self, num_state, num_action, weight_exc, weight_inh, node_type):\r\n        \"\"\"\r\n        定义BDM-SNN网络\r\n        :param num_state: 状态个数\r\n        :param num_action: 动作个数\r\n        :param weight_exc: 兴奋性连接权重\r\n        :param weight_inh: 抑制性连接权重\r\n        \"\"\"\r\n        super().__init__()\r\n        # parameters\r\n        BG = basalganglia(num_state, num_action, weight_exc, weight_inh, node_type)\r\n        dm_connection = BG.getweight()\r\n        dm_mask = BG.getmask()\r\n        # input-dlpfc\r\n        con_matrix9 = torch.eye((num_state), dtype=torch.float)\r\n        dm_connection.append(CustomLinear(weight_exc * con_matrix9, con_matrix9))\r\n        dm_mask.append(con_matrix9)\r\n        # gpi-th\r\n        con_matrix10 = torch.eye((num_action), dtype=torch.float)\r\n        dm_mask.append(con_matrix10)\r\n        dm_connection.append(CustomLinear(weight_inh * con_matrix10, con_matrix10))\r\n        # th-pm\r\n        dm_mask.append(con_matrix10)\r\n        dm_connection.append(CustomLinear(weight_exc * con_matrix10, con_matrix10))\r\n        # dlpfc-th\r\n        con_matrix11 = torch.ones((num_state, num_action), dtype=torch.float)\r\n        dm_mask.append(con_matrix11)\r\n        dm_connection.append(CustomLinear(0.2 * weight_exc * con_matrix11, con_matrix11))\r\n        # pm-pm\r\n        con_matrix3 = torch.ones((num_action, num_action), dtype=torch.float)\r\n        con_matrix4 = torch.eye((num_action), dtype=torch.float)\r\n        con_matrix5 = con_matrix3 - con_matrix4\r\n        con_matrix5 = con_matrix5\r\n        dm_mask.append(con_matrix5)\r\n        dm_connection.append(CustomLinear(5 * weight_inh * con_matrix5, con_matrix5))\r\n        # dlpfc thalamus pm +bg\r\n        self.weight_exc = weight_exc\r\n        self.num_subDM = 8\r\n        self.connection = dm_connection\r\n        self.mask = dm_mask\r\n        self.node = BG.node\r\n        self.node_type = node_type\r\n        if self.node_type == \"hh\":\r\n            self.node.extend([SimHHNode() for i in range(self.num_subDM - BG.num_subBG)])\r\n            self.node[6].g_Na = torch.tensor(12)\r\n            self.node[6].g_K = torch.tensor(3.6)\r\n            self.node[6].g_L = torch.tensor(0.03)\r\n        if self.node_type == \"lif\":\r\n            self.node.extend([IFNode() for i in range(self.num_subDM - BG.num_subBG)])\r\n        self.learning_rule = BG.learning_rule\r\n        self.learning_rule.append(MutliInputSTDP(self.node[5], [self.connection[10], self.connection[12]]))  # gpi-丘脑\r\n        self.learning_rule.append(MutliInputSTDP(self.node[6], [self.connection[11], self.connection[13]]))  # pm\r\n        self.learning_rule.append(STDP(self.node[7], self.connection[9]))\r\n\r\n        out_shape=[self.connection[0].weight.shape[1],self.connection[1].weight.shape[1],self.connection[2].weight.shape[1],self.connection[4].weight.shape[1],self.connection[3].weight.shape[1],self.connection[10].weight.shape[1],self.connection[11].weight.shape[1],self.connection[9].weight.shape[1]]\r\n        self.out = []\r\n        self.dw = []\r\n        for i in range(self.num_subDM):\r\n            self.out.append(torch.zeros((out_shape[i]), dtype=torch.float))\r\n            self.dw.append(torch.zeros((out_shape[i]), dtype=torch.float))\r\n\r\n    def forward(self, input):\r\n        \"\"\"\r\n        根据输入得到网络的输出\r\n        :param input: 输入\r\n        :return: 网络的输出\r\n        \"\"\"\r\n        self.out[7] = self.node[7](self.connection[9](input))\r\n        self.out[0], self.dw[0] = self.learning_rule[0](self.out[7])\r\n        self.out[1], self.dw[1] = self.learning_rule[1](self.out[7])\r\n        self.out[2], self.dw[2] = self.learning_rule[2](self.out[7], self.out[3])\r\n        self.out[3], self.dw[3] = self.learning_rule[3](self.out[1], self.out[2])\r\n        self.out[4], self.dw[4] = self.learning_rule[4](self.out[0], self.out[3], self.out[2])\r\n        self.out[5], self.dw[5] = self.learning_rule[5](self.out[4], self.out[7])\r\n        self.out[6], self.dw[6] = self.learning_rule[6](self.out[5], self.out[6])\r\n        br = [\"StrD1\", \"StrD2\", \"STN\", \"Gpe\", \"Gpi\", \"thalamus\", \"PM\", \"DLPFC\"]\r\n        for i in range(self.num_subDM):\r\n            if torch.max(self.out[i]) > 0 and self.node_type == \"hh\":\r\n                self.node[i].n_reset()\r\n            print(\"every areas:\", br[i], self.out[i])\r\n        return self.out[6], self.dw\r\n\r\n    def UpdateWeight(self, i, s, num_action, dw):\r\n        \"\"\"\r\n        更新网络中第i组连接的权重\r\n        :param i:要更新的连接组索引\r\n        :param s:传入状态\r\n        :param dw:更新权重的量\r\n        :return:\r\n        \"\"\"\r\n        if self.node_type == \"hh\":\r\n            self.connection[i].update(0.2 * self.weight_exc * dw)\r\n            self.connection[i].weight.data[s, [s * num_action, s * num_action + 1]] /= (self.connection[i].weight.data[s, [s * num_action, s * num_action + 1]].float().max() + 1e-12)\r\n            self.connection[i].weight.data[s, :] = self.connection[i].weight.data[s, :] * self.weight_exc\r\n        if self.node_type == \"lif\":\r\n            dw_mean = dw[s, [s * num_action, s * num_action + 1]].mean()\r\n            dw_std = dw[s, [s * num_action, s * num_action + 1]].std()\r\n            dw[s, [s * num_action, s * num_action + 1]] = (dw[s, [s * num_action,s * num_action + 1]] - dw_mean) / dw_std\r\n            dw[s, :] = dw[s, :] * self.mask[i][s, :]\r\n            self.connection[i].update(dw)\r\n            self.connection[i].weight.data[s, [s * num_action, s * num_action + 1]] /= (self.connection[i].weight.data[s, [s * num_action, s * num_action + 1]].float().max() + 1e-12)\r\n        if i in [0, 1, 2, 6, 7, 11, 12]:\r\n            self.connection[i].weight.data = torch.clamp(self.connection[i].weight.data, 0, None)\r\n        if i in [3, 4, 5, 8, 10]:\r\n            self.connection[i].weight.data = torch.clamp(self.connection[i].weight.data, None, 0)\r\n\r\n    def reset(self):\r\n        \"\"\"\r\n        reset神经元或学习法则的中间量\r\n        :return: None\r\n        \"\"\"\r\n        for i in range(self.num_subDM):\r\n            self.node[i].n_reset()\r\n        for i in range(len(self.learning_rule)):\r\n            self.learning_rule[i].reset()\r\n\r\n    def getweight(self):\r\n        \"\"\"\r\n        获取网络的连接(包括权值等)\r\n        :return: 网络的连接\r\n        \"\"\"\r\n        return self.connection\r\n"
  },
  {
    "path": "braincog/model_zoo/convnet.py",
    "content": "import abc\nfrom functools import partial\nfrom torch.nn import functional as F\nimport torchvision\nfrom timm.models import register_model\nfrom braincog.base.node.node import *\nfrom braincog.base.connection.layer import *\nfrom braincog.base.encoder.encoder import *\nfrom braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule\n\n\nclass BaseConvNet(BaseModule, abc.ABC):\n    def __init__(self,\n                 step,\n                 input_channels,\n                 num_classes,\n                 encode_type,\n                 spike_output: bool,\n                 out_channels: list,\n                 block_depth: list,\n                 node_list: list,\n                 *args,\n                 **kwargs):\n        super().__init__(step, encode_type, *args, **kwargs)\n        self.num_cls = num_classes\n        self.spike_output = spike_output\n        self.groups = kwargs['n_groups'] if 'n_groups' in kwargs else 1\n        if not spike_output:\n            node_list.append(nn.Identity)\n            out_channels.append(self.num_cls)\n            self.vote = nn.Identity()\n            # self.vote = nn.Sequential(\n            #     nn.Linear(self.step, 32),\n            #     nn.ReLU(),\n            #     nn.Linear(32, 1)\n            # )\n        else:\n            out_channels.append(10 * self.num_cls)\n            self.vote = VotingLayer(10)\n\n        # check list length\n        if len(node_list) != len(out_channels):\n            raise ValueError\n        self.input_channels = input_channels\n        self.out_channels = out_channels\n        self.block_depth = block_depth\n        self.node_list = node_list\n        self.feature = self._create_feature()\n        self.fc = self._create_fc()\n        if self.layer_by_layer:\n            self.flatten = nn.Flatten(start_dim=1)\n        else:\n            self.flatten = nn.Flatten()\n\n    @staticmethod\n    def _create_feature(self):\n        raise NotImplementedError\n\n    @staticmethod\n    def _create_fc(self):\n        raise NotImplementedError\n\n    def forward(self, inputs):\n        inputs = self.encoder(inputs)\n        self.reset()\n        if not self.training:\n            self.fire_rate.clear()\n\n        if not self.layer_by_layer:\n            outputs = []\n            if self.warm_up:\n                step = 1\n            else:\n                step = self.step\n\n            for t in range(step):\n                x = inputs[t]\n                x = self.feature(x)\n                x = self.flatten(x)\n                x = self.fc(x)\n                x = self.vote(x)\n                outputs.append(x)\n\n            return sum(outputs) / len(outputs)\n            # outputs = torch.stack(outputs)\n            # outputs = rearrange(outputs, 't b c -> b c t')\n            # outputs = self.vote(outputs).squeeze()\n            # return outputs\n\n        else:\n            x = self.feature(inputs)\n            x = self.flatten(x)\n            x = self.fc(x)\n            if self.groups == 1:\n                x = rearrange(x, '(t b) c -> t b c', t=self.step).mean(0)\n            else:\n                x = rearrange(x, 'b (c t) -> t b c', t=self.step).mean(0)\n            x = self.vote(x)\n            return x\n\n\nclass MNISTConvNet(BaseConvNet):\n    def __init__(self,\n                 step,\n                 input_channels,\n                 num_classes,\n                 encode_type,\n                 block_depth,\n                 spike_output: bool,\n                 out_channels: list,\n                 node_list: list,\n                 *args,\n                 **kwargs):\n        self.feature_size = 28\n        super().__init__(step,\n                         input_channels,\n                         num_classes,\n                         encode_type,\n                         spike_output,\n                         out_channels,\n                         block_depth,\n                         node_list,\n                         *args,\n                         **kwargs)\n\n    def _create_feature(self):\n        feature_depth = len(self.node_list) - 2\n\n        feature = [BaseConvModule(\n            self.input_channels, self.out_channels[0], node=self.node_list[0])]\n        if self.block_depth[0] != 1:\n            feature.extend(\n                [BaseConvModule(self.out_channels[0], self.out_channels[0], node=self.node_list[0])] * (\n                    self.block_depth[0] - 1),\n            )\n        feature.append(nn.AvgPool2d(2))\n        self.feature_size = self.feature_size // 2\n\n        for i in range(1, feature_depth):\n            feature.append(BaseConvModule(\n                self.out_channels[i - 1], self.out_channels[i], node=self.node_list[i]))\n            if self.block_depth[i] != 1:\n                feature.extend(\n                    [BaseConvModule(self.out_channels[i], self.out_channels[i], node=self.node_list[i])] * (\n                        self.block_depth[0] - 1),\n                )\n            feature.append(nn.AvgPool2d(2))\n            feature.append(self.node_list[0]())\n            self.feature_size = self.feature_size // 2\n\n        return nn.Sequential(*feature)\n\n    def _create_fc(self):\n        fc = nn.Sequential(\n            NDropout(.5),\n            BaseLinearModule(self.out_channels[-3] * self.feature_size * self.feature_size, self.out_channels[-2],\n                             node=self.node_list[-2]),\n            NDropout(.5),\n            BaseLinearModule(\n                self.out_channels[-2], self.out_channels[-1], node=self.node_list[-1])\n        )\n        return fc\n\n\nclass CifarConvNet(BaseConvNet):\n    def __init__(self,\n                 step,\n                 input_channels,\n                 num_classes,\n                 encode_type,\n                 spike_output: bool,\n                 out_channels: list,\n                 node_list: list,\n                 block_depth: list,\n                 *args,\n                 **kwargs):\n        super().__init__(step,\n                         input_channels,\n                         num_classes,\n                         encode_type,\n                         spike_output,\n                         out_channels,\n                         block_depth,\n                         node_list,\n                         *args,\n                         **kwargs)\n\n    def _create_feature(self):\n        feature_depth = len(self.node_list) - 1\n\n        feature = [BaseConvModule(\n            self.input_channels * self.init_channel_mul, self.out_channels[0], node=self.node_list[0], groups=self.groups)]\n        if self.block_depth[0] != 1:\n            feature.extend(\n                [BaseConvModule(self.out_channels[0], self.out_channels[0], node=self.node_list[0], groups=self.groups)] * (\n                    self.block_depth[0] - 1),\n            )\n        feature.append(nn.AvgPool2d(2))\n        for i in range(1, feature_depth - 1):\n            feature.append(BaseConvModule(\n                self.out_channels[i - 1], self.out_channels[i], node=self.node_list[i], groups=self.groups))\n            if self.block_depth[i] != 1:\n                feature.extend(\n                    [BaseConvModule(self.out_channels[i], self.out_channels[i], node=self.node_list[i], groups=self.groups)] * (\n                        self.block_depth[i] - 1),\n                )\n            feature.append(nn.AvgPool2d(2))\n\n        feature.append(BaseConvModule(\n            self.out_channels[-3], self.out_channels[-2], node=self.node_list[-2], groups=self.groups))\n        if self.block_depth[feature_depth - 1] != 1:\n            feature.extend(\n                [BaseConvModule(self.out_channels[-2], self.out_channels[-2], node=self.node_list[-2], groups=self.groups)] * (\n                    self.block_depth[feature_depth - 1] - 1),\n            )\n        feature.append(nn.AdaptiveAvgPool2d((1, 1)))\n\n        return nn.Sequential(*feature)\n\n    def _create_fc(self):\n        fc = nn.Sequential(\n            # NDropout(.5),\n            BaseLinearModule(\n                self.out_channels[-2], self.out_channels[-1], node=self.node_list[-1], groups=self.groups)\n        )\n        return fc\n\n\n@register_model\ndef mnist_convnet(step,\n                  encode_type,\n                  spike_output: bool,\n                  node_type,\n                  *args,\n                  **kwargs):\n    out_channels = [128, 128, 2048]\n    block_depth = [1, 1]\n    node_cls = partial(node_type, step=step, **kwargs)\n    if spike_output:\n        node_list = [node_cls] * (len(out_channels) + 1)\n    else:\n        node_list = [node_cls] * (len(out_channels))\n\n    return MNISTConvNet(step=step,\n                        input_channels=1,\n                        encode_type=encode_type,\n                        block_depth=block_depth,\n                        node_list=node_list,\n                        out_channels=out_channels,\n                        spike_output=spike_output,\n                        **kwargs)\n\n\n@register_model\ndef cifar_convnet(step,\n                  encode_type,\n                  spike_output: bool,\n                  node_type,\n                  *args,\n                  **kwargs):\n    out_channels = [256, 256, 512, 1024]\n    # out_channels = [64, 128, 128, 256]\n    block_depth = [2, 2, 2, 2]\n    # print(kwargs)\n    node_cls = partial(node_type, step=step, **kwargs)\n    # print(node_cls)\n    if spike_output:\n        node_list = [node_cls] * (len(out_channels) + 1)\n    else:\n        node_list = [node_cls] * (len(out_channels))\n\n    return CifarConvNet(step=step,\n                        input_channels=3,\n                        encode_type=encode_type,\n                        node_list=node_list,\n                        block_depth=block_depth,\n                        out_channels=out_channels,\n                        spike_output=spike_output,\n                        **kwargs)\n\n\n@register_model\ndef dvs_convnet(step,\n                encode_type,\n                spike_output: bool,\n                node_type,\n                num_classes,\n                *args,\n                **kwargs):\n    out_channels = [128, 256, 256, 512, 512]\n    block_depth = [2, 1, 2, 1, 2]\n\n    # out_channels = [40, 80, 80, 160, 160]\n    # out_channels = [256, 512, 512, 1024, 1024]\n    # out_channels = [64, 128, 128, 256, 256]\n    # block_depth = [4, 2, 4, 2, 4]\n\n    # out_channels = [128, 256, 512, 512]\n    # block_depth = [2, 2, 2, 2]\n    node_cls = partial(node_type, step=step, **kwargs)\n    if spike_output:\n        node_list = [node_cls] * (len(out_channels) + 1)\n        # node_list[-2] = partial(DoubleSidePLIFNode, step=step, **kwargs)\n    else:\n        node_list = [node_cls] * (len(out_channels))\n        # node_list[-1] = partial(DoubleSidePLIFNode, step=step, **kwargs)\n\n    return CifarConvNet(step=step,\n                        input_channels=2,\n                        num_classes=num_classes,\n                        encode_type=encode_type,\n                        node_list=node_list,\n                        block_depth=block_depth,\n                        out_channels=out_channels,\n                        spike_output=spike_output,\n                        **kwargs)\n"
  },
  {
    "path": "braincog/model_zoo/fc_snn.py",
    "content": "from functools import partial\nfrom torch.nn import functional as F\nimport torchvision\nfrom timm.models import register_model\nfrom braincog.base.node.node import *\nfrom braincog.base.connection.layer import *\nfrom braincog.base.encoder.encoder import *\nfrom braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule\nfrom braincog.datasets import is_dvs_data\n\n\nclass STSC_Attention(nn.Module):\n    def __init__(self, n_channel: int, dimension: int = 2, time_rf: int = 4, reduction: int = 2):\n\n        super().__init__()\n        assert dimension == 4 or dimension == 2, 'dimension must be 4 or 2'\n\n        self.dimension = dimension\n\n        if self.dimension == 4:\n            self.avg_pool = nn.AdaptiveAvgPool2d(1)\n\n        self.time_padding = (time_rf - 1) // 2\n        self.n_channels = n_channel\n        r_channel = n_channel // reduction\n        self.recv_T = nn.Conv1d(n_channel, r_channel, kernel_size=time_rf, padding=self.time_padding, groups=1,\n                                bias=True)\n        self.recv_C = nn.Sequential(\n            nn.ReLU(),\n            nn.Linear(r_channel, n_channel, bias=False),\n        )\n        self.sigmoid = nn.Sigmoid()\n\n    def forward(self, x_seq: torch.Tensor):\n        assert x_seq.dim() == 3 or x_seq.dim() == 5, ValueError(\n            f'expected 3D or 5D input with shape [T, B, N] or [T, B, C, H, W], but got input with shape {x_seq.shape}')\n        x_seq_C = x_seq.transpose(0, 1)  # x_seq_C.shape = [B, T, N] or [B, T, C, H, W]\n        x_seq_T = x_seq_C.transpose(1, 2)  # x_seq_T.shape = [B, C, N] or [B, C, T, H, W]\n\n        if self.dimension == 2:\n            recv_h_T = self.recv_T(x_seq_T)\n            recv_h_C = self.recv_C(recv_h_T.transpose(1, 2))\n            D_ = 1 - self.sigmoid(recv_h_C)\n            D = D_.transpose(0, 1)\n\n        elif self.dimension == 4:\n            avgout_C = self.avg_pool(x_seq_C).view(\n                [x_seq_C.shape[0], x_seq_C.shape[1], x_seq_C.shape[2]])  # avgout_C.shape = [N, T, C]\n            avgout_T = avgout_C.transpose(1, 2)\n            recv_h_T = self.recv_T(avgout_T)\n            recv_h_C = self.recv_C(recv_h_T.transpose(1, 2))\n            D_ = 1 - self.sigmoid(recv_h_C)\n            D = D_.transpose(0, 1)\n\n        return D\n\n\nclass STSC_Temporal_Conv(nn.Module):\n    def __init__(self, channels: int, dimension: int = 2, time_rf: int = 2):\n\n        super().__init__()\n        assert dimension == 4 or dimension == 2, 'dimension must be 4 or 2'\n        self.dimension = dimension\n\n        time_padding = (time_rf - 1) // 2\n        self.time_padding = time_padding\n\n        if dimension == 4:\n            kernel_size = (time_rf, 1, 1)\n            padding = (time_padding, 0, 0)\n            self.conv = nn.Conv3d(channels, channels, kernel_size=kernel_size, padding=padding, groups=channels,\n                                  bias=False)\n        else:\n            kernel_size = time_rf\n            self.conv = nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=time_padding, groups=channels,\n                                  bias=False)\n\n    def forward(self, x_seq: torch.Tensor):\n        assert x_seq.dim() == 3 or x_seq.dim() == 5, ValueError(\n            f'expected 3D or 5D input with shape [T, B, N] or [T, B, C, H, W], but got input with shape {x_seq.shape}')\n\n        # x_seq.shape = [T, B, N] or [T, B, C, H, W]\n\n        x_seq = x_seq.transpose(0, 1)  # x_seq.shape = [B, T, N] or [B, T, C, H, W]\n        x_seq = x_seq.transpose(1, 2)  # x_seq.shape = [B, N, T] or [B, C, T, H, W]\n        x_seq = self.conv(x_seq)\n        x_seq = x_seq.transpose(1, 2)  # x_seq.shape = [B, T, N] or [B, T, C, H, W]\n        x_seq = x_seq.transpose(0, 1)  # x_seq.shape = [T, B, N] or [T, B, C, H, W]\n\n        return x_seq\n\n\nclass STSC(nn.Module):\n    def __init__(self, in_channel: int, dimension: int = 2, time_rf_conv: int = 5, time_rf_at: int = 3, use_gate=True,\n                 use_filter=True, reduction: int = 1):\n\n        super().__init__()\n\n        assert dimension == 4 or dimension == 2, 'dimension must be 4 or 2'\n        self.dimension = dimension\n\n        self.time_rf_conv = time_rf_conv\n        self.time_rf_at = time_rf_at\n\n        if use_filter:\n            self.temporal_conv = STSC_Temporal_Conv(in_channel, time_rf=time_rf_conv, dimension=dimension)\n\n        if use_gate:\n            self.spatio_temporal_attention = STSC_Attention(in_channel, time_rf=time_rf_at, reduction=reduction,\n                                                            dimension=dimension)\n\n        self.use_gate = use_gate\n        self.use_filter = use_filter\n\n    def forward(self, x_seq: torch.Tensor):\n        assert x_seq.dim() == 3 or x_seq.dim() == 5, ValueError(\n            f'expected 3D or 5D input with shape [T, B, N] or [T, B, C, H, W], but got input with shape {x_seq.shape}')\n\n        if self.use_filter:\n            # Filitering\n            x_seq_conv = self.temporal_conv(x_seq)\n        else:\n            # without filtering\n            x_seq_conv = x_seq\n\n        if self.dimension == 2:\n            if self.use_gate:\n                # Gating\n                x_seq_D = self.spatio_temporal_attention(x_seq)\n                y_seq = x_seq_conv * x_seq_D\n            else:\n                # without gating\n                y_seq = x_seq_conv\n        else:\n            if self.use_gate:\n                # Gating\n                x_seq_D = self.spatio_temporal_attention(x_seq)\n                y_seq = x_seq_conv * x_seq_D[:, :, :, None, None]  # broadcast\n            else:\n                # without gating\n                y_seq = x_seq_conv\n\n        return y_seq\n\n\n@register_model\nclass SHD_SNN(BaseModule):\n    \"\"\"\n    在SHD数据集上的SNN基准网络：Input-128FC-128FC-100FC-Voting-20.\n    STSC是增强时序信息的模块, 参考https://www.frontiersin.org/articles/10.3389/fnins.2022.1079357.\n    不加STSC模块的acc在78%左右\n    \"\"\"\n    def __init__(self,\n                 num_classes=20,\n                 step=15,\n                 node_type=LIFNode,\n                 encode_type='direct',\n                 *args,\n                 **kwargs):\n        super().__init__(step, encode_type, *args, **kwargs)\n\n        self.n_preact = kwargs['n_preact'] if 'n_preact' in kwargs else False\n\n        self.num_classes = num_classes\n        self.tet_loss = kwargs['tet_loss'] if 'tet_loss' in kwargs else False\n\n        self.node = node_type\n        if issubclass(self.node, BaseNode):\n            self.node = partial(self.node, **kwargs, step=step)\n\n        self.dataset = kwargs['dataset']\n        self.ts_conv = STSC(700, dimension=2, time_rf_conv=5, time_rf_at=3, use_gate=True, use_filter=True)\n        self.fc = nn.Sequential(\n            nn.Linear(700, 128),\n            partial(self.node, **kwargs)(),\n            nn.Linear(128, 128),\n            partial(self.node, **kwargs)(),\n            nn.Linear(128, 100),\n            partial(self.node, **kwargs)(),\n            VotingLayer(5)\n        )\n\n    def forward(self, inputs):\n        inputs = self.encoder(inputs)\n        self.reset()\n\n        if self.layer_by_layer:\n            inputs = rearrange(inputs, '(t b) c -> t b c', t=self.step)\n            inputs = self.ts_conv(inputs)\n            x = rearrange(inputs, 't b c -> (t b) c', t=self.step)\n            x = self.fc(x)\n            x = rearrange(x, '(t b) c -> t b c', t=self.step).mean(0)\n            return x\n\n        else:\n            outputs = []\n            inputs = self.ts_conv(inputs)\n            for t in range(self.step):\n                x = inputs[t]\n                x = self.fc(x)\n                outputs.append(x)\n            return sum(outputs) / len(outputs)\n"
  },
  {
    "path": "braincog/model_zoo/glsnn.py",
    "content": "import abc\r\nfrom functools import partial\r\nfrom timm.models import register_model\r\nfrom braincog.base.node.node import *\r\nfrom braincog.base.connection.layer import *\r\nfrom braincog.base.encoder.encoder import *\r\nfrom braincog.model_zoo.base_module import BaseModule, BaseLinearModule, BaseConvModule\r\nfrom braincog.utils import rand_ortho, mse\r\nfrom torch import autograd\r\n\r\n\r\nclass BaseGLSNN(BaseModule):\r\n    \"\"\"\r\n    The fully connected model of the GLSNN\r\n    :param input_size: the shape of the input\r\n    :param hidden_sizes: list, the number of neurons of each layer in the hidden layers\r\n    :param ouput_size: the number of the output layers\r\n    \"\"\"\r\n\r\n    def __init__(self, input_size=784, hidden_sizes=[800] * 3, output_size=10, opt=None):\r\n        super().__init__(step=opt.step, encode_type=opt.encode_type)\r\n        network_sizes = [input_size] + hidden_sizes + [output_size]\r\n        feedforward = []\r\n        for ind in range(len(network_sizes) - 1):\r\n            feedforward.append(\r\n                BaseLinearModule(in_features=network_sizes[ind], out_features=network_sizes[ind + 1], node=LIFNode))\r\n        self.ff = nn.ModuleList(feedforward)\r\n        feedback = []\r\n        for ind in range(1, len(network_sizes) - 2):\r\n            feedback.append(nn.Linear(network_sizes[-1], network_sizes[ind]))\r\n        self.fb = nn.ModuleList(feedback)\r\n\r\n        for m in self.modules():\r\n            if isinstance(m, nn.Linear):\r\n                out_, in_ = m.weight.shape\r\n                m.weight.data = torch.Tensor(rand_ortho((out_, in_), np.sqrt(6. / (out_ + in_))))\r\n                m.bias.data.zero_()\r\n        self.step = opt.step\r\n        self.lr_target = opt.lr_target\r\n\r\n    def forward(self, x):\r\n        \"\"\"\r\n        process the information in the forward manner\r\n        :param x: the input\r\n        \"\"\"\r\n        self.reset()\r\n        x = x.view(x.shape[0], 784)\r\n        sumspikes = [0] * (len(self.ff) + 1)\r\n        sumspikes[0] = x\r\n        for ind, mod in enumerate(self.ff):\r\n            for t in range(self.step):\r\n                spike = mod(sumspikes[ind])\r\n                sumspikes[ind + 1] += spike\r\n            sumspikes[ind + 1] = sumspikes[ind + 1] / self.step\r\n        return sumspikes\r\n\r\n    def feedback(self, ff_value, y_label):\r\n        \"\"\"\r\n        process information in the feedback manner and get target\r\n        :param ff_value: the feedforward value of each layer\r\n        :param y_label: the label of the corresponding input\r\n        \"\"\"\r\n        fb_value = []\r\n        cost = mse(ff_value[-1], y_label)\r\n        P = ff_value[-1]\r\n        h_ = ff_value[-2] - self.lr_target * torch.autograd.grad(cost, ff_value[-2], retain_graph=True)[0]\r\n        fb_value.append(h_)\r\n        for i in range(len(self.fb) - 1, -1, -1):\r\n            h = ff_value[i + 1]\r\n            h_ = h - self.fb[i](P - y_label)\r\n            fb_value.append(h_)\r\n        return fb_value, cost\r\n\r\n    def set_gradient(self, x, y):\r\n        \"\"\"\r\n        get the corresponding update of each layer\r\n        \"\"\"\r\n        ff_value = self.forward(x)\r\n\r\n        fb_value, cost = self.feedback(ff_value, y)\r\n\r\n        ff_value = ff_value[1:]\r\n        len_ff = len(self.ff)\r\n        for idx, layer in enumerate(self.ff):\r\n            if idx == len_ff - 1:\r\n                layer.fc.weight.grad, layer.fc.bias.grad = autograd.grad(cost, layer.fc.parameters())\r\n            else:\r\n                in1 = ff_value[idx]\r\n                in2 = fb_value[len(fb_value) - 1 - idx]\r\n                loss_local = mse(in1, in2.detach())\r\n                layer.fc.weight.grad, layer.fc.bias.grad = autograd.grad(loss_local, layer.fc.parameters())\r\n        return ff_value, cost\r\n\r\n    def forward_parameters(self):\r\n        res = []\r\n        for layer in self.ff:\r\n            res += layer.parameters()\r\n        return res\r\n\r\n    def feedback_parameters(self):\r\n        res = []\r\n        for layer in self.fb:\r\n            res += layer.parameters()\r\n        return res\r\n\r\n\r\nif __name__ == '__main__':\r\n    net = BaseGLSNN()\r\n    print(net)\r\n"
  },
  {
    "path": "braincog/model_zoo/linearNet.py",
    "content": "import torch.nn.functional as F\r\n\r\nfrom braincog.base.strategy.surrogate import *\r\nfrom braincog.base.node.node import IFNode\r\nfrom braincog.base.learningrule.STDP import STDP, MutliInputSTDP\r\n\r\n\r\nclass droDMTrainNet(nn.Module):\r\n    \"\"\"\r\n    Drosophila Training network: compound eye-KC-MBON\r\n    \"\"\"\r\n\r\n    def __init__(self, connection):\r\n        \"\"\"\r\n        根据传入的连接 构建训练网络\r\n        :param connection: 训练网络的连接\r\n        \"\"\"\r\n\r\n        super().__init__()\r\n        trace_stdp = 0.99\r\n        self.num_subMB = 3\r\n        self.node = [IFNode() for i in range(self.num_subMB)]\r\n        self.connection = connection\r\n        self.learning_rule = []\r\n        self.learning_rule.append(STDP(self.node[0], self.connection[0], trace_stdp))\r\n        self.learning_rule.append(STDP(self.node[1], self.connection[1], trace_stdp))\r\n        self.learning_rule.append(MutliInputSTDP(self.node[2], [self.connection[2], self.connection[3]], trace_stdp))\r\n\r\n        self.out_vis = torch.zeros((self.connection[0].weight.shape[1]), dtype=torch.float)\r\n        self.out_KC = torch.zeros((self.connection[1].weight.shape[1]), dtype=torch.float)\r\n        self.out_MBON = torch.zeros((self.connection[2].weight.shape[1]), dtype=torch.float)\r\n\r\n    def forward(self, input):\r\n        \"\"\"\r\n        根据输入得到输出\r\n        :param input: 输入电流\r\n        :return: 网络的输出，以及网络运行产生的STDP可塑性\r\n        \"\"\"\r\n        self.out_vis = self.node[0](self.connection[0](input))\r\n        self.out_KC, dw_kc = self.learning_rule[1](self.out_vis)\r\n        self.out_MBON, dw_mbon = self.learning_rule[2](self.out_KC, self.out_MBON)\r\n        return self.out_MBON, dw_kc[0], dw_mbon[0]\r\n\r\n    def UpdateWeight(self, i, dw):\r\n        \"\"\"\r\n        更新网络中第i组连接的权重\r\n        :param i: 要更新的连接的索引\r\n        :param dw: 更新的量\r\n        :return: None\r\n        \"\"\"\r\n        self.connection[i].update(dw)\r\n        self.connection[i].weight.data = F.normalize(self.connection[i].weight.data.float(), p=1, dim=1)\r\n\r\n    def reset(self):\r\n        \"\"\"\r\n        reset神经元或学习法则的中间量\r\n        :return: None\r\n        \"\"\"\r\n        for i in range(self.num_subMB):\r\n            self.node[i].n_reset()\r\n        for i in range(len(self.learning_rule)):\r\n            self.learning_rule[i].reset()\r\n\r\n    def getweight(self):\r\n        \"\"\"\r\n        获取网络的连接(包括权值等)\r\n        :return: 网络的连接\r\n        \"\"\"\r\n        return self.connection\r\n"
  },
  {
    "path": "braincog/model_zoo/nonlinearNet.py",
    "content": "import torch.nn.functional as F\r\n\r\nfrom braincog.base.strategy.surrogate import *\r\nfrom braincog.base.node.node import IFNode\r\nfrom braincog.base.learningrule.STDP import STDP, MutliInputSTDP\r\n\r\n\r\nclass droDMTestNet(nn.Module):\r\n    \"\"\"\r\n    Drosophila Testing Network: compound eye-KC-MBON  DA-GABA-MB\r\n    \"\"\"\r\n\r\n    def __init__(self, connection):\r\n        \"\"\"\r\n        根据传入的连接 构建测试网络\r\n        :param connection: 测试网络的连接\r\n        \"\"\"\r\n        super().__init__()\r\n        trace_stdp = 0.99\r\n        self.num_subMB = 5\r\n        self.node = [IFNode() for i in range(self.num_subMB)]\r\n        self.connection = connection\r\n        self.learning_rule = []\r\n        self.learning_rule.append(STDP(self.node[0], self.connection[0], trace_stdp))\r\n        self.learning_rule.append(MutliInputSTDP(self.node[1], [self.connection[1], self.connection[5]], trace_stdp))\r\n        self.learning_rule.append(MutliInputSTDP(self.node[2], [self.connection[2], self.connection[3], self.connection[9]], trace_stdp))\r\n        self.learning_rule.append(MutliInputSTDP(self.node[3], [self.connection[4], self.connection[6]], trace_stdp))\r\n        self.learning_rule.append(MutliInputSTDP(self.node[4], [self.connection[7], self.connection[8]], trace_stdp))\r\n\r\n        self.out_vis = torch.zeros((self.connection[0].weight.shape[1]), dtype=torch.float)\r\n        self.out_KC = torch.zeros((self.connection[1].weight.shape[1]), dtype=torch.float)\r\n        self.out_MBON = torch.zeros((self.connection[2].weight.shape[1]), dtype=torch.float)\r\n        self.out_APL = torch.zeros((self.connection[4].weight.shape[1]), dtype=torch.float)\r\n        self.out_DA = torch.zeros((self.connection[7].weight.shape[1]), dtype=torch.float)\r\n\r\n    def forward(self, input, input_da):\r\n        \"\"\"\r\n        根据输入得到输出\r\n        :param input: 输入电流\r\n        :return: 网络的输出，以及网络运行产生的STDP可塑性\r\n        \"\"\"\r\n        self.out_vis = self.node[0](self.connection[0](input))\r\n        self.out_KC, dw_kc = self.learning_rule[1](self.out_vis, self.out_APL)\r\n        self.out_MBON, dw_mbon = self.learning_rule[2](self.out_KC, self.out_MBON, self.out_DA)\r\n        self.out_APL, dw_apl = self.learning_rule[3](self.out_KC, self.out_DA)\r\n        self.out_DA, dw_da = self.learning_rule[4](self.out_APL, input_da)\r\n        return self.out_MBON, dw_kc[1], dw_apl[0]\r\n\r\n    def UpdateWeight(self, i, dw):\r\n        \"\"\"\r\n        更新网络中第i组连接的权重\r\n        :param i: 要更新的连接的索引\r\n        :param dw: 更新的量\r\n        :return: None\r\n        \"\"\"\r\n        self.connection[i].update(dw)\r\n        self.connection[i].weight.data = F.normalize(self.connection[i].weight.data.float(), p=1, dim=0)\r\n\r\n    def reset(self):\r\n        \"\"\"\r\n        reset神经元或学习法则的中间量\r\n        :return: None\r\n        \"\"\"\r\n        for i in range(self.num_subMB):\r\n            self.node[i].n_reset()\r\n        for i in range(len(self.learning_rule)):\r\n            self.learning_rule[i].reset()\r\n\r\n    def getweight(self):\r\n        \"\"\"\r\n        获取网络的连接(包括权值等)\r\n        :return: 网络的连接\r\n        \"\"\"\r\n        return self.connection\r\n"
  },
  {
    "path": "braincog/model_zoo/qsnn.py",
    "content": "import numpy as np\nfrom scipy.linalg import orth\nfrom scipy.special import expit\nfrom scipy.signal import fftconvolve\nimport torch\nfrom torch.nn import Parameter\nimport torch.nn as nn\nfrom braincog.datasets.gen_input_signal import lambda_max, dt\nfrom braincog.base.encoder import QSEncoder\n\ngamma = 0.1\nbeta = 1.0\ntheta = 3.0\n\n# kernel parameters\ntau_s = 4.0  # synaptic time constant\ntau_L = 10.0  # leak time constant\n\n# conductance parameters\ng_B = 0.6                                   # basal conductance\ng_A = 0.05                                  # apical conductance\ng_L = 1.0 / tau_L                             # leak conductance\ng_D = g_B                                   # dendritic conductance in output layer\n\nk_D = g_D / (g_L + g_D)\n\n\nSTEPS = int(50 / dt)\nSLEN = 20                      # spike time length\n\n# --- sigmoid function --- #\n\n\ndef sigma(x):\n    return torch.sigmoid(x)\n\n# def sigma(x):\n#     return gamma * np.log(1+np.exp(beta*(x-theta)))\n\n\ndef deriv_sigma(x):\n    return sigma(x) * (1.0 - sigma(x))\n\n\n# kernel parameters\ntau_s = 4.0                                                   # synaptic time constant\ntau_L = 10.0                                                # leak time constant\n# --- kernel function --- #\nmem = STEPS\n\n\ndef kappa(x):\n    return np.exp(-x / tau_s)\n\n\ndef get_kappas(n):\n    return np.array([kappa(i + 1) for i in range(n)])\n\n\nkappas = get_kappas(mem // 2)  # initialize kappas array\nkernel = np.zeros(mem)\n\nkernel[:mem // 2] = kappas[:]\nkernel[mem // 2:] = -np.flipud(kappas)[:]\n\n\nW_MIN = -1.0\nW_MAX = 1.0\n\n\nclass Net(nn.Module):\n    \"\"\"\n    两房室脉冲神经网络\n    \"\"\"\n    def __init__(self, net_size):\n        super().__init__()\n        self.input_size = net_size[0]\n        self.hidden_layers = nn.ModuleList([Hidden_layer(net_size[i], net_size[i + 1], net_size[-1]) for i in range(len(net_size) - 2)])\n        self.out_layer = Output_layer(net_size[-2], net_size[-1])\n        self.kernel = torch.from_numpy(kernel[:, np.newaxis]).cuda()\n        self.qs_code = QSEncoder\n\n    def update_state(self, input_, label, test):\n        if len(self.hidden_layers) > 1:\n            self.hidden_layers[0].update_state(input_, self.out_layer.spike_rate, test=test)\n            for i in range(len(self.hidden_layers) - 2):\n                self.hidden_layers[i + 1].update_state(self.hidden_layers[i].spike_rate, self.out_layer.spike_rate, test=test)\n\n            self.hidden_layers[-1].update_state(self.hidden_layers[-2].spike_rate, self.out_layer.spike_rate, test=test)\n        else:\n            self.hidden_layers[0].update_state(input_, self.out_layer.spike_rate, test=test)\n\n        self.out_layer.update_state(self.hidden_layers[-1].spike_rate, label, test=test)\n\n    def routine(self,\n                input_,\n                input_delta,\n                image_ori,\n                image_ori_delta,\n                shift,\n                label,\n                test=False,\n                noise=False,\n                noise_rate=None):\n        \"\"\"\n        网络信息处理过程\n        :param input_: 输入图片\n        :param input_delta: 输入扰动图片，用于计算相位\n        :param image_ori: 原始图片\n        :param image_ori_delta: 原始扰动图片\n        :param shift: 是否反转背景\n        :param label: 输入数据分类标签\n        :param test: 是否是测试阶段\n        :param noise: 是否增加噪声\n        :param noise_rate: 噪声比例\n        \"\"\"\n        encoder = self.qs_code(lambda_max, STEPS, SLEN, shift, noise, noise_rate)\n        input_ = encoder(input_, input_delta, image_ori, image_ori_delta)\n        input_ = torch.from_numpy(input_).to(self.kernel.device)\n        psp = torch.mm(input_, self.kernel).abs().float()\n\n        for i in range(STEPS):\n            self.update_state(psp, label, test=test)\n\n    def update_weight(self, lr, t, beta, eps):\n        self.out_layer.update_weight(lr, t, beta, eps)\n        if len(self.hidden_layers) > 1:\n            self.hidden_layers[-1].update_weight(self.out_layer.delta, lr, t, beta, eps)\n            for i in range(len(self.hidden_layers) - 1):\n                self.hidden_layers[-(i + 2)].update_weight(self.hidden_layers[-(i + 1)].delta, lr, t, beta, eps)\n        else:\n            self.hidden_layers[0].update_weight(self.out_layer.delta, lr, t, beta, eps)\n\n    def predict(self,\n                input_,\n                input_delta,\n                image_ori,\n                image_ori_delta,\n                shift,\n                noise,\n                noise_rate=0):\n        self.routine(input_,\n                     input_delta,\n                     image_ori=image_ori,\n                     image_ori_delta=image_ori_delta,\n                     shift=shift,\n                     label=None,\n                     test=True,\n                     noise=noise,\n                     noise_rate=noise_rate)\n\n        pred = torch.argmax(self.out_layer.spike_rate.flatten())\n        return pred\n\n\nclass Hidden_layer(nn.Module):\n    \"\"\"\n    隐藏层两房室网络\n    \"\"\"\n    def __init__(self, input_size, neu_num, fb_neus):\n        super().__init__()\n        self.basal_linear = nn.Linear(input_size, neu_num)\n        nn.init.uniform_(self.basal_linear.weight, -0.1, 0.1)\n        nn.init.uniform_(self.basal_linear.bias, -0.1, 0.1)\n        self.soma_V = 0.0\n        self.basal_V = 0.0\n        # for adam\n        self.m = 0.0\n        self.v = 0.0\n        self.m_hat = 0.0\n        self.v_hat = 0.0\n        self.m_b = 0.0\n        self.v_b = 0.0\n        self.m_b_hat = 0.0\n        self.v_b_hat = 0.0\n        # backprop\n        self.delta = 0.0\n\n    def update_state(self, basal_input, apical_input, test):\n        self.basal_input = basal_input.T  # [1, 781]\n        self.basal_V = self.basal_linear(basal_input.T)\n        self.soma_V = self.soma_V + 1 / tau_L * (-self.soma_V + g_B / g_L * (self.basal_V - self.soma_V)) * dt\n        self.spike_rate = lambda_max * sigma(self.soma_V)\n\n    def update_weight(self, delta_, lr, t, beta, eps):\n        weight_dot = lambda_max * k_D * delta_ * deriv_sigma(k_D * self.basal_V)  # [1, 500]\n        self.delta = torch.mm(weight_dot, self.basal_linear.weight.data)  # [500, 784] x [1, 500]\n        weight_delta = weight_dot[:, :, None] * self.basal_input[:, None, :]\n        bias_delta = weight_dot\n        self.m = beta[0] * self.m + (1 - beta[0]) * weight_delta\n        self.v = beta[1] * self.v + (1 - beta[1]) * torch.square(weight_delta)\n        self.m_hat = self.m / (1 - beta[0] ** t)\n        self.v_hat = self.v / (1 - beta[1] ** t)\n        self.m_b = beta[0] * self.m_b + (1 - beta[0]) * bias_delta\n        self.v_b = beta[1] * self.v_b + (1 - beta[1]) * torch.square(bias_delta)\n        self.m_b_hat = self.m_b / (1 - beta[0] ** t)\n        self.v_b_hat = self.v_b / (1 - beta[1] ** t)\n        # update weight\n        weight_delta = lr * self.m_hat / (torch.sqrt(self.v_hat) + eps)\n        bias_delta = lr * self.m_b_hat / (torch.sqrt(self.v_b_hat) + eps)\n        self.basal_linear.weight.data.sub_(weight_delta.mean(0))\n        self.basal_linear.bias.data.sub_(bias_delta.mean(0))\n\n\nclass Output_layer(nn.Module):\n    \"\"\"\n    输出层两房室网络\n    \"\"\"\n    def __init__(self, input_size, neu_num):\n        super().__init__()\n        self.basal_linear = nn.Linear(input_size, neu_num)\n        nn.init.uniform_(self.basal_linear.weight, -0.1, 0.1)\n        nn.init.uniform_(self.basal_linear.bias, -0.1, 0.1)\n        self.soma_V = 0.0\n        self.basal_V = 0.0\n        self.spike_rate = 0.0\n        # adam\n        self.m = 0.0\n        self.v = 0.0\n        self.m_hat = 0.0\n        self.v_hat = 0.0\n        self.m_b = 0.0\n        self.v_b = 0.0\n        self.m_b_hat = 0.0\n        self.v_b_hat = 0.0\n        # backprop\n        self.delta = 0.0\n\n    def update_state(self, basal_input, I, test):\n        self.basal_input = basal_input\n        self.basal_V = self.basal_linear(basal_input)\n        if test:\n            self.soma_V = self.soma_V + 1 / tau_L * (-self.soma_V + g_B / g_L * (self.basal_V - self.soma_V)) * dt\n        else:\n            self.soma_V = self.soma_V + 1 / tau_L * (-self.soma_V + g_B / g_L * (self.basal_V - self.soma_V) +\n                                                     I - self.soma_V) * dt\n        self.spike_rate = lambda_max * sigma(self.soma_V)\n\n    def update_weight(self, lr, t, beta, eps):\n        weight_dot = lambda_max * k_D * (sigma(k_D * self.basal_V) - sigma(self.soma_V)) * deriv_sigma(k_D * self.basal_V)\n        self.delta = torch.mm(weight_dot, self.basal_linear.weight.data)     # [1, 500]\n        bias_delta = weight_dot  # [1, 10]\n        weight_delta = weight_dot[:, :, None] * self.basal_input[:, None, :]   # [1, 10, 500]\n        self.m = beta[0] * self.m + (1 - beta[0]) * weight_delta\n        self.v = beta[1] * self.v + (1 - beta[1]) * torch.square(weight_delta)\n        self.m_hat = self.m / (1 - beta[0] ** t)\n        self.v_hat = self.v / (1 - beta[1] ** t)\n        self.m_b = beta[0] * self.m_b + (1 - beta[0]) * bias_delta\n        self.v_b = beta[1] * self.v_b + (1 - beta[1]) * torch.square(bias_delta)\n        self.m_b_hat = self.m_b / (1 - beta[0] ** t)\n        self.v_b_hat = self.v_b / (1 - beta[1] ** t)\n        # update weight\n        weight_delta = lr * self.m_hat / (torch.sqrt(self.v_hat) + eps)\n        bias_delta = lr * self.m_b_hat / (torch.sqrt(self.v_b_hat) + eps)\n        self.basal_linear.weight.data.sub_(weight_delta.mean(0))\n        self.basal_linear.bias.data.sub_(bias_delta.mean(0))\n"
  },
  {
    "path": "braincog/model_zoo/resnet.py",
    "content": "'''\nDeep Residual Learning for Image Recognition\nhttps://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py\n'''\nimport os\nimport sys\nfrom functools import partial\nfrom timm.models import register_model\nfrom timm.models.layers import trunc_normal_, DropPath\nfrom braincog.model_zoo.base_module import *\nfrom braincog.base.node.node import *\n\n__all__ = [\n    'ResNet',\n    'resnet18',\n    'resnet34_half',\n    'resnet34',\n    'resnet50_half',\n    'resnet50',\n    'resnet101',\n    'resnet152',\n    'resnext50_32x4d',\n    'resnext101_32x8d',\n    'wide_resnet50_2',\n    'wide_resnet101_2',\n]\n\n\ndef conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):\n    '''3x3 convolution with padding'''\n    return nn.Conv2d(in_planes,\n                     out_planes,\n                     kernel_size=3,\n                     stride=stride,\n                     padding=dilation,\n                     groups=groups,\n                     bias=False,\n                     dilation=dilation)\n\n\ndef conv1x1(in_planes, out_planes, stride=1):\n    '''1x1 convolution'''\n    return nn.Conv2d(in_planes,\n                     out_planes,\n                     kernel_size=1,\n                     stride=stride,\n                     bias=False)\n\n\nclass BasicBlock(nn.Module):\n    \"\"\"\n    ResNet的基础模块, 采用identity-connection的方式.\n    :param inplanes: 输出通道数\n    :param planes: 内部通道数量\n    :param stride: stride\n    :param downsample: 是否降采样\n    :param groups: 分组卷积\n    :param base_width: 基础通道数量\n    :param dilation: 空洞卷积\n    :param norm_layer: Norm的方式\n    :param node: 神经元类型, 默认为 ``LIFNode``\n    \"\"\"\n\n    expansion = 1\n    __constants__ = ['downsample']\n    def __init__(self,\n                 inplanes,\n                 planes,\n                 stride=1,\n                 downsample=None,\n                 groups=1,\n                 base_width=64,\n                 dilation=1,\n                 norm_layer=None,\n                 node=LIFNode):\n        super(BasicBlock, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        if groups != 1 or base_width != 64:\n            raise ValueError(\n                'BasicBlock only supports groups=1 and base_width=64')\n        if dilation > 1:\n            raise NotImplementedError(\n                'Dilation > 1 not supported in BasicBlock')\n        # Both self.conv1 and self.downsample layers downsample the input when stride != 1\n        self.bn1 = norm_layer(inplanes)\n        self.node1 = node()\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        # self.relu = nn.ReLU(inplace=False)\n        self.node2 = node()\n        self.bn2 = norm_layer(planes)\n        self.conv2 = conv3x3(planes, planes)\n\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        identity = x\n\n        out = self.bn1(x)\n        out = self.node1(out)\n        out = self.conv1(out)\n\n        out = self.bn2(out)\n        out = self.node2(out)\n        out = self.conv2(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n\n        return out\n\n\nclass Bottleneck(nn.Module):\n    \"\"\"\n    ResNet的Botteneck模块, 采用identity-connection的方式.\n    :param inplanes: 输出通道数\n    :param planes: 内部通道数量\n    :param stride: stride\n    :param downsample: 是否降采样\n    :param groups: 分组卷积\n    :param base_width: 基础通道数量\n    :param dilation: 空洞卷积\n    :param norm_layer: Norm的方式\n    :param node: 神经元类型, 默认为 ``LIFNode``\n    \"\"\"\n    expansion = 4\n    __constants__ = ['downsample']\n\n    def __init__(self,\n                 inplanes,\n                 planes,\n                 stride=1,\n                 downsample=None,\n                 groups=1,\n                 base_width=64,\n                 dilation=1,\n                 norm_layer=None,\n                 node=torch.nn.Identity):\n        super(Bottleneck, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        width = int(planes * (base_width / 64.)) * groups\n        # Both self.conv2 and self.downsample layers downsample the input when stride != 1\n        self.bn1 = norm_layer(inplanes)\n        self.conv1 = conv1x1(inplanes, width)\n\n        self.bn2 = norm_layer(width)\n        self.conv2 = conv3x3(width, width, stride, groups, dilation)\n\n        self.bn3 = norm_layer(width)\n        self.conv3 = conv1x1(width, planes * self.expansion)\n\n        # self.relu = nn.ReLU(inplace=False)\n        self.downsample = downsample\n        self.stride = stride\n        self.node1 = node()\n        self.node2 = node()\n        self.node3 = node()\n\n    def forward(self, x):\n        identity = x\n\n        out = self.bn1(x)\n        out = self.node1(out)\n        out = self.conv1(out)\n\n        out = self.bn2(out)\n        out = self.node2(out)\n        out = self.conv2(out)\n\n        out = self.bn3(out)\n        out = self.node3(out)\n        out = self.conv3(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        return out\n\n\nclass ResNet(BaseModule):\n    \"\"\"\n    ResNet-SNN\n    :param block: Block类型\n    :param layers: block 层数\n    :param inplanes: 输入通道数量\n    :param num_classes: 输出类别数\n    :param zero_init_residual: 是否使用零初始化\n    :param groups: 卷积分组\n    :param width_per_group: 每一组的宽度\n    :param replace_stride_with_dilation: 是否使用stride替换dilation\n    :param norm_layer: Norm 方式, 默认为 ``BatchNorm``\n    :param step: 仿真步长, 默认为 ``8``\n    :param encode_type: 编码方式, 默认为 ``direct``\n    :param spike_output: 是否使用脉冲输出, 默认为 ``False``\n    :param args:\n    :param kwargs:\n    \"\"\"\n    def __init__(self,\n                 block,\n                 layers,\n                 inplanes=64,\n                 num_classes=10,\n                 zero_init_residual=False,\n                 groups=1,\n                 width_per_group=64,\n                 replace_stride_with_dilation=None,\n                 norm_layer=None,\n                 step=8,\n                 encode_type='direct',\n                 spike_output=False,\n                 *args,\n                 **kwargs):\n        super().__init__(\n            step,\n            encode_type,\n            *args,\n            **kwargs\n        )\n        self.spike_output = spike_output\n        self.num_classes = num_classes\n\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        self._norm_layer = norm_layer\n\n        # print('inplanes %d' % inplanes)\n        self.inplanes = inplanes\n        self.interplanes = [\n            self.inplanes, self.inplanes * 2, self.inplanes * 4,\n            self.inplanes * 8\n        ]\n        self.dilation = 1\n\n        self.node = kwargs['node_type']\n        if issubclass(self.node, BaseNode):\n            self.node = partial(self.node, **kwargs)\n\n        if replace_stride_with_dilation is None:\n            # each element in the tuple indicates if we should replace\n            # the 2x2 stride with a dilated convolution instead\n            replace_stride_with_dilation = [False, False, False]\n        if len(replace_stride_with_dilation) != 3:\n            raise ValueError('replace_stride_with_dilation should be None '\n                             'or a 3-element tuple, got {}'.format(\n                                 replace_stride_with_dilation))\n        self.groups = groups\n        self.base_width = width_per_group\n        self.static_data = False\n\n        self.dataset = kwargs['dataset']\n        if self.dataset == 'dvsg' or self.dataset == 'dvsc10' or self.dataset == 'NCALTECH101' or self.dataset == 'NCARS' or self.dataset == 'DVSG':\n            self.conv1 = nn.Conv2d(2 * self.init_channel_mul,\n                                   self.inplanes,\n                                   kernel_size=3,\n                                   padding=1,\n                                   bias=False)\n        elif self.dataset == 'imnet':\n            self.conv1 = nn.Conv2d(3 * self.init_channel_mul,\n                                   self.inplanes,\n                                   kernel_size=7,\n                                   stride=2,\n                                   padding=3,\n                                   bias=False)\n            self.static_data = True\n        elif self.dataset == 'esimnet':\n            reconstruct = kwargs[\"reconstruct\"] if \"reconstruct\" in kwargs else False\n            print(reconstruct)\n            if reconstruct:\n                self.conv1 = nn.Conv2d(1 * self.init_channel_mul,\n                                       self.inplanes,\n                                       kernel_size=7,\n                                       stride=2,\n                                       padding=3,\n                                       bias=False)\n                self.static_data = True\n            else:\n                self.conv1 = nn.Conv2d(2 * self.init_channel_mul,\n                                       self.inplanes,\n                                       kernel_size=7,\n                                       stride=2,\n                                       padding=3,\n                                       bias=False)\n                self.static_data = True\n        elif self.dataset == 'cifar10' or self.dataset == 'cifar100':\n            self.conv1 = nn.Conv2d(3 * self.init_channel_mul,\n                                   self.inplanes,\n                                   kernel_size=3,\n                                   padding=1,\n                                   bias=False)\n            self.static_data = True\n\n        # self.relu = nn.ReLU(inplace=False)\n        self.layer1 = self._make_layer(\n            block, self.interplanes[0], layers[0], node=self.node)\n        self.layer2 = self._make_layer(block,\n                                       self.interplanes[1],\n                                       layers[1],\n                                       stride=2,\n                                       dilate=replace_stride_with_dilation[0], node=self.node)\n        self.layer3 = self._make_layer(block,\n                                       self.interplanes[2],\n                                       layers[2],\n                                       stride=2,\n                                       dilate=replace_stride_with_dilation[1], node=self.node)\n        self.layer4 = self._make_layer(block,\n                                       self.interplanes[3],\n                                       layers[3],\n                                       stride=2,\n                                       dilate=replace_stride_with_dilation[2], node=self.node)\n\n        self.bn1 = norm_layer(self.inplanes)\n        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n\n        if self.spike_output:\n            self.fc = nn.Linear(\n                self.interplanes[3] * block.expansion, num_classes * 10)\n            self.node2 = self.node()\n            self.vote = VotingLayer(10)\n        else:\n            self.fc = nn.Linear(\n                self.interplanes[3] * block.expansion, num_classes\n            )\n            self.node2 = nn.Identity()\n            self.vote = nn.Identity()\n\n        self.warm_up = False\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight,\n                                        mode='fan_out',\n                                        nonlinearity='relu')\n            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n        # Zero-initialize the last BN in each residual branch,\n        # so that the residual branch starts with zeros, and each residual block behaves like an identity.\n        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677\n        if zero_init_residual:\n            for m in self.modules():\n                if isinstance(m, Bottleneck):\n                    nn.init.constant_(m.bn3.weight, 0)\n                elif isinstance(m, BasicBlock):\n                    nn.init.constant_(m.bn2.weight, 0)\n\n    def _make_layer(self, block, planes, blocks, stride=1, dilate=False, node=torch.nn.Identity):\n        norm_layer = self._norm_layer\n        downsample = None\n        previous_dilation = self.dilation\n        if dilate:\n            self.dilation *= stride\n            stride = 1\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            if block == BasicBlock:\n                downsample = nn.Sequential(\n                    norm_layer(self.inplanes),\n                    self.node(),\n                    conv1x1(self.inplanes, planes * block.expansion, stride),\n                )\n            elif block == Bottleneck:\n                downsample = nn.Sequential(\n                    norm_layer(self.inplanes),\n                    self.node(),\n                    conv1x1(self.inplanes, planes * block.expansion, stride),\n                )\n            else:\n                raise NotImplementedError\n\n        layers = [block(self.inplanes, planes, stride, downsample, self.groups,\n                        self.base_width, previous_dilation, norm_layer, node=node)]\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(\n                block(self.inplanes,\n                      planes,\n                      groups=self.groups,\n                      base_width=self.base_width,\n                      dilation=self.dilation,\n                      norm_layer=norm_layer, node=node))\n\n        return nn.Sequential(*layers)\n\n    def forward(self, inputs):\n        inputs = self.encoder(inputs)\n        self.reset()\n\n        if self.layer_by_layer:\n\n            x = self.conv1(inputs)\n            x = self.layer1(x)\n            x = self.layer2(x)\n            x = self.layer3(x)\n            x = self.layer4(x)\n\n            x = self.bn1(x)\n            # x = self.node1(x)\n            x = self.avgpool(x)\n\n            x = torch.flatten(x, 1)\n            # print(x.shape)\n            x = self.fc(x)\n            x = rearrange(x, '(t b) c -> t b c', t=self.step).mean(0)\n            x = self.node2(x)\n            x = self.vote(x)\n\n            return x\n\n        else:\n            outputs = []\n\n            if self.warm_up:\n                step = 1\n            else:\n                step = self.step\n            for t in range(step):\n                x = inputs[t]\n\n                x = self.conv1(x)\n\n                x = self.layer1(x)\n                x = self.layer2(x)\n                x = self.layer3(x)\n                x = self.layer4(x)\n\n                x = self.bn1(x)\n                # x = self.node1(x)\n                x = self.avgpool(x)\n\n                x = torch.flatten(x, 1)\n                x = self.fc(x)\n\n                x = self.node2(x)\n                x = self.vote(x)\n\n                outputs.append(x)\n\n            return sum(outputs) / len(outputs)\n\n\ndef _resnet(arch, block, layers, pretrained=False, **kwargs):\n    model = ResNet(block, layers, **kwargs)\n    # only load state_dict()\n    if pretrained:\n        raise NotImplementedError\n\n    return model\n\n\n@register_model\ndef resnet9(pretrained=False, **kwargs):\n    return _resnet('resnet9', BasicBlock, [1, 1, 1, 1], pretrained, **kwargs)\n\n\n@register_model\ndef resnet18(pretrained=False, **kwargs):\n    # kwargs['inplanes'] = 96\n    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, **kwargs)\n\n\n@register_model\ndef resnet34_half(pretrained=False, **kwargs):\n    kwargs['inplanes'] = 32\n    return _resnet('resnet34_half', BasicBlock, [3, 4, 6, 3], pretrained,\n                   **kwargs)\n\n\n@register_model\ndef resnet34(pretrained=False, **kwargs):\n    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, **kwargs)\n\n\n@register_model\ndef resnet50_half(pretrained=False, **kwargs):\n    kwargs['inplanes'] = 32\n    return _resnet('resnet50_half', Bottleneck, [3, 4, 6, 3], pretrained,\n                   **kwargs)\n\n\n@register_model\ndef resnet50(pretrained=False, **kwargs):\n    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, **kwargs)\n\n\n@register_model\ndef resnet101(pretrained=False, **kwargs):\n    return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained,\n                   **kwargs)\n\n\n@register_model\ndef resnet152(pretrained=False, **kwargs):\n    return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained,\n                   **kwargs)\n\n\n@register_model\ndef resnext50_32x4d(pretrained=False, **kwargs):\n    kwargs['groups'] = 32\n    kwargs['width_per_group'] = 4\n    return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], pretrained,\n                   **kwargs)\n\n\n@register_model\ndef resnext101_32x8d(pretrained=False, **kwargs):\n    kwargs['groups'] = 32\n    kwargs['width_per_group'] = 8\n    return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], pretrained,\n                   **kwargs)\n\n\n@register_model\ndef wide_resnet50_2(pretrained=False, **kwargs):\n    kwargs['width_per_group'] = 64 * 2\n    return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], pretrained,\n                   **kwargs)\n\n\n@register_model\ndef wide_resnet101_2(pretrained=False, **kwargs):\n    kwargs['width_per_group'] = 64 * 2\n    return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], pretrained,\n                   **kwargs)\n\n\nif __name__ == '__main__':\n    net = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=1000)\n    image_h, image_w = 224, 224\n    from thop import profile\n    from thop import clever_format\n\n    flops, params = profile(net,\n                            inputs=(torch.randn(1, 3, image_h, image_w),),\n                            verbose=False)\n    flops, params = clever_format([flops, params], '%.3f')\n    out = net(torch.autograd.Variable(torch.randn(3, 3, image_h, image_w)))\n    print(f'1111, flops: {flops}, params: {params},out_shape: {out.shape}')\n"
  },
  {
    "path": "braincog/model_zoo/resnet19_snn.py",
    "content": "# encoding: utf-8\n# Author    : Floyed<Floyed_Shen@outlook.com>\n# Datetime  : 2022/7/26 19:33\n# User      : Floyed\n# Product   : PyCharm\n# Project   : braincog\n# File      : resnet19_snn.py\n# explain   :\n\nimport os\nimport sys\nfrom functools import partial\nimport numpy as np\nfrom timm.models import register_model\nfrom timm.models.layers import trunc_normal_, DropPath\nfrom braincog.model_zoo.base_module import *\nfrom braincog.base.node.node import *\nfrom braincog.base.connection.layer import *\nfrom braincog.datasets import is_dvs_data\n\ndef conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=dilation, groups=groups, bias=False, dilation=dilation)\n\n\ndef conv1x1(in_planes, out_planes, stride=1):\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,\n                 node=LIFNode, base_width=64, dilation=1, norm_layer=None):\n        super(BasicBlock, self).__init__()\n        if norm_layer is None:\n            norm_layer = ThresholdDependentBatchNorm2d\n        # if groups != 1 or base_width != 64:\n        #     raise ValueError('BasicBlock only supports groups=1 and base_width=64')\n        # if dilation > 1:\n        #     raise NotImplementedError(\"Dilation > 1 not supported in BasicBlock\")\n        # Both self.conv1 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = norm_layer(num_features=planes, alpha=1.)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = norm_layer(num_features=planes, alpha=np.sqrt(.5))\n        self.downsample = downsample\n        self.stride = stride\n        self.node1 = node()\n        self.node2 = node()\n\n    def forward(self, x):\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.node1(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.node2(out)\n\n        return out\n\n\n\nclass ResNet(BaseModule):\n    def __init__(self, block, layers, num_classes=10, zero_init_residual=False, groups=1, width_per_group=128,\n                 replace_stride_with_dilation=None, norm_layer=None, step=4, encode_type='direct', node_type=LIFNode,\n                 *args, **kwargs):\n\n        super().__init__(\n            step,\n            encode_type,\n            *args,\n            **kwargs\n        )\n\n        super().__init__(step, encode_type, *args, **kwargs)\n        if not self.layer_by_layer:\n            raise ValueError('ResNet-SNN only support for layer-wise mode, because of tdBN')\n\n        self.node = node_type\n        if issubclass(self.node, BaseNode):\n            self.node = partial(self.node, **kwargs, step=step)\n\n        self.dataset = kwargs['dataset']\n        if is_dvs_data(self.dataset):\n            data_channel = 2\n        else:\n            data_channel = 3\n\n        if norm_layer is None:\n            norm_layer = ThresholdDependentBatchNorm2d\n        self._norm_layer = partial(norm_layer,   step=step)\n        self.sum_output=kwargs[\"sum_output\"] if \"sum_output\"in kwargs else True \n        self.inplanes = 128\n        self.dilation = 1\n        if replace_stride_with_dilation is None:\n            # each element in the tuple indicates if we should replace\n            # the 2x2 stride with a dilated convolution instead\n            replace_stride_with_dilation = [False, False, False]\n        if len(replace_stride_with_dilation) != 3:\n            raise ValueError(\"replace_stride_with_dilation should be None \"\n                             \"or a 3-element tuple, got {}\".format(replace_stride_with_dilation))\n        self.groups = groups\n        self.base_width = width_per_group\n        self.conv1 = nn.Conv2d(data_channel, self.inplanes, kernel_size=3, stride=1, padding=1,\n                               bias=False)\n        self.bn1 = self._norm_layer(num_features=self.inplanes, alpha=np.sqrt(.5))\n        self.layer1 = self._make_layer(block, 128, layers[0])\n        self.layer2 = self._make_layer(block, 256, layers[1], stride=2,\n                                       dilate=replace_stride_with_dilation[0])\n        self.layer3 = self._make_layer(block, 512, layers[2], stride=2,\n                                       dilate=replace_stride_with_dilation[1])\n        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n\n        self.fc1 = nn.Linear(512 * block.expansion, 256)\n        self.fc2 = nn.Linear(256, num_classes)\n        self.node1 = self.node()\n        self.node2 = self.node()\n\n        # Zero-initialize the last BN in each residual branch,\n        # so that the residual branch starts with zeros, and each residual block behaves like an identity.\n        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677\n        if zero_init_residual:\n            for m in self.modules():\n                if isinstance(m, BasicBlock):\n                    nn.init.constant_(m.bn2.weight, 0)\n                elif isinstance(m, nn.Conv2d):\n                    nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n\n    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):\n        norm_layer = self._norm_layer\n        # downsample = None\n        previous_dilation = self.dilation\n        if dilate:\n            self.dilation *= stride\n            stride = 1\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                conv1x1(self.inplanes, planes * block.expansion, stride),\n                norm_layer(num_features=planes * block.expansion, alpha=np.sqrt(.5)),\n            )\n        else:\n            downsample = nn.Sequential(\n                norm_layer(num_features=planes * block.expansion, alpha=np.sqrt(.5)),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride=stride, downsample=downsample, groups=self.groups,\n                            base_width=self.base_width, norm_layer=norm_layer, node=self.node))\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(block(self.inplanes, planes, groups=self.groups,\n                                base_width=self.base_width, dilation=self.dilation,\n                                norm_layer=norm_layer, node=self.node))\n\n        return nn.Sequential(*layers)\n\n    def _forward_impl(self, x):\n\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.node1(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n\n        x = self.avgpool(x)\n        x = torch.flatten(x, 1)\n        x = self.fc1(x)\n        x = self.node2(x)\n        x = self.fc2(x)\n        \n        if self.sum_output:x= rearrange(x, '(t b) c -> b c t', t=self.step).mean(-1)\n        else :x=  rearrange(x, '(t b) c -> t b c ', t=self.step)\n        return x\n\n    def forward(self, inputs):\n        inputs = self.encoder(inputs)\n        self.reset()\n\n        return self._forward_impl(inputs)\n\n\ndef _resnet(arch, block, layers, pretrained, progress, norm=ThresholdDependentBatchNorm2d, **kwargs):\n    tdBN = partial(norm, layer_by_layer=kwargs['layer_by_layer'], threshold=kwargs['threshold'])\n    model = ResNet(block, layers, norm_layer=tdBN, **kwargs)\n    if pretrained:\n        raise NotImplementedError\n    return model\n\n\n@register_model\ndef resnet19(pretrained=False, progress=True, norm=ThresholdDependentBatchNorm2d, **kwargs):\n    return _resnet('resnet19', BasicBlock, [3, 3, 2], pretrained, progress, norm=norm, **kwargs)\n\n\nif __name__ == '__main__':\n    net = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=1000)\n    image_h, image_w = 224, 224\n    from thop import profile\n    from thop import clever_format\n\n    flops, params = profile(net,\n                            inputs=(torch.randn(1, 3, image_h, image_w),),\n                            verbose=False)\n    flops, params = clever_format([flops, params], '%.3f')\n    out = net(torch.autograd.Variable(torch.randn(3, 3, image_h, image_w)))\n    print(f'1111, flops: {flops}, params: {params},out_shape: {out.shape}')\n"
  },
  {
    "path": "braincog/model_zoo/rsnn.py",
    "content": "\r\nimport torch\r\nfrom torch import nn\r\n\r\nfrom braincog.base.node.node import IFNode\r\nfrom braincog.base.learningrule.STDP import STDP,MutliInputSTDP\r\nfrom braincog.base.connection.CustomLinear import CustomLinear\r\n\r\n\r\nfrom collections import deque\r\nfrom random import randint\r\n\r\nclass RSNN(nn.Module):\r\n    def __init__(self,num_state,num_action):\r\n        super().__init__()\r\n        # parameters\r\n        rsnn_mask=[]\r\n        rsnn_con=[]\r\n        con_matrix1 = torch.ones((num_state,num_action), dtype=torch.float)\r\n        rsnn_mask.append(con_matrix1)\r\n        rsnn_con.append(CustomLinear(torch.randn(num_state,num_action), con_matrix1))\r\n\r\n        self.num_subR=2\r\n        self.connection = rsnn_con\r\n        self.mask=rsnn_mask\r\n        self.node = [IFNode() for i in range(self.num_subR)]\r\n        self.learning_rule = []\r\n        self.learning_rule.append(MutliInputSTDP(self.node[1], [self.connection[0]]))\r\n\r\n        self.weight_trace = torch.zeros(con_matrix1.shape, dtype=torch.float)\r\n        \r\n        self.out_in = torch.zeros((num_state), dtype=torch.float)\r\n        self.out = torch.zeros((self.connection[0].weight.size()[1]), dtype=torch.float)\r\n        self.dw = torch.zeros((self.connection[0].weight.size()), dtype=torch.float)\r\n\r\n    def forward(self, input):\r\n        input=torch.tensor(input, dtype=torch.float)\r\n        self.out_in=self.node[0](input)\r\n        self.out,self.dw = self.learning_rule[0](self.out_in)\r\n        return self.out,self.dw\r\n\r\n    def UpdateWeight(self,reward):\r\n        self.weight_trace[self.weight_trace>0]=self.weight_trace[self.weight_trace>0]*reward\r\n        self.weight_trace[self.weight_trace < 0] = -1*self.weight_trace[self.weight_trace < 0] * reward\r\n        self.connection[0].update(self.weight_trace)\r\n        for i in range(self.connection[0].weight.size()[1]):\r\n            self.connection[0].weight.data[:, i] = (self.connection[0].weight.data[:, i] - torch.min(self.connection[0].weight.data[:, i])) / (torch.max(self.connection[0].weight.data[:, i]) - torch.min(self.connection[0].weight.data[:, i]))\r\n        self.connection[0].weight.data= self.connection[0].weight.data * 0.5\r\n    def reset(self):\r\n        for i in range(self.num_subR):\r\n            self.node[i].n_reset()\r\n        for i in range(len(self.learning_rule)):\r\n            self.learning_rule[i].reset()\r\n    def getweight(self):\r\n        return self.connection\r\n"
  },
  {
    "path": "braincog/model_zoo/sew_resnet.py",
    "content": "import torch\nimport torch.nn as nn\nfrom copy import deepcopy\n\ntry:\n    from torchvision.models.utils import load_state_dict_from_url\nexcept ImportError:\n    from torchvision._internally_replaced_utils import load_state_dict_from_url\nfrom braincog.base.node import *\nfrom braincog.model_zoo.base_module import *\nfrom braincog.datasets import is_dvs_data\nfrom timm.models import register_model\n__all__ = ['SEWResNet', 'sew_resnet18', 'sew_resnet34', 'sew_resnet50', 'sew_resnet101',\n           'sew_resnet152', 'sew_resnext50_32x4d', 'sew_resnext101_32x8d',\n           'sew_wide_resnet50_2', 'sew_wide_resnet101_2']\n\nmodel_urls = {\n    \"resnet18\": \"https://download.pytorch.org/models/resnet18-f37072fd.pth\",\n    \"resnet34\": \"https://download.pytorch.org/models/resnet34-b627a593.pth\",\n    \"resnet50\": \"https://download.pytorch.org/models/resnet50-0676ba61.pth\",\n    \"resnet101\": \"https://download.pytorch.org/models/resnet101-63fe2227.pth\",\n    \"resnet152\": \"https://download.pytorch.org/models/resnet152-394f9c45.pth\",\n    \"resnext50_32x4d\": \"https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth\",\n    \"resnext101_32x8d\": \"https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth\",\n    \"wide_resnet50_2\": \"https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth\",\n    \"wide_resnet101_2\": \"https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth\",\n}\n\n# modified by https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py\n\ndef sew_function(x: torch.Tensor, y: torch.Tensor, cnf:str):\n    if cnf == 'ADD':\n        return x + y\n    elif cnf == 'AND':\n        return x * y\n    elif cnf == 'IAND':\n        return x * (1. - y)\n    else:\n        raise NotImplementedError\n\n\n\ndef conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=dilation, groups=groups, bias=False, dilation=dilation)\n\n\ndef conv1x1(in_planes, out_planes, stride=1):\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,\n                 base_width=64, dilation=1, norm_layer=None, cnf: str = None, node: callable = None, **kwargs):\n        super(BasicBlock, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        if groups != 1 or base_width != 64:\n            raise ValueError('BasicBlock only supports groups=1 and base_width=64')\n        if dilation > 1:\n            raise NotImplementedError(\"Dilation > 1 not supported in BasicBlock\")\n        # Both self.conv1 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = norm_layer(planes)\n        self.node1 = node()\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = norm_layer(planes)\n        self.node2 = node()\n        self.downsample = downsample\n        if downsample is not None:\n            self.downsample_sn = node()\n        self.stride = stride\n        self.cnf = cnf\n\n    def forward(self, x):\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.node1(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.node2(out)\n\n        if self.downsample is not None:\n            identity = self.downsample_sn(self.downsample(x))\n\n        out = sew_function(identity, out, self.cnf)\n\n        return out\n\n    def extra_repr(self) -> str:\n        return super().extra_repr() + f'cnf={self.cnf}'\n\nclass Bottleneck(nn.Module):\n    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)\n    # while original implementation places the stride at the first 1x1 convolution(self.conv1)\n    # according to \"Deep residual learning for image recognition\"https://arxiv.org/abs/1512.03385.\n    # This variant is also known as ResNet V1.5 and improves accuracy according to\n    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.\n\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,\n                 base_width=64, dilation=1, norm_layer=None, cnf: str = None, node: callable = None, **kwargs):\n        super(Bottleneck, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        width = int(planes * (base_width / 64.)) * groups\n        # Both self.conv2 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv1x1(inplanes, width)\n        self.bn1 = norm_layer(width)\n        self.node1 = node()\n        self.conv2 = conv3x3(width, width, stride, groups, dilation)\n        self.bn2 = norm_layer(width)\n        self.node2 = node()\n        self.conv3 = conv1x1(width, planes * self.expansion)\n        self.bn3 = norm_layer(planes * self.expansion)\n        self.node3 = node()\n        self.downsample = downsample\n        if downsample is not None:\n            self.downsample_sn = node()\n        self.stride = stride\n        self.cnf = cnf\n\n    def forward(self, x):\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.node1(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.node2(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n        out = self.node3(out)\n\n        if self.downsample is not None:\n            identity = self.downsample_sn(self.downsample(x))\n\n        out = sew_function(out, identity, self.cnf)\n\n        return out\n\n    def extra_repr(self) -> str:\n        return super().extra_repr() + f'cnf={self.cnf}'\n\n\nclass SEWResNet(BaseModule):\n    def __init__(self, block, layers, num_classes=1000, step=8,encode_type=\"direct\",zero_init_residual=False,\n                 groups=1, width_per_group=64, replace_stride_with_dilation=None,\n                 norm_layer=None, cnf: str = None,   *args,**kwargs):\n        super().__init__(            \n            step,\n            encode_type,\n            *args,\n            **kwargs\n        )\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        self._norm_layer = norm_layer\n        self.num_classes = num_classes\n\n        self.node = kwargs['node_type']\n        if issubclass(self.node, BaseNode):\n            self.node = partial(self.node, **kwargs, step=step)\n        self.once=kwargs[\"once\"] if \"once\"in kwargs else False\n        self.sum_output=kwargs[\"sum_output\"] if \"sum_output\"in kwargs else True\n        self.dataset = kwargs['dataset']\n        if not is_dvs_data(self.dataset):\n            init_channel = 3\n        else:\n            init_channel = 2\n        self.inplanes = 64\n        self.dilation = 1\n        if replace_stride_with_dilation is None:\n            # each element in the tuple indicates if we should replace\n            # the 2x2 stride with a dilated convolution instead\n            replace_stride_with_dilation = [False, False, False]\n        if len(replace_stride_with_dilation) != 3:\n            raise ValueError(\"replace_stride_with_dilation should be None \"\n                             \"or a 3-element tuple, got {}\".format(replace_stride_with_dilation))\n        self.groups = groups\n        self.base_width = width_per_group\n\n \n\n        self.conv1 = nn.Conv2d(init_channel, self.inplanes, kernel_size=7, stride=2, padding=3,\n                               bias=False)\n        self.bn1 = norm_layer(self.inplanes)\n        self.node1 = self.node()\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, 64, layers[0], cnf=cnf, node=self.node, **kwargs)\n        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,\n                                       dilate=replace_stride_with_dilation[0], cnf=cnf, node=self.node, **kwargs)\n        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,\n                                       dilate=replace_stride_with_dilation[1], cnf=cnf, node=self.node, **kwargs)\n        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,\n                                       dilate=replace_stride_with_dilation[2], cnf=cnf, node=self.node, **kwargs)\n        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n        self.fc = nn.Linear(512 * block.expansion, num_classes)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n        # Zero-initialize the last BN in each residual branch,\n        # so that the residual branch starts with zeros, and each residual block behaves like an identity.\n        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677\n        if zero_init_residual:\n            for m in self.modules():\n                if isinstance(m, Bottleneck):\n                    nn.init.constant_(m.bn3.weight, 0)\n                elif isinstance(m, BasicBlock):\n                    nn.init.constant_(m.bn2.weight, 0)\n\n    def _make_layer(self, block, planes, blocks, stride=1, dilate=False, cnf: str=None, node: callable = None, **kwargs):\n        norm_layer = self._norm_layer\n        downsample = None\n        previous_dilation = self.dilation\n        if dilate:\n            self.dilation *= stride\n            stride = 1\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                conv1x1(self.inplanes, planes * block.expansion, stride),\n                norm_layer(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,\n                            self.base_width, previous_dilation, norm_layer, cnf, node, **kwargs))\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(block(self.inplanes, planes, groups=self.groups,\n                                base_width=self.base_width, dilation=self.dilation,\n                                norm_layer=norm_layer, cnf=cnf, node=node, **kwargs))\n\n        return nn.Sequential(*layers)\n\n    def _forward_impl(self, inputs):\n        # See note [TorchScript super()]\n        inputs = self.encoder(inputs)\n        self.reset()\n\n        if self.layer_by_layer:\n\n            x = self.conv1(inputs)\n            \n            x = self.bn1(x)\n            x = self.node1(x)\n            x = self.maxpool(x)\n\n            x = self.layer1(x)\n            x = self.layer2(x)\n            x = self.layer3(x)\n            x = self.layer4(x)\n\n            x = self.avgpool(x)\n\n            x = torch.flatten(x, 1)\n\n            \n            x = self.fc(x)\n            \n            x = rearrange(x, '(t b) c -> t b c', t=self.step)\n            #print(x)\n            if self.sum_output:x=x.mean(0)\n \n\n            return x\n\n        else:\n            outputs=[]\n            for t in range(self.step):\n                x = inputs[t]\n                x = self.conv1(x)\n                x = self.bn1(x)\n                x = self.node1(x)\n                x = self.maxpool(x)\n\n                x = self.layer1(x)\n                x = self.layer2(x)\n                x = self.layer3(x)\n                x = self.layer4(x)\n\n                x = self.avgpool(x)\n                x = torch.flatten(x, 1)\n                \n                \n                x = self.fc(x)\n\n                outputs.append(x)\n            if not self.sum_output:return outputs\n            return sum(outputs) / len(outputs)\n\n    def _forward_once(self,x):\n        # inputs = self.encoder(inputs)\n        # x = inputs[t]\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.node1(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = self.avgpool(x)\n        x = torch.flatten(x, 1)\n        \n        \n        x = self.fc(x)\n        return x\n    def forward(self, x):\n        if self.once:return self._forward_once(x)\n        return self._forward_impl(x)\n\nclass SEWResNet19(BaseModule):\n    def __init__(self, block, layers, num_classes=1000, step=8,encode_type=\"direct\",zero_init_residual=False,\n                 groups=1, width_per_group=64, replace_stride_with_dilation=None,\n                 norm_layer=None, cnf: str = None,   *args,**kwargs):\n        super().__init__(            \n            step,\n            encode_type,\n            *args,\n            **kwargs\n        )\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        self._norm_layer = norm_layer\n        self.num_classes = num_classes\n\n        self.node = kwargs['node_type']\n        if issubclass(self.node, BaseNode):\n            self.node = partial(self.node, **kwargs, step=step)\n        self.once=kwargs[\"once\"] if \"once\"in kwargs else False\n        self.sum_output=kwargs[\"sum_output\"] if \"sum_output\"in kwargs else True\n        self.dataset = kwargs['dataset']\n        if not is_dvs_data(self.dataset):\n            init_channel = 3\n        else:\n            init_channel = 2\n        self.inplanes = 64\n        self.dilation = 1\n        if replace_stride_with_dilation is None:\n            # each element in the tuple indicates if we should replace\n            # the 2x2 stride with a dilated convolution instead\n            replace_stride_with_dilation = [False, False, False]\n        if len(replace_stride_with_dilation) != 3:\n            raise ValueError(\"replace_stride_with_dilation should be None \"\n                             \"or a 3-element tuple, got {}\".format(replace_stride_with_dilation))\n        self.groups = groups\n        self.base_width = width_per_group\n\n \n\n        self.conv1 = nn.Conv2d(init_channel, self.inplanes, kernel_size=3, stride=1, padding=1,\n                               bias=False)\n        self.bn1 = norm_layer(self.inplanes)\n        self.node1 = self.node() \n        self.layer1 = self._make_layer(block, 128, layers[0], cnf=cnf, node=self.node, **kwargs)\n        self.layer2 = self._make_layer(block, 256, layers[1], stride=2,\n                                       dilate=replace_stride_with_dilation[0], cnf=cnf, node=self.node, **kwargs)\n        self.layer3 = self._make_layer(block, 512, layers[2], stride=2,\n                                       dilate=replace_stride_with_dilation[1], cnf=cnf, node=self.node, **kwargs) \n        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n        self.fc1 = nn.Linear(512 * block.expansion, 256)\n        self.fc2 = nn.Linear(256, num_classes)\n         \n        self.node2 = self.node()\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n        # Zero-initialize the last BN in each residual branch,\n        # so that the residual branch starts with zeros, and each residual block behaves like an identity.\n        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677\n        if zero_init_residual:\n            for m in self.modules():\n                if isinstance(m, Bottleneck):\n                    nn.init.constant_(m.bn3.weight, 0)\n                elif isinstance(m, BasicBlock):\n                    nn.init.constant_(m.bn2.weight, 0)\n\n    def _make_layer(self, block, planes, blocks, stride=1, dilate=False, cnf: str=None, node: callable = None, **kwargs):\n        norm_layer = self._norm_layer\n        downsample = None\n        previous_dilation = self.dilation\n        if dilate:\n            self.dilation *= stride\n            stride = 1\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                conv1x1(self.inplanes, planes * block.expansion, stride),\n                norm_layer(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,\n                            self.base_width, previous_dilation, norm_layer, cnf, node, **kwargs))\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(block(self.inplanes, planes, groups=self.groups,\n                                base_width=self.base_width, dilation=self.dilation,\n                                norm_layer=norm_layer, cnf=cnf, node=node, **kwargs))\n\n        return nn.Sequential(*layers)\n\n    def _forward_impl(self, inputs):\n        # See note [TorchScript super()]\n        inputs = self.encoder(inputs)\n        self.reset()\n\n        if self.layer_by_layer:\n\n            x = self.conv1(inputs)\n            \n            x = self.bn1(x)\n            x = self.node1(x) \n\n            x = self.layer1(x)\n            x = self.layer2(x)\n            x = self.layer3(x) \n\n            x = self.avgpool(x)\n\n            x = torch.flatten(x, 1)\n\n            x = self.fc1(x)\n            x = self.node2(x)\n            x = self.fc2(x)\n            \n            x = rearrange(x, '(t b) c -> t b c', t=self.step)\n            #print(x)\n            if self.sum_output:x=x.mean(0)\n \n\n            return x\n\n        else:\n            outputs=[]\n            for t in range(self.step):\n                x = inputs[t]\n                x = self.conv1(x)\n                x = self.bn1(x)\n                x = self.node1(x)\n\n                x = self.layer1(x)\n                x = self.layer2(x)\n                x = self.layer3(x)\n\n                x = self.avgpool(x)\n                x = torch.flatten(x, 1)\n                \n                x = self.fc1(x)\n                x = self.node2(x)\n                x = self.fc2(x)\n\n                outputs.append(x)\n            if not self.sum_output:return outputs\n            return sum(outputs) / len(outputs)\n\n    def _forward_once(self,x):\n        # inputs = self.encoder(inputs)\n        # x = inputs[t]\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.node1(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = self.avgpool(x)\n        x = torch.flatten(x, 1)\n        \n        \n        x = self.fc(x)\n        return x\n    def forward(self, x):\n        if self.once:return self._forward_once(x)\n        return self._forward_impl(x)\n \nclass SEWResNetCifar(BaseModule):\n    def __init__(self, block, layers, num_classes=1000, step=8,encode_type=\"direct\",zero_init_residual=False,\n                 groups=1, width_per_group=64, replace_stride_with_dilation=None,\n                 norm_layer=None, cnf: str = None,   *args,**kwargs):\n        super().__init__(            \n            step,\n            encode_type,\n            *args,\n            **kwargs\n        )\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        self._norm_layer = norm_layer\n        self.num_classes = num_classes\n\n        self.node = kwargs['node_type']\n        if issubclass(self.node, BaseNode):\n            self.node = partial(self.node, **kwargs, step=step)\n        self.once=kwargs[\"once\"] if \"once\"in kwargs else False\n        self.sum_output=kwargs[\"sum_output\"] if \"sum_output\"in kwargs else True\n        self.dataset = kwargs['dataset']\n        if not is_dvs_data(self.dataset):\n            init_channel = 3\n        else:\n            init_channel = 2\n        self.inplanes = 64\n        self.dilation = 1\n        if replace_stride_with_dilation is None:\n            # each element in the tuple indicates if we should replace\n            # the 2x2 stride with a dilated convolution instead\n            replace_stride_with_dilation = [False, False, False]\n        if len(replace_stride_with_dilation) != 3:\n            raise ValueError(\"replace_stride_with_dilation should be None \"\n                             \"or a 3-element tuple, got {}\".format(replace_stride_with_dilation))\n        self.groups = groups\n        self.base_width = width_per_group\n\n \n\n        self.conv1 = nn.Conv2d(init_channel, self.inplanes, kernel_size=3, stride=1, padding=1,\n                               bias=False)\n        self.bn1 = norm_layer(self.inplanes)\n        self.node1 = self.node()\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, 128, layers[0], cnf=cnf, node=self.node, **kwargs)\n        self.layer2 = self._make_layer(block, 256, layers[1], stride=2,\n                                       dilate=replace_stride_with_dilation[0], cnf=cnf, node=self.node, **kwargs)\n        self.layer3 = self._make_layer(block, 512, layers[2], stride=2,\n                                       dilate=replace_stride_with_dilation[1], cnf=cnf, node=self.node, **kwargs)\n \n        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n        self.fc = nn.Linear(512 * block.expansion, num_classes)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n        # Zero-initialize the last BN in each residual branch,\n        # so that the residual branch starts with zeros, and each residual block behaves like an identity.\n        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677\n        if zero_init_residual:\n            for m in self.modules():\n                if isinstance(m, Bottleneck):\n                    nn.init.constant_(m.bn3.weight, 0)\n                elif isinstance(m, BasicBlock):\n                    nn.init.constant_(m.bn2.weight, 0)\n\n    def _make_layer(self, block, planes, blocks, stride=1, dilate=False, cnf: str=None, node: callable = None, **kwargs):\n        norm_layer = self._norm_layer\n        downsample = None\n        previous_dilation = self.dilation\n        if dilate:\n            self.dilation *= stride\n            stride = 1\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                conv1x1(self.inplanes, planes * block.expansion, stride),\n                norm_layer(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,\n                            self.base_width, previous_dilation, norm_layer, cnf, node, **kwargs))\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(block(self.inplanes, planes, groups=self.groups,\n                                base_width=self.base_width, dilation=self.dilation,\n                                norm_layer=norm_layer, cnf=cnf, node=node, **kwargs))\n\n        return nn.Sequential(*layers)\n\n    def _forward_impl(self, inputs):\n        # See note [TorchScript super()]\n        inputs = self.encoder(inputs)\n        self.reset()\n\n        if self.layer_by_layer:\n\n            x = self.conv1(inputs)\n            \n            x = self.bn1(x)\n            x = self.node1(x)\n\n            x = self.layer1(x)\n            x = self.layer2(x)\n            x = self.layer3(x)\n\n            x = self.avgpool(x)\n\n            x = torch.flatten(x, 1)\n\n            \n            x = self.fc(x)\n            \n            x = rearrange(x, '(t b) c -> t b c', t=self.step)\n            #print(x)\n            if self.sum_output:x=x.mean(0)\n \n\n            return x\n\n        else:\n            outputs=[]\n            for t in range(self.step):\n                x = inputs[t]\n                x = self.conv1(x)\n                x = self.bn1(x)\n                x = self.node1(x)\n\n                x = self.layer1(x)\n                x = self.layer2(x)\n                x = self.layer3(x)\n\n                x = self.avgpool(x)\n                x = torch.flatten(x, 1)\n                \n                \n                x = self.fc(x)\n\n                outputs.append(x)\n            if not self.sum_output:return outputs\n            return sum(outputs) / len(outputs)\n\n    def _forward_once(self,x):\n        # inputs = self.encoder(inputs)\n        # x = inputs[t]\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.node1(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = self.avgpool(x)\n        x = torch.flatten(x, 1)\n        \n        \n        x = self.fc(x)\n        return x\n    def forward(self, x):\n        if self.once:return self._forward_once(x)\n        return self._forward_impl(x)\n\n\ndef _sew_resnet(arch, block, layers, pretrained, progress, cnf,  **kwargs):\n    model = SEWResNet(block, layers, cnf=cnf,  **kwargs)\n    if pretrained:\n        state_dict = load_state_dict_from_url(model_urls[arch],\n                                              progress=progress)\n        model.load_state_dict(state_dict)\n    return model\n\n@register_model\ndef sew_resnet19(pretrained=False, progress=True, cnf: str = None,  **kwargs):\n    \"\"\"\n    :param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet\n    :type pretrained: bool\n    :param progress: If True, displays a progress bar of the download to stderr\n    :type progress: bool\n    :param cnf: the name of spike-element-wise function\n    :type cnf: str\n    :param node: a spiking neuron layer\n    :type node: callable\n    :param kwargs: kwargs for `node`\n    :type kwargs: dict\n    :return: Spiking ResNet-18\n    :rtype: torch.nn.Module\n    The spike-element-wise ResNet-18 `\"Deep Residual Learning in Spiking Neural Networks\" <https://arxiv.org/abs/2102.04159>`_ modified by the ResNet-18 model from `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_\n    \"\"\"\n\n    return SEWResNet19( BasicBlock, [3,3, 2],  cnf=cnf, **kwargs)\n \n@register_model\ndef sew_resnet18(pretrained=False, progress=True, cnf: str = None,  **kwargs):\n    \"\"\"\n    :param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet\n    :type pretrained: bool\n    :param progress: If True, displays a progress bar of the download to stderr\n    :type progress: bool\n    :param cnf: the name of spike-element-wise function\n    :type cnf: str\n    :param node: a spiking neuron layer\n    :type node: callable\n    :param kwargs: kwargs for `node`\n    :type kwargs: dict\n    :return: Spiking ResNet-18\n    :rtype: torch.nn.Module\n    The spike-element-wise ResNet-18 `\"Deep Residual Learning in Spiking Neural Networks\" <https://arxiv.org/abs/2102.04159>`_ modified by the ResNet-18 model from `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_\n    \"\"\"\n\n    return _sew_resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, cnf, **kwargs)\n\n@register_model\ndef sew_resnet20(pretrained=False, progress=True, cnf: str = None,  **kwargs):\n    \"\"\"\n    :param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet\n    :type pretrained: bool\n    :param progress: If True, displays a progress bar of the download to stderr\n    :type progress: bool\n    :param cnf: the name of spike-element-wise function\n    :type cnf: str\n    :param node: a spiking neuron layer\n    :type node: callable\n    :param kwargs: kwargs for `node`\n    :type kwargs: dict\n    :return: Spiking ResNet-34\n    :rtype: torch.nn.Module\n    The spike-element-wise ResNet-34 `\"Deep Residual Learning in Spiking Neural Networks\" <https://arxiv.org/abs/2102.04159>`_\n    modified by the ResNet-34 model from `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_\n    \"\"\"\n    return SEWResNetCifar(  BasicBlock, [3,3,3],  cnf=cnf,  **kwargs)\n@register_model\ndef sew_resnet32(pretrained=False, progress=True, cnf: str = None,  **kwargs):\n    \"\"\"\n    :param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet\n    :type pretrained: bool\n    :param progress: If True, displays a progress bar of the download to stderr\n    :type progress: bool\n    :param cnf: the name of spike-element-wise function\n    :type cnf: str\n    :param node: a spiking neuron layer\n    :type node: callable\n    :param kwargs: kwargs for `node`\n    :type kwargs: dict\n    :return: Spiking ResNet-34\n    :rtype: torch.nn.Module\n    The spike-element-wise ResNet-34 `\"Deep Residual Learning in Spiking Neural Networks\" <https://arxiv.org/abs/2102.04159>`_\n    modified by the ResNet-34 model from `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_\n    \"\"\"\n    return SEWResNetCifar(  BasicBlock, [5,5,5],  cnf=cnf,  **kwargs)\n@register_model\ndef sew_resnet44(pretrained=False, progress=True, cnf: str = None,  **kwargs):\n    \"\"\"\n    :param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet\n    :type pretrained: bool\n    :param progress: If True, displays a progress bar of the download to stderr\n    :type progress: bool\n    :param cnf: the name of spike-element-wise function\n    :type cnf: str\n    :param node: a spiking neuron layer\n    :type node: callable\n    :param kwargs: kwargs for `node`\n    :type kwargs: dict\n    :return: Spiking ResNet-34\n    :rtype: torch.nn.Module\n    The spike-element-wise ResNet-34 `\"Deep Residual Learning in Spiking Neural Networks\" <https://arxiv.org/abs/2102.04159>`_\n    modified by the ResNet-34 model from `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_\n    \"\"\"\n    return SEWResNetCifar(  BasicBlock, [7,7,7],  cnf=cnf,  **kwargs)\n@register_model\ndef sew_resnet56(pretrained=False, progress=True, cnf: str = None,  **kwargs):\n    \"\"\"\n    :param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet\n    :type pretrained: bool\n    :param progress: If True, displays a progress bar of the download to stderr\n    :type progress: bool\n    :param cnf: the name of spike-element-wise function\n    :type cnf: str\n    :param node: a spiking neuron layer\n    :type node: callable\n    :param kwargs: kwargs for `node`\n    :type kwargs: dict\n    :return: Spiking ResNet-34\n    :rtype: torch.nn.Module\n    The spike-element-wise ResNet-34 `\"Deep Residual Learning in Spiking Neural Networks\" <https://arxiv.org/abs/2102.04159>`_\n    modified by the ResNet-34 model from `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_\n    \"\"\"\n    return SEWResNetCifar(  BasicBlock, [9,9,9],  cnf=cnf,  **kwargs)\n@register_model\ndef sew_resnet34(pretrained=False, progress=True, cnf: str = None,  **kwargs):\n    \"\"\"\n    :param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet\n    :type pretrained: bool\n    :param progress: If True, displays a progress bar of the download to stderr\n    :type progress: bool\n    :param cnf: the name of spike-element-wise function\n    :type cnf: str\n    :param node: a spiking neuron layer\n    :type node: callable\n    :param kwargs: kwargs for `node`\n    :type kwargs: dict\n    :return: Spiking ResNet-34\n    :rtype: torch.nn.Module\n    The spike-element-wise ResNet-34 `\"Deep Residual Learning in Spiking Neural Networks\" <https://arxiv.org/abs/2102.04159>`_\n    modified by the ResNet-34 model from `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_\n    \"\"\"\n    return _sew_resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, cnf,  **kwargs)\n\n@register_model\ndef sew_resnet50(pretrained=False, progress=True, cnf: str = None,  **kwargs):\n    \"\"\"\n    :param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet\n    :type pretrained: bool\n    :param progress: If True, displays a progress bar of the download to stderr\n    :type progress: bool\n    :param cnf: the name of spike-element-wise function\n    :type cnf: str\n    :param node: a spiking neuron layer\n    :type node: callable\n    :param kwargs: kwargs for `node`\n    :type kwargs: dict\n    :return: Spiking ResNet-50\n    :rtype: torch.nn.Module\n    The spike-element-wise ResNet-50 `\"Deep Residual Learning in Spiking Neural Networks\" <https://arxiv.org/abs/2102.04159>`_\n    modified by the ResNet-50 model from `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_\n    \"\"\"\n    return _sew_resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, cnf, **kwargs)\n\n@register_model\ndef sew_resnet101(pretrained=False, progress=True, cnf: str = None,  **kwargs):\n    \"\"\"\n    :param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet\n    :type pretrained: bool\n    :param progress: If True, displays a progress bar of the download to stderr\n    :type progress: bool\n    :param cnf: the name of spike-element-wise function\n    :type cnf: str\n    :param node: a spiking neuron layer\n    :type node: callable\n    :param kwargs: kwargs for `node`\n    :type kwargs: dict\n    :return: Spiking ResNet-101\n    :rtype: torch.nn.Module\n    The spike-element-wise ResNet-101 `\"Deep Residual Learning in Spiking Neural Networks\" <https://arxiv.org/abs/2102.04159>`_\n    modified by the ResNet-101 model from `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_\n    \"\"\"\n    return _sew_resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, cnf, **kwargs)\n\n@register_model\ndef sew_resnet152(pretrained=False, progress=True, cnf: str = None,  **kwargs):\n    \"\"\"\n    :param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet\n    :type pretrained: bool\n    :param progress: If True, displays a progress bar of the download to stderr\n    :type progress: bool\n    :param cnf: the name of spike-element-wise function\n    :type cnf: str\n    :param node: a single step neuron\n    :type node: callable\n    :param kwargs: kwargs for `node`\n    :type kwargs: dict\n    :return: Spiking ResNet-152\n    :rtype: torch.nn.Module\n    The spike-element-wise ResNet-152 `\"Deep Residual Learning in Spiking Neural Networks\" <https://arxiv.org/abs/2102.04159>`_\n    modified by the ResNet-152 model from `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_\n    \"\"\"\n    return _sew_resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, cnf,  **kwargs)\n\n@register_model\ndef sew_resnext50_32x4d(pretrained=False, progress=True, cnf: str = None, **kwargs):\n    \"\"\"\n    :param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet\n    :type pretrained: bool\n    :param progress: If True, displays a progress bar of the download to stderr\n    :type progress: bool\n    :param cnf: the name of spike-element-wise function\n    :type cnf: str\n    :param node: a single step neuron\n    :type node: callable\n    :param kwargs: kwargs for `node`\n    :type kwargs: dict\n    :return: Spiking ResNeXt-50 32x4d\n    :rtype: torch.nn.Module\n    The spike-element-wise ResNeXt-50 32x4d `\"Deep Residual Learning in Spiking Neural Networks\" <https://arxiv.org/abs/2102.04159>`_\n    modified by the ResNeXt-50 32x4d model from `\"Aggregated Residual Transformation for Deep Neural Networks\" <https://arxiv.org/pdf/1611.05431.pdf>`_\n    \"\"\"\n    kwargs['groups'] = 32\n    kwargs['width_per_group'] = 4\n    return _sew_resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], pretrained, progress, cnf,  **kwargs)\n\n@register_model\ndef sew_resnext34_32x4d(pretrained=False, progress=True, cnf: str = None,   **kwargs):\n    \"\"\"\n    :param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet\n    :type pretrained: bool\n    :param progress: If True, displays a progress bar of the download to stderr\n    :type progress: bool\n    :param cnf: the name of spike-element-wise function\n    :type cnf: str\n    :param node: a single step neuron\n    :type node: callable\n    :param kwargs: kwargs for `node`\n    :type kwargs: dict\n    :return: Spiking ResNeXt-101 32x8d\n    :rtype: torch.nn.Module\n    The spike-element-wise ResNeXt-101 32x8d `\"Deep Residual Learning in Spiking Neural Networks\" <https://arxiv.org/abs/2102.04159>`_ modified by the ResNeXt-101 32x8d model from `\"Aggregated Residual Transformation for Deep Neural Networks\" <https://arxiv.org/pdf/1611.05431.pdf>`_\n    \"\"\"\n    kwargs['groups'] = 32\n    kwargs['width_per_group'] = 4\n    return _sew_resnet('resnext34_32x4d', BasicBlock, [3, 4, 6, 3], pretrained, progress, cnf,   **kwargs)\n\n@register_model\ndef sew_resnext101_32x8d(pretrained=False, progress=True, cnf: str = None, node: callable=None, **kwargs):\n    \"\"\"\n    :param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet\n    :type pretrained: bool\n    :param progress: If True, displays a progress bar of the download to stderr\n    :type progress: bool\n    :param cnf: the name of spike-element-wise function\n    :type cnf: str\n    :param node: a single step neuron\n    :type node: callable\n    :param kwargs: kwargs for `node`\n    :type kwargs: dict\n    :return: Spiking ResNeXt-101 32x8d\n    :rtype: torch.nn.Module\n    The spike-element-wise ResNeXt-101 32x8d `\"Deep Residual Learning in Spiking Neural Networks\" <https://arxiv.org/abs/2102.04159>`_ modified by the ResNeXt-101 32x8d model from `\"Aggregated Residual Transformation for Deep Neural Networks\" <https://arxiv.org/pdf/1611.05431.pdf>`_\n    \"\"\"\n    kwargs['groups'] = 32\n    kwargs['width_per_group'] = 8\n    return _sew_resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], pretrained, progress, cnf, node, **kwargs)\n\n@register_model\ndef sew_wide_resnet50_2(pretrained=False, progress=True, cnf: str = None,  **kwargs):\n    \"\"\"\n    :param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet\n    :type pretrained: bool\n    :param progress: If True, displays a progress bar of the download to stderr\n    :type progress: bool\n    :param cnf: the name of spike-element-wise function\n    :type cnf: str\n    :param node: a single step neuron\n    :type node: callable\n    :param kwargs: kwargs for `node`\n    :type kwargs: dict\n    :return: Spiking Wide ResNet-50-2\n    :rtype: torch.nn.Module\n    The spike-element-wise Wide ResNet-50-2 `\"Deep Residual Learning in Spiking Neural Networks\" <https://arxiv.org/abs/2102.04159>`_\n    modified by the Wide ResNet-50-2 model from `\"Wide Residual Networks\" <https://arxiv.org/pdf/1605.07146.pdf>`_\n    The model is the same as ResNet except for the bottleneck number of channels\n    which is twice larger in every block. The number of channels in outer 1x1\n    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048\n    channels, and in Wide ResNet-50-2 has 2048-1024-2048.\n    \"\"\"\n    kwargs['width_per_group'] = 64 * 2\n    return _sew_resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], pretrained, progress, cnf,  **kwargs)\n\n@register_model\ndef sew_wide_resnet101_2(pretrained=False, progress=True, cnf: str = None,  **kwargs):\n    \"\"\"\n    :param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet\n    :type pretrained: bool\n    :param progress: If True, displays a progress bar of the download to stderr\n    :type progress: bool\n    :param cnf: the name of spike-element-wise function\n    :type cnf: str\n    :param node: a single step neuron\n    :type node: callable\n    :param kwargs: kwargs for `node`\n    :type kwargs: dict\n    :return: Spiking Wide ResNet-101-2\n    :rtype: torch.nn.Module\n    The spike-element-wise Wide ResNet-101-2 `\"Deep Residual Learning in Spiking Neural Networks\" <https://arxiv.org/abs/2102.04159>`_\n    modified by the Wide ResNet-101-2 model from `\"Wide Residual Networks\" <https://arxiv.org/pdf/1605.07146.pdf>`_\n    The model is the same as ResNet except for the bottleneck number of channels\n    which is twice larger in every block. The number of channels in outer 1x1\n    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048\n    channels, and in Wide ResNet-50-2 has 2048-1024-2048.\n    \"\"\"\n    kwargs['width_per_group'] = 64 * 2\n    return _sew_resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], pretrained, progress, cnf,  **kwargs)\n"
  },
  {
    "path": "braincog/model_zoo/vgg_snn.py",
    "content": "# encoding: utf-8\n# Author    : Floyed<Floyed_Shen@outlook.com>\n# Datetime  : 2022/7/26 18:56\n# User      : Floyed\n# Product   : PyCharm\n# Project   : BrainCog\n# File      : vgg_snn.py\n# explain   :\n\nfrom functools import partial\nfrom torch.nn import functional as F\nimport torchvision\nfrom timm.models import register_model\nfrom braincog.base.node.node import *\nfrom braincog.base.connection.layer import *\nfrom braincog.base.encoder.encoder import *\nfrom braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule\nfrom braincog.datasets import is_dvs_data\n\n\n@register_model\nclass SNN7_tiny(BaseModule):\n    def __init__(self,\n                 num_classes=10,\n                 step=8,\n                 node_type=LIFNode,\n                 encode_type='direct',\n                 *args,\n                 **kwargs):\n        super().__init__(step, encode_type, *args, **kwargs)\n\n        self.num_classes = num_classes\n\n        self.node = node_type\n        if issubclass(self.node, BaseNode):\n            self.node = partial(self.node, **kwargs, step=step)\n\n        self.dataset = kwargs['dataset']\n        assert not is_dvs_data(self.dataset), 'SNN7_tiny only support static datasets now'\n\n        self.feature = nn.Sequential(\n            BaseConvModule(3, 16, kernel_size=(3, 3), padding=(1, 1), node=self.node),\n            BaseConvModule(16, 64, kernel_size=(3, 3), padding=(1, 1), node=self.node),\n            nn.MaxPool2d(2),\n            BaseConvModule(64, 128, kernel_size=(3, 3), padding=(1, 1), node=self.node),\n            BaseConvModule(128, 128, kernel_size=(3, 3), padding=(1, 1), node=self.node),\n            nn.MaxPool2d(2),\n            BaseConvModule(128, 256, kernel_size=(3, 3), padding=(1, 1), node=self.node),\n            BaseConvModule(256, 256, kernel_size=(3, 3), padding=(1, 1), node=self.node),\n            nn.MaxPool2d(2),\n            BaseConvModule(256, 512, kernel_size=(3, 3), padding=(1, 1), node=self.node),\n        )\n        self.fc = nn.Sequential(\n            nn.Flatten(),\n            nn.Linear(512 * 4 * 4, self.num_classes),\n        )\n\n    def forward(self, inputs):\n        inputs = self.encoder(inputs)\n        self.reset()\n\n        if self.layer_by_layer:\n            x = self.feature(inputs)\n            x = self.fc(x)\n            x = rearrange(x, '(t b) c -> t b c', t=self.step).mean(0)\n            return x\n\n        else:\n            outputs = []\n            for t in range(self.step):\n                x = inputs[t]\n                x = self.feature(x)\n                x = self.fc(x)\n                outputs.append(x)\n\n            return sum(outputs) / len(outputs)\n\n\n@register_model\nclass SNN5(BaseModule):\n    def __init__(self,\n                 num_classes=10,\n                 step=8,\n                 node_type=LIFNode,\n                 encode_type='direct',\n                 *args,\n                 **kwargs):\n        super().__init__(step, encode_type, *args, **kwargs)\n\n        self.n_preact = kwargs['n_preact'] if 'n_preact' in kwargs else False\n\n        self.num_classes = num_classes\n\n        self.node = node_type\n        if issubclass(self.node, BaseNode):\n            self.node = partial(self.node, **kwargs, step=step)\n\n        self.dataset = kwargs['dataset']\n        if not is_dvs_data(self.dataset):\n            init_channel = 3\n        else:\n            init_channel = 2\n\n        self.feature = nn.Sequential(\n            BaseConvModule(init_channel, 16, kernel_size=(3, 3), padding=(1, 1), node=self.node, n_preact=self.n_preact),\n            BaseConvModule(16, 64, kernel_size=(5, 5), padding=(2, 2), node=self.node, n_preact=self.n_preact),\n            nn.AvgPool2d(2),\n            BaseConvModule(64, 128, kernel_size=(5, 5), padding=(2, 2), node=self.node, n_preact=self.n_preact),\n            nn.AvgPool2d(2),\n            BaseConvModule(128, 256, kernel_size=(3, 3), padding=(1, 1), node=self.node, n_preact=self.n_preact),\n            nn.AvgPool2d(2),\n            BaseConvModule(256, 512, kernel_size=(3, 3), padding=(1, 1), node=self.node, n_preact=self.n_preact),\n            nn.AvgPool2d(2),\n        )\n        self.fc = nn.Sequential(\n            nn.Flatten(),\n            nn.Linear(512 * 3 * 3, self.num_classes),\n        )\n\n    def forward(self, inputs):\n        inputs = self.encoder(inputs)\n        self.reset()\n\n        if self.layer_by_layer:\n            x = self.feature(inputs)\n            x = self.fc(x)\n            x = rearrange(x, '(t b) c -> t b c', t=self.step).mean(0)\n            return x\n\n        else:\n            outputs = []\n            for t in range(self.step):\n                x = inputs[t]\n                x = self.feature(x)\n                x = self.fc(x)\n                outputs.append(x)\n\n            return sum(outputs) / len(outputs)\n\n\n@register_model\nclass VGG_SNN(BaseModule):\n    def __init__(self,\n                 num_classes=10,\n                 step=8,\n                 node_type=LIFNode,\n                 encode_type='direct',\n                 *args,\n                 **kwargs):\n        super().__init__(step, encode_type, *args, **kwargs)\n\n        self.n_preact = kwargs['n_preact'] if 'n_preact' in kwargs else False\n\n        self.num_classes = num_classes\n\n        self.node = node_type\n        if issubclass(self.node, BaseNode):\n            self.node = partial(self.node, **kwargs, step=step)\n\n        self.dataset = kwargs['dataset']\n        if not is_dvs_data(self.dataset):\n            raise NotImplementedError('VGG-SNN model is only for DVS data, but current datasets is {}'.format(self.dataset))\n\n        self.feature = nn.Sequential(\n            BaseConvModule(2, 64, kernel_size=(3, 3), padding=(1, 1), node=self.node),\n            BaseConvModule(64, 128, kernel_size=(3, 3), padding=(1, 1), node=self.node),\n            nn.AvgPool2d(2),\n            BaseConvModule(128, 256, kernel_size=(3, 3), padding=(1, 1), node=self.node),\n            BaseConvModule(256, 256, kernel_size=(3, 3), padding=(1, 1), node=self.node),\n            nn.AvgPool2d(2),\n            BaseConvModule(256, 512, kernel_size=(3, 3), padding=(1, 1), node=self.node),\n            BaseConvModule(512, 512, kernel_size=(3, 3), padding=(1, 1), node=self.node),\n            nn.AvgPool2d(2),\n            BaseConvModule(512, 512, kernel_size=(3, 3), padding=(1, 1), node=self.node),\n            BaseConvModule(512, 512, kernel_size=(3, 3), padding=(1, 1), node=self.node),\n            nn.AvgPool2d(2),\n        )\n        self.fc = nn.Sequential(\n            nn.Flatten(),\n            nn.Linear(512 * 3 * 3, self.num_classes),\n        )\n\n    def forward(self, inputs):\n        inputs = self.encoder(inputs)\n        self.reset()\n\n        if self.layer_by_layer:\n            x = self.feature(inputs)\n            x = self.fc(x)\n            x = rearrange(x, '(t b) c -> t b c', t=self.step).mean(0)\n            return x\n\n        else:\n            outputs = []\n            for t in range(self.step):\n                x = inputs[t]\n                x = self.feature(x)\n                x = self.fc(x)\n                outputs.append(x)\n\n            return sum(outputs) / len(outputs)\n"
  },
  {
    "path": "braincog/utils.py",
    "content": "import os\r\nimport random\r\nimport math\r\nimport csv\r\nimport numpy as np\r\nimport torch\r\nfrom torch import nn\r\nfrom PIL import Image\r\nimport matplotlib.pyplot as plt\r\nimport torchvision.transforms as transforms\r\n\r\n\r\ndef setup_seed(seed):\r\n    \"\"\"\r\n    为CPU，GPU，所有GPU，numpy，python设置随机数种子，并禁止hash随机化\r\n    :param seed: seed value\r\n    :return:\r\n    \"\"\"\r\n    torch.manual_seed(seed)\r\n    torch.cuda.manual_seed(seed)\r\n    torch.cuda.manual_seed_all(seed)\r\n    np.random.seed(seed)\r\n    random.seed(seed)\r\n\r\n\r\n    torch.backends.cudnn.benchmark = False\r\n\r\n    torch.backends.cudnn.deterministic = True\r\n\r\n\r\n    os.environ['PYTHONHASHSEED'] = str(seed)\r\n\r\n\r\n\r\ndef random_gradient(model: nn.Module, sigma: float):\r\n    \"\"\"\r\n    为梯度添加噪声\r\n    :param model: 模型\r\n    :param sigma: 噪声方差\r\n    :return:\r\n    \"\"\"\r\n    for param in model.parameters():\r\n        if param.grad is None:\r\n            continue\r\n        noise = torch.randn_like(param) * sigma\r\n        param.grad = param.grad + noise\r\n\r\n\r\nclass AverageMeter(object):\r\n    def __init__(self):\r\n        self.reset()\r\n\r\n    def reset(self):\r\n        self.avg = 0\r\n        self.sum = 0\r\n        self.cnt = 0\r\n\r\n    def update(self, val, n=1):\r\n        self.sum += val * n\r\n        self.cnt += n\r\n        self.avg = self.sum / self.cnt\r\n\r\nclass TensorGather(object):\r\n    def __init__(self):\r\n        self.reset()\r\n    def reset(self):\r\n        self.gather=None\r\n\r\n    def update(self, val):\r\n        if self.gather is not None:self.gather=torch.cat([self.gather,val],dim=0)\r\n        else:self.gather=val\r\n \r\ndef accuracy(output, target, topk=(1,)):\r\n    \"\"\"Compute the top1 and top5 accuracy\r\n    \"\"\"\r\n    maxk = max(topk)\r\n    batch_size = target.size(0)\r\n    # Return the k largest elements of the given input tensor\r\n    # along a given dimension -> N * k\r\n    _, pred = output.topk(maxk, 1, True, True)\r\n    pred = pred.t()\r\n    correct = pred.eq(target.view(1, -1).expand_as(pred))\r\n    res = []\r\n    for k in topk:\r\n        correct_k = correct[:k].reshape(-1).float().sum(0)\r\n        res.append(correct_k.mul_(100.0 / batch_size))\r\n    return res\r\n\r\n\r\ndef mse(x, y):\r\n    out = (x - y).pow(2).sum(-1, keepdim=True).mean()\r\n    return out\r\n\r\n\r\ndef rand_ortho(shape, irange):\r\n    A = - irange + 2 * irange * np.random.rand(*shape)\r\n    U, s, V = np.linalg.svd(A, full_matrices=True)\r\n    return np.dot(U, np.dot(np.eye(U.shape[1], V.shape[0]), V))\r\n\r\n\r\ndef adjust_surrogate_coeff(epoch, tot_epochs):\r\n    T_min, T_max = 1e-3, 1e1\r\n    Kmin, Kmax = math.log(T_min) / math.log(10), math.log(T_max) / math.log(10)\r\n    t = torch.tensor([math.pow(10, Kmin + (Kmax - Kmin) / tot_epochs * epoch)]).float().cuda()\r\n    k = torch.tensor([1]).float().cuda()\r\n    if k < 1:\r\n        k = 1 / t\r\n    return t, k\r\n\r\n\r\ndef save_feature_map(x, dir=''):\r\n    for idx, layer in enumerate(x):\r\n        layer = layer.cpu()\r\n        for batch in range(layer.shape[0]):\r\n            for channel in range(layer.shape[1]):\r\n                fname = '{}_{}_{}_{}.jpg'.format(\r\n                    idx, batch, channel, layer.shape[-1])\r\n                fp = layer[batch, channel]\r\n                plt.tight_layout()\r\n                plt.axis('off')\r\n                plt.imshow(fp, cmap='inferno')\r\n                plt.savefig(os.path.join(dir, fname),\r\n                            bbox_inches='tight', pad_inches=0)\r\n\r\n\r\ndef save_spike_info(fname, epoch, batch_idx, step, avg, var, spike, avg_per_step):\r\n    \"\"\"\r\n    对spike-info格式进行调整, 便于保存\r\n    :param fname: 输出文件名\r\n    :param epoch: epoch\r\n    :param batch_idx: batch index\r\n    :param step: 仿真步长\r\n    :param avg: 平均脉冲发放率\r\n    :param var: 脉冲发放率的方差\r\n    :param spike:\r\n    :param avg_per_step:\r\n    :return:\r\n    \"\"\"\r\n    if not os.path.exists(fname):\r\n        f = open(fname, mode='w', encoding='utf8', newline='')\r\n        writer = csv.writer(f)\r\n        head = ['epoch', 'batch', 'layer', 'avg', 'var']\r\n        head.extend(['st_{}'.format(i) for i in range(step + 1)])  # spike times\r\n        head.extend(['as_{}'.format(i) for i in range(step)])  # avg spike per time\r\n        writer.writerow(head)\r\n\r\n    else:\r\n        f = open(fname, mode='a', encoding='utf8', newline='')\r\n        writer = csv.writer(f)\r\n\r\n    for layer in range(len(avg)):\r\n        lst = [epoch, batch_idx, layer, avg[layer], var[layer]]\r\n        lst.extend(spike[layer])\r\n        lst.extend(avg_per_step[layer])\r\n        lst = [str(x) for x in lst]\r\n        writer.writerow(lst)\r\n\r\n\r\ndef calc_aurc(confidences, labels):\r\n    \r\n \r\n    predictions = torch.argmax(confidences, dim=1)\r\n    max_confs = torch.max(confidences, dim=1)[0]\r\n\r\n    n = len(labels)\r\n\r\n    indices = torch.argsort(max_confs)\r\n\r\n    labels, predictions, confidences = labels[indices].flip(dims=[0]), predictions[indices].flip(dims=[0]), confidences[indices].flip(dims=[0])\r\n    risk_cov = torch.divide(torch.cumsum(labels != predictions,dim=0).float(), torch.arange(1, n+1).cuda())\r\n    nrisk = torch.sum(labels != predictions)\r\n    aurc = torch.mean(risk_cov)\r\n    opt_aurc = (1./n) * torch.sum(torch.divide(torch.arange(1, nrisk + 1).cuda().float(), n - nrisk + torch.arange(1, nrisk + 1).cuda()))\r\n    eaurc = aurc - opt_aurc\r\n            \r\n    return aurc, eaurc\r\n"
  },
  {
    "path": "docs/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line, and also\n# from the environment for the first two.\nSPHINXOPTS    ?=\nSPHINXBUILD   ?= sphinx-build\nSOURCEDIR     = source\nBUILDDIR      = build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n"
  },
  {
    "path": "docs/make.bat",
    "content": "@ECHO OFF\n\npushd %~dp0\n\nREM Command file for Sphinx documentation\n\nif \"%SPHINXBUILD%\" == \"\" (\n\tset SPHINXBUILD=sphinx-build\n)\nset SOURCEDIR=source\nset BUILDDIR=build\n\n%SPHINXBUILD% >NUL 2>NUL\nif errorlevel 9009 (\n\techo.\n\techo.The 'sphinx-build' command was not found. Make sure you have Sphinx\n\techo.installed, then set the SPHINXBUILD environment variable to point\n\techo.to the full path of the 'sphinx-build' executable. Alternatively you\n\techo.may add the Sphinx directory to PATH.\n\techo.\n\techo.If you don't have Sphinx installed, grab it from\n\techo.https://www.sphinx-doc.org/\n\texit /b 1\n)\n\nif \"%1\" == \"\" goto help\n\n%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\ngoto end\n\n:help\n%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\n\n:end\npopd\n"
  },
  {
    "path": "docs/source/conf.py",
    "content": "# Configuration file for the Sphinx documentation builder.\n#\n# This file only contains a selection of the most common options. For a full\n# list see the documentation:\n# https://www.sphinx-doc.org/en/master/usage/configuration.html\n\n# -- Path setup --------------------------------------------------------------\n\n# If extensions (or modules to document with autodoc) are in another directory,\n# add these directories to sys.path here. If the directory is relative to the\n# documentation root, use os.path.abspath to make it absolute, like shown here.\n#\nimport os\nimport sys\nimport warnings\nwarnings.filterwarnings(\"ignore\")\nsys.path.insert(0, os.path.abspath('../../braincog'))\n\n\n# -- Project information -----------------------------------------------------\n\nproject = 'braincog'\ncopyright = '2022, Brain-Inspired-Cognitive-Intelligence-Engine(BrainCog)'\nauthor = 'Brain-Inspired-Cognitive-Intelligence-Engine'\n\n# The full version, including alpha/beta/rc tags\nrelease = '0.2.7.11'\n\n\n# -- General configuration ---------------------------------------------------\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\n# ones.\nextensions = [\n    'sphinx.ext.autodoc',\n    'sphinx.ext.napoleon',\n    'sphinx.ext.doctest',\n    'sphinx.ext.intersphinx',\n    'sphinx.ext.todo',\n    'sphinx.ext.coverage',\n    'sphinx.ext.mathjax',\n    'recommonmark',\n#     'sphinx_markdown_tables',\n]\n\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = ['_templates']\n\n# The language for content autogenerated by Sphinx. Refer to documentation\n# for a list of supported languages.\n#\n# This is also used if you do content translation via gettext catalogs.\n# Usually you set \"language\" from the command line for these cases.\nlanguage = 'zh_CN'\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This pattern also affects html_static_path and html_extra_path.\nexclude_patterns = []\n\n\n# -- Options for HTML output -------------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\nhtml_theme = 'sphinx_rtd_theme'\n\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\nhtml_static_path = ['_static']\n"
  },
  {
    "path": "docs/source/examples/Brain_Cognitive_Function_Simulation/drosophila.md",
    "content": "\n# Drosophila-inspired decision-making SNN\n\n## Run\n \"drosophila.py\"  includes the training phase and testing phase.\n\n```shell\npython drosophila.py\n```\n* Training Phase\n\ngreen-upright T is safe and blue-inverted T is dangerous\n\n* Testing Phase \n\nFor linear pathway and nonlinear pathway, choose between blue-upright T and green-inverted T, and count the PI values under different color intensity\n\n## Results\n\nThe following picture shows the linear (a) and nonlinear (b) pathways, the training and testing phases (c), and the PI values on different color intensities (d).\n\n![description](./dro.jpg)"
  },
  {
    "path": "docs/source/examples/Brain_Cognitive_Function_Simulation/index.rst",
    "content": "Brain_Cognitive_Function_Simulation\n======================================\n\n.. toctree::\n    :maxdepth: 2\n\n    drosophila"
  },
  {
    "path": "docs/source/examples/Decision_Making/BDM_SNN.md",
    "content": "\n# Brain-inspired Decision-Making Spiking Neural Network\n\n## Run\n \"BDM-SNN.py\"  includes the multi-brain regions coordinated decision-making spiking neural network with LIF neurons.\n\n \"BDM-SNN-hh.py\"  includes the BDM-SNN with simplified HH neurons.\n\n \"BDM-SNN-UAV.py\"  includes the BDM-SNN applied to the UAV (DJI Tello talent), users need to define the reinforcement learning task.\n\n```shell\npython BDM-SNN.py\n\npython BDM-SNN-hh.py\n\npython BDM-SNN-UAV.py\n```\n\n## Results\n \"BDM-SNN.py\"  and  \"BDM-SNN-hh.py\"  have been verified on Flappy Bird game. BDM-SNN could stably pass the pipeline on the first try.\n\n![description](./bdm.png)"
  },
  {
    "path": "docs/source/examples/Decision_Making/RL.md",
    "content": "# PL-SDQN\n\nThis repository contains code from our paper [**Solving the Spike Feature Information Vanishing Problem in Spiking Deep Q Network with Potential Based Normalization**]. If you use our code or refer to this project, please cite this paper.\n\n## Requirments\n\n* numpy\n* scipy\n* pytorch >= 1.7.0\n* torchvision\n* gym\n* atari-py\n* opencv-python\n* tianshou\n\n## Train\n\n```python \npython ./sdqn/main.py\n```\n\n\n"
  },
  {
    "path": "docs/source/examples/Decision_Making/index.rst",
    "content": "Decision Making\n======================================\n\n.. toctree::\n    :maxdepth: 2\n\n    RL\n    BDM_SNN"
  },
  {
    "path": "docs/source/examples/Knowledge_Representation_and_Reasoning/CKRGSNN.md",
    "content": "# Commonsense Knowledge Representation SNN\n\n(https://arxiv.org/abs/2207.05561)\n\nThis repository contains code from our paper [**Brain-inspired Graph Spiking Neural Networks for Commonsense Knowledge Representation and Reasoning**] preprint in: https://arxiv.org/abs/2207.05561 . If you use our code or refer to this project, please cite this paper.\n\n\n\n\n## Requirments\n\n* python=3.8\n* numpy\n* scipy\n* turicreate\n* pytorch >= 1.7.0\n* torchvision\n\n\n## Dataset\n\nConceptNet: https://github.com/commonsense/conceptnet5\n\n\n## Run\n\n```shell\npython main.py\n```\n\nThis module selects core knowledge in ConceptNet to form the sub_Concept.csv file as the input of the model. The input current, spike trains during the learning process and the network weight distribution after the learning are shown in the Results folder.\n"
  },
  {
    "path": "docs/source/examples/Knowledge_Representation_and_Reasoning/CRSNN.md",
    "content": "# Causal Reasoning SNN\n(https://10.1109/IJCNN52387.2021.9534102)\n\nThis repository contains code from our paper [**A Brain-Inspired Causal Reasoning Model Based on Spiking Neural Networks\n**] published in 2021 International Joint Conference on Neural Networks (IJCNN).  https://ieeexplore.ieee.org/abstract/document/9534102. If you use our code or refer to this project, please cite this paper.\n\n## Requirments\n\n* numpy\n* scipy\n* pytorch >= 1.7.0\n* torchvision\n\n\n\n## Run\n\n```shell\npython main.py\n```\n\n\nThis module builds an example of a brain-like causal inference spiking neural network model. The input causal graph is shown in figure causal_graph.png. The input current, spike trains during the learning process and the network weight distribution after the learning are shown in the Results folder.\n"
  },
  {
    "path": "docs/source/examples/Knowledge_Representation_and_Reasoning/SPSNN.md",
    "content": "#  Sequence Production SNN\n[![DOI](https://doi.org/10.3389/fncom.2021.612041)]\n\nThis repository contains code from our paper [**Brain Inspired Sequences Production by Spiking Neural Networks With Reward-Modulated STDP**] published in Frontiers in Computational Neuroscience. https://www.frontiersin.org/articles/10.3389/fncom.2021.612041/full. If you use our code or refer to this project, please cite this paper.\n\n## Requirments\n\n* numpy\n* scipy\n* pytorch >= 1.7.0\n* torchvision\n\n\n\n\n\n## Run\n\n```shell\npython main.py file\n```\n\n\nThis module builds a sequence Production spiking neural network model, realizing the memory and reconstruction functions for arbitrary symbol sequences. The input causal graph is shown in figure causal_graph.png. The input current, spike trains during the learning process and the network weight distribution after the learning are shown in the Results folder.\n"
  },
  {
    "path": "docs/source/examples/Knowledge_Representation_and_Reasoning/index.rst",
    "content": "Knowledge Representation and Reasoning\n=========================================\n\n.. toctree::\n    :maxdepth: 2\n\n    musicMemory\n    SPSNN\n    CRSNN\n    CKRGSNN"
  },
  {
    "path": "docs/source/examples/Knowledge_Representation_and_Reasoning/musicMemory.md",
    "content": "# Music Memory\n\n数据集：http://www.piano-midi.de/\n\n自行下载数据集，数据使用方法见task下的示例。"
  },
  {
    "path": "docs/source/examples/Multi-scale_Brain_Structure_Simulation/Corticothalamic_minicolumn.md",
    "content": "# Corticothalamic minicolumn\n\n## Description\nThe anatomical data is saved in the \"tool\" package. The **main.py** create the network of minicolumn deppending on the anatomical data.\nA file named **\"fire.csv\"** will be generated to record the firing result of neurons in each time step.\n\n## Requirments\n* numpy\n* scipy\n* pytorch >= 1.7.0\n\n```shell\npython main.py\n```"
  },
  {
    "path": "docs/source/examples/Multi-scale_Brain_Structure_Simulation/HumanBrain.md",
    "content": "# Human Brain Simulation\n\n## Description\nHuman Brain Simulation is a large scale brain modeling framework depending on braincog framework.\n\n## Requirements:\n* numpy >= 1.21.2\n* scipy >= 1.8.0\n* h5py >= 3.6.0\n* torch >= 1.10\n* torchvision >= 0.12.0\n* torchaudio  >= 0.11.0\n* timm >= 0.5.4\n* matplotlib >= 3.5.1\n* einops >= 0.4.1\n* thop >= 0.0.31\n* pyyaml >= 6.0\n* loris >= 0.5.3\n* pandas >= 1.4.2  \n* tonic (special)\n* pandas >= 1.4.2  \n\n## Example:\n\n```shell \ncd ~/examples/Multi-scale Brain Structure Simulation/HumanBrain/\npython brainSimHum.py\n```\n\n## Parameters:\nTo simulate the models (both human and macaque brain), the parameters of the neuron number in each region and the connectome power between regions can be set flexibly in the main function (nsz and asz) of the .py files.\n"
  },
  {
    "path": "docs/source/examples/Multi-scale_Brain_Structure_Simulation/Human_PFC.md",
    "content": "# Human PFC\n\n## Input:\n\n* 程序输入六层皮质柱的电生理数据，数据文件名中的数字表示神经元数量。程序默认有背景电流输入。其中data+数字命名的是随机输入刺激的文件，输入图片刺激的文件分别有人类参数的和小鼠参数的。分别都支持四种形状的图片文件，圆形、正方形、三角形和星形。\n\n链接：https://drive.google.com/drive/folders/1AVc2aNTxkcsGAPlq1SuWtatGzyQRPCmp?usp=sharing\n\n\n\n## output\n\n* 程序生成的各个神经元放电时间点记录的数据文件\n\n## application:\n\n* 程序中可以修改每次PFC模型放电环境的情况\n\n"
  },
  {
    "path": "docs/source/examples/Multi-scale_Brain_Structure_Simulation/MacaqueBrain.md",
    "content": "# Macaque Brain Simulation\n\n## Description\nMacaque Brain Simulation is a large scale brain modeling framework depending on braincog framework.\n\n## Requirements:\n* numpy >= 1.21.2\n* scipy >= 1.8.0\n* h5py >= 3.6.0\n* torch >= 1.10\n* torchvision >= 0.12.0\n* torchaudio  >= 0.11.0\n* timm >= 0.5.4\n* matplotlib >= 3.5.1\n* einops >= 0.4.1\n* thop >= 0.0.31\n* pyyaml >= 6.0\n* loris >= 0.5.3\n* pandas >= 1.4.2  \n* tonic (special)\n* pandas >= 1.4.2  \n\n## Example:\n\n```shell\ncd ~/examples/Multi-scale Brain Structure Simulation/MacaqueBrain/\npython brainSimMaq.py\n```\n\n## Parameters:\nTo simulate the models (both human and macaque brain), the parameters of the neuron number in each region and the connectome power between regions can be set flexibly in the main function (nsz and asz) of the .py files.\n"
  },
  {
    "path": "docs/source/examples/Multi-scale_Brain_Structure_Simulation/index.rst",
    "content": "Multi-scale Brain Structure Simulation\n=========================================\n\n.. toctree::\n    :maxdepth: 2\n\n    MacaqueBrain\n    HumanBrain\n    mouse_brain\n    Human_PFC\n    Corticothalamic_minicolumn"
  },
  {
    "path": "docs/source/examples/Multi-scale_Brain_Structure_Simulation/mouse_brain.md",
    "content": "# Mouse Brain\n\n## Input:\n\n* 程序输入213个脑区之间的连接权重的表格。放在谷歌网盘上面，名称是'W_213.xlsx'。\n\n链接：[https://drive.google.com/drive/folders/1snPbpiVBpVuRgRYcl4AG4v49NgKKwBtA?usp=sharing](https://drive.google.com/drive/folders/1snPbpiVBpVuRgRYcl4AG4v49NgKKwBtA?usp=sharing)\n\n\n## output\n\n* 程序生成的各个神经元放电时间点记录的数据mat文件，数据点的数量大需要用画图软件显示结果。\n\n## application:\n\n* 程序中可以修改神经元的数量模拟时间的情况\n"
  },
  {
    "path": "docs/source/examples/Perception_and_Learning/Conversion.md",
    "content": "# Conversion Method\nTraining deep spiking neural network with ann-snn conversion\nreplace ReLU and MaxPooling in pytorch model to make origin ANN to be converted SNN to finish complex tasks\n\n## Results\n```shell\npython CIFAR10_VGG16.py\npython converted_CIFAR10.py\n```\n\nYou should first run the `CIFAR10_VGG16.py` to get a well-trained ANN.\nThen `converted_CIFAR10.py` can be used to run the snn inference process."
  },
  {
    "path": "docs/source/examples/Perception_and_Learning/MultisensoryIntegration.md",
    "content": "# Multisensory Integration DEMO\n\nIn `MultisensoryIntegrationDEMO_AM.py` and `MultisensoryIntegrationDEMO_AM.py`, we implement the SNNs based multisensory integration framework. To load the dataset, preprocess it and get the weights with the function `get_concept_datase_dic_and_initial_weights_lst()`​. We use `IMNet`​ or ​`AMNet`​ to describe the structure of the IM/AM model. For presynaptic neuron, we use the function `convert_vec_into_spike_trains()​` to generate the spike trains.\n\nWhile for postsynaptic neuron, we use the function `reducing_tol_binarycode()` to get the multisensory integrated output for each concept. And  *tol* is the only parameter.\n\nIn `measure_and_visualization.py`, we will measure and visualize the results.\n\n## Multisensory Dataset\n\nWhen implement the model in braincog, we use the famous multisensory dataset--BBSR.\n\nSome examples are as follows:\n\n| Concept   | Visual      | Somatic   | Audiation   | Taste    | Smell    |\n| --------- | ----------- | --------- | ----------- | -------- | -------- |\n| advantage | 0.213333333 | 0.032     | 0           | 0        | 0        |\n| arm       | 2.5111112   | 2.2733334 | 0.133333286 | 0.233333 | 0.4      |\n| ball      | 1.9580246   | 2.3111112 | 0.523809429 | 0.185185 | 0.111111 |\n| baseball  | 2.2714286   | 2.6071428 | 0.352040714 | 0.071429 | 0.392857 |\n| bee       | 2.795698933 | 2.4129034 | 2.096774286 | 0.290323 | 0.419355 |\n| beer      | 1.4866666   | 2.2533334 | 0.190476286 | 5.8      | 4.6      |\n| bird      | 2.7632184   | 2.027586  | 3.064039286 | 1.068966 | 0.517241 |\n| car       | 2.521839133 | 2.9517244 | 2.216748857 | 0        | 2.206897 |\n| foot      | 2.664444533 | 2.58      | 0.380952429 | 0.433333 | 3        |\n| honey     | 1.757142867 | 2.3214286 | 0.015306143 | 5.642857 | 4.535714 |\n\n## How to Run\n\nTo get the multisensory integrated vectors:\n\n```\ncd examples/MultisensoryIntegration/code\npython MultisensoryIntegrationDEMO_AM.py\npython MultisensoryIntegrationDEMO_IM.py\n```\n\nTo measure and analysis the vectors:\n\n```\ncd examples/MultisensoryIntegration/code\npython measure_and_visualization.py\n```\n\n\n"
  },
  {
    "path": "docs/source/examples/Perception_and_Learning/QSNN.md",
    "content": "# Quantum superposition inspired spiking neural network\n\nThis repository contains code from our paper [**Quantum superposition inspired spiking neural network**] published in iScience. https://doi.org/10.1016/j.isci.2021.102880. If you use our code or refer to this project, please cite this paper.\n\n## Requirments\n\n* numpy\n* scipy\n* pytorch >= 1.7.0\n* torchvision\n\n\n## Data preparation\n\nFirst modify the ```DATA_DIR='path/to/datasets``` in ```examples/QSNN/main.py``` to the root directory of your MNIST datasets.\n\n\n## Train\n\n```shell  \npython ./main.py\n```\n"
  },
  {
    "path": "docs/source/examples/Perception_and_Learning/UnsupervisedSTDP.md",
    "content": "# Unsupervised STDP\nThis is an example of training Unsupervised STDP-based spiking neural network. We used a STB-STDP algrithom to train SNN, and mutiply adaptive mechanisms.\n \n## How to run\npython codef.py\n\n## Result\nWe train the model on Mnist and FashionMNIST, and the best accuracy for MNIST is 97.9%, for FashionMNIST is 87.0%."
  },
  {
    "path": "docs/source/examples/Perception_and_Learning/img_cls/bp.md",
    "content": "# Script for training high-performance SNNs based on back propagation \nThis is an example of training high-performance SNNs using the braincog.\nIt is able to train high performance SNNs on CIFAR10, DVS-CIFAR10, ImageNet and other datasets, and reach the advanced level. \n\n## Install braincog  \n\n```shell\ngit clone https://github.com/xxx/Brain-Cog.git\ncd braincog \npython setup install --user \n```\n\n## Examples of training\n\n```shell\ncd examples/img_cls/bp \npython main.py --model dvs_convnet --node-type LIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGate --device 0 \n```\n\n## Benchmark \n\nWe provide a benchmark of SNNs trained with braincog and the corresponding scripts. \nThis provides an open, fair platform for comparison of subsequent SNNs on classification tasks. \n\n**Note**: The results may vary due to random seeding and software version issues. \n\n\n### CIFAR10 \n\n| ID  | Dataset | Node-type  | Config |    Model    | Batch Size |   Accuracy   | Script                                                                                                                                     |\n|:----|:-------:|:----------:|:------:|:-----------:|:----------:|:------------:|:-------------------------------------------------------------------------------------------------------------------------------------------|\n| 1   | CIFAR10 |  IF+Atan   |   -    |   convnet   |    128     |    95.54     | ```python main.py --model cifar_convnet --node-type IFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun AtanGrad --device 0```    |\n| 1   | CIFAR10 |  LIF+Atan  |   -    |   convnet   |    128     |    91.92     | ```python main.py --model cifar_convnet --node-type LIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun AtanGrad --device 0```   |\n| 1   | CIFAR10 | PLIF+Atan  |   -    |   convert   |    128     |    93.32     | ```python main.py --model cifar_convnet --node-type PLIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun AtanGrad --device 0```  |\n| 1   | CIFAR10 |  IF+Atan   |   -    |  resnet18   |    128     | 89.76/89.80  | ```python main.py --model resnet18 --node-type IFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun AtanGrad --device 0```         |\n| 1   | CIFAR10 |  LIF+Atan  |   -    |  resnet18   |    128     | 89.93/89.88  | ```python main.py --model resnet18 --node-type LIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun AtanGrad --device 0```        |\n| 1   | CIFAR10 | PLIF+Atan  |   -    |  resnet18   |    128     | 92.64/ 90.65 | ```python main.py --model resnet18 --node-type PLIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun AtanGrad --device 0```       |\n| 1   | CIFAR10 |  IF+QGate  |   -    | dvs_convnet |    128     |    95.73     | ```python main.py --model cifar_convnet --node-type IFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun QGateGrad --device 0```   |\n| 1   | CIFAR10 | LIF+QGate  |   -    | dvs_convnet |    128     |    96.04     | ```python main.py --model cifar_convnet --node-type LIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun QGateGrad --device 0```  |\n| 1   | CIFAR10 | PLIF+QGate |   -    | dvs_convnet |    128     | 96.04/95.84  | ```python main.py --model cifar_convnet --node-type PLIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun QGateGrad --device 0``` |\n| 1   | CIFAR10 |  IF+QGate  |   -    |  resnet18   |    128     |    89.19     | ```python main.py --model resnet18 --node-type IFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun QGateGrad --device 0```        |\n| 1   | CIFAR10 | LIF+QGate  |   -    |  resnet18   |    128     | 90.95/90.68  | ```python main.py --model resnet18 --node-type LIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun QGateGrad --device 0```       |\n| 1   | CIFAR10 | PLIF+QGate |   -    |  resnet18   |    128     | 90.97/91.02  | ```python main.py --model resnet18 --node-type PLIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun QGateGrad --device 0```      |\n\n\n### CIFAR100 \n\n| ID  | Dataset  | Node-type  | Config |    Model    | Batch Size | Accuracy | Script                                                                                                                                                        |\n|:----|:--------:|:----------:|:------:|:-----------:|:----------:|:--------:|:--------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| 1   | CIFAR100 |  IF+Atan   |   -    | dvs_convnet |    128     |  76.52   | ```python main.py --num-classes 100 --model cifar_convnet --node-type IFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun AtanGrad --device 0```    |\n| 1   | CIFAR100 |  LIF+Atan  |   -    | dvs_convnet |    128     |  71.89   | ```python main.py --num-classes 100 --model cifar_convnet --node-type LIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun AtanGrad --device 0```   |\n| 1   | CIFAR100 | PLIF+Atan  |   -    | dvs_convnet |    128     |  72.82   | ```python main.py --num-classes 100 --model cifar_convnet --node-type PLIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun AtanGrad --device 0```  |\n| 1   | CIFAR100 |  IF+Atan   |   -    |  resnet18   |    128     |  62.47   | ```python main.py --num-classes 100 --model resnet18 --node-type IFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun AtanGrad --device 0```         |\n| 1   | CIFAR100 |  LIF+Atan  |   -    |  resnet18   |    128     |  62.63   | ```python main.py --num-classes 100 --model resnet18 --node-type LIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun AtanGrad --device 0```        |\n| 1   | CIFAR100 | PLIF+Atan  |   -    |  resnet18   |    128     |  62.71   | ```python main.py --num-classes 100 --model resnet18 --node-type PLIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun AtanGrad --device 0```       |\n| 1   | CIFAR100 |  IF+QGate  |   -    | dvs_convnet |    128     |  76.44   | ```python main.py --num-classes 100 --model cifar_convnet --node-type IFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun QGateGrad --device 0```   |\n| 1   | CIFAR100 | LIF+QGate  |   -    | dvs_convnet |    128     |  77.73   | ```python main.py --num-classes 100 --model cifar_convnet --node-type LIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun QGateGrad --device 0```  |\n| 1   | CIFAR100 | PLIF+QGate |   -    | dvs_convnet |    128     |  77.25   | ```python main.py --num-classes 100 --model cifar_convnet --node-type PLIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun QGateGrad --device 0``` |\n| 1   | CIFAR100 |  IF+QGate  |   -    |  resnet18   |    128     |  60.01   | ```python main.py --num-classes 100 --model resnet18 --node-type IFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun QGateGrad --device 0```        |\n| 1   | CIFAR100 | LIF+QGate  |   -    |  resnet18   |    128     |  61.33   | ```python main.py --num-classes 100 --model resnet18 --node-type LIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun QGateGrad --device 0```       |\n| 1   | CIFAR100 | PLIF+QGate |   -    |  resnet18   |    128     |  62.32   | ```python main.py --num-classes 100 --model resnet18 --node-type PLIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun QGateGrad --device 0```      |\n\n\n### DVS-CIFAR10\n\n| ID  |   Dataset   | Node-type  | Config |    Model    | Batch Size | FLOPS |  Accuracy   | Script                                                                                                                                   |\n|:----|:-----------:|:----------:|:------:|:-----------:|:----------:|:-----:|:-----------:|:-----------------------------------------------------------------------------------------------------------------------------------------|\n| 1   | DVS-CIFAR10 |  IF+Atan   |   -    | dvs_convnet |    128     | 7503  |    65.90    | ```python main.py --model dvs_convnet --node-type IFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun AtanGrad --device 0```    |\n| 1   | DVS-CIFAR10 |  LIF+Atan  |   -    | dvs_convnet |    128     | 7503  |    82.10    | ```python main.py --model dvs_convnet --node-type LIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun AtanGrad --device 0```   |\n| 1   | DVS-CIFAR10 | PLIF+Atan  |   -    | dvs_convnet |    128     | 7503  |    81.90    | ```python main.py --model dvs_convnet --node-type PLIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun AtanGrad --device 0```  |\n| 1   | DVS-CIFAR10 |  IF+Atan   |   -    |  resnet18   |    128     | 3149  |    69.10    | ```python main.py --model resnet18 --node-type IFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun AtanGrad --device 0```       |\n| 1   | DVS-CIFAR10 |  LIF+Atan  |   -    |  resnet18   |    128     | 3149  |    78.50    | ```python main.py --model resnet18 --node-type LIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun AtanGrad --device 0```      |\n| 1   | DVS-CIFAR10 | PLIF+Atan  |   -    |  resnet18   |    128     | 3149  |    77.70    | ```python main.py --model resnet18 --node-type PLIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun AtanGrad --device 0```     |\n| 1   | DVS-CIFAR10 |  IF+QGate  |   -    | dvs_convnet |    128     | 7503  |    68.30    | ```python main.py --model dvs_convnet --node-type IFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGateGrad --device 0```   |\n| 1   | DVS-CIFAR10 | LIF+QGate  |   -    | dvs_convnet |    128     | 7503  | 82.60/82.90 | ```python main.py --model dvs_convnet --node-type LIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGateGrad --device 0```  |\n| 1   | DVS-CIFAR10 | PLIF+QGate |   -    | dvs_convnet |    128     | 7503  |    83.20    | ```python main.py --model dvs_convnet --node-type PLIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGateGrad --device 0``` |\n| 1   | DVS-CIFAR10 |  IF+QGate  |   -    |  resnet18   |    128     | 3149  | 65.70/66.80 | ```python main.py --model resnet18 --node-type IFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGateGrad --device 0```      |\n| 1   | DVS-CIFAR10 | LIF+QGate  |   -    |  resnet18   |    128     | 3149  | 79.00/79.40 | ```python main.py --model resnet18 --node-type LIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGateGrad --device 0```     |\n| 1   | DVS-CIFAR10 | PLIF+QGate |   -    |  resnet18   |    128     | 3149  | 78.10/78.20 | ```python main.py --model resnet18 --node-type PLIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGateGrad --device 0```    |\n\n\n### DVS-Gesture\n\n| ID  | Dataset | Node-type  | Config |    Model    | Batch Size |  Accuracy   | Script                                                                                                                                                  |\n|:----|:-------:|:----------:|:------:|:-----------:|:----------:|:-----------:|:--------------------------------------------------------------------------------------------------------------------------------------------------------|\n| 1   |  DVS-G  |  IF+Atan   |   -    | dvs_convnet |    128     |    64.77    | ```python main.py --num-classes 11 --model dvs_convnet --node-type IFNode --dataset dvsg --step 10 --batch-size 128 --act-fun AtanGrad --device 0```    |\n| 1   |  DVS-G  |  LIF+Atan  |   -    | dvs_convnet |    128     |    91.28    | ```python main.py --num-classes 11 --model dvs_convnet --node-type LIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun AtanGrad --device 0```   |\n| 1   |  DVS-G  | PLIF+Atan  |   -    | dvs_convnet |    128     |    91.67    | ```python main.py --num-classes 11 --model dvs_convnet --node-type PLIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun AtanGrad --device 0```  |\n| 1   |  DVS-G  |  IF+Atan   |   -    |  resnet18   |    128     |    63.25    | ```python main.py --num-classes 11 --model resnet18 --node-type IFNode --dataset dvsg --step 10 --batch-size 128 --act-fun AtanGrad --device 0```       |\n| 1   |  DVS-G  |  LIF+Atan  |   -    |  resnet18   |    128     |    91.29    | ```python main.py --num-classes 11 --model resnet18 --node-type LIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun AtanGrad --device 0```      |\n| 1   |  DVS-G  | PLIF+Atan  |   -    |  resnet18   |    128     |    90.15    | ```python main.py --num-classes 11 --model resnet18 --node-type PLIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun AtanGrad --device 0```     |\n| 1   |  DVS-G  |  IF+QGate  |   -    | dvs_convnet |    128     |    48.48    | ```python main.py --num-classes 11 --model dvs_convnet --node-type IFNode --dataset dvsg --step 10 --batch-size 128 --act-fun QGateGrad --device 0```   |\n| 1   |  DVS-G  | LIF+QGate  |   -    | dvs_convnet |    128     | 92.05/92.42 | ```python main.py --num-classes 11 --model dvs_convnet --node-type LIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun QGateGrad --device 0```  |\n| 1   |  DVS-G  | PLIF+QGate |   -    | dvs_convnet |    128     |    91.28    | ```python main.py --num-classes 11 --model dvs_convnet --node-type PLIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun QGateGrad --device 0``` |\n| 1   |  DVS-G  |  IF+QGate  |   -    |  resnet18   |    128     |    57.95    | ```python main.py --num-classes 11 --model resnet18 --node-type IFNode --dataset dvsg --step 10 --batch-size 128 --act-fun QGateGrad --device 0```      |\n| 1   |  DVS-G  | LIF+QGate  |   -    |  resnet18   |    128     |    90.91    | ```python main.py --num-classes 11 --model resnet18 --node-type LIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun QGateGrad --device 0```     |\n| 1   |  DVS-G  | PLIF+QGate |   -    |  resnet18   |    128     |    92.42    | ```python main.py --num-classes 11 --model resnet18 --node-type PLIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun QGateGrad --device 0```    |\n\n### NCALTECH101\n\n| ID  |   Dataset   | Node-type  | Config |    Model    | Batch Size |  Accuracy   | Script                                                                                                                                                          |\n|:----|:-----------:|:----------:|:------:|:-----------:|:----------:|:-----------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| 1   | NCALTECH101 |  IF+QGate  |   -    | dvs_convnet |    128     | 23.09/51.15 | ```python main.py --num-classes 100 --model dvs_convnet --node-type IFNode --dataset NCALTECH101 --step 10 --batch-size 128 --act-fun QGateGrad --device 0```   |\n| 1   | NCALTECH101 | LIF+QGate  |   -    | dvs_convnet |    128     | 72.78/75.09 | ```python main.py --num-classes 100 --model dvs_convnet --node-type LIFNode --dataset NCALTECH101 --step 10 --batch-size 128 --act-fun QGateGrad --device 0```  |\n| 1   | NCALTECH101 | PLIF+QGate |   -    | dvs_convnet |    128     | 74.61/76.79 | ```python main.py --num-classes 100 --model dvs_convnet --node-type PLIFNode --dataset NCALTECH101 --step 10 --batch-size 128 --act-fun QGateGrad --device 0``` |\n| 1   | NCALTECH101 |  IF+QGate  | -/mix  |  resnet18   |    128     | 61.24/60.87 | ```python main.py --num-classes 100 --model resnet18 --node-type IFNode --dataset NCALTECH101 --step 10 --batch-size 128 --act-fun QGateGrad --device 0```      |\n| 1   | NCALTECH101 | LIF+QGate  | -/mix  |  resnet18   |    128     | 66.22/70.84 | ```python main.py --num-classes 100 --model resnet18 --node-type LIFNode --dataset NCALTECH101 --step 10 --batch-size 128 --act-fun QGateGrad --device 0```     |\n| 1   | NCALTECH101 | PLIF+QGate | -/mix  |  resnet18   |    128     | 69.62/69.87 | ```python main.py --num-classes 100 --model resnet18 --node-type PLIFNode --dataset NCALTECH101 --step 10 --batch-size 128 --act-fun QGateGrad --device 0```    |\n\nNote: \n1. resnet18 is used here by adding a maximum pooling after the initial convolution layer.\nHowever, in the final version of braincog, we remove this pooling layer.\n2. mix refers to the use of EventMix as a data augmentation method.\n3. We will continue to add other results.\n\n\n### Citation \nIf you find this package helpful, please consider citing it:\n\n```BibTex\n@software{name,\n  author       = {xxx},\n  title        = {braincog: xxx},\n  month        = jul,\n  year         = 2022,\n  note         = {{Documentation available under \n                   https://xxx.readthedocs.io}},\n  publisher    = {Zenodo},\n  version      = {xxx},\n  doi          = {xxx},\n  url          = {xxx}\n}\n```"
  },
  {
    "path": "docs/source/examples/Perception_and_Learning/img_cls/glsnn.md",
    "content": "# SNN with global feedback connections\nTraining deep spiking neural network with the global \nfeedback connections and the local optimization learning rules. And is a little different from our original paper. \n\n## Results\n```shell\npython cls_glsnn.py\n```\nWe train the model for 100 epochs, and the best accuracy for MNIST is 98.23\\%, for FashionMNIST is 89.68\\%.\n![image](result_zdc.png)"
  },
  {
    "path": "docs/source/examples/Perception_and_Learning/img_cls/index.rst",
    "content": "Examples for Image Classification\n=================================\n\n.. toctree::\n    :maxdepth: 2\n\n    bp\n    glsnn"
  },
  {
    "path": "docs/source/examples/Perception_and_Learning/index.rst",
    "content": "Perception and Learning\n=================================\n\n.. toctree::\n    :maxdepth: 2\n\n    img_cls/index\n    Conversion\n    UnsupervisedSTDP\n    QSNN\n    MultisensoryIntegration"
  },
  {
    "path": "docs/source/examples/Social_Cognition/Mirror_Test.md",
    "content": "# Mirror Test\n\nThe mirror_test.py implements the core code of the Multi-Robots Mirror Self-Recognition Test in \"Toward Robot Self-Consciousness (II): Brain-Inspired Robot Bodily Self Model for Self-Recognition\".\n\nThe experiment is: three robots with identical appearance move their arms randomly in front of the mirror at the same time. \n\nIn the training stage, according to the spiking time difference of neurons in IPLM and IPLV, the robot learns the correlations between self-generated actions and visual feedbacks in motion by learning with spike timing dependent plasticity (STDP) mechanism. \n\nIn the test stage, the robot can predicts the visual feedback generated by its arm movement according to the training results. With the InsulaNet, the robot can identify which mirror image belongs to it.\n\nIn the result, Motion Detection shows the results of visual detection, and Motion Prediction shows the visual feedback generated by itself. The red line in the figure indicates that the robot determines that the corresponding mirror belongs to itself.\n\nDifferences from the original article:\nSince there is no motion error under the simulation conditions, the theta_threshold is set to zero.\n\n\n### Citation \nIf you find this package helpful, please consider citing it:\n\n```BibTex\n@article{zeng2018toward,\n  title={Toward robot self-consciousness (ii): brain-inspired robot bodily self model for self-recognition},\n  author={Zeng, Yi and Zhao, Yuxuan and Bai, Jun and Xu, Bo},\n  journal={Cognitive Computation},\n  volume={10},\n  number={2},\n  pages={307--320},\n  year={2018},\n  publisher={Springer}\n}\n```"
  },
  {
    "path": "docs/source/examples/Social_Cognition/ToM.md",
    "content": "# ToM\n\n\n## Requirments\n\n* numpy\n* scipy\n* pytorch >= 1.7.0\n* torchvision\n* pygame\n\n## Run\n### Train \n* the file to be run: main_both.py \n* args:\n    * the path to save net_NPC: --save_net_N\n    * the path to save net_a: --save_net_a\n    * time steps: --T\n\n```bash\npython main_both.py --save_net_N=net_NPC.pth --save_net_a=net_agent.pth --episodes=45 --trajectories=30 --T=50 --mode=train --task=both\n```\n\n### Test\n\n```bash\npython main_ToM.py --save_net_N=net_NPC.pth --save_net_a=net_agent.pth --episodes=45 --trajectories=30 --T=50 --mode=train --task=both\n```\n\n"
  },
  {
    "path": "docs/source/examples/Social_Cognition/index.rst",
    "content": "Social Cognition\n======================================\n\n.. toctree::\n    :maxdepth: 2\n\n    ToM\n    Mirror_Test\n"
  },
  {
    "path": "docs/source/examples/index.rst",
    "content": "Examples\n=================================\n\n.. toctree::\n    :maxdepth: 2\n\n    Perception_and_Learning/index\n    Brain_Cognitive_Function_Simulation/index\n    Decision_Making/index\n    Knowledge_Representation_and_Reasoning/index\n    Social_Cognition/index\n    Multi-scale_Brain_Structure_Simulation/index"
  },
  {
    "path": "docs/source/index.rst",
    "content": ".. braincog documentation master file, created by\n   sphinx-quickstart on Sun Apr 10 21:02:06 2022.\n   You can adapt this file completely to your liking, but it should at least\n   contain the root `toctree` directive.\n\nWelcome to braincog's documentation!\n====================================\n\n.. toctree::\n   :maxdepth: 2\n    \n   \n   tutorial/index\n   braincog\n   examples/index\n   \n\nIndices and tables\n==================\n\n* :ref:`genindex`\n* :ref:`modindex`\n* :ref:`search`\n"
  },
  {
    "path": "docs/source/modules.rst",
    "content": "braincog\n========\n\n.. toctree::\n   :maxdepth: 4\n\n   braincog\n"
  },
  {
    "path": "docs/source/setup.rst",
    "content": "setup module\n============\n\n.. automodule:: setup\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs.md",
    "content": "# Sphinx 文档教程 \n\n## 安装  \n\n131 braincog 环境已经装好了\n```shell\n    pip install sphinx sphinx-rtd-theme recommonmark\n``` \n\n## 配置 \n\n已经配置好了, 直接用就行了\n```shell\n    sphinx-quickstart\n``` \n\n## 编译\n\n### braincog 之中的, 编译在Brain Docs之中\n\n1. 重新从 repo 中抓取 ```rst``` 文本\n```shell \n    cd braincog/docs \n    rm -rf ./source/braincog*rst\n    sphinx-apidoc -o ./source/ ../braincog -f \n```\n\n2. 编译 html \n\n```shell\n\n    make clean\n    make html \n```\n\n### Examples 的编译 \n\n1. 在 ```braincog/docs/source/index.rst``` 中, ```img_cls/Tutorial``` 后面一行添加 ``xxx/Tutorial``.\n2. 然后在 ``Brain/docs/source`` 下面添加 ``xxx.md`` 文件, 要和上面的 ``xxx`` 同名.\n3. 用 [Markdown](https://markdown.com.cn/basic-syntax/) 语法, 编写教程, 怎么用, 效果是啥. \n4. 编译html \n\n```shell\n    make clean \n    make html\n```\n\n## 查看 \n\n编译好的文件可以在 ```braincog/docs/build/html``` 中查看. \n\n## 上传\n\n在130服务器上面:\n\n```shell \n    sudo cp braincog/docs/build/html/* /var/www/html\n```\n就可以更新文档了, 并在 [172.18.116.130](http://172.18.116.130/index.html) 中看到.\n"
  },
  {
    "path": "documents/Data_engine.md",
    "content": "# BrainCog Data Engine\n\nIn addition to the static datasets, BrainCog supports the commonly used neuromorphic\ndatasets, such as DVSGesture, DVSCIFAR10, NCALTECH101, ES-ImageNet.\nAlso, the neuromorphic dataset N-Omniglot for few-shot learning is also integrated into \nBrainCog.\n\n**[DVSGesture](https://openaccess.thecvf.com/content_cvpr_2017/papers/Amir_A_Low_Power_CVPR_2017_paper.pdf)**\n\nThis dataset contains 11 hand gestures from 29 subjects under 3 illumination conditions recorded using a DVS128. \n\n**[DVSCIFAR10](https://www.frontiersin.org/articles/10.3389/fnins.2017.00309/full)**\n\nThis dataset converts 10,000 frame-based images in the CIFAR10 dataset into 10,000 event streams using a dynamic vision sensor.\n\n**[NCALTECH101](https://www.frontiersin.org/articles/10.3389/fnins.2015.00437/full)**\n\nThe NCaltech101 dataset is captured by mounting the ATIS sensor on a motorized pan-tilt unit and having the sensor move while it views Caltech101 examples on an LCD monitor. \nThe \"Faces\" class has been removed from N-Caltech101, leaving 100 object classes plus a background class\n\n**[ES-ImageNet](https://arxiv.org/abs/2110.12211)**\n\nThe dataset is converted with Omnidirectional Discrete Gradient (ODG) from 1,300,000 frame-based images in the ImageNet dataset into event-stream samples, which has 1000 categories. \n\n**[N-Omniglot](https://www.nature.com/articles/s41597-022-01851-z)**\n\nThis dataset contains 1,623 categories of handwritten characters, with only 20 samples per class. \nThe dataset is acquired with the DVS acquisition platform to shoot videos (generated from the original Omniglot dataset) played on the monitor, and use the Robotic Process Automation (RPA) software to collect the data automatically.\n \nYou can easily use them in the braincog/datasets folder, taking DVSCIFAR10 as an example\n```python\nloader_train, loader_eval,_,_ = get_dvsc10_data(batch_size=128,step=10)\n```"
  },
  {
    "path": "documents/Lectures.md",
    "content": "# Lectures\n\n- [BrainCog Talk] Beginning BrainCog Lecture 32. Structural Modeling and Neural Activity Simulation of Mammalian Corticothalamic Functional Minicolumn on BrainCog [[English Version](https://www.youtube.com/watch?v=xwcu3yHe4FQ&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=30), [Chinese Version](https://www.bilibili.com/video/BV1aN411v7BG/?spm_id_from=333.788&vd_source=ffca9a0cf41b21082e79f7f6ad9a5301)]\n- [BrainCog Talk] Beginning BrainCog Lecture 31. A Brain-inspired Theory of Mind Spiking Neural Network Improves Multi-agent Cooperation [[English Version](https://www.youtube.com/watch?v=xwcu3yHe4FQ&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=30), [Chinese Version](https://www.bilibili.com/video/BV11m4y1N7Bh/?spm_id_from=333.788&vd_source=ffca9a0cf41b21082e79f7f6ad9a5301)]\n- [BrainCog Talk] Beginning BrainCog Lecture 30. FPGA-based Spiking Neural Networks Acceleration Using BrainCog [[English Version](https://www.youtube.com/watch?v=xwcu3yHe4FQ&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=30), [Chinese Version](https://www.bilibili.com/video/BV1mF411Q7Rr/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]\n- [BrainCog Talk] Beginning BrainCog Lecture 29. Using BrainCog to Realize Robot's Intention Prediction for Users [[English Version](https://www.youtube.com/watch?v=vlwRpq-HwDo&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=29), [Chinese Version](https://www.bilibili.com/video/BV1oz4y1H7QL/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]\n- [BrainCog Talk] Beginning BrainCog Lecture 28. Reward-modulated brain-inspired spiking neural network to empower group for nature-inspired self-organized obstacle avoidance implemented by Braincog [[English Version](https://www.youtube.com/watch?v=2X2KkG0KXo4&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=28), [Chinese Version](https://www.bilibili.com/video/BV16h411N7tG/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]\n- [BrainCog Talk] Beginning BrainCog Lecture 27. Drosophila-inspired Linear and Nonlinear Decision-Making Spiking Neural Network implemented by Braincog  [[English Version](https://www.youtube.com/watch?v=j49q33dMbYk&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=27), [Chinese Version](https://www.bilibili.com/video/BV1Ch4y147i6/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]\n- [BrainCog Talk] Beginning BrainCog Lecture 26. Spiking CapsNet based on BrainCog [[English Version](https://www.youtube.com/watch?v=Gn1PMMFnD6M&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=26), [Chinese Version](https://www.bilibili.com/video/BV1M24y1N7Re/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]\n- [BrainCog Talk] Beginning BrainCog Lecture 25. SNNs with adaptive self-feedback and balanced excitatory–inhibitory neurons based on BrainCog [[English Version](https://www.youtube.com/watch?v=pqeagXtYk8w&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=25), [Chinese Version](https://www.bilibili.com/video/BV18g4y1j7WA/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]\n- [BrainCog Talk] Beginning BrainCog Lecture 24. DVS Augmentation Strategies based on BrainCog [[English Version](https://www.youtube.com/watch?v=aok8B34DNVs&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=24), [Chinese Version](https://www.bilibili.com/video/BV1324y1L7QP/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]\n- [BrainCog Talk] Beginning BrainCog Lecture 23. The Implement of Object Detection and Semantic Segmentation Based on SNNs with Braincog [[English Version](https://www.youtube.com/watch?v=2kh7rNwQkkY&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=23), [Chinese Version](https://www.bilibili.com/video/BV1XT411r7ak/?spm_id_from=333.788&vd_source=ffca9a0cf41b21082e79f7f6ad9a5301)]\n- [BrainCog Talk] Beginning BrainCog Lecture 22. BrainCog Data Engine: Spatio-temporal Sequence Data N-Omniglot and Its Applications [[English Version](https://www.youtube.com/watch?v=J5qBIX9f8Ns&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=22), [Chinese Version](https://www.bilibili.com/video/BV1uT411a733/?spm_id_from=333.788&vd_source=ffca9a0cf41b21082e79f7f6ad9a5301)]\n- [BrainCog Talk] Beginning BrainCog Lecture 21. Dynamic structural development for SNNs incorporating constraints, pruning and regeneration based on BrainCog [[English Version](https://www.youtube.com/watch?v=OccvvdyiOgY&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=21), [Chinese Version](https://www.bilibili.com/video/BV1Zv4y1Y72h/?spm_id_from=333.788&vd_source=ffca9a0cf41b21082e79f7f6ad9a5301)]\n- [BrainCog Talk] Beginning BrainCog Lecture 20. Developmental Plasticity-inspired Adaptive Pruning for SNNs based on BrainCog [[English Version](https://www.youtube.com/watch?v=ckrSW_2vqeE&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=20), [Chinese Version](https://www.bilibili.com/video/BV1iD4y1P7XE/?spm_id_from=333.788&vd_source=ffca9a0cf41b21082e79f7f6ad9a5301)]\n- [BrainCog Talk] Beginning BrainCog Lecture 19. Multi-brain areas Coordinated Brain-inspired Affective Empathy Spiking Neural Network Based on Braincog [[English Version](https://www.youtube.com/watch?v=dB-hQ-5RJ6U&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=19), [Chinese Version](https://www.bilibili.com/video/BV16Y411q7b2/?spm_id_from=333.788&vd_source=ffca9a0cf41b21082e79f7f6ad9a5301)]\n- [BrainCog Talk] Beginning BrainCog Lecture 18. Application of the Prefrontal Cortex Column Model in Working Memory Task with BrainCog [[English Version](https://www.youtube.com/watch?v=6xzURpSFkxk&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=18), [Chinese Version](https://www.bilibili.com/video/BV1As4y1t7wF/?spm_id_from=333.788&vd_source=ffca9a0cf41b21082e79f7f6ad9a5301)]\n- [BrainCog Talk] Beginning BrainCog Lecture 17. A Brain-inspired Theory of Mind Model Based on BrainCog for Reducing Other Agents’ Safety Risks [[English Version](https://www.youtube.com/watch?v=16Csw03bTjY&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=17), [Chinese Version](https://www.bilibili.com/video/BV12A411f74Q/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]\n- [BrainCog Talk] Beginning BrainCog Lecture 16. Brain-inspired Bodily Self-perception Model Based on BrainCog [[English Version](https://www.youtube.com/watch?v=mV_iMsDEQsg&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=16), [Chinese Version](https://www.bilibili.com/video/BV1pK411i7Jk/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]\n- [BrainCog Talk] Beginning BrainCog Lecture 15. SNN-based Music Memory and Generation Based on BrainCog [[English Version](https://www.youtube.com/watch?v=c0Tcs1B5xho&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=15), [Chinese Version](https://www.bilibili.com/video/BV1M3411X7C2/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]\n- [BrainCog Talk] Beginning BrainCog Lecture 14. The Implement of Multisensory Concept Learning Framework Based on SNNs with Braincog [[English Version](https://www.youtube.com/watch?v=c9UfOCGzFPQ&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=14), [Chinese Version](https://www.bilibili.com/video/BV1Ae411P7tY/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]\n- [BrainCog Talk] Beginning BrainCog Lecture 13. Symbolic Representation and Reasoning SNN Based on Braincog [[English Version](https://www.youtube.com/watch?v=1iosjPkOBRo&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=13), [Chinese Version](https://www.bilibili.com/video/BV1DW4y1p7kg/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]\n- [BrainCog Talk] Beginning BrainCog Lecture 12. Unsupervised STDP-based Spiking Neural Networks Based on BrainCog [[English Version](https://www.youtube.com/watch?v=pzPJ1XOEB9U&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=12), [Chinese Version](https://www.bilibili.com/video/BV1MR4y1Z7Be/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]\n- [BrainCog Talk] Beginning BrainCog Lecture 11. Backpropagation with Spatiotemporal Adjustment for Training Deep Spiking Neural Networks through BrainCog [[English Version](https://www.youtube.com/watch?v=Fm85BjQszng&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=11), [Chinese Version](https://www.bilibili.com/video/BV1EK411Z71F/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]\n- [BrainCog Talk] Beginning BrainCog Lecture 10. Multi-brain Areas Coordinated Brain-inspired Decision-Making Spiking Neural Network Based on Braincog [[English Version](https://www.youtube.com/watch?v=uVCcZHzzN3U&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=10), [Chinese Version](https://www.bilibili.com/video/BV1jD4y1x7PU/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]\n- [BrainCog Talk] Beginning BrainCog Lecture 9. Spiking Neural Networks with Global Feedback Connections Based on BrainCog [[English Version](https://www.youtube.com/watch?v=g_qelwoQsD8&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=9), [Chinese Version](https://www.bilibili.com/video/BV1qv4y1D74Y/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]\n- [BrainCog Talk] Beginning BrainCog Lecture 8. Converting Artificial Neural Network to Spiking Neural Network through BrainCog [[English Version](https://www.youtube.com/watch?v=cxiKyQ7F8UE&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=8), [Chinese Version](https://www.bilibili.com/video/BV17e4y147H6/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]\n- [BrainCog Talk] Beginning BrainCog Lecture 7. Implementing Quantum Superposition Inspired Spatio-temporal Spike Encoding through BrainCog [[English Version](https://www.youtube.com/watch?v=5T-2Yyr9a0s&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=7), [Chinese Version](https://www.bilibili.com/video/BV1BG41177Z6/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]\n- [BrainCog Talk] Beginning BrainCog Lecture 6. Implementing spiking deep Q network through Braincog [[English Version](https://www.youtube.com/watch?v=jCsOBtiN-q0&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=6), [Chinese Version](https://www.bilibili.com/video/BV1XN4y1c7Kn/?spm_id_from=333.337.search-card.all.click&vd_source=ffca9a0cf41b21082e79f7f6ad9a5301)]\n- [BrainCog Talk] Beginning BrainCog Lecture 5. Advanced BrainCog System Functions [[English Version](https://www.youtube.com/watch?v=VJBORFl6dTQ&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=5), [Chinese Version](https://www.bilibili.com/video/BV1tT411N7c9/?spm_id_from=333.337.search-card.all.click&vd_source=ffca9a0cf41b21082e79f7f6ad9a5301)]\n- [BrainCog Talk] Beginning BrainCog Lecture 4. Creating Cognitive SNNs for Brain Areas [[English Version](https://www.youtube.com/watch?v=gYxG1D1b4Zo&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=4), [Chinese Version](https://www.bilibili.com/video/BV19d4y1679Y/?spm_id_from=333.788&vd_source=ffca9a0cf41b21082e79f7f6ad9a5301)]\n- [BrainCog Talk] Beginning BrainCog Lecture 3. Creating SNNs Easily and Quickly [[English Version](https://www.youtube.com/watch?v=k3byUIp4O24&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=3), [Chinese Version](https://www.bilibili.com/video/BV1Be4y1874W/?spm_id_from=333.788&vd_source=ffca9a0cf41b21082e79f7f6ad9a5301)]\n- [BrainCog Talk] Beginning BrainCog Lecture 2. Computational Modeling of Spiking Neurons [[English Version](https://www.youtube.com/watch?v=5jPsPkyFTY8&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=2), [Chinese Version](https://www.bilibili.com/video/BV16K411f7vQ/?spm_id_from=333.788&vd_source=ffca9a0cf41b21082e79f7f6ad9a5301)]\n- [BrainCog Talk] Beginning BrainCog Lecture 1. Installing and Deploying BrainCog platform [[English Version](https://www.youtube.com/watch?v=XkHq-MbKo20&list=PLNXUFsTshMlYTW6oleY5YjVEfnoQSw0N7&index=1), [Chinese Version](https://www.bilibili.com/video/BV1AW4y1b7v1/?spm_id_from=333.788&vd_source=ac84fc93f3c82d1cd96648079077afd3)]\n"
  },
  {
    "path": "documents/Pub_brain_inspired_AI.md",
    "content": "# Publications Using BrainCog \n## Brain Inspired AI\n\n\n### Perception and Leanring\n| Papers                                                                                                                                                                                                                                           | Codes                                                                                                                                                                                                                                                          | Publisher                                                                         |\n| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------- |\n| [BackEISNN: A deep spiking neural network with adaptive self-feedback and balanced excitatory–inhibitory neurons](https://www.sciencedirect.com/science/article/pii/S0893608022002520)                                                           | [https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Perception_and_Learning/img_cls/bp/main_backei.py](https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Perception_and_Learning/img_cls/bp/main_backei.py)                                   | Neural Networks                                                                   |\n| [Efficient and Accurate Conversion of Spiking Neural Network with Burst Spikes](https://www.ijcai.org/proceedings/2022/345)                                                                                                                      | [https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Perception_and_Learning/Conversion/converted_CIFAR10.py](https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Perception_and_Learning/Conversion/converted_CIFAR10.py)                       | IJCAI                                                                             |\n| [Spike Calibration: Fast and Accurate Conversion of Spiking Neural Network for Object Detection and Segmentation](https://arxiv.org/abs/2207.02702)                                                                                              | [https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Perception_and_Learning/Conversion/converted_CIFAR10.py](https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Perception_and_Learning/Conversion/converted_CIFAR10.py)                       | Arxiv                                                                             |\n| [An unsupervised STDP-based spiking neural network inspired by biologically plausible learning rules and connections](https://www.sciencedirect.com/science/article/pii/S0893608023003301)                                                       | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/UnsupervisedSTDP](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/UnsupervisedSTDP)                                                     | Neural Networks                                                                   |\n| [Multisensory Concept Learning Framework Based on Spiking Neural Networks](https://www.frontiersin.org/articles/10.3389/fnsys.2022.845177/full)                                                                                                  | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/MultisensoryIntegration](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/MultisensoryIntegration)                                       | Frontiers in Systems Neuroscience                                                 |\n| [Backpropagation with biologically plausible spatiotemporal adjustment for training deep spiking neural networks](https://www.sciencedirect.com/science/article/pii/S2666389922001192)                                                           | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/bp](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/bp)                                                                 | Patterns                                                                          |\n| [Quantum superposition inspired spiking neural network](https://www.cell.com/iscience/fulltext/S2589-0042(21)00848-8)                                                                                                                            | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/QSNN](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/QSNN)                                                                             | iScience                                                                          |\n| [GLSNN: A Multi-Layer Spiking Neural Network Based on Global Feedback Alignment and Local STDP Plasticity](https://www.frontiersin.org/articles/10.3389/fncom.2020.576841/full)                                                                  | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/glsnn](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/glsnn)                                                           | Frontiers in Neuroscience                                                         |\n| [Spiking CapsNet: A spiking neural network with a biologically plausible routing rule between capsules](https://www.sciencedirect.com/science/article/pii/S002002552200843X)                                                                     | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/spiking_capsnet](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/spiking_capsnet)                                       | Information Sciences                                                              |\n| [N-Omniglot, a large-scale neuromorphic dataset for spatio-temporal sparse few-shot learning](https://www.nature.com/articles/s41597-022-01851-z)                                                                                                | [https://github.com/BrainCog-X/Brain-Cog/tree/main/braincog/datasets/NOmniglot](https://github.com/BrainCog-X/Brain-Cog/tree/main/braincog/datasets/NOmniglot)                                                                                                 | Scientific Data                                                                   |\n| [EventMix: An efficient data augmentation strategy for event-based learning](https://www.sciencedirect.com/science/article/pii/S0020025523007557)                                                                                                | [https://github.com/BrainCog-X/Brain-Cog/blob/main/braincog/datasets/cut_mix.py](https://github.com/BrainCog-X/Brain-Cog/blob/main/braincog/datasets/cut_mix.py)                                                                                               | Information Sciences                                                              |\n| [Challenging deep learning models with image distortion based on the abutting grating illusion](https://www.sciencedirect.com/science/article/pii/S2666389923000260)                                                                             | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/IllusionPerception/AbuttingGratingIllusion](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/IllusionPerception/AbuttingGratingIllusion) | Patterns                                                                          |\n| [MSAT: Biologically Inspired Multi-Stage Adaptive Threshold for Conversion of Spiking Neural Networks](https://arxiv.org/abs/2303.13080)                                                                                                         | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/Conversion/msat_conversion](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/Conversion/msat_conversion)                                 | Arxiv                                                                             |\n| [Improving the Performance of Spiking Neural Networks on Event-based Datasets with Knowledge Transfer](https://arxiv.org/abs/2303.13077)                                                                                                         | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/transfer_for_dvs](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/transfer_for_dvs)                                     | Arxiv                                                                             |\n| [Improving Stability and Performance of Spiking Neural Networks Through Enhancing Temporal Consistency](https://www.sciencedirect.com/science/article/abs/pii/S0031320324008458)                                                                 | [https://github.com/Brain-Cog-Lab/ETC](https://github.com/Brain-Cog-Lab/ETC)                                                                                                                                                                                   | Pattern Recognition                                                               |\n| [Spiking Generative Adversarial Network with Attention Scoring Decoding](https://www.sciencedirect.com/science/article/abs/pii/S0893608024003472)                                                                                                | [https://github.com/Brain-Cog-Lab/sgad](https://github.com/Brain-Cog-Lab/sgad)                                                                                                                                                                                 | Neural Networks                                                                   |\n| [Temporal Knowledge Sharing Enables Spiking Neural Network Learning from Past and Future](https://ieeexplore.ieee.org/document/10462632)                                                                                                         | [https://github.com/Brain-Cog-Lab/TKS](https://github.com/Brain-Cog-Lab/TKS)                                                                                                                                                                                   | IEEE Transactions on Artificial Intelligence                                      |\n| [TIM: An Efficient Temporal Interaction Module for Spiking Transformer](https://www.ijcai.org/proceedings/2024/0347.pdf)                                                                                                                         | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/TIM](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/TIM)                                                                                                                               | IJCAI2024                                                                         |\n| [Exploiting Nonlinear Dendritic Adaptive Computation in Training Deep Spiking Neural Networks](https://www.sciencedirect.com/science/article/pii/S0893608023006202)                                                                              | [https://github.com/BrainCog-X/Brain-Cog/blob/main/braincog/base/node/node.py#L1412](https://github.com/BrainCog-X/Brain-Cog/blob/main/braincog/base/node/node.py#L1412)                                                                                       | Neural Networks                                                                   |\n| [Are Conventional SNNs Really Efficient? A Perspective from Network Quantization](https://openaccess.thecvf.com/content/CVPR2024/papers/Shen_Are_Conventional_SNNs_Really_Efficient_A_Perspective_from_Network_Quantization_CVPR_2024_paper.pdf) |                                                                                                                                                                                                                                                                | 2024 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)        |\n| [CACE-Net: Co-guidance Attention and Contrastive Enhancement for Effective Audio-Visual Event Localization](https://dl.acm.org/doi/10.1145/3664647.3681503)                                                                                      | [https://github.com/Brain-Cog-Lab/CACE-Net](https://github.com/Brain-Cog-Lab/CACE-Net)                                                                                                                                                                         | Proceedings of the 32nd ACM International Conference on Multimedia, 2024: 985-993 |\n| [Parallel Spiking Unit for Efficient Training of Spiking Neural Networks](https://arxiv.org/abs/2402.00449)                                                                                                                                      | [https://github.com/Brain-Cog-Lab/PSU](https://github.com/Brain-Cog-Lab/PSU)                                                                                                                                                                                   | IJCNN 2024                                                                        |\n| [Spiking Neural Networks with Consistent Mapping Relations Allow High-Accuracy Inference](https://www.sciencedirect.com/science/article/abs/pii/S0020025524007369)                                                                               | [https://github.com/Brain-Cog-Lab/casc](https://github.com/Brain-Cog-Lab/casc)                                                                                                                                                                                 | Information Sciences                                                              |\n| [Directly training temporal Spiking Neural Network with sparse surrogate gradient](https://www.sciencedirect.com/science/article/abs/pii/S0893608024004234)                                                                                      | [https://github.com/Brain-Cog-Lab/msg](https://github.com/Brain-Cog-Lab/msg)                                                                                                                                                                                   | Neural Networks                                                                   |\n\n\n### Knowledge Representation and Reasoning\n| Papers                                                                                                                                                                                | Codes                                                                                                         | Publisher                                                                           |\n| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------- |\n| [A Brain-Inspired Causal Reasoning Model Based on Spiking Neural Networks](https://ieeexplore.ieee.org/abstract/document/9534102)                                                     | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Knowledge_Representation_and_Reasoning/CRSNN       | IJCNN2021                                                                           |\n| [Brain Inspired Sequences Production by Spiking Neural Networks With Reward-Modulated STDP](https://www.frontiersin.org/articles/10.3389/fncom.2021.612041/full)                      | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Knowledge_Representation_and_Reasoning/SPSNN       | Frontiers in Computational Neuroscience                                             |\n| [Temporal-Sequential Learning With a Brain-Inspired Spiking Neural Network and Its Application to Musical Memory](https://www.frontiersin.org/articles/10.3389/fncom.2020.00051/full) | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Knowledge_Representation_and_Reasoning/musicMemory | Frontiers in Computational Neuroscience                                             |\n| [Stylistic Composition of Melodies Based on a Brain-Inspired Spiking Neural Network](https://www.frontiersin.org/articles/10.3389/fnsys.2021.639484/full)                             | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/affective_empathy/BRP-SNN         | Frontiers in Neurorobotics                                                          |\n| [Brain-inspired Graph Spiking Neural Networks for Commonsense Knowledge Representation and Reasoning](https://arxiv.org/abs/2207.05561 )                                              | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Knowledge_Representation_and_Reasoning/CKRGSNN     | Arxiv                                                                               |\n| [An Efficient Knowledge Transfer Strategy for Spiking Neural Networks from Static to Event Domain](https://arxiv.org/abs/2303.13077)                                                  | [https://github.com/Brain-Cog-Lab/Transfer-for-DVS](https://github.com/Brain-Cog-Lab/Transfer-for-DVS)        | Proceedings of the AAAI Conference on Artificial Intelligence, 2024, 38(1): 512-520 |\n\n\n### Decision Making\n| Papers                                                                                                                                                                                      | Codes                                                                                                   | Publisher                  |\n| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------- | -------------------------- |\n| [Nature-inspired self-organizing collision avoidance for drone swarm based on reward-modulated spiking neural network](https://www.cell.com/patterns/fulltext/S2666-3899(22)00236-7)        | https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/decision_making/swarm/Collision-Avoidance.py | Cell Patterns              |\n| [Solving the spike feature information vanishing problem in spiking deep Q network with potential based normalization](https://www.frontiersin.org/articles/10.3389/fnins.2022.953368/full) | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/decision_making/RL/sdqn                      | Frontiers in Neuroscience  |\n| [A Brain-Inspired Decision-Making Spiking Neural Network and Its Application in Unmanned Aerial Vehicle](https://www.frontiersin.org/articles/10.3389/fnbot.2018.00056/full )               | https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/decision_making/BDM-SNN/BDM-SNN-hh.py        | Frontiers in Neurorobotics |\n| [Multi-compartment Neuron and Population Encoding improved Spiking Neural Network for Deep Distributional Reinforcement Learning](https://arxiv.org/abs/2301.07275)                         | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/decision_making/RL/mcs-fqf                   | Arxiv                      |\n\n\n### Motor Control\n| Papers | Codes                                                                                | Publisher |\n| ------ | ------------------------------------------------------------------------------------ | --------- |\n|        | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/MotorControl/experimental |           |\n\n\n### Social Cognition\n| Papers                                                                                                                                                                    | Codes                                                                                                                                                                                                          | Publisher                               |\n| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------- |\n| [Toward Robot Self-Consciousness (II): Brain-Inspired Robot Bodily Self Model for Self-Recognition](https://link.springer.com/article/10.1007/s12559-017-9505-1 )         | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/mirror_test](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/mirror_test)                             | Cognitive Computation                   |\n| [A brain-inspired intention prediction model and its applications to humanoid robot](https://www.frontiersin.org/articles/10.3389/fnins.2022.1009237/full)                | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/Intention_Prediction](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/Intention_Prediction)           | Frontiers in Neuroscience               |\n| [A Brain-Inspired Theory of Mind Spiking Neural Network for Reducing Safety Risks of Other Agents](https://www.frontiersin.org/articles/10.3389/fnins.2022.753900/full)   | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/ToM](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/ToM)                                             | Frontiers in Neuroscience               |\n| [Brain-Inspired Affective Empathy Computational Model and Its Application on Altruistic Rescue Task](https://www.frontiersin.org/articles/10.3389/fncom.2022.784967/full) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/affective_empathy/BAE-SNN](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/affective_empathy/BAE-SNN) | Frontiers in Computational Neuroscience |\n| [A brain-inspired robot pain model based on a spiking neural network](https://www.frontiersin.org/articles/10.3389/fnbot.2022.1025338/full)                               | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/affective_empathy/BRP-SNN](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/affective_empathy/BRP-SNN) | Frontiers in Neurorobotics              |\n| [Brain-Inspired Theory of Mind Spiking Neural Network Elevates Multi-Agent Cooperation and Competition](https://papers.ssrn.com/sol3/papers.cfm?abstract_id=4271099)      | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/MAToM-SNN  ](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/MAToM-SNN  )                             | SSRN                                    |\n\n\n### Development and Evolution\n| Papers                                                                                                                                                                  | Codes                                                                                                                                                                                      | Publisher                                       |\n| ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----------------------------------------------- |\n| [Developmental Plasticity-inspired Adaptive Pruning for Deep Spiking and Artificial Neural Networks](https://arxiv.org/abs/2211.12714)                                  | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/DPAP](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/DPAP)           | Arxiv                                           |\n| [Adaptive Sparse Structure Development with Pruning and Regeneration for Spiking Neural Networks](https://arxiv.org/abs/2211.12219)                                     | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/SD-SNN](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/SD-SNN)       | Arxiv                                           |\n| [Emergence of Brain-inspired Small-world Spiking Neural Network through Neuroevolution](https://arxiv.org/abs/2304.10749)                                               | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/ELSM](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/ELSM)           | Arxiv                                           |\n| [Enhancing Efficient Continual Learning with Dynamic Structure Development of Spiking Neural Networks](https://arxiv.org/abs/2308.04749)                                | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/DSD-SNN](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/DSD-SNN)     | IJCAI                                           |\n| [Brain-inspired Evolutionary Neural Architecture Search for Spiking Neural Networks](https://ieeexplore.ieee.org/abstract/document/10542732)                            | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structure_Evolution/EB-NAS](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structure_Evolution/EB-NAS)             | IEEE Transactions on Artificial Intelligence    |\n| [Adaptive Structure Evolution and Biologically Plausible Synaptic Plasticity for Recurrent Spiking Neural Networks](https://www.nature.com/articles/s41598-023-43488-x) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structure_Evolution/Adaptive_lsm](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structure_Evolution/Adaptive_lsm) | Scientific Reports                              |\n| [Brain-inspired Neural Circuit Evolution for Spiking Neural Networks](https://www.pnas.org/doi/10.1073/pnas.2218173120)                                                 | [https://github.com/BrainCog-X/Brain-Cog/tree/main/braincog/model_zoo/NeuEvo](https://github.com/BrainCog-X/Brain-Cog/tree/main/braincog/model_zoo/NeuEvo)                                 | Proceedings of the National Academy of Sciences |\n\n### Safety and Security\n| Papers                                    | Codes                                                                                                                                                      | Publisher |\n| ----------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- |\n| [DPSNN](https://arxiv.org/abs/2205.12718) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Snn_safety/DPSNN](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Snn_safety/DPSNN) | Arxiv     |\n\n### Dataset\n| Papers                                                                                                                                                                                                                           | Codes                                                                                                                                                              | Publisher   |\n| -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----------- |\n| [Bullying10K: A Large-scale Neuromorphic Dataset Towards Privacy-preserving Bullying Recognition](https://proceedings.neurips.cc/paper_files/paper/2023/file/05ffe69463062b7f9fb506c8351ffdd7-Paper-Datasets_and_Benchmarks.pdf) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/braincog/datasets/bullying10k](https://github.com/BrainCog-X/Brain-Cog/tree/main/braincog/datasets/bullying10k) | Neurips2023 |"
  },
  {
    "path": "documents/Pub_brain_simulation.md",
    "content": "# Publications Using BrainCog \n## Brain Simulation\n\n### Funtion\n\n| Papers                                                                                                                                                                                                       | Codes                                                                                                                                                                                                                                                 | Publisher           |\n|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------|\n| [A neural algorithm for Drosophila linear and nonlinear decision-making](https://www.nature.com/articles/s41598-020-75628-y)                                                                                 | [https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Brain_Cognitive_Function_Simulation/drosophila/drosophila.py](https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Brain_Cognitive_Function_Simulation/drosophila/drosophila.py)\t   | Scientific Reports  |\n\n### Structure\n\n| Papers                                                                                                                                                                                                       | Codes                                                                                                                                                                                                                                                | Publisher                        |\n|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------|\n| Corticothalamic minicolumn\t                                                                                                                                                                                  | \t\t[https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn)   |                                  |\n| Human Brain\t                                                                                                                                                                                                 | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Multiscale_Brain_Structure_Simulation/HumanBrain](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Multiscale_Brain_Structure_Simulation/HumanBrain)\t\t                         |                                  |\n| Macaque Brain                                                                                                                                                                                                | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Multiscale_Brain_Structure_Simulation/MacaqueBrain](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Multiscale_Brain_Structure_Simulation/MacaqueBrain)\t                      |                                  |\n| Mouse Brain                                                                                                                                                                                                  | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Multiscale_Brain_Structure_Simulation/Mouse_brain ](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Multiscale_Brain_Structure_Simulation/Mouse_brain )\t                      |                                  |\n| [Comparison Between Human and Rodent Neurons for Persistent Activity Performance: A Biologically Plausible Computational Investigation](https://www.frontiersin.org/articles/10.3389/fnsys.2021.628839/full) | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Multiscale_Brain_Structure_Simulation/Human_PFC_Model\t                                                                                                                                    | Frontiers in System Neuroscience |\n"
  },
  {
    "path": "documents/Pub_sh_codesign.md",
    "content": "# Publications Using BrainCog \n## Software-Hardware Co-design\n\n\n### Hardware Acceleration\n| Papers                                                       | Codes                                    | Publisher                                                    |\n| ------------------------------------------------------------ | ---------------------------------------- | ------------------------------------------------------------ |\n| [FireFly: A High-Throughput and Reconfigurable Hardware Accelerator for Spiking Neural Networks](https://ieeexplore.ieee.org/abstract/document/10143752) | https://github.com/adamgallas/FireFly-v1 | IEEE Transactions on Very Large Scale Integration (VLSI) Systems |\n| [FireFly v2: Advancing Hardware Support for High-Performance Spiking Neural Network With a Spatiotemporal FPGA Accelerator](https://ieeexplore.ieee.org/abstract/document/10478105) | https://github.com/adamgallas/FireFly-v2 | IEEE Transactions on Computer-Aided Design of Integrated Circuits and Systems |\n| [Revealing Untapped DSP Optimization Potentials for FPGA-Based Systolic Matrix Engines](https://ieeexplore.ieee.org/abstract/document/10705564) | https://github.com/adamgallas/SpinalDLA  | 2024 34th International Conference on Field-Programmable Logic and Applications (FPL) |\n| [FireFly-S: Exploiting Dual-Side Sparsity for Spiking Neural Networks Acceleration With Reconfigurable Spatial Architecture](https://ieeexplore.ieee.org/document/10754657) |                                          | IEEE Transactions on Circuits and Systems I: Regular Papers  |\n| Pushing up to the Limit of Memory Bandwidth and Capacity Utilization for Efficient LLM Decoding on Embedded FPGA |                                          | 2025 Design, Automation & Test in Europe Conference & Exhibition (DATE) |\n\n"
  },
  {
    "path": "documents/Publication.md",
    "content": "# Publications Using BrainCog\n\n## 2024\n| Papers                                                                                                                                                                                                                                           | Codes                                                                                                                                                                                                                          | Publisher                                                                   |\n| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------- |\n| [Developmental Plasticity-inspired Adaptive Pruning for Deep Spiking and Artificial Neural Networks](https://ieeexplore.ieee.org/document/10691937)                                                                                              | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/DPAP](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/DPAP)                                               | IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)      |\n| [Multi-compartment Neuron and Population Encoding improved Spiking Neural Network for Deep Distributional Reinforcement Learning](https://www.sciencedirect.com/science/article/abs/pii/S089360802400827X?via%3Dihub)                            | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/decision_making/RL/mcs-fqf                                                                                                                                          | Neural Networks                                                             |\n| [Improving Stability and Performance of Spiking Neural Networks Through Enhancing Temporal Consistency](https://www.sciencedirect.com/science/article/abs/pii/S0031320324008458)                                                                 | [https://github.com/Brain-Cog-Lab/ETC](https://github.com/Brain-Cog-Lab/ETC)                                                                                                                                                   | Pattern Recognition                                                         |\n| [Spiking Generative Adversarial Network with Attention Scoring Decoding](https://www.sciencedirect.com/science/article/abs/pii/S0893608024003472)                                                                                                | [https://github.com/Brain-Cog-Lab/sgad](https://github.com/Brain-Cog-Lab/sgad)                                                                                                                                                 | Neural Networks                                                             |\n| [Temporal Knowledge Sharing Enables Spiking Neural Network Learning from Past and Future](https://ieeexplore.ieee.org/document/10462632)                                                                                                         | [https://github.com/Brain-Cog-Lab/TKS](https://github.com/Brain-Cog-Lab/TKS)                                                                                                                                                   | IEEE Transactions on Artificial Intelligence (TAI)                          |\n| [TIM: An Efficient Temporal Interaction Module for Spiking Transformer](https://www.ijcai.org/proceedings/2024/0347.pdf)                                                                                                                         | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/TIM](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/TIM)                                                                                               | IJCAI2024                                                                   |\n| [Are Conventional SNNs Really Efficient? A Perspective from Network Quantization](https://openaccess.thecvf.com/content/CVPR2024/papers/Shen_Are_Conventional_SNNs_Really_Efficient_A_Perspective_from_Network_Quantization_CVPR_2024_paper.pdf) |                                                                                                                                                                                                                                | 2024 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)  |\n| [CACE-Net: Co-guidance Attention and Contrastive Enhancement for Effective Audio-Visual Event Localization](https://dl.acm.org/doi/10.1145/3664647.3681503)                                                                                      | [https://github.com/Brain-Cog-Lab/CACE-Net](https://github.com/Brain-Cog-Lab/CACE-Net)                                                                                                                                         | Proceedings of the 32nd ACM International Conference on Multimedia (ACM MM) |\n| [Parallel Spiking Unit for Efficient Training of Spiking Neural Networks](https://ieeexplore.ieee.org/document/10650207)                                                                                                                         | [https://github.com/Brain-Cog-Lab/PSU](https://github.com/Brain-Cog-Lab/PSU)                                                                                                                                                   | IJCNN 2024                                                                  |\n| [Spiking Neural Networks with Consistent Mapping Relations Allow High-Accuracy Inference](https://www.sciencedirect.com/science/article/abs/pii/S0020025524007369)                                                                               | [https://github.com/Brain-Cog-Lab/casc](https://github.com/Brain-Cog-Lab/casc)                                                                                                                                                 | Information Sciences                                                        |\n| [Directly training temporal Spiking Neural Network with sparse surrogate gradient](https://www.sciencedirect.com/science/article/abs/pii/S0893608024004234)                                                                                      | [https://github.com/Brain-Cog-Lab/msg](https://github.com/Brain-Cog-Lab/msg)                                                                                                                                                   | Neural Networks                                                             |\n| [Brain-inspired Evolutionary Neural Architecture Search for Spiking Neural Networks](https://ieeexplore.ieee.org/abstract/document/10542732)                                                                                                     | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structure_Evolution/EB-NAS](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structure_Evolution/EB-NAS)                                                 | IEEE Transactions on Artificial Intelligence (TAI)                          |\n| [MSAT: Biologically Inspired Multi-Stage Adaptive Threshold for Conversion of Spiking Neural Networks](https://link.springer.com/article/10.1007/s00521-024-09529-w)                                                                             | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/Conversion/msat_conversion](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/Conversion/msat_conversion) | Neural Computing and Applications                                           |\n| [An Efficient Knowledge Transfer Strategy for Spiking Neural Networks from Static to Event Domain](https://ojs.aaai.org/index.php/AAAI/article/view/27806/27643)                                                                                 | [https://github.com/Brain-Cog-Lab/Transfer-for-DVS](https://github.com/Brain-Cog-Lab/Transfer-for-DVS)                                                                                                                         | Proceedings of the AAAI Conference on Artificial Intelligence (AAAI)        |\n\n## 2023\n| Papers                                                                                                                                                                                                                           | Codes                                                                                                                                                                                                                                                          | Publisher                                              |\n| -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------ |\n| [An unsupervised STDP-based spiking neural network inspired by biologically plausible learning rules and connections](https://www.sciencedirect.com/science/article/pii/S0893608023003301)                                       | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/UnsupervisedSTDP](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/UnsupervisedSTDP)                                                     | Neural Networks                                        |\n| [EventMix: An efficient data augmentation strategy for event-based learning](https://www.sciencedirect.com/science/article/pii/S0020025523007557)                                                                                | [https://github.com/BrainCog-X/Brain-Cog/blob/main/braincog/datasets/cut_mix.py](https://github.com/BrainCog-X/Brain-Cog/blob/main/braincog/datasets/cut_mix.py)                                                                                               | Information Sciences                                   |\n| [Challenging deep learning models with image distortion based on the abutting grating illusion](https://www.sciencedirect.com/science/article/pii/S2666389923000260)                                                             | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/IllusionPerception/AbuttingGratingIllusion](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/IllusionPerception/AbuttingGratingIllusion) | Patterns                                               |\n| [Improving the Performance of Spiking Neural Networks on Event-based Datasets with Knowledge Transfer](https://arxiv.org/abs/2303.13077)                                                                                         | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/transfer_for_dvs](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/transfer_for_dvs)                                     | Arxiv                                                  |\n| [Exploiting Nonlinear Dendritic Adaptive Computation in Training Deep Spiking Neural Networks](https://www.sciencedirect.com/science/article/pii/S0893608023006202)                                                              | [https://github.com/BrainCog-X/Brain-Cog/blob/main/braincog/base/node/node.py#L1412](https://github.com/BrainCog-X/Brain-Cog/blob/main/braincog/base/node/node.py#L1412)                                                                                       | Neural Networks                                        |\n| [Emergence of Brain-inspired Small-world Spiking Neural Network through Neuroevolution](https://www.sciencedirect.com/science/article/pii/S258900422400066X)                                                                     | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/ELSM](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/ELSM)                                                                               | iScience                                               |\n| [Enhancing Efficient Continual Learning with Dynamic Structure Development of Spiking Neural Networks](https://www.ijcai.org/proceedings/2023/0334.pdf)                                                                          | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/DSD-SNN](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/DSD-SNN)                                                                         | IJCAI                                                  |\n| [Adaptive Structure Evolution and Biologically Plausible Synaptic Plasticity for Recurrent Spiking Neural Networks](https://www.nature.com/articles/s41598-023-43488-x)                                                          | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structure_Evolution/Adaptive_lsm](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structure_Evolution/Adaptive_lsm)                                                                     | Scientific Reports                                     |\n| [Brain-inspired Neural Circuit Evolution for Spiking Neural Networks](https://www.pnas.org/doi/10.1073/pnas.2218173120)                                                                                                          | [https://github.com/BrainCog-X/Brain-Cog/tree/main/braincog/model_zoo/NeuEvo](https://github.com/BrainCog-X/Brain-Cog/tree/main/braincog/model_zoo/NeuEvo)                                                                                                     | Proceedings of the National Academy of Sciences (PANS) |\n| [Bullying10K: A Large-scale Neuromorphic Dataset Towards Privacy-preserving Bullying Recognition](https://proceedings.neurips.cc/paper_files/paper/2023/file/05ffe69463062b7f9fb506c8351ffdd7-Paper-Datasets_and_Benchmarks.pdf) | [https://github.com/BrainCog-X/Brain-Cog/tree/main/braincog/datasets/bullying10k](https://github.com/BrainCog-X/Brain-Cog/tree/main/braincog/datasets/bullying10k)                                                                                             | Neurips2023                                            |\n\n## 2022\n| Papers                                                                                                                                                                                      | Codes                                                                                                                                                                                                                                    | Publisher                               |\n| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------- |\n| [BackEISNN: A deep spiking neural network with adaptive self-feedback and balanced excitatory–inhibitory neurons](https://www.sciencedirect.com/science/article/pii/S0893608022002520)      | [https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Perception_and_Learning/img_cls/bp/main_backei.py](https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Perception_and_Learning/img_cls/bp/main_backei.py)             | Neural Networks                         |\n| [Efficient and Accurate Conversion of Spiking Neural Network with Burst Spikes](https://www.ijcai.org/proceedings/2022/345)                                                                 | [https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Perception_and_Learning/Conversion/converted_CIFAR10.py](https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Perception_and_Learning/Conversion/converted_CIFAR10.py) | IJCAI                                   |\n| [Spike Calibration: Fast and Accurate Conversion of Spiking Neural Network for Object Detection and Segmentation](https://arxiv.org/abs/2207.02702)                                         | [https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Perception_and_Learning/Conversion/converted_CIFAR10.py](https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Perception_and_Learning/Conversion/converted_CIFAR10.py) | Arxiv                                   |\n| [Multisensory Concept Learning Framework Based on Spiking Neural Networks](https://www.frontiersin.org/articles/10.3389/fnsys.2022.845177/full)                                             | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/MultisensoryIntegration](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/MultisensoryIntegration)                 | Frontiers in Systems Neuroscience       |\n| [Backpropagation with biologically plausible spatiotemporal adjustment for training deep spiking neural networks](https://www.sciencedirect.com/science/article/pii/S2666389922001192)      | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/bp](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/bp)                                           | Patterns                                |\n| [Spiking CapsNet: A spiking neural network with a biologically plausible routing rule between capsules](https://www.sciencedirect.com/science/article/pii/S002002552200843X)                | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/spiking_capsnet](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/spiking_capsnet)                 | Information Sciences                    |\n| [N-Omniglot, a large-scale neuromorphic dataset for spatio-temporal sparse few-shot learning](https://www.nature.com/articles/s41597-022-01851-z)                                           | [https://github.com/BrainCog-X/Brain-Cog/tree/main/braincog/datasets/NOmniglot](https://github.com/BrainCog-X/Brain-Cog/tree/main/braincog/datasets/NOmniglot)                                                                           | Scientific Data                         |\n| [Brain-inspired Graph Spiking Neural Networks for Commonsense Knowledge Representation and Reasoning](https://arxiv.org/abs/2207.05561 )                                                    | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Knowledge_Representation_and_Reasoning/CKRGSNN                                                                                                                                | Arxiv                                   |\n| [Nature-inspired self-organizing collision avoidance for drone swarm based on reward-modulated spiking neural network](https://www.cell.com/patterns/fulltext/S2666-3899(22)00236-7)        | https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/decision_making/swarm/Collision-Avoidance.py                                                                                                                                  | Cell Patterns                           |\n| [Solving the spike feature information vanishing problem in spiking deep Q network with potential based normalization](https://www.frontiersin.org/articles/10.3389/fnins.2022.953368/full) | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/decision_making/RL/sdqn                                                                                                                                                       | Frontiers in Neuroscience               |\n| [A brain-inspired intention prediction model and its applications to humanoid robot](https://www.frontiersin.org/articles/10.3389/fnins.2022.1009237/full)                                  | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/Intention_Prediction](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/Intention_Prediction)                                     | Frontiers in Neuroscience               |\n| [A Brain-Inspired Theory of Mind Spiking Neural Network for Reducing Safety Risks of Other Agents](https://www.frontiersin.org/articles/10.3389/fnins.2022.753900/full)                     | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/ToM](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/ToM)                                                                       | Frontiers in Neuroscience               |\n| [Brain-Inspired Affective Empathy Computational Model and Its Application on Altruistic Rescue Task](https://www.frontiersin.org/articles/10.3389/fncom.2022.784967/full)                   | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/affective_empathy/BAE-SNN](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/affective_empathy/BAE-SNN)                           | Frontiers in Computational Neuroscience |\n| [A brain-inspired robot pain model based on a spiking neural network](https://www.frontiersin.org/articles/10.3389/fnbot.2022.1025338/full)                                                 | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/affective_empathy/BRP-SNN](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/affective_empathy/BRP-SNN)                           | Frontiers in Neurorobotics              |\n| [Brain-Inspired Theory of Mind Spiking Neural Network Elevates Multi-Agent Cooperation and Competition](https://papers.ssrn.com/sol3/papers.cfm?abstract_id=4271099)                        | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/MAToM-SNN  ](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/MAToM-SNN  )                                                       | SSRN                                    |\n| [Adaptive Sparse Structure Development with Pruning and Regeneration for Spiking Neural Networks](https://www.sciencedirect.com/science/article/abs/pii/S0020025524013951)                  | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/SD-SNN](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structural_Development/SD-SNN)                                                     | Information Sciences                    |\n| [DPSNN](https://arxiv.org/abs/2205.12718)                                                                                                                                                   | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Snn_safety/DPSNN](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Snn_safety/DPSNN)                                                                               | Arxiv                                   |\n\n## 2021\n| Papers                                                                                                                                                           | Codes                                                                                                                                                                              | Publisher                               |\n| ---------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------- |\n| [Quantum superposition inspired spiking neural network](https://www.cell.com/iscience/fulltext/S2589-0042(21)00848-8)                                            | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/QSNN](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/QSNN) | iScience                                |\n| [A Brain-Inspired Causal Reasoning Model Based on Spiking Neural Networks](https://ieeexplore.ieee.org/abstract/document/9534102)                                | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Knowledge_Representation_and_Reasoning/CRSNN                                                                            | IJCNN2021                               |\n| [Brain Inspired Sequences Production by Spiking Neural Networks With Reward-Modulated STDP](https://www.frontiersin.org/articles/10.3389/fncom.2021.612041/full) | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Knowledge_Representation_and_Reasoning/SPSNN                                                                            | Frontiers in Computational Neuroscience |\n| [Stylistic Composition of Melodies Based on a Brain-Inspired Spiking Neural Network](https://www.frontiersin.org/articles/10.3389/fnsys.2021.639484/full)        | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/affective_empathy/BRP-SNN                                                                              | Frontiers in Neurorobotics              |\n\n## 2020\n| Papers                                                                                                                                                                                | Codes                                                                                                                                                                                                | Publisher                               |\n| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------- |\n| [GLSNN: A Multi-Layer Spiking Neural Network Based on Global Feedback Alignment and Local STDP Plasticity](https://www.frontiersin.org/articles/10.3389/fncom.2020.576841/full)       | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/glsnn](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/glsnn) | Frontiers in Neuroscience               |\n| [Temporal-Sequential Learning With a Brain-Inspired Spiking Neural Network and Its Application to Musical Memory](https://www.frontiersin.org/articles/10.3389/fncom.2020.00051/full) | https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Knowledge_Representation_and_Reasoning/musicMemory                                                                                        | Frontiers in Computational Neuroscience |\n\n## 2018\n| Papers                                                                                                                                                                        | Codes                                                                                                                                                                              | Publisher                  |\n| ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------- |\n| [A Brain-Inspired Decision-Making Spiking Neural Network and Its Application in Unmanned Aerial Vehicle](https://www.frontiersin.org/articles/10.3389/fnbot.2018.00056/full ) | https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/decision_making/BDM-SNN/BDM-SNN-hh.py                                                                                   | Frontiers in Neurorobotics |\n| [Toward Robot Self-Consciousness (II): Brain-Inspired Robot Bodily Self Model for Self-Recognition](https://link.springer.com/article/10.1007/s12559-017-9505-1 )             | [https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/mirror_test](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Social_Cognition/mirror_test) | Cognitive Computation      |\n"
  },
  {
    "path": "documents/Tutorial.md",
    "content": "# Tutorial\n\n\n- How to Install BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/1_installation.html), [Chinese Version](https://mp.weixin.qq.com/s/fX-S3fKKDfR3NV4ISAHrZg)]\n- Overall Introduction of BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/2_Overall%20introduction%20of%20BrainCog.html), [Chinese Version](https://mp.weixin.qq.com/s/ywgQ5ydQxr6W7d_Y_XqC6w)]\n- Spiking Neuron Modeling with BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/3_Tutorial%20of%20Spiking%20Neuron%20Modeling.html), [Chinese Version](https://mp.weixin.qq.com/s/pCdlbkrdnMNHj7wcwi9D2Q)]\n- Building Efficient Spiking Neural Networks with BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/4_Building%20Efficient%20Spiking%20Neural%20Networks%20with%20BrainCog.html), [Chinese Version](https://mp.weixin.qq.com/s/NIz7MSAOJQ79m97hkP3HVg)]\n- Building Cognitive Networks with BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/5_How%20to%20build%20a%20cognitive%20network.html), [Chinese Version](https://mp.weixin.qq.com/s/K0pY6V8TupJgYXH4WvS9jg)]\n- Implementing Deep Reinforcement Learning SNNs with BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/6_SDQN.html), [Chinese Version](https://mp.weixin.qq.com/s/Zt78vj_sKn5jffEyeh86Cw)]\n- Quantum Superposition State Inspired Encoding With BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/7_QSSNN.html), [Chinese Version](https://mp.weixin.qq.com/s/YVNkwDwuF9FG-YqQyd3KTg)]\n- Converting ANNs to SNNs with BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/8_conversion.html), [Chinese Version](https://mp.weixin.qq.com/s/4g8WBoa4SOcb_VQ24wO-Xw)]\n- SNNs with Global Feedback Connections Based on BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/9_GLSNN.html), [Chinese Version](https://mp.weixin.qq.com/s/-AS0BihlXdFwCt1hgzFx3g)]\n- Multi-brain Areas Coordinated Brain-inspired Decision-Making SNNs Based on Braincog [[English Version](http://www.brain-cog.network/docs/tutorial/10_BDMSNN.html), [Chinese Version](https://mp.weixin.qq.com/s/dltOGjhUZ9yTzSssIvkhtA)]\n- Backpropagation with Spatiotemporal Adjustment for Training SNNs Based on BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/11_biobp.html), [Chinese Version](https://mp.weixin.qq.com/s/6ZtaWtiY9K96UhVnvhfP7w)]\n- Unsupervised STDP Based SNNs with Multiple Adaptive Mechanisms Based on BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/12_unsup.html), [Chinese Version](https://mp.weixin.qq.com/s/v_svhQ0N3JAYo1l8NQ1W1w)]\n- Symbolic Representation and Reasoning SNNs Based on BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/13_krr.html), [Chinese Version](https://mp.weixin.qq.com/s/UXc5QDbg8tUKVU3L6dhvxw)]\n- Multisensory Integration Based on BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/14_multisensory.html), [Chinese Version](https://mp.weixin.qq.com/s/s9gXF5NkPUGXkcFAFtsCMw)]\n- SNN-based Music Memory and Generation Based on BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/15_musicSNN.html), [Chinese Version](https://mp.weixin.qq.com/s/y-t9rFEXYSmI_-rnVk_usw)]\n- Brain-inspired Bodily Self-perception Model Based on BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/16_self.html), [Chinese Version](https://mp.weixin.qq.com/s/RLbS3GmhE8ScyZTKJKl_SQ)]\n- A Brain-inspired Theory of Mind Model Based on BrainCog for Reducing Other Agents’ Safety Risks [[English Version](http://www.brain-cog.network/docs/tutorial/17_tom.html), [Chinese Version](https://mp.weixin.qq.com/s/xzFF6IK86W4h2CMs7jt7Xw)]\n- Application of the Prefrontal Cortex Column Model in Working Memory Task with BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/17_tom.html), [Chinese Version](https://mp.weixin.qq.com/s/E6pb5W0Q4noK72NGf4SNGQ)]\n- Multi-brain areas Coordinated Brain-inspired Affective Empathy Spiking Neural Network Based on Braincog [[English Version](http://www.brain-cog.network/docs/tutorial/17_tom.html), [Chinese Version](https://mp.weixin.qq.com/s/xUnal4I5tM-Rl49fu0ZRVw)]\n- Developmental Plasticity-inspired Adaptive Pruning for SNNs based on BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/17_tom.html), [Chinese Version](https://mp.weixin.qq.com/s/-F0XcIovI1WY6YyoO5P58Q)]\n- Dynamic structural development for SNNs incorporating constraints, pruning and regeneration based on BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/17_tom.html), [Chinese Version](https://mp.weixin.qq.com/s/n1EK6lddriS0dnSgfdsa8w)]\n- BrainCog Data Engine: Spatio-temporal Sequence Data N-Omniglot and Its Applications [[English Version](http://www.brain-cog.network/docs/tutorial/17_tom.html), [Chinese Version](https://mp.weixin.qq.com/s/VIzDb2kn_4LUyUh0qc2j4w)]\n- The Implement of Object Detection and Semantic Segmentation Based on SNNs with Braincog [[English Version](http://www.brain-cog.network/docs/tutorial/17_tom.html), [Chinese Version](https://mp.weixin.qq.com/s/fyf_cbzH_i1_ZEFWTSH81Q)]\n- DVS Augmentation Strategies based on BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/17_tom.html), [Chinese Version](https://mp.weixin.qq.com/s/BNMZ_oQK9Foui2zcnKWYgg)]\n- SNNs with adaptive self-feedback and balanced excitatory–inhibitory neurons based on BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/17_tom.html), [Chinese Version](https://mp.weixin.qq.com/s/TFEucKDQHB69wChEL_Y3Cw)]\n- Spiking CapsNet based on BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/17_tom.html), [Chinese Version](https://mp.weixin.qq.com/s/VIzDb2kn_4LUyUh0qc2j4w)]\n- Drosophila-inspired Linear and Nonlinear Decision-Making Spiking Neural Network implemented by Braincog [[English Version](http://www.brain-cog.network/docs/tutorial/17_tom.html), [Chinese Version](https://mp.weixin.qq.com/s/VIzDb2kn_4LUyUh0qc2j4w)]\n- Reward-modulated brain-inspired spiking neural network to empower group for nature-inspired self-organized obstacle avoidance implemented by Braincog [[English Version](http://www.brain-cog.network/docs/tutorial/17_tom.html), [Chinese Version](https://mp.weixin.qq.com/s/VIzDb2kn_4LUyUh0qc2j4w)]\n- Using BrainCog to Realize Robot's Intention Prediction for Users [[English Version](http://www.brain-cog.network/docs/tutorial/17_tom.html), [Chinese Version](https://mp.weixin.qq.com/s/VIzDb2kn_4LUyUh0qc2j4w)]\n- FPGA-based Spiking Neural Networks Acceleration Using BrainCog [[English Version](http://www.brain-cog.network/docs/tutorial/17_tom.html), [Chinese Version](https://mp.weixin.qq.com/s/VIzDb2kn_4LUyUh0qc2j4w)]\n"
  },
  {
    "path": "examples/Brain_Cognitive_Function_Simulation/drosophila/README.md",
    "content": "# Drosophila-inspired decision-making SNN\n\n## Run\n\nThe drosophila.py implements the core code of the Drosophila-inspired linear and non-linear decision-making in paper entitled \"A Neural Algorithm for linear and non-linear Decision-making inspired by Drosophila\".\n\nThe experiments includes training phase and testing phase:\n\n* Training Phase\n\nTraining linear network and nonlinear network by reward-modulated spiking neural network: green-upright T is safe and blue-inverted T is dangerous\n\n* Testing Phase \n\nFor linear pathway and nonlinear pathway, choose between blue-upright T and green-inverted T, and count the PI values under different color intensity\n\n## Results\n\nThe following picture shows the linear (a) and nonlinear (b) pathways, the training and testing phases (c), and the PI values on different color intensities (d).\n\n![description](./dro.jpg)\n\nDifferences from the original article: an improved reward-modulated STDP learning rule. \n\n## Citation\n\nIf you find this package helpful, please consider citing the following papers:\n\n```BibTex\n@article{zhao2020neural,\n  title={A neural algorithm for Drosophila linear and nonlinear decision-making},\n  author={Zhao, Feifei and Zeng, Yi and Guo, Aike and Su, Haifeng and Xu, Bo},\n  journal={Scientific Reports},\n  volume={10},\n  number={1},\n  pages={1--16},\n  year={2020},\n  publisher={Nature Publishing Group}\n}\n\n@misc{https://doi.org/10.48550/arxiv.2207.08533,\n  doi = {10.48550/ARXIV.2207.08533},\n  url = {https://arxiv.org/abs/2207.08533},\n  author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},\n  title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},\n  publisher = {arXiv},\n  year = {2022},\n}\n\n```\n"
  },
  {
    "path": "examples/Brain_Cognitive_Function_Simulation/drosophila/drosophila.py",
    "content": "import numpy as np\r\nimport torch,os,sys\r\nfrom torch import nn\r\nfrom torch.nn import Parameter\r\n\r\nimport abc\r\nimport math\r\nfrom abc import ABC\r\n\r\nimport numpy as np\r\nimport torch\r\nfrom torch import nn\r\nfrom torch.nn import Parameter\r\nimport torch.nn.functional as F\r\nimport matplotlib.pyplot as plt\r\nfrom braincog.base.strategy.surrogate import *\r\nfrom braincog.base.node.node import IFNode\r\nfrom braincog.base.learningrule.STDP import STDP,MutliInputSTDP\r\nfrom braincog.base.connection.CustomLinear import CustomLinear\r\nfrom braincog.model_zoo.nonlinearNet import droDMTestNet\r\nfrom braincog.model_zoo.linearNet import droDMTrainNet\r\nimport copy\r\n\r\nif __name__==\"__main__\":\r\n    \"\"\"\r\n    建立训练网络\r\n    \"\"\"\r\n    num_state=5\r\n    num_action=2\r\n    weight_exc=0.5\r\n    weight_inh=-0.05\r\n    trace_decay=0.8\r\n    mb_connection=[]\r\n    #input-visual\r\n    con_matrix0 = torch.eye((num_state), dtype=torch.float)\r\n    mb_connection.append(CustomLinear(weight_exc * con_matrix0,con_matrix0))\r\n    # visual-kc\r\n    con_matrix1 =torch.eye((num_state), dtype=torch.float)\r\n    mb_connection.append(CustomLinear( weight_exc * con_matrix1,con_matrix1))\r\n    # kc-mbon\r\n    con_matrix2 = torch.ones((num_state,num_action), dtype=torch.float)\r\n    mb_connection.append(CustomLinear(weight_exc * con_matrix2,con_matrix2))\r\n    # mbon-mbon\r\n    con_matrix3 = torch.ones((num_action,num_action), dtype=torch.float)\r\n    con_matrix4 = torch.eye((num_action), dtype=torch.float)\r\n    con_matrix5=con_matrix3-con_matrix4\r\n    con_matrix5=con_matrix5\r\n    mb_connection.append(CustomLinear(weight_inh * con_matrix5,con_matrix5))\r\n\r\n    MB=droDMTrainNet(mb_connection)\r\n    weight_trace_mbon=torch.zeros(con_matrix2.shape, dtype=torch.float)\r\n    \"\"\"\r\n    学习绿色正立T是安全的  蓝色倒立T是有惩罚的\r\n    \"\"\"\r\n    #learning GT\r\n    # RGB T t\r\n    GT = torch.tensor([0, 0.8, 0, 1.0, 0])\r\n    Bt = torch.tensor([0, 0, 0.8, 0, 1.0])\r\n    input = GT - Bt  # input GT\r\n    input[input < 0] = 0\r\n    for i_train in range(20):\r\n        GT_out,dwkc,dwmbon=MB(input)\r\n        print(\"stdp:\",dwkc,dwmbon)\r\n        #vis-kc STDP\r\n        MB.UpdateWeight(1, dwkc)\r\n        #kc-mbon rstdp\r\n        weight_trace_mbon *= trace_decay\r\n        weight_trace_mbon += dwmbon\r\n        if max(GT_out)>0:\r\n            r=torch.ones((num_state,num_action), dtype=torch.float)\r\n            p_action= torch.tensor([0])\r\n            r[:,p_action]=-1\r\n            dw_mbon = r * weight_trace_mbon\r\n            MB.UpdateWeight(2, dw_mbon)\r\n            print(\"output:\",GT_out)\r\n\r\n    MB.reset()\r\n    weight_trace_mbon = torch.zeros(con_matrix2.shape, dtype=torch.float)\r\n    #learning Bt\r\n    GT = torch.tensor([0,0.8,0, 1.0, 0])\r\n    Bt = torch.tensor([0, 0, 0.8, 0, 1.0])\r\n    input = Bt - GT  # input Bt\r\n    input[input < 0] = 0\r\n    for i_train in range(20):\r\n        GT_out,dwkc,dwmbon=MB(input)\r\n        #vis-kc STDP\r\n        MB.UpdateWeight(1, dwkc)\r\n        #kc-mbon rstdp\r\n        weight_trace_mbon *= trace_decay\r\n        weight_trace_mbon += dwmbon\r\n        if max(GT_out)>0:\r\n            r=torch.ones((num_state,num_action), dtype=torch.float)\r\n            p_action= torch.tensor([1])\r\n            r[:,p_action]=-1\r\n            dw_mbon = r * weight_trace_mbon\r\n            MB.UpdateWeight(2, dw_mbon)\r\n    train_weight=MB.getweight()\r\n    for i in range(len(train_weight)):\r\n        print(\"weight after learning:\", train_weight[i].weight.data)\r\n    print(\"end training\")\r\n\r\n\r\n    #linear test conflict decision making\r\n    test_num=12\r\n    t1=torch.zeros((test_num), dtype=torch.float)\r\n    t2=torch.zeros((test_num), dtype=torch.float)\r\n    for c in range(t1.shape[0]):\r\n        MB_test = droDMTrainNet(copy.deepcopy(train_weight))\r\n        MB_test.reset()\r\n        Gt = torch.tensor([0, (c*0.1), 0, 0, 0.5])\r\n        BT = torch.tensor([0, 0, (c*0.1), 0.5, 0])\r\n        input =Gt - BT   # input Gt\r\n        input[input < 0] = 0\r\n        count=torch.zeros((num_action), dtype=torch.float)\r\n        for i_train in range(500):\r\n            GT_out,dwkc,dwmbon=MB_test(input)\r\n            count+=GT_out\r\n        t1[c]=count[0]\r\n        t2[c]=count[1]\r\n    p1=(t1-t2)/(t1+t2)\r\n    print(t1,t2,p1)\r\n    for i in range(len(train_weight)):\r\n        print(\"weight after learning:\", train_weight[i].weight.data)\r\n\r\n    \"\"\"\r\n    建立测试网络，验证不同浓度下绿色正立T和蓝色倒立T\r\n    \"\"\"\r\n    # non-linear test conflict decision making\r\n    weight_inh_test=-0.3\r\n    num_apl=2\r\n    num_da=1\r\n    da_mb_connection=train_weight\r\n    # kc-apl\r\n    con_matrix6 = torch.ones((num_state, num_apl), dtype=torch.float)\r\n    da_mb_connection.append(CustomLinear(weight_exc * con_matrix6, con_matrix6))\r\n    # apl-kc\r\n    con_matrix7 = torch.ones((num_apl,num_state), dtype=torch.float)\r\n    da_mb_connection.append(CustomLinear(weight_inh_test * con_matrix7, con_matrix7))\r\n    # da-apl\r\n    con_matrix8 = torch.ones((num_da, num_apl), dtype=torch.float)\r\n    da_mb_connection.append(CustomLinear(weight_inh_test * con_matrix8, con_matrix8))\r\n    # apl-da\r\n    con_matrix9 = torch.ones((num_apl, num_da), dtype=torch.float)\r\n    da_mb_connection.append(CustomLinear(weight_inh_test * con_matrix9, con_matrix9))\r\n    # 1-da\r\n    con_matrix10 = torch.ones((num_da), dtype=torch.float)\r\n    da_mb_connection.append(CustomLinear(weight_exc * con_matrix10, con_matrix10))\r\n    # da-mbon\r\n    con_matrix11 = torch.ones((num_da,num_action), dtype=torch.float)\r\n    da_mb_connection.append(CustomLinear(weight_exc * con_matrix11, con_matrix11))\r\n\r\n    #0 input-vis 1 vis-kc 2 kc-mbon 3-mbon-mbon 4 kc-apl 5 apl-kc 6 da-apl 7 apl-da 8 input-da\r\n    t1 = torch.zeros((test_num), dtype=torch.float)\r\n    t2 = torch.zeros((test_num), dtype=torch.float)\r\n    for c in range(t1.shape[0]):\r\n        DA_MB_test = droDMTestNet(copy.deepcopy(da_mb_connection))\r\n        DA_MB_test.reset()\r\n        Gt = torch.tensor([0, (c * 0.1), 0, 0, 0.5])\r\n        BT = torch.tensor([0, 0, (c * 0.1), 0.5, 0])\r\n        input = Gt - BT  # input Gt\r\n        input[input < 0] = 0\r\n        count = torch.zeros((num_action), dtype=torch.float)\r\n        for i_train in range(500):\r\n            if i_train<10 and i_train%2==0:\r\n                input_da = torch.tensor([0.5])\r\n            else:\r\n                input_da = torch.tensor([0.0])\r\n            GT_out, dwkc, dwapl= DA_MB_test(input,input_da)\r\n            DA_MB_test.UpdateWeight(5, dwkc)\r\n            DA_MB_test.UpdateWeight(4, dwapl)\r\n            count += GT_out\r\n        t1[c] = count[0]\r\n        t2[c] = count[1]\r\n    p2 = (t1 - t2) / (t1 + t2)\r\n    print(t1, t2, p2)\r\n    MB_test = MB.getweight()\r\n    for i in range(len(train_weight)):\r\n        print(\"weight after learning:\", train_weight[i].weight.data)\r\n\r\n\r\nx = torch.arange(0, test_num)\r\nx=x*0.1\r\nplt.figure()\r\nA,=plt.plot(x, p1,label=\"linear\")\r\nB,=plt.plot(x, p2,label=\"non-linear\")\r\nfont1 = {'family' : 'Times New Roman','weight' : 'normal','size' : 15,}\r\nplt.xlabel(\"color intensity\",font1)\r\nplt.ylabel(\"PI\",font1)\r\nplt.legend(handles=[A,B],prop=font1)\r\nplt.show()"
  },
  {
    "path": "examples/Embodied_Cognition/RHI/RHI_Test.py",
    "content": "import numpy as np\r\nimport torch,os,sys\r\nfrom torch import nn\r\nfrom torch.nn import Parameter \r\nimport abc\r\nimport math\r\nfrom abc import ABC\r\nimport numpy as np\r\nimport torch\r\nfrom torch import nn\r\nfrom torch.nn import Parameter\r\nimport torch.nn.functional as F\r\nimport matplotlib.pyplot as plt\r\nfrom braincog.base.strategy.surrogate import *\r\nimport os\r\nos.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\"\r\nimport random\r\nimport gc\r\nfrom braincog.base.node.node import IzhNodeMU\r\nimport objgraph\r\nfrom pympler import tracker\r\n\r\nclass CustomLinear(nn.Module):\r\n    def __init__(self, weight,mask=None):\r\n        super().__init__()\r\n\r\n        self.weight = nn.Parameter(weight, requires_grad=True)\r\n        self.mask=mask\r\n    def forward(self, x: torch.Tensor):\r\n        #\r\n        # ret.shape = [C]\r\n        return x.mul(self.weight) # Changed\r\n\r\n    def update(self, dw):\r\n        with torch.no_grad():\r\n            if self.mask is not None:\r\n                dw *= self.mask\r\n            self.weight.data+= dw\r\n\r\nclass M1Net(nn.Module):\r\n    def __init__(self,connection):\r\n        super().__init__()\r\n        self.node = []\r\n        self.node.append(IzhNodeMU(threshold=param_threshold, a=param_a, b=param_b, c=param_c, d=param_d, mem=param_mem, u=param_u, dt=param_dt))\r\n        self.connection = connection\r\n\r\n    def forward(self, input):        \r\n        input_n = input*I_max\r\n        input_r = torch.round(input_n)\r\n        Spike = torch.zeros(num_neuron, dtype=torch.float)\r\n        self.node[0].n_reset()\r\n        if TrickID == 1:\r\n            for i in range(num_AI):\r\n                input_ri = int(input_r[i].item())\r\n                FR_i = spike_num_list[input_ri]\r\n                Spike[i] = FR_i\r\n            FR_n = Spike\r\n        else:\r\n            for t in range(Simulation_time):    \r\n                self.out=self.node[0](input_r)\r\n                n_Spike = self.node[0].spike          \r\n                Spike = Spike + n_Spike\r\n            FR_n = Spike/Simulation_time\r\n        return FR_n\r\n    \r\n    def reset(self):\r\n        for i in range(len(self.node)):\r\n            self.node[i].n_reset()\r\n\r\nclass VNet(nn.Module):\r\n    def __init__(self,connection):\r\n        super().__init__()\r\n        self.node = []\r\n        self.node.append(IzhNodeMU(threshold=param_threshold, a=param_a, b=param_b, c=param_c, d=param_d, mem=param_mem, u=param_u, dt=param_dt))\r\n        self.connection = connection\r\n\r\n    def forward(self, input):        \r\n        input_n = input*I_max\r\n        input_r = torch.round(input_n)\r\n        Spike = torch.zeros(num_neuron, dtype=torch.float)\r\n        self.node[0].n_reset()\r\n        if TrickID == 1:\r\n            for i in range(num_neuron):\r\n                input_ri = int(input_r[i].item())\r\n                FR_i = spike_num_list[input_ri]\r\n                Spike[i] = FR_i\r\n            FR_n = Spike\r\n        else:\r\n            for t in range(Simulation_time):    \r\n                self.out=self.node[0](input_r)\r\n                n_Spike = self.node[0].spike          \r\n                Spike = Spike + n_Spike\r\n            FR_n = Spike/Simulation_time\r\n        return FR_n\r\n\r\n    def reset(self):\r\n        for i in range(len(self.node)):\r\n            self.node[i].n_reset()\r\n\r\nclass S1Net(nn.Module):\r\n    def __init__(self,connection):\r\n        super().__init__()\r\n        self.node = []\r\n        self.node.append(IzhNodeMU(threshold=param_threshold, a=param_a, b=param_b, c=param_c, d=param_d, mem=param_mem, u=param_u, dt=param_dt))\r\n        self.connection = connection\r\n\r\n    def forward(self, input, FR, C): \r\n        FR_W = torch.zeros(num_neuron, dtype=torch.float)\r\n        if len(FR.shape) == 1:\r\n            FR_W = FR*self.connection[0].weight\r\n        else:\r\n            for i in range(FR.shape[0]):\r\n                FR_Wi = FR[i]*self.connection[i].weight\r\n                FR_W = FR_W + FR_Wi\r\n        sf = torch.tanh(FR_W)\r\n        sf = torch.where(sf<0, 0, sf)\r\n        input_n = -C * (input-sf) + input\r\n        input_n = torch.where(input_n<0, 0, input_n)\r\n        input = input_n*I_max\r\n        input_r = torch.round(input)\r\n        Spike = torch.zeros(num_S1, dtype=torch.float)\r\n        self.node[0].n_reset()\r\n        if TrickID == 1:\r\n            for i in range(num_neuron):\r\n                input_ri = int(input_r[i].item())\r\n                FR_i = spike_num_list[input_ri]\r\n                Spike[i] = FR_i\r\n            FR_n = Spike\r\n        else:\r\n            for t in range(Simulation_time):    \r\n                self.out=self.node[0](input_r)\r\n                n_Spike = self.node[0].spike          \r\n                Spike = Spike + n_Spike\r\n            FR_n = Spike/Simulation_time\r\n        return FR_n, input_n\r\n\r\n    def reset(self):\r\n        for i in range(len(self.node)):\r\n            self.node[i].n_reset()\r\n\r\nclass EBANet(nn.Module):\r\n    def __init__(self,connection):\r\n        super().__init__()\r\n        self.node = []\r\n        self.node.append(IzhNodeMU(threshold=param_threshold, a=param_a, b=param_b, c=param_c, d=param_d, mem=param_mem, u=param_u, dt=param_dt))\r\n        self.connection = connection\r\n\r\n    def forward(self, input, FR, C):   \r\n        FR_W = torch.zeros(num_neuron, dtype=torch.float)\r\n        if len(FR.shape) == 1:\r\n            FR_W = FR*self.connection[0].weight\r\n        else:\r\n            for i in range(FR.shape[0]):\r\n                FR_Wi = FR[i]*self.connection[i].weight\r\n                FR_W = FR_W + FR_Wi\r\n        sf = torch.tanh(FR_W)\r\n        sf = torch.where(sf<0, 0, sf)\r\n        input_n = -C * (input-sf) + input\r\n        input_n = torch.where(input_n<0, 0, input_n)\r\n        input = input_n*I_max\r\n        input_r = torch.round(input)\r\n        Spike = torch.zeros(num_S1, dtype=torch.float)\r\n        self.node[0].n_reset()\r\n        if TrickID == 1:\r\n            for i in range(num_neuron):\r\n                input_ri = int(input_r[i].item())\r\n                FR_i = spike_num_list[input_ri]\r\n                Spike[i] = FR_i\r\n            FR_n = Spike\r\n        else:\r\n            for t in range(Simulation_time):    \r\n                self.out=self.node[0](input_r)\r\n                n_Spike = self.node[0].spike          \r\n                Spike = Spike + n_Spike\r\n            FR_n = Spike/Simulation_time\r\n        return FR_n, input_n\r\n\r\n    def reset(self):\r\n        for i in range(len(self.node)):\r\n            self.node[i].n_reset()\r\n\r\nclass TPJNet(nn.Module):\r\n    def __init__(self,connection):\r\n        super().__init__()\r\n        self.node = []\r\n        self.node.append(IzhNodeMU(threshold=param_threshold, a=param_a, b=param_b, c=param_c, d=param_d, mem=param_mem, u=param_u, dt=param_dt))\r\n        self.connection = connection       \r\n    \r\n    def forward(self, input, FR, C):   \r\n        FR_W = torch.zeros(num_neuron, dtype=torch.float)\r\n        if len(FR.shape) == 1:\r\n            FR_W = FR*self.connection[0].weight\r\n        else:\r\n            for i in range(FR.shape[0]):\r\n                FR_Wi = FR[i]*self.connection[i].weight\r\n                FR_W = FR_W + FR_Wi\r\n        sf = torch.tanh(FR_W)\r\n        sf = torch.where(sf<0, 0, sf)\r\n        input_n = -C * (input-sf) + input\r\n        input_n = torch.where(input_n<0, 0, input_n)\r\n        input = input_n*I_max\r\n        input_r = torch.round(input)\r\n        Spike = torch.zeros(num_S1, dtype=torch.float)\r\n        self.node[0].n_reset()\r\n        if TrickID == 1:\r\n            for i in range(num_neuron):\r\n                input_ri = int(input_r[i].item())\r\n                FR_i = spike_num_list[input_ri]\r\n                Spike[i] = FR_i\r\n            FR_n = Spike\r\n        else:\r\n            for t in range(Simulation_time):    \r\n                self.out=self.node[0](input_r)\r\n                n_Spike = self.node[0].spike          \r\n                Spike = Spike + n_Spike\r\n            FR_n = Spike/Simulation_time\r\n        return FR_n, input_n\r\n\r\n    def reset(self):\r\n        for i in range(len(self.node)):\r\n            self.node[i].n_reset()\r\n\r\n    def UpdateWeight(self, i, W):\r\n        self.connection[i].weight.data = W\r\n\r\nclass AINet(nn.Module):\r\n    def __init__(self,connection):\r\n        super().__init__()\r\n        self.node = []\r\n        self.node.append(IzhNodeMU(threshold=param_threshold, a=param_a, b=param_b, c=param_c, d=param_d, mem=param_mem, u=param_u, dt=param_dt))\r\n        self.connection = connection       \r\n\r\n    def forward(self, input, FR, C):   \r\n        FR_W = torch.zeros(num_neuron, dtype=torch.float)\r\n        if len(FR.shape) == 1:\r\n            FR_W = FR*self.connection[0].weight\r\n        else:\r\n            for i in range(FR.shape[0]):\r\n                FR_Wi = FR[i]*self.connection[i].weight\r\n                FR_W = FR_W + FR_Wi\r\n        sf = torch.tanh(FR_W)\r\n        sf = torch.where(sf<0, 0, sf)\r\n        input_n = -C * (input-sf) + input\r\n        input_n = torch.where(input_n<0, 0, input_n)\r\n        input = input_n*I_max\r\n        input_r = torch.round(input)\r\n        Spike = torch.zeros(num_S1, dtype=torch.float)\r\n        self.node[0].n_reset()\r\n        if TrickID == 1:\r\n            for i in range(num_neuron):\r\n                input_ri = int(input_r[i].item())\r\n                FR_i = spike_num_list[input_ri]\r\n                Spike[i] = FR_i\r\n            FR_n = Spike\r\n        else:\r\n            for t in range(Simulation_time):    \r\n                self.out=self.node[0](input_r)\r\n                n_Spike = self.node[0].spike          \r\n                Spike = Spike + n_Spike\r\n            FR_n = Spike/Simulation_time\r\n        return FR_n, input_n\r\n   \r\n    def reset(self):\r\n        for i in range(len(self.node)):\r\n            self.node[i].n_reset()\r\n\r\n    def UpdateWeight(self, i, W):\r\n        self.connection[i].weight.data = self.connection[i].weight.data + W\r\n\r\n\r\ndef DeltaWeight(Pre, Pre_n, Post, Post_n):\r\n    alpha = -0.0035\r\n    beta = 0.35\r\n    gamma = -0.55\r\n    T1 = alpha * (Pre_n*Post_n)\r\n    T2 = beta * (Pre_n*(Post_n-Post))\r\n    T3 = gamma * ((Pre_n-Pre)*Post_n)\r\n    dW = T1 + T2 + T3\r\n    return dW\r\n\r\n\r\n\r\nif __name__==\"__main__\":\r\n    \"\"\"\r\n    Set the number of neurons, and each neuron represents unique motion information (such as angle)\r\n    \"\"\"\r\n    # number of neurons\r\n    num_neuron = 9 \r\n    num_M1 = num_neuron \r\n    num_S1  = num_neuron \r\n    num_TPJ = num_neuron\r\n    num_V = num_neuron\r\n    num_EBA = num_neuron\r\n    num_AI = num_neuron\r\n\r\n    Init_Weight = 1.\r\n\r\n    param_threshold = 30.\r\n    param_a = 0.02\r\n    param_b = -0.1\r\n    param_c = -55.\r\n    param_d = 18.\r\n    param_mem = -70.\r\n    param_u = 0.\r\n    param_dt = 1.\r\n    Simulation_time = 1000\r\n    I_max = 1000\r\n\r\n    # When TrickID is set to 1, it means that the mapping relationship from input current \r\n    # to firing rate is obtained directly by loading Izh.npy, \r\n    # which can significantly reduce the program running time\r\n    TrickID = 1 \r\n    if TrickID == 1:\r\n        spike_num_list=np.load('Izh.npy')\r\n        spike_num_list = spike_num_list/I_max\r\n\r\n    ##############################\r\n    # M1\r\n    ##############################\r\n    # M1_Input-M1\r\n    M1_connection = []\r\n    con_matrix0 = torch.ones(num_M1, dtype=torch.float)*Init_Weight\r\n    M1_connection.append(CustomLinear(con_matrix0))\r\n    M1 = M1Net(M1_connection)\r\n  \r\n    ##############################\r\n    # V\r\n    ##############################\r\n    # V_Input-V\r\n    V_connection = []\r\n    con_matrix3 = torch.ones(num_V, dtype=torch.float)*Init_Weight\r\n    V_connection.append(CustomLinear(con_matrix3))\r\n    V = VNet(V_connection)\r\n \r\n    ##############################\r\n    # S1\r\n    ##############################\r\n    # M1-S1\r\n    S1_connection = []\r\n    con_matrix1 = torch.ones(num_S1, dtype=torch.float)*Init_Weight\r\n    S1_connection.append(CustomLinear(con_matrix1))\r\n    S1 = S1Net(S1_connection)\r\n\r\n    ##############################\r\n    # EBA\r\n    ##############################\r\n    # V-EBA\r\n    EBA_connection = []\r\n    con_matrix4 = torch.ones(num_EBA, dtype=torch.float)*Init_Weight\r\n    EBA_connection.append(CustomLinear(con_matrix4))\r\n    EBA = EBANet(EBA_connection)\r\n\r\n    ##############################\r\n    # TPJ\r\n    ##############################\r\n    # S1-TPJ, EBA-TPJ\r\n    TPJ_connection = []\r\n    # S1-TPJ\r\n    con_matrix2 = torch.ones(num_TPJ, dtype=torch.float)*Init_Weight*150\r\n    TPJ_connection.append(CustomLinear(con_matrix2))\r\n    # EBA-TPJ\r\n    con_matrix5 = torch.ones(num_TPJ, dtype=torch.float)*Init_Weight*150\r\n    TPJ_connection.append(CustomLinear(con_matrix5))\r\n    TPJ = TPJNet(TPJ_connection)\r\n\r\n    ##############################\r\n    # AI\r\n    ##############################\r\n    # S1-AI, TPJ-AI, EBA-AI\r\n    AI_connection = []\r\n    # S1-AI\r\n    con_matrix6 = torch.ones(num_AI, dtype=torch.float)*Init_Weight\r\n    AI_connection.append(CustomLinear(con_matrix6))\r\n    # TPJ-AI\r\n    con_matrix7 = torch.ones(num_AI, dtype=torch.float)*Init_Weight\r\n    AI_connection.append(CustomLinear(con_matrix7))\r\n    # EBA-AI\r\n    con_matrix8 = torch.ones(num_AI, dtype=torch.float)*Init_Weight\r\n    AI_connection.append(CustomLinear(con_matrix8))\r\n    AI = AINet(AI_connection)\r\n    \r\n    AI.connection[0].weight.data = torch.from_numpy(np.load('W_S1_AI.npy'))\r\n    AI.connection[2].weight.data = torch.from_numpy(np.load('W_EBA_AI.npy'))\r\n\r\n    ##############################\r\n    # Coding\r\n    ##############################\r\n    S = 1\r\n    ISI = 1\r\n    JMax = int((num_neuron-1)/2)\r\n    listJ = list(range(-JMax,JMax+1))\r\n    Coding = torch.zeros([num_neuron, num_neuron], dtype=torch.float)\r\n    for i in range(len(listJ)):\r\n        e = float(listJ[i])\r\n        listY = []\r\n        for j in range(len(listJ)):\r\n            x = float(listJ[j])\r\n            y = math.exp(-(x-e)**2/(2*S**2))\r\n            Coding[i][j] = y \r\n\r\n    print(AI.connection[0].weight.data) # dW_S1AI\r\n    print(AI.connection[2].weight.data) # dW_EBAAI\r\n        \r\n    ##############################\r\n    # Test\r\n    ##############################\r\n    Time = 300\r\n    CT = 100\r\n    Motion_Start = 1\r\n    Motion_End = Motion_Start + CT\r\n    Vision_Start = Motion_End\r\n    Vision_End = Vision_Start + CT\r\n\r\n    CM1 = 0.04 \r\n    CV = 0.04 \r\n    CS1 = 0.04 \r\n    CEBA = 0.04 \r\n    CTPJ = 0.01 \r\n    CAI = 0.15 \r\n    \r\n    Result_List = []\r\n    Veridical_hand = int((num_neuron-1)/2)\r\n    for Disparity in range(-JMax,JMax+1):\r\n        M1_input = torch.zeros(num_M1, dtype=torch.float)\r\n        V_input = torch.zeros(num_V, dtype=torch.float)\r\n        S1_input = torch.zeros(num_S1, dtype=torch.float)\r\n        TPJ_input = torch.zeros(num_TPJ, dtype=torch.float)\r\n        EBA_input = torch.zeros(num_EBA, dtype=torch.float)\r\n        AI_input = torch.zeros(num_AI, dtype=torch.float)\r\n        \r\n        FR_M1 = torch.zeros(num_M1, dtype=torch.float)\r\n        FR_V = torch.zeros(num_V, dtype=torch.float)\r\n        FR_S1 = torch.zeros(num_S1, dtype=torch.float)\r\n        FR_EBA = torch.zeros(num_EBA, dtype=torch.float)\r\n        FR_TPJ = torch.zeros(num_TPJ, dtype=torch.float)\r\n        FR_AI = torch.zeros(num_AI, dtype=torch.float)\r\n\r\n        FR_AI_List = torch.zeros([Time, num_AI], dtype=torch.float)\r\n\r\n        \r\n        with torch.no_grad():\r\n            for t in range(1,Time+1):\r\n                S_M1 = torch.zeros(num_M1, dtype=torch.float)\r\n                S_V = torch.zeros(num_V, dtype=torch.float)\r\n\r\n                if t>=Motion_Start and t<=Motion_End:\r\n                    S_M1 = Coding[Veridical_hand]\r\n                    M1_input = (1-(1-CM1)**t)*S_M1\r\n                else:\r\n                    M1_input = S_M1 \r\n                \r\n                    \r\n                if t>=Vision_Start and t<=Vision_End:\r\n                    S_V = Coding[Veridical_hand+Disparity]\r\n                    V_input = (1-(1-CV)**(t-CT))*S_V\r\n                else:\r\n                    V_input = S_V\r\n\r\n                FR_M1_n = M1(M1_input)\r\n\r\n                FR_V_n = V(V_input)       \r\n\r\n                [FR_S1_n, S1_input_n] = S1(S1_input, FR_M1_n, CS1)\r\n\r\n                [FR_EBA_n, EBA_input_n] = EBA(EBA_input, FR_V_n, CEBA)\r\n\r\n                FR_Input_TPJ_n = torch.stack((FR_S1_n, FR_EBA_n), 0)\r\n                [FR_TPJ_n, TPJ_input_n] = TPJ(TPJ_input, FR_Input_TPJ_n, CTPJ)\r\n                \r\n                FR_Input_AI = torch.stack((FR_S1_n, FR_TPJ_n, FR_EBA_n), 0)\r\n                [FR_AI_n, AI_input_n] = AI(AI_input, FR_Input_AI, CAI)\r\n\r\n                FR_AI_List[t-1] = FR_AI_n\r\n\r\n                FR_M1 = FR_M1_n\r\n                FR_V = FR_V_n\r\n                FR_S1 = FR_S1_n\r\n                FR_EBA = FR_EBA_n\r\n                FR_TPJ = FR_TPJ_n\r\n                FR_AI = FR_AI_n\r\n                \r\n                S1_input = S1_input_n\r\n                TPJ_input = TPJ_input_n\r\n                EBA_input = EBA_input_n\r\n                AI_input = AI_input_n\r\n\r\n            print('Test Time End')\r\n            Estimated_hand = torch.max(torch.max(FR_AI_List, 0)[0],0)[1].item()\r\n            Proprioceptive_drift = Estimated_hand - Veridical_hand\r\n            R = [Disparity, Proprioceptive_drift]            \r\n            Result_List.append(R)\r\n\r\n            print(R)\r\n            print(torch.max(FR_AI_List, 0)[0])\r\n            print(\"----------------------------\") \r\n    \r\n    print(Result_List)\r\n\r\n    X = [x[0] for x in Result_List]\r\n    Y = [x[1] for x in Result_List]\r\n    S = np.polyfit(X,Y,3)\r\n    xn = np.linspace(-(num_neuron-1)/2, (num_neuron-1)/2, 1000)\r\n    yn = np.poly1d(S)\r\n    plt.plot(xn, yn(xn), X, Y, 'o')\r\n    plt.show()"
  },
  {
    "path": "examples/Embodied_Cognition/RHI/RHI_Train.py",
    "content": "import numpy as np\r\nimport torch,os,sys\r\nfrom torch import nn\r\nfrom torch.nn import Parameter \r\nimport abc\r\nimport math\r\nfrom abc import ABC\r\nimport numpy as np\r\nimport torch\r\nfrom torch import nn\r\nfrom torch.nn import Parameter\r\nimport torch.nn.functional as F\r\nimport matplotlib.pyplot as plt\r\nfrom braincog.base.strategy.surrogate import *\r\nimport os\r\nos.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\"\r\nimport random\r\nimport gc\r\nfrom braincog.base.node.node import IzhNodeMU\r\nimport objgraph\r\nfrom pympler import tracker\r\n\r\nclass CustomLinear(nn.Module):\r\n    def __init__(self, weight,mask=None):\r\n        super().__init__()\r\n\r\n        self.weight = nn.Parameter(weight, requires_grad=True)\r\n        self.mask=mask\r\n    def forward(self, x: torch.Tensor):\r\n        #\r\n        # ret.shape = [C]\r\n        return x.mul(self.weight)\r\n\r\n    def update(self, dw):\r\n        with torch.no_grad():\r\n            if self.mask is not None:\r\n                dw *= self.mask\r\n            self.weight.data+= dw\r\n\r\nclass M1Net(nn.Module):\r\n    def __init__(self,connection):\r\n        super().__init__()\r\n        self.node = []\r\n        self.node.append(IzhNodeMU(threshold=param_threshold, a=param_a, b=param_b, c=param_c, d=param_d, mem=param_mem, u=param_u, dt=param_dt))\r\n        self.connection = connection\r\n\r\n    def forward(self, input):        \r\n        input_n = input*I_max\r\n        input_r = torch.round(input_n)\r\n        Spike = torch.zeros(num_neuron, dtype=torch.float)\r\n        self.node[0].n_reset()\r\n        if TrickID == 1:\r\n            for i in range(num_AI):\r\n                input_ri = int(input_r[i].item())\r\n                FR_i = spike_num_list[input_ri]\r\n                Spike[i] = FR_i\r\n            FR_n = Spike\r\n        else:\r\n            for t in range(Simulation_time):    \r\n                self.out=self.node[0](input_r)\r\n                n_Spike = self.node[0].spike          \r\n                Spike = Spike + n_Spike\r\n            FR_n = Spike/Simulation_time\r\n        return FR_n\r\n    \r\n    def reset(self):\r\n        for i in range(len(self.node)):\r\n            self.node[i].n_reset()\r\n\r\nclass VNet(nn.Module):\r\n    def __init__(self,connection):\r\n        super().__init__()\r\n        self.node = []\r\n        self.node.append(IzhNodeMU(threshold=param_threshold, a=param_a, b=param_b, c=param_c, d=param_d, mem=param_mem, u=param_u, dt=param_dt))\r\n        self.connection = connection\r\n\r\n    def forward(self, input):        \r\n        input_n = input*I_max\r\n        input_r = torch.round(input_n)\r\n        Spike = torch.zeros(num_neuron, dtype=torch.float)\r\n        self.node[0].n_reset()\r\n        if TrickID == 1:\r\n            for i in range(num_neuron):\r\n                input_ri = int(input_r[i].item())\r\n                FR_i = spike_num_list[input_ri]\r\n                Spike[i] = FR_i\r\n            FR_n = Spike\r\n        else:\r\n            for t in range(Simulation_time):    \r\n                self.out=self.node[0](input_r)\r\n                n_Spike = self.node[0].spike          \r\n                Spike = Spike + n_Spike\r\n            FR_n = Spike/Simulation_time\r\n        return FR_n\r\n\r\n    def reset(self):\r\n        for i in range(len(self.node)):\r\n            self.node[i].n_reset()\r\n\r\nclass S1Net(nn.Module):\r\n    def __init__(self,connection):\r\n        super().__init__()\r\n        self.node = []\r\n        self.node.append(IzhNodeMU(threshold=param_threshold, a=param_a, b=param_b, c=param_c, d=param_d, mem=param_mem, u=param_u, dt=param_dt))\r\n        self.connection = connection\r\n\r\n    def forward(self, input, FR, C, Fired, W_LatInh): \r\n        FR_W = torch.zeros(num_neuron, dtype=torch.float)\r\n        if len(FR.shape) == 1:\r\n            FR_W = FR*self.connection[0].weight\r\n        else:\r\n            for i in range(FR.shape[0]):\r\n                FR_Wi = FR[i]*self.connection[i].weight\r\n                FR_W = FR_W + FR_Wi\r\n        sf = torch.tanh(FR_W)\r\n        sf = torch.where(sf<0, 0, sf)\r\n        input_n = -C * (input-sf) + input\r\n        input_n = torch.where(input_n<0, 0, input_n)\r\n        input = input_n*I_max\r\n        input_r = torch.round(input)\r\n        Spike = torch.zeros(num_S1, dtype=torch.float)\r\n        self.node[0].n_reset()\r\n        if TrickID == 1:\r\n            for i in range(num_neuron):\r\n                input_ri = int(input_r[i].item())\r\n                FR_i = spike_num_list[input_ri]\r\n                Spike[i] = FR_i\r\n            FR_n = Spike\r\n        else:\r\n            for t in range(Simulation_time):    \r\n                self.out=self.node[0](input_r)\r\n                n_Spike = self.node[0].spike          \r\n                Spike = Spike + n_Spike\r\n            FR_n = Spike/Simulation_time\r\n\r\n        S = input_n\r\n        S = torch.where(input_n>= fire_threshold, 1, S)\r\n        S = torch.where(input_n< fire_threshold, 0, S)\r\n        if torch.sum(S) > 0:\r\n            Fired = Fired + 1; \r\n            W_LatInh = torch.tanh(W_LatInh - 2 * torch.acos(S) * torch.exp(Fired) - 1) + 1\r\n\r\n        return FR_n, input_n, Fired, W_LatInh\r\n\r\n    def reset(self):\r\n        for i in range(len(self.node)):\r\n            self.node[i].n_reset()\r\n\r\nclass EBANet(nn.Module):\r\n    def __init__(self,connection):\r\n        super().__init__()\r\n        self.node = []\r\n        self.node.append(IzhNodeMU(threshold=param_threshold, a=param_a, b=param_b, c=param_c, d=param_d, mem=param_mem, u=param_u, dt=param_dt))\r\n        self.connection = connection\r\n\r\n    def forward(self, input, FR, C, Fired, W_LatInh):  \r\n        FR_W = torch.zeros(num_neuron, dtype=torch.float)\r\n        if len(FR.shape) == 1:\r\n            FR_W = FR*self.connection[0].weight\r\n        else:\r\n            for i in range(FR.shape[0]):\r\n                FR_Wi = FR[i]*self.connection[i].weight\r\n                FR_W = FR_W + FR_Wi\r\n        sf = torch.tanh(FR_W)\r\n        sf = torch.where(sf<0, 0, sf)\r\n        input_n = -C * (input-sf) + input\r\n        input_n = torch.where(input_n<0, 0, input_n)\r\n        input = input_n*I_max\r\n        input_r = torch.round(input)\r\n        Spike = torch.zeros(num_S1, dtype=torch.float)\r\n        self.node[0].n_reset()\r\n        if TrickID == 1:\r\n            for i in range(num_neuron):\r\n                input_ri = int(input_r[i].item())\r\n                FR_i = spike_num_list[input_ri]\r\n                Spike[i] = FR_i\r\n            FR_n = Spike\r\n        else:\r\n            for t in range(Simulation_time):    \r\n                self.out=self.node[0](input_r)\r\n                n_Spike = self.node[0].spike          \r\n                Spike = Spike + n_Spike\r\n            FR_n = Spike/Simulation_time\r\n        \r\n        S = input_n\r\n        S = torch.where(input_n>= fire_threshold, 1, S)\r\n        S = torch.where(input_n< fire_threshold, 0, S)\r\n        if torch.sum(S) > 0:\r\n            Fired = Fired + 1; \r\n            W_LatInh = torch.tanh(W_LatInh - 2 * torch.acos(S) * torch.exp(Fired) - 1) + 1\r\n\r\n        return FR_n, input_n, Fired, W_LatInh\r\n\r\n    def reset(self):\r\n        for i in range(len(self.node)):\r\n            self.node[i].n_reset()\r\n\r\nclass TPJNet(nn.Module):\r\n    def __init__(self,connection):\r\n        super().__init__()\r\n        self.node = []\r\n        self.node.append(IzhNodeMU(threshold=param_threshold, a=param_a, b=param_b, c=param_c, d=param_d, mem=param_mem, u=param_u, dt=param_dt))\r\n        self.connection = connection       \r\n    \r\n    def forward(self, input, FR, C):   \r\n        FR_W = torch.zeros(num_neuron, dtype=torch.float)\r\n        if len(FR.shape) == 1:\r\n            FR_W = FR*self.connection[0].weight\r\n        else:\r\n            for i in range(FR.shape[0]):\r\n                FR_Wi = FR[i]*self.connection[i].weight\r\n                FR_W = FR_W + FR_Wi\r\n        sf = torch.tanh(FR_W)\r\n        sf = torch.where(sf<0, 0, sf)\r\n        input_n = -C * (input-sf) + input\r\n        input_n = torch.where(input_n<0, 0, input_n)\r\n        input = input_n*I_max\r\n        input_r = torch.round(input)\r\n        Spike = torch.zeros(num_S1, dtype=torch.float)\r\n        self.node[0].n_reset()\r\n        if TrickID == 1:\r\n            for i in range(num_neuron):\r\n                input_ri = int(input_r[i].item())\r\n                FR_i = spike_num_list[input_ri]\r\n                Spike[i] = FR_i\r\n            FR_n = Spike\r\n        else:\r\n            for t in range(Simulation_time):    \r\n                self.out=self.node[0](input_r)\r\n                n_Spike = self.node[0].spike          \r\n                Spike = Spike + n_Spike\r\n            FR_n = Spike/Simulation_time\r\n        return FR_n, input_n\r\n\r\n    def reset(self):\r\n        for i in range(len(self.node)):\r\n            self.node[i].n_reset()\r\n\r\n    def UpdateWeight(self, i, W):\r\n        self.connection[i].weight.data = W\r\n\r\nclass AINet(nn.Module):\r\n    def __init__(self,connection):\r\n        super().__init__()\r\n        self.node = []\r\n        self.node.append(IzhNodeMU(threshold=param_threshold, a=param_a, b=param_b, c=param_c, d=param_d, mem=param_mem, u=param_u, dt=param_dt))\r\n        self.connection = connection       \r\n\r\n    def forward(self, input, FR, C):   \r\n        FR_W = torch.zeros(num_neuron, dtype=torch.float)\r\n        if len(FR.shape) == 1:\r\n            FR_W = FR*self.connection[0].weight\r\n        else:\r\n            for i in range(FR.shape[0]):\r\n                FR_Wi = FR[i]*self.connection[i].weight\r\n                FR_W = FR_W + FR_Wi\r\n        sf = torch.tanh(FR_W)\r\n        sf = torch.where(sf<0, 0, sf)\r\n        input_n = -C * (input-sf) + input\r\n        input_n = torch.where(input_n<0, 0, input_n)\r\n        input = input_n*I_max\r\n        input_r = torch.round(input)\r\n        Spike = torch.zeros(num_S1, dtype=torch.float)\r\n        self.node[0].n_reset()\r\n        if TrickID == 1:\r\n            for i in range(num_neuron):\r\n                input_ri = int(input_r[i].item())\r\n                FR_i = spike_num_list[input_ri]\r\n                Spike[i] = FR_i\r\n            FR_n = Spike\r\n        else:\r\n            for t in range(Simulation_time):    \r\n                self.out=self.node[0](input_r)\r\n                n_Spike = self.node[0].spike          \r\n                Spike = Spike + n_Spike\r\n            FR_n = Spike/Simulation_time\r\n        return FR_n, input_n\r\n   \r\n    def reset(self):\r\n        for i in range(len(self.node)):\r\n            self.node[i].n_reset()\r\n\r\n    def UpdateWeight(self, i, W, WIn):\r\n        self.connection[i].weight.data = self.connection[i].weight.data + W*WIn\r\n\r\n\r\ndef DeltaWeight(Pre, Pre_n, Post, Post_n):\r\n    alpha = -0.0035\r\n    beta = 0.35\r\n    gamma = -0.55\r\n    T1 = alpha * (Pre_n*Post_n)\r\n    T2 = beta * (Pre_n*(Post_n-Post))\r\n    T3 = gamma * ((Pre_n-Pre)*Post_n)\r\n    dW = T1 + T2 + T3\r\n    return dW\r\n\r\n\r\n\r\nif __name__==\"__main__\":\r\n    \"\"\"\r\n    Set the number of neurons, and each neuron represents unique motion information (such as angle)\r\n    \"\"\"\r\n    # number of neurons\r\n    num_neuron = 9 \r\n    num_M1 = num_neuron \r\n    num_S1  = num_neuron \r\n    num_TPJ = num_neuron\r\n    num_V = num_neuron\r\n    num_EBA = num_neuron\r\n    num_AI = num_neuron\r\n\r\n    Init_Weight = 1.\r\n\r\n    param_threshold = 30.\r\n    param_a = 0.02\r\n    param_b = -0.1\r\n    param_c = -55.\r\n    param_d = 18.\r\n    param_mem = -70.\r\n    param_u = 0.\r\n    param_dt = 1.\r\n    Simulation_time = 1000\r\n    I_max = 1000\r\n\r\n    # When the TrickID is set to 1, it means that the mapping relationship from input current \r\n    # to firing rate is obtained directly by loading the Izh.npy, \r\n    # which can significantly reduce the program running time\r\n    TrickID = 1 \r\n    if TrickID == 1:\r\n        spike_num_list=np.load('Izh.npy')\r\n        spike_num_list = spike_num_list/I_max\r\n\r\n    ##############################\r\n    # M1\r\n    ##############################\r\n    # M1_Input-M1\r\n    M1_connection = []\r\n    con_matrix0 = torch.ones(num_M1, dtype=torch.float)*Init_Weight\r\n    M1_connection.append(CustomLinear(con_matrix0))\r\n    M1 = M1Net(M1_connection)\r\n  \r\n    ##############################\r\n    # V\r\n    ##############################\r\n    # V_Input-V\r\n    V_connection = []\r\n    con_matrix3 = torch.ones(num_V, dtype=torch.float)*Init_Weight\r\n    V_connection.append(CustomLinear(con_matrix3))\r\n    V = VNet(V_connection)\r\n \r\n    ##############################\r\n    # S1\r\n    ##############################\r\n    # M1-S1\r\n    S1_connection = []\r\n    con_matrix1 = torch.ones(num_S1, dtype=torch.float)*Init_Weight\r\n    S1_connection.append(CustomLinear(con_matrix1))\r\n    S1 = S1Net(S1_connection)\r\n\r\n    ##############################\r\n    # EBA\r\n    ##############################\r\n    # V-EBA\r\n    EBA_connection = []\r\n    con_matrix4 = torch.ones(num_EBA, dtype=torch.float)*Init_Weight\r\n    EBA_connection.append(CustomLinear(con_matrix4))\r\n    EBA = EBANet(EBA_connection)\r\n\r\n    ##############################\r\n    # TPJ\r\n    ##############################\r\n    # S1-TPJ, EBA-TPJ\r\n    TPJ_connection = []\r\n    # S1-TPJ\r\n    con_matrix2 = torch.ones(num_TPJ, dtype=torch.float)*Init_Weight*150\r\n    TPJ_connection.append(CustomLinear(con_matrix2))\r\n    # EBA-TPJ\r\n    con_matrix5 = torch.ones(num_TPJ, dtype=torch.float)*Init_Weight*150\r\n    TPJ_connection.append(CustomLinear(con_matrix5))\r\n    TPJ = TPJNet(TPJ_connection)\r\n\r\n    ##############################\r\n    # AI\r\n    ##############################\r\n    # S1-AI, TPJ-AI, EBA-AI\r\n    AI_connection = []\r\n    # S1-AI\r\n    con_matrix6 = torch.ones(num_AI, dtype=torch.float)*Init_Weight\r\n    AI_connection.append(CustomLinear(con_matrix6))\r\n    # TPJ-AI\r\n    con_matrix7 = torch.ones(num_AI, dtype=torch.float)*Init_Weight\r\n    AI_connection.append(CustomLinear(con_matrix7))\r\n    # EBA-AI\r\n    con_matrix8 = torch.ones(num_AI, dtype=torch.float)*Init_Weight\r\n    AI_connection.append(CustomLinear(con_matrix8))\r\n    AI = AINet(AI_connection)   \r\n    \r\n    ##############################\r\n    # Coding\r\n    ##############################\r\n    S = 1\r\n    ISI = 1\r\n    JMax = int((num_neuron-1)/2)\r\n    listJ = list(range(-JMax,JMax+1))\r\n    Coding = torch.zeros([num_neuron, num_neuron], dtype=torch.float)\r\n    for i in range(len(listJ)):\r\n        e = float(listJ[i])\r\n        listY = []\r\n        for j in range(len(listJ)):\r\n            x = float(listJ[j])\r\n            y = math.exp(-(x-e)**2/(2*S**2))\r\n            Coding[i][j] = y\r\n\r\n    ##############################\r\n    # Train\r\n    ##############################\r\n    MoveNum = 25  \r\n    Time = 300\r\n    CT = 100\r\n    Motion_Start = 1\r\n    Motion_End = Motion_Start + CT\r\n    Vision_Start = Motion_End\r\n    Vision_End = Vision_Start + CT\r\n\r\n    CM1 = 0.04 \r\n    CV = 0.04 \r\n    CS1 = 0.04 \r\n    CEBA = 0.04 \r\n    CTPJ = 0.01 \r\n    CAI = 0.15 \r\n\r\n    for k in range(num_neuron):\r\n        for i in range(MoveNum): \r\n            print(i)        \r\n            M1_input = torch.zeros(num_M1, dtype=torch.float)\r\n            V_input = torch.zeros(num_V, dtype=torch.float)\r\n            S1_input = torch.zeros(num_S1, dtype=torch.float)\r\n            TPJ_input = torch.zeros(num_TPJ, dtype=torch.float)\r\n            EBA_input = torch.zeros(num_EBA, dtype=torch.float)\r\n            AI_input = torch.zeros(num_AI, dtype=torch.float)           \r\n            \r\n            FR_M1 = torch.zeros(num_M1, dtype=torch.float)\r\n            FR_V = torch.zeros(num_V, dtype=torch.float)\r\n            FR_S1 = torch.zeros( num_S1, dtype=torch.float)\r\n            FR_EBA = torch.zeros(num_EBA, dtype=torch.float)\r\n            FR_TPJ = torch.zeros(num_TPJ, dtype=torch.float)\r\n            FR_AI = torch.zeros( num_AI, dtype=torch.float)\r\n\r\n            dW_S1TPJ = torch.zeros(num_M1, dtype=torch.float)\r\n            dW_EBATPJ = torch.zeros(num_M1, dtype=torch.float)\r\n            dW_S1AI = torch.zeros(num_M1, dtype=torch.float)\r\n            dW_EBAAI = torch.zeros(num_M1, dtype=torch.float)\r\n\r\n            fire_threshold = 0.7\r\n            W_LatInh_Init = torch.ones(num_neuron, dtype=torch.float)*Init_Weight\r\n            W_LatInh_S1_AI = W_LatInh_Init\r\n            W_LatInh_EBA_AI = W_LatInh_Init\r\n            Fired_S1 = torch.zeros(num_S1, dtype=torch.float)\r\n            Fired_EBA = torch.zeros(num_EBA, dtype=torch.float)\r\n            \r\n            with torch.no_grad():\r\n                for t in range(1,Time+1):\r\n                    S_M1 = torch.zeros(num_M1, dtype=torch.float)\r\n                    S_V = torch.zeros(num_V, dtype=torch.float)\r\n\r\n                    if t>=Motion_Start and t<=Motion_End:\r\n                        S_M1 = Coding[k]\r\n                        M1_input = (1-(1-CM1)**t)*S_M1 \r\n                    else:\r\n                        M1_input = S_M1\r\n                    \r\n                        \r\n                    if t>=Vision_Start and t<=Vision_End:\r\n                        S_V = Coding[k]\r\n                        V_input = (1-(1-CV)**(t-CT))*S_V \r\n                    else:\r\n                        V_input = S_V\r\n\r\n                    FR_M1_n = M1(M1_input)\r\n                    \r\n                    FR_V_n = V(V_input)       \r\n                                \r\n                    [FR_S1_n, S1_input_n, Fired_S1, W_LatInh_S1_AI] = S1(S1_input, FR_M1_n, CS1, Fired_S1, W_LatInh_S1_AI)\r\n\r\n                    [FR_EBA_n, EBA_input_n, Fired_EBA, W_LatInh_EBA_AI] =  EBA(EBA_input, FR_V_n, CEBA, Fired_EBA, W_LatInh_EBA_AI)\r\n\r\n                    FR_Input_TPJ_n = torch.stack((FR_S1_n, FR_EBA_n), 0)\r\n                    [FR_TPJ_n, TPJ_input_n] = TPJ(TPJ_input, FR_Input_TPJ_n, CTPJ)\r\n                    \r\n                    FR_Input_AI = torch.stack((FR_S1_n, FR_TPJ_n, FR_EBA_n), 0)\r\n                    [FR_AI_n, AI_input_n] = AI(AI_input, FR_Input_AI, CAI)\r\n\r\n                    # Update weights\r\n                    # S1-AI\r\n                    ddW_S1AI = DeltaWeight(FR_S1, FR_S1_n, FR_AI, FR_AI_n)\r\n                    dW_S1AI = dW_S1AI + ddW_S1AI\r\n                    # EBA-AI\r\n                    ddW_EBAAI = DeltaWeight(FR_EBA, FR_EBA_n, FR_AI, FR_AI_n)\r\n                    dW_EBAAI = dW_EBAAI + ddW_EBAAI\r\n\r\n                    FR_M1 = FR_M1_n\r\n                    FR_V = FR_V_n\r\n                    FR_S1 = FR_S1_n\r\n                    FR_EBA = FR_EBA_n\r\n                    FR_TPJ = FR_TPJ_n\r\n                    FR_AI = FR_AI_n\r\n                    \r\n\r\n                    S1_input = S1_input_n\r\n                    TPJ_input = TPJ_input_n\r\n                    EBA_input = EBA_input_n\r\n                    AI_input = AI_input_n\r\n\r\n            AI.UpdateWeight(0, dW_S1AI, W_LatInh_S1_AI)\r\n            AI.UpdateWeight(2, dW_EBAAI, W_LatInh_EBA_AI)\r\n            \r\n            print(AI.connection[0].weight.data) # dW_S1AI\r\n            print(AI.connection[2].weight.data) # dW_EBAAI\r\n            \r\n            M1.reset()\r\n            V.reset()\r\n            S1.reset()\r\n            EBA.reset()\r\n            TPJ.reset()\r\n            AI.reset()                \r\n            \r\n    np.save('W_S1_AI.npy', AI.connection[0].weight.data)\r\n    np.save('W_EBA_AI.npy', AI.connection[2].weight.data)\r\n\r\n    print('Training End')"
  },
  {
    "path": "examples/Embodied_Cognition/RHI/ReadMe.md",
    "content": "\n"
  },
  {
    "path": "examples/Hardware_acceleration/README.md",
    "content": "## FireFly: A High-Throughput Hardware Accelerator for Spiking Neural Networks\n\n### Demo of Deploying SNNs on FPGA platform\n\nThis is an example of deploying an SNN model on Xilinx Zynq Ultrascale FPGA based on Braincog.\n\n### Requirements\n\n- Xilinx Zynq Ultrascale FPGA evaluation board Ultra96v2 or ZCU104.\n- PYNQ images for the chosen evaluation boards. You can download the latest pre-compiled images from the [PYNQ website](http://www.pynq.io/board.html), or you can compile a new one following the [PYNQ Tutorial](https://pynq.readthedocs.io/en/latest/). Install the PYNQ image to the SD card, and boot the evaluation board in SD mode.\n\n### Examples\n\nClone the project to fetch the necessary bitstream files and pre-processed SNN models, copy all the files to the Ultra96v2 or ZCU104 board.\n\n```shell\ngit clone https://github.com/adamgallas/firefly_v1_cifar_test\n```\n\nOpen a terminal in Ultra96v2 or ZCU104. Install einops on Ultra96v2 or ZCU104.\n\n```shell\ncd firefly_v1_common\npip install einops-0.6.0-py3-none-any.whl\n```\n\nRun CIFAR10  classification test on Ultra96v2:\n\n```shell\npython ultra96_test.py\n```\n\nRun CIFAR10  classification test on ZCU104:\n\n```python\npython zcu104_test.py\n```\n\n### Citation\n\n### Citation \nIf you find this work helpful, please consider citing it:\n\n```BibTex\n@article{li2023firefly,\n  title={FireFly: A High-Throughput Hardware Accelerator for Spiking Neural Networks With Efficient DSP and Memory Optimization},\n  author={Li, Jindong and Shen, Guobin and Zhao, Dongcheng and Zhang, Qian and Zeng, Yi},\n  journal={IEEE Transactions on Very Large Scale Integration (VLSI) Systems},\n  year={2023},\n  publisher={IEEE}\n}\n```"
  },
  {
    "path": "examples/Hardware_acceleration/firefly_v1_schedule_on_pynq.py",
    "content": "import numpy as np\nimport tqdm\nfrom standalone_utils import *\nimport math\nimport time\nimport ctypes as ct\n\n\nclass FireFlyV1ConvSchedule:\n    def __init__(\n            self,\n            ctrl_io,\n            allocate_method,\n            input_buffer_addr,\n            output_buffer_addr,\n            weight_data,\n            bias_data,\n\n            parallel_channel=16,\n            kernel_size=3,\n            input_channels=64,\n            output_channels=128,\n            width=32,\n            height=32,\n            enable_pooling=False,\n            direct_adapt=False,\n            winner_takes_all=False,\n            final_conv=False,\n            time_step=8,\n            threshold=64,\n            max_cnt=2048\n    ):\n        self.ctrl_io = ctrl_io\n\n        self.input_buffer_addr = np.uint32(input_buffer_addr)\n        self.output_buffer_addr = np.uint32(output_buffer_addr)\n\n        self.weight_buffer = allocate_method(shape=weight_data.size, dtype=np.int8)\n        self.bias_buffer = allocate_method(shape=bias_data.size, dtype=np.int16)\n        self.weight_buffer_addr = np.uint32(self.weight_buffer.device_address)\n        self.bias_buffer_addr = np.uint32(self.bias_buffer.device_address)\n\n        self.weight_buffer[:] = np.ascontiguousarray(weight_data.flatten())\n        self.bias_buffer[:] = np.ascontiguousarray(bias_data.flatten())\n        self.weight_buffer.flush()\n        self.bias_buffer.flush()\n        self.max_cnt = max_cnt\n\n        self.parallel_channel = np.uint32(parallel_channel)\n        self.kernel_size = np.uint32(kernel_size)\n        self.input_channels = np.uint32(input_channels)\n        self.output_channels = np.uint32(output_channels)\n        self.width = np.uint32(width)\n        self.height = np.uint32(height)\n        self.enable_pooling = np.uint32(enable_pooling)\n        self.direct_adapt = np.uint32(direct_adapt)\n        self.winner_takes_all = np.uint32(winner_takes_all)\n        self.time_step = np.uint32(time_step)\n        self.threshold = np.int32(threshold)\n\n        self.out_width = np.uint32(width >> enable_pooling)\n        self.out_height = np.uint32(height >> enable_pooling)\n\n        self.numOfIFMs = np.uint32(input_channels / parallel_channel - 1)\n        self.numOfOFMs = np.uint32(output_channels / parallel_channel - 1)\n        self.numOfTimeSteps = np.uint32(time_step - 1)\n        self.numOfTimeStepIFMs = np.uint32((input_channels / parallel_channel) * time_step - 1)\n        self.numOfTimeStepOFMs = np.uint32((output_channels / parallel_channel) * time_step - 1)\n        self.weightsLength = np.uint32(input_channels - 1)\n\n        if direct_adapt:\n            factor = kernel_size * kernel_size\n            padded_length = math.ceil(input_channels / parallel_channel / factor)\n            self.numOfIFMs = np.uint32(padded_length - 1)\n            self.numOfTimeStepIFMs = np.uint32(padded_length * time_step - 1)\n            self.weightsLength = np.uint32(padded_length * parallel_channel - 1)\n            self.out_width = np.uint32(1)\n            self.out_height = np.uint32(1)\n\n        self.mm2s_fix_len = np.uint32(self.width * self.height * self.time_step * self.input_channels / 8)\n        self.s2mm_fix_len = np.uint32(self.out_width * self.out_height * self.parallel_channel / 8)\n        self.bias_len = np.uint32(self.output_channels * 2)\n        self.weight_len = np.uint32(self.output_channels * self.input_channels * self.kernel_size * self.kernel_size)\n\n        self.stride_of_channel = np.uint32(self.out_width * self.out_height * self.parallel_channel / 8)\n        self.stride_of_time_step = np.uint32(self.out_width * self.out_height * self.output_channels / 8)\n\n        if direct_adapt:\n            self.mm2s_fix_len = np.uint32(self.time_step * self.input_channels / 8)\n            self.weight_len = np.uint32(self.output_channels * self.input_channels)\n            self.stride_of_channel = np.uint32(2 * self.parallel_channel / 8)\n            self.stride_of_time_step = np.uint32(2 * self.output_channels / 8)\n\n        if final_conv:\n            self.flatten_channel = np.uint32(self.out_width * self.out_height * self.output_channels)\n            factor = kernel_size * kernel_size * 8 * 4\n            round_channel = int(math.ceil(self.flatten_channel / factor) * factor)\n            self.stride_of_time_step = np.uint32(round_channel / 8)\n\n        self.configReg_0x00 = np.uint32(((self.time_step - 1) << 16) + (self.numOfOFMs << 8) + self.numOfIFMs).tobytes()\n        self.configReg_0x04 = np.uint32((self.numOfTimeStepOFMs << 12) + self.numOfTimeStepIFMs).tobytes()\n        self.configReg_0x08 = np.uint32((self.threshold << 14) + self.weightsLength).tobytes()\n        self.configReg_0x0c = np.uint32((self.winner_takes_all << 30) + (self.direct_adapt << 29) + (\n                self.enable_pooling << 28) + ((self.height - 1) << 16) + self.width - 1).tobytes()\n        self.configReg_0x20 = np.uint32(self.stride_of_time_step).tobytes()\n        self.configReg_0x24 = np.uint32(self.stride_of_channel).tobytes()\n        self.configReg_0x28 = np.uint32(self.input_buffer_addr).tobytes()\n        self.configReg_0x2c = np.uint32(self.output_buffer_addr).tobytes()\n        self.configReg_0x30 = np.uint32(self.mm2s_fix_len).tobytes()\n        self.configReg_0x34 = np.uint32(self.s2mm_fix_len).tobytes()\n        self.configReg_0x38 = np.uint32(self.weight_buffer_addr).tobytes()\n        self.configReg_0x3c = np.uint32(self.weight_len).tobytes()\n        self.configReg_0x40 = np.uint32(self.bias_buffer_addr).tobytes()\n        self.configReg_0x44 = np.uint32(self.bias_len).tobytes()\n\n        self.paramCmd = np.uint32(0x00010000).tobytes()\n        self.inOutCmd = np.uint32(0x00000101).tobytes()\n\n    def gen_cmd(self):\n        cmd_list = []\n        cmd_list.append(np.frombuffer(self.configReg_0x00, dtype=np.uint32))\n        cmd_list.append(np.frombuffer(self.configReg_0x04, dtype=np.uint32))\n        cmd_list.append(np.frombuffer(self.configReg_0x08, dtype=np.uint32))\n        cmd_list.append(np.frombuffer(self.configReg_0x0c, dtype=np.uint32))\n        cmd_list.append(np.frombuffer(self.configReg_0x20, dtype=np.uint32))\n        cmd_list.append(np.frombuffer(self.configReg_0x24, dtype=np.uint32))\n        cmd_list.append(np.frombuffer(self.configReg_0x28, dtype=np.uint32))\n        cmd_list.append(np.frombuffer(self.configReg_0x2c, dtype=np.uint32))\n        cmd_list.append(np.frombuffer(self.configReg_0x30, dtype=np.uint32))\n        cmd_list.append(np.frombuffer(self.configReg_0x34, dtype=np.uint32))\n        cmd_list.append(np.frombuffer(self.configReg_0x38, dtype=np.uint32))\n        cmd_list.append(np.frombuffer(self.configReg_0x3c, dtype=np.uint32))\n        cmd_list.append(np.frombuffer(self.configReg_0x40, dtype=np.uint32))\n        cmd_list.append(np.frombuffer(self.configReg_0x44, dtype=np.uint32))\n        cmd_list.append(np.frombuffer(self.paramCmd, dtype=np.uint32))\n        cmd_list.append(np.frombuffer(self.inOutCmd, dtype=np.uint32))\n        return np.array(cmd_list).flatten()\n\n    def send_config(self):\n        self.ctrl_io.write(0x00, self.configReg_0x00)\n        self.ctrl_io.write(0x04, self.configReg_0x04)\n        self.ctrl_io.write(0x08, self.configReg_0x08)\n        self.ctrl_io.write(0x0c, self.configReg_0x0c)\n\n        self.ctrl_io.write(0x20, self.configReg_0x20)\n        self.ctrl_io.write(0x24, self.configReg_0x24)\n        self.ctrl_io.write(0x28, self.configReg_0x28)\n        self.ctrl_io.write(0x2c, self.configReg_0x2c)\n        self.ctrl_io.write(0x30, self.configReg_0x30)\n        self.ctrl_io.write(0x34, self.configReg_0x34)\n        self.ctrl_io.write(0x38, self.configReg_0x38)\n        self.ctrl_io.write(0x3c, self.configReg_0x3c)\n        self.ctrl_io.write(0x40, self.configReg_0x40)\n        self.ctrl_io.write(0x44, self.configReg_0x44)\n\n    def begin_schedule_non_blocking(self):\n        self.ctrl_io.write(0x48, self.paramCmd)\n        self.ctrl_io.write(0x48, self.paramCmd)\n        self.ctrl_io.write(0x48, self.inOutCmd)\n\n    def begin_schedule_blocking(self):\n        self.begin_schedule_non_blocking()\n        while self.ctrl_io.read(0x18, length=8) == 0:\n            continue\n        self.clear_schedule()\n\n    def clear_schedule(self):\n        self.ctrl_io.write(0x18, 0)\n\n    def run_all(self):\n        self.ctrl_io.write(0x00, self.configReg_0x00)\n        self.ctrl_io.write(0x04, self.configReg_0x04)\n        self.ctrl_io.write(0x08, self.configReg_0x08)\n        self.ctrl_io.write(0x0c, self.configReg_0x0c)\n\n        self.ctrl_io.write(0x20, self.configReg_0x20)\n        self.ctrl_io.write(0x24, self.configReg_0x24)\n        self.ctrl_io.write(0x28, self.configReg_0x28)\n        self.ctrl_io.write(0x2c, self.configReg_0x2c)\n        self.ctrl_io.write(0x30, self.configReg_0x30)\n        self.ctrl_io.write(0x34, self.configReg_0x34)\n        self.ctrl_io.write(0x38, self.configReg_0x38)\n        self.ctrl_io.write(0x3c, self.configReg_0x3c)\n        self.ctrl_io.write(0x40, self.configReg_0x40)\n        self.ctrl_io.write(0x44, self.configReg_0x44)\n\n        self.ctrl_io.write(0x48, self.paramCmd)\n        self.ctrl_io.write(0x48, self.paramCmd)\n        self.ctrl_io.write(0x48, self.inOutCmd)\n        cnt = 0\n        while self.ctrl_io.read(0x18, length=8) == 0:\n            cnt = cnt + 1\n            if cnt > self.max_cnt:\n                print(\"timeout, abort!\")\n                break\n        end = time.time()\n        self.ctrl_io.write(0x18, 0)\n\n    def read_status(self):\n        status = np.uint32(self.ctrl_io.read(0x50, length=4)).tobytes()\n        busy_status = status[0]\n        input_status = status[1]\n        output_status = status[2]\n        param_status = status[3]\n\n        print(\"busy_status\", busy_status)\n        print(\"input_status\", input_status)\n        print(\"output_status\", output_status)\n        print(\"param_status\", param_status)\n\n\ndef create_schedule(model_config_list: list,\n                    ctrl_io,\n                    allocate_method,\n                    buffer_0,\n                    buffer_1,\n                    image_height,\n                    image_width,\n                    time_step=4,\n                    parallel_channel=16\n                    ):\n    schedule_list = []\n    curr_image_height = image_height\n    curr_image_width = image_width\n    input_buffer_addr = buffer_0.device_address\n    output_buffer_addr = buffer_1.device_address\n\n    for config in model_config_list:\n        if config[\"layer_type\"] == \"conv+IFNode\":\n            schedule = FireFlyV1ConvSchedule(\n                ctrl_io=ctrl_io,\n                allocate_method=allocate_method,\n\n                input_buffer_addr=input_buffer_addr,\n                output_buffer_addr=output_buffer_addr,\n                weight_data=conv_weight_channel_tiling(parallel_channel, config[\"weight\"]),\n                bias_data=config[\"bias\"].flatten(),\n\n                parallel_channel=parallel_channel,\n                kernel_size=3,\n                input_channels=config[\"input_channel\"],\n                output_channels=config[\"output_channel\"],\n                width=curr_image_width,\n                height=curr_image_height,\n                enable_pooling=False,\n                time_step=time_step,\n                threshold=config[\"threshold\"],\n                final_conv=\"flatten\" in config\n            )\n            schedule_list.append(schedule)\n            input_buffer_addr, output_buffer_addr = output_buffer_addr, input_buffer_addr\n\n        elif config[\"layer_type\"] == \"conv+IFNode+maxpool\":\n            schedule = FireFlyV1ConvSchedule(\n                ctrl_io=ctrl_io,\n                allocate_method=allocate_method,\n\n                input_buffer_addr=input_buffer_addr,\n                output_buffer_addr=output_buffer_addr,\n                weight_data=conv_weight_channel_tiling(parallel_channel, config[\"weight\"]),\n                bias_data=config[\"bias\"].flatten(),\n\n                parallel_channel=parallel_channel,\n                kernel_size=3,\n                input_channels=config[\"input_channel\"],\n                output_channels=config[\"output_channel\"],\n                width=curr_image_width,\n                height=curr_image_height,\n                enable_pooling=True,\n                time_step=time_step,\n                threshold=config[\"threshold\"],\n                final_conv=\"flatten\" in config\n            )\n            schedule_list.append(schedule)\n            input_buffer_addr, output_buffer_addr = output_buffer_addr, input_buffer_addr\n            curr_image_height = int(curr_image_height / 2)\n            curr_image_width = int(curr_image_width / 2)\n\n        elif config[\"layer_type\"].__contains__(\"linear\"):\n\n            input_channel = config[\"input_channel\"]\n            weight = config[\"weight\"]\n            if \"weight_reshape\" in config:\n                factor = 9 * 8 * 4\n                round_channel = int(math.ceil(input_channel / factor) * factor)\n                weight = rearrange(weight, \"o (i p h w) -> o (i h w p)\", p=parallel_channel,\n                                   h=curr_image_height, w=curr_image_width)\n                weight = np.pad(weight, ((0, 0), (0, round_channel - input_channel)), mode=\"constant\")\n                input_channel = round_channel\n\n            weight = linear_weight_channel_tiling(parallel_channel, weight)\n\n            schedule = FireFlyV1ConvSchedule(\n                ctrl_io=ctrl_io,\n                allocate_method=allocate_method,\n\n                input_buffer_addr=input_buffer_addr,\n                output_buffer_addr=output_buffer_addr,\n                weight_data=weight,\n                bias_data=config[\"bias\"].flatten(),\n\n                parallel_channel=parallel_channel,\n                kernel_size=3,\n                input_channels=input_channel,\n                output_channels=config[\"output_channel\"],\n                width=3,\n                height=3,\n                enable_pooling=False,\n                time_step=time_step,\n                threshold=config[\"threshold\"],\n                direct_adapt=config[\"direct_adapt\"],\n                winner_takes_all=config[\"winner_take_all\"]\n            )\n            schedule_list.append(schedule)\n            input_buffer_addr, output_buffer_addr = output_buffer_addr, input_buffer_addr\n\n    return schedule_list, input_buffer_addr\n\n\ndef schedule_run_all(schedule_list):\n    for schedule in schedule_list:\n        schedule.run_all()\n\n\ndef gen_cmd_array(schedule_list):\n    cmd_array = []\n    for schedule in schedule_list:\n        cmd_array.append(schedule.gen_cmd().flatten())\n    return np.array(cmd_array)\n\n\ndef init_firefly_c_lib(path, schedule_list):\n    cmd_arr = gen_cmd_array(schedule_list)\n    lib = ct.CDLL(path)\n    sche = lib.firefly_v1_schedule\n    u32Ptr = ct.POINTER(ct.c_uint32)\n    u32PtrPtr = ct.POINTER(u32Ptr)\n\n    ct_arr = np.ctypeslib.as_ctypes(cmd_arr)\n    u32PtrArr = u32Ptr * ct_arr._length_\n    ct_ptr = ct.cast(u32PtrArr(*(ct.cast(row, u32Ptr) for row in ct_arr)), u32PtrPtr)\n    sche_len = ct.c_uint8(cmd_arr.shape[0])\n\n    return sche, ct_ptr, sche_len\n\n\ndef init_firefly_c_lib_with_time(path, schedule_list):\n    cmd_arr = gen_cmd_array(schedule_list)\n    lib = ct.CDLL(path)\n    sche = lib.firefly_v1_schedule_time_it\n    u32Ptr = ct.POINTER(ct.c_uint32)\n    u32PtrPtr = ct.POINTER(u32Ptr)\n\n    ct_arr = np.ctypeslib.as_ctypes(cmd_arr)\n    u32PtrArr = u32Ptr * ct_arr._length_\n    ct_ptr = ct.cast(u32PtrArr(*(ct.cast(row, u32Ptr) for row in ct_arr)), u32PtrPtr)\n    sche_len = ct.c_uint8(cmd_arr.shape[0])\n\n    return sche, ct_ptr, sche_len\n\n\ndef firefly_v1_simulate(model_config_list, x):\n    for config in model_config_list:\n        if config[\"layer_type\"] == \"input_quant_stub\":\n            x = np_quantize_prepare(x, config[\"scale\"], config[\"zero_point\"])\n        elif config[\"layer_type\"] == \"encoder+conv+IFNode\":\n            x = direct_coding(x, config[\"weight\"], config[\"bias\"], config[\"time_step\"], config[\"threshold\"])\n        elif config[\"layer_type\"] == \"conv+IFNode\":\n            x = conv_ifnode_forward(x, config[\"weight\"], config[\"bias\"], config[\"threshold\"])\n        elif config[\"layer_type\"] == \"conv+IFNode+maxpool\":\n            x = conv_ifnode_maxpool_forward(x, config[\"weight\"], config[\"bias\"], config[\"threshold\"])\n        elif config[\"layer_type\"] == \"linear+WTA\":\n            x = linear_wta_forward(x, config[\"weight\"], config[\"bias\"])\n        elif config[\"layer_type\"] == \"linear+IFNode\":\n            x = linear_ifnode_forward(x, config[\"weight\"], config[\"bias\"], config[\"threshold\"])\n\n    return x\n\n\ndef evaluate_simulate(model_config_list, sample):\n    correct = 0\n    for (image, target) in tqdm.tqdm(zip(sample[0], sample[1]), total=len(sample[0])):\n        sim_in = np.expand_dims(image.numpy(), axis=0)\n        _, sim_out = firefly_v1_simulate(model_config_list, sim_in)\n        correct += sim_out == target.item()\n    return correct / len(sample[0])\n"
  },
  {
    "path": "examples/Hardware_acceleration/standalone_utils.py",
    "content": "import math\n\nimport numpy as np\nfrom einops import rearrange\n\n\ndef get_im2col_indices(x_shape, field_height, field_width, padding=1, stride=1):\n    N, C, H, W = x_shape\n    assert (H + 2 * padding - field_height) % stride == 0\n    assert (W + 2 * padding - field_height) % stride == 0\n    out_height = int((H + 2 * padding - field_height) / stride + 1)\n    out_width = int((W + 2 * padding - field_width) / stride + 1)\n\n    i0 = np.repeat(np.arange(field_height), field_width)\n    i0 = np.tile(i0, C)\n    i1 = stride * np.repeat(np.arange(out_height), out_width)\n    j0 = np.tile(np.arange(field_width), field_height * C)\n    j1 = stride * np.tile(np.arange(out_width), out_height)\n    i = i0.reshape(-1, 1) + i1.reshape(1, -1)\n    j = j0.reshape(-1, 1) + j1.reshape(1, -1)\n\n    k = np.repeat(np.arange(C), field_height * field_width).reshape(-1, 1)\n\n    return k, i, j\n\n\ndef im2col_indices(x, field_height, field_width, padding=1, stride=1):\n    p = padding\n    x_padded = np.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant')\n    k, i, j = get_im2col_indices(x.shape, field_height, field_width, padding, stride)\n    cols = x_padded[:, k, i, j]\n    C = x.shape[1]\n    cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1)\n    return cols\n\n\ndef max_pool_forward_reshape(x, pool_param):\n    N, C, H, W = x.shape\n    pool_height, pool_width = pool_param['pool_height'], pool_param['pool_width']\n    stride = pool_param['stride']\n    assert pool_height == pool_width == stride, 'Invalid pool params'\n    assert H % pool_height == 0\n    assert W % pool_height == 0\n    x_reshaped = x.reshape(N, C, int(H / pool_height), pool_height, int(W / pool_width), pool_width)\n    out = x_reshaped.max(axis=3).max(axis=4)\n    return out\n\n\ndef max_pool_forward_fast(x, pool_param):\n    N, C, H, W = x.shape\n    pool_height, pool_width = pool_param['pool_height'], pool_param['pool_width']\n    stride = pool_param['stride']\n\n    same_size = pool_height == pool_width == stride\n    tiles = H % pool_height == 0 and W % pool_width == 0\n    if same_size and tiles:\n        out = max_pool_forward_reshape(x, pool_param)\n    else:\n        out = max_pool_forward_im2col(x, pool_param)\n    return out\n\n\ndef max_pool_forward_im2col(x, pool_param):\n    N, C, H, W = x.shape\n    pool_height, pool_width = pool_param['pool_height'], pool_param['pool_width']\n    stride = pool_param['stride']\n\n    assert (H - pool_height) % stride == 0, 'Invalid height'\n    assert (W - pool_width) % stride == 0, 'Invalid width'\n\n    out_height = int((H - pool_height) / stride + 1)\n    out_width = int((W - pool_width) / stride + 1)\n\n    x_split = x.reshape(N * C, 1, H, W)\n    x_cols = im2col_indices(x_split, pool_height, pool_width, padding=0, stride=stride)\n    x_cols_argmax = np.argmax(x_cols, axis=0)\n    x_cols_max = x_cols[x_cols_argmax, np.arange(x_cols.shape[1])]\n    out = x_cols_max.reshape(out_height, out_width, N, C).transpose(2, 3, 0, 1)\n    return out\n\n\ndef conv_forward_fast(x, w, b, pad=1, stride=1):\n    N, C, H, W = x.shape\n    # x = x.astype(np.int32)\n    w = w.astype(np.int32)\n    b = b.astype(np.int16)\n    num_filters, _, filter_height, filter_width = w.shape\n\n    out_height = int((H + 2 * pad - filter_height) / stride + 1)\n    out_width = int((W + 2 * pad - filter_width) / stride + 1)\n    out = np.zeros((N, num_filters, out_height, out_width), dtype=np.int32)\n\n    x_cols = im2col_indices(x, w.shape[2], w.shape[3], pad, stride)\n    res = w.reshape((w.shape[0], -1)).dot(x_cols) + b.reshape(-1, 1)\n    out = res.reshape(w.shape[0], out.shape[2], out.shape[3], x.shape[0])\n    out = out.transpose(3, 0, 1, 2)\n\n    return out\n\n\ndef spike_map_pack_to_bytes_array(spike_map, parallel):\n    buf_in = rearrange(spike_map, 't (c p) h w->t c h w p', p=parallel)\n    buf_in = np.packbits(buf_in.flatten(), bitorder='little')\n    return buf_in\n\n\ndef bytes_array_split_to_spike_map(buf_in, time_step, parallel, H, W):\n    unpacked = np.unpackbits(buf_in, bitorder='little')\n    unpacked = rearrange(unpacked, '(t c h w p)->t (c p) h w', t=time_step, p=parallel, h=H, w=W)\n    return unpacked\n\n\ndef preprocess(model_config_list, x, parallel):\n    time_step = model_config_list[1][\"time_step\"]\n    scale = model_config_list[0][\"scale\"]\n    zero_point = model_config_list[0][\"zero_point\"]\n    weight = model_config_list[1][\"weight\"]\n    bias = model_config_list[1][\"bias\"]\n    threshold = model_config_list[1][\"threshold\"]\n\n    encode_in = np_quantize_prepare(x, scale, zero_point)\n    encode_in = np.expand_dims(encode_in, axis=0)\n    firefly_in = direct_coding(encode_in, weight, bias, time_step, threshold)\n    packed = spike_map_pack_to_bytes_array(firefly_in, parallel)\n    return firefly_in, packed\n\n\ndef integrate_and_fire(y, threshold):\n    membrane = np.zeros(y.shape[1:], dtype=np.int32)\n    out_spike = []\n    for v in y:\n        membrane = membrane + v\n        o = membrane > threshold\n        out_spike.append(o)\n        membrane[o] = 0\n    return np.array(out_spike)\n\n\ndef direct_coding(x, w, b, time_step, threshold):\n    x = x.repeat(time_step, axis=0)\n    out_spike = conv_ifnode_forward(x, w, b, threshold)\n    return out_spike\n\n\ndef conv_ifnode_forward(x, w, b, threshold):\n    y = conv_forward_fast(x, w, b)\n    out_spike = integrate_and_fire(y, threshold)\n    return out_spike\n\n\ndef conv_ifnode_maxpool_forward(x, w, b, threshold):\n    y = conv_ifnode_forward(x, w, b, threshold)\n    out_spike = max_pool_forward_fast(y, {'pool_height': 2, 'pool_width': 2, 'stride': 2})\n    return out_spike\n\n\ndef linear_wta_forward(x, w, b):\n    x = x.astype(np.int32)\n    w = w.astype(np.int32)\n    b = b.astype(np.int32)\n    x = x.reshape([x.shape[0], -1])\n    x = np.pad(x, ((0, 0), (0, w.shape[1] - x.shape[1])), 'constant')\n    out = np.dot(x, w.T) + b\n    out_sum = out.sum(axis=0)\n    max_index = out_sum.argmax()\n    return out, max_index\n\n\ndef linear_ifnode_forward(x, w, b, threshold):\n    x = x.astype(np.int32)\n    w = w.astype(np.int32)\n    b = b.astype(np.int32)\n    x = x.reshape([x.shape[0], -1])\n    x = np.pad(x, ((0, 0), (0, w.shape[1] - x.shape[1])), 'constant')\n    out = np.dot(x, w.T) + b\n    out_spike = integrate_and_fire(out, threshold)\n    return out_spike\n\n\ndef pad_conv_weight_round_to_parallel(parallel, weight, pad_output_channel_only=False):\n    output_channel = weight.shape[0]\n    input_channel = weight.shape[1]\n    padded_output_channel = (parallel - (output_channel % parallel)) % parallel\n    padded_input_channel = (parallel - (input_channel % parallel)) % parallel\n    if pad_output_channel_only:\n        padded_input_channel = 0\n    new_weight = np.pad(weight, ((0, padded_output_channel), (0, padded_input_channel), (0, 0), (0, 0)), 'constant')\n    return new_weight\n\n\ndef pad_linear_weight_round_to_parallel(parallel, weight):\n    output_channel = weight.shape[0]\n    padded_output_channel = (parallel - (output_channel % parallel)) % parallel\n    new_weight = np.pad(weight, ((0, padded_output_channel), (0, 0)), 'constant')\n    return new_weight\n\n\ndef pad_linear_weight_round_to_factor(weight, factor):\n    input_channel = weight.shape[1]\n    round_channel = int(math.ceil(input_channel / factor) * factor)\n    padded_input_channel = round_channel - input_channel\n    new_weight = np.pad(weight, ((0, 0), (0, padded_input_channel)), 'constant')\n    return new_weight\n\n\ndef pad_bias_round_to_parallel(parallel, bias, pad_value=0):\n    channel = bias.shape[0]\n    padded_channel = (parallel - (channel % parallel)) % parallel\n    new_bias = np.pad(bias, (0, padded_channel), 'constant', constant_values=pad_value)\n    return new_bias\n\n\ndef np_quantize_per_tensor(x, scale, zero_point):\n    q_min = np.iinfo(np.int8).min\n    q_max = np.iinfo(np.int8).max\n    x = np.round(x / scale + zero_point)\n    x = np.clip(x, q_min, q_max)\n    return x.astype(np.int8)\n\n\ndef np_quantize_prepare(x, scale, zero_point):\n    x = np_quantize_per_tensor(x, scale, zero_point)\n    return x - zero_point\n\n\ndef conv_weight_channel_tiling(parallel, weight):\n    return rearrange(weight, '(o op) (i ip) kr kc -> (o i kr kc) ip op', op=parallel, ip=parallel)\n\n\ndef linear_weight_channel_tiling(parallel, weight):\n    return rearrange(weight, '(o op) (i ip) -> (o i) ip op', op=parallel, ip=parallel)\n\n\ndef conv_to_linear_weight_tiling(parallel, h, w, weight):\n    rearrange(weight, '(o op) (i ip h w)-> (o i h w) ip op', op=parallel, ip=parallel, h=h, w=w)\n\n\ndef init_input_buffer(input_spikes,\n                      parallel=16,\n                      stride_of_channel=8 * 1024,\n                      stride_of_time_step=512 * 1024):\n    t, c, h, w = input_spikes.shape\n    input_spikes_rearrange = rearrange(input_spikes, 't (c p) h w -> t c (h w p)', p=parallel)\n    pack_spikes = np.packbits(input_spikes_rearrange, axis=-1, bitorder='little')\n    input_buffer = np.zeros(stride_of_time_step * t, dtype=np.uint8)\n    length = int(h * w * parallel / 8)\n    for i in range(t):\n        for j in range(int(c / parallel)):\n            addr = i * stride_of_time_step + j * stride_of_channel\n            input_buffer[addr:addr + length] = pack_spikes[i, j]\n    return input_buffer\n\n\ndef get_from_output_buffer(output_buffer,\n                           t, c, h, w,\n                           parallel=16,\n                           stride_of_channel=8 * 1024,\n                           stride_of_time_step=512 * 1024):\n    ret = []\n    length = int(h * w * parallel / 8)\n    for i in range(t):\n        for j in range(int(c / parallel)):\n            addr = i * stride_of_time_step + j * stride_of_channel\n            ret.append(output_buffer[addr:addr + length])\n    ret = np.array(ret)\n    ret = np.unpackbits(ret, axis=-1, bitorder='little')\n    ret = rearrange(ret, '(t c) (h w p) -> t (c p) h w', p=parallel, t=t, h=h, w=w)\n    return ret.astype(bool)\n\n\ndef get_output_index(buffer, parallel):\n    valid_data = buffer[:parallel]\n    if parallel == 16:\n        return np.unpackbits(valid_data[12:14], bitorder='little').argmax()\n    elif parallel == 32:\n        return np.unpackbits(valid_data[24:28], bitorder='little').argmax()\n    else:\n        return 0\n\n\ndef save_model_config_list(model_config_list, path):\n    np.save(path, model_config_list)\n    return\n\n\ndef load_model_config_list(path):\n    model_config_list = np.load(path, allow_pickle=True)\n    return model_config_list\n"
  },
  {
    "path": "examples/Hardware_acceleration/ultra96_test.py",
    "content": "from standalone_utils import *\nfrom firefly_v1_schedule_on_pynq import *\nfrom pynq import PL\nfrom pynq import Overlay\nfrom pynq import allocate\nfrom pynq import MMIO\nimport numpy as np\nimport time\nfrom einops import rearrange\nol = Overlay('firefly_v1_ultra96_bitstream/sys_wrapper.bit')\n\nimage = np.load(\"firefly_v1_cifar10_data/image.npy\")\ntarget = np.load(\"firefly_v1_cifar10_data/target.npy\")\nmodel_config_list = load_model_config_list(\"firefly_v1_cifar10_data/snn7_cifar10_x16.npy\")\n\ninput_buffer = allocate(shape=(1<<23), dtype=np.uint8)\noutput_buffer = allocate(shape=(1<<23), dtype=np.uint8)\nctrl_io = MMIO(0x0400000000, 0x400)\n\nschedule_list,output_addr=create_schedule(\n    model_config_list=model_config_list,\n    ctrl_io=ctrl_io,\n    allocate_method=allocate,\n    buffer_0=input_buffer,\n    buffer_1=output_buffer,\n    image_height=32,\n    image_width=32,\n    time_step=4,\n    parallel_channel=16\n)\n\nsche, ct_ptr, sche_len = init_firefly_c_lib_with_time(\"firefly_v1_common/firefly_v1_lib.so\", schedule_list)\n\nprint(\" ----------------- initialize finish!\")\n\n\nprint(\" ----------------- python schedule begin\")\nerr_cnt = 0\nfor i in range(len(image)):\n    image_0 = image[i]\n    target_0 = target[i]\n    print(i)\n\n    start = time.time()\n    firefly_in, input_packed = preprocess(model_config_list, image_0, 16)\n    input_buffer[:input_packed.size]=input_packed\n    input_buffer.flush()\n    end = time.time()\n\n    elapsed = round((end - start) * 1000000)\n    # print(\"preprocess:\", elapsed, \"us\")\n\n    input_buffer[:input_packed.size]=input_packed\n    input_buffer.flush()\n    start = time.time()\n    schedule_run_all(schedule_list)\n    end = time.time()\n\n    elapsed = round((end - start) * 1000000)\n    print(\"snn inference:\", elapsed, \"us\")\n\n    output_buffer.invalidate()\n    test_out = output_buffer[:16]\n    test_index = np.unpackbits(test_out[12:14],bitorder='little').argmax()\n    print(\"test result:\", test_index,\"gold result\", target_0)\n    if test_index != target_0:\n        err_cnt = err_cnt + 1\n\nprint(\"accuray: \" , 1 - err_cnt/len(image))\nprint(\" ----------------- python schedule finish\")\n\nprint(\" ----------------- c schedule begin\")\nerr_cnt = 0\nfor i in range(len(image)):\n    image_0 = image[i]\n    target_0 = target[i]\n    print(i)\n\n    start = time.time()\n    firefly_in, input_packed = preprocess(model_config_list, image_0, 16)\n    input_buffer[:input_packed.size]=input_packed\n    input_buffer.flush()\n    end = time.time()\n\n    elapsed = round((end - start) * 1000000)\n    # print(\"preprocess:\", elapsed, \"us\")\n\n    input_buffer[:input_packed.size]=input_packed\n    input_buffer.flush()\n    start = time.time()\n    sche(ct_ptr, sche_len)\n    end = time.time()\n\n    elapsed = round((end - start) * 1000000)\n    # print(\"clib inference call:\", elapsed, \"us\")\n\n    output_buffer.invalidate()\n    test_out = output_buffer[:16]\n    test_index = np.unpackbits(test_out[12:14],bitorder='little').argmax()\n    print(\"test result:\", test_index,\"gold result\", target_0)\n    if test_index != target_0:\n        err_cnt = err_cnt + 1\n\nprint(\"accuray: \" , 1 - err_cnt/len(image))\nprint(\" ----------------- c schedule finish\")\n"
  },
  {
    "path": "examples/Hardware_acceleration/zcu104_test.py",
    "content": "from standalone_utils import *\nfrom firefly_v1_schedule_on_pynq import *\nfrom pynq import PL\nfrom pynq import Overlay\nfrom pynq import allocate\nfrom pynq import MMIO\nimport numpy as np\nimport time\nfrom einops import rearrange\nol = Overlay('firefly_v1_zcu104_bitstream/sys_wrapper.bit')\n\nimage = np.load(\"firefly_v1_cifar10_data/image.npy\")\ntarget = np.load(\"firefly_v1_cifar10_data/target.npy\")\nmodel_config_list = load_model_config_list(\"firefly_v1_cifar10_data/snn7_cifar10_x32.npy\")\n\ninput_buffer = allocate(shape=(1<<23), dtype=np.uint8)\noutput_buffer = allocate(shape=(1<<23), dtype=np.uint8)\nctrl_io = MMIO(0x0400000000, 0x400)\n\nschedule_list,output_addr=create_schedule(\n    model_config_list=model_config_list,\n    ctrl_io=ctrl_io,\n    allocate_method=allocate,\n    buffer_0=input_buffer,\n    buffer_1=output_buffer,\n    image_height=32,\n    image_width=32,\n    time_step=4,\n    parallel_channel=32\n)\n\nsche, ct_ptr, sche_len = init_firefly_c_lib_with_time(\"firefly_v1_common/firefly_v1_lib.so\", schedule_list)\n\nprint(\" ----------------- initialize finish!\")\n\n\nprint(\" ----------------- python schedule begin\")\nerr_cnt = 0\nfor i in range(len(image)):\n    image_0 = image[i]\n    target_0 = target[i]\n    print(i)\n\n    start = time.time()\n    firefly_in, input_packed = preprocess(model_config_list, image_0, 32)\n    input_buffer[:input_packed.size]=input_packed\n    input_buffer.flush()\n    end = time.time()\n\n    elapsed = round((end - start) * 1000000)\n    # print(\"preprocess:\", elapsed, \"us\")\n\n    input_buffer[:input_packed.size]=input_packed\n    input_buffer.flush()\n    start = time.time()\n    schedule_run_all(schedule_list)\n    end = time.time()\n\n    elapsed = round((end - start) * 1000000)\n    print(\"snn inference:\", elapsed, \"us\")\n\n    output_buffer.invalidate()\n    test_out = output_buffer[:32]\n    test_index = np.unpackbits(test_out[24:28],bitorder='little').argmax()\n    print(\"test result:\", test_index,\"gold result\", target_0)\n    if test_index != target_0:\n        err_cnt = err_cnt + 1\n\nprint(\"accuray: \" , 1 - err_cnt/len(image))\nprint(\" ----------------- python schedule finish\")\n\nprint(\" ----------------- c schedule begin\")\nerr_cnt = 0\nfor i in range(len(image)):\n    image_0 = image[i]\n    target_0 = target[i]\n    print(i)\n\n    start = time.time()\n    firefly_in, input_packed = preprocess(model_config_list, image_0, 32)\n    input_buffer[:input_packed.size]=input_packed\n    input_buffer.flush()\n    end = time.time()\n\n    elapsed = round((end - start) * 1000000)\n    # print(\"preprocess:\", elapsed, \"us\")\n\n    input_buffer[:input_packed.size]=input_packed\n    input_buffer.flush()\n    start = time.time()\n    sche(ct_ptr, sche_len)\n    end = time.time()\n\n    elapsed = round((end - start) * 1000000)\n    # print(\"clib inference call:\", elapsed, \"us\")\n\n    output_buffer.invalidate()\n    test_out = output_buffer[:32]\n    test_index = np.unpackbits(test_out[24:28],bitorder='little').argmax()\n    print(\"test result:\", test_index,\"gold result\", target_0)\n    if test_index != target_0:\n        err_cnt = err_cnt + 1\n\nprint(\"accuray: \" , 1 - err_cnt/len(image))\nprint(\" ----------------- c schedule finish\")\n"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/CKRGSNN/README.md",
    "content": "# Commonsense Knowledge Representation SNN\n\n(https://arxiv.org/abs/2207.05561)\n\nThis repository contains code from our paper [**Brain-inspired Graph Spiking Neural Networks for Commonsense Knowledge Representation and Reasoning**] preprint in: https://arxiv.org/abs/2207.05561 . If you use our code or refer to this project, please cite this paper.\n\n\n\n\n## Requirments\n\n* python=3.8\n* numpy\n* scipy\n* turicreate\n* pytorch >= 1.7.0\n* torchvision\n\n\n## Dataset\n\nConceptNet: https://github.com/commonsense/conceptnet5\n\n\n## Run\n\n```shell\npython main.py\n```\n\nThis module selects core knowledge in ConceptNet to form the sub_Concept.csv file as the input of the model. The input current, spike trains during the learning process and the network weight distribution after the learning are shown in the Results folder.\n\n\n### Citation \nIf you find this package helpful, please consider citing the following papers:\n\n```BibTex\n@article{KRRfang2022,\n    title   = {Brain-inspired Graph Spiking Neural Networks for Commonsense Knowledge Representation and Reasoning},\n    author  = { Fang, Hongjian and Zeng, Yi and  Tang, Jianbo and Wang, Yuwei and Liang, Yao and  Liu, Xin},\n    journal = {arXiv preprint arXiv:2207.05561},\n    year    = {2022}\n}\n\n\n@misc{https://doi.org/10.48550/arxiv.2207.08533,\n  doi = {10.48550/ARXIV.2207.08533},\n  url = {https://arxiv.org/abs/2207.08533},\n  author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},\n  title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},\n  publisher = {arXiv},\n  year = {2022},\n}\n\n```\n\n"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/CKRGSNN/main.py",
    "content": "import time\nimport numpy as np\nimport os\nimport warnings\nimport scipy.io as scio\nimport math\nfrom matplotlib import pyplot as plt\nimport torch\nfrom braincog.base.node.node import *\nimport turicreate as tc\nfrom braincog.base.brainarea.BrainArea import *\nfrom braincog.utils import *\n\n\nwarnings.filterwarnings('ignore')\nnp.set_printoptions(threshold=np.inf)\n\n\nclass CKRNet(BrainArea):\n    \"\"\"\n    Commonsense Knowledge Representation  Net\n    \"\"\"\n\n    def __init__(self, w1, w2):\n        \"\"\"\n        \"\"\"\n        super().__init__()\n\n        self.node = [LIFNode(threshold=16, tau=15)]\n        self.connection = [CustomLinear(w1), CustomLinear(w2)]\n        self.stdp = []\n\n        self.stdp.append(MutliInputSTDP(self.node[0], [self.connection[0], self.connection[1]], decay=0.83))\n        self.x1 = torch.zeros(1, w2.shape[0])\n\n    def forward(self, x):\n        \"\"\"\n        x is spike train\n        \"\"\"\n        self.x1, dw1 = self.stdp[0](self.x1, x)\n\n        return self.x1, dw1\n\n    def reset(self):\n        self.x1 *= 0\n\n\ndef S_bound(S):\n\n    S[S > synapse_bound] = synapse_bound\n    S[S < -synapse_bound] = -synapse_bound\n\n    for i in range(N_entity):\n        temp1 = S[Index_E[i], :]\n        temp2 = temp1[:, Index_E[i]]\n        temp2[temp2 > inner_bound_E] = inner_bound_E\n        temp1[:, Index_E[i]] = temp2\n        S[Index_E[i], :] = temp1\n\n    for i in range(N_relation):\n        temp1 = S[Index_R[i], :]\n        temp2 = temp1[:, Index_R[i]]\n        temp2[temp2 > inner_bound_R] = inner_bound_R\n        temp1[:, Index_R[i]] = temp2\n        S[Index_R[i], :] = temp1\n\n    return S\n\n\nif __name__ == \"__main__\":\n\n    print(os.getcwd())\n\n    KG = tc.SFrame.read_csv('./sub_Conceptnet.csv')\n\n    Set_R = set()\n    Set_E = set()\n\n    for i in range(KG.shape[0]):\n\n        Set_R.add(KG[i]['Relation'])\n        Set_E.add(KG[i]['Head'])\n        Set_E.add(KG[i]['Tail'])\n\n    List_E = sorted(Set_E)\n    List_R = list(Set_R)\n    List_R.sort()\n\n    # Network Parameter#dkenf.kejlklkelkvjlkxjel\n\n    I_syn = 5\n    tau_m = 30\n    I_t = 3  # Time duration of stimu current\n    I_P = 150  # Strength of input current\n    A_P = 0.009\n    certainty = 0.2\n\n    synapse_bound = 1    # The bound of all synapse\n    inner_bound_E = 0.6  # The bound of population inner synapse\n    inner_bound_R = 0.3  # The bound of population inner synapse\n\n    Ce = 20   # num of entity\n    Cr = 100  # num of relation\n    N_entity = len(List_E)\n    N_relation = len(List_R)\n    total_neurons = Ce * N_entity + Cr * N_relation\n\n    KG_No = KG.shape[0]\n    trail_time = 40\n    runtime = KG_No * trail_time\n\n    print('N_entity=', N_entity)\n    print('N_relation=', N_relation)\n    print('KG_No=', KG_No)\n    print('runtime=', runtime)\n    print('total_neurons=', total_neurons)\n\n    S = np.zeros((total_neurons, total_neurons), dtype=float)  # Initial Weights\n    S = torch.tensor(S, dtype=torch.float32)\n    E = np.identity((total_neurons), dtype=float)\n    E = torch.tensor(E, dtype=torch.float32)\n\n    I_stimu = np.zeros((total_neurons, runtime))\n    ADJ = np.zeros((total_neurons, runtime))  # record the firing condition\n\n    Index_E = []\n    Index_R = []\n    for i in range(N_entity):\n        Index_E.append(np.arange(i * Ce, i * Ce + Ce))\n\n    for i in range(N_relation):\n        Index_R.append(np.arange(N_entity * Ce + i * Cr, N_entity * Ce + i * Cr + Cr))\n\n    for i in range(KG_No):\n        Head = KG[i]['Head']\n        Rela = KG[i]['Relation']\n        Tail = KG[i]['Tail']\n        Weig = KG[i]['Weight']\n\n        # print(List_E.index(Head))\n        # print(List_R.index(Rela))\n        # print(List_E.index(Rela))\n        # print(Index_R[List_R.index(Rela)])\n\n        I_stimu[Index_E[List_E.index(Head)], 10 + i * trail_time: 10 + I_t + i * trail_time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)\n        I_stimu[Index_R[List_R.index(Rela)], 15 + i * trail_time: 15 + I_t + i * trail_time] = certainty * I_P + I_P * np.random.rand(Cr, I_t)\n        I_stimu[Index_E[List_E.index(Tail)], 20 + i * trail_time: 20 + I_t + i * trail_time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)\n\n    CKRGSNN = CKRNet(S, E)\n\n    for t in range(runtime):\n\n        I_input = torch.tensor(I_stimu[:, t].reshape(1, total_neurons), dtype=torch.float32)\n\n        x, dw = CKRGSNN(I_input)\n\n        S += A_P * dw[1]\n\n        S += S_bound(S) - S\n\n        ADJ[:, t] = x\n        print(t, 'step in >>', runtime)\n\n    img_I = plt.matshow(I_stimu)\n    plt.savefig(\"I_stimu1.jpg\", dpi=500, bbox_inches='tight')\n\n    img_ADJ = plt.matshow(ADJ)\n    plt.savefig(\"ADJ1.jpg\", dpi=500, bbox_inches='tight')\n\n    img_S = plt.matshow(S)\n    plt.colorbar()\n    plt.savefig(\"S1.jpg\", dpi=500, bbox_inches='tight')\n\n    plt.show()\n\n    S = np.mat(S)\n    dataNew = './data_save.mat'\n    scio.savemat(dataNew, {'I_stimu': I_stimu, 'ADJ': ADJ, 'Weight': S})\n"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/CKRGSNN/sub_Conceptnet.csv",
    "content": "Relation,Head,Tail,Weight\nantonym,ab_extra,ab_intra,1.0\nantonym,ab_intra,ab_extra,1.0\nantonym,abactinal,actinal,1.0\nantonym,abandon,acquire,1.0\nantonym,abandon,arrogate,1.0\nantonym,abandon,embrace,1.0\nantonym,abandon,engage,1.0\nantonym,abandon,gain,1.0\nantonym,abandon,join,1.0\nantonym,abandon,maintain,1.0\nantonym,abandon,retain,1.0\nantonym,abandon,unite,1.0\natlocation,clock,department_store,1.0\natlocation,clock,desk,4.472\natlocation,clock,house,2.828\natlocation,clock,office,2.0\natlocation,crisps,table,1.0\natlocation,crisps,vending_machines,1.0\natlocation,crockery,cupboard,1.0\natlocation,crocs,united_states,0.5\natlocation,fungus,damp_spot,1.0\natlocation,fungus,damp_warm_place,1.0\natlocation,fungus,damp_wood,2.0\natlocation,fungus,damp_woods,1.0\natlocation,fungus,dank_place,1.0\natlocation,fungus,dark,1.0\natlocation,fungus,dark_and_dank_place,1.0\natlocation,fungus,dark_damp_area,2.828\natlocation,fungus,dark_damp_place,2.0\natlocation,carnival_rides,fairgrounds,1.0\natlocation,carousel,carnival,2.828\natlocation,carpet,at_hotel,1.0\ncapableof,adult,drive_car,1.0\ncapableof,adult,drive_train,1.0\ncapableof,adult,explain_rules_to_child,1.0\ncapableof,adult,feed_and_take_care_of_itself,1.0\ncapableof,adult,gift_knowledge_to_child,1.0\ncapableof,adult,hand_toy_to_child,1.0\ncapableof,adult,help_child,3.464\ncapableof,adult,keep_property,1.0\ncapableof,adults,act_like_infants,1.0\ncapableof,adults,care_for_babies,1.0\ncapableof,adults,carry_infants,1.0\ncapableof,adults,count,2.0\ncapableof,adults,demand_respect_from_children,1.0\ncapableof,adults,dress_themselves,2.828\ncapableof,adults,drink_beer,2.0\ncapableof,adults,drive_vehicle,1.0\ncapableof,adults,eat_sushi,2.0\ncapableof,adults,fail_to_pass_test,1.0\ncauses,competing_against,testing_yourself_against_another_person,1.0\ncauses,competing_against,try_hardest,1.0\ncauses,computing_sum,get_total_amount,1.0\ncauses,computing_sum,getting_answer,4.899\ncauses,computing_sum,getting_right_answer,1.0\ncauses,computing_sum,getting_total,1.0\ncauses,computing_sum,having_answer,2.0\ncauses,computing_sum,having_total,1.0\ncauses,computing_sum,headache,3.464\ncauses,computing_sum,insight,1.0\ncauses,computing_sum,know_total,2.0\ncauses,computing_sum,knowing_total,1.0\ncauses,computing_sum,may_get_wrong_answer,1.0\ncauses,computing_sum,number,1.0\ncauses,computing_sum,obtaining_total,1.0\ncauses,computing_sum,reaching_total,1.0\ncauses,computing_sum,receiving_total,1.0\ndesires,person,have_day_off,1.0\ndesires,person,have_diamonds,2.0\ndesires,person,have_easy,1.0\ndesires,person,have_enough_food,2.0\ndesires,person,have_enough_to_eat,1.0\ndesires,person,have_everyone_happy_with,2.0\ndesires,person,have_everything,1.0\ndesires,person,have_fast_internet_acess,1.0\ndesires,person,have_femily,1.0\ndesires,person,have_firm_body,1.0\ndesires,person,have_fond_memories,1.0\ndesires,person,have_fortune,1.0\ndesires,person,have_free_time,1.0\ndesires,person,have_friends,1.0\ndesires,person,have_fulfilling_life,1.0\ndesires,person,have_fun_in_life,2.0\ndesires,person,have_fun_on_weekends,1.0\ndesires,person,have_fun_weekend,2.0\ndesires,person,have_future,1.0\ndesires,person,have_good_bones,1.0\ndesires,person,have_good_day,2.828\ndesires,person,have_good_eyesight,1.0\ndesires,person,have_good_feelings,1.0\ndesires,person,have_good_friends,2.828\ndesires,person,have_good_life,1.0\ndesires,person,have_good_memories,2.0\ndesires,person,have_good_memory,1.0\ndesires,person,have_good_relationships_with_others,1.0\ndesires,person,have_good_skin,1.0\ndesires,person,have_great_sex,2.0\ndesires,person,have_happy_childhood,1.0\ndesires,person,have_happy_family,1.0\ndesires,person,have_healthy_life,1.0\ndesires,person,have_healthy_sex_life,1.0\ndesires,person,have_home_to_live_in,2.828\ndesires,person,have_hot_water,1.0\ndesires,person,have_influence,1.0\ndesires,person,have_inner_peace,2.828\ndesires,person,have_interesting_job,1.0\ndesires,person,have_large_vocabulary,1.0\ndesires,person,have_lasting_friendships,1.0\ndesires,person,have_long_live,2.0\ndesires,person,have_loving_family,1.0\ndesires,person,have_many_good_friends,2.0\ndesires,person,have_meaningful_life,2.0\ndesires,person,have_minimum_necessities_of_life,1.0\ndesires,person,have_money_to_buy_chocolate,2.0\ndesires,person,have_more,1.0\nhascontext,bridge,card_games,1.0\nhascontext,bridge,communication,1.0\nhascontext,bridge,computing,1.0\nhascontext,bridge,music,1.0\nhascontext,bridge,wrestling,1.0\nhascontext,bridge_and_tunnel,new_york_city,1.0\nhascontext,bridge_and_tunnel,pejorative,1.0\nhascontext,bridge_and_tunnel,slang,1.0\nhassubevent,maintain_good_health,avoid_guns,1.0\nhassubevent,maintain_good_health,avoid_hate,1.0\nhassubevent,maintain_good_health,avoid_heavily_processed_foods,2.0\nhassubevent,maintain_good_health,avoid_highly_processed_food,1.0\nhassubevent,maintain_good_health,avoid_illegal_drugs,1.0\nhassubevent,maintain_good_health,avoid_losing_sleep,1.0\nhassubevent,maintain_good_health,avoid_marijuana,1.0\nhassubevent,maintain_good_health,avoid_much_sunlight,1.0\nhassubevent,maintain_good_health,avoid_poison,1.0\nhassubevent,maintain_good_health,avoid_racism,1.0\nhassubevent,maintain_good_health,avoid_smoking_marijuana,1.0\nhassubevent,maintain_good_health,avoid_smoking_tobacco,1.0\nhassubevent,maintain_good_health,avoid_tobacco,1.0\nhassubevent,maintain_good_health,avoid_unpleasant_people,1.0\nhassubevent,maintain_good_health,avoid_war,1.0\nhassubevent,maintain_good_health,become_vegan,1.0\nhassubevent,maintain_good_health,calm,1.0\nhassubevent,maintain_good_health,care_about_people,1.0\nhassubevent,maintain_good_health,chew_food_well,1.0\nhassubevent,maintain_good_health,determine_nutritional_needs,1.0\nhassubevent,maintain_good_health,do_exercise,1.0\nhassubevent,maintain_good_health,dress_warmly_in_cold_weather,1.0\nhassubevent,maintain_good_health,drink_very_little_alcoholic_beverages,1.0\nhassubevent,maintain_good_health,drive_safely,1.0\nhassubevent,maintain_good_health,eat_apple_day,1.0\nhassubevent,maintain_good_health,eat_enough_and_exercise,1.0\nhassubevent,maintain_good_health,eat_foods_containing_fiber,1.0\nhassubevent,maintain_good_health,excersize,1.0\nhassubevent,maintain_good_health,exercise,1.0\nhassubevent,maintain_good_health,have_fullfilled_sex_life,1.0\nhassubevent,maintain_good_health,have_spouse,1.0\nhassubevent,maintain_good_health,leave_cat_alone,1.0\nhassubevent,maintain_good_health,live_healthy_lifestyle,1.0\nhassubevent,maintain_good_health,loving,1.0\nhassubevent,maintain_good_health,monitor_health_often,1.0\nhassubevent,maintain_good_health,not_eat_too_much_sna,1.0\nisa,antidiarrheal,medicine,2.0\nisa,antidiarrheal_therapy,drug_therapy,1.0\nisa,antidiuretic,medicine,2.0\nisa,antidiuretic_agent,medicine,1.0\nisa,antido,artificial_language,2.0\nisa,antidorcas,mammal_genus,2.0\nisa,antidoron,food,0.5\nisa,antidote,neutralizer,1.0\nisa,antidote,remedy,2.0\nisa,antiemetic,medicine,1.0\nisa,antiemetic,medicine,2.0\nisa,antiemetic_therapy,drug_therapy,1.0\nisa,antiepileptic,anticonvulsant,1.0\nisa,antiestablishmentarianism,doctrine,2.0\nisa,antietam,national_cemetery_in_maryland,1.0\nisa,antifeminist,bigot,2.0\nisa,antiferromagnetism,magnetism,2.0\nisa,antifibrinolytic_agent,medicine,1.0\nisa,antiflatulent,agent,2.0\nisa,antifouling_paint,paint,2.0\nisa,antifreeze,automotive_fluid,1.0\nisa,antifreeze,liquid,2.0\nisa,antifungal,agent,2.0\nisa,antigen,antigen,1.0\nisa,antigen,carbohydrate,1.0\nisa,antigen,tangible_thing,1.0\nisa,antigen,substance,2.0\nisa,antigenes,person,0.5\nisa,antigenic_determinant,site,2.0\nisa,antigone,play,0.5\nisa,antigonia,fish_genus,2.0\nisa,antigonia,fish,0.5\nisa,antigorite,serpentine,1.0\nisa,antigorite,serpentine,1.0\nisa,gardens,often_in_yards,1.0\nisa,gardens,places_where_plants_grow,1.0\nisa,gardens,pleasant_outdoor_place,1.0\nisa,gardens,pleasant_places_to,1.0\nisa,garding,town,0.5\nisa,gardnerian_wicca,wicca,1.0\nisa,gardon,river,0.5\nisa,garfield,fictional_character,0.5\nisa,garfield,station,0.5\nisa,garfield,station,0.5\nisa,garfish,fish,0.5\nisa,gargamel,comics_character,0.5\nisa,garganey,bird,0.5\nisa,garlic,food_ingredient,1.0\nisa,garlic,flavorer,2.0\nisa,garlic,alliaceous_plant,2.0\nisa,tire_iron,rigid_portable_object,1.0\nisa,tire_iron,shaped_thing,1.0\nisa,tire_iron,hand_tool,2.0\nisa,tire_iron,lever,2.0\nisa,tire_pump,useful_tool,1.0\nisa,tire_pump,mechanical_pump,1.0\nisa,tire_rotation,axis_constrained_rotation,1.0\nisa,tire_sealant,automotive_product,1.0\nisa,tire_sealant,sealant,1.0\nisa,garlic_chive,alliaceous_plant,2.0\nisa,garlic_clove,bulb,1.0\npartof,paragraph,textual_document,1.0\npartof,paragraph,text,2.0\npartof,victoria,zambezi,2.0\npartof,victoria,zambia,2.0\npartof,victoria,zimbabwe,2.0\npartof,victoria_land,antarctica,2.0\npartof,vidalia,georgia,2.0\npartof,video,television,2.0\npartof,vienna,austria,2.0\npartof,vienne,poitou_charentes,0.5\npartof,vienne,france,2.0\npartof,vientiane,laos,2.0\npartof,vieques,puerto_rico,2.0\npartof,vietnam,indochina,2.0\npartof,viewfinder,camera,1.0\npartof,vigo,galicia,0.5\npartof,vigo,vigo,0.5\npartof,vigo,galicia,0.5\npartof,villa,asturias,0.5\npartof,villahermosa,tabasco,0.5\npartof,villahermosa,mexico,2.0\npartof,villarreal,valencian_community,0.5\npartof,vilnius,dzūkija,0.5\npartof,vilnius,lithuania,0.5\npartof,vilnius,lithuania,2.0\npartof,visual_purple,rod,2.0\npartof,visual_signal,visual_communication,2.0\npartof,viti_levu,fiji_islands,2.0\npartof,viña_del_mar,valparaíso,0.5\npartof,vladivostok,russia,2.0\npartof,vocabulary,language,2.0\npartof,vocal_cord,larynx,2.0\npartof,voider,body_armor,2.0\npartof,volapük,international_auxiliary_language,0.5\npartof,volcanic_crater,volcano,2.0\npartof,volcano,south_park,0.5\npartof,volcano_islands,japan,2.0\npartof,volcano_islands,pacific,2.0\npartof,volga,russia,2.0\npartof,volgograd,russia,2.0\npartof,volkhov,russia,2.0\npartof,parthenon,athens,2.0\nreceivesaction,carpet,found_on_ground,1.0\nreceivesaction,carpet,used_as_floor_covering,1.0\nreceivesaction,carpeted_floors,found_in_many_kinds_of_buildings,1.0\nreceivesaction,carpeting,used_in_place_of_hardwood_floors,1.0\nreceivesaction,carpets,bought_at_carpet_stores,1.0\nreceivesaction,cartoons,animated,1.0\nreceivesaction,case,tried_in_appeals_court,1.0\nreceivesaction,cases,heard_in_court_of_law,1.0\nreceivesaction,cash,denominated_in_dollars,1.0\nreceivesaction,cash,earned,1.0\nreceivesaction,cash,measured_in_dollars_and_cents,1.0\nreceivesaction,castanets,bound_together_with_leather_string,1.0\nreceivesaction,castanets,used_in_form_of_dance,1.0\nreceivesaction,casual_describes_clothing,worn_for_comfort_and_function,1.0\nreceivesaction,cat,attracted_to_parakeets,1.0\nreceivesaction,cats,thought_to_hate_dogs,1.0\nreceivesaction,cats_and_dogs,treated_badly,1.0\nreceivesaction,cats_purr_when,contented,1.0\nreceivesaction,cattle,fed_in_feed_lots,1.0\nreceivesaction,cauldron,steeped_in_magical_tradition_and_mystery,1.0\nreceivesaction,cauliflower_and_broccoli,combined_into_one_super_vegetable,1.0\nreceivesaction,cavitron,used_in_brain_surgery,1.0\nreceivesaction,cds,bought_in_stores,1.0\nreceivesaction,cds_usually,made_from_plastic,1.0\nreceivesaction,cedar,used_as_shingles_on_houses,1.0\nreceivesaction,ceilings,painted_with_brush,1.0\nreceivesaction,ceilings_have_color_which,painted_onto,1.0\nreceivesaction,celebrity,associated_with_autographs,1.0\nreceivesaction,celebrity,associated_with_desire,1.0\nreceivesaction,celebrity,associated_with_fame,1.0\nreceivesaction,celebrity,associated_with_fans,1.0\nrelatedto,penis,tarse,1.0\nrelatedto,penis_pump,cock_pump,1.0\nrelatedto,penis_worm,priapulid,1.0\nrelatedto,penised,bedicked,1.0\nrelatedto,penitence,compunction,2.0\nrelatedto,penitence,remorse,1.0\nrelatedto,penitence,repentance,1.0\nrelatedto,penitence,repentance,2.0\nrelatedto,penitent,penaunt,1.0\nrelatedto,penitentiary,penitential,2.0\nrelatedto,penitentiary,jail,1.0\nrelatedto,penrose_diagram,carter_penrose_diagram,1.0\nrelatedto,penrose_process,penrose_mechanism,1.0\nrelatedto,penrose_staircase,penrose_stairs,1.0\nrelatedto,penrose_staircase,penrose_steps,1.0\nrelatedto,penrose_stairs,penrose_staircase,1.0\nrelatedto,penrose_stairs,penrose_steps,1.0\nrelatedto,penrose_steps,penrose_staircase,1.0\nrelatedto,penrose_steps,penrose_stairs,1.0\nrelatedto,penrose_triangle,penrose_triangle,0.5\nrelatedto,pensacola,pensacola,0.5\nrelatedto,pension,pension,0.5\nrelatedto,pension,hotel,1.0\nrelatedto,penstemon_linarioides,narrow_leaf_penstemon,2.0\nrelatedto,penstemon_newberryi,mountain_pride,2.0\nrelatedto,penstemon_palmeri,balloon_flower,2.0\nrelatedto,penstemon_rupicola,rock_penstemon,2.0\nrelatedto,penstemon_serrulatus,cascade_penstemon,2.0\nrelatedto,penstock,sluice,2.0\nrelatedto,penstock,sluicegate,2.0\nrelatedto,pent,shut_up,2.0\nrelatedto,pent_up,repressed,2.0\nrelatedto,penta,quinque,1.0\nrelatedto,pentaborane,pentaborane,0.5\nrelatedto,pentabromodiphenyl_ether,pentabromodiphenyl_ether,0.5\nrelatedto,pentabromodiphenyl_ether,pentabromodiphenyl_oxide,1.0\nrelatedto,pentacene,pentacene,0.5\nrelatedto,pentachloronitrobenzene,pentachloronitrobenzene,0.5\nrelatedto,pentagon,pentagon,0.5\nrelatedto,pentagonal,pentangular,2.0\nrelatedto,pentagram,pentacle,1.0\nrelatedto,pentagram,pentalpha,1.0\nrelatedto,pentagram,pentangle,1.0\nrelatedto,pentagram,pentacle,2.0\nrelatedto,pentagraph,pentagraph,0.5\nrelatedto,pentail,pen_tailed_treeshrew,1.0\nrelatedto,pentalpha,pentagram,1.0\nrelatedto,pentalpha,pentangle,1.0\nrelatedto,pentamethylbenzene,pentamethylbenzene,0.5\nrelatedto,pentamethylenetetrazol,pentylenetetrazol,2.0\nrelatedto,pentamidine,pentamidine,0.5\nrelatedto,pentanal,pentanal,0.5\nrelatedto,pentanal,pentanaldehyde,1.0\nrelatedto,pentanal,valeraldehyde,1.0\nrelatedto,pentane,pentane,0.5\nrelatedto,pentangle,pentacle,2.0\nrelatedto,pentanoate,valerate,1.0\nrelatedto,pentanoic_acid,valeric_acid,2.0\nrelatedto,pentanol,amyl_alcohol,1.0\nrelatedto,pentanol,pentyl_alcohol,1.0\nrelatedto,pentastomid,tongue_worm,2.0\nrelatedto,pentastomida,pentastomida,0.5\nrelatedto,pentateuch,books_of_moses,1.0\nrelatedto,pentateuch,law,1.0\nrelatedto,pentateuch,torah,1.0\nrelatedto,pentateuch,torah,2.0\nrelatedto,pentatone,pentatonic_scale,2.0\nrelatedto,pentatonic_scale,pentatonic_scale,0.5\nrelatedto,pentazocine,pentazocine,0.5\nrelatedto,pentazole,pentazole,0.5\nrelatedto,pentecost,pentecost,0.5\nrelatedto,pentecost,feast_of_weeks,1.0\nrelatedto,pentecost,shavuos,1.0\nrelatedto,pentecost,shavuot,1.0\nrelatedto,pentecost,whit,1.0\nrelatedto,pentecost,whit_sunday,1.0\nrelatedto,pentecost,whitsun,1.0\nrelatedto,pentecost,whitsunday,1.0\nrelatedto,pentecost,shavous,2.0\nrelatedto,pentecostal,pentecostalist,2.0\nrelatedto,pentecostalism,pentecostalism,0.5\nrelatedto,pentel,pentel,0.5\nrelatedto,pentelic,pentelican,1.0\nrelatedto,pentene,pentene,0.5\nusedfor,gourmet_shop,buy_weird_food,1.0\nusedfor,gourmet_shop,buying_imported_foods,1.0\nusedfor,gourmet_shop,customers_who_knowlegeable_about_cooking,1.0\nusedfor,gourmet_shop,fine_foods,1.0\nusedfor,gourmet_shop,hard_to_find_foods,1.0\nusedfor,gourmet_shop,icky_foods,1.0\nusedfor,lip,pleasure,1.0\nusedfor,lip,pouring,1.0\nusedfor,vessel,containing_highway_for_blood_flow,1.0\nusedfor,vessel,float,1.0\nusedfor,vessel,hold_flowers,1.0\nusedfor,vessel,moving,1.0\nusedfor,vessel,moving_people_around,1.0\nusedfor,vessel,navigate,1.0\nusedfor,vessel,piloting,1.0\nusedfor,vessel,sailing_on_ocean,1.0\nusedfor,vessel,ship_goods,1.0\nusedfor,vessel,shipping,2.0\nusedfor,vessel,store_liquid,1.0\nusedfor,vessel,storing_liquids,2.0\nusedfor,vessel,transporting_things,1.0\nusedfor,vessel,usually_staying_above_water,1.0\nusedfor,veterinarians,sick_animals,1.0\nusedfor,vibrator,entertain_yourself,1.0\nusedfor,vibrator,get_off,1.0\nusedfor,vibrator,increase_pleasure_during_sex,1.0\nusedfor,vibrator,sexually_stimulate_yourself_or_else,1.0\nusedfor,viewing_video,having_fun,1.0\nusedfor,viewing_video,having_good_time,1.0\nusedfor,viewing_video,learning,2.828\nusedfor,viewing_video,learning_language,1.0\nusedfor,viewing_video,learning_new,1.0\nusedfor,viewing_video,relaxation,1.0\nusedfor,viewing_video,relaxing,2.0\nusedfor,viewing_video,reviewing_video,1.0\nusedfor,viewing_video,spending_time_with_grandson,1.0\nusedfor,viewing_video,watching_film,1.0\nusedfor,viewing_video,watching_home_movies,1.0\nusedfor,viewing_video,watching_memories,1.0\nusedfor,viewing_video,watching_movie_star,1.0\nusedfor,village,learning,1.0\nusedfor,village,people_to_live_in,1.0\nusedfor,village,playing,1.0\nusedfor,village,raise_child,1.0\nusedfor,village,sedentary_living,1.0\nusedfor,viola,play_song,1.0\nusedfor,viola,playing,1.0\nusedfor,viola,playing_music,2.0\nusedfor,viola,playing_sissy_music,1.0\nusedfor,viola,sing_song,1.0\nusedfor,violence,kill,1.0\nusedfor,violence,terrorism,1.0\nusedfor,violin,create_music,1.0\nusedfor,violin,creating,2.0\nusedfor,violin,creating_art,1.0\nusedfor,violin,entertaining,1.0\nusedfor,violin,entertainment,1.0\nusedfor,violin,fu_n,1.0\nusedfor,violin,making_lovely_music,1.0\nusedfor,violin,mkae_annoying_noises,1.0\nusedfor,violin,music,1.0\nusedfor,violin,play_music,2.828\nusedfor,violin,playing,1.0\nusedfor,violin,playing_music,5.657\nusedfor,violin,playing_music_on_stringed_instrument,1.0\nusedfor,visa_card,acquiring_debt,1.0\nusedfor,visa_card,adults_not_kids,1.0\nusedfor,visiting_museum,entertainment,1.0\nusedfor,visiting_museum,feeling_young,1.0\nusedfor,visiting_museum,finding_out_about_past,1.0\nusedfor,visiting_museum,finding_out_about_world,1.0\nusedfor,visiting_museum,fun,2.0\nusedfor,visiting_museum,getting_ideas_from_past_experiences,1.0\nusedfor,visiting_museum,having_fun,1.0\nusedfor,visiting_museum,having_fun_day,1.0\nusedfor,visiting_museum,learing_about_past,1.0\nusedfor,visiting_museum,learning_about_culture,1.0\nusedfor,visiting_museum,learning_about_history,1.0\nusedfor,visiting_museum,learning_about_other_places,1.0\nusedfor,visiting_museum,learning_history,2.0\nusedfor,lip,protect,1.0\nusedfor,lip,speak,2.0\nusedfor,lip,suck,2.0\nusedfor,lip,whistling,2.0\nusedfor,lips,communicate,2.828\nusedfor,lips,flapping,1.0"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/CRSNN/README.md",
    "content": "# Causal Reasoning SNN\n(https://10.1109/IJCNN52387.2021.9534102)\n\nThis repository contains code from our paper [**A Brain-Inspired Causal Reasoning Model Based on Spiking Neural Networks\n**] published in 2021 International Joint Conference on Neural Networks (IJCNN).  https://ieeexplore.ieee.org/abstract/document/9534102. If you use our code or refer to this project, please cite this paper.\n\n## Requirments\n\n* numpy\n* scipy\n* pytorch >= 1.7.0\n* torchvision\n\n\n\n## Run\n\n```shell\npython main.py\n```\n\n\nThis module builds an example of a brain-like causal inference spiking neural network model. The input causal graph is shown in figure causal_graph.png. The input current, spike trains during the learning process and the network weight distribution after the learning are shown in the Results folder.\n\n\n### Citation \nIf you find this package helpful, please consider citing the following papers:\n\n```BibTex\n@inproceedings{fang2021CRSNN,\n  title={A Brain-Inspired Causal Reasoning Model Based on Spiking Neural Networks},\n  author={Fang, Hongjian and Zeng, Yi},\n  booktitle={2021 International Joint Conference on Neural Networks (IJCNN)},\n  pages={1--5},\n  year={2021},\n  organization={IEEE}\n}\n\n@misc{https://doi.org/10.48550/arxiv.2207.08533,\n  doi = {10.48550/ARXIV.2207.08533},\n  url = {https://arxiv.org/abs/2207.08533},\n  author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},\n  title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},\n  publisher = {arXiv},\n  year = {2022},\n}\n\n```\n"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/CRSNN/main.py",
    "content": "import time\nimport numpy as np\nimport os\nimport warnings\nimport math\nfrom matplotlib import pyplot as plt\nimport torch\nfrom braincog.base.node.node import *\nfrom braincog.base.brainarea.BrainArea import *\nfrom braincog.utils import *\n\n\nwarnings.filterwarnings('ignore')\nnp.set_printoptions(threshold=np.inf)\n\n\nclass CRNet(BrainArea):\n    \"\"\"\n\n       网络结构类:CRNet（Causal Reasoning Net)，定义了网络的结构，继承自BrainArea基类。\n       :param threshold: 神经元发放脉冲需要达到的阈值\n       :param tau: 神经元膜电位常数，控制膜电位衰减\n       :param decay:STDP机制衰减常数，控制STDP机制作用强度随时间变化\n       :param w1:神经网络内部连接权重\n       :param w2:外部输入电流到每个神经元的连接\n\n    \"\"\"\n\n    def __init__(self, w1, w2):\n        \"\"\"\n        \"\"\"\n        super().__init__()\n\n        self.node = [LIFNode(threshold=16, tau=15)]\n        self.connection = [CustomLinear(w1), CustomLinear(w2)]\n        self.stdp = []\n\n        self.stdp.append(MutliInputSTDP(self.node[0], [self.connection[0], self.connection[1]], decay=0.8))\n        self.x1 = torch.zeros(1, w2.shape[0])\n\n    def forward(self, x):\n        \"\"\"\n        一次时间步的前向传播过冲函数，计算脉冲发放情况和权重改变量\n\n        :param x1:经过该时间步后的脉冲发放情况\n        :param dw1:STDP机制在一个时间步后带来的权重改变量\n        \"\"\"\n        self.x1, dw1 = self.stdp[0](x, self.x1)\n\n        return self.x1, dw1\n\n    def reset(self):\n        self.x1 *= 0\n\n\ndef S_bound(S):\n    \"\"\"\n       S_bound:网络权重边界控制函数，主要功能为控制全网络突触连接权不超过阈值，维持弱连接\n       另外，本函数还需要将神经元组内部权重控制在一定的范围之内，以防神经元组不断重复激活自身的情况发生。\n\n       :param synapse_bound: 全网络的突触连接的阈值，以维持网络整体为弱连接\n       :param inner_bound: 神经元组内部突触连接的阈值，防止神经元组不断重复激活自身导致网络放电紊乱\n\n    \"\"\"\n\n    S[S > synapse_bound] = synapse_bound\n    S[S < -synapse_bound] = -synapse_bound\n\n    temp1 = S[E1_index, :]\n    temp2 = temp1[:, E1_index]\n    temp2[temp2 > inner_bound_E] = inner_bound_E\n    temp1[:, E1_index] = temp2\n    S[E1_index, :] = temp1\n\n    temp1 = S[E2_index, :]\n    temp2 = temp1[:, E2_index]\n    temp2[temp2 > inner_bound_E] = inner_bound_E\n    temp1[:, E2_index] = temp2\n    S[E2_index, :] = temp1\n\n    temp1 = S[E3_index, :]\n    temp2 = temp1[:, E3_index]\n    temp2[temp2 > inner_bound_E] = inner_bound_E\n    temp1[:, E3_index] = temp2\n    S[E3_index, :] = temp1\n\n    temp1 = S[E4_index, :]\n    temp2 = temp1[:, E4_index]\n    temp2[temp2 > inner_bound_E] = inner_bound_E\n    temp1[:, E4_index] = temp2\n    S[E4_index, :] = temp1\n\n    temp1 = S[E4_index, :]\n    temp2 = temp1[:, E4_index]\n    temp2[temp2 > inner_bound_E] = inner_bound_E\n    temp1[:, E4_index] = temp2\n    S[E4_index, :] = temp1\n\n    temp1 = S[E5_index, :]\n    temp2 = temp1[:, E5_index]\n    temp2[temp2 > inner_bound_E] = inner_bound_E\n    temp1[:, E5_index] = temp2\n    S[E5_index, :] = temp1\n\n    temp1 = S[R1_index, :]\n    temp2 = temp1[:, R1_index]\n    temp2[temp2 > inner_bound_R] = inner_bound_R\n    temp1[:, R1_index] = temp2\n    S[R1_index, :] = temp1\n\n    temp1 = S[R2_index, :]\n    temp2 = temp1[:, R2_index]\n    temp2[temp2 > inner_bound_R] = inner_bound_R\n    temp1[:, R2_index] = temp2\n    S[R2_index, :] = temp1\n\n    return S\n\n\nif __name__ == \"__main__\":\n\n    # Neurons Parameter\n\n    Cr = 200   # num of relation\n    Ce = 50    # num of entity\n\n    total_time = 2500            # Runtime in ms\n\n    tau = 100             # time constant of STDP\n    stdpwin = 25               # STDP windows in ms\n    thresh = 30               # Judge if the neurons fire or not\n    abs_T = 25               # The length of the ABS\n    Reset = 0                # Reset Potential\n    I_syn = 5\n    tau_m = 30\n    Rm = 10\n\n    N_entity = 5\n    N_relation = 2\n    I_t = 5          # Duration of Current\n    I_P = 25         # Strength of input current\n    certainty = 0.5\n\n    A_P = 0.01\n    synapse_bound = 0.2   # The bound of all synapse\n    inner_bound_E = 0.08   # The bound of population inner synapse\n    inner_bound_R = 0.06   # The bound of population inner synapse\n\n    total_neurons = Ce * N_entity + Cr * N_relation\n\n    \"\"\"\n        SPSNN主函数，实现网络核心主要功能\n\n        :param Cr: 因果图中节点神经元组中神经元数量\n        :param Ce: 因果图中因果关系神经元组中神经元数量\n        :param total_time: 网络总体模拟的时间步长\n        :param learning_times: 网络进行序列学习的次数\n        :param N_entity: 对网络添加外部输入电流的时间长度\n        :param N_relation: 对网络添加外部输入电流的强度\n        :param A_P: 网络在进行STDP学习后突触改变放缩量\n        :param certainty: 网络输入电流大小的确定度\n        :param total_neurons: 网络神经元总量\n        :param ADJ: 网络中脉冲放电情况矩阵\n        :param I_stimu: 网络中外部输入电流矩阵\n        :param S: 网络突触连接权重矩阵\n        :param E: 单位矩阵，用以对每个神经元引入外部电流\n\n    \"\"\"\n\n    # Initial Neurual Network\n\n    E1_index = np.linspace(0, Ce - 1, Ce, dtype=int)\n    E2_index = np.linspace(Ce, 2 * Ce - 1, Ce, dtype=int)\n    E3_index = np.linspace(2 * Ce, 3 * Ce - 1, Ce, dtype=int)\n    E4_index = np.linspace(3 * Ce, 4 * Ce - 1, Ce, dtype=int)\n    E5_index = np.linspace(4 * Ce, 5 * Ce - 1, Ce, dtype=int)\n\n    R1_index = np.linspace(5 * Ce, 5 * Ce + Cr - 1, Cr, dtype=int)\n    R2_index = np.linspace(5 * Ce + Cr, 5 * Ce + 2 * Cr - 1, Cr, dtype=int)\n\n    Ne = total_neurons\n\n    v = Reset * np.zeros(Ne)\n\n    firings = []                       # spike timings\n\n    ADJ = np.zeros((total_time, Ne))   # record the firing condition\n    abs_Ne = np.zeros(Ne)             # maintain the ABS of every neurons\n\n    I_stimu = np.zeros((Ne, total_time), dtype=float)\n\n    If_Memory = np.zeros((total_neurons), dtype=bool)\n    If_Memory[:] = True\n\n    # Pre-set synapses\n\n    S = np.zeros((total_neurons, total_neurons), dtype=float)  # Initial Weights\n\n    S = S - np.diag(S)      # Set the diag num to\n\n    E = np.identity((total_neurons), dtype=float)\n\n    W_set_innner = 0.7\n\n    temp = S[E1_index, :]\n    temp[:, E1_index] = W_set_innner * np.random.rand(Ce, Ce)\n    S[E1_index, :] = temp\n\n    temp = S[E2_index, :]\n    temp[:, E2_index] = W_set_innner * np.random.rand(Ce, Ce)\n    S[E2_index, :] = temp\n\n    temp = S[E3_index, :]\n    temp[:, E3_index] = W_set_innner * np.random.rand(Ce, Ce)\n    S[E3_index, :] = temp\n\n    temp = S[E4_index, :]\n    temp[:, E4_index] = W_set_innner * np.random.rand(Ce, Ce)\n    S[E4_index, :] = temp\n\n    temp = S[E5_index, :]\n    temp[:, E5_index] = W_set_innner * np.random.rand(Ce, Ce)\n    S[E5_index, :] = temp\n\n    temp = S[R1_index, :]\n    temp[:, R1_index] = 0.25 * W_set_innner * np.random.rand(Cr, Cr)\n    S[R1_index, :] = temp\n\n    temp = S[R2_index, :]\n    temp[:, R2_index] = 0.25 * W_set_innner * np.random.rand(Cr, Cr)\n    S[R2_index, :] = temp\n\n    \"\"\"\n    对于因果图中的因果关系，给予网络不同神经元组输入电流刺激，使其建立连接\n    \"\"\"\n    i = 1\n    time = np.linspace(11 + i * 100, 10 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[E1_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)\n    I_stimu[E1_index, :] = temp\n\n    time = np.linspace(21 + i * 100, 20 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[R1_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Cr, I_t)\n    I_stimu[R1_index, :] = temp\n\n    time = np.linspace(31 + i * 100, 30 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[E2_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)\n    I_stimu[E2_index, :] = temp\n\n    i = 3\n    time = np.linspace(11 + i * 100, 10 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[E2_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)\n    I_stimu[E2_index, :] = temp\n\n    time = np.linspace(21 + i * 100, 20 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[R2_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Cr, I_t)\n    I_stimu[R2_index, :] = temp\n\n    time = np.linspace(31 + i * 100, 30 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[E1_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)\n    I_stimu[E1_index, :] = temp\n\n    i = 6\n    time = np.linspace(11 + i * 100, 10 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[E2_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)\n    I_stimu[E2_index, :] = temp\n\n    time = np.linspace(21 + i * 100, 20 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[R1_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Cr, I_t)\n    I_stimu[R1_index, :] = temp\n\n    time = np.linspace(31 + i * 100, 30 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[E3_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)\n    I_stimu[E3_index, :] = temp\n\n    i = 8\n    time = np.linspace(11 + i * 100, 10 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[E3_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)\n    I_stimu[E3_index, :] = temp\n\n    time = np.linspace(21 + i * 100, 20 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[R2_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Cr, I_t)\n    I_stimu[R2_index, :] = temp\n\n    time = np.linspace(31 + i * 100, 30 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[E2_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)\n    I_stimu[E2_index, :] = temp\n\n    i = 11\n    time = np.linspace(11 + i * 100, 10 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[E2_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)\n    I_stimu[E2_index, :] = temp\n\n    time = np.linspace(21 + i * 100, 20 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[R1_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Cr, I_t)\n    I_stimu[R1_index, :] = temp\n\n    time = np.linspace(31 + i * 100, 30 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[E4_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)\n    I_stimu[E4_index, :] = temp\n\n    i = 13\n    time = np.linspace(11 + i * 100, 10 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[E4_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)\n    I_stimu[E4_index, :] = temp\n\n    time = np.linspace(21 + i * 100, 20 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[R2_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Cr, I_t)\n    I_stimu[R2_index, :] = temp\n\n    time = np.linspace(31 + i * 100, 30 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[E2_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)\n    I_stimu[E2_index, :] = temp\n\n    i = 16\n    time = np.linspace(11 + i * 100, 10 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[E3_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)\n    I_stimu[E3_index, :] = temp\n\n    time = np.linspace(21 + i * 100, 20 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[R1_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Cr, I_t)\n    I_stimu[R1_index, :] = temp\n\n    time = np.linspace(31 + i * 100, 30 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[E5_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)\n    I_stimu[E5_index, :] = temp\n\n    i = 18\n    time = np.linspace(11 + i * 100, 10 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[E5_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)\n    I_stimu[E5_index, :] = temp\n\n    time = np.linspace(21 + i * 100, 20 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[R2_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Cr, I_t)\n    I_stimu[R2_index, :] = temp\n\n    time = np.linspace(31 + i * 100, 30 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[E3_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)\n    I_stimu[E3_index, :] = temp\n\n    i = 21\n    time = np.linspace(11 + i * 100, 10 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[E4_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)\n    I_stimu[E4_index, :] = temp\n\n    time = np.linspace(21 + i * 100, 20 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[R1_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Cr, I_t)\n    I_stimu[R1_index, :] = temp\n\n    time = np.linspace(31 + i * 100, 30 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[E5_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)\n    I_stimu[E5_index, :] = temp\n\n    i = 23\n    time = np.linspace(11 + i * 100, 10 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[E5_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)\n    I_stimu[E5_index, :] = temp\n\n    time = np.linspace(21 + i * 100, 20 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[R2_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Cr, I_t)\n    I_stimu[R2_index, :] = temp\n\n    time = np.linspace(31 + i * 100, 30 + I_t + i * 100, I_t, dtype=int)\n    temp = I_stimu[E4_index, :]\n    temp[:, time] = certainty * I_P + I_P * np.random.rand(Ce, I_t)\n    I_stimu[E4_index, :] = temp\n\n    S = torch.tensor(S, dtype=torch.float32)\n    E = torch.tensor(E, dtype=torch.float32)\n    CRSNN = CRNet(S, E)\n\n    for t in range(total_time):\n        I_input = torch.tensor(I_stimu[:, t].reshape(1, total_neurons), dtype=torch.float32)\n\n        x, dw = CRSNN(I_input)\n        S += A_P * dw[1]\n\n        S += S_bound(S) - S\n\n        ADJ[t] = x\n\n    plt.matshow(I_stimu)\n    plt.matshow(ADJ.transpose())\n\n    plt.matshow(S)\n    plt.colorbar()\n\n    plt.show()\n"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/SPSNN/README.md",
    "content": "#  Sequence Production SNN\n\nThis repository contains code from our paper [**Brain Inspired Sequences Production by Spiking Neural Networks With Reward-Modulated STDP**] published in Frontiers in Computational Neuroscience. This module builds a Sequence Production spiking neural network model, realizing the memory and reconstruction functions for arbitrary symbol sequences. The input causal graph is shown in figure causal_graph.png. The input current, spike trains during the learning process and the network weight distribution after the learning are shown in the Results folder.\n\n\n## Requirments\n\n* numpy\n* scipy\n* pytorch >= 1.7.0\n* torchvision\n\n\n\n\n\n## Run\n\n```shell\npython main.py file\n```\n\n\n\n\n### Citation \nIf you find this package helpful, please consider citing the following papers:\n\n```BibTex\n@article{fang2021spsnn,\n    title     = {Brain inspired sequences production by spiking neural networks with reward-modulated stdp},\n    author    = {Fang, Hongjian and Zeng, Yi and Zhao, Feifei},\n    journal   = {Frontiers in Computational Neuroscience},\n    volume    = {15},\n    pages     = {8},\n    year      = {2021},\n    publisher = {Frontiers}\n}\n\n\n@misc{https://doi.org/10.48550/arxiv.2207.08533,\n  doi = {10.48550/ARXIV.2207.08533},\n  url = {https://arxiv.org/abs/2207.08533},\n  author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},\n  title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},\n  publisher = {arXiv},\n  year = {2022},\n}\n\n```\n"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/SPSNN/main.py",
    "content": "import time\nimport numpy as np\nimport os\nimport warnings\nimport math\nfrom matplotlib import pyplot as plt\nimport torch\nfrom braincog.base.node.node import *\nfrom braincog.base.brainarea.BrainArea import *\nfrom braincog.utils import *\n\n\nwarnings.filterwarnings('ignore')\nnp.set_printoptions(threshold=np.inf)\n\n\n\n\n\nclass SPNet(BrainArea):\n    \"\"\"\n\n    网络结构类:SPNet（Sequence Production Net)，定义了网络的结构，继承自BrainArea基类。\n    :param threshold: 神经元发放脉冲需要达到的阈值\n    :param tau: 神经元膜电位常数，控制膜电位衰减\n    :param decay:STDP机制衰减常数，控制STDP机制作用强度随时间变化\n    :param w1:神经网络内部连接权重\n    :param w2:外部输入电流到每个神经元的连接\n\n\n    \"\"\"\n\n    def __init__(self, w1, w2 ):\n        \"\"\"\n        \"\"\"\n        super().__init__()\n\n        self.node = [LIFNode(threshold=10,tau=15) ]\n        self.connection = [CustomLinear(w1), CustomLinear(w2) ]\n        self.stdp = []\n\n        self.stdp.append(MutliInputSTDP(self.node[0], [self.connection[0], self.connection[1]],decay=0.745))\n        self.x1 = torch.zeros(1, w2.shape[0])\n\n    def forward(self, x):\n        \"\"\"\n        一次时间步的前向传播过冲函数，计算脉冲发放情况和权重改变量\n\n        :param x1:经过该时间步后的脉冲发放情况\n        :param dw1:STDP机制在一个时间步后带来的权重改变量\n        \"\"\"\n        self.x1, dw1 = self.stdp[0]( self.x1,x)\n\n\n        return self.x1, dw1\n\n    def reset(self):\n        self.x1 *= 0\n\n\n\ndef S_bound(S):\n    \"\"\"\n    S_bound:网络权重边界控制函数，主要功能为控制全网络突触连接权不超过阈值，维持弱连接\n    另外，本函数还需要将神经元组内部权重控制在一定的范围之内，以防神经元组不断重复激活自身的情况发生。\n\n    :param synapse_bound: 全网络的突触连接的阈值，以维持网络整体为弱连接\n    :param inner_bound: 神经元组内部突触连接的阈值，防止神经元组不断重复激活自身导致网络放电紊乱\n\n    \"\"\"\n\n    S[S > synapse_bound]  =  synapse_bound\n    S[S < -synapse_bound] = -synapse_bound\n\n\n    temp1 = S[l1_stimu,:]\n    temp2 = temp1[:, l1_stimu]\n    temp2 [temp2>inner_bound] = inner_bound\n    temp1[:, l1_stimu] = temp2\n    S[l1_stimu, :] = temp1\n\n    temp1 = S[l2_stimu, :]\n    temp2 = temp1[:, l2_stimu]\n    temp2[temp2 > inner_bound] = inner_bound\n    temp1[:, l2_stimu] = temp2\n    S[l2_stimu, :] = temp1\n\n    temp1 = S[l3_stimu, :]\n    temp2 = temp1[:, l3_stimu]\n    temp2[temp2 > inner_bound] = inner_bound\n    temp1[:, l3_stimu] = temp2\n    S[l3_stimu, :] = temp1\n\n    temp1 = S[l4_stimu, :]\n    temp2 = temp1[:, l4_stimu]\n    temp2[temp2 > inner_bound] = inner_bound\n    temp1[:, l4_stimu] = temp2\n    S[l4_stimu, :] = temp1\n\n\n    return S\n\n\nif __name__ == \"__main__\":\n\n\n\n    ## Neurons Parameter\n\n    C = 50;  # constant:the number of neurons of a symbol\n    runtime = 1000;  # Runtime in ms\n\n\n    thresh = 30;  # Judge if the neurons fire or not\n\n\n    I_syn = 5;\n    tau_m = 30;\n    Rm = 10;\n    learning_times = 3;\n    Sym_size = 6\n    I_t = 5  # Time duration of stimu current\n    I_P = 130  # Strength of input current\n    A_P = 0.024\n    certainty = 0.35\n\n    synapse_bound = 10  # The bound of all synapse\n    inner_bound = 1  # The bound of population inner synapse\n\n    \"\"\"\n        SPSNN主函数，实现网络核心主要功能\n\n        :param C: 神经元组中神经元数量\n        :param runtime: 网络总体模拟的时间步长\n        :param learning_times: 网络进行序列学习的次数\n        :param I_t: 对网络添加外部输入电流的时间长度\n        :param I_P: 对网络添加外部输入电流的强度\n        :param A_P: 网络在进行STDP学习后突触改变放缩量\n        :param certainty: 网络输入电流大小的确定度  \n        :param total_neurons: 网络神经元总量\n        :param ADJ: 网络中脉冲放电情况矩阵\n        :param I_stimu: 网络中外部输入电流矩阵\n        :param S: 网络突触连接权重矩阵\n        :param E: 单位矩阵，用以对每个神经元引入外部电流  \n\n    \"\"\"\n\n    # Initial Neurual Network\n\n    Net1 = [C,Sym_size*C,Sym_size*C,Sym_size*C,1]       #memory\n    Net2 = [Sym_size,Sym_size,Sym_size]                 #action\n\n\n    current_end   = 0\n    total_neurons = int(sum(Net1)+sum(Net2))\n\n    index1 = np.linspace(1,sum(Net1),sum(Net1),dtype=int)-1\n\n    index2 = np.linspace (sum(Net1)+1,total_neurons,sum(Net2),dtype=int)-1\n\n\n    Ne = total_neurons\n\n\n\n    firings= []                       # spike timings\n\n\n    ADJ    = np.zeros((runtime,Ne))   # record the firing condition\n    abs_Ne = np.zeros(Ne)             # maintain the ABS of every neurons\n\n\n\n    I_stimu = np.zeros((Ne,runtime));\n\n    P       = np.zeros((runtime,3));   # potential of neuron 1\n\n\n\n    # logical vector to differ if the neuron is belong to memory part\n    If_Memory = np.zeros((total_neurons),dtype=bool)\n    If_Memory[index1[:]] = True\n\n\n\n    # logical vector to differ if the neuron is belong to action part\n\n    If_Action = np.zeros((total_neurons),dtype=bool)\n    If_Action[index2[:]] = True\n\n\n\n\n\n    # Pre-set synapses\n\n    S = np.zeros((total_neurons,total_neurons),dtype=float)           #  Initial Weights\n\n    S = S - np.diag(S)      # Set the diag num to\n\n\n    E = np.identity((total_neurons),dtype=float)\n\n\n\n    W_r2a = 0.3\n\n\n    # Memory to Action\n    for i in range (C,sum(Net1)-2 ):\n        S[ int(index2[int(i/C)-1]), i] =W_r2a\n\n\n\n\n    # Learning Process\n    I_stimu = np.zeros((Ne,runtime))\n    seq = np.array([6,3,4])\n\n    l1_stimu = np.arange  (0,C)\n    l2_stimu = np.arange(C+(seq[0]-1)*C,C+(seq[0])*C)\n    l3_stimu = np.arange(C+ 6*C+(seq[1]-1)*C,C+ 6*C+(seq[1])*C)\n    l4_stimu = np.arange(C+12*C+(seq[2]-1)*C,C+12*C+(seq[2])*C)\n    l5_stimu = Ne-1\n\n\n\n\n\n    np.linspace(20 + i * 100 ,20+I_t+i*100-1,I_t ,dtype=int )\n\n    for i in range(learning_times):\n        \"\"\"\n        对网络添加输入电流\n        \"\"\"\n\n\n\n        temp = I_stimu [l1_stimu,:]\n\n        I_stimu [l1_stimu, 10 + i * 100 :10+I_t+i*100] = certainty*I_P + I_P * np.random.rand(C,I_t)\n        I_stimu [l2_stimu, 25 + i * 100 :25+I_t+i*100] = certainty*I_P + I_P * np.random.rand(C,I_t)\n        I_stimu [l3_stimu, 40 + i * 100 :40+I_t+i*100] = certainty*I_P + I_P * np.random.rand(C,I_t)\n        I_stimu [l4_stimu, 55 + i * 100 :55+I_t+i*100] = certainty*I_P + I_P * np.random.rand(C,I_t)\n        I_stimu [l5_stimu, 70 + i * 100 :70+I_t+i*100] = certainty*I_P + I_P * np.random.rand(1,I_t)\n\n\n\n    I_stimu[l1_stimu,700:700+I_t] = I_P * np.random.rand(C,I_t)\n\n\n\n    S = torch.tensor(S,dtype=torch.float32)\n    E = torch.tensor(E,dtype=torch.float32)\n    SPSNN = SPNet(S,E)\n\n\n\n\n    for  t in range (runtime):\n\n        I_input = torch.tensor( I_stimu[:,t].reshape(1,total_neurons),dtype=torch.float32)\n\n\n        x,dw = SPSNN( I_input )\n\n        S   += A_P*dw[1]\n\n        S   += S_bound(S) - S\n\n\n\n        ADJ[t] = x\n\n\n\n\n\n\nplt.matshow(I_stimu)\nplt.matshow(ADJ)\n\nplt.matshow(S)\nplt.colorbar()\nplt.show()\n\n\n\n\n\n\n\n\n\n\n\n\n"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Areas/apac.py",
    "content": "'''\nCreated on 2016.7.7\n\n@author: liangqian\n'''\nfrom Modal.note import Note\nfrom Modal.cluster import Cluster\nfrom conf.conf import configs\nfrom Modal.pitch import Pitch\nclass APAC():\n    '''\n    anterior primary auditory cortex,encoding the musical notes\n    '''\n\n\n    def __init__(self):\n        '''\n        Constructor\n        '''\n        self.notes = []\n        #self.cluster = Cluster()\n        \n    def encodingNote(self,NoteID):\n        NoteName = configs.notesMap.get(int(NoteID))\n        n = Pitch()\n        n.name = NoteName\n        n.frequence = int(NoteID)\n        self.notes.append(n)\n        return n\n    \n    def encodingMIDINote(self,p):\n        NoteName = configs.notesMap.get(p.frequence)\n        p.name = NoteName\n    "
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Areas/cortex.py",
    "content": "'''\nCreated on 2016.7.6\n\n@author: liangqian\n'''\n\nfrom Areas.pfc import PFC\nfrom Areas.pac import PAC\nfrom conf.conf import *\nfrom Modal.synapse import Synapse\nimport random\nimport numpy as np\nfrom Areas.apac import APAC\nfrom Areas.pac import Music_Sequence_Mem\n\nclass Cortex():\n    '''\n    This class is used to control areas in the cortex, just cortex controlling\n    '''\n\n    def __init__(self, neutype, dt):\n        self.neutype = neutype\n        self.msm = Music_Sequence_Mem(neutype)\n        self.pfc = PFC(self.neutype)\n        self.dt = dt\n\n    def addSubGoalToPFC(self, goalname):\n        self.pfc.addNewSubGoal(goalname)\n        ''' tt = np.arange(0, 5, self.dt)\n        for t in tt:\n            self.pfc.doRemebering(goalname, self.dt, t)'''\n\n    def addComposerToPFC(self, composername):\n        self.pfc.addNewComposer(composername)\n        '''tt = np.arange(0, 5, self.dt)\n        for t in tt:\n            self.pfc.doRememberingComposer(composername, self.dt, t)'''\n\n    def addGenreToPFC(self, genrename):\n        self.pfc.addNewGenre(genrename)\n        '''tt = np.arange(0, 5, self.dt)\n        for t in tt:\n            self.pfc.doRememberingGenre(genrename, self.dt, t)'''\n\n    def musicSequenceMemroyInit(self):\n        self.msm.createActionSequenceMem(1, self.neutype)\n\n    def rememberANote(self, goalname, noteName, order):\n        note = self.apac.encodingNote(noteName)\n        if (order > len(self.msm.sequenceLayers.get(1).groups)):\n            self.msm.sequenceLayers.get(1).addNewGroups(GroupID=order, layerID=1, neunum=128)\n        tt = np.arange((order - 1) * 5, order * 5, self.dt)\n        for t in tt:\n            self.ips.doRemebering(goalname, self.dt, t)\n            self.msm.doRemembering_note_only(note, order, self.dt, t)\n        self.msm.doConnectToGoal(self.ips.goals.groups.get(goalname), order)\n        dic = {}\n        if (configs.flag_experiments == False):\n            dic[\"GoalSpike\"] = self.ips.goals.groups.get(goalname).writeSpikeInfoToJson()\n            dic[\"MSMSpike\"] = self.msm.sequenceLayers.get(1).groups.get(order).writeSpikeInfoToJson()\n\n            dic[\"Neuron\"] = self.msm.sequenceLayers.get(1).groups.get(order).writeSelfInfoToJson(\"MSM\")\n            dic[\"GroupNum\"] = order\n        return dic\n\n    def connectKeyAndNotesUsingKSModel(self,keys, Notes):\n        noteNeurons = Notes.neurons\n        for tone in keys.groups.values():\n            tneurons = tone.neurons\n            for nindex, noteneu in enumerate(noteNeurons[1:]):\n                i = nindex%12\n                syn = Synapse(tneurons[i],noteneu) # from tone to pitch\n                if(tneurons[i].importance < 0):\n                    syn.excitability = 0\n                    syn.weight = -100 # 这个地方应该用KS pitch profile，稍等下再改\n                else:\n                    syn.excitability = 1\n                    syn.weight = 20+random.uniform(20,30)\n                syn.type = 3\n                noteneu.synapses.append(syn)\n                noteneu.pre_neurons.append(tneurons[i])\n\n                # syn1 = Synapse(noteneu, tneurons[i])  # from pitch to tone\n                # syn1.excitability = 1\n                # syn1.type = 3\n                # tneurons[i].synapses.append(syn1)\n                # tneurons[i].pre_neurons.append(noteneu)\n\n\n    def connectKeyAndNotes(self,keys, Notes):\n        noteNeurons = Notes.neurons\n        for tone in keys.groups.values():\n            tneurons = tone.neurons\n            for nindex, noteneu in enumerate(noteNeurons[1:]):\n                i = nindex%12\n                syn = Synapse(tneurons[i],noteneu)\n                if(tneurons[i].importance < 0):\n                    syn.excitability = 0\n                    syn.weight = -100\n                else:\n                    syn.excitability = 1\n                    syn.weight = 20+random.uniform(20,30)\n                syn.type = 3\n                noteneu.synapses.append(syn)\n                noteneu.pre_neurons.append(tneurons[i])\n\n    def connectKeyAndNotesUsingKSModel(self,keys, Notes):\n        noteNeurons = Notes.neurons\n        for tone in keys.groups.values():\n            tneurons = tone.neurons\n            for nindex, noteneu in enumerate(noteNeurons[1:]):\n                i = nindex%12\n                syn = Synapse(tneurons[i],noteneu) # from tone to pitch\n                if(tneurons[i].importance < 0):\n                    syn.excitability = 0\n                    syn.weight = -100 # 这个地方应该用KS pitch profile，稍等下再改\n                else:\n                    syn.excitability = 1\n                    syn.weight = 20+random.uniform(20,30)\n                syn.type = 3\n                noteneu.synapses.append(syn)\n                noteneu.pre_neurons.append(tneurons[i])\n\n                # syn1 = Synapse(noteneu, tneurons[i])  # from pitch to tone\n                # syn1.excitability = 1\n                # syn1.type = 3\n                # tneurons[i].synapses.append(syn1)\n                # tneurons[i].pre_neurons.append(noteneu)\n\n    def rememberANoteWithKnowledge(self,goalname, composername, genrename, emo, keyName, trackIndex, noteIndex, tinterval,order, xmldata):\n\n        instrumentTrack = self.msm.sequenceLayers.get(trackIndex)\n        if (order > len(instrumentTrack.get(\"N\").groups)):\n            instrumentTrack.get(\"N\").addNewGroups(GroupID=order, layerID=trackIndex, neunum=129)\n            instrumentTrack.get(\"T\").addNewGroups(GroupID=order, layerID=trackIndex, neunum=64)\n            self.connectKeyAndNotesUsingKSModel(self.pfc.keys,instrumentTrack.get(\"N\").groups.get(order))# prior knowledge\n        tt = np.arange((order - 1) * 5, order * 5, self.dt)\n        for t in tt:\n            self.pfc.doRemebering(goalname, self.dt, t)\n            self.pfc.doRememberingComposer(composername, self.dt, t)\n            self.pfc.doRememberingGenre(genrename, self.dt, t)\n            self.pfc.doRememberingKey(keyName,self.dt,t)\n            self.pfc.doRememberingMode(keyName%2, self.dt,t)\n            self.msm.doRemembering(trackIndex, noteIndex, order, self.dt, t, tinterval)\n\n        self.msm.doConnectToTitle(self.pfc.goals.groups.get(goalname), instrumentTrack, order)\n        self.msm.doConnectToComposer(self.pfc.composers.groups.get(composername), instrumentTrack, order)\n        self.msm.doConnectToGenre(self.pfc.genres.groups.get(genrename), instrumentTrack, order)\n        self.msm.doConnectToKey(self.pfc.keys.groups.get(keyName), instrumentTrack, order, noteIndex)\n        self.msm.doConnectToMode(self.pfc.modes.groups.get(keyName%2), keyName, instrumentTrack, order, noteIndex)\n        dic = {}\n        dic[\"GoalSpike\"] = self.pfc.goals.groups.get(goalname).writeSpikeInfoToJson()\n        dic[\"ComposerSpike\"] = self.pfc.composers.groups.get(composername).writeSpikeInfoToJson()\n        #dic[\"EmotionSpike\"] = self.amy.emos.groups.get(emo).writeSpikeInfoToJson()\n        dic[\"KeySpike\"] = self.pfc.keys.groups.get(keyName).writeSpikeInfoToJson()\n        dic[\"ModeSpike\"] = self.pfc.modes.groups.get(keyName%2).writeSpikeInfoToJson()\n        dic[\"MSMNSpike\"] = instrumentTrack.get(\"N\").groups.get(order).writeSpikeInfoToJson()\n        dic[\"MSMTSpike\"] = instrumentTrack.get(\"T\").groups.get(order).writeSpikeInfoToJson()\n        return(dic)\n\n    def rememberANoteandTempo(self, goalname, composername, genrename, trackIndex, noteIndex, order, tinterval):\n        instrumentTrack = self.msm.sequenceLayers.get(trackIndex)\n        if (order > len(instrumentTrack.get(\"N\").groups)):\n            instrumentTrack.get(\"N\").addNewGroups(GroupID=order, layerID=trackIndex, neunum=129)\n            instrumentTrack.get(\"T\").addNewGroups(GroupID=order, layerID=trackIndex, neunum=64)\n        tt = np.arange((order - 1) * 5, order * 5, self.dt)\n        for t in tt:\n            self.pfc.doRemebering(goalname, self.dt, t)\n            self.pfc.doRememberingComposer(composername, self.dt, t)\n            self.pfc.doRememberingGenre(genrename, self.dt, t)\n            self.msm.doRemembering(trackIndex, noteIndex, order, self.dt, t, tinterval)\n\n        self.msm.doConnectToTitle(self.pfc.goals.groups.get(goalname), instrumentTrack, order)\n        self.msm.doConnectToComposer(self.pfc.composers.groups.get(composername), instrumentTrack, order)\n        self.msm.doConnectToGenre(self.pfc.genres.groups.get(genrename), instrumentTrack, order)\n        dic = {}\n        ngraph = {}\n\n        if (configs.RunTimeState == 1):\n            dic[\"GoalSpike\"] = self.pfc.goals.groups.get(goalname).writeSpikeInfoToJson()\n            dic[\"ComposerSpike\"] = self.pfc.composers.groups.get(composername).writeSpikeInfoToJson()\n            dic[\"MSMSpike\"] = instrumentTrack.get(\"N\").groups.get(order).writeSpikeInfoToJson()\n            dic[\"MSMTSpike\"] = instrumentTrack.get(\"T\").groups.get(order).writeSpikeInfoToJson()\n\n            temp = {}\n            temp[1] = instrumentTrack.get(\"N\").groups.get(order).writeSelfInfoToJson(\"NMSM\")\n            temp[2] = instrumentTrack.get(\"T\").groups.get(order).writeSelfInfoToJson(\"TMSM\")\n            #         dic[\"GroupNum\"] = order\n\n            Nodes = []\n            Edges = []\n            for key, td in temp.items():\n                nlist = td.get(\"Neuron\")\n                for n in nlist:\n                    # if(len(n.get('synapses')) > 0):\n                    d = {}\n                    d['id'] = n.get('area') + '_' + str(n.get('TrackID')) + '_' + str(n.get('GroupID')) + '_' + str(\n                        n.get('Index'))\n                    d['area'] = n.get('area')\n                    if (n.get('area') == 'NMSM'):\n                        d['label'] = configs.notesMap.get(n.get('Index') - 2)\n                    else:\n                        d['label'] = str(n.get('Index') * 60)\n                    Nodes.append(d)\n                    synlist = n.get('synapses')\n                    for syn in synlist:\n                        e = {}\n                        # e['id'] = syn.get('Sarea') + '_'+str(syn.get('SgroupID'))+'_'+str(syn.get('Sindex')) + '_' +syn.get('Tarea') + '_'+str(syn.get('TgroupID')) + '_'+str(syn.get('Tindex'))\n                        e['weight'] = str(syn.get('weight'))\n                        e['source'] = syn.get('Sarea') + '_' + str(syn.get('StrackID')) + '_' + str(\n                            syn.get('SgroupID')) + '_' + str(syn.get('Sindex'))\n                        e['target'] = syn.get('Tarea') + '_' + str(syn.get('TtrackID')) + '_' + str(\n                            syn.get('TgroupID')) + '_' + str(syn.get('Tindex'))\n\n                        Edges.append(e)\n\n            ngraph[\"node\"] = Nodes\n            ngraph[\"edge\"] = Edges\n\n        #print(dic)\n        return dic, ngraph\n\n    def actionSequenceMemoryInit(self):\n        self.asm.createActionSequenceMem(1, self.neutype, 16)\n\n    def recallMusicPFC(self, goalName):\n        self.pfc.setTestStates()\n        self.msm.setTestStates()\n        result = self.pfc.doRecalling2(goalName, self.msm)\n        return result\n\n\n\n    def recallMusicByEpisode(self, episodeNotes):  # using time window search episode\n        self.pfc.setTestStates()\n        self.msm.setTestStates()\n        #         sl = self.msm.sequenceLayers.get(1)\n        #         for index,group in sl.groups.items():\n        #             strs = \"group_\"+str(index)+\":\"\n        #             for n in group.neurons:\n        #                 if(n.preActive == True):\n        #                     strs +=\" neu_Index:\"+str(n.index)+\",\"\n        #             print(strs)\n        result = self.msm.recallByEpisode2(episodeNotes, self.pfc.goals)\n        return result\n\n\n    def generateEx_Nihilo(self, firstNote, durations, length):\n        '''\n        this function is used to generate the main melody, only one track\n        '''\n        self.pfc.setTestStates()\n        self.msm.setTestStates()\n        result = {}\n        track1 = []\n        for i in range(length):\n            dic = {}\n            tt = np.arange(i * 5, (i + 1) * 5, self.dt)\n            for t in tt:\n                self.pfc.inhibiteGoals(self.dt, t)\n                self.msm.generateEx_Nihilo(firstNote, durations, i, self.dt, t)\n\n            panneu = []\n            maxrate = 0.0\n            maxneu = None\n            for ni, neu in enumerate(self.msm.sequenceLayers.get(1).get(\"N\").groups.get(i + 1).neurons):\n                if (neu.preActive == True):\n                    panneu.append(neu)\n                if (len(neu.spiketime) > 0):\n                    # dic['N'] = neu.selectivity\n                    if (len(neu.spiketime) > maxrate):\n                        maxrate = len(neu.spiketime)\n                        maxneu = neu\n            print(maxneu.I)\n            if (dic.get('N') == None):\n                j = random.randint(0, len(panneu) - 1)\n                neu = panneu[j]\n                neu.I = 20\n                for t in tt:\n                    neu.update_normal(self.dt, t)\n                dic['N'] = neu.selectivity\n            else:  # chose the neuron which has the max firing rate\n                dic['N'] = maxneu.selectivity\n\n            patneu = []\n            maxrate = 0.0\n            maxneu = None\n            for neu in self.msm.sequenceLayers.get(1).get(\"T\").groups.get(i + 1).neurons:\n                if (neu.preActive == True): patneu.append(neu)\n                if (len(neu.spiketime) > 0):\n                    if (len(neu.spiketime) > maxrate):\n                        maxrate = len(neu.spiketime)\n                        maxneu = neu\n\n            if (dic.get('T') == None):\n                j = random.randint(0, len(patneu) - 1)\n                neu = patneu[j]\n                neu.I = 20\n                for t in tt:\n                    neu.update_normal(self.dt, t)\n                dic['T'] = neu.selectivity\n            else:\n                dic['T'] = neu.selectivity\n\n            track1.append(dic)\n        result[1] = track1\n        #print(result)\n        return result\n\n    def generateEx_Nihilo2(self, firstNote, durations, length):\n        '''\n        this function is used to generate the main melody, only one track\n        '''\n        self.pfc.setTestStates()\n        self.msm.setTestStates()\n        result = {}\n        track1 = []\n        for i in range(length):\n            dic = {}\n            tt = np.arange(i * 5, (i + 1) * 5, self.dt)\n            for t in tt:\n                self.pfc.inhibiteGoals(self.dt, t)\n                self.msm.generateEx_Nihilo(firstNote, durations, i, self.dt, t)\n\n            panneu = []\n            for ni, neu in enumerate(self.msm.sequenceLayers.get(1).get(\"N\").groups.get(i + 1).neurons):\n                if (neu.preActive == True):\n                    panneu.append(neu)\n                if (len(neu.spiketime) > 0):\n                    dic['N'] = neu.selectivity\n\n            if (dic.get('N') == None):\n                j = random.randint(0, len(panneu) - 1)\n                neu = panneu[j]\n                neu.I = 20\n                for t in tt:\n                    neu.update_normal(self.dt, t)\n                dic['N'] = neu.selectivity\n\n            patneu = []\n            for neu in self.msm.sequenceLayers.get(1).get(\"T\").groups.get(i + 1).neurons:\n                if (neu.preActive == True): patneu.append(neu)\n                if (len(neu.spiketime) > 0):\n                    dic['T'] = neu.selectivity\n\n            if (dic.get('T') == None):\n                j = random.randint(0, len(patneu) - 1)\n                neu = patneu[j]\n                neu.I = 20\n                for t in tt:\n                    neu.update_normal(self.dt, t)\n                dic['T'] = neu.selectivity\n\n            track1.append(dic)\n            result[1] = track1\n        return result\n\n    def generateEx_NihiloAccordingToGenre(self, genreName, firstNote, durations, length):\n        genreName = genreName.title()\n        self.pfc.setTestStates()\n        self.msm.setTestStates()\n        result = {}\n        track1 = []\n\n        for i in range(length):\n            dic = {}\n            tt = np.arange(i * 5, (i + 1) * 5, self.dt)\n            for t in tt:\n                self.pfc.inhibiteGoals(self.dt, t)\n                self.pfc.inhibitComposers(self.dt, t)\n                self.pfc.doRememberingGenre(genreName, self.dt, t)\n                self.msm.generateEx_Nihilo(firstNote, durations, i, self.dt, t)\n\n            panneu = []\n            for ni, neu in enumerate(self.msm.sequenceLayers.get(1).get(\"N\").groups.get(i + 1).neurons):\n                if (neu.preActive == True):\n                    panneu.append(neu)\n                if (len(neu.spiketime) > 0):\n                    dic['N'] = neu.selectivity\n\n            if (dic.get('N') == None):\n                j = random.randint(0, len(panneu) - 1)\n                neu = panneu[j]\n                neu.I = 20\n                for t in tt:\n                    neu.update_normal(self.dt, t)\n                dic['N'] = neu.selectivity\n\n            patneu = []\n            for neu in self.msm.sequenceLayers.get(1).get(\"T\").groups.get(i + 1).neurons:\n                if (neu.preActive == True): patneu.append(neu)\n                if (len(neu.spiketime) > 0):\n                    dic['T'] = neu.selectivity\n\n            if (dic.get('T') == None):\n                j = random.randint(0, len(patneu) - 1)\n                neu = patneu[j]\n                neu.I = 20\n                for t in tt:\n                    neu.update_normal(self.dt, t)\n                dic['T'] = neu.selectivity\n\n            track1.append(dic)\n            result[1] = track1\n        return result\n\n    def generateEx_NihiloAccordingToComposer(self, composerName, firstNote, durations, length):\n        composerName = composerName.title()\n        self.pfc.setTestStates()\n        self.msm.setTestStates()\n        result = {}\n        track1 = []\n\n        for i in range(length):\n            dic = {}\n            tt = np.arange(i * 5, (i + 1) * 5, self.dt)\n            for t in tt:\n                self.pfc.inhibiteGoals(self.dt, t)\n                self.pfc.doRememberingComposer(composerName, self.dt, t)\n                self.msm.generateEx_Nihilo(firstNote, durations, i, self.dt, t)\n            panneu = []\n            maxrate = 0.0\n            maxneu = None\n            for ni, neu in enumerate(self.msm.sequenceLayers.get(1).get(\"N\").groups.get(i + 1).neurons):\n                if (neu.preActive == True):\n                    panneu.append(neu)\n                if (len(neu.spiketime) > 0):\n                    # dic['N'] = neu.selectivity\n                    #print(str(neu.selectivity) + \":\" + str(len(neu.spiketime)))\n                    if (len(neu.spiketime) > maxrate):\n                        maxrate = len(neu.spiketime)\n                        maxneu = neu\n            if (dic.get('N') == None):\n                j = random.randint(0, len(panneu) - 1)\n                neu = panneu[j]\n                neu.I = 20\n                for t in tt:\n                    neu.update_normal(self.dt, t)\n                dic['N'] = neu.selectivity\n            else:  # chose the neuron which has the max firing rate\n                dic['N'] = maxneu.selectivity\n\n            patneu = []\n            maxrate = 0.0\n            maxneu = None\n            for neu in self.msm.sequenceLayers.get(1).get(\"T\").groups.get(i + 1).neurons:\n                if (neu.preActive == True): patneu.append(neu)\n                if (len(neu.spiketime) > 0):\n                    if (len(neu.spiketime) > maxrate):\n                        maxrate = len(neu.spiketime)\n                        maxneu = neu\n\n            if (dic.get('T') == None):\n                j = random.randint(0, len(patneu) - 1)\n                neu = patneu[j]\n                neu.I = 20\n                for t in tt:\n                    neu.update_normal(self.dt, t)\n                dic['T'] = neu.selectivity\n            else:\n                dic['T'] = neu.selectivity\n\n            track1.append(dic)\n        result[1] = track1\n        #print(result)\n        return result\n\n    def generateMelodyWithKey(self, key, firstNotes, durations, length):\n        result = {}\n\n        self.pfc.setTestStates()\n        self.msm.setTestStates()\n\n        row,col = firstNotes.shape\n\n        for i in range(length):\n\n            time = np.arange(i*5,(i+1)*5,self.dt)\n            for t in time:\n                self.pfc.inhibiteGoals(self.dt,t)\n                self.pfc.inhibitComposers(self.dt,t)\n                self.pfc.inhibitGenres(self.dt,t)\n                self.pfc.doRememberingKey(key,self.dt,t)\n                self.msm.generateMelodyWithTone(firstNotes[:,i] if i < col else None, durations[:,i] if durations else None,key,i+1,self.dt,t)\n            # find the max firing rate\n            for j,sl in self.msm.sequenceLayers.items():\n                if j > 4: break;\n                #print(\"***********************this is part \" + str(j) + \"****************************\")\n                dic = {}\n                if i < col:\n                    dic[\"N\"] = firstNotes[j-1][i]\n                    dic[\"T\"] = durations[j-1][i] if durations else 1.0\n                else:\n                    maxrate = 0\n                    for nn in sl.get(\"N\").groups.get(i+1).neurons:\n\n                        if nn.preActive:\n                            # print(\"------order: \"+str(i+1)+\"-------\")\n                            # print(nn.selectivity)\n                            # print(nn.I)\n                            # print(nn.I_upper)\n                            # print(nn.I_lower)\n                            # print(len(nn.spiketime))\n                            if len(nn.spiketime) > maxrate:\n                                maxrate = len(nn.spiketime)\n                                #print(\"top1:\"+str(nn.selectivity))\n                                dic[\"N\"] = nn.selectivity\n                    if durations is not None:\n                        dic[\"T\"] = 1\n                    else:\n                        maxrate = 0\n                        for tn in sl.get(\"T\").groups.get(i + 1).neurons:\n                            if tn.preActive:\n                                if len(tn.spiketime) > maxrate:\n                                    maxrate = len(tn.spiketime)\n                                    #print(tn.selectivity)\n                                    dic[\"N\"] = nn.selectivity\n                if dic.get(\"T\") is None:\n                    dic[\"T\"] = 1.0\n                if result.get(j) is None:\n                    part = []\n                    part.append(dic)\n                    result[j] = part\n                else:\n                    result.get(j).append(dic)\n        return result\n\n    def recallActionIPS(self, goalName):\n        self.pfc.setTestStates()\n        self.pfc.setTestStates()\n        self.pfc.doRecalling(goalName, self.asm)\n\n    def generate2TrackMusic(self, firstNotes, durations, lengths):\n        self.pfc.setTestStates()\n        self.msm.setTestStates()\n        result = {}\n        for k, notes in firstNotes.items():\n            track1 = []\n            for i in range(lengths[k - 1]):\n                dic = {}\n                tt = np.arange(i * 5, (i + 1) * 5, self.dt)\n                for t in tt:\n                    self.pfc.inhibiteGoals(self.dt, t)\n                    # self.msm.generateSimgleTrackNotes(j+1, firstNotes[j], durations[j], i, self.dt, t)\n                    self.msm.generateSimgleTrackNotes(k, notes, durations.get(k), i, self.dt, t)\n\n                panneu = []\n                for ni, neu in enumerate(self.msm.sequenceLayers.get(k).get(\"N\").groups.get(i + 1).neurons):\n                    if (neu.preActive == True):\n                        panneu.append(neu)\n                    if (len(neu.spiketime) > 0):\n                        dic['N'] = neu.selectivity\n\n                if (dic.get('N') == None):\n                    j = random.randint(0, len(panneu) - 1)\n                    neu = panneu[j]\n                    neu.I = 20\n                    for t in tt:\n                        neu.update_normal(self.dt, t)\n                    dic['N'] = neu.selectivity\n\n                patneu = []\n                for neu in self.msm.sequenceLayers.get(k).get(\"T\").groups.get(i + 1).neurons:\n                    if (neu.preActive == True): patneu.append(neu)\n                    if (len(neu.spiketime) > 0):\n                        dic['T'] = neu.selectivity\n\n                if (dic.get('T') == None):\n                    j = random.randint(0, len(patneu) - 1)\n                    neu = patneu[j]\n                    neu.I = 20\n                    for t in tt:\n                        neu.update_normal(self.dt, t)\n                    dic['T'] = neu.selectivity\n\n                track1.append(dic)\n            result[k] = track1\n        print(result)\n        return result"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Areas/pac.py",
    "content": "'''\nPrimary Auditory Area\n'''\n\nfrom braincog.base.brainarea.BrainArea import BrainArea\n\nfrom Modal.sequencememory import SequenceMemory\nfrom Modal.notesequencelayer import NoteSequenceLayer\nfrom Modal.temposequencelayer import TempoSequenceLayer\nfrom Modal.synapse import Synapse\nimport numpy as np\nimport math\nfrom conf.conf import *\n\n\nclass PAC(BrainArea,SequenceMemory):\n    '''\n     the planum polare, anterior to PAC, as well as in the left planum temporale,posterior to PAC.\n    '''\n\n    def __init__(self, neutype):\n        '''\n        Constructor\n        '''\n        SequenceMemory.__init__(self, neutype)\n\n    def forward(self, x):\n        pass\n\n    def createActionSequenceMem(self, layernum, neutype):\n\n        sl = NoteSequenceLayer(neutype)\n        tl = TempoSequenceLayer(neutype)\n        instrumentTrack = {}\n        instrumentTrack[\"N\"] = sl\n        instrumentTrack[\"T\"] = tl\n        self.sequenceLayers[layernum] = instrumentTrack\n        print(len(self.sequenceLayers))\n\n    def doRemembering_note_only(self, note, order, dt, t):\n        # remember note\n        sl = self.sequenceLayers.get(1)\n        sgroup = sl.groups.get(order)\n        dt = 0.1\n        for n in sgroup.neurons:\n            n.I_ext = note.frequence\n            n.computeFilterCurrent()\n            n.update(dt, t, 'Learn')\n\n    def doRemembering(self, trackIndex, noteIndex, order, dt, t, tinterval=0):\n        # remember note\n        iTrack = self.sequenceLayers.get(trackIndex)\n        sl = iTrack.get(\"N\")\n        sgroup = sl.groups.get(order)\n        dt = 0.1\n        for n in sgroup.neurons:\n            n.I_ext = noteIndex\n            n.computeFilterCurrent()\n            n.update(dt, t, 'Learn')\n\n        # remember tempo\n        tl = iTrack.get(\"T\")\n        tgroup = tl.groups.get(order)\n        dt = 0.1\n        for n in tgroup.neurons:\n            n.I_ext = tinterval\n            n.computeFilterCurrent()\n            n.update(dt, t, 'Learn')\n\n\n\n    def doConnectToTitle(self, title, track, order):\n        for sl in track.values():\n            self.doConnecting(title, sl, order)\n\n    def doConnectToComposer(self, composer, track, order):\n        for sl in track.values():\n            self.doConnecting(composer, sl, order)\n\n    def doConnectToGenre(self, genre, track, order):\n        for sl in track.values():\n            self.doConnecting(genre, sl, order)\n\n    def generateEx_Nihilo(self, firstNote, durations, order, dt, t):\n        ns = self.sequenceLayers.get(1).get(\"N\")\n        ts = self.sequenceLayers.get(1).get(\"T\")\n        nneurons = ns.groups.get(order + 1).neurons\n        tneurons = ts.groups.get(order + 1).neurons\n        # firstNotes specify the beginning notes to trigger the following notes\n        if (order < len(firstNote)):  # beginning notes\n            i = firstNote[order]\n            nneu = nneurons[i + 1]\n            nneu.I = 20\n            nneu.update_normal(dt, t)\n\n            d = int(durations[order] / 0.125) - 1\n            tneu = tneurons[d]\n            tneu.I = 20\n\n            tneu.update_normal(dt, t)\n        else:  # generate next note\n            for nn in nneurons:\n                nn.updateCurrentOfLowerAndUpperLayer(t)\n                nn.update(dt, t, 'test')\n            for tn in tneurons:\n                tn.updateCurrentOfLowerAndUpperLayer(t)\n                tn.update(dt, t, 'test')\n\n    def generateSimgleTrackNotes(self, trackIndex, firstNote, durations, order, dt, t):\n        ns = self.sequenceLayers.get(trackIndex).get(\"N\")\n        ts = self.sequenceLayers.get(trackIndex).get(\"T\")\n        nneurons = ns.groups.get(order + 1).neurons\n        tneurons = ts.groups.get(order + 1).neurons\n        # firstNotes specify the beginning notes to trigger the following notes\n        if (order < len(firstNote)):  # beginning notes\n            i = firstNote[order]\n            nneu = nneurons[i + 1]\n            nneu.I = 20\n            nneu.update_normal(dt, t)\n\n            d = int(durations[order] / 0.125) - 1\n            tneu = tneurons[d]\n            tneu.I = 20\n            tneu.update_normal(dt, t)\n        else:  # generate next note\n            for nn in nneurons:\n                nn.updateCurrentOfLowerAndUpperLayer(t)\n                nn.update(dt, t, 'test')\n            #                 if(neu.spike == True):\n            #                     print(neu.selectivity)\n            for tn in tneurons:\n                tn.updateCurrentOfLowerAndUpperLayer(t)\n                tn.update(dt, t, 'test')\n\n\nclass Music_Sequence_Mem(SequenceMemory):\n    '''\n     the planum polare, anterior to PAC, as well as in the left planum temporale,posterior to PAC.\n    '''\n\n    def __init__(self, neutype):\n        '''\n        Constructor\n        '''\n        SequenceMemory.__init__(self, neutype)\n\n    def createActionSequenceMem(self, layernum, neutype):\n\n        sl = NoteSequenceLayer(neutype)\n        tl = TempoSequenceLayer(neutype)\n        instrumentTrack = {}\n        instrumentTrack[\"N\"] = sl\n        instrumentTrack[\"T\"] = tl\n        self.sequenceLayers[layernum] = instrumentTrack\n        print(len(self.sequenceLayers))\n\n    def doRemembering_note_only(self, note, order, dt, t):\n        # remember note\n        sl = self.sequenceLayers.get(1)\n        sgroup = sl.groups.get(order)\n        dt = 0.1\n        for n in sgroup.neurons:\n            n.I_ext = note.frequence\n            n.computeFilterCurrent()\n            n.update(dt, t, 'Learn')\n\n    def doRemembering(self, trackIndex, noteIndex, order, dt, t, tinterval=0):\n        # remember note\n        iTrack = self.sequenceLayers.get(trackIndex)\n        sl = iTrack.get(\"N\")\n        sgroup = sl.groups.get(order)\n        dt = 0.1\n        for n in sgroup.neurons:\n            n.I_ext = noteIndex\n            n.computeFilterCurrent()\n            n.update(dt, t, 'Learn')\n\n        # remember tempo\n        tl = iTrack.get(\"T\")\n        tgroup = tl.groups.get(order)\n        dt = 0.1\n        for n in tgroup.neurons:\n            n.I_ext = tinterval\n            n.computeFilterCurrent()\n            n.update(dt, t, 'Learn')\n\n    def recallByEpisode(self, episodeNotes, goals):\n        dt = 0.1\n        sl = self.sequenceLayers.get(1)\n        firstresult = {}\n        firstNote = episodeNotes[0]\n        tt = np.arange(0, 5, dt)\n\n        # find first note and activate goal neurons\n        for t in tt:\n            for id, group in sl.groups.items():\n                neuid = firstNote - 15;\n                if (group.neurons[neuid - 1].preActive == False): continue\n                group.neurons[neuid - 1].I_ext = 20\n                group.neurons[neuid - 1].updateCurrentOfLowerAndUpperLayer(t)\n                group.neurons[neuid - 1].update(dt, t, 'test')\n                if (group.neurons[neuid - 1].spike == True):\n                    firstresult[id] = 1\n            for name, g in goals.groups.items():\n                for neu in g.neurons:\n                    neu.updateCurrentOfLowerAndUpperLayer(t)\n                    neu.update(dt, t)\n\n        # find rest episode notes\n        restResult = {}\n        # for i in range(1,len(episodeNotes)):\n        goalchecked = {}\n        for groupID, value in firstresult.items():\n            tmp = {}\n            for i in range(1, len(episodeNotes)):\n                fre = episodeNotes[i]\n                tt = np.arange(i * 5, (i + 1) * 5, dt)\n                g = sl.groups.get(groupID + i)\n                for t in tt:\n                    # update memory\n                    neuid = fre - 15;\n                    if (g.neurons[neuid - 1].preActive == False): continue\n                    g.neurons[neuid - 1].I_ext = 20\n                    g.neurons[neuid - 1].updateCurrentOfLowerAndUpperLayer(t)\n                    if (g.neurons[neuid - 1].I_lower == 0): break\n                    g.neurons[neuid - 1].update(dt, t, 'test')\n                    if (g.neurons[neuid - 1].spike == True):\n                        tmp[i] = g.id\n\n                    # update goals' neurons\n                    for name, gg in goals.groups.items():\n                        for neu in gg.neurons:\n                            neu.updateCurrentOfLowerAndUpperLayer(t)\n                            neu.update(dt, t)\n                            # if(neu.spike == True):print(name)\n            # find\n            maxFiringRate = 0\n            maxGoalName = ''\n            maxGoal = {}  # an episode may be mapped to more than one songs\n            for name, gg in goals.groups.items():\n                if (goalchecked.get(name) == None):\n                    averageFiringRate = 0\n                    for neu in gg.neurons:\n                        averageFiringRate = averageFiringRate + len(neu.spiketime)\n                    averageFiringRate = float(averageFiringRate) / float(len(gg.neurons))\n                    gg.averageFiringRate = averageFiringRate\n                    if (averageFiringRate > maxFiringRate):\n                        maxFiringRate = averageFiringRate\n                        maxGoalName = name\n            maxGoal[maxGoalName] = 1\n            goalchecked[maxGoalName] = 1\n            for name, gg in goals.groups.items():\n                if (gg.averageFiringRate == maxFiringRate and goalchecked.get(name) == None):\n                    maxGoal[name] = 1\n                    goalchecked[name] = 1\n            tmp['goal'] = maxGoal\n            restResult[groupID] = tmp\n        print(restResult)\n        episodeResult = []\n        for key, value in firstresult.items():\n            tmp = {}\n            tmp[0] = key\n            dic = restResult.get(key)\n            for i in range(1, len(episodeNotes)):\n                if (dic.get(i) != None):\n                    tmp[i] = dic.get(i)\n            if (len(tmp) == len(episodeNotes)):\n                tmp['goal'] = dic.get('goal')\n                episodeResult.append(tmp)\n        print(episodeResult)\n\n        for res in episodeResult:\n            msmgroupID = res.get(len(episodeNotes) - 1) + 1\n            for gname, value in res.get('goal').items():\n                gg = goals.groups.get(gname)\n\n    def recallByEpisode2(self, episodeNotes, goals):\n        dt = 0.1\n        sl = self.sequenceLayers.get(1).get(\"N\")\n        tl = self.sequenceLayers.get(1).get(\"T\")\n        print(len(sl.groups))\n        firstresult = {}\n        firstNote = episodeNotes[0]\n        tt = np.arange(0, 5, dt)\n\n        # find first note and activate goal neurons\n        for t in tt:\n            for id, group in sl.groups.items():\n                neuid = firstNote;\n                if (group.neurons[neuid + 1].preActive == False): continue\n                # group.neurons[neuid+1].I_ext = 20\n                # group.neurons[neuid+1].updateCurrentOfLowerAndUpperLayer(t)\n                group.neurons[neuid + 1].I = 20\n                group.neurons[neuid + 1].update(dt, t, 'test')\n                if (group.neurons[neuid + 1].spike == True):\n                    firstresult[id] = 1\n\n        # find rest episode notes\n        restResult = {}\n        # for i in range(1,len(episodeNotes)):\n        goalchecked = {}\n        for groupID in firstresult.keys():\n            #\n            sl.setTestStates()\n\n            tmp = {}\n            tt = np.arange(0, 5, dt)\n            g = sl.groups.get(groupID)\n            neuid = firstNote\n            # g.neurons[neuid+1].I_ext = 20\n            g.neurons[neuid + 1].I = 20\n            for t in tt:\n                # g.neurons[neuid+1].updateCurrentOfLowerAndUpperLayer(t)\n                g.neurons[neuid + 1].update(dt, t, 'test')\n\n            for i in range(1, len(episodeNotes)):\n                fre = episodeNotes[i]\n                tt = np.arange(i * 5, (i + 1) * 5, dt)\n                if (groupID + i > len((sl.groups))): continue\n                g = sl.groups.get(groupID + i)\n                for t in tt:\n                    # update memory\n                    neuid = fre;\n                    if (g.neurons[neuid + 1].preActive == False): continue\n                    g.neurons[neuid + 1].I_ext = 20\n                    g.neurons[neuid + 1].updateCurrentOfLowerAndUpperLayer(t)\n                    if (g.neurons[neuid + 1].I_lower == 0): break\n                    g.neurons[neuid + 1].update(dt, t, 'test')\n                    if (g.neurons[neuid + 1].spike == True):\n                        tmp[i] = g.id\n\n            restResult[groupID] = tmp\n            # print(restResult)\n        episodeResult = []\n        for key in firstresult.keys():\n            tmp = {}\n            tmp[0] = key\n            dic = restResult.get(key)\n            for i in range(1, len(episodeNotes)):\n                if (dic.get(i) != None):\n                    tmp[i] = dic.get(i)\n            if (len(tmp) == len(episodeNotes)):\n                episodeResult.append(tmp)\n        # print(episodeResult)\n\n        # begin remembering\n        finalResult = []\n        for i, res in enumerate(episodeResult):\n            goals.setTestStates()\n            self.setTestStates()\n            for i, fre in enumerate(episodeNotes):\n                tt = np.arange(i * 5, (i + 1) * 5, dt)\n                neuid = fre\n                g = sl.groups.get(res.get(i))\n                for t in tt:\n                    # g.neurons[neuid+1].I_ext = 20\n                    # g.neurons[neuid+1].updateCurrentOfLowerAndUpperLayer(t)\n                    g.neurons[neuid + 1].I = 20\n                    g.neurons[neuid + 1].update(dt, t, 'test')\n\n                    for name, gg in goals.groups.items():\n                        for neu in gg.neurons:\n                            neu.updateCurrentOfLowerAndUpperLayer(t)\n                            neu.update(dt, t)\n            # find goal\n            maxFiringRate = 0\n            maxGoalName = ''\n            maxGoal = {}  # an episode may be mapped to more than one songs\n            for name, gg in goals.groups.items():\n                averageFiringRate = 0\n                for neu in gg.neurons:\n                    averageFiringRate = averageFiringRate + len(neu.spiketime)\n                averageFiringRate = float(averageFiringRate) / float(len(gg.neurons))\n                gg.averageFiringRate = averageFiringRate\n                if (averageFiringRate > maxFiringRate):\n                    maxFiringRate = averageFiringRate\n                    maxGoalName = name\n            maxGoal[maxGoalName] = 1\n            for name, gg in goals.groups.items():\n                if (gg.averageFiringRate == maxFiringRate and goalchecked.get(name) == None):\n                    maxGoal[name] = 1\n            # print(maxGoal)\n\n            # recall the rest song\n            for goalname, value in maxGoal.items():\n                # reset State\n                goals.setTestStates()\n                nextGroupId = res.get(len(episodeNotes) - 1) + 1\n                for i in range(nextGroupId, len(sl.groups) + 1):\n                    sl.groups.get(i).setTestStates()\n                    tl.groups.get(i).setTestStates()\n                restSpikeResult = {}\n                count = 0\n                gg = goals.groups.get(goalname)\n                for i in range(nextGroupId, len(sl.groups) + 1):\n                    order = len(episodeNotes) + count\n                    if (order == 3):\n                        print(\"debug\")\n                    tt = np.arange(order * 5, (order + 1) * 5, dt)\n                    msmgroup = sl.groups.get(i)\n                    msmtgroup = tl.groups.get(i)\n                    tdic = {}\n                    for t in tt:\n                        for n in gg.neurons:\n                            # n.updateCurrentOfLowerAndUpperLayer(t)\n                            n.I = 30\n                            n.update_normal(dt, t)\n\n                        for neu in msmgroup.neurons:\n                            neu.updateCurrentOfLowerAndUpperLayer(t)\n                            neu.update(dt, t, 'test')\n                            if (neu.spike == True and restSpikeResult.get(order) == None):\n                                # restSpikeResult[int(order)] = neu.selectivity\n                                tdic[\"N\"] = neu.selectivity\n\n                        for neu in msmtgroup.neurons:\n                            neu.updateCurrentOfLowerAndUpperLayer(t)\n                            neu.update(dt, t, 'test')\n                            if (neu.spike == True and restSpikeResult.get(order) == None):\n                                # restSpikeResult[int(order)] = neu.selectivity\n                                tdic[\"T\"] = neu.selectivity\n                    if (tdic):\n                        restSpikeResult[int(order)] = tdic\n                    count += 1\n                # print(restSpikeResult)\n                dic = {}\n                dic['goal'] = goalname\n                dic['rest'] = restSpikeResult\n                finalResult.append(dic)\n        return (finalResult)\n\n    def doConnectToTitle(self, title, track, order):\n        for sl in track.values():\n            self.doConnecting(title, sl, order)\n\n    def doConnectToComposer(self, composer, track, order):\n        for sl in track.values():\n            self.doConnecting(composer, sl, order)\n\n    def doConnectToGenre(self, genre, track, order):\n        for sl in track.values():\n            self.doConnecting(genre, sl, order)\n\n    def doConnectToEmotion(self, emo, track, order):\n        for sl in track.values():\n            self.doConnecting(emo, sl, order)\n\n    def doConnectToKey(self, key, track, order, noteIndex):\n        group = track.get(\"N\").groups.get(order)\n        if group == None: return\n        tb = (order - 1) * group.timeWindow\n        te = (order) * group.timeWindow\n        sp1_goal = {}\n        sp2 = []\n        kneu = key.neurons[noteIndex % 12]\n        sp = []\n        for st in kneu.spiketime:\n            if (st < te and st >= tb):\n                sp.append(st)\n        sp1_goal[kneu.index] = sp\n\n        n = group.neurons[noteIndex + 1]\n        if (len(n.spiketime) > 0):\n            for index, sp in sp1_goal.items():\n                temp = 0\n                for sp1 in n.spiketime:  # spike times of group\n                    for sp2 in sp:\n                        if (abs(sp1 - sp2) <= n.timeWindow):\n                            temp += 1\n                if (temp >= 2):  # super threshold, create a new synapse between goal and neurons of sequence group\n                    # syn = Synapse(goal.neurons[index - 1], n)\n                    # syn.type = 2\n                    # syn.weight = 30\n                    # n.pre_neurons.append(goal.neurons[index - 1])\n                    # n.synapses.append(syn)\n\n                    # add reverse synapse to neurons of the goal\n                    syn2 = Synapse(n, kneu)\n                    syn2.type = 3\n                    syn2.weight = 0\n                    kneu.synapses.append(syn2)\n\n    def doConnectToMode(self, mode, keyName, track, order, noteIndex):  # 只连接对应调式的音级(这是两个神经元之间的连接)\n        noteScales = np.where((configs.keyscales.get(keyName % 2)[keyName // 2]) == noteIndex % 12)[0][0]\n        group = track.get(\"N\").groups.get(order)\n        if (group == None): return\n        tb = (order - 1) * group.timeWindow\n        te = (order) * group.timeWindow\n        sp1_goal = {}\n        sp2 = []\n\n        modeneu = mode.neurons[noteScales]\n        sp = []\n        for st in modeneu.spiketime:\n            if (st < te and st >= tb):\n                sp.append(st)\n        sp1_goal[modeneu.index] = sp\n\n        n = group.neurons[noteIndex + 1]\n        if (len(n.spiketime) > 0):\n            for index, sp in sp1_goal.items():\n                temp = 0\n                for sp1 in n.spiketime:  # spike times of group\n                    for sp2 in sp:\n                        if (abs(sp1 - sp2) <= n.timeWindow):\n                            temp += 1\n                if (temp >= 2):  # super threshold, create a new synapse between goal and neurons of sequence group\n                    # syn = Synapse(goal.neurons[index - 1], n)\n                    # syn.type = 2\n                    # syn.weight = 30\n                    # n.pre_neurons.append(goal.neurons[index - 1])\n                    # n.synapses.append(syn)\n\n                    # add reverse synapse to neurons of the goal\n                    syn2 = Synapse(n, modeneu)\n                    syn2.type = 3\n                    syn2.weight = 0\n                    modeneu.synapses.append(syn2)\n\n    def generateEx_Nihilo(self, firstNote, durations, order, dt, t):\n        ns = self.sequenceLayers.get(1).get(\"N\")\n        ts = self.sequenceLayers.get(1).get(\"T\")\n        nneurons = ns.groups.get(order + 1).neurons\n        tneurons = ts.groups.get(order + 1).neurons\n        # firstNotes specify the beginning notes to trigger the following notes\n        if (order < len(firstNote)):  # beginning notes\n            i = firstNote[order]\n            nneu = nneurons[i + 1]\n            nneu.I = 20\n            nneu.update_normal(dt, t)\n\n            d = int(durations[order] / 0.125) - 1\n            tneu = tneurons[d]\n            tneu.I = 20\n\n            tneu.update_normal(dt, t)\n        else:  # generate next note\n            for nn in nneurons:\n                nn.updateCurrentOfLowerAndUpperLayer(t)\n                nn.update(dt, t, 'test')\n            for tn in tneurons:\n                tn.updateCurrentOfLowerAndUpperLayer(t)\n                tn.update(dt, t, 'test')\n\n    def generateMelodyWithTone(self, firstNote, duration, tone, order, dt, t):\n        for i, part in self.sequenceLayers.items():\n            if i > 4: break\n            ns = part.get(\"N\")\n            ts = None\n            if duration is not None:\n                ts = part.get(\"T\")  # 不输入值的话就都生成四分音符\n\n            # pitches updating\n            if firstNote is not None:\n                for nneu in ns.groups.get(order).neurons:\n                    nneu.I_ext = firstNote[i - 1]\n                    nneu.computeFilterCurrent()\n                    nneu.update_normal(dt, t)\n            else:\n                for nneu in ns.groups.get(order).neurons:\n                    nneu.updateCurrentOfLowerAndUpperLayer(t)\n                    nneu.update_normal(dt, t)\n\n            # durations updating\n            if ts is not None:\n                if duration is not None:\n                    for tneu in ts.groups.get(order).neurons:\n                        tneu.I_ext = duration[i - 1]\n                        tneu.computeFilterCurrent()\n                        tneu.update_normal(dt, t)\n                else:\n                    for tneu in ts.groups.get(order).neurons:\n                        tneu.updateCurrentOfLowerAndUpperLayer(t)\n                        tneu.update_normal(dt, t)\n\n    def generateSimgleTrackNotes(self, trackIndex, firstNote, durations, order, dt, t):\n        ns = self.sequenceLayers.get(trackIndex).get(\"N\")\n        ts = self.sequenceLayers.get(trackIndex).get(\"T\")\n        nneurons = ns.groups.get(order + 1).neurons\n        tneurons = ts.groups.get(order + 1).neurons\n        # firstNotes specify the beginning notes to trigger the following notes\n        if (order < len(firstNote)):  # beginning notes\n            i = firstNote[order]\n            nneu = nneurons[i + 1]\n            nneu.I = 20\n            nneu.update_normal(dt, t)\n\n            d = int(durations[order] / 0.125) - 1\n            tneu = tneurons[d]\n            tneu.I = 20\n            tneu.update_normal(dt, t)\n        else:  # generate next note\n            for nn in nneurons:\n                nn.updateCurrentOfLowerAndUpperLayer(t)\n                nn.update(dt, t, 'test')\n            #                 if(neu.spike == True):\n            #                     print(neu.selectivity)\n            for tn in tneurons:\n                tn.updateCurrentOfLowerAndUpperLayer(t)\n                tn.update(dt, t, 'test')"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Areas/pfc.py",
    "content": "import numpy as np\nimport math\nfrom braincog.base.brainarea.PFC import PFC\nfrom Modal.synapse import Synapse\nfrom Modal.titlelayer import TitleLayer\nfrom Modal.composerlayer import ComposerLayer\nfrom Modal.genrelayer import GenreLayer\nfrom conf.conf import *\nfrom Modal.layer import *\n\n\nclass PFC(PFC):\n    '''\n    This area is used to store the sub-Goal of the task\n    '''\n\n    def __init__(self, neutype):\n        '''\n        Constructor\n        '''\n        super().__init__()\n        self.neutype = neutype\n        self.goals = TitleLayer(self.neutype)  # store the musical titles\n        self.composers = ComposerLayer(self.neutype)  # store composers\n        self.keys = KeyLayer(self.neutype)\n        self.modes = ModeLayer(self.neutype)\n        self.genres = GenreLayer(self.neutype)\n        self.chords = ChordLayer(self.neutype)\n\n\n    def addNewKey(self):\n        row, col = configs.key_matrix.shape\n        for i in range(row):\n            #print(configs.key_matrix[i,:])\n            self.keys.addNewGroups(i+1, 1, col, configs.key_matrix[i,:])\n\n    def addNewSubGoal(self, goalname):\n        if (self.goals.groups.get(goalname) == None):\n            self.goals.addNewGroups(len(self.goals.groups) + 1, 1, 1, goalname)\n\n    def addNewComposer(self, composername):\n        if (self.composers.groups.get(composername) == None):\n            self.composers.addNewGroups(len(self.composers.groups) + 1, 1, 1, composername)\n\n    def addNewGenre(self, genrename):\n        if (self.genres.groups.get(genrename) == None):\n            self.genres.addNewGroups(len(self.genres.groups) + 1, 1, 1, genrename)\n\n    def addNewMode(self):\n        for i, m in configs.index2mode.items():\n            self.modes.addNewGroups(i + 1, 1, 12, m)\n            # 与调式网络相连\n            scales = configs.keyscales.get(i)\n            for k in range(12):  # key neurons project to mode neurons\n                for j, index in enumerate(scales[k, :]):\n                    pre = self.keys.groups.get(k).neurons[index]\n                    post = self.modes.groups.get(i).neurons[j]\n                    syn = Synapse(pre, post)\n                    syn.excitability = 1\n                    syn.type = 3\n                    post.synapses.append(syn)\n                    post.pre_neurons.append(pre)\n\n                    # syn1 = Synapse(post, pre) # mode neurons project to key neurons\n                    # syn1.type = 2\n                    # syn1.excitability = 1\n                    # syn1.weight = 10 # 这个地方应该设置成KS model\n                    # pre.synapses.append(syn1)\n                    # pre.pre_neurons.append(post)\n\n    def addNewKey(self):\n        row, col = configs.key_matrix.shape\n        for i in range(row):\n            #print(configs.key_matrix[i,:])\n            self.keys.addNewGroups(i+1, 1, col, configs.key_matrix[i,:])\n\n    def addNewChord(self):\n        for i in range(0, 7):  # 暂时先存储7个三和弦\n            self.chords.addNewGroups(i + 1, 1, 1)\n            # 与调式网络相连\n            # 先连接T,S,D和弦\n            for t, c in configs.chordsMap.items():\n                # print(t)\n                # print(configs.keyIndexMap.get(t))\n                # print(c[i, :])\n                # print('---------')\n                for k in c[i, :]:\n                    pre = self.chords.groups.get(i).neurons[0]\n                    post = self.keys.groups.get(configs.keyIndexMap.get(t)).neurons[k]\n                    syn = Synapse(pre, post)\n                    syn.excitability = 1\n                    post.synapses.append(syn)\n                    post.pre_neurons.append(pre)\n        # 建立和弦内部连接\n        r = np.argwhere(configs.chordsMatrix >= 1)\n        for i in range(len(r)):\n            # print(r[i][0])\n            # print(r[i][1])\n            pre = self.chords.groups.get(r[i][0]).neurons[0]\n            post = self.chords.groups.get(r[i][1]).neurons[0]\n            syn = Synapse(pre, post)\n            post.synapses.append(syn)\n            post.pre_neurons.append(pre)\n\n    def setTestStates(self):\n\n        self.goals.setTestStates()\n        self.composers.setTestStates()\n        self.genres.setTestStates()\n        self.keys.setTestStates()\n        self.modes.setTestStates()\n        self.chords.setTestStates()\n\n    def doRecalling(self, goalname, asm):\n        goal = self.goals.groups.get(goalname)\n        #         print(goal.name)\n        #         print(goal.id)\n        result = {}\n        sequences = asm.sequenceLayers.get(1).groups\n        dt = 0.1\n        time = np.arange(0, len(sequences) * 5, dt)\n\n        for t in time:\n            order = math.floor(t / 5) + 1\n\n            for neu in goal.neurons:\n                neu.I = 30\n                neu.update_normal(dt, t)\n            sg = sequences.get(order)\n            for neu in sg.neurons:\n                neu.updateCurrentOfLowerAndUpperLayer(t)\n                neu.update(dt, t, 'test')\n                # if(neu.spike == True):\n                # print(neu.index)\n                if (neu.spike == True and result.get(order) == None):\n                    result[int(order)] = neu.selectivity\n        return result\n\n    def doRecalling2(self, goalname, asm):\n        goal = self.goals.groups.get(goalname)\n        #         print(goal.name)\n        #         print(goal.id)\n\n        result = {}\n        for tindex, strack in asm.sequenceLayers.items():\n            nsequences = strack.get(\"N\").groups\n            tsequences = strack.get(\"T\").groups\n            dic = {}\n            ndic = {}\n            tdic = {}\n            dt = 0.1\n            time = np.arange(0, len(nsequences) * 5, dt)\n            for t in time:\n                order = math.floor(t / 5) + 1\n                for neu in goal.neurons:\n                    neu.I = 30\n                    neu.update_normal(dt, t)\n                nsg = nsequences.get(order)\n                for neu in nsg.neurons:\n                    # print(neu.selectivity)\n                    neu.updateCurrentOfLowerAndUpperLayer(t)\n                    neu.update(dt, t, 'test')\n                    # if(neu.I > 0):\n                    #     print(neu.I)\n                    if (neu.spike == True and ndic.get(order) == None):\n                        ndic[int(order)] = neu.selectivity\n\n                tsg = tsequences.get(order)\n                for neu in tsg.neurons:\n                    neu.updateCurrentOfLowerAndUpperLayer(t)\n                    neu.update(dt, t, 'test')\n                    if (neu.spike == True and tdic.get(order) == None):\n                        tdic[int(order)] = neu.selectivity\n\n            dic[\"N\"] = ndic\n            dic[\"T\"] = tdic\n            result[tindex] = dic\n        return result\n\n    def doRecalling3(self,goalname,asm):\n        goal = self.goals.groups.get(goalname)\n        #         print(goal.name)\n        #         print(goal.id)\n\n        result = {}\n        for tindex, strack in asm.sequenceLayers.items():\n            nsequences = strack.get(\"N\").groups\n            tsequences = strack.get(\"T\").groups\n            part = []\n            ndic = {}\n            tdic = {}\n            dt = 0.1\n            # for order in range(len(nsequences)):\n            #     nns = nsequences.get(order+1).neurons\n            #     for n in nns:\n            #         if n.preActive:\n            #             print('-------order:'+str(order)+', selectivity: '+str(n.selectivity)+'--------------------')\n            #             for syn in n.synapses:\n            #                 if syn.weight > 0:\n            #                     print(syn.weight)\n            #time = np.arange(0, len(nsequences) * 5, dt)\n            print(len(nsequences))\n            for i in range(0,len(nsequences)):\n                order = i+1\n                tmp = {}\n                time = np.arange(i*5,(i+1)*5,dt)\n                for t in time:\n                    for neu in goal.neurons:\n                        neu.I = 30\n                        neu.update_normal(dt, t)\n                    nsg = nsequences.get(order)\n                    for neu in nsg.neurons:\n                        neu.updateCurrentOfLowerAndUpperLayer(t)\n                        neu.update_normal(dt, t)\n                        # if (neu.I > 0):\n                        #     print('order: ' + str(order))\n                        #     print(neu.I)\n                        #     print(neu.selectivity)\n                        #     print(neu.I_lower)\n                        if (neu.spike == True and ndic.get(order) == None):# 这里用的是first spike的理念，我觉得最好改成max firingrate\n                            ndic[int(order)] = neu.selectivity\n                            tmp[\"N\"] = neu.selectivity\n\n                    tsg = tsequences.get(order)\n                    for neu in tsg.neurons:\n                        neu.updateCurrentOfLowerAndUpperLayer(t)\n                        neu.update_normal(dt, t)\n                        if (neu.spike == True and tdic.get(order) == None):#这个地方bug太邪乎了，等一会儿改\n                            tdic[int(order)] = neu.selectivity\n                            tmp[\"T\"] = neu.selectivity\n                part.append(tmp)\n            result[tindex] = part\n        print(len(result))\n        return result\n\n    def doRemebering(self, goalname, dt, t):\n        # storing the title information\n        goal_group = self.goals.groups.get(goalname)\n        for neu in goal_group.neurons:\n            neu.I = 50\n            neu.update_normal(dt, t)\n\n    def doRememberingComposer(self, composername, dt, t):\n        composer_group = self.composers.groups.get(composername)\n        for neu in composer_group.neurons:\n            neu.I = 50\n            neu.update_normal(dt, t)\n\n    def doRememberingGenre(self, genrename, dt, t):\n        genre_group = self.genres.groups.get(genrename)\n        for neu in genre_group.neurons:\n            neu.I = 50\n            neu.update_normal(dt, t)\n\n    def doRememberingKey(self, key, dt, t):\n        key_group = self.keys.groups.get(key)\n        for neu in key_group.neurons:\n            neu.I = 50 if neu.importance > 0 else -100\n            neu.update_learn(dt, t)\n\n    def doRememberingMode(self,mode, dt,t):\n        mode_group = self.modes.groups.get(mode)\n        for neu in mode_group.neurons:\n            neu.I = 50\n            neu.update_learn(dt,t)\n\n    def innerLearning(self, goalname, composer, genre):\n        g = self.goals.groups.get(goalname)\n        c = self.composers.groups.get(composer)\n        gre = self.genres.groups.get(genre)\n        if (g != None and c != None):\n            for n1 in c.neurons:\n                if (len(n1.spiketime) > 0):\n                    for n2 in g.neurons:\n                        if (len(n2.spiketime) > 0):\n                            temp = 0\n                            for sp1 in n1.spiketime:\n                                for sp2 in n2.spiketime:\n                                    if (abs(sp1 - sp2) <= n1.tau_ref):\n                                        temp += 1\n                            if (temp >= 3):\n                                syn = Synapse(n1, n2)\n                                syn.type = 2\n                                syn.weight = 5\n                                n2.synapses.append(syn)\n                                n2.pre_neurons.append(n1)\n\n        if (gre != None):\n            for n1 in gre.neurons:\n                if (len(n1.spiketime) > 0):\n                    if (c != None):\n                        for n2 in c.neurons:\n                            if (len(n2.spiketime) > 0):\n                                temp = 0\n                                for sp1 in n1.spiketime:\n                                    for sp2 in n2.spiketime:\n                                        if (abs(sp1 - sp2) <= n1.tau_ref):\n                                            temp += 1\n                                if (temp >= 4):\n                                    syn = Synapse(n1, n2)\n                                    syn.type = 2\n                                    syn.weight = 5\n                                    n2.synapses.append(syn)\n                                    n2.pre_neurons.append(n1)\n                    if (g != None):\n                        for n2 in g.neurons:\n                            if (len(n2.spiketime) > 0):\n                                temp = 0\n                                for sp1 in n1.spiketime:\n                                    for sp2 in n2.spiketime:\n                                        if (abs(sp1 - sp2) <= n1.tau_ref):\n                                            temp += 1\n                                if (temp >= 4):\n                                    syn = Synapse(n1, n2)\n                                    syn.type = 2\n                                    syn.weight = 5\n                                    n2.synapses.append(syn)\n                                    n2.pre_neurons.append(n1)\n\n    def inhibitGenres(self,dt,t):\n        gen_group = self.goals.groups\n        for g in gen_group.values():\n            for neu in g.neurons:\n                neu.I = -100\n                neu.update_normal(dt, t)\n\n    def inhibiteGoals(self, dt, t):\n        goal_group = self.goals.groups\n        for g in goal_group.values():\n            for neu in g.neurons:\n                neu.I = -100\n                neu.update(dt, t)\n\n    def inhibitComposers(self, dt, t):\n        com_group = self.composers.groups\n        for g in com_group.values():\n            for neu in g.neurons:\n                neu.I = -100\n                neu.update(dt, t)\n\n\n\n\n\n\n\n"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/PAC.py",
    "content": "'''\nPrimary Auditory Cortex\n'''\nimport torch\nfrom braincog.base.node.node import *\nfrom braincog.base.brainarea.BrainArea import *\nfrom braincog.base.connection import CustomLinear\nfrom braincog.base.learningrule.STDP import *\n\nclass PAC(BrainArea):\n\n    def __int__(self,w,mask):\n        self.noteNetworks = NoteLIFNode()\n        self.connection = [CustomLinear(w,mask),CustomLinear(w2,mask2)]\n        self.stdp = []\n        self.internalinputs = torch.zeros(640,640)\n        self.stdp.append(MutliInputSTDP(self.noteNetworks, self.connection))\n\n    def forward(self, x):\n        self.internalinputs,dw = self.stdp[0](x,self.internalinputs)\n        return self.internalinputs, dw"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/cluster.py",
    "content": "\nfrom .lifneuron import LIFNeuron\nfrom .synapse import Synapse\nfrom Modal.izhikevichneuron import *\n\nclass Cluster():\n    '''\n    classdocs\n    '''\n\n    def __init__(self, neutype='LIF', neunum=10):\n        '''\n        Constructor\n        '''\n        self.id = 0  # starting with 1\n        self.name = ''  # name of this group\n        self.neutype = neutype\n        self.neunum = neunum\n        self.neurons = []\n        self.timeWindow = 5  # ms\n\n    def createClusterNetwork(self):\n        # create neurons\n        for i in range(0, self.neunum):\n            if (self.neutype == 'LIF'):\n                node = LIFNeuron()\n                node.index = i + 1\n                node.setPreference()\n                self.neurons.append(node)\n            if (self.neutype == 'Izhikevich'):\n                node = IzhikevichNeuron()\n                node.index = i\n                self.neurons.append(node)\n            if (self.neutype == 'Gaussian'):\n                node = GaussianNeuron()\n                node.index = i + 1\n                self.neurons.append(node)\n            if (self.neutype == 'HH'):\n                node = HHNeuron()\n                node.index = i\n                self.neurons.append(node)\n\n    def setInhibitoryNeurons(self, ratio_inhneuron):\n        for i in range(int(self.neunum * (1 - ratio_inhneuron)), self.neunum):\n            self.neurons[i].type = 'inh'\n\n    def setPropertiesofNeurons(self, groupID, layerType, layerID):\n        for n in self.neurons:\n            n.layerType = layerType\n            n.groupIndex = groupID\n            n.layerIndex = layerID\n\n    def setTestStates(self):\n        for neu in self.neurons:\n            neu.setTestStates()\n\n    def createFullConnections(self):  # all in all connections\n        for i in range(0, self.neunum):\n            neu = self.neurons[i]\n            for j in range(0, self.neunum):\n                if (j != i):\n                    syn = Synapse(self.neurons[j], neu)  # this neuron is considered as post_synapse neuron\n                    syn.type = 0\n                    neu.synapses.append(syn)\n                    neu.pre_neurons.append(self.neurons[j])\n\n    def createInhibitoryConnections(self):  # all in all inhibitory connections\n        for i in range(0, self.neunum):\n            neu = self.neurons[i]\n            for j in range(0, self.neunum):\n                if (j != i):\n                    syn = Synapse(self.neurons[j], neu)  # this neuron is considered as post_synapse neuron\n                    syn.type = 0\n                    syn.excitability = 0\n                    syn.weight = 20\n                    neu.synapses.append(syn)\n                    neu.pre_neurons.append(self.neurons[j])\n\n    def writeSelfInfoToJson(self):\n        dic = {}\n        nlist = []\n        for neu in self.neurons:\n            if (len(neu.spiketime) <= 0): continue\n            ndic = neu.writeBasicInfoToJson()\n            nlist.append(ndic)\n        dic[\"GroupID\"] = self.id\n        dic[\"Name\"] = self.name\n        dic[\"Neuron\"] = nlist\n        return dic\n\n    def writeSpikeInfoToJson(self):\n        nlist = []\n        for neu in self.neurons:\n            if (len(neu.spiketime) > 0):\n                tmp = {}\n                tmp[\"GroupID\"] = neu.groupIndex\n                tmp[\"Index\"] = neu.index\n                tmp[\"SpikeTime\"] = neu.writeSpikeTimeToJson()\n                nlist.append(tmp)\n        return nlist\n\nclass ModeCluster(Cluster):\n    def __init__(self, neutype, neunum):\n        Cluster.__init__(self, neutype,neunum)\n\n    def createClusterNetwork(self, areaName):\n        for i in range(self.neunum): # 暂时先不考虑importance\n            if (self.neutype == 'LIF'):\n                node = ModeLIFNeuron()\n                node.index = i + 1\n                node.areaName = areaName\n                node.selectivity = i+1\n                self.neurons.append(node)\n            if (self.neutype == 'Izhikevich'):\n                self.neutype = 'Izhikevich'\n                node = ModeIzhikevichNeuron()\n                node.index = i + 1\n                node.areaName = areaName\n                node.selectivity = i+1\n                self.neurons.append(node)\n\nclass KeyCluster(Cluster):\n    def __init__(self, neutype, neunum):\n        '''\n        Constructor\n        '''\n        Cluster.__init__(self, neutype, neunum)\n\n    def createClusterNetwork(self,tone,areaName):\n        for i in range(0, self.neunum):\n            if (self.neutype == 'LIF'):\n                node = KeyLIFNeuron()\n                node.index = i + 1\n                node.areaName = areaName\n                node.selectivity = i\n                node.importance = tone[i]\n                self.neurons.append(node)\n            if(self.neutype == 'Izhikevich'):\n                self.neutype = 'Izhikevich'\n                a=0\n                b=0\n                c=0\n                d=0\n                if tone[i] == 2:\n                    a = 0.02\n                    b = 0.2\n                    c = -55\n                    d = 4\n                if tone[i] == -1:\n                    a=0.1\n                    b = 0.2\n                    c = -65\n                    d = 2\n                if(tone[i] == 1):\n                    a = 0.02\n                    b = 0.2\n                    c = -65\n                    d = 8\n                node = KeyIzhikevichNeuron(a,b,c,d)\n                node.index = i + 1\n                node.areaName = areaName\n                node.selectivity = i\n                node.importance = tone[i]\n                self.neurons.append(node)\n\nclass ChordCluster(Cluster):\n    def __init__(self, neutype,neunum):\n        Cluster.__init__(self, neutype, neunum)\n\n    def createClusterNetwork(self):\n        for i in range(self.neunum):\n            if (self.neutype == 'LIF'):\n                node = ChordLIFNeuron()\n                node.index = i + 1\n                node.areaName = 'Chord'\n                node.selectivity = i\n                node.importance = 1\n                self.neurons.append(node)\n            if self.neutype == 'Izhikevich':\n                node = ChordIzhikevichNeuron()\n                node.index = i + 1\n                node.areaName = 'Chord'\n                node.selectivity = i\n                node.importance = 1\n                self.neurons.append(node)"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/composercluster.py",
    "content": "from .cluster import Cluster\nfrom .composerlifneuron import ComposerLIFNeuron\n\n\nclass ComposerCluster(Cluster):\n    '''\n    classdocs\n    '''\n\n    def __init__(self, neutype, neunum):\n        '''\n        Constructor\n        '''\n        Cluster.__init__(self, neutype, neunum)\n\n    def createClusterNetwork(self):\n        for i in range(0, self.neunum):\n            if (self.neutype == 'LIF'):\n                node = ComposerLIFNeuron()\n                node.index = i + 1\n                node.areaName = 'Composer'\n                self.neurons.append(node)"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/composerlayer.py",
    "content": "from .layer import Layer\nfrom .composercluster import ComposerCluster\n\n\nclass ComposerLayer(Layer):\n    '''\n        This layer defines the information of composer name. One neuron corresponds to a composer\n    '''\n\n    def __init__(self, neutype='LIF'):\n        self.neutype = neutype\n        self.groups = {}\n\n    def setTestStates(self):\n        for id, g in self.groups.items():\n            g.setTestStates()\n\n    def addNewGroups(self, groupID, layerID, neunum, composername):\n        g = ComposerCluster(self.neutype, neunum)\n        g.id = groupID\n        g.name = composername\n        g.createClusterNetwork()\n        g.setPropertiesofNeurons(groupID, 'G', layerID)\n        self.groups[composername] = g"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/composerlifneuron.py",
    "content": "from .lifneuron import LIFNeuron\n\n\nclass ComposerLIFNeuron(LIFNeuron):\n    '''\n    classdocs\n    '''\n\n    def __init__(self, tau_ref=0, vthresh=5, Rm=2, Cm=0.2):\n        '''\n        Constructor\n        '''\n        LIFNeuron.__init__(self, tau_ref, vthresh, Rm, Cm)\n\n    def update(self, dt, t):\n        self.spike = False\n        # self.updateCurrentOfLowerAndUpperLayer(t)\n        if (t >= self.t_rest):\n            self.mem += dt * (-self.mem + self.I * self.Rm) / self.tau_m\n            if (self.mem > self.vth):\n                self.spike = True\n                self.spiketime.append(t)\n                self.mem = 0\n                self.t_rest = t + self.tau_ref"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/genrecluster.py",
    "content": "from .cluster import Cluster\nfrom .genrelifneuron import GenreLIFNeuron\n\nclass GenreCluster(Cluster):\n    '''\n    classdocs\n    '''\n\n    def __init__(self, neutype, neunum):\n        '''\n        Constructor\n        '''\n        Cluster.__init__(self, neutype, neunum)\n\n    def createClusterNetwork(self):\n        for i in range(0, self.neunum):\n            if (self.neutype == 'LIF'):\n                node = GenreLIFNeuron()\n                node.index = i + 1\n                node.areaName = 'Genre'\n                self.neurons.append(node)"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/genrelayer.py",
    "content": "from .layer import Layer\nfrom .genrecluster import GenreCluster\nclass GenreLayer(Layer):\n    '''\n        This layer defines the information of composer name. One neuron corresponds to a composer\n    '''\n\n    def __init__(self, neutype='LIF'):\n        self.neutype = neutype\n        self.groups = {}\n\n    def setTestStates(self):\n        for id, g in self.groups.items():\n            g.setTestStates()\n\n    def addNewGroups(self, groupID, layerID, neunum, genrename):\n        g = GenreCluster(self.neutype, neunum)\n        g.id = groupID\n        g.name = genrename\n        g.createClusterNetwork()\n        g.setPropertiesofNeurons(groupID, 'G', layerID)\n        self.groups[genrename] = g"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/genrelifneuron.py",
    "content": "from .lifneuron import LIFNeuron\n\nclass GenreLIFNeuron(LIFNeuron):\n    '''\n    classdocs\n    '''\n\n    def __init__(self, tau_ref=0, vthresh=5, Rm=2, Cm=0.2):\n        '''\n        Constructor\n        '''\n        LIFNeuron.__init__(self, tau_ref, vthresh, Rm, Cm)\n\n    def update(self, dt, t):\n        self.spike = False\n        # self.updateCurrentOfLowerAndUpperLayer(t)\n        if (t >= self.t_rest):\n            self.mem += dt * (-self.mem + self.I * self.Rm) / self.tau_m\n            if (self.v > self.vth):\n                self.spike = True\n                self.spiketime.append(t)\n                self.mem = 0\n                self.t_rest = t + self.tau_ref"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/izhikevichneuron.py",
    "content": "'''\nCreated on 2016.4.8\n\n@author: liangqian\n'''\n#from modal.izhikevich import Izhikevich\nfrom braincog.base.node import IzhNodeMU\nimport math\nimport random\nimport numpy as np\n\nclass IzhikevichNeuron(IzhNodeMU):\n    '''\n    classdocs\n    '''\n\n\n    def __init__(self, a = 0.1,b = 0.2,c = -65,d = 8,vthresh = 30, dt=0.1):\n        '''\n        Constructor\n        '''\n        super().__init__(threshold=vthresh, a=a, b=b, c=c, d=d, dt=dt)\n        self.layerType = 'S'  # S:sequenceLayer, G: goal layer\n        self.layerIndex = 0  # the layer in which the neuron situated\n        self.groupIndex = 0  # the group in which the neuron situated\n        self.index = 0  # starting with 1\n        self.areaName = ''\n        self.synapses = [] #this neuron is considered as post-synaptic neuron\n        self.spiketime = []\n        self.pre_neurons = []\n        self.I_syn_lower = 0\n        self.I_syn_upper = 0\n        self.I_ext = 0\n        self.I_lower = 0\n        self.I_upper = 0\n        self.I_ext = -100\n        self.timeWindow = 5  # ms\n        self.I_bg = random.randint(0, 10)\n        # self.state = 'Learn' # else test\n        self.selectivity = 0\n        self.importance = 0\n        self.preActive = False\n        self.I = 0\n        self.v = -65\n        self.u = b * self.v\n        self.vthresh = vthresh\n        self.spike = False\n        self.type = 'exc'\n        \n    def update_old(self,dt,t):\n        self.spike = 0\n        self.updateSynapses(t)\n        self.updateCurrentOfLowerAndUpperLayer(t)\n        self.I = self.I_ext + self.I_syn_lower + self.I_syn_upper\n        self.v += dt * (0.04*self.v * self.v + 5 * self.v + 140 - self.u + self.I)\n        self.u += dt * self.a * (self.b *self.v - self.u)\n        #self.synapseWeightsDepression()\n        \n        if(self.v >= 30):\n            self.spike = 1\n            self.v = self.c\n            self.u += self.d\n            self.spiketime.append(t)\n\n    def update(self,dt,t,state):\n        if (state == 'Learn'):\n            self.update_learn(dt, t)\n        if(state == 'test'):\n            self.update_test(dt,t)\n    def update_learn(self,dt,t):\n        self.spike = False\n        self.v += dt * (0.04 * self.v * self.v + 5 * self.v + 140 - self.u + self.I)\n        self.u += dt * self.a * (self.b * self.v - self.u)\n        if self.v > self.vthresh:\n            self.spike = True\n            self.v = self.c\n            self.u += self.d\n            self.preActive = True\n            self.spiketime.append(t)\n            self.updateSynapses(t)\n\n    def update_test(self,dt,t):\n        self.spike = False\n        self.v += dt * (0.04 * self.v * self.v + 5 * self.v + 140 - self.u + self.I)\n        self.u += dt * self.a * (self.b * self.v - self.u)\n        if self.v > self.vthresh:\n            self.spike = True\n            self.v = self.c\n            self.u += self.d\n\n    def update_normal(self,dt,t):\n        self.spike = False\n        self.v += dt * (0.04 * self.v * self.v + 5 * self.v + 140 - self.u + self.I)\n        self.u += dt * self.a * (self.b * self.v - self.u)\n        if self.v >= self.vthresh:\n            self.spike = True\n            self.v = self.c\n            self.u += self.d\n            self.spiketime.append(t)\n\n\n    def updateSynapses(self,t):\n        for syn in self.synapses:\n            syn.computeWeight(t) \n    \n    def updateCurrentOfLowerAndUpperLayer(self,t):\n        I_inh = 0\n        I_ext = 0\n        I_exc_ext = 0\n        for syn in self.synapses:\n            # compute the alpha value of all spikes before this time t\n            alpha_value = 0\n            for st in syn.pre.spiketime:\n                temp = 0\n                if(t - st >= 0): temp = 6*(t/1000)*math.exp(-0.03*(t - st)/1000)\n                else:temp = 0\n                alpha_value += temp\n            if(syn.type == 0): # from the same group\n                if(syn.pre.type == 'inh'):\n                    I_inh += syn.weight * (self.v+80) * alpha_value\n                if(syn.pre.type == 'exc'):\n                    I_ext += syn.weight * self.v * alpha_value\n            if(syn.type == 1):# from other modules in the same layer\n                I_exc_ext += syn.weight * self.v * alpha_value\n            if(syn.type == 2): # from the upper layer\n                self.I_syn_upper += self.weight * self.v * alpha_value\n                \n        self.I_syn_lower = -I_inh + I_ext + I_exc_ext\n\n    def setTestStates(self):\n        self.t_rest = 0\n        self.spiketime = []\n        self.v = -65\n        self.u = self.b*self.v\n        self.I = 0\n        self.I_ext = 0\n        for syn in self.synapses:\n            syn.strength = 0\n\n    def writeBasicInfoToJson(self):\n        dic = {}\n        dic[\"TrackID\"] = self.layerIndex\n        dic[\"GroupID\"] = self.groupIndex\n        dic[\"Index\"] = self.index\n        dic[\"selectivity\"] = self.selectivity\n        dic[\"area\"] = self.areaName\n        slist = []\n        for syn in self.synapses:\n            if (syn.weight <= 0): continue\n            tmp = {}\n            tmp[\"type\"] = syn.type\n            tmp[\"StrackID\"] = syn.pre.layerIndex\n            tmp[\"SgroupID\"] = syn.pre.groupIndex\n            tmp[\"Sindex\"] = syn.pre.index\n            tmp[\"Sarea\"] = syn.pre.areaName\n            tmp[\"pre-selectivity\"] = syn.pre.selectivity\n            tmp[\"weight\"] = syn.weight\n            slist.append(tmp)\n        dic[\"synapses\"] = slist\n        return dic\n\n    def writeSpikeTimeToJson(self):\n        slist = []\n        for i, t in enumerate(self.spiketime):\n            dic = {}\n            dic[i + 1] = round(t, 2)#两位有效数字\n            slist.append(dic)\n        return slist\n        \n\n\nclass NoteIzhikevichNeuron(IzhikevichNeuron):\n    def __init__(self,a = 0.1,b = 0.2,c = -65,d = 8,vthresh = 30):\n        IzhikevichNeuron.__init__(self,a,b,c,d,vthresh)\n    def setPreference(self):\n        self.selectivity = self.index - 2\n    def computeFilterCurrent(self):\n        if(self.I_ext == self.selectivity):\n            self.I = 30\n\n    def updateCurrentOfLowerAndUpperLayer(self, t):\n        self.I_lower = 0\n        self.I_upper = 0\n        for syn in self.synapses:\n            syn.computeShortTermFacilitation2(t)\n            if (syn.type == 0):  # the same group\n                if (syn.excitability == 0):\n                    self.I_lower -= syn.weight * syn.strength\n                    if (self.I_lower < -20): self.I_lower = -20\n            if (syn.type == 1):  # pre and post neurons come from  the same layer but not the same group\n                #if(syn.weight > 0):\n                    # print('pre_neuron_group id:'+str(syn.pre.groupIndex) + ' neuron index:'+str(syn.pre.index))\n                    # print('post_neuron_group id:'+str(self.groupIndex) + ' neuron index:'+str(self.index))\n                    # print('syn.strength=' + str(syn.strength))\n                    # print('syn.weight='+ str(syn.weight))\n\n                self.I_lower += syn.weight * syn.strength\n                # print('syn.strength='+str(syn.strength))\n                # print('syn.weight='+ str(syn.weight))\n\n            if (syn.type >= 2):  # pre and post neurons come from the different layers\n                # print(syn.pre.groupIndex)\n                self.I_upper += syn.weight * syn.strength\n        self.I = self.I_lower + self.I_upper\nclass TempoIzhikevichNeuron(IzhikevichNeuron):\n    def __init__(self,a = 0.1,b = 0.2,c = -65,d = 8,vthresh = 30):\n        IzhikevichNeuron.__init__(self,a,b,c,d,vthresh)\n    def setPreference(self):\n        self.selectivity = self.index * 0.125\n\n    def computeFilterCurrent(self):\n        if(self.I_ext <= self.selectivity + 0.0625 and self.I_ext >= self.selectivity - 0.0625 ):\n            self.I = 30\n\n    def updateCurrentOfLowerAndUpperLayer(self, t):\n        self.I_lower = 0\n        self.I_upper = 0\n        for syn in self.synapses:\n            syn.computeShortTermFacilitation2(t)\n            if (syn.type == 0):  # the same group\n                if (syn.excitability == 0):\n                    self.I_lower -=  syn.weight * syn.strength\n                    if (self.I_lower < -20): self.I_lower = -20\n            if (syn.type == 1):  # pre and post neurons come from  the same layer but not the same group\n                #                 if(syn.weight > 0):\n                #\n                #                     print('pre_neuron_group id:'+str(syn.pre.groupIndex) + ' neuron index:'+str(syn.pre.index))\n                #                     print('post_neuron_group id:'+str(self.groupIndex) + ' neuron index:'+str(self.index))\n                #                     print(syn.weight)\n                self.I_lower += syn.weight * syn.strength\n                # print('syn.strength='+str(syn.strength))\n\n            if (syn.type == 2):  # pre and post neurons come from the different layers\n                # print(syn.pre.groupIndex)\n                # self.I_lower = syn.weight * syn.strength\n                self.I_upper += syn.weight * syn.strength\n\n        #         if(self.I_upper == 0):\n        #             self.I = self.I_ext\n        else:\n            self.I = self.I_lower + self.I_upper\n\nclass TitleIzhikevichNeuron(IzhikevichNeuron):\n    def __init__(self,a = 0.1,b = 0.2,c = -65,d = 8,vthresh=30):\n        IzhikevichNeuron.__init__(self,a,b,c,d,vthresh)\n\nclass ComposerIzhikevichNeuron(IzhikevichNeuron):\n    def __init__(self,a = 0.1,b = 0.2,c = -65,d = 8,vthresh=30):\n        IzhikevichNeuron.__init__(self, a,b,c,d,vthresh)\n\nclass GenreIzhikevichNeuron(IzhikevichNeuron):\n    def __init__(self, a = 0.1,b = 0.2,c = -65,d = 8, vthresh=30):\n        IzhikevichNeuron.__init__(self, a, b, c, d, vthresh)\n\nclass AmyIzhikevichNeuron(IzhikevichNeuron):\n    def __init__(self,a = 0.1,b = 0.2,c = -65,d = 8,vthresh=30):\n        IzhikevichNeuron.__init__(self, a,b,c,d,vthresh)\n\nclass DirectionIzhikevichNeuron(IzhikevichNeuron):\n    def __init__(self,a = 0.1,b = 0.2,c = -65,d = 8,vthresh=30):\n        IzhikevichNeuron.__init__(self,a,b,c,d,vthresh)\n\n    def setPreference(self):\n        # self.selectivity = 2 * math.pi/240 * self.index - math.pi/240\n        self.selectivity = (self.index + 1) * math.pi / 120\n\n    def computeFilterCurrent(self, input):\n        if (input < self.selectivity + math.pi / 240 and input >= self.selectivity - math.pi / 240):\n            self.I = self.I_ext = 30\n\n    def updateCurrentOfLowerAndUpperLayer(self, t):\n        self.I_lower = 0\n        self.I_upper = 0\n        for syn in self.synapses:\n            syn.computeShortTermFacilitation2(t)\n            if (syn.type == 0):  # the same group\n                if (syn.excitability == 0):\n                    self.I_lower -= syn.weight * syn.strength\n                    if (self.I_lower < -20): self.I_lower = -20\n            if (syn.type == 1):  # pre and post neurons come from  the same layer but not the same group\n                #if(syn.weight > 0):In t\n                    # print('pre_neuron_group id:'+str(syn.pre.groupIndex) + ' neuron index:'+str(syn.pre.index))\n                    # print('post_neuron_group id:'+str(self.groupIndex) + ' neuron index:'+str(self.index))\n                    # print('syn.strength=' + str(syn.strength))\n                    # print('syn.weight='+ str(syn.weight))\n\n                self.I_lower += syn.weight * syn.strength\n                # print('syn.strength='+str(syn.strength))\n                # print('syn.weight='+ str(syn.weight))\n\n            if (syn.type >= 2):  # pre and post neurons come from the different layers\n                # print(syn.pre.groupIndex)\n                self.I_upper += syn.weight * syn.strength\n        self.I = self.I_lower + self.I_upper\n\nclass GridIzhikevichCell(IzhikevichNeuron):\n    def __init__(self,a = 0.1,b = 0.2,c = -65,d = 8,vthresh = 30):\n        IzhikevichNeuron.__init__(self, a,b,c,d,vthresh)\nclass KeyIzhikevichNeuron(IzhikevichNeuron):\n    def __init__(self,a = 0.1,b = 0.2,c = -65,d = 8,vthresh = 30):\n        IzhikevichNeuron.__init__(self,a,b,c,d,vthresh)\n\nclass ModeIzhikevichNeuron(IzhikevichNeuron):\n    def __init__(self,a = 0.1,b = 0.2,c = -65,d = 8,vthresh = 30):\n        IzhikevichNeuron.__init__(self,a,b,c,d,vthresh)\n\nclass ChordIzhikevichNeuron(IzhikevichNeuron):\n    def __init__(self,a = 0.1,b = 0.2,c = -65,d = 8, vthresh = 30):\n        IzhikevichNeuron.__init__(self, a,b,c,d,vthresh)\n\n"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/layer.py",
    "content": "from abc import ABCMeta,abstractmethod\nfrom conf.conf import configs\nfrom Modal.cluster import *\nclass Layer():\n    '''\n    classdocs\n    '''\n    _metaclass_ = ABCMeta\n\n    def __init__(self, neutype):\n        '''\n        Constructor\n        '''\n        self.neutype = neutype\n        self.groups = {}\n\n    @abstractmethod\n    def resetProperties(self):\n        raise NotImplementedError\n\n    def addNewGroups(self, layerID, neunum):\n        raise NotImplementedError\n\nclass ModeLayer(Layer):\n    def __init__(self, neutype = 'LIF'):\n        self.neutype = neutype\n        self.groups = {}\n\n    def setTestStates(self):\n        for id, g in self.groups.items():\n            g.setTestStates()\n\n    def addNewGroups(self, groupID, layerID, neunum, modeName):\n        g = ModeCluster('Izhikevich', neunum)\n        g.id = groupID\n        g.name = modeName\n        g.createClusterNetwork(g.name)\n        g.setPropertiesofNeurons(groupID,'Mode',layerID)\n        self.groups[groupID-1] = g\n\nclass KeyLayer(Layer):\n    def __init__(self, neutype='LIF'):\n        self.neutype = neutype\n        self.groups = {}\n\n    def setTestStates(self):\n        for id, g in self.groups.items():\n            g.setTestStates()\n\n    def addNewGroups(self, groupID, layerID, neunum, key):\n        g = KeyCluster('Izhikevich', neunum)\n        g.id = groupID\n        g.name = configs.index2key.get(groupID-1)\n        g.createClusterNetwork(key,g.name)\n        g.setPropertiesofNeurons(groupID, 'Key', layerID)\n        self.groups[groupID-1] = g\n\nclass ChordLayer(Layer):\n    def __init__(self, neutype = 'LIF'):\n        Layer.__init__(self,neutype)\n\n    def setTestStates(self):\n        for id, g in self.groups.items():\n            g.setTestStates()\n\n    def addNewGroups(self,groupID, layerID, neunum):\n        g = ChordCluster('Izhikevich', neunum)\n        g.id = groupID\n        g.name = groupID\n        g.createClusterNetwork()\n        g.setPropertiesofNeurons(groupID, 'Chord', layerID)\n        self.groups[groupID - 1] = g"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/lifneuron.py",
    "content": "import torch\nimport random\nfrom braincog.base.node import node\nimport numpy as np\nclass LIFNeuron(node.LIFNode):\n\n    def __init__(self, tau_ref = 0, vthresh = 5, Rm = 2, Cm = 0.2,dt = 0.1,*args, **kwargs):\n        super().__init__(threshold=vthresh, tau=Rm*Cm, dt=dt, *args, **kwargs)\n        self.layerType = 'S'  # S:sequenceLayer, G: goal layer\n        self.layerIndex = 0  # the layer in which the neuron situated\n        self.groupIndex = 0  # the group in which the neuron situated\n        self.index = 0  # starting with 1\n        self.areaName = ''\n        self.pre_neurons = []\n        self.synapses = []\n        self.spiketime = []\n        self.type = 'exc'\n        self.tau_ref = tau_ref\n        self.tau_m = Rm*Cm\n        self.vth = vthresh\n        self.Rm = Rm\n        self.Cm = Cm\n        self.t_rest = 0\n        self.I = 0\n        self.spike = False\n        self.firingrate = 0  # Hz\n        self.I_ower = 0\n        self.I_upper = 0\n        self.I_ext = -100\n        self.timeWindow = 5  # ms\n        self.I_bg = random.randint(0, 10)\n        # self.state = 'Learn' # else test\n        self.selectivity = 0\n        self.preActive = False\n\n    def update(self, dt, t, state):  # state = 'learn' or  state = 'test'\n\n        if (state == 'Learn'):\n            self.update_learn(dt, t)\n            '''\n            #------------Gaussian selectivity--------------#\n            self.I = math.exp(-((self.I_ext-math.pi/16)/0.24)**2)\n            self.I = self.I if self.I >= 0.5 else 0\n            self.I *= 10\n            '''\n\n        elif (state == 'test'):\n            self.update_test(dt, t)\n\n    def update_learn(self, dt, t):\n        self.spike = False\n        # self.computeFilterCurrent()\n        '''\n        #------------Gaussian selectivity--------------#\n        self.I = math.exp(-((self.I_ext-math.pi/16)/0.24)**2)\n        self.I = self.I if self.I >= 0.5 else 0\n        self.I *= 10\n        '''\n        if (t >= self.t_rest):\n            self.mem += dt * (-self.mem + self.I * self.Rm) / self.tau_m\n            if (self.mem > self.vth):\n                self.spike = True\n                self.preActive = True\n                # print(\"groupID:\"+ str(self.groupIndex) + \", neuronID:\"+str(self.index))\n                self.spiketime.append(t)\n                self.mem = 0\n                self.t_rest = t + self.tau_ref\n                self.updateSynapses(t)\n\n    def update_test(self, dt, t):\n        self.spike = False\n        # self.updateCurrentOfLowerAndUpperLayer(t)\n        if (t >= self.t_rest):\n            self.mem += dt * (-self.mem + self.I * self.Rm) / self.tau_m\n            if (self.mem > self.vth):\n                self.spike = True\n                self.spiketime.append(t)\n                self.mem = 0\n                self.t_rest = t + self.tau_ref\n\n    def update_normal(self, dt, t):\n\n        self.spike = False\n        # self.I = self.I_ext\n        if (t >= self.t_rest):\n            self.mem += dt * (-self.mem + self.I * self.Rm) / self.tau_m\n            if (self.mem > self.vth):\n                self.spike = True\n                self.mem = 0\n                self.t_rest = t + self.tau_ref\n                self.spiketime.append(t)\n\n    def updateSynapses(self, t):\n        for syn in self.synapses:\n            syn.computeWeight(t)\n\n    def setTestStates(self):\n        self.t_rest = 0\n        self.spiketime = []\n        self.mem = 0\n        self.I = 0\n        self.I_ext = 0\n        for syn in self.synapses:\n            syn.strength = 0\n\n        # print('I=' + str(self.I))\n\n    def computeFilterCurrent(self):\n        pass\n\n    def setPreference(self):  # set preference of a neuron or called selectivity\n        # self.selectivity = 2 * math.pi/16 * self.index - math.pi/16 # the mean of the Gaussian funtion\n        pass\n\n    def writeBasicInfoToJson(self, areaName):\n        dic = {}\n        dic[\"TrackID\"] = self.layerIndex\n        dic[\"GroupID\"] = self.groupIndex\n        dic[\"Index\"] = self.index\n        dic[\"area\"] = areaName\n        slist = []\n        for syn in self.synapses:\n            if (syn.weight <= 0): continue\n            tmp = {}\n            tmp[\"StrackID\"] = syn.pre.layerIndex\n            tmp[\"SgroupID\"] = syn.pre.groupIndex\n            tmp[\"Sindex\"] = syn.pre.index\n            tmp[\"Sarea\"] = syn.pre.areaName\n            tmp[\"TtrackID\"] = self.layerIndex\n            tmp[\"TgroupID\"] = self.groupIndex\n            tmp[\"Tindex\"] = self.index\n            tmp[\"Tarea\"] = self.areaName\n            tmp[\"type\"] = syn.type\n            tmp[\"weight\"] = syn.weight\n            slist.append(tmp)\n        dic[\"synapses\"] = slist\n        return dic\n\n    def writeSpikeTimeToJson(self):\n        slist = []\n        for i, t in enumerate(self.spiketime):\n            dic = {}\n            dic[i + 1] = t\n            slist.append(dic)\n        return slist\n\n# neu  = LIFNeuron()\n# dt = 0.001\n# T = 1\n# time = np.arange(0,T,dt)\n# spikes = np.zeros(len(time))\n# for i in range(0,len(time)):\n#     if(i == 22):\n#         print(\"debug\")\n#     neu.I = 84.49\n#     neu.update_normal(dt, time[i])\n#     if(neu.spike == True):\n#         spikes[i] = 1\n#         #spikes[i] = neu.mem\n# print(len(neu.spiketime))\n# pl.plot(time,spikes)\n# pl.show()"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/note.py",
    "content": "'''\nCreated on 2016.7.6\n\n@author: liangqian\n'''\nfrom Modal.pitch import Pitch\nclass Note():\n    '''\n    Because a chord consist of more than two pitches at the same time, so using\n    arrays to record the chord\n    '''\n    def __init__(self):\n        self.pitches = []\n#         self.startTime = []\n#         self.endTime = []\n        self.lastTime = []\n       \n        "
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/notecluster.py",
    "content": "from .cluster import Cluster\nfrom .notelifneuron import NoteLIFNeuron\nfrom Modal.izhikevichneuron import *\n\nclass NoteCluster(Cluster):\n    '''\n    classdocs\n    '''\n\n    def __init__(self, neutype, neunum):\n        '''\n        Constructor\n        '''\n        Cluster.__init__(self, neutype, neunum)\n\n    def createClusterNetwork(self):\n        for i in range(0, self.neunum):\n            if (self.neutype == 'LIF'):\n                node = NoteLIFNeuron()\n                node.index = i + 1\n                node.areaName = 'NMSM'\n                node.setPreference()\n                self.neurons.append(node)\n\n            if (self.neutype == 'Izhikevich'):\n                node = NoteIzhikevichNeuron()\n                node.index = i + 1\n                node.areaName = 'NMSM'\n                node.setPreference()\n                self.neurons.append(node)\n#             if(self.neutype == 'Izhi'):\n#                 node = IzhikevichNeuron(a = 0.02,b = 0.2,c = -65,d = 8,vthresh = 30)\n#                 node.index = i\n#                 self.neurons.append(node)\n#             if(self.neutype == 'Gaussian'):\n#                 node = GaussianNeuron()\n#                 node.index = i+1\n#                 self.neurons.append(node)\n#             if(self.neutype == 'HH'):\n#                 node = HHNeuron()\n#                 node.index = i\n#                 self.neurons.append(node)"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/notelifneuron.py",
    "content": "from .lifneuron import LIFNeuron\n\n\nclass NoteLIFNeuron(LIFNeuron):\n    '''\n    classdocs\n    '''\n\n    def __init__(self, tau_ref=0.5, vthresh=5, Rm=2, Cm=0.2):\n        '''\n        Constructor\n        '''\n        LIFNeuron.__init__(self, tau_ref, vthresh, Rm, Cm)\n\n    def setPreference(self):\n        self.selectivity = self.index - 2\n\n    def computeFilterCurrent(self):\n        if (self.I_ext == self.selectivity):\n            self.I = 10\n\n    def updateCurrentOfLowerAndUpperLayer(self, t):\n        self.I_lower = 0\n        self.I_upper = 0\n        for syn in self.synapses:\n            syn.computeShortTermFacilitation(t)\n            if (syn.type == 0):  # the same group\n                if (syn.excitability == 0):\n                    self.I_lower -= syn.weight * syn.strength\n                    if (self.I_lower < -20): self.I_lower = -20\n            if (syn.type == 1):  # pre and post neurons come from  the same layer but not the same group\n                # if(syn.weight > 0):\n                #     print('pre_neuron_group id:'+str(syn.pre.groupIndex) + ' neuron index:'+str(syn.pre.index))\n                #     print('post_neuron_group id:'+str(self.groupIndex) + ' neuron index:'+str(self.index))\n                #     print('syn.strength=' + str(syn.strength))\n                #     print('syn.weight='+ str(syn.weight))\n\n                self.I_lower += 0.001 * syn.weight * syn.strength\n                # print('syn.strength='+str(syn.strength))\n                # print('syn.weight='+ str(syn.weight))\n\n            if (syn.type == 2):  # pre and post neurons come from the different layers\n                # print(syn.pre.groupIndex)\n                self.I_upper += 0.001 * syn.weight * syn.strength\n\n        #         if(self.I_upper == 0):\n\n        #             self.I = self.I_ext\n        self.I = 0.4 * self.I_lower + 0.6 * self.I_upper"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/notesequencelayer.py",
    "content": "from .sequencelayer import SequenceLayer\nfrom .notecluster import NoteCluster\nfrom .synapse import Synapse\n\n\nclass NoteSequenceLayer(SequenceLayer):\n    '''\n    classdocs\n    '''\n\n    def __init__(self, neutype):\n        '''\n        Constructor\n        '''\n        SequenceLayer.__init__(self, neutype)\n\n    def addNewGroups(self, GroupID, layerID, neunum):\n        g = NoteCluster(self.neutype, neunum)\n        g.createClusterNetwork()\n        # g.createInhibitoryConnections()\n        g.id = GroupID\n        g.setPropertiesofNeurons(g.id, 'S', layerID)\n        self.groups[g.id] = g\n\n        # create full connection with the former group\n        if (len(self.groups) > 1):\n            s = 0\n            if (g.id <= 5):\n                s = 1\n            else:\n                s = g.id - 4\n            # for i in range(1,g.id)[::-1]:\n            for i in range(s, g.id)[::-1]:\n                pre_g = self.groups.get(i)\n                for n1 in pre_g.neurons:\n                    for n2 in g.neurons:\n                        if (n1.type == 'inh' or n2.type == 'inh'): continue;\n                        syn = Synapse(n1, n2)\n                        syn.type = 1\n                        syn.delay = g.id - pre_g.id\n                        n2.pre_neurons.append(n1)\n                        n2.synapses.append(syn)"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/pitch.py",
    "content": "'''\nCreated on 2018.8.29\n\n@author: liangqian\n'''\n\nclass Pitch():\n    '''\n    classdocs\n    '''\n\n\n    def __init__(self):\n        '''\n        Constructor\n        '''\n        self.name = ''\n        self.frequence = 0 #midi index number just now"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/sequencelayer.py",
    "content": "from .layer import Layer\nfrom .cluster import Cluster\nfrom .synapse import Synapse\n\n\nclass SequenceLayer(Layer):\n    '''\n    This class mainly stores the musical sequential elements, including pitches and durations\n    '''\n\n    def __init__(self, neutype='LIF'):\n        '''\n        Constructor\n        '''\n        self.type = \"\"\n        self.neutype = neutype\n        self.groups = {}\n\n    def addNewGroups(self, GroupID, layerID, neunum):\n        g = Cluster(self.neutype, neunum)\n        g.createClusterNetwork()\n        g.id = GroupID\n        g.setPropertiesofNeurons(g.id, 'S', layerID)\n        self.groups[g.id] = g\n\n        # create full connection with the former group\n        if (len(self.groups) > 1):\n            for i in range(1, g.id)[::-1]:\n                pre_g = self.groups.get(i)\n                for n1 in pre_g.neurons:\n                    for n2 in g.neurons:\n                        if (n1.type == 'inh' or n2.type == 'inh'): continue;\n                        syn = Synapse(n1, n2)\n                        syn.type = 1\n                        syn.delay = g.id - pre_g.id\n                        n2.pre_neurons.append(n1)\n                        n2.synapses.append(syn)\n\n    def setTestStates(self):\n        for gid, g in self.groups.items():\n            g.setTestStates()"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/sequencememory.py",
    "content": "from .synapse import Synapse\nfrom Modal.sequencelayer import SequenceLayer\nimport numpy as np\nfrom Modal.synapse import Synapse\n\nclass SequenceMemory():\n    '''\n    classdocs\n    '''\n\n    def __init__(self, neutype):\n        '''\n        Constructor\n        '''\n        self.neutype = neutype\n        self.sequenceLayers = {}\n\n    def createActionSequenceMem(self, layernum, neutype, neunumpergroup):\n        pass\n\n    def doRemembering(self):\n        pass\n\n    def doConnecting(self, goal, sl, order):\n        # the goal and the group always generate spikes in a limit time window,create a synapse between them.\n        group = sl.groups.get(order)\n        if (group == None): return\n        tb = (order - 1) * group.timeWindow\n        te = (order) * group.timeWindow\n        sp1_goal = {}\n        sp2 = []\n\n        for n in goal.neurons:\n            sp = []\n            for st in n.spiketime:\n                if (st < te and st >= tb):\n                    sp.append(st)\n            sp1_goal[n.index] = sp\n\n        for n in group.neurons:\n            if (len(n.spiketime) > 0):\n                for index, sp in sp1_goal.items():\n                    temp = 0\n                    for sp1 in n.spiketime:  # spike times of group\n                        for sp2 in sp:\n                            if (abs(sp1 - sp2) <= n.tau_ref):\n                                temp += 1\n                    if (\n                            temp >= 4):  # super threshold, create a new synapse between goal and neurons of sequence group\n                        syn = Synapse(goal.neurons[index - 1], n)\n                        syn.type = 2\n                        syn.weight = 3\n                        n.pre_neurons.append(goal.neurons[index - 1])\n                        n.synapses.append(syn)\n\n                        # add reverse synapse to neurons of the goal\n                        syn2 = Synapse(n, goal.neurons[index - 1])\n                        syn2.type = 2\n                        syn2.weight = 1\n                        goal.neurons[index - 1].synapses.append(syn2)\n\n        # clear the goal's spike time\n\n    ''' *************************************************************\n    I have forgot why neurons here needs to be cleaned, but this must be important, mark here\n    for n in goal.neurons:\n        n.spiketime = []    \n    ******************************************************************\n    '''\n\n    #     def doConnectToGoal(self,goal,track,order): # connect to the goal in the time window\n    #\n    #             for sl in track.values():\n    #                 group = sl.groups.get(order)\n    #\n    #                 if(group == None): continue\n    #                 # the goal and the group always generate spikes in a limit time window,create a synapse between them.\n    #                 tb = (order-1)*group.timeWindow\n    #                 te = (order) * group.timeWindow\n    #                 sp1_goal = {}\n    #                 sp2 = []\n    #\n    #                 for n in goal.neurons:\n    #                     sp = []\n    #                     for st in n.spiketime:\n    #                         if(st < te and st >= tb ):\n    #                             sp.append(st)\n    #                     sp1_goal[n.index] = sp\n    #\n    #                 for n in group.neurons:\n    #                     if(len(n.spiketime) > 0):\n    #                         for index,sp in sp1_goal.items():\n    #                             temp = 0\n    #                             for sp1 in n.spiketime: #spike times of group\n    #                                 for sp2 in sp:\n    #                                     if(abs(sp1-sp2) <= n.tau_ref):\n    #                                         temp += 1\n    #                             if(temp >= 4): # super threshold, create a new synapse between goal and neurons of sequence group\n    #                                 syn = Synapse(goal.neurons[index-1],n)\n    #                                 syn.type = 2\n    #                                 syn.weight = 3\n    #                                 n.pre_neurons.append(goal.neurons[index-1])\n    #                                 n.synapses.append(syn)\n    #\n    #                                 #add reverse synapse to neurons of the goal\n    #                                 syn2 = Synapse(n,goal.neurons[index-1])\n    #                                 syn2.type = 2\n    #                                 syn2.weight = 1\n    #                                 goal.neurons[index-1].synapses.append(syn2)\n    #\n    #             #clear the goal's spike time\n    #             ''' *************************************************************\n    #             I have forgot why neurons here needs to be cleaned, but this must be important, mark here\n    #             for n in goal.neurons:\n    #                 n.spiketime = []\n    #             ******************************************************************\n    #         '''\n    #\n    #     def doConnectToComposer(self, composer, track, order):\n    #         for sl in track.values():\n    #             group = sl.groups.get(order)\n    #\n    #             if(group == None): continue\n    #             # the goal and the group always generate spikes in a limit time window,create a synapse between them.\n    #             tb = (order-1)*group.timeWindow\n    #             te = (order) * group.timeWindow\n    #             sp1_composer = {}\n    #             sp2 = []\n    #\n    #             for n in composer.neurons:\n    #                 sp = []\n    #                 for st in n.spiketime:\n    #                     if(st < te and st >= tb ):\n    #                         sp.append(st)\n    #                 sp1_composer[n.index] = sp\n    #\n    #             for n in group.neurons:\n    #                 if(len(n.spiketime) > 0):\n    #                     for index,sp in sp1_composer.items():\n    #                         temp = 0\n    #                         for sp1 in n.spiketime: #spike times of group\n    #                             for sp2 in sp:\n    #                                 if(abs(sp1-sp2) <= n.tau_ref):\n    #                                     temp += 1\n    #                         if(temp >= 4): # super threshold, create a new synapse between composer and neurons of sequence group\n    #                             syn = Synapse(composer.neurons[index-1],n)\n    #                             syn.type = 2\n    #                             syn.weight = 3\n    #                             n.pre_neurons.append(composer.neurons[index-1])\n    #                             n.synapses.append(syn)\n    #\n    #                             #add reverse synapse to neurons of the goal\n    # #                             syn2 = Synapse(n,goal.neurons[index-1])\n    # #                             syn2.type = 2\n    # #                             syn2.weight = 1\n    # #                             goal.neurons[index-1].synapses.append(syn2)\n    #\n    #\n    #     def doConnectToGenre(self, genre, track, order):\n    #         for sl in track.values():\n    #             group = sl.groups.get(order)\n    #\n    #             if(group == None): continue\n    #             # the goal and the group always generate spikes in a limit time window,create a synapse between them.\n    #             tb = (order-1)*group.timeWindow\n    #             te = (order) * group.timeWindow\n    #             sp1_genre = {}\n    #             sp2 = []\n    #\n    #             for n in genre.neurons:\n    #                 sp = []\n    #                 for st in n.spiketime:\n    #                     if(st < te and st >= tb ):\n    #                         sp.append(st)\n    #                 sp1_genre[n.index] = sp\n    #\n    #             for n in group.neurons:\n    #                 if(len(n.spiketime) > 0):\n    #                     for index,sp in sp1_genre.items():\n    #                         temp = 0\n    #                         for sp1 in n.spiketime: #spike times of group\n    #                             for sp2 in sp:\n    #                                 if(abs(sp1-sp2) <= n.tau_ref):\n    #                                     temp += 1\n    #                         if(temp >= 4): # super threshold, create a new synapse between composer and neurons of sequence group\n    #                             syn = Synapse(genre.neurons[index-1],n)\n    #                             syn.type = 2\n    #                             syn.weight = 3\n    #                             n.pre_neurons.append(genre.neurons[index-1])\n    #                             n.synapses.append(syn)\n\n    def setTestStates(self):\n        for itrack in self.sequenceLayers.values():\n            for sl in itrack.values():\n                sl.setTestStates()"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/synapse.py",
    "content": "import math\n\nclass Synapse():\n    '''\n    classdocs\n    '''\n\n    def __init__(self, pre, post):\n        '''\n        Constructor\n        '''\n        self.type = 0  # 0: within a group; 1: different groups in the same layer; 2: other layer\n        self.pre = pre\n        self.post = post\n        self.weight = 0\n        self.excitability = 1  # 1:excited connection; 0:inhibitory connection\n        self.strength = 0  # short term depression and facilitation factor\n        self.delay = 0  # time delay of transmission between pre and post\n\n    def computeWeight(self, t):\n\n        if (self.type == 0):  # pre and post neurons are in the same group\n            pass\n        elif (self.type == 1):  # pre and post neurons are in the same layer but different groups\n            for st in self.pre.spiketime:\n                s = t - st - (self.delay-1)*self.post.timeWindow\n                temp = 0\n                if (self.post.groupIndex - self.pre.groupIndex == self.delay):  # compute weight according to time delay\n                    # using STDP rules\n                    if (s >= 0):\n                        temp = math.exp(-s / 5)\n                    else:\n                        #                         print(self.pre.groupIndex)\n                        #                         print(self.post.groupIndex)\n                        temp = -math.exp(s / 5)\n                    self.weight += temp\n\n\n        elif (self.type == 2):  # pre and post neurons are in the different layers\n            pass\n            # computing the STDP to update the weight within the time window\n        elif (self.type == 3):\n            # pass #fixed weight\n            for st in self.pre.spiketime:\n                s = t - st - (self.delay - 1) * self.post.timeWindow\n                temp = 0\n                # using STDP rules\n                if (s >= 0):\n                    temp = 5 * math.exp(-s / 5)\n                else:\n                    #                         print(self.pre.groupIndex)\n                    #                         print(self.post.groupIndex)\n                    temp = -5 * math.exp(s / 5)\n                self.weight += temp\n\n    def computeShortTermFacilitation(self, t):\n        if (self.type == 1):\n            for st in self.pre.spiketime[::-1]:\n                at = st + self.delay\n                # if (at <= t and at >= t - self.post.tau_ref):  # between current time and time minus refractory period\n                if (at <= t and at >= t):\n                    temp = (self.strength + 1) * 0.2\n                    self.strength += temp\n\n        elif (self.type == 2):\n            #             print(self.pre.areaName)\n            #             print(self.pre.index)\n            #             print(self.pre.groupIndex)\n            for st in self.pre.spiketime:\n                if (st <= t and st >= t - self.post.tau_ref):\n                    temp = (self.strength + 1) * 0.5\n                    self.strength = self.strength + temp\n\n        elif (self.type == 0):\n            if (self.excitability == 0):\n                for st in self.pre.spiketime:\n                    self.strength += (self.strength + 1) * 0.8\n\n    def computeShortTermFacilitation2(self, t):\n        if (self.type == 1):\n            for st in self.pre.spiketime[::-1]:\n                at = st + self.delay\n                # if ( at <= t and at >= t - self.post.tau_ref):\n                if (at <= t):  # between current time and time minus refractory period\n                    self.strength = 1\n\n        elif (self.type >= 2):\n            #             print(self.pre.areaName)\n            #             print(self.pre.index)\n            #             print(self.pre.groupIndex)\n            for st in self.pre.spiketime:\n                # if (st <= t and st >= t - self.post.tau_ref):\n                if (st <= t):\n                    self.strength = 1\n\n    def computeShortTermReduction(self, t):\n        if (self.type == 2):\n            for st in self.pre.spiketime[::-1]:\n                if (t - st > self.post.timeWindow):\n                    self.strength -= (self.strength + 1) * 0.5"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/tempocluster.py",
    "content": "from .cluster import Cluster\nfrom .tempolifneuron import TempoLIFNeuron\nfrom Modal.izhikevichneuron import *\n\nclass TempoCluster(Cluster):\n    '''\n    classdocs\n    '''\n\n    def __init__(self, neutype, neunum):\n        '''\n        Constructor\n        '''\n        Cluster.__init__(self, neutype, neunum)\n\n    def createClusterNetwork(self):\n        for i in range(0, self.neunum):\n            if (self.neutype == 'LIF'):\n                node = TempoLIFNeuron()\n                node.index = i + 1\n                node.areaName = 'TMSM'\n                node.setPreference()\n                self.neurons.append(node)\n            if (self.neutype == 'Izhikevich'):\n                node = TempoIzhikevichNeuron()\n                node.index = i + 1\n                node.areaName = 'TMSM'\n                node.setPreference()\n                self.neurons.append(node)"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/tempolifneuron.py",
    "content": "from .lifneuron import LIFNeuron\nimport math\n\n\nclass TempoLIFNeuron(LIFNeuron):\n    '''\n    classdocs\n    '''\n\n    def __init__(self, tau_ref=0.5, vthresh=5, Rm=2, Cm=0.2):\n        '''\n        Constructor\n        '''\n        LIFNeuron.__init__(self, tau_ref, vthresh, Rm, Cm)\n\n    def setPreference(self):\n        # Gaussian function to set selectivity\n        self.selectivity = self.index * 0.125  #\n\n    def computeFilterCurrent(self):\n        if (self.I_ext <= self.selectivity + 0.0625 and self.I_ext >= self.selectivity - 0.0625):\n            self.I = 10\n\n    def updateCurrentOfLowerAndUpperLayer(self, t):\n        self.I_lower = 0\n        self.I_upper = 0\n        for syn in self.synapses:\n            syn.computeShortTermFacilitation(t)\n            if (syn.type == 0):  # the same group\n                if (syn.excitability == 0):\n                    self.I_lower -= 0.001 * syn.weight * syn.strength\n                    if (self.I_lower < -20): self.I_lower = -20\n            if (syn.type == 1):  # pre and post neurons come from  the same layer but not the same group\n                #                 if(syn.weight > 0):\n                #\n                #                     print('pre_neuron_group id:'+str(syn.pre.groupIndex) + ' neuron index:'+str(syn.pre.index))\n                #                     print('post_neuron_group id:'+str(self.groupIndex) + ' neuron index:'+str(self.index))\n                #                     print(syn.weight)\n                self.I_lower += 0.001 * syn.weight * syn.strength\n                # print('syn.strength='+str(syn.strength))\n\n            if (syn.type == 2):  # pre and post neurons come from the different layers\n                # print(syn.pre.groupIndex)\n                # self.I_lower = syn.weight * syn.strength\n                self.I_upper += 0.001 * syn.weight * syn.strength\n\n        #         if(self.I_upper == 0):\n        #             self.I = self.I_ext\n        else:\n            self.I = 0.4 * self.I_lower + 0.6 * self.I_upper"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/temposequencelayer.py",
    "content": "from .sequencelayer import SequenceLayer\nfrom .tempocluster import TempoCluster\nfrom .synapse import Synapse\n\n\nclass TempoSequenceLayer(SequenceLayer):\n    '''\n    classdocs\n    '''\n\n    def __init__(self, neutype):\n        '''\n        Constructor\n        '''\n        SequenceLayer.__init__(self, neutype)\n\n    def addNewGroups(self, GroupID, layerID, neunum):\n        g = TempoCluster(self.neutype, neunum)\n        g.createClusterNetwork()\n        # g.createInhibitoryConnections()\n        g.id = GroupID\n        g.setPropertiesofNeurons(g.id, 'S', layerID)\n        self.groups[g.id] = g\n\n        # create full connection with the former group\n\n        if (len(self.groups) > 1):\n            s = 0\n            if (g.id <= 5):\n                s = 1\n            else:\n                s = g.id - 4\n            for i in range(s, g.id)[::-1]:\n                pre_g = self.groups.get(i)\n                for n1 in pre_g.neurons:\n                    for n2 in g.neurons:\n                        if (n1.type == 'inh' or n2.type == 'inh'): continue;\n                        syn = Synapse(n1, n2)\n                        syn.type = 1\n                        syn.delay = g.id - pre_g.id\n                        n2.pre_neurons.append(n1)\n                        n2.synapses.append(syn)"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/titlecluster.py",
    "content": "from .cluster import Cluster\nfrom .titlelifneuron import TitleLIFNeuron\n\n\nclass TitleCluster(Cluster):\n    '''\n    classdocs\n    '''\n\n    def __init__(self, neutype, neunum):\n        '''\n        Constructor\n        '''\n        Cluster.__init__(self, neutype, neunum)\n        self.averageFiringRate = 0\n\n    def createClusterNetwork(self):\n        for i in range(0, self.neunum):\n            if (self.neutype == 'LIF'):\n                node = TitleLIFNeuron()\n                node.index = i + 1\n                node.areaName = 'IPS'\n                self.neurons.append(node)\n#             if(self.neutype == 'Izhi'):\n#                 node = IzhikevichNeuron(a = 0.02,b = 0.2,c = -65,d = 8,vthresh = 30)\n#                 node.index = i\n#                 self.neurons.append(node)\n#             if(self.neutype == 'Gaussian'):\n#                 node = GaussianNeuron()\n#                 node.index = i+1\n#                 self.neurons.append(node)\n#             if(self.neutype == 'HH'):\n#                 node = HHNeuron()\n#                 node.index = i\n#                 self.neurons.append(node)"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/titlelayer.py",
    "content": "\n\nfrom .layer import Layer\nfrom .titlecluster import TitleCluster\n\n\nclass TitleLayer(Layer):\n    '''\n    classdocs\n    '''\n\n    def __init__(self, neutype='LIF'):\n        '''\n        Constructor\n        '''\n        self.neutype = neutype\n        self.groups = {}\n\n    def setTestStates(self):\n        for id, g in self.groups.items():\n            g.setTestStates()\n\n    def addNewGroups(self, groupID, layerID, neunum, goalname):\n        g = TitleCluster(self.neutype, neunum)\n        g.id = groupID\n        g.name = goalname\n        g.createClusterNetwork()\n        g.setPropertiesofNeurons(groupID, 'G', layerID)\n        self.groups[goalname] = g"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/Modal/titlelifneuron.py",
    "content": "from .lifneuron import LIFNeuron\nimport math\n\n\nclass TitleLIFNeuron(LIFNeuron):\n    '''\n    classdocs\n    '''\n\n    def __init__(self, tau_ref=0, vthresh=5, Rm=2, Cm=0.2):\n        '''\n        Constructor\n        '''\n        LIFNeuron.__init__(self, tau_ref, vthresh, Rm, Cm)\n\n    def updateCurrentOfLowerAndUpperLayer(self, t):\n        self.I_lower = 0\n        self.I_upper = 0\n        for syn in self.synapses:\n            syn.computeShortTermFacilitation(t)\n            syn.computeShortTermReduction(t)\n            if (syn.type == 2):  # pre and post neurons come from the different layers\n                self.I_lower += syn.weight * syn.strength\n\n        if (self.I_lower <= 0):\n            self.I = self.I_lower\n        else:\n            self.I = math.log(self.I_lower)\n\n    def update(self, dt, t):\n        self.spike = False\n        # self.updateCurrentOfLowerAndUpperLayer(t)\n        if (t >= self.t_rest):\n            self.mem += dt * (-self.mem + self.I * self.Rm) / self.tau_m\n            if (self.mem > self.vth):\n                self.spike = True\n                self.spiketime.append(t)\n                self.mem = 0\n                self.t_rest = t + self.tau_ref\n\n    def computeFiringRate(self):\n        if (self.I == 0):\n            self.firingrate = 0\n        else:\n\n            self.firingrate = 1 / (self.tau_ref + self.Rm * self.Cm * math.log(self.I / (self.I - self.vth)))\n            self.firingrate *= 1000\n            self.firingrate = round(self.firingrate)\n            # print(self.firingrate)"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/README.md",
    "content": "# Music Memory and stylistic composition\n\nThis repository contains code from our paper:\n- [**Temporal-Sequential Learning With a Brain-Inspired Spiking Neural Network and Its Application to Musical Memory**](https://www.cell.com/patterns/fulltext/S2666-3899(22)00119-2) published in Frontiers in Computational Neuroscience. **https://www.cell.com/patterns/fulltext/S2666-3899(22)00119-2**,\n- [**Stylistic composition of melodies based on a brain-inspired spiking neural network**](https://www.frontiersin.org/articles/10.3389/fnsys.2021.639484/full) published in  Frontiers in Systems Neuroscience **https://www.frontiersin.org/articles/10.3389/fnsys.2021.639484/full**.\n- [**Mode-conditioned music learning and composition: a spiking neural network inspired by neuroscience and psychology**](https://arxiv.org/pdf/2411.14773) preprint in arXiv **https://arxiv.org/pdf/2411.14773**.\n\n## Requirments\n\n* numpy\n* scipy\n* pytorch >= 1.7.0\n* pretty_midi >= 0.2.9\n* music21\n\n\n## Data preparation\n\nThe dataset used here can be referred to the website http://www.piano-midi.de/.\nThe dataset used in mode-conditioned music learning can be referred to the website https://github.com/lqnankai/Music-Dataset. \n\n\n## Run\n* Run the script *task/musicMemory.py* to memorize and recall the musical melodies, the result will be recorded in a midi file.\n* Run the script *task/musicGeneration.py* to learn and generate melodies with different styles, the result will be recorded in a midi file.\n* Run the script *task/mode-conditioned learning.py* to learn and generate melodies with different mode and keys, the result will be recorded in a midi file.\nThe API and details can be found in these scripts. \n\n## Citation\nIf you find this package helpful, please consider citing the following papers:\n\n```BibTex\n@article{LQ2020,\n    author  = {Liang, Qian and Zeng, Yi and Xu, Bo},\n    year    = {2020},\n    month   = {07},\n    pages   = {51},\n    title   = {Temporal-Sequential Learning With a Brain-Inspired Spiking Neural Network and Its Application to Musical Memory},\n    volume  = {14},\n    journal = {Frontiers in Computational Neuroscience}\n}\n\n\n@article{LQ2021,\n    title     = {Stylistic composition of melodies based on a brain-inspired spiking neural network},\n    author    = {Liang, Qian and Zeng, Yi},\n    journal   = {Frontiers in systems neuroscience},\n    volume    = {15},\n    pages     = {21},\n    year      = {2021},\n    publisher = {Frontiers}\n}\n\n@misc{liang2024modeconditionedmusiclearningcomposition,\n      title={Mode-conditioned music learning and composition: a spiking neural network inspired by neuroscience and psychology}, \n      author={Qian Liang and Yi Zeng and Menghaoran Tang},\n      year={2024},\n      eprint={2411.14773},\n      archivePrefix={arXiv},\n      primaryClass={cs.SD},\n      url={https://arxiv.org/abs/2411.14773}, \n}\n\n```\n"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/api/music_engine_api.py",
    "content": "\n\nfrom conf.conf import *\nfrom Areas.cortex import Cortex\nimport pretty_midi\nimport math\nimport json\nimport music21 as m21\n\n\nclass EngineAPI():\n    '''\n\n    '''\n\n    def __init__(self):\n        '''\n        Constructor\n        '''\n\n        self.cortex = Cortex(configs.neuron_type, configs.dt)\n\n    def cortexInit(self):\n        self.cortex.musicSequenceMemroyInit()\n        self.cortex.pfc.addNewKey()\n        self.cortex.pfc.addNewMode()\n        self.cortex.pfc.addNewChord()\n\n    def rememberMusic(self, muiscName, composerName=\"None\"):\n        '''\n        :param muiscName: the name of the melody\n        :param composerName: the composer\n        :return:\n        '''\n        muiscName = muiscName.title()\n        composerName = composerName.title()\n        self.cortex.pfc.setTestStates()\n        self.cortex.msm.setTestStates()\n        self.cortex.addSubGoalToPFC(muiscName)\n        self.cortex.addComposerToPFC(composerName)\n        genreName = str(configs.GenreMap.get(composerName))\n        self.cortex.addGenreToPFC(genreName)\n        self.cortex.pfc.innerLearning(muiscName, composerName, genreName)\n\n        goaldic = {}\n        composerdic = {}\n        genredic = {}\n        if (configs.RunTimeState == 1):\n            g = self.cortex.pfc.titles.groups.get(muiscName)\n            c = self.cortex.pfc.composers.groups.get(composerName)\n            gre = self.cortex.pfc.genres.groups.get(genreName)\n            goaldic = g.writeSelfInfoToJson(\"IPS\")\n            composerdic = c.writeSelfInfoToJson(\"Composer\")\n            genredic = gre.writeSelfInfoToJson(\"Genre\")\n\n        return goaldic, composerdic\n\n    def learnFourPartMusic(self,xmldata, musicName, composerName=\"None\"):\n        musicName = musicName.title()\n        composerName = composerName.title()\n        genreName = \"None\"\n        toneName = configs.keyIndexMap.get(str(xmldata.analyze('key')))\n\n        print(musicName + \" learning...\")\n\n        emo = \"None\"\n        for i, part in enumerate(xmldata.parts):\n            if (self.cortex.msm.sequenceLayers.get(i + 1) == None):\n                self.cortex.msm.createActionSequenceMem(i + 1, self.cortex.neutype)\n            self.rememberPartNotes(musicName, composerName, genreName, emo, toneName, i + 1, part)\n\n\n    def rememberPartNotes(self,musicName, composerName, genreName, emo, keyName, partIndex, part):\n        print(\"Learning the part \"+str(partIndex))\n        for i,note in enumerate(part.flat.notes[:20]):\n            p = 0\n            dur = 0\n            if note.isChord:\n                dur = note.duration.quarterLength\n                for chord_note in note:\n                    p = m21.pitch.Pitch(chord_note.pitch).midi\n\n            else:\n                dur = note.duration.quarterLength\n                p = m21.pitch.Pitch(note.pitch).midi\n            if dur == 0.0:\n                dur = 0.125\n            if keyName == 'None':\n                self.cortex.rememberANoteandTempo(musicName, composerName, genreName, emo, partIndex, p, i+1, dur)\n            else:\n                self.cortex.rememberANoteWithKnowledge(musicName, composerName, genreName, emo, keyName, partIndex, p, dur, i+1, part)\n\n    def rememberMIDIMusic(self, musicName, composerName, noteLength, fileName):\n        '''\n        :param musicName: the name of the piece of music\n        :param composerName: the composer who writes this melody\n        :param fileName: the name of this midi file\n        :return: none\n        '''\n        musicName = musicName.title()\n        composerName = composerName.title()\n        print(musicName + \" processing...\")\n        pm = pretty_midi.PrettyMIDI(fileName)\n        genreName = str(configs.GenreMap.get(composerName))\n        for i, ins in enumerate(pm.instruments):\n            if (i >= 1): break;\n            if (self.cortex.msm.sequenceLayers.get(i + 1) == None):\n                # create a new layer to store the track\n                self.cortex.msm.createActionSequenceMem(i + 1, self.cortex.neutype)\n            self.rememberTrackNotes(musicName, composerName, genreName, i + 1, ins, pm, noteLength)\n        print(musicName + \" finished!\")\n\n    def rememberTrackNotes(self, musicName, composerName, genreName, trackIndex, track, pm, noteLength):\n        r_notes = []\n        r_intervals = []\n        total_dic = {}\n\n        print(track)\n        if(noteLength == \"ALL\"):\n            noteLength = len(track.notes)\n        order = 1\n        i = 0\n        #while (i < len(track.notes)):\n        while (i < noteLength):\n            #if (i >= rl): break;\n            note = track.notes[i]\n            start = pm.time_to_tick(note.start)\n            end = pm.time_to_tick(note.end)\n            pitches = []\n            durations = []\n            restFlag = False\n            # this part recognizes a rest\n            if (i == 0):  # determine whether the first note is a rest\n                if (start >= 30):\n                    pitches.append(-1)  # -1 represents a rest\n                    durations.append(start / pm.resolution)\n                    restFlag = True\n            else:\n                lastend = pm.time_to_tick(track.notes[i - 1].end)\n                if (start - lastend >= 50):\n                    pitches.append(-1)\n                    durations.append((start - lastend) / pm.resolution)\n                    restFlag = True\n            if (restFlag == True):\n                dic, g = self.rememberANote(musicName, composerName, genreName, trackIndex, pitches[0], order,\n                                            durations[0], True)\n                if (configs.RunTimeState == 1):\n                    jstr = json.dumps(g)\n                    self.conn.send('/Queue/SampleQueue', jstr)\n                #print(str(order) + \":(-1,\" + str(durations[0]) + \")\")\n                order = order + 1\n                pitches = []\n                durations = []\n\n                # this part recognizes a chord\n            pitches.append(note.pitch)\n            durations.append((end - start) / pm.resolution)\n            j = i + 1\n            while (j < len(track.notes)):\n                nextstart = pm.time_to_tick(track.notes[j].start)\n                nextend = pm.time_to_tick(track.notes[j].end)\n                # if(start == nextstart or end > nextstart):\n                if (math.fabs(start - nextstart) <= 30 or end - nextstart >= 30):\n                    pitches.append(track.notes[j].pitch)\n                    durations.append((nextend - nextstart) / pm.resolution)\n                    j = j + 1\n                else:\n                    break\n            i = j\n\n            if (i < noteLength):\n                dic, g = self.rememberANote(musicName, composerName, genreName, trackIndex, pitches[0], order,\n                                            durations[0], True)\n                str1 = str(order) + \":(\"\n                for t in range(len(pitches)):\n                    str1 += str(pitches[t]) + \",\" + str(durations[t]) + \";\"\n                #print(str1 + \")\")\n                order = order + 1\n                if (configs.RunTimeState == 1):\n                    jstr = json.dumps(g)\n                    self.conn.send('/Queue/SampleQueue', jstr)\n                    nlist = dic.get('MSMSpike')\n                    ns = []\n                    for l in nlist:\n                        n = l.get('Index')\n                        ns.append(n)\n                    r_notes.append(ns)\n                    tlist = dic.get('MSMTSpike')\n                    ts = []\n                    for l in tlist:\n                        t = l.get('Index')\n                        ts.append(t * 60)\n                    r_intervals.append(ts)\n        return total_dic\n\n    def rememberNotes(self, MusicName, notes, intervals, tempo=True):\n        jStr = ''\n        # print(intervals)\n        notesStr = notes.split(\",\")\n        intervalsStr = intervals.split(\",\")\n        intervaltimes = []\n        for i in range(len(intervalsStr) - 1):\n            intervaltimes.append(int(intervalsStr[i]))\n        print(intervaltimes)\n        for i, note in enumerate(notesStr):\n            note = int(note)\n            if (i < len(notesStr) - 1):\n                tinterval = intervalsStr[i]\n                tinterval = int(intervalsStr[i])\n            self.rememberANote(MusicName, note, i + 1, tinterval, tempo)\n        return jStr\n\n    def rememberANote(self, MusicName, ComposerName, genreName, TrackIndex, NoteIndex, order, tinterval, tempo=False):\n        if (tempo == False):\n            dic = self.cortex.rememberANote(MusicName, NoteIndex, order)\n            jsonStr = json.dumps(dic)\n            return jsonStr\n        else:\n            dic, g = self.cortex.rememberANoteandTempo(MusicName, ComposerName, genreName, TrackIndex, NoteIndex, order,\n                                                       tinterval)\n            return dic, g\n\n    def memorizing(self,MusicName, ComposerName, noteLength, fileName):\n        '''\n        :param musicName: the name of the piece of music\n        :param composerName: the composer who writes this melody\n        :param noteLength: the number of notes to be trained(integer), if you want to learn all the notes of a musical work, the value should be specified as \"ALL\"\n        :param fileName: the path and the name of this midi file\n        :return: none\n        '''\n        self.rememberMusic(MusicName, ComposerName)\n        self.rememberMIDIMusic(MusicName,ComposerName,noteLength, fileName)\n\n    def recallMusic(self, musicName):\n        print(\"Recall the \" + musicName + \" ......\")\n        musicName = musicName.title()\n        result = self.cortex.recallMusicPFC(musicName)\n        #print(result)\n        noteResult = {}\n        for tindex,track in result.items():\n            ns = track.get('N')\n            ts = track.get('T')\n            tmp = []\n            for key in ns.keys():\n                dic = {}\n                dic['N']=ns.get(key)\n                dic['T']=ts.get(key)\n                tmp.append(dic)\n            noteResult[tindex] = tmp\n        self.writeMidiFile(musicName+\"_recall\",noteResult)\n        print(\"Recall \" + musicName + \" finished!\")\n        return noteResult\n\n\n    def generateEx_Nihilo(self, firstNote, durations, length,gen_fName):\n        '''\n        parameters:\n        fistNote: Specify the beginning notes to generate a note\n        durations: Specify the duration of the beginning notes\n        length: the length of the generated music, less than 50 notes\n        '''\n        print(\"Generate melody with no style............\")\n        result = self.cortex.generateEx_Nihilo2(firstNote, durations, length)\n        self.writeMidiFile(gen_fName,result)\n        print(\"Generating finished!\")\n        return result\n\n    def generateEx_NihiloAccordingToGenre(self, genreName, firstNote, durations, length,gen_fName):\n        '''\n        parameters:\n        genreName:Specify the style of genre of the generated melody, for example: Baroque,Classical,Romantic\n        fistNote: Specify the beginning notes to generate a note\n        durations: Specify the duration of the beginning notes\n        length: the length of the generated music, less than 50 notes\n        '''\n        print(\"Generate melody with \"+ genreName+\"\\'s style............\")\n        result = self.cortex.generateEx_NihiloAccordingToGenre(genreName, firstNote, durations, length)\n        self.writeMidiFile(gen_fName,result)\n        print(\"Generating finished!\")\n        return result\n\n    def generateEx_NihiloAccordingToComposer(self, composerName, firstNote, durations, length,gen_fName):\n        '''\n        parameters:\n        composerName:Specify the style of composer of the generated melody, for example: Bach, Mozart and etc.\n        fistNote: Specify the beginning notes to generate a note\n        durations: Specify the duration of the beginning notes\n        length: the length of the generated music, less than 50 notes\n        '''\n        print(\"Generate melody with \" + composerName + \"'s style............\")\n        result = self.cortex.generateEx_NihiloAccordingToComposer(composerName, firstNote, durations, length)\n        self.writeMidiFile(gen_fName,result)\n        print(\"Generating finished!\")\n        return result\n\n    def generate2TrackMusic(self, firstNotes, durations, lengths):\n        result = self.cortex.generate2TrackMusic(firstNotes, durations, lengths)\n        return result\n\n    def generateMelodyWithKey(self,tone, firstNotes,durations = None,length = 8):\n\n        result = self.cortex.generateMelodyWithKey(tone, firstNotes,durations,length)\n        return result\n\n\n    def writeMidiFile(self,fileName, mudic):\n        '''\n            mudic format description:\n            mudic = {1:[{'N':71,'T':0.5}.....],\n                         2:[{'N':60,'T':0.25}.....],\n                         ....\n            }\n        '''\n        fileName += \".mid\"\n        pm = pretty_midi.PrettyMIDI()\n        # Create an Instrument instance for a cello instrument\n        for values in mudic.values():\n\n            piano = pretty_midi.Instrument(program=0)\n            # Iterate over note names, which will be converted to note number later\n            start = 0\n            end = 0\n            for i, n in enumerate(values):\n                # Retrieve the MIDI note number for this note name\n                # note_number = pretty_midi.note_name_to_number(note_name)\n                # Create a Note instance, starting at 0s and ending at .5s\n                end = start + n.get('T')\n                note_name = n.get('N')\n                if (note_name == -1):\n                    note = pretty_midi.Note(\n                        velocity=0, pitch=0, start=start, end=end)\n                else:\n                    note = pretty_midi.Note(\n                        velocity=100, pitch=note_name, start=start, end=end)\n                # Add it to our cello instrument\n                piano.notes.append(note)\n                start = end\n            # Add the cello instrument to the PrettyMIDI object\n            pm.instruments.append(piano)\n        # Write out the MIDI data\n        pm.write(fileName)"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/conf/GenreData.txt",
    "content": "Baroque:Bach\nClassical:Haydn,Mozart,Beethoven,Schubert,Clementi\nRomantic:Mendelssohn,Liszt,Chopin,Schumann,Brahms,Burgmueller,Debussy,Godowsky,Moszkowski,Mussorgsky,Rachmaninov,Ravel,Tchaikovsky,Albéniz,Balakirew,Borodin,Granados,Grieg,Sinding\n"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/conf/MIDIData.txt",
    "content": "-1:rest\n0:C3\n1:C sharp3/D flat3\n2:D3\n3:D sharp3/E flat3\n4:E3\n5:F3\n6:F sharp3/G flat3\n7:G3\n8:G sharp3/A flat3\n9:A3\n10:A sharp3/B flat3\n11:B3\n12:C2\n13:C sharp2/D flat2\n14:D2\n15:D sharp2/E flat2\n16:E2\n17:F2\n18:F sharp2/G flat2\n19:G2\n20:G sharp2/A flat2\n21:A2\n22:A sharp2/B flat2\n23:B2\n24:C1\n25:C sharp1/D flat1\n26:D1\n27:D sharp1/E flat1\n28:E1\n29:F1\n30:F sharp1/G flat1\n31:G1\n32:G sharp1/A flat1\n33:A1\n34:A sharp1/B flat1\n35:B1\n36:C\n37:C sharp/D flat\n38:D\n39:D sharp/E flat\n40:E\n41:F\n42:F sharp/G flat\n43:G\n44:G sharp/A flat\n45:A\n46:A sharp/B flat\n47:B\n48:c\n49:c sharp/d flat\n50:d\n51:d sharp/e flat\n52:e\n53:f\n54:f sharp/g flat\n55:g\n56:g sharp/a flat\n57:a\n58:a sharp/b flat\n59:b\n60:c1\n61:c sharp1/d flat1\n62:d1\n63:d sharp1/e flat1\n64:e1\n65:f1\n66:f sharp1/g flat1\n67:g1\n68:g sharp1/a flat1\n69:a1\n70:a sharp1/b flat1\n71:b1\n72:c2\n73:c sharp2/d flat2\n74:d2\n75:d sharp2/e flat2\n76:e2\n77:f2\n78:f sharp2/g flat2\n79:g2\n80:g sharp2/a flat2\n81:a2\n82:a sharp2/b flat2\n83:b2\n84:c3\n85:c sharp3/d flat3\n86:d3\n87:d sharp3/e flat3\n88:e3\n89:f3\n90:f sharp3/g flat3\n91:f3\n92:g sharp3/a flat3\n93:a3\n94:a sharp3/b flat3\n95:b3\n96:c4\n97:c sharp4/d flat4\n98:d4\n99:d sharp4/e flat4\n100:e4\n101:f4\n102:f sharp4/g flat4\n103:g4\n104:g sharp4/a flat4\n105:a4\n106:a sharp4/b flat4\n107:b4\n108:c5\n109:c sharp5/d flat5\n110:d5\n111:d sharp5/e flat5\n112:e5\n113:f5\n114:f sharp5/g flat5\n115:g5\n116:g sharp5/a flat5\n117:a5\n118:a sharp5/b flat5\n119:b5\n120:c6\n121:c sharp6/d flat6\n122:d6\n123:d sharp6/e flat6\n124:e6\n125:f6\n126:f sharp6/g flat6\n127:g6"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/conf/conf.py",
    "content": "import numpy\nimport numpy as np\nimport pandas as pd\nclass Conf():\n    '''\n    classdocs\n    '''\n\n    def __init__(self, neutype=\"LIF\", task=\"MusicLearning\", dt=0.1):\n        '''\n        Constructor\n        '''\n        self.neuron_type = neutype\n        self.task = task\n        self.dt = dt\n        self.notesMap = {}\n        self.GenreMap = {}\n        self.emoMap = {}\n        self.key_matrix = []\n        self.keysMap = {}\n        self.index2key = {}\n        self.index2mode = {}\n        self.keyIndexMap = {}\n        self.keyscales = {}\n        self.chordsMap = {}\n        self.chordsMatrix = np.zeros((7, 7))\n        self.RunTimeState = 0  # 0: GUI, 1: Bigdata experiments 2: other\n\n    def readNoteFiles(self):\n        # f = open(\"./Data.txt\",\"r\")\n        f = open(\"../inputs/MIDIData.txt\", \"r\")\n        while (True):\n            line = f.readline()\n            if not line:\n                break\n            else:\n                strs = line.split(\":\")\n                index = int(strs[0])\n                self.notesMap[index] = strs[1].strip()\n        f.close()\n\n    def readGenreFils(self):\n        f = open(\"../inputs/GenreData.txt\", \"r\")\n        while (True):\n            line = f.readline()\n            if not line:\n                break\n            else:\n                strs = line.split(\":\")\n                g = strs[0].strip()\n                ns = strs[1].split(\",\")\n                for n in ns:\n                    self.GenreMap[(n.strip()).title()] = g.title()\n        f.close()\n\n    def readEmotionFiles(self):\n        f = open(\"../inputs/information.csv\", \"r\")\n        while (True):\n            line = f.readline()\n            if not line:\n                break\n            else:\n                strs = line.split(\",\")\n                mn = strs[0].strip()\n                e = strs[3].strip()\n                self.emoMap[mn.title()] = e.title()\n        f.close()\n\n    def readKeysFile(self):\n        f = open(\"../inputs/keyIndex.csv\", \"r\")\n        while (True):\n            line = (f.readline()).strip()\n            if not line:\n                break\n            else:\n                strs = line.split(\",\")\n                toneName = strs[0].strip()\n                self.keysMap[toneName] = int(strs[1].strip())\n        # print(self.keysMap)\n        self.index2key = dict(zip(self.keysMap.values(), self.keysMap.keys()))\n        # print(self.index2key)\n        self.key_matrix = np.array(pd.read_excel(\"../inputs/keys.xlsx\", sheet_name='keys'))\n        self.keyscales = {0: np.array(pd.read_excel(\"../inputs/keys.xlsx\", sheet_name='major')),\n                          1: np.array(pd.read_excel(\"../inputs/keys.xlsx\", sheet_name='minor'))}\n\n    def readKeys2IndexFile(self):\n        f = open(\"../inputs/keyIndex.csv\", \"r\")\n        while (True):\n            line = f.readline().strip()\n            if not line:\n                break\n            else:\n                strs = line.split(\",\")\n                self.keyIndexMap[strs[0].strip()] = int(strs[1].strip())\n        # print(self.keyIndexMap)\n\n    def readChordsFile(self):\n        tmp = pd.read_excel(\"../inputs/chords.xlsx\", sheet_name=None)\n        for key, chords in tmp.items():\n            chords = np.array(chords)\n            self.chordsMap[key.strip()] = chords\n        # print(self.chordsMap)\n        # 暂时先连接主，下属，属和弦\n        self.chordsMatrix = np.array([[1, 0, 0, 1, 1, 0, 0],\n                                      [0, 0, 0, 0, 0, 0, 0],\n                                      [0, 0, 0, 0, 0, 0, 0],\n                                      [1, 0, 0, 1, 1, 0, 0],\n                                      [1, 0, 0, 0, 1, 0, 0],\n                                      [0, 0, 0, 0, 0, 0, 0],\n                                      [0, 0, 0, 0, 0, 0, 0]])\n\n    def readModesFile(self):\n        f = open(\"../inputs/modeindex.csv\", \"r\")\n        while (True):\n            line = (f.readline()).strip()\n            if not line:\n                break\n            else:\n                strs = line.split(\",\")\n                self.index2mode[int(strs[0].strip())] = strs[1].strip()\n\n\nconfigs = Conf(neutype = 'Izhikevich')\nconfigs.readNoteFiles()\nconfigs.readGenreFils()\nconfigs.readEmotionFiles()\nconfigs.readKeysFile()\nconfigs.readKeys2IndexFile()\nconfigs.readChordsFile()\nconfigs.readModesFile()"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/inputs/1.txt",
    "content": "1:A-B2\n2:A#-B2\n3:B-B2\n4:C-B1\n5:C#-B1\n6:D-B1\n7:D#-B1\n8:E-B1\n9:F-B1\n10:F#-B1\n11:G-B1\n12:G#-B1\n13:A-B1\n14:A#-B1\n15:B-B1\n16:C-B\n17:C#-B\n18:D-B\n19:D#-B\n20:E-B\n21:F-B\n22:F#-B\n23:G-B\n24:G#-B\n25:A-B\n26:A#-B\n27:B-B\n28:C-S\n29:C#-S\n30:D-S\n31:D#-S\n32:E-S\n33:F-S\n34:F#-S\n35:G-S\n36:G#-S\n37:A-S\n38:A#-S\n39:B-S\n40:C-S1\n41:C#-S1\n42:D-S1\n43:D#-S1\n44:E-S1\n45:F-S1\n46:F#-S1\n47:G-S1\n48:G#-S1\n49:A-S1\n50:A#-S1\n51:B-S1\n52:C-S2\n53:C#-S2\n54:D-S2\n55:D#-S2\n56:E-S2\n57:F-S2\n58:F#-S2\n59:G-S2\n60:G#-S2\n61:A-S2\n62:A#-S2\n63:B-S2\n64:C-S3\n65:C#-S3\n66:D-S3\n67:D#-S3\n68:E-S3\n69:F-S3\n70:F#-S3\n71:G-S3\n72:G#-S3\n73:A-S3\n74:A#-S3\n75:B-S3\n76:C-S4\n77:C#-S4\n78:D-S4\n79:D#-S4\n80:E-S4\n81:F-S4\n82:F#-S4\n83:G-S4\n84:G#-S4\n85:A-S4\n86:A#-S4\n87:B-S4\n88:C-S5\n"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/inputs/Data.txt",
    "content": "1:A2\n2:A#2\n3:B2\n4:C1\n5:C#1\n6:D1\n7:D#1\n8:E1\n9:F1\n10:F#1\n11:G1\n12:G#1\n13:A1\n14:A#1\n15:B1\n16:C\n17:C#\n18:D\n19:D#\n20:E\n21:F\n22:F#\n23:G\n24:G#\n25:A\n26:A#\n27:B\n28:c\n29:c#\n30:d\n31:d#\n32:e\n33:f\n34:f#\n35:g\n36:g#\n37:a\n38:a#\n39:b\n40:c1\n41:c#1\n42:d1\n43:d#1\n44:e1\n45:f1\n46:f#1\n47:g1\n48:g#1\n49:a1\n50:a#1\n51:b1\n52:c2\n53:c#2\n54:d2\n55:d#2\n56:e2\n57:f2\n58:f#2\n59:g2\n60:g#2\n61:a2\n62:a#2\n63:b2\n64:c3\n65:c#3\n66:d3\n67:d#3\n68:e3\n69:f3\n70:f#3\n71:f3\n72:g#3\n73:a3\n74:a#3\n75:b3\n76:c4\n77:c#4\n78:d4\n79:d#4\n80:e4\n81:f4\n82:f#4\n83:g4\n84:g#4\n85:a4\n86:a#4\n87:b4\n88:c5"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/inputs/GenreData.txt",
    "content": "Baroque:Bach\nClassical:Haydn,Mozart,Beethoven,Schubert,Clementi\nRomantic:Mendelssohn,Liszt,Chopin,Schumann,Brahms,Burgmueller,Debussy,Godowsky,Moszkowski,Mussorgsky,Rachmaninov,Ravel,Tchaikovsky,Albéniz,Balakirew,Borodin,Granados,Grieg,Sinding\n"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/inputs/MIDIData.txt",
    "content": "-1:rest\n0:C3\n1:C sharp3/D flat3\n2:D3\n3:D sharp3/E flat3\n4:E3\n5:F3\n6:F sharp3/G flat3\n7:G3\n8:G sharp3/A flat3\n9:A3\n10:A sharp3/B flat3\n11:B3\n12:C2\n13:C sharp2/D flat2\n14:D2\n15:D sharp2/E flat2\n16:E2\n17:F2\n18:F sharp2/G flat2\n19:G2\n20:G sharp2/A flat2\n21:A2\n22:A sharp2/B flat2\n23:B2\n24:C1\n25:C sharp1/D flat1\n26:D1\n27:D sharp1/E flat1\n28:E1\n29:F1\n30:F sharp1/G flat1\n31:G1\n32:G sharp1/A flat1\n33:A1\n34:A sharp1/B flat1\n35:B1\n36:C\n37:C sharp/D flat\n38:D\n39:D sharp/E flat\n40:E\n41:F\n42:F sharp/G flat\n43:G\n44:G sharp/A flat\n45:A\n46:A sharp/B flat\n47:B\n48:c\n49:c sharp/d flat\n50:d\n51:d sharp/e flat\n52:e\n53:f\n54:f sharp/g flat\n55:g\n56:g sharp/a flat\n57:a\n58:a sharp/b flat\n59:b\n60:c1\n61:c sharp1/d flat1\n62:d1\n63:d sharp1/e flat1\n64:e1\n65:f1\n66:f sharp1/g flat1\n67:g1\n68:g sharp1/a flat1\n69:a1\n70:a sharp1/b flat1\n71:b1\n72:c2\n73:c sharp2/d flat2\n74:d2\n75:d sharp2/e flat2\n76:e2\n77:f2\n78:f sharp2/g flat2\n79:g2\n80:g sharp2/a flat2\n81:a2\n82:a sharp2/b flat2\n83:b2\n84:c3\n85:c sharp3/d flat3\n86:d3\n87:d sharp3/e flat3\n88:e3\n89:f3\n90:f sharp3/g flat3\n91:f3\n92:g sharp3/a flat3\n93:a3\n94:a sharp3/b flat3\n95:b3\n96:c4\n97:c sharp4/d flat4\n98:d4\n99:d sharp4/e flat4\n100:e4\n101:f4\n102:f sharp4/g flat4\n103:g4\n104:g sharp4/a flat4\n105:a4\n106:a sharp4/b flat4\n107:b4\n108:c5\n109:c sharp5/d flat5\n110:d5\n111:d sharp5/e flat5\n112:e5\n113:f5\n114:f sharp5/g flat5\n115:g5\n116:g sharp5/a flat5\n117:a5\n118:a sharp5/b flat5\n119:b5\n120:c6\n121:c sharp6/d flat6\n122:d6\n123:d sharp6/e flat6\n124:e6\n125:f6\n126:f sharp6/g flat6\n127:g6"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/inputs/chords.csv",
    "content": "C,,,\r\n1,C,E,G\r\n4,F,A,C\r\n5,G,B,D\r\n,,,\r\na,,,\r\n1,A,C,E\r\n4,D,F,A\r\n5,E,G#,B\r\n,,,\r\nG,,,\r\n1,G,B,D\r\n4,C,E,G\r\n5,D,F#,A\r\n,,,\r\nF,,,\r\n1,F,A,C\r\n4,B-,D,F\r\n5,C,E,G\r\n"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/inputs/information.csv",
    "content": "chpn_op35_2.mid,Chopin,Romantic,unclear\nchpn_op33_4.mid,Chopin,Romantic,unclear\nchpn-p14.mid,Chopin,Romantic,unclear\nchpn-p17.mid,Chopin,Romantic,soft\nchpn_op33_2.mid,Chopin,Romantic,happy\nchpn_op53.mid,Chopin,Romantic,unclear\nchpn_op35_3.mid,Chopin,Romantic,unclear\nchpn-p1.mid,Chopin,Romantic,unclear\nchpn_op7_2.mid,Chopin,Romantic,soft\nchpn_op25_e11.mid,Chopin,Romantic,passionate\nchpn-p16.mid,Chopin,Romantic,unclear\nchpn-p20.mid,Chopin,Romantic,depressed\nchpn-p13.mid,Chopin,Romantic,soft\nchpn_op27_2.mid,Chopin,Romantic,soft\nchpn-p6.mid,Chopin,Romantic,depressed\nchpn-p9.mid,Chopin,Romantic,unclear\nchpn-p15.mid,Chopin,Romantic,depressed\nchpn_op10_e12.mid,Chopin,Romantic,passionate\nchpn-p22.mid,Chopin,Romantic,unclear\nchpn_op23.mid,Chopin,Romantic,unclear\nchpn-p24.mid,Chopin,Romantic,passionate\nchpn-p18.mid,Chopin,Romantic,happy\nchpn_op27_1.mid,Chopin,Romantic,depressed\nchpn_op10_e05.mid,Chopin,Romantic,passionate\nchpn_op25_e4.mid,Chopin,Romantic,unclear\nchpn_op66.mid,Chopin,Romantic,unclear\nchpn_op25_e2.mid,Chopin,Romantic,soft\nchpn_op10_e01.mid,Chopin,Romantic,passionate\nchpn-p11.mid,Chopin,Romantic,happy\nchpn-p7.mid,Chopin,Romantic,soft\nchp_op18.mid,Chopin,Romantic,unclear\nchpn-p23.mid,Chopin,Romantic,soft\nchpn-p5.mid,Chopin,Romantic,unclear\nchpn_op35_4.mid,Chopin,Romantic,unclear\nchpn_op25_e3.mid,Chopin,Romantic,unclear\nchpn-p12.mid,Chopin,Romantic,passionate\nchpn-p21.mid,Chopin,Romantic,soft\nchpn-p4.mid,Chopin,Romantic,depressed\nchpn-p2.mid,Chopin,Romantic,depressed\nchpn_op35_1.mid,Chopin,Romantic,unclear\nchp_op31.mid,Chopin,Romantic,unclear\nchpn_op25_e1.mid,Chopin,Romantic,unclear\nchpn-p10.mid,Chopin,Romantic,unclear\nchpn-p8.mid,Chopin,Romantic,passionate\nchpn_op25_e12.mid,Chopin,Romantic,unclear\nchpn_op7_1.mid,Chopin,Romantic,happy\nchpn-p3.mid,Chopin,Romantic,unclear\nchpn-p19.mid,Chopin,Romantic,soft\nscn15_5.mid,Schumann,Romantic,unclear\nscn15_7.mid,Schumann,Romantic,unclear\nscn15_12.mid,Schumann,Romantic,unclear\nscn15_6.mid,Schumann,Romantic,unclear\nscn15_13.mid,Schumann,Romantic,unclear\nscn68_12.mid,Schumann,Romantic,unclear\nscn15_1.mid,Schumann,Romantic,unclear\nscn15_3.mid,Schumann,Romantic,unclear\nscn15_2.mid,Schumann,Romantic,unclear\nscn16_3.mid,Schumann,Romantic,unclear\nscn16_2.mid,Schumann,Romantic,unclear\nscn68_10.mid,Schumann,Romantic,unclear\nscn16_6.mid,Schumann,Romantic,unclear\nscn16_5.mid,Schumann,Romantic,unclear\nscn16_1.mid,Schumann,Romantic,unclear\nscn15_9.mid,Schumann,Romantic,unclear\nscn15_8.mid,Schumann,Romantic,unclear\nscn15_11.mid,Schumann,Romantic,unclear\nscn16_7.mid,Schumann,Romantic,unclear\nscn16_8.mid,Schumann,Romantic,unclear\nscn16_4.mid,Schumann,Romantic,unclear\nscn15_10.mid,Schumann,Romantic,unclear\nschum_abegg.mid,Schumann,Romantic,unclear\nscn15_4.mid,Schumann,Romantic,unclear\nty_august.mid,Tchaikovsky,Romantic,unclear\nty_april.mid,Tchaikovsky,Romantic,unclear\nty_juli.mid,Tchaikovsky,Romantic,unclear\nty_oktober.mid,Tchaikovsky,Romantic,unclear\nty_juni.mid,Tchaikovsky,Romantic,unclear\nty_november.mid,Tchaikovsky,Romantic,unclear\nty_januar.mid,Tchaikovsky,Romantic,unclear\nty_dezember.mid,Tchaikovsky,Romantic,unclear\nty_maerz.mid,Tchaikovsky,Romantic,unclear\nty_mai.mid,Tchaikovsky,Romantic,unclear\nty_februar.mid,Tchaikovsky,Romantic,unclear\nty_september.mid,Tchaikovsky,Romantic,unclear\ndebussy_cc_3.mid,Debussy,Romantic,happy\nDEB_CLAI.MID,Debussy,Romantic,soft\ndebussy_cc_2.mid,Debussy,Romantic,unclear\ndebussy_cc_6.mid,Debussy,Romantic,unclear\ndeb_menu.mid,Debussy,Romantic,unclear\nDEB_PASS.MID,Debussy,Romantic,unclear\ndeb_prel.mid,Debussy,Romantic,happy\ndebussy_cc_1.mid,Debussy,Romantic,unclear\ndebussy_cc_4.mid,Debussy,Romantic,unclear\nalb_se2.mid,Albeniz,Romantic,unclear\nalb_se8.mid,Albeniz,Romantic,unclear\nalb_se1.mid,Albeniz,Romantic,unclear\nalb_esp6.mid,Albeniz,Romantic,unclear\nalb_esp4.mid,Albeniz,Romantic,unclear\nalb_esp2.mid,Albeniz,Romantic,unclear\nalb_se4.mid,Albeniz,Romantic,unclear\nalb_esp1.mid,Albeniz,Romantic,unclear\nalb_esp5.mid,Albeniz,Romantic,unclear\nalb_se6.mid,Albeniz,Romantic,unclear\nalb_se3.mid,Albeniz,Romantic,unclear\nalb_esp3.mid,Albeniz,Romantic,unclear\nalb_se5.mid,Albeniz,Romantic,unclear\nalb_se7.mid,Albeniz,Romantic,unclear\nliz_et_trans8.mid,Liszt,Romantic,unclear\nliz_liebestraum.mid,Liszt,Romantic,soft\nliz_rhap10.mid,Liszt,Romantic,unclear\nliz_rhap15.mid,Liszt,Romantic,unclear\nliz_et2.mid,Liszt,Romantic,unclear\nliz_rhap09.mid,Liszt,Romantic,unclear\nliz_rhap02.mid,Liszt,Romantic,unclear\nliz_rhap12.mid,Liszt,Romantic,unclear\nliz_et4.mid,Liszt,Romantic,unclear\nliz_et1.mid,Liszt,Romantic,unclear\nliz_et_trans4.mid,Liszt,Romantic,unclear\nliz_et5.mid,Liszt,Romantic,unclear\nliz_donjuan.mid,Liszt,Romantic,unclear\nliz_et3.mid,Liszt,Romantic,soft\nliz_et6.mid,Liszt,Romantic,unclear\nliz_et_trans5.mid,Liszt,Romantic,unclear\npathetique_1.mid,Beethoven,Classical,unclear\nbeethoven_hammerklavier_4.mid,Beethoven,Classical,unclear\nbeethoven_hammerklavier_2.mid,Beethoven,Classical,unclear\nbeethoven_opus22_4.mid,Beethoven,Classical,happy\nwaldstein_2.mid,Beethoven,Classical,unclear\nbeethoven_hammerklavier_3.mid,Beethoven,Classical,depressed\nbeethoven_les_adieux_3.mid,Beethoven,Classical,happy\nbeethoven_opus22_2.mid,Beethoven,Classical,unclear\nwaldstein_1.mid,Beethoven,Classical,happy\nappass_2.mid,Beethoven,Classical,unclear\nbeethoven_opus10_3.mid,Beethoven,Classical,unclear\nbeethoven_opus22_1.mid,Beethoven,Classical,unclear\nbeethoven_hammerklavier_1.mid,Beethoven,Classical,unclear\npathetique_3.mid,Beethoven,Classical,passionate\nbeethoven_opus90_2.mid,Beethoven,Classical,unclear\nbeethoven_opus22_3.mid,Beethoven,Classical,happy\nbeethoven_opus90_1.mid,Beethoven,Classical,unclear\nwaldstein_3.mid,Beethoven,Classical,happy\nbeethoven_opus10_1.mid,Beethoven,Classical,unclear\nappass_1.mid,Beethoven,Classical,unclear\nbeethoven_opus10_2.mid,Beethoven,Classical,unclear\nmond_2.mid,Beethoven,Classical,unclear\nbeethoven_les_adieux_1.mid,Beethoven,Classical,depressed\nelise.mid,Beethoven,Classical,soft\nappass_3.mid,Beethoven,Classical,passionate\npathetique_2.mid,Beethoven,Classical,soft\nmond_3.mid,Beethoven,Classical,passionate\nbeethoven_les_adieux_2.mid,Beethoven,Classical,depressed\nmond_1.mid,Beethoven,Classical,depressed\nmuss_3.mid,Mussorgsky,Romantic,unclear\nmuss_5.mid,Mussorgsky,Romantic,unclear\nmuss_1.mid,Mussorgsky,Romantic,unclear\nmuss_8.mid,Mussorgsky,Romantic,unclear\nmuss_6.mid,Mussorgsky,Romantic,unclear\nmuss_2.mid,Mussorgsky,Romantic,unclear\nmuss_7.mid,Mussorgsky,Romantic,unclear\nmuss_4.mid,Mussorgsky,Romantic,unclear\nmendel_op30_4.mid,Mendelssohn,Romantic,unclear\nmendel_op19_2.mid,Mendelssohn,Romantic,unclear\nmendel_op19_5.mid,Mendelssohn,Romantic,unclear\nmendel_op19_4.mid,Mendelssohn,Romantic,soft\nmendel_op19_6.mid,Mendelssohn,Romantic,depressed\nmendel_op19_1.mid,Mendelssohn,Romantic,soft\nmendel_op30_5.mid,Mendelssohn,Romantic,unclear\nmendel_op30_2.mid,Mendelssohn,Romantic,unclear\nmendel_op62_3.mid,Mendelssohn,Romantic,depressed\nmendel_op30_3.mid,Mendelssohn,Romantic,unclear\nmendel_op62_4.mid,Mendelssohn,Romantic,unclear\nmendel_op19_3.mid,Mendelssohn,Romantic,unclear\nmendel_op53_5.mid,Mendelssohn,Romantic,unclear\nmendel_op30_1.mid,Mendelssohn,Romantic,soft\nmendel_op62_5.mid,Mendelssohn,Romantic,unclear\ngra_esp_2.mid,Granados,Romantic,unclear\ngra_esp_3.mid,Granados,Romantic,unclear\ngra_esp_4.mid,Granados,Romantic,unclear\nfruehlingsrauschen.mid,Sinding,Romantic,unclear\nrac_op32_1.mid,Rachmaninov,Romantic,unclear\nrac_op23_3.mid,Rachmaninov,Romantic,unclear\nrac_op33_8.mid,Rachmaninov,Romantic,unclear\nrac_op33_6.mid,Rachmaninov,Romantic,unclear\nrac_op33_5.mid,Rachmaninov,Romantic,unclear\nrac_op23_7.mid,Rachmaninov,Romantic,unclear\nrac_op32_13.mid,Rachmaninov,Romantic,unclear\nrac_op23_2.mid,Rachmaninov,Romantic,unclear\nrac_op3_2.mid,Rachmaninov,Romantic,unclear\nrac_op23_5.mid,Rachmaninov,Romantic,unclear\ngod_alb_esp2.mid,Godowsky,Romantic,unclear\ngod_chpn_op10_e01.mid,Godowsky,Romantic,unclear\nmz_545_1.mid,Mozart,Classical,unclear\nmz_570_2.mid,Mozart,Classical,unclear\nmz_311_2.mid,Mozart,Classical,unclear\nmz_333_3.mid,Mozart,Classical,unclear\nmz_331_3.mid,Mozart,Classical,passionate\nmz_332_3.mid,Mozart,Classical,unclear\nmz_331_2.mid,Mozart,Classical,unclear\nmz_332_1.mid,Mozart,Classical,unclear\nmz_545_2.mid,Mozart,Classical,unclear\nmz_330_1.mid,Mozart,Classical,unclear\nmz_545_3.mid,Mozart,Classical,unclear\nmz_333_2.mid,Mozart,Classical,unclear\nmz_330_2.mid,Mozart,Classical,unclear\nmz_333_1.mid,Mozart,Classical,unclear\nmz_311_3.mid,Mozart,Classical,unclear\nmz_331_1.mid,Mozart,Classical,happy\nmz_570_3.mid,Mozart,Classical,unclear\nmz_332_2.mid,Mozart,Classical,unclear\nmz_570_1.mid,Mozart,Classical,unclear\nmz_330_3.mid,Mozart,Classical,unclear\nmz_311_1.mid,Mozart,Classical,unclear\nrav_scarbo.mid,Ravel,Romantic,unclear\nrav_ondi.mid,Ravel,Romantic,unclear\nrav_eau.mid,Ravel,Romantic,unclear\nrav_gib.mid,Ravel,Romantic,unclear\nravel_miroirs_1.mid,Ravel,Romantic,unclear\nclementi_opus36_4_1.mid,Clementi,Classical,unclear\nclementi_opus36_2_2.mid,Clementi,Classical,unclear\nclementi_opus36_6_1.mid,Clementi,Classical,happy\nclementi_opus36_1_3.mid,Clementi,Classical,happy\nclementi_opus36_3_3.mid,Clementi,Classical,unclear\nclementi_opus36_5_1.mid,Clementi,Classical,happy\nclementi_opus36_1_2.mid,Clementi,Classical,soft\nclementi_opus36_4_3.mid,Clementi,Classical,unclear\nclementi_opus36_1_1.mid,Clementi,Classical,happy\nclementi_opus36_2_3.mid,Clementi,Classical,unclear\nclementi_opus36_4_2.mid,Clementi,Classical,unclear\nclementi_opus36_3_2.mid,Clementi,Classical,unclear\nclementi_opus36_5_3.mid,Clementi,Classical,unclear\nclementi_opus36_5_2.mid,Clementi,Classical,unclear\nclementi_opus36_3_1.mid,Clementi,Classical,unclear\nclementi_opus36_6_2.mid,Clementi,Classical,unclear\nclementi_opus36_2_1.mid,Clementi,Classical,happy\nmos_op36_6.mid,Moszkowski,Romantic,unclear\nhaydn_7_2.mid,Haydn,Classical,unclear\nhaydn_35_2.mid,Haydn,Classical,unclear\nhaydn_7_3.mid,Haydn,Classical,unclear\nhaydn_8_4.mid,Haydn,Classical,unclear\nhaydn_9_1.mid,Haydn,Classical,unclear\nhaydn_8_3.mid,Haydn,Classical,unclear\nhaydn_35_3.mid,Haydn,Classical,unclear\nhaydn_33_3.mid,Haydn,Classical,unclear\nhaydn_8_1.mid,Haydn,Classical,unclear\nhaydn_8_2.mid,Haydn,Classical,unclear\nhaydn_7_1.mid,Haydn,Classical,unclear\nhaydn_9_2.mid,Haydn,Classical,unclear\nhaydn_43_1.mid,Haydn,Classical,unclear\nhay_40_2.mid,Haydn,Classical,unclear\nhaydn_9_3.mid,Haydn,Classical,unclear\nhaydn_35_1.mid,Haydn,Classical,unclear\nhaydn_33_1.mid,Haydn,Classical,unclear\nhaydn_43_2.mid,Haydn,Classical,unclear\nhay_40_1.mid,Haydn,Classical,unclear\nhaydn_33_2.mid,Haydn,Classical,unclear\nhaydn_43_3.mid,Haydn,Classical,unclear\nbach_850.mid,Bach,Baroque,happy\nbach_846.mid,Bach,Baroque,soft\nbach_847.mid,Bach,Baroque,unclear\ngrieg_halling.mid,Grieg,Romantic,unclear\ngrieg_wedding.mid,Grieg,Romantic,unclear\ngrieg_brooklet.mid,Grieg,Romantic,unclear\ngrieg_butterfly.mid,Grieg,Romantic,happy\ngrieg_wanderer.mid,Grieg,Romantic,unclear\ngrieg_zwerge.mid,Grieg,Romantic,unclear\ngrieg_march.mid,Grieg,Romantic,unclear\ngrieg_album.mid,Grieg,Romantic,unclear\ngrieg_elfentanz.mid,Grieg,Romantic,unclear\ngrieg_spring.mid,Grieg,Romantic,unclear\ngrieg_waechter.mid,Grieg,Romantic,unclear\ngrieg_kobold.mid,Grieg,Romantic,unclear\ngrieg_berceuse.mid,Grieg,Romantic,unclear\ngrieg_voeglein.mid,Grieg,Romantic,unclear\ngrieg_walzer.mid,Grieg,Romantic,unclear\ngrieg_once_upon_a_time.mid,Grieg,Romantic,unclear\nbr_im2.mid,Brahms,Romantic,unclear\nbr_rhap.mid,Brahms,Romantic,unclear\nbrahms_opus1_1.mid,Brahms,Romantic,unclear\nbrahms_opus1_3.mid,Brahms,Romantic,unclear\nbrahms_opus117_1.mid,Brahms,Romantic,unclear\nbrahms_opus1_2.mid,Brahms,Romantic,unclear\nBR_IM6.MID,Brahms,Romantic,unclear\nbrahms_opus1_4.mid,Brahms,Romantic,unclear\nbrahms_opus117_2.mid,Brahms,Romantic,unclear\nbr_im5.mid,Brahms,Romantic,unclear\nburg_geschwindigkeit.mid,Burgmueller,Romantic,unclear\nburg_perlen.mid,Burgmueller,Romantic,unclear\nburg_trennung.mid,Burgmueller,Romantic,unclear\nburg_agitato.mid,Burgmueller,Romantic,unclear\nburg_sylphen.mid,Burgmueller,Romantic,unclear\nburg_spinnerlied.mid,Burgmueller,Romantic,unclear\nburg_quelle.mid,Burgmueller,Romantic,unclear\nburg_erwachen.mid,Burgmueller,Romantic,unclear\nburg_gewitter.mid,Burgmueller,Romantic,unclear\nschuim-4.mid,Schubert,Classical,unclear\nschumm-6.mid,Schubert,Classical,unclear\nschu_143_2.mid,Schubert,Classical,unclear\nschubert_D850_3.mid,Schubert,Classical,unclear\nschub_d960_1.mid,Schubert,Classical,unclear\nschubert_D935_4.mid,Schubert,Classical,unclear\nschu_143_1.mid,Schubert,Classical,unclear\nschubert_D850_2.mid,Schubert,Classical,unclear\nschumm-2.mid,Schubert,Classical,unclear\nschubert_D850_4.mid,Schubert,Classical,unclear\nschub_d960_3.mid,Schubert,Classical,unclear\nschub_d760_3.mid,Schubert,Classical,unclear\nschumm-5.mid,Schubert,Classical,unclear\nschumm-3.mid,Schubert,Classical,unclear\nschumm-1.mid,Schubert,Classical,unclear\nschub_d760_4.mid,Schubert,Classical,unclear\nschubert_D935_2.mid,Schubert,Classical,unclear\nschumm-4.mid,Schubert,Classical,unclear\nschu_143_3.mid,Schubert,Classical,unclear\nschub_d960_2.mid,Schubert,Classical,unclear\nschuim-1.mid,Schubert,Classical,unclear\nschub_d760_1.mid,Schubert,Classical,unclear\nschubert_D850_1.mid,Schubert,Classical,unclear\nschubert_D935_3.mid,Schubert,Classical,unclear\nschub_d960_4.mid,Schubert,Classical,unclear\nschuim-3.mid,Schubert,Classical,unclear\nschub_d760_2.mid,Schubert,Classical,unclear\nschubert_D935_1.mid,Schubert,Classical,unclear\nschuim-2.mid,Schubert,Classical,unclear\nislamei.mid,Balakirew,Romantic,unclear\nbor_ps2.mid,Borodin,Romantic,unclear\nbor_ps1.mid,Borodin,Romantic,unclear\nbor_ps7.mid,Borodin,Romantic,unclear\nbor_ps4.mid,Borodin,Romantic,unclear\nbor_ps6.mid,Borodin,Romantic,unclear\nbor_ps3.mid,Borodin,Romantic,unclear\nbor_ps5.mid,Borodin,Romantic,unclear\n"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/inputs/keyIndex.csv",
    "content": "C major,0\na minor,1\nG major,2\ne minor,3\nD major,4\nb minor,5\nA major,6\nf# minor,7\nE major,8\nc# minor,9\nB major,10\ng# minor,11\nF major,12\nd minor,13\nB- major,14\ng minor,15\nE- major,16\nc minor,17\nA- major,18\nf minor,19\nD- major,20\nb- minor,21\nG- major,22\ne- minor,23\nC# major,20\na# minor,21\nF# major,22\nd# minor,23\nC- major,10\na- minor,11\n"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/inputs/keys.csv",
    "content": "1,-1,2,-1,3,4,-1,5,-1,6,-1,7\n3,-1,4,-1,5,6,-1,-1,7,1,-1,2\n4,-1,5,-1,6,-1,7,1,-1,2,-1,3\n6,-1,-1,7,1,-1,2,3,-1,4,-1,5\n-1,7,1,-1,2,-1,3,4,-1,5,-1,6\n-1,2,3,-1,4,-1,5,6,-1,-1,7,1\n-1,3,4,-1,5,-1,6,-1,7,1,-1,2\n-1,5,6,-1,-1,7,1,-1,2,3,-1,4\n-1,6,-1,7,1,-1,2,-1,3,4,-1,5\n7,1,-1,2,3,-1,4,-1,5,6,-1,-1\n-1,2,-1,3,4,-1,5,-1,6,-1,7,1\n-1,4,-1,5,6,-1,-1,7,1,-1,2,3\n5,-1,6,-1,7,1,-1,2,-1,3,4,-1\n-1,7,1,-1,2,3,-1,4,-1,5,6,-1\n2,-1,3,4,-1,5,-1,6,-1,7,1,-1\n4,-1,5,6,-1,-1,7,1,-1,2,3,-1\n6,-1,7,1,-1,2,-1,3,4,-1,5,-1\n1,-1,2,3,-1,4,-1,5,6,-1,-1,7\n3,4,-1,5,-1,6,-1,7,1,-1,2,-1\n5,6,-1,-1,7,1,-1,2,3,-1,4,-1\n7,1,-1,2,-1,3,4,-1,5,-1,6,-1\n2,3,-1,4,-1,5,6,-1,-1,7,1,-1\n-1,5,-1,6,-1,7,1,-1,2,-1,3,4\n-1,-1,7,1,-1,2,3,-1,4,-1,5,6\n"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/inputs/modeindex.csv",
    "content": "0,major\n1,minor\n"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/inputs/pitch2midi.csv",
    "content": "C,0,12,24,36,48,60,72,84,96,108,120\r\nC#,1,13,25,37,49,61,73,85,97,109,121\r\nC-,11,23,35,47,59,71,83,95,107,119,\r\nD,2,14,26,38,50,62,74,86,98,110,122\r\nD#,3,15,27,39,51,63,75,87,99,111,123\r\nD-,1,13,25,37,49,61,73,85,97,109,121\r\nE,4,16,28,40,52,64,76,88,100,112,124\r\nE#,5,17,29,41,53,65,77,89,101,113,125\r\nE-,3,15,27,39,51,63,75,87,99,111,123\r\nF,5,17,29,41,53,65,77,89,101,113,125\r\nF#,6,18,30,42,54,66,78,90,102,114,126\r\nF-,4,16,28,40,52,64,76,88,100,112,124\r\nG,7,19,31,43,55,67,79,91,103,115,127\r\nG#,8,20,32,44,56,68,80,92,104,116,\r\nG-,6,18,30,42,54,66,78,90,102,114,126\r\nA,9,21,33,45,57,69,81,93,105,117,\r\nA#,10,22,34,46,58,70,82,94,106,118,\r\nA-,8,20,32,44,56,68,80,92,104,116,\r\nB,11,23,35,47,59,71,83,95,107,119,\r\nB+,0,12,24,36,48,60,72,84,96,108,120\r\nB-,10,22,34,46,58,70,82,94,106,118,\r\n"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/inputs/tones2.csv",
    "content": ",0,1,2,3,4,5,6,7,8,9,10,11\r\nC major,2,-1,1,-1,1,1,-1,1,-1,1,-1,1\r\na minor,1,-1,1,-1,1,1,-1,-1,2,2,-1,1\r\nG major,1,-1,1,-1,1,-1,2,2,-1,1,-1,1\r\ne minor,1,-1,-1,2,2,-1,2,1,-1,1,-1,1\r\nD major,-1,2,2,-1,1,-1,2,1,-1,1,-1,1\r\nb minor,-1,2,1,-1,1,-1,2,1,-1,-1,2,2\r\nA major,-1,2,1,-1,1,-1,2,-1,2,2,-1,1\r\nf# minor,-1,2,1,-1,-1,2,2,-1,2,1,-1,1\r\nE major,-1,2,-1,2,2,-1,2,-1,2,1,-1,1\r\nc# minor,2,2,-1,2,1,-1,2,-1,2,1,-1,-1\r\nB major,-1,2,-1,2,1,-1,2,-1,2,-1,2,2\r\ng# minor,-1,2,-1,2,1,-1,-1,2,2,-1,2,1\r\nF major,1,-1,1,-1,1,2,-1,1,-1,1,2,-1\r\nd minor,-1,2,2,-1,1,1,-1,1,-1,1,2,-1\r\nB- major,1,-1,1,2,-1,1,-1,1,-1,1,2,-1\r\ng minor,1,-1,1,2,-1,-1,2,2,-1,1,2,-1\r\nE- major,1,-1,1,2,-1,1,-1,1,2,-1,2,-1\r\nc minor,1,-1,1,2,-1,1,-1,1,2,-1,-1,2\r\nA- major,1,2,-1,2,-1,1,-1,1,2,-1,2,-1\r\nf minor,1,2,-1,-1,2,1,-1,1,2,-1,2,-1\r\nD- major,1,2,-1,2,-1,1,2,-1,2,-1,2,-1\r\nb- minor,1,2,-1,2,-1,1,2,-1,-1,2,2,-1\r\n"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/task/mode-conditioned learning.py",
    "content": "import sys\nimport os\nimport time\nsys.path.append(\"../../../../\")\nsys.path.append(\"../\")\nimport numpy as np\nimport music21 as m21\nfrom conf.conf import *\nfrom api.music_engine_api import EngineAPI\n\n\nif __name__==\"__main__\":\n    musicEngine = EngineAPI()\n    musicEngine.cortexInit()\n    #------------Bach dataset learning----------------#\n    paths = m21.corpus.getComposer('bach')\n    print(len(paths))\n    for path in paths:\n        musicName = (str(path).split('\\\\'))[-1]\n        print(musicName)\n        if musicName.split('.')[-1] != 'mxl': continue\n        xmldata = m21.corpus.parse(path)\n        musicEngine.rememberMusic(musicName, \"None\")\n        musicEngine.learnFourPartMusic(xmldata, musicName, \"None\")\n\n    #------------generation test----------------#\n    key = 'C major'\n    firstnotes = np.array([[m21.pitch.Pitch('E5').midi],\n                           [m21.pitch.Pitch('G4').midi],\n                           [m21.pitch.Pitch('C4').midi],\n                           [m21.pitch.Pitch('C3').midi]])\n\n    result = musicEngine.generateMelodyWithKey(configs.keyIndexMap.get(key),firstnotes,None,4)\n    steam1 = m21.stream.Stream()\n    for i,part in result.items():\n        pt = m21.stream.Stream()\n        for v in part:\n            p = v.get(\"N\")\n            d = v.get(\"T\")\n            n = m21.note.Note(p)\n            n.quarterLength = d\n            pt.append(n)\n        steam1.insert(0,pt)\n    opath = '../result_output/tone learning/'\n    nowtime = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))\n    t2 = ''.join([x for x in nowtime if x.isdigit()])\n    steam1.write('midi', fp=opath+key+\"_\"+t2+'.mid')\n"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/task/musicGeneration.py",
    "content": "import sys\nsys.path.append(\"../../../../\")\nsys.path.append(\"../\")\nfrom api.music_engine_api import EngineAPI\nimport os\n\nif __name__==\"__main__\":\n    #----------------------------------Init------------------------------#\n    musicEngine = EngineAPI()\n    musicEngine.cortexInit()\n\n    #----------------------------Learning process------------------------#\n    input_path = \"../testData/\"\n    for composerName in os.listdir(input_path):\n        dpath = os.path.join(input_path,composerName)\n        if os.path.isdir(dpath):\n            for musicName in os.listdir(dpath):\n                fileName = (os.path.join(dpath,musicName))\n                musicEngine.memorizing(musicName,composerName,20,fileName)\n\n\n    #-------------------------Generation Process------------------------#\n    beginnotes = {1:[-1,67],\n                  2:[-1]}\n    begindurs = {1:[0.5,0.25],\n                 2:[0.5]}\n    lengths = [10,8]\n\n    genreName = \"Classical\"\n    composerName = \"Bach\"\n    #Generate a piece of music melody#\n    musicEngine.generateEx_Nihilo(beginnotes.get(2),begindurs.get(2),20,\"melody_generated\")\n    #Generate a piece of melody with a composer style\n    musicEngine.generateEx_NihiloAccordingToComposer(composerName,beginnotes.get(2),begindurs.get(2),15,\"Bach_generated\")\n    #Generate a piece of melody with a genre style\n    musicEngine.generateEx_NihiloAccordingToGenre(genreName,beginnotes.get(1),begindurs.get(1),15,\"Classical_generated\")"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/task/musicMemory.py",
    "content": "import sys\nsys.path.append(\"../\")\nsys.path.append(\"../../../../\")\nfrom api.music_engine_api import EngineAPI\nimport os\n\n\nif __name__==\"__main__\":\n    musicEngine = EngineAPI()\n    musicEngine.cortexInit()\n\n\n    input_path = \"../testData/\"\n    #--------------------learning process---------------#\n\n    for composerName in os.listdir(input_path):\n        dpath = os.path.join(input_path,composerName)\n        if os.path.isdir(dpath):\n            for musicName in os.listdir(dpath):\n                fileName = (os.path.join(dpath,musicName))\n                #Here is the training function, the first and the second parameters refer to the title and composer of a melody.\n                #The third parameter indicates the number of notes you want to learn. If you want to learn all the notes of melodies, this value is \"ALL\".\n                musicEngine.memorizing(musicName, composerName, 20, fileName)\n\n    # recall the music based on the name of a music\n    musicEngine.recallMusic(\"Sonate C Major.Mid\")\n\n\n"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/tools/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/tools/generateData.py",
    "content": "'''\nCreated on 2016.4.27\n\n@author: liangqian\n'''\nimport json\nimport random\n\nclass joint():\n    def __init__(self):\n        self.x = 0\n        self.y = 0\n        self.z = 0\n    def self2dic(self):\n        return {'x':self.x,\n                'y':self.y,\n                'z':self.z}\n        \nclass data():\n    \n    def __init__(self):\n        self.T = 0\n        self.position = 0\n        self.joints = {}\n        \n    def self2dic(self):\n        return {'T':self.T,\n                'position':self.position\n                }\n\nData_List = []\n\n# for i in range(1,419):\n#     d = data()\n#     d.T = i\n#     d.position = random.randint(0,100)\n#     dic = d.self2dic()\n#     Data_List.append(d)\n# \n# dic = {}\n# dic['Data'] = Data_List\n#     \n# # using json to write file\n# strs = (json.dumps(dic))\n# \n# fout = open('position.txt','w')\n# fout.write(strs)\n# fout.close()\n# \n# \n# fin = open('position.txt','r')\n# data_json = fin.read()\n# \n# print(data_json)\nf = open('../Data.txt','r')\nline = f.readline()\nwhile(len(line) != 0):\n    if('T=' in line):\n        ss = line.split('=')\n        d = data()\n        d.T = int(ss[1].strip())\n        d.position = random.randint(0,100)\n        line = f.readline()\n        while(len(line.strip()) != 0):\n            sss = line.split(' ')\n            jj = joint()\n            s = sss[1].split('=')\n            jj.x = float(s[1].strip())\n            s = sss[2].split('=')\n            jj.y = float(s[1].strip())\n            s = sss[3].split('=')\n            jj.z = float(s[1].strip())\n            d.joints[sss[0]] = jj\n            line = f.readline()\n        Data_List.append(d)\n        print(len(d.joints))\n    line = f.readline()     \nprint(len(Data_List))        \n"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/tools/hamonydataset_test.py",
    "content": "import os\nimport sys\nimport music21 as m21\nimport numpy as np\nfirstnotes = np.array([[m21.pitch.Pitch('b-4').midi],\n                       [m21.pitch.Pitch('d4').midi],\n                       [m21.pitch.Pitch('f3').midi],\n                       [m21.pitch.Pitch('B2').midi]])\nprint(firstnotes)\ninput_path = \"../xmlfiles/hamony dataset/\"\nfor ch in os.listdir(input_path):\n    if ch == '.DS_Store': continue\n    fileName = os.path.join(input_path, ch)\n    for fName in os.listdir(fileName):\n        if fName == 'four':\n            if fName == '.DS_Store': continue\n            fn = os.path.join(fileName+'/four')\n            for f in os.listdir(fn):\n                if f == '.DS_Store': continue\n                print(f)\n                musicName = f.split(\"_\")[0]\n                fTone = (f.split(\"_\")[1]).split(\".\")[0]\n                print(musicName)\n                print(fTone)\n                print('---')\n"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/tools/msg.py",
    "content": "'''\nCreated on 2016.6.8\n\n@author: liangqian\n'''\nimport time\nimport sys\nimport stomp\nclass MyListener(object):\n    def on_error(self, headers, message):\n        print('received an error : %s' % message)\n    def on_message(self, headers, message):\n        print('%s' % message)\n\nconn = stomp.Connection([('159.226.19.16',61613)])\n#conn = stomp.Connection([('10.10.10.106',61613)])   \nconn.set_listener('', MyListener())\nconn.start()\nprint('hh')\n\nconn.connect(wait=True,headers={'client-id':'LXYNB','non_persistent':'true'})\n \nconn.subscribe(destination='/topic/TEST2.FOO',id='LX', ack='auto',headers={'activemq.subscriptionName':'LXYNB'})\n#conn.send(body='hello,garfield!', destination='/topic/myTopic.messages')\n \nwhile(True):\n    pass\nconn.disconnect()\n\n"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/tools/msgq.py",
    "content": "'''\nCreated on 2018.9.7\n\n@author: liangqian\n'''\nimport time\nimport sys\nimport stomp\n\ndef createMSQ():\n    queue_name = '/queue/SampleQueue'\n    conn = stomp.Connection([('localhost',61613)])\n    #conn.start()\n    print(\"building connection to activemq......\")\n    conn.connect()\n    #return conn\n#     for i in range (10):\n#         msg = 'this is the '+ str(i) + 'th messages'\n#         conn.send(queue_name,msg)\n#         print(msg)\n#     conn.disconnect()\n    return conn\n\n#createMSQ()"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/tools/oscillations.py",
    "content": "'''\nCreated on 2016.5.13\n\n@author: liangqian\n'''\n\nfrom modal.lifneuron import LIFNeuron\nfrom modal.cluster import Cluster\nfrom modal.synapse import Synapse\n\n\nc = Cluster('LIF')\nc.createClusterNetwork()\nc.setInhibitoryNeurons(0.2)\n\nfor i in range(0,c.neunum):\n    for j in range(0,c.neunum):\n        if(i != j):\n            node = c.neurons[j]\n            node.pre_neurons.append(c.neurons[i])\n            syn = Synapse(c.neurons[i],node)\n\n\n"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/tools/position.txt",
    "content": "{\"1\": {\"position\": 58, \"T\": 1}, \"2\": {\"position\": 96, \"T\": 2}, \"3\": {\"position\": 36, \"T\": 3}, \"4\": {\"position\": 28, \"T\": 4}, \"5\": {\"position\": 43, \"T\": 5}, \"6\": {\"position\": 46, \"T\": 6}, \"7\": {\"position\": 91, \"T\": 7}, \"8\": {\"position\": 44, \"T\": 8}, \"9\": {\"position\": 59, \"T\": 9}, \"10\": {\"position\": 83, \"T\": 10}, \"11\": {\"position\": 17, \"T\": 11}, \"12\": {\"position\": 50, \"T\": 12}, \"13\": {\"position\": 46, \"T\": 13}, \"14\": {\"position\": 53, \"T\": 14}, \"15\": {\"position\": 9, \"T\": 15}, \"16\": {\"position\": 88, \"T\": 16}, \"17\": {\"position\": 23, \"T\": 17}, \"18\": {\"position\": 28, \"T\": 18}, \"19\": {\"position\": 17, \"T\": 19}, \"20\": {\"position\": 16, \"T\": 20}, \"21\": {\"position\": 43, \"T\": 21}, \"22\": {\"position\": 93, \"T\": 22}, \"23\": {\"position\": 13, \"T\": 23}, \"24\": {\"position\": 0, \"T\": 24}, \"25\": {\"position\": 92, \"T\": 25}, \"26\": {\"position\": 56, \"T\": 26}, \"27\": {\"position\": 2, \"T\": 27}, \"28\": {\"position\": 38, \"T\": 28}, \"29\": {\"position\": 17, \"T\": 29}, \"30\": {\"position\": 13, \"T\": 30}, \"31\": {\"position\": 95, \"T\": 31}, \"32\": {\"position\": 86, \"T\": 32}, \"33\": {\"position\": 45, \"T\": 33}, \"34\": {\"position\": 30, \"T\": 34}, \"35\": {\"position\": 42, \"T\": 35}, \"36\": {\"position\": 87, \"T\": 36}, \"37\": {\"position\": 14, \"T\": 37}, \"38\": {\"position\": 56, \"T\": 38}, \"39\": {\"position\": 76, \"T\": 39}, \"40\": {\"position\": 55, \"T\": 40}, \"41\": {\"position\": 5, \"T\": 41}, \"42\": {\"position\": 98, \"T\": 42}, \"43\": {\"position\": 5, \"T\": 43}, \"44\": {\"position\": 17, \"T\": 44}, \"45\": {\"position\": 73, \"T\": 45}, \"46\": {\"position\": 55, \"T\": 46}, \"47\": {\"position\": 86, \"T\": 47}, \"48\": {\"position\": 56, \"T\": 48}, \"49\": {\"position\": 80, \"T\": 49}, \"50\": {\"position\": 87, \"T\": 50}, \"51\": {\"position\": 47, \"T\": 51}, \"52\": {\"position\": 51, \"T\": 52}, \"53\": {\"position\": 62, \"T\": 53}, \"54\": {\"position\": 3, \"T\": 54}, \"55\": {\"position\": 65, \"T\": 55}, \"56\": {\"position\": 90, \"T\": 56}, \"57\": {\"position\": 76, \"T\": 57}, \"58\": {\"position\": 35, \"T\": 58}, \"59\": {\"position\": 81, \"T\": 59}, \"60\": {\"position\": 0, \"T\": 60}, \"61\": {\"position\": 32, \"T\": 61}, \"62\": {\"position\": 93, \"T\": 62}, \"63\": {\"position\": 100, \"T\": 63}, \"64\": {\"position\": 70, \"T\": 64}, \"65\": {\"position\": 10, \"T\": 65}, \"66\": {\"position\": 37, \"T\": 66}, \"67\": {\"position\": 87, \"T\": 67}, \"68\": {\"position\": 72, \"T\": 68}, \"69\": {\"position\": 77, \"T\": 69}, \"70\": {\"position\": 65, \"T\": 70}, \"71\": {\"position\": 21, \"T\": 71}, \"72\": {\"position\": 45, \"T\": 72}, \"73\": {\"position\": 81, \"T\": 73}, \"74\": {\"position\": 27, \"T\": 74}, \"75\": {\"position\": 55, \"T\": 75}, \"76\": {\"position\": 8, \"T\": 76}, \"77\": {\"position\": 2, \"T\": 77}, \"78\": {\"position\": 66, \"T\": 78}, \"79\": {\"position\": 80, \"T\": 79}, \"80\": {\"position\": 76, \"T\": 80}, \"81\": {\"position\": 49, \"T\": 81}, \"82\": {\"position\": 49, \"T\": 82}, \"83\": {\"position\": 24, \"T\": 83}, \"84\": {\"position\": 51, \"T\": 84}, \"85\": {\"position\": 43, \"T\": 85}, \"86\": {\"position\": 31, \"T\": 86}, \"87\": {\"position\": 58, \"T\": 87}, \"88\": {\"position\": 84, \"T\": 88}, \"89\": {\"position\": 58, \"T\": 89}, \"90\": {\"position\": 95, \"T\": 90}, \"91\": {\"position\": 91, \"T\": 91}, \"92\": {\"position\": 97, \"T\": 92}, \"93\": {\"position\": 53, \"T\": 93}, \"94\": {\"position\": 33, \"T\": 94}, \"95\": {\"position\": 11, \"T\": 95}, \"96\": {\"position\": 8, \"T\": 96}, \"97\": {\"position\": 96, \"T\": 97}, \"98\": {\"position\": 28, \"T\": 98}, \"99\": {\"position\": 54, \"T\": 99}, \"100\": {\"position\": 37, \"T\": 100}, \"101\": {\"position\": 60, \"T\": 101}, \"102\": {\"position\": 71, \"T\": 102}, \"103\": {\"position\": 6, \"T\": 103}, \"104\": {\"position\": 7, \"T\": 104}, \"105\": {\"position\": 4, \"T\": 105}, \"106\": {\"position\": 39, \"T\": 106}, \"107\": {\"position\": 75, \"T\": 107}, \"108\": {\"position\": 47, \"T\": 108}, \"109\": {\"position\": 30, \"T\": 109}, \"110\": {\"position\": 100, \"T\": 110}, \"111\": {\"position\": 30, \"T\": 111}, \"112\": {\"position\": 40, \"T\": 112}, \"113\": {\"position\": 11, \"T\": 113}, \"114\": {\"position\": 3, \"T\": 114}, \"115\": {\"position\": 83, \"T\": 115}, \"116\": {\"position\": 45, \"T\": 116}, \"117\": {\"position\": 91, \"T\": 117}, \"118\": {\"position\": 85, \"T\": 118}, \"119\": {\"position\": 56, \"T\": 119}, \"120\": {\"position\": 33, \"T\": 120}, \"121\": {\"position\": 74, \"T\": 121}, \"122\": {\"position\": 100, \"T\": 122}, \"123\": {\"position\": 56, \"T\": 123}, \"124\": {\"position\": 91, \"T\": 124}, \"125\": {\"position\": 16, \"T\": 125}, \"126\": {\"position\": 0, \"T\": 126}, \"127\": {\"position\": 17, \"T\": 127}, \"128\": {\"position\": 73, \"T\": 128}, \"129\": {\"position\": 40, \"T\": 129}, \"130\": {\"position\": 10, \"T\": 130}, \"131\": {\"position\": 6, \"T\": 131}, \"132\": {\"position\": 63, \"T\": 132}, \"133\": {\"position\": 23, \"T\": 133}, \"134\": {\"position\": 47, \"T\": 134}, \"135\": {\"position\": 79, \"T\": 135}, \"136\": {\"position\": 97, \"T\": 136}, \"137\": {\"position\": 13, \"T\": 137}, \"138\": {\"position\": 8, \"T\": 138}, \"139\": {\"position\": 39, \"T\": 139}, \"140\": {\"position\": 72, \"T\": 140}, \"141\": {\"position\": 59, \"T\": 141}, \"142\": {\"position\": 55, \"T\": 142}, \"143\": {\"position\": 91, \"T\": 143}, \"144\": {\"position\": 22, \"T\": 144}, \"145\": {\"position\": 62, \"T\": 145}, \"146\": {\"position\": 23, \"T\": 146}, \"147\": {\"position\": 14, \"T\": 147}, \"148\": {\"position\": 70, \"T\": 148}, \"149\": {\"position\": 86, \"T\": 149}, \"150\": {\"position\": 10, \"T\": 150}, \"151\": {\"position\": 72, \"T\": 151}, \"152\": {\"position\": 91, \"T\": 152}, \"153\": {\"position\": 29, \"T\": 153}, \"154\": {\"position\": 29, \"T\": 154}, \"155\": {\"position\": 94, \"T\": 155}, \"156\": {\"position\": 69, \"T\": 156}, \"157\": {\"position\": 58, \"T\": 157}, \"158\": {\"position\": 0, \"T\": 158}, \"159\": {\"position\": 11, \"T\": 159}, \"160\": {\"position\": 24, \"T\": 160}, \"161\": {\"position\": 70, \"T\": 161}, \"162\": {\"position\": 58, \"T\": 162}, \"163\": {\"position\": 58, \"T\": 163}, \"164\": {\"position\": 7, \"T\": 164}, \"165\": {\"position\": 25, \"T\": 165}, \"166\": {\"position\": 32, \"T\": 166}, \"167\": {\"position\": 92, \"T\": 167}, \"168\": {\"position\": 58, \"T\": 168}, \"169\": {\"position\": 55, \"T\": 169}, \"170\": {\"position\": 13, \"T\": 170}, \"171\": {\"position\": 17, \"T\": 171}, \"172\": {\"position\": 4, \"T\": 172}, \"173\": {\"position\": 30, \"T\": 173}, \"174\": {\"position\": 58, \"T\": 174}, \"175\": {\"position\": 91, \"T\": 175}, \"176\": {\"position\": 53, \"T\": 176}, \"177\": {\"position\": 41, \"T\": 177}, \"178\": {\"position\": 45, \"T\": 178}, \"179\": {\"position\": 69, \"T\": 179}, \"180\": {\"position\": 82, \"T\": 180}, \"181\": {\"position\": 44, \"T\": 181}, \"182\": {\"position\": 48, \"T\": 182}, \"183\": {\"position\": 16, \"T\": 183}, \"184\": {\"position\": 75, \"T\": 184}, \"185\": {\"position\": 91, \"T\": 185}, \"186\": {\"position\": 16, \"T\": 186}, \"187\": {\"position\": 71, \"T\": 187}, \"188\": {\"position\": 59, \"T\": 188}, \"189\": {\"position\": 89, \"T\": 189}, \"190\": {\"position\": 25, \"T\": 190}, \"191\": {\"position\": 85, \"T\": 191}, \"192\": {\"position\": 34, \"T\": 192}, \"193\": {\"position\": 83, \"T\": 193}, \"194\": {\"position\": 36, \"T\": 194}, \"195\": {\"position\": 30, \"T\": 195}, \"196\": {\"position\": 5, \"T\": 196}, \"197\": {\"position\": 7, \"T\": 197}, \"198\": {\"position\": 5, \"T\": 198}, \"199\": {\"position\": 74, \"T\": 199}, \"200\": {\"position\": 29, \"T\": 200}, \"201\": {\"position\": 40, \"T\": 201}, \"202\": {\"position\": 48, \"T\": 202}, \"203\": {\"position\": 69, \"T\": 203}, \"204\": {\"position\": 0, \"T\": 204}, \"205\": {\"position\": 67, \"T\": 205}, \"206\": {\"position\": 7, \"T\": 206}, \"207\": {\"position\": 79, \"T\": 207}, \"208\": {\"position\": 72, \"T\": 208}, \"209\": {\"position\": 68, \"T\": 209}, \"210\": {\"position\": 97, \"T\": 210}, \"211\": {\"position\": 40, \"T\": 211}, \"212\": {\"position\": 84, \"T\": 212}, \"213\": {\"position\": 20, \"T\": 213}, \"214\": {\"position\": 91, \"T\": 214}, \"215\": {\"position\": 16, \"T\": 215}, \"216\": {\"position\": 12, \"T\": 216}, \"217\": {\"position\": 100, \"T\": 217}, \"218\": {\"position\": 89, \"T\": 218}, \"219\": {\"position\": 33, \"T\": 219}, \"220\": {\"position\": 60, \"T\": 220}, \"221\": {\"position\": 25, \"T\": 221}, \"222\": {\"position\": 82, \"T\": 222}, \"223\": {\"position\": 28, \"T\": 223}, \"224\": {\"position\": 53, \"T\": 224}, \"225\": {\"position\": 6, \"T\": 225}, \"226\": {\"position\": 63, \"T\": 226}, \"227\": {\"position\": 93, \"T\": 227}, \"228\": {\"position\": 14, \"T\": 228}, \"229\": {\"position\": 47, \"T\": 229}, \"230\": {\"position\": 42, \"T\": 230}, \"231\": {\"position\": 94, \"T\": 231}, \"232\": {\"position\": 61, \"T\": 232}, \"233\": {\"position\": 88, \"T\": 233}, \"234\": {\"position\": 17, \"T\": 234}, \"235\": {\"position\": 3, \"T\": 235}, \"236\": {\"position\": 97, \"T\": 236}, \"237\": {\"position\": 38, \"T\": 237}, \"238\": {\"position\": 30, \"T\": 238}, \"239\": {\"position\": 84, \"T\": 239}, \"240\": {\"position\": 37, \"T\": 240}, \"241\": {\"position\": 99, \"T\": 241}, \"242\": {\"position\": 36, \"T\": 242}, \"243\": {\"position\": 100, \"T\": 243}, \"244\": {\"position\": 53, \"T\": 244}, \"245\": {\"position\": 44, \"T\": 245}, \"246\": {\"position\": 37, \"T\": 246}, \"247\": {\"position\": 80, \"T\": 247}, \"248\": {\"position\": 8, \"T\": 248}, \"249\": {\"position\": 79, \"T\": 249}, \"250\": {\"position\": 94, \"T\": 250}, \"251\": {\"position\": 82, \"T\": 251}, \"252\": {\"position\": 84, \"T\": 252}, \"253\": {\"position\": 47, \"T\": 253}, \"254\": {\"position\": 27, \"T\": 254}, \"255\": {\"position\": 22, \"T\": 255}, \"256\": {\"position\": 22, \"T\": 256}, \"257\": {\"position\": 26, \"T\": 257}, \"258\": {\"position\": 58, \"T\": 258}, \"259\": {\"position\": 83, \"T\": 259}, \"260\": {\"position\": 60, \"T\": 260}, \"261\": {\"position\": 16, \"T\": 261}, \"262\": {\"position\": 54, \"T\": 262}, \"263\": {\"position\": 65, \"T\": 263}, \"264\": {\"position\": 7, \"T\": 264}, \"265\": {\"position\": 57, \"T\": 265}, \"266\": {\"position\": 15, \"T\": 266}, \"267\": {\"position\": 63, \"T\": 267}, \"268\": {\"position\": 54, \"T\": 268}, \"269\": {\"position\": 58, \"T\": 269}, \"270\": {\"position\": 16, \"T\": 270}, \"271\": {\"position\": 45, \"T\": 271}, \"272\": {\"position\": 57, \"T\": 272}, \"273\": {\"position\": 72, \"T\": 273}, \"274\": {\"position\": 93, \"T\": 274}, \"275\": {\"position\": 55, \"T\": 275}, \"276\": {\"position\": 77, \"T\": 276}, \"277\": {\"position\": 74, \"T\": 277}, \"278\": {\"position\": 20, \"T\": 278}, \"279\": {\"position\": 57, \"T\": 279}, \"280\": {\"position\": 89, \"T\": 280}, \"281\": {\"position\": 68, \"T\": 281}, \"282\": {\"position\": 74, \"T\": 282}, \"283\": {\"position\": 87, \"T\": 283}, \"284\": {\"position\": 11, \"T\": 284}, \"285\": {\"position\": 74, \"T\": 285}, \"286\": {\"position\": 69, \"T\": 286}, \"287\": {\"position\": 95, \"T\": 287}, \"288\": {\"position\": 76, \"T\": 288}, \"289\": {\"position\": 23, \"T\": 289}, \"290\": {\"position\": 6, \"T\": 290}, \"291\": {\"position\": 42, \"T\": 291}, \"292\": {\"position\": 71, \"T\": 292}, \"293\": {\"position\": 24, \"T\": 293}, \"294\": {\"position\": 18, \"T\": 294}, \"295\": {\"position\": 76, \"T\": 295}, \"296\": {\"position\": 97, \"T\": 296}, \"297\": {\"position\": 41, \"T\": 297}, \"298\": {\"position\": 60, \"T\": 298}, \"299\": {\"position\": 13, \"T\": 299}, \"300\": {\"position\": 84, \"T\": 300}, \"301\": {\"position\": 93, \"T\": 301}, \"302\": {\"position\": 75, \"T\": 302}, \"303\": {\"position\": 81, \"T\": 303}, \"304\": {\"position\": 76, \"T\": 304}, \"305\": {\"position\": 93, \"T\": 305}, \"306\": {\"position\": 2, \"T\": 306}, \"307\": {\"position\": 34, \"T\": 307}, \"308\": {\"position\": 27, \"T\": 308}, \"309\": {\"position\": 37, \"T\": 309}, \"310\": {\"position\": 71, \"T\": 310}, \"311\": {\"position\": 84, \"T\": 311}, \"312\": {\"position\": 97, \"T\": 312}, \"313\": {\"position\": 56, \"T\": 313}, \"314\": {\"position\": 100, \"T\": 314}, \"315\": {\"position\": 41, \"T\": 315}, \"316\": {\"position\": 88, \"T\": 316}, \"317\": {\"position\": 51, \"T\": 317}, \"318\": {\"position\": 80, \"T\": 318}, \"319\": {\"position\": 62, \"T\": 319}, \"320\": {\"position\": 34, \"T\": 320}, \"321\": {\"position\": 34, \"T\": 321}, \"322\": {\"position\": 32, \"T\": 322}, \"323\": {\"position\": 56, \"T\": 323}, \"324\": {\"position\": 58, \"T\": 324}, \"325\": {\"position\": 25, \"T\": 325}, \"326\": {\"position\": 29, \"T\": 326}, \"327\": {\"position\": 15, \"T\": 327}, \"328\": {\"position\": 55, \"T\": 328}, \"329\": {\"position\": 87, \"T\": 329}, \"330\": {\"position\": 20, \"T\": 330}, \"331\": {\"position\": 18, \"T\": 331}, \"332\": {\"position\": 25, \"T\": 332}, \"333\": {\"position\": 21, \"T\": 333}, \"334\": {\"position\": 69, \"T\": 334}, \"335\": {\"position\": 50, \"T\": 335}, \"336\": {\"position\": 12, \"T\": 336}, \"337\": {\"position\": 30, \"T\": 337}, \"338\": {\"position\": 41, \"T\": 338}, \"339\": {\"position\": 59, \"T\": 339}, \"340\": {\"position\": 100, \"T\": 340}, \"341\": {\"position\": 20, \"T\": 341}, \"342\": {\"position\": 20, \"T\": 342}, \"343\": {\"position\": 72, \"T\": 343}, \"344\": {\"position\": 30, \"T\": 344}, \"345\": {\"position\": 94, \"T\": 345}, \"346\": {\"position\": 31, \"T\": 346}, \"347\": {\"position\": 26, \"T\": 347}, \"348\": {\"position\": 13, \"T\": 348}, \"349\": {\"position\": 5, \"T\": 349}, \"350\": {\"position\": 2, \"T\": 350}, \"351\": {\"position\": 97, \"T\": 351}, \"352\": {\"position\": 97, \"T\": 352}, \"353\": {\"position\": 3, \"T\": 353}, \"354\": {\"position\": 75, \"T\": 354}, \"355\": {\"position\": 98, \"T\": 355}, \"356\": {\"position\": 97, \"T\": 356}, \"357\": {\"position\": 35, \"T\": 357}, \"358\": {\"position\": 12, \"T\": 358}, \"359\": {\"position\": 46, \"T\": 359}, \"360\": {\"position\": 91, \"T\": 360}, \"361\": {\"position\": 40, \"T\": 361}, \"362\": {\"position\": 48, \"T\": 362}, \"363\": {\"position\": 98, \"T\": 363}, \"364\": {\"position\": 70, \"T\": 364}, \"365\": {\"position\": 7, \"T\": 365}, \"366\": {\"position\": 75, \"T\": 366}, \"367\": {\"position\": 35, \"T\": 367}, \"368\": {\"position\": 48, \"T\": 368}, \"369\": {\"position\": 77, \"T\": 369}, \"370\": {\"position\": 91, \"T\": 370}, \"371\": {\"position\": 96, \"T\": 371}, \"372\": {\"position\": 9, \"T\": 372}, \"373\": {\"position\": 2, \"T\": 373}, \"374\": {\"position\": 76, \"T\": 374}, \"375\": {\"position\": 25, \"T\": 375}, \"376\": {\"position\": 95, \"T\": 376}, \"377\": {\"position\": 72, \"T\": 377}, \"378\": {\"position\": 59, \"T\": 378}, \"379\": {\"position\": 4, \"T\": 379}, \"380\": {\"position\": 87, \"T\": 380}, \"381\": {\"position\": 100, \"T\": 381}, \"382\": {\"position\": 91, \"T\": 382}, \"383\": {\"position\": 73, \"T\": 383}, \"384\": {\"position\": 63, \"T\": 384}, \"385\": {\"position\": 64, \"T\": 385}, \"386\": {\"position\": 16, \"T\": 386}, \"387\": {\"position\": 20, \"T\": 387}, \"388\": {\"position\": 51, \"T\": 388}, \"389\": {\"position\": 56, \"T\": 389}, \"390\": {\"position\": 81, \"T\": 390}, \"391\": {\"position\": 16, \"T\": 391}, \"392\": {\"position\": 88, \"T\": 392}, \"393\": {\"position\": 68, \"T\": 393}, \"394\": {\"position\": 65, \"T\": 394}, \"395\": {\"position\": 54, \"T\": 395}, \"396\": {\"position\": 29, \"T\": 396}, \"397\": {\"position\": 37, \"T\": 397}, \"398\": {\"position\": 88, \"T\": 398}, \"399\": {\"position\": 82, \"T\": 399}, \"400\": {\"position\": 93, \"T\": 400}, \"401\": {\"position\": 71, \"T\": 401}, \"402\": {\"position\": 33, \"T\": 402}, \"403\": {\"position\": 72, \"T\": 403}, \"404\": {\"position\": 83, \"T\": 404}, \"405\": {\"position\": 93, \"T\": 405}, \"406\": {\"position\": 72, \"T\": 406}, \"407\": {\"position\": 3, \"T\": 407}, \"408\": {\"position\": 98, \"T\": 408}, \"409\": {\"position\": 99, \"T\": 409}, \"410\": {\"position\": 39, \"T\": 410}, \"411\": {\"position\": 62, \"T\": 411}, \"412\": {\"position\": 60, \"T\": 412}, \"413\": {\"position\": 48, \"T\": 413}, \"414\": {\"position\": 32, \"T\": 414}, \"415\": {\"position\": 66, \"T\": 415}, \"416\": {\"position\": 40, \"T\": 416}, \"417\": {\"position\": 33, \"T\": 417}, \"418\": {\"position\": 68, \"T\": 418}}"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/tools/readjson.py",
    "content": "'''\nCreated on 2016.5.24\n\n@author: liangqian\n'''\n\n\nimport json\nimport os\n\ndef readjsonFile(filename):\n    #print(os.path.abspath(os.curdir))\n    f = open(filename,'r')\n    jsonstrs = f.read()\n    #print(jsonstrs)\n    jdata = json.loads(jsonstrs)\n    return jdata\n    \n    \n\n#readjsonFile('../jsondata.txt')"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/tools/testSound.py",
    "content": "'''\nCreated on 2016.6.29\n\n@author: liangqian\n'''\n\nimport pygame,sys\npygame.init()\npygame.mixer.init()\npygame.time.delay(1000)\npygame.mixer.music.load(\"do.wav\")\npygame.mixer.music.play()\nwhile 1:\n    for event in pygame.event.get():\n        if event.type==pygame.QUIT:\n            sys.exit()\n"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/tools/testmusic21.py",
    "content": "from music21 import *\n#s = converter.parse('../xmlfiles/four_part_hamony/ch4-03_A-major.xml')\n\ns = corpus.parse('bach/bwv65.2.xml')\ns.analyze('key')\nprint(len(s.parts))\ns.show()\n\nfor i,part in enumerate(s.parts):\n    print(i)\n    print('-----------')\n    for ns in part.flat.notes:\n        print(ns.pitch)\n        print(ns.duration.quarterLength)\n\nnote1 = note.Note(\"D5\")\nnote2 = note.Note(\"F#5\")\nnote2.duration.quarterLength = 0.5\nnote3 = note.Note(\"A5\")\n\nstream1 = stream.Stream()\nstream1.append(note1)\nstream1.append(note2)\nstream1.append(note3)\n\nprint(note2.offset)\n\nsout = stream1.getElementsByOffset(0,2)\n\nsBach = corpus.parse('bach/bwv57.8')\ns = sBach.chordify()\n#cs = s.getElementsByClass('Chord')\ns1 = s.flatten()\nchords = s1.getElementsByClass('Chord')\n\n\n# cMinor = chord.Chord([\"A4\",\"F4\",\"D5\"])\n# print(cMinor.inversion())\n# print(cMinor.isMinorTriad())\n\nkeyA = key.Key('B-')\nfor c in chords:\n    rn = roman.romanNumeralFromChord(c, keyA)\n    c.addLyric(str(rn.figure))\n\nchords.show()\n\n"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/tools/testopengl.py",
    "content": "'''\nCreated on 2018.8.31\n\n@author: liangqian\n'''\nfrom OpenGL.GL import *\nfrom OpenGL.GLU import *\nfrom OpenGL.GLUT import *\n \ndef drawFunc():\n    glClear(GL_COLOR_BUFFER_BIT)\n    #glRotatef(1, 0, 1, 0)\n    glutWireTeapot(0.5)\n    glFlush()\n \nglutInit()\nglutInitDisplayMode(GLUT_SINGLE | GLUT_RGBA)\nglutInitWindowSize(400, 400)\nglutCreateWindow(b\"First\")\nglutDisplayFunc(drawFunc)\n#glutIdleFunc(drawFunc)\nglutMainLoop()"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/tools/testwave.py",
    "content": "import wave\nimport struct\nimport os\nimport numpy as np\n\nf = 440\nframerate = 44100.0\nfw = wave.open(\"sine.wav\",\"wb\")\nfw.setnchannels(1)\nfw.setframerate(framerate)\nfw.setsampwidth(2)\ntt = np.arange(0, 1, 1.0/framerate)\n\ndata = [2000*(np.sin(2*np.pi*f*t)+np.sin(2*np.pi*2*f*t)+np.sin(2*np.pi*3*f*t)) for t in tt]\nprint(data)\nfor d in data:\n    fw.writeframes(struct.pack('h',int(d)))\nfw.close()"
  },
  {
    "path": "examples/Knowledge_Representation_and_Reasoning/musicMemory/tools/xmlParser.py",
    "content": "import librosa\nimport music21 as m21\nimport pandas as pd\nimport os\n\n\n'''\nThis function parses MusicXML file and extracts necessary score information as CSV.\n'''\ndef readXmlAsCsv(xmlPath='xml/'):\n    for subfolder in os.listdir(xmlPath):\n        if subfolder.startswith('.'):\n            continue\n        subfolder_path = os.path.join(xmlPath, subfolder)\n        for item in os.listdir(subfolder_path):\n            if item.endswith('xml'):\n                item_path = os.path.join(subfolder_path, item)\n                xml_data = m21.converter.parse(item_path)\n                print(\"Converting \", item_path)\n                score = []\n                for part in xml_data.parts:\n                    for note in part.flat.notes:\n                        if note.isChord:\n                            print('note is chord: ', note)\n                            measureNo = note.measureNumber\n                            start = note.offset\n                            duration = note.quarterLength\n\n                            for chord_note in note:\n                                pitch = chord_note.pitch\n                                articulations = note.articulations\n                                expressions = note.expressions\n                                spanners = note.getSpannerSites()\n                                gliss = []\n                                for spanner in spanners:\n                                    if 'Glissando' in spanner.classes:\n                                        if spanner.isFirst(chord_note):\n                                            gliss.append('slide start')\n                                        if spanner.isLast(chord_note):\n                                            gliss.append('slide last')\n                                score.append(\n                                    [measureNo, start, duration, pitch, m21.pitch.Pitch(pitch).frequency,\n                                     articulations, expressions, gliss, spanners])\n\n                        else:\n                            measureNo = note.measureNumber\n                            start = note.offset\n                            duration = note.quarterLength\n                            pitch = note.pitch\n                            articulations = note.articulations\n                            expressions = note.expressions\n                            spanners = note.getSpannerSites()\n                            gliss = []\n                            for spanner in spanners:\n                                if 'Glissando' in spanner.classes:\n                                    if spanner.isFirst(note):\n                                        gliss.append('slide start')\n                                    if spanner.isLast(note):\n                                        gliss.append('slide last')\n\n                            score.append(\n                                [measureNo, start, duration, pitch, m21.pitch.Pitch(pitch).frequency,\n                                 articulations, expressions, gliss, spanners])\n                score = sorted(score, key=lambda x: (x[0], x[1], x[2]))\n                df = pd.DataFrame(score,\n                                  columns=['MeasureNumber', 'Start', 'Duration', 'Pitch', 'f0',\n                                           'Articulations', 'Expressions', 'Glissando', 'Spanner'])\n                df.to_csv(os.path.join(path, 'csv', subfolder, os.path.splitext(item)[0] + '.csv'))"
  },
  {
    "path": "examples/MotorControl/experimental/README.md",
    "content": "# Experimental works for motor control with different Brain Aeras.\n\nThe project is still immature and under continuous development...\n\n## Citation\n\nIf you find the code and dataset useful in your research, please consider citing:\n```\n@misc{https://doi.org/10.48550/arxiv.2207.08533,\n  doi = {10.48550/ARXIV.2207.08533},\n  url = {https://arxiv.org/abs/2207.08533},\n  author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},\n  title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},\n  publisher = {arXiv},\n  year = {2022},\n}\n```"
  },
  {
    "path": "examples/MotorControl/experimental/brain_area.py",
    "content": "import torch\nimport numpy as  np\nimport torch.nn as nn\nfrom braincog.base.node.node import *\n\n\nclass MoColumnPOP(nn.Module):\n    def __init__(self, \n                input_dims: int,\n                pop_num: int = 16,\n                embedding_dim: int = 64,\n                time_window: int = 16) -> None:\n        super().__init__()\n        self._threshold = 1.0\n        self.v_reset = 0.0\n        self._time_window = time_window\n        self._pop_num = pop_num\n        self._node = LIFNode\n        self.column_net = nn.ModuleList(\n            [nn.Sequential(\n                nn.Linear(input_dims, embedding_dim), \n                self._node(threshold=self._threshold, v_reset=self.v_reset)) \n            for _ in range(pop_num)\n            ]\n        )\n\n        self.decode = nn.Linear(embedding_dim, 64)\n\n    def reset(self):\n        for mod in self.modules():\n            if hasattr(mod, 'n_reset'):\n                mod.n_reset()\n\n    def _emb_decode(self, x):\n        pop_emb_decode = []\n        for net in self.column_net:\n            emb = net(x)\n            pop_emb_decode.append(self.decode(emb))\n        return pop_emb_decode\n\n\n    def forward(self, inputs):\n        pop_emb_decode = self._emb_decode(inputs)\n        out = sum(pop_emb_decode) / self._pop_num\n        return out\n\n\nclass MotorCortex(nn.Module):\n    def __init__(self, \n                input_dims: int,\n                out_dims: int = 128,\n                time_window: int = 16) -> None:\n        super().__init__()\n        self._threshold = 1.0\n        self.v_reset = 0.0\n        self._time_window = time_window\n        self._node = LIFNode\n        self.pfc_net = nn.Sequential(\n            nn.Linear(input_dims, 512),\n            self._node(threshold=self._threshold, v_reset=self.v_reset)\n        )\n        self.sma_net = nn.Sequential(\n            nn.Linear(input_dims, 512),\n            self._node(threshold=self._threshold, v_reset=self.v_reset)\n        )\n\n        self.ganglia_net = nn.Sequential(\n            nn.Linear(512, 128),\n            self._node(threshold=self._threshold, v_reset=self.v_reset)\n        )\n        self.pmc_net = nn.Sequential(\n            nn.Linear(512, 512),\n            self._node(threshold=self._threshold, v_reset=self.v_reset)\n        )\n\n        self.motor_net = nn.Sequential(\n            nn.Linear(512+128, 128),\n            self._node(threshold=self._threshold, v_reset=self.v_reset)\n        )\n\n        self.motor_emb = MoColumnPOP(input_dims=128, embedding_dim=out_dims, time_window=time_window)\n\n\n    def reset(self):\n        for mod in self.modules():\n            if hasattr(mod, 'n_reset'):\n                mod.n_reset()\n        \n\n    def _compute_motor_out(self, inputs):\n        sma_out = self.sma_net(inputs)\n        ganglia_out = self.ganglia_net(sma_out)\n        motor_in = torch.concat([ganglia_out, sma_out], dim=-1)\n        motor_out = self.motor_net(motor_in)\n        # pop coding \n        return motor_out\n\n\n    def forward(self, inputs):\n        self.reset()\n        outs = []\n        for step in range(self._time_window):\n            motor_out =  self._compute_motor_out(inputs)\n            m_emb = self.motor_emb(motor_out)  # [Batch, 128]\n            outs.append(m_emb)\n        return outs\n        \n\n\nclass Celebellum(nn.Module):\n    def __init__(self,\n                 input_dims: int =  512,\n                 out_dims: int =  7, \n                 time_window: int = 16,\n                 ) -> None:\n        super().__init__()\n        self._threshold = 1.0\n        self.v_reset = 0.0\n        self._time_window = time_window\n        self._node = LIFNode\n        self.gc_layer = nn.Sequential(\n            nn.Linear(input_dims, 512),\n            self._node(threshold=self._threshold, v_reset=self.v_reset)\n        )\n\n        self.pc_layer = nn.Sequential(\n            nn.Linear(512, 512),\n            self._node(threshold=self._threshold, v_reset=self.v_reset)\n        )\n        self.dcn_layer = nn.Sequential(\n            nn.Linear(input_dims + 512, 512),\n            self._node(threshold=self._threshold, v_reset=self.v_reset),\n            nn.Linear(512, out_dims)\n        )\n\n    def reset(self):\n        for mod in self.modules():\n            if hasattr(mod, 'n_reset'):\n                mod.n_reset()\n    def forward(self, x):\n        self.reset()\n        outs = []\n        for step in range(self._time_window):\n            gc = self.gc_layer(x[step])\n            pc = self.pc_layer(gc)\n            dcn_in = torch.concat([x[step], pc], dim=-1)\n            dcn = self.dcn_layer(dcn_in)\n            outs.append(dcn)\n        cel_out = sum(outs) / self._time_window\n        return cel_out\n\n\n\nif __name__ == '__main__':\n    motor = MotorCortex(input_dims=1024)\n    for mod in motor.modules():\n        # print('mod: ', mod)\n        if hasattr(mod, 'n_reset'):\n            print('mod: ', mod)\n    "
  },
  {
    "path": "examples/MotorControl/experimental/main.py",
    "content": "import torch\nimport numpy as  np\nimport torch.nn as nn\nfrom model import Motion\nimport tqdm\nimport argparse\nfrom torch.nn import functional as F\n\nparser = argparse.ArgumentParser(description='Motor Parameters')\nparser.add_argument('--lr', default=0.001, type=float, help='learning rate')\nparser.add_argument('--time-window', type=int, default=8, help=\"Number of timesteps to do.\")\nparser.add_argument('--device', type=str, default='0', help=\"CUDA device\")\nparser.add_argument('--log-path', type=str, default='./logs/out.txt', help=\"Log path\")\n\nargs = parser.parse_args()\nprint(args)\n\ndevice = torch.device('cuda:'+args.device)\nLABELS = {\n    'position_group_0':  (-0.337, -0.020,  -0.077,  -0.031164,  0.999496, -0.005979,  2.850154),\n    'position_group_1':  (-0.337,  0.007,  -0.077,  -0.039668,  0.999161, -0.010174,  2.894892),\n    'position_group_2':  (-0.337,  0.030,  -0.076,  -0.031164,  0.999496, -0.005979,  2.850154),\n    'position_group_3':  (-0.337,  0.052,  -0.076,  -0.031164,  0.999496, -0.005979,  2.850154),\n    'position_group_4':  (-0.339,  0.074,  -0.076,   0.016057,  0.999842, -0.007643,  2.804204),    \n    'position_group_5':  (-0.339,  0.096,  -0.078,   0.016057,  0.999842, -0.007643,  2.804204),  \n    'position_group_6':  (-0.339,  0.123,  -0.079,   0.016057,  0.999842, -0.007643,  2.804204),  \n    'position_group_7':  (-0.337,  0.139,  -0.080,   0.076912,  0.997035, -0.0021723, 2.799101),  \n    'position_group_8':  (-0.337,  0.163,  -0.0770,  0.076912,  0.997035, -0.0021723, 2.799101),  \n    'position_group_9':  (-0.338,  0.188,  -0.075,   0.076912,  0.997035, -0.002172,  2.799101),  \n    'position_group_10': (-0.338,  0.212,  -0.075,   0.087103,  0.995757, -0.029681,  2.785759),  \n    'position_group_11': (-0.338,  0.235,  -0.070,   0.087103,  0.995757, -0.029681,  2.785759),  \n    'position_group_12': (-0.338,  0.259,  -0.073,   0.087103,  0.995757, -0.029681,  2.785759),  \n    'position_group_13': (-0.339,  0.273,  -0.065,   0.202020,  0.979225,  0.017483,  2.764647),  \n    'position_group_14': (-0.336,  0.290,  -0.066,   0.244628,  0.963147, -0.111827,  2.740450),  \n}\nposition_num = 15\nposition_dims = 7\n\nTARGETS = []\nfor i in range(position_num):\n    TARGETS.append(np.array(LABELS['position_group_'+str(i)], dtype=np.float32))\nTARGETS = np.stack(TARGETS, axis=0)\n\n\nt_factors = np.array([10.0, 10.0, 100.0, 10.0, 1.0, 100.0, 1.0], dtype=np.float32)\nTARGETS_FAC = TARGETS * t_factors[np.newaxis, :]\n\nKEYS = {\n    'c1': 0,\n    'd2': 1,\n    'e1': 2,\n    'f1': 3,\n    'g1': 4, \n    'a1': 5,\n    'b1': 6,\n    'c2': 7,\n    'd2': 8,\n    'e2': 9,\n    'f2': 10,\n    'g2': 11,\n    'a2': 12,\n    'b2': 13,\n    'c3': 14,\n    'd3': 15,\n    'e3': 16\n}\n\nfinger_num = 3\nfinger_pop_num = 10\nkey_num = 17\nkey_pop_num = 5\ndef creat_key_finger_emb():\n    key_value = np.zeros((key_num, key_num*key_pop_num), dtype=np.float32)\n    finger_value = np.zeros((finger_num, finger_num*finger_pop_num), dtype=np.float32)\n    for i in range(key_num):\n        key_value[i, i*key_pop_num: (i+1)*key_pop_num] += 1.0   \n    for i in range(finger_num):\n        finger_value[i, i*finger_pop_num: (i+1)*finger_pop_num] += 1.0  \n    return (key_value, finger_value)\n\ndef mse_loss(pred, target):\n    mse = F.mse_loss(pred, target)\n\n\ndef main():\n    key_embs, finger_emb = creat_key_finger_emb()\n    in_dims = key_embs.shape[1] + finger_emb.shape[1]\n    model = Motion(in_dims=in_dims, out_dims=position_dims, time_window=args.time_window).to(device)\n    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)\n    criterion = nn.MSELoss().to(device)\n    T = 100\n    batch_size = 32\n    EPOCHS = 200\n    with open(args.log_path, 'a+') as f:\n        argsDict = args.__dict__\n        f.writelines('------------------ start ------------------' + '\\n')\n        for eachArg, value in argsDict.items():\n            f.writelines(eachArg + ' : ' + str(value) + '\\n')\n        f.writelines('------------------- end -------------------'+ '\\n')\n        for epoch in range(EPOCHS):\n            # train\n            for step in tqdm.tqdm(range(T)):\n                key_idxs = np.random.choice(key_num, size=batch_size)\n                finger_idxs = np.random.choice(finger_num, size=batch_size)    \n                labels = np.clip(key_idxs - finger_idxs, a_min=0, a_max=position_num-1)\n                in_emb = np.concatenate([key_embs[key_idxs], finger_emb[finger_idxs]], axis=-1)\n                x = torch.from_numpy(in_emb).to(device)\n                y = torch.from_numpy(TARGETS_FAC[labels]).to(device)\n                y_pred = model(x)\n                loss = criterion(y_pred, y)\n                optimizer.zero_grad()\n                loss.backward()\n                optimizer.step()\n            # test\n            \n            loss_record = []\n            f.writelines('\\n')\n            f.writelines('Epoch:[{epoch}/{total_eps}]:\\n'.format(epoch=epoch, total_eps=EPOCHS))\n            for key in range(key_num):\n                for fin in range(finger_num):\n                    in_emb = np.concatenate([key_embs[key], finger_emb[fin]], axis=-1)\n                    x = torch.from_numpy(in_emb).to(device)\n                    with torch.no_grad():\n                        pred = model(x)\n                    \n                    y = max(min(key - fin, position_num-1), 0)\n                    target = torch.from_numpy(TARGETS_FAC[y]).to(device)\n                    loss = F.mse_loss(pred, target, reduction='sum')\n                    loss_record.append(loss.cpu().item())\n                    real_pred = pred.cpu().numpy() / t_factors\n                    distant = np.sum((TARGETS[y] - real_pred)**2)**0.5\n                    f.writelines('  Predict position {y}: {pred}\\n'.format(y=y, pred=real_pred.tolist()))\n            f.writelines('==> Epoch:[{epoch}/{total_eps}][validation stage]: loss: {loss}, distant {dis}\\n'.format(\n                    epoch=epoch, total_eps=EPOCHS, loss=sum(loss_record)/len(loss_record), dis=distant))\n            print('==> Epoch:[{epoch}/{total_eps}][validation stage]: loss: {loss}, distant {dis}\\n'.format(\n                    epoch=epoch, total_eps=EPOCHS, loss=sum(loss_record)/len(loss_record), dis=distant))\n            \n\nif __name__ == '__main__':\n    main()"
  },
  {
    "path": "examples/MotorControl/experimental/model.py",
    "content": "import torch\nimport numpy as  np\nimport torch.nn as nn\nfrom brain_area import Celebellum, MotorCortex\n\n\nclass Motion(nn.Module):\n    def __init__(self, in_dims: int, out_dims: int=17, time_window: int=8, emb_size: int = 128) -> None:\n        super().__init__()\n        self._time_window = time_window\n        self.in_emb = nn.Linear(in_dims, emb_size)\n        self.motor_cotex = MotorCortex(input_dims=emb_size, out_dims=64,  time_window=self._time_window)\n        self.cele = Celebellum(input_dims=64, out_dims=out_dims, time_window=self._time_window)\n        # self.opti = torch.optim.Adam(net.parameters(), lr=0.001)\n\n    def forward(self, x):\n        in_emb = self.in_emb(x)\n        motor_out = self.motor_cotex(in_emb)\n        out = self.cele(motor_out)\n        return out\n\n    def learn(self):\n        pass\n\n"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/README.md",
    "content": "# Corticothalamic minicolumn\n\n## Description\nThe anatomical data is saved in the \"tool\" package. The **main.py** create the network of minicolumn deppending on the anatomical data.\nA file named **\"fire.csv\"** will be generated to record the firing result of neurons in each time step.\n\n## Requirments\n* numpy\n* scipy\n* pytorch >= 1.7.0\n\n```shell\npython main.py\n```"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/data/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/data/globaldata.py",
    "content": "'''\nCreated on 2014.11.25\n\n@author: Liang Qian\n'''\n\n#from tools import dbconnection as DB\nfrom tools import exdata as Data\n\ndata = Data.EXDATA()\ncurneuronindex = 0\n\n\nSynapseNumberPerDendrite = 40\nProximalDendriteNumerPerNeuron = 1\nf = open('fire.csv','w')\n\n"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/main.py",
    "content": "import sys\nsys.path.append(\"../\")\nimport time\nfrom model import cortex_thalamus\n\nif __name__ == '__main__':\n    starttime = time.time()\n    myCortex = cortex_thalamus.Cortex_Thalamus(1000)  # create a cortex object and specify the neuron number scale\n    myCortex.CreateCortexNetwork()  # create cortex-thalamus network by the cortical object\n    myCortex.run()\n    print(len(myCortex.synapse))\n\n    totaltime = (time.time() - starttime)\n    print(\"totaltime:\" + str(totaltime) + \"s\")"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/model/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/model/cortex.py",
    "content": "'''\nCreated on 2015.5.27\n\n@author: Liang Qian\n'''\n\nfrom data import globaldata as Global\nfrom .layer import Layer\nfrom .synapse import Synapse\n\nclass Cortex():\n    '''\n    This class defines properties and funtions of cortex\n    '''\n\n\n    def __init__(self,neuronnumscale):\n        '''\n        Constructor\n        '''\n        self.neuronnumscale = neuronnumscale\n        self.neuronsNumber = 0\n        self.synapsesNumber = 0\n        self.neurons = [] # a list storing all neurons of the whole cortex\n        self.layers = {} # a dictionary storing layers of cortex\n        self.synapses = [] # a list storing all synapses of the whole cortex\n        self.minicolumns = [] # a list storing information per mini-column\n        self.neurontoindex = {} # a dictionary storing name to index of neuronlist\n        self.totaldata = Global.data.getCortexData()\n    \n    def setNeuronToIndex(self,node):\n        name = node.name\n        if(self.neurontoindex.get(name) == None):\n            self.neurontoindex[name] = len(self.neurons) - 1\n    def setLayers(self):\n        layerdic = Global.data.getLayerData()\n        for i,info in layerdic.items():\n            layer = Layer()\n            layer.name = info.get('name')\n            layer.neuronnum = self.neuronnumscale * float(info.get('neuronnum'))/100\n            print(layer.name + ' neuron number:'+str(layer.neuronnum))\n#             layer.synapsenum = self.synapsenum * info.get('synapsenum')\n#             print(layer.name + ' synapse number:'+str(layer.synapsenum))\n            self.layers[layer.name] = layer\n\n"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/model/cortex_thalamus.py",
    "content": "'''\nCreated on 2014.11.13\n\n@author: Liang Qian\n'''\n\nimport sys\nfrom .dendrite import Dendrite\nsys.path.append('../')\n\nfrom data import globaldata as Global\nfrom .thalamus import Thalamus\nfrom .cortex import Cortex\nfrom .layer import Layer\nfrom .synapse import Synapse\nfrom braincog.base.node.node import *\nclass Cortex_Thalamus():\n    '''\n    cortex class is used to build human brain cortex\n    members are as follows:\n    @param neuronnum: total neuron number\n    @param layer: layer list\n    \n    at the aspect of coding or algorithm,cortex is a huge Graph,neurons can be seen as nodes,synapses can be seen as edges,\n    but there may be plenty of edges between any two given nodes,so we can not use the <ni,nj> to express an edge,\n    the edge should be defined as an object,if we use adjacency list to store this huge graph,the form is just like\n    node->map(node,list<edge>)->(map(another node,list<edge>))\n    '''\n\n\n    def __init__(self, neuronnumscale):\n        self.neuronnumscale = neuronnumscale\n        self.cortex = Cortex(neuronnumscale)\n        self.thalamus = Thalamus()\n        self.minicolumns = [] # a list storing information per mini-column\n        self.synapsenum = 0\n        self.neuronname = [] # a list storing neuron name in all cortex\n        self.neuron = []  # a list storing all neuron in cortex\n        self.synapse = [] # a list storing all synapses in thalamocortex\n        self.neurontoindex = {} # a dictionary to storing the neuron mapping to index in neuron list\n        self.totaldata = Global.data.getCortexData()\n        \n        \n    def setSynapseNum(self):\n        num = 0\n        for name,info in self.totaldata.items():\n            num += self.neuronnumscale * (info.get('synnum')*info.get('neunum')/100.0)\n            self.neuronname.append(info.get('neuronname'))\n        self.synapsenum = num\n        \n    def setLayer(self):\n        layerdic = Global.data.getLayerData()\n        for i,info in layerdic.items():\n            layer = Layer()\n            layer.name = info.get('name')\n            layer.neuronnum = self.neuronnumscale * info.get('neuronnum')/100\n            print(layer.name + ' neuron number:'+str(layer.neuronnum))\n            layer.synapsenum = self.synapsenum * info.get('synapsenum')\n            print(layer.name + ' synapse number:'+str(layer.synapsenum))\n            self.layer[layer.name] = layer        \n \n    def setNeuronsDendritesAndSynapes(self): \n        neurondic = Global.data.getNeuronData()\n        index = 0\n        for s in self.neuronname:\n            neuroninfo = neurondic.get(s)\n            self.neurontoindex[s] = len(self.neuron)\n            num = self.neuronnumscale * float(neuroninfo.get('percent'))/100.0\n            #using synapse numbers to compute dendrites numbers\n            synapsedic = Global.data.getSynapseData(s)\n            dendic = {}\n            totaldennum = 0\n            for r,item in synapsedic.items():\n                synum = int(item.get('synapsenum'))\n                dennum = (synum+Global.SynapseNumberPerDendrite-1)//Global.SynapseNumberPerDendrite # a dendrite contains no more than 40 synapses\n                loc = item.get('locationlayer')\n                dendic[loc] = dennum\n            for i in range(int(num)):\n                #init\n                node = CTIzhNode(morphology = neuroninfo.get('morphology'),\n                                name = neuroninfo.get('name'),\n                                excitability = neuroninfo.get('excitability'),\n                                spiketype = neuroninfo.get('spiketype'),\n                                synnum = synum,\n                                locationlayer = neuroninfo.get('location layer'),\n                                totalindex = index,\n                                Gup = float(neuroninfo.get('Gup')),\n                                Gdown = float(neuroninfo.get('Gdown')),\n                                Vr = float(neuroninfo.get('Vr')),\n                                Vt = float(neuroninfo.get('Vt')),\n                                Vpeak = float(neuroninfo.get('Vpeak')),\n                                a = float(neuroninfo.get('a')),\n                                b = float(neuroninfo.get('b')),\n                                c = float(neuroninfo.get('Csoma')),\n                                d = float(neuroninfo.get('d')),\n                                capacitance = float(neuroninfo.get('capacitance')),\n                                k = float(neuroninfo.get('k')),\n                                 )\n                # set dendrites\n                count = 0; flag = False\n                for dlocatelayer,dnum in dendic.items():\n                    for j in range(dnum):\n                        den = Dendrite()\n                        den.locationlayer = dlocatelayer\n                        if(dlocatelayer != node.locationlayer or flag):\n                            den.postion = 'distal'\n                            node.distal_dendrites.append(den)\n                        else:\n                            if(count < Global.ProximalDendriteNumerPerNeuron):\n                                den.postion = 'proximal'\n                                node.proximal_dendrites.append(den)\n                                count += 1\n                                if(count >= Global.ProximalDendriteNumerPerNeuron):\n                                    flag = True                                                    \n                #node.getDendritesInfo()\n                if(node.locationlayer == 'T'):\n                    self.thalamus.neuronsNumber += 1\n                    self.thalamus.neurons.append(node)\n                    self.thalamus.setNeuronToIndex(node)\n                else: \n                    self.cortex.neuronsNumber += 1\n                    self.cortex.neurons.append(node)\n                    self.cortex.setNeuronToIndex(node)\n                    la = self.cortex.layers.get(node.locationlayer)\n                    la.neuronlist.append(node)\n                    if node.name not in la.neuronname:\n                        la.neuronname.append(node.name)\n                self.neuron.append(node)\n                index += 1\n        # set synapses\n        for postindex,node in enumerate(self.neuron): # this step can be optimized\n            synapsedic = Global.data.getSynapseData(node.name)\n            # get synapse_pre neuron and relative numbers of synapses\n            count = 0\n            for r,item in synapsedic.items():\n                #print(item)\n                for s in self.neuronname:\n                    #print(\"pre_neuron_name:\"+s)\n                    totalsynapsenum = round(item.get('synapsenum') * item.get(s)/100.0)\n                    if(postindex == 0):count += totalsynapsenum\n                    #print(\"pre_neuron_name_synapse_num:\"+str(totalsynapsenum))\n                    if(totalsynapsenum > 0): #if this neuron connect to synapse_pre neuron s,distribute the synapse to these neurons\n                        info = self.totaldata.get(s)\n                        preneuronnum = round(self.neuronnumscale * int(info.get('neunum'))/100)\n                        #print(\"pre_neuron_name_number:\" + str(preneuronnum))\n                        avgnum = lastnum = 0\n                        if(preneuronnum == 1):\n                            avgnum = lastnum = totalsynapsenum\n                        else:\n                            avgnum = totalsynapsenum // (preneuronnum-1)\n                            lastnum = totalsynapsenum - (avgnum*(preneuronnum-1))\n                        preneuronindex = self.neurontoindex.get(s)\n                        for j in range(preneuronindex,preneuronindex+preneuronnum):\n                            if(j == postindex):lastnum += avgnum;continue # not connect to itself\n                            preneuron = self.neuron[j]\n                            synapselist = []\n                            if(preneuron.adjneuronlist.get(node) != None):\n                                synapselist = preneuron.adjneuronlist.get(node)\n                            if(j == preneuronindex+preneuronnum-1): #the last neuron\n                                for t in range(lastnum):\n                                    synapse = Synapse(self.neuron[j],node,item.get('locationlayer'))\n                                    synapselist.append(synapse)\n                                    self.synapse.append(synapse)\n                                    if(item.get('locationlayer') != 'T'):\n                                        layerinfo = self.cortex.layers.get(item.get('locationlayer'))\n                                        layerinfo.synapselist.append(synapse)\n                                    if(node.locationlayer == 'T'): \n                                        self.thalamus.synapses.append(synapse)\n                                        self.thalamus.synapsesNumber += 1\n                                    else:\n                                        self.cortex.synapses.append(synapse)\n                                        self.cortex.synapsesNumber += 1\n                            else:\n                                for t in range(avgnum):\n                                    synapse = Synapse(self.neuron[j],node,item.get('locationlayer'))\n                                    synapselist.append(synapse)\n                                    self.synapse.append(synapse)\n                                    if(item.get('locationlayer') != 'T'):\n                                        layerinfo = self.cortex.layers.get(item.get('locationlayer'))\n                                        layerinfo.synapselist.append(synapse)\n                                    if(node.locationlayer == 'T'): \n                                        self.thalamus.synapses.append(synapse)\n                                        self.thalamus.synapsesNumber += 1\n                                    else:\n                                        self.cortex.synapses.append(synapse)\n                                        self.cortex.synapsesNumber += 1\n                            if(preneuron.adjneuronlist.get(node) == None and len(synapselist) > 0):\n                                preneuron.adjneuronlist[node] = synapselist            \n                      \n        # set these synapses to dendrite list\n        #self.setSynapsesToDendrites()\n    def setSynapsesToDendrites(self):\n        for node in self.neuron:\n            if node.name == 'TCs': # synapses from TCs to ss4(L4) must be located in proximal dendrites of ss4(L4)\n                for post,synlist in node.adjneuronlist.items():\n                    if(post.name == 'ss4(L4)'):\n                        for syn in synlist:\n                            flag = post.addSynapseToDendrite('proximal',syn)\n                            if(not flag):flag = post.addSynapseToDendrite('distal',syn)\n                            if(not flag):\n                                print('all dendrites are full in neuron' + post.name + '_'+str(node.totalindex))\n                    else:\n                        for syn in synlist:\n                            flag = False\n                            if(post.locationlayer == syn.locationlayer):\n                                flag = post.addSynapseToDendrite('proximal',syn)\n                                if(not flag):\n                                    flag = post.addSynapseToDendrite('distal',syn)\n                            else:\n                                flag = post.addSynapseToDendrite('distal',syn)\n                            if(not flag): print('all dendrites are full in neuron' + post.name + '_'+str(node.totalindex))\n            else:\n                for post,synlist in node.adjneuronlist.items():\n                    for syn in synlist:\n                        flag = False\n                        if(post.locationlayer == syn.locationlayer):\n                            flag = post.addSynapseToDendrite('proximal',syn)\n                            if(not flag):\n                                flag = post.addSynapseToDendrite('distal',syn)\n                            if(not flag): print('all dendrites are full in neuron' + node.name + '_'+str(node.totalindex))\n#         f = open('dendrites_info.csv','w')\n#         for node in self.neuron:\n#             if(node.totalindex == 0):\n#                 node.getDendriteSynapsesInfo(f)\n#         f.close()\n\n    def setCortexProperties(self):\n        self.cortex.setCortexProperties()\n    def setThalamusProperties(self):\n        self.thalamus.setThalamusProperties()            \n    \n    def CreateCortexNetwork(self):\n        self.setSynapseNum()\n        self.cortex.setLayers()\n        self.setNeuronsDendritesAndSynapes()\n        self.setThalamusProperties()\n\n                \n    #----------API Of the whole network--------------#\n    def getTotalNeuronNumber(self):\n        return len(self.neuron)\n    def getTotalSynapseNumber(self):\n        return len(self.synapse)\n    def getCortexNeuronNumber(self):\n        return self.corticalneuronnumber\n    def getThalamoNeuronNumber(self):\n        return self.thamlamoneuronnumber\n    def getSpecifiedNeuronNumber(self,name):\n        result = {}\n        if name in self.neuronname:\n            info = self.totaldata.get(name)\n            num = self.neuronnumscale * info.get('neunum')/100\n            result[name] = num\n        elif name == 'all':\n            for r,info in self.totaldata.items():\n                num = self.neuronnumscale * info.get('neunum')\n                result[r] = num\n        return result   \n    def getNeuronTypesNumber(self):\n        return len(self.neuronname)\n    def getNeuronTypes(self):\n        return self.neuronname\n    def getCorticalSynapseNumber(self):\n        return len(self.corticalsynapse)\n    def getThalamoSynapseNumber(self):\n        return len(self.corticalsynapse)\n    def getPreAndPostNeuronsOfSynapse(self,index):\n        if(index >= 0 or index <= len(self.synapse -1)):\n            return self.synapse[index].pre,self.synapse[index].post\n        else: return None\n    #--------------API of the layer----------------------#\n    def getCortexLayerNeuronNumber(self,layername):\n        layerinfo = self.layer.get(layername)\n        return layerinfo.getLayerNeuronNumber()\n    def getCortexLayerSynapseNumber(self,layername):\n        layerinfo = self.layer.get(layername)\n        return layerinfo.getLayerSynapseNumber()\n    def getCortexLayerNeuronTypes(self,layername):\n        layerinfo = self.layer.get(layername)\n        if(layerinfo == None):\n            print(layername +\" is not in Cortex!\")\n            return None\n        return layerinfo.neuronname                                \n    def getCortexLayerPreAndPostNeuronsOfSynapse(self,layername,index):\n        layerinfo = self.layer.get(layername)\n        if(index >= 0 and index < len(layerinfo.synapselist)):\n            return layerinfo.synapselist[index].pre,  layerinfo.synapselist[index].post\n    def getNeuronAllPreNeuronsTypes(self,index):\n        if(index >= 0 and index < len(self.neuron)):\n            node = self.neuron[index]\n            return node.getWholePreSynapseNeuronType()\n    def outputNeuronInfo(self):\n        f = open('neuron.csv','w')\n        f.write('index,name,morphology,excitability,locationlayer\\n')\n        for node in self.neuron:\n            flag = 'No'\n            if(node.excitability == \"TRUE\"):\n                flag = 'Yes'           \n            f.write(str(node.totalindex)+','+node.name+','+node.morphology+','+flag+','+node.locationlayer+'\\n')\n        f.close()\n    def outputConnectionMatrix(self):\n        M = len(self.neuron)\n        matrix = [[0 for col in range(M)] for row in range(M)]\n        for node in self.neuron:\n            for post,list in node.adjneuronlist.items():\n                weight = len(list)\n                row = node.totalindex\n                col = post.totalindex\n                matrix[row][col] = weight\n        f = open('connection.csv','w')\n        name = ''\n        for node in self.neuron:\n            name += node.name + ','\n        f.write(name+'\\n')\n        for row in range(M):\n            line = ''\n            for col in range(M):\n                line += str(matrix[row][col])+','\n            f.write(line+'\\n')\n        f.close()\n    def outputsynapspercent(self,namelist):\n        totalcount = 0\n        slist = {'L1':0,'L2/3':0,'L4':0,'L5':0,'L6':0,'T':0}\n        for name in namelist:\n            for node in self.neuron:\n                for pre,list in node.adjneuronlist.items():\n                    if pre.name == name:\n                        totalcount = totalcount+len(list)\n                        loclayer = node.locationlayer\n                        value = slist.get(loclayer) + len(list)\n                        slist[loclayer] = value\n        print(slist)\n#-----------------------runing the whole network---------------------------#\n    def run(self):\n        '''\n        run the cortical system\n        '''\n        #s1:stimulate the neuron in L4\n        L = self.cortex.layers.get('L4')\n        L.stimulateNeuronInLayer4_BFS(30,self.neuron)\n\n\n\n    def outputSpikeThreashold(self):\n        f = open('threashold.csv','w')\n        for node in self.neuron:\n            f.write(str(node.Vpeak)+'\\n')\n        f.close() \n\n\n            \n        \n                     \n                                            \n                "
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/model/dendrite.py",
    "content": "'''\nCreated on 2015.5.19\n\n@author: liangqian\n'''\nimport sys\n\nimport sys\nfrom data.globaldata import *\n\nclass Dendrite():\n    '''\n    This class defines dendrite structure of a neuron.\n    A dendrite contains no more than 40 synapses\n    '''\n    def __init__(self):\n        '''\n        Constructor\n        '''\n        self.synapses = [] # synapses list which this dendrite contains\n        self.locationlayer = '' # layer this dendrite locates in\n        self.postion = '' # the distance from soma, proximal or distal\n    \n    def setSynapse(self,syn):\n        '''\n            This function is going to insert the Synapse syn to this dendrite\n            if the number of synapse of this dendrite is more than theshold, the current synapse\n            can not be inserted to the dendrite.\n        '''\n        if(len(self.synapses) >= SynapseNumberPerDendrite):\n            return False\n        else:\n            self.synapses.append(syn)\n            return True\n    def getSynapseInfo(self,f,nodename,denpos):\n        for syns in self.synapses:\n            syns.getInfo(f,nodename,denpos,self.locationlayer)    "
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/model/fire.csv",
    "content": ""
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/model/layer.py",
    "content": "'''\nCreated on 2014.11.13.\n\n@author: Liang Qian\n'''\n\nimport sys\nsys.setrecursionlimit(1000000)\nfrom data import globaldata as Global\nclass Layer():\n    '''\n    layer class is used to build brain cortical layer of cortex\n    class members are as follows:\n    @param name: layer name(L1,L2,L3,.etc)\n    @param neuronnum: total neural numbers of this layer\n    @param neuraltype: neural type list of this layer\n    @param neuronlist: different types of neuron list in this layer    \n    '''\n\n\n    def __init__(self):\n        self.name = ''\n        self.neuronnum = 0\n        self.synapsenum = 0\n        self.neuronlist = []\n        self.synapselist = []\n        self.neuronname = []\n        \n    def getLayerNeuronNumber(self):\n        return len(self.neuronlist)\n    def getLayerSynapseNumber(self):\n        return len(self.synapselist)\n    def getLayerNeuronTypes(self):\n        return self.neuronname\n        \n    def stimulateNeuronInLayer4_BFS(self, T, neulist):\n        for node in self.neuronlist:\n            if(node.name == 'ss4(L2/3)'):\n                break;\n        step = int(T*1000/1.0)\n        dc = 0\n        for i in range(step):\n            #print(i)\n            strs = str(i)+','\n            if(i > 1):\n                for n in neulist:\n                    if(n.totalindex == node.totalindex):continue\n                    if(n.dc > 0):\n                        n.integral(n.dc)\n                        n.calc_spike()\n                        if(n.spike == 1):\n                            strs +=n.name+':'+str(n.totalindex)+','\n            if(i < 10 or i > 25000):\n                dc = 0\n            else:\n                dc = 400\n            node.integral(dc)\n            node.calc_spike()\n            if(node.spike == 1):\n                strs +=node.name+':'+str(node.totalindex)+','\n            \n            Global.f.write(strs+'\\n')\n        Global.f.close()"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/model/synapse.py",
    "content": "'''\nCreated on 2014.11.13\n\n@author: Liang Qian\n'''\n\nclass Synapse():\n    '''\n    synapsis class is used to create a synapsis structure\n    members are as follows:\n    @param pre: pre-synapsis neuron\n    @param post: post-synapsis neuron\n    @param locationlayer: layer where this synapse locate in   \n    '''\n\n\n    def __init__(self, pre,post,locationlayer):\n        self.pre = pre\n        self.post = post\n        self.locationlayer = locationlayer\n        self.I = 0\n        self.weight = 0 if(pre.name == 'p2/3' and post.name == 'p2/3') else -1\n#         self.tauAMPA = 5\n#         self.tauNMDA = 150\n#         self.tauGABAA = 6\n#         self.tauGABAB = 150\n#         self.STDPA_pos = 1\n#         self.STDPA_neg = 2\n#         self.tau_pos = 20\n#         self.tau_neg = 20\n    def getInfo(self,f,nodename,denpos,denlayer):\n        f.write('neuron:'+nodename+','+'dendrite:'+denpos+','+'den_layer:'+denlayer+','\n            +'syn_Layer:'+ self.locationlayer+','\n            + 'pre_neuron:'+self.pre.name+','+'pre_neuron_index:'+str(self.pre.totalindex)+','\n            + 'post_neuron:'+self.post.name+','+'post_neuron_index:'+str(self.post.totalindex)+','\n            + 'weight:'+str(self.weight)+'\\n')    \n        \n        "
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/model/thalamus.py",
    "content": "'''\nCreated on 2015.5.27\n\n@author: Liang Qian\n'''\n\nclass Thalamus():\n    '''\n    This class defines the basic functions and properties of thalamus\n    '''\n\n\n    def __init__(self):\n        '''\n        Constructor\n        '''\n        self.neuronsNumber = 0\n        self.synapsesNumber = 0\n        self.neurons = []\n        self.synapses = []\n        self.neurontoindex = {}\n    \n    def setNeuronToIndex(self,node):\n        name = node.name\n        if(self.neurontoindex.get(name) == None):\n            self.neurontoindex[name] = len(self.neurons)-1\n    def setThalamusProperties(self):\n        print(len(self.synapses))\n        for node in self.neurons:\n            if(node.name == 'TCs'):\n                for post,synlist in node.adjneuronlist.items():\n                    if(post.name == 'ss4(L2/3)'):\n                        for syn in synlist:\n                            syn.weight = 0\n            if(node.name == 'TCn'):\n                for post,synlist in node.adjneuronlist.items():\n                    if(post.name == 'p6(L5/6)'):\n                        for syn in synlist:\n                            syn.weight = 0\n    "
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/tools/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/tools/cortical.csv",
    "content": "neuronname,neuronnum,synmapsenum,area\nnb1,1.5,8890,cortex\np2/3,26,7106,cortex\nb2/3,3.1,3854,cortex\nss4(L4),9.2,5792,cortex\nss4(L2/3),9.2,4989,cortex\np4,9.2,6703,cortex\nb4,5.4,3230,cortex\nnb4,1.5,3688,cortex\np5(L2/3),4.8,5196,cortex\np5(L5/6),1.3,13075,cortex\nb5,0.6,2981,cortex\nnb5,0.8,2981,cortex\np6(L4),13.6,6363,cortex\np6(L5/6),4.5,6421,cortex\nb6,2,3220,cortex\nnb6,2,3220,cortex\nnb2/3,4.2,3307,cortex\nTCs,0.5,4000,thalamus\nTCn,0.5,4000,thalamus\nTIs,0.1,3000,thalamus\nTIn,0.1,3000,thalamus\nTRN,0.5,4000,thalamus\n"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/tools/exdata.py",
    "content": "import os\nimport csv\n#import pandas as pd\nclass EXDATA():\n    def __int__(self):\n        pass\n\n    def getCortexData(self):\n        neurondic = {}\n        f = open(\"./tools/cortical.csv\",\"r\")\n        line = f.readline()\n        count = 0\n        while(True):\n            line = (f.readline()).strip()\n            if(len(line) <= 0): break;\n            strs = line.split(\",\")\n            info = {}\n            info['neuronname'] = strs[0].strip()\n            info['neunum'] = float(strs[1])\n            info['synnum'] = int(strs[2])\n            info['area'] = str(strs[3])\n            neurondic[strs[0].strip()] = info\n            count += 1\n        f.close()\n        print(neurondic)\n        return neurondic\n\n    def getCortexData2(self):\n        neurondic = {}\n        data = pd.read_csv(\"../tools/cortical.csv\")\n        print('debug')\n\n\n    def getLayerData(self):\n        layerdic = {}\n        f = open(\"./tools/layer.csv\",\"r\")\n        strs = (f.readline()).strip()\n        str = strs.split(\",\")\n        count = 0\n        while(True):\n            info = {}\n            line = (f.readline()).strip()\n            if(len(line) <= 0):break\n            v = line.split(\",\")\n            for i in range(len(str)):\n                if(i > 0):\n                    v[i] = float(v[i])\n                info[str[i]] = v[i]\n            layerdic[count] = info\n            count += 1\n        f.close()\n        return layerdic\n\n    def getNeuronData(self):\n        neurondic = {}\n        f = open(\"./tools/neuron.csv\",\"r\")\n        strs = (f.readline()).strip()\n        str = strs.split(\",\")\n        while (True):\n            info = {}\n            line = (f.readline()).strip()\n            if (len(line) <= 0): break\n            v = line.split(\",\")\n            for i in range(len(str)):\n                if(len(v[i]) <= 0): break\n                if(i > 4): v[i] = float(v[i])\n                info[str[i].strip()] = v[i]\n            neurondic[v[0].strip()] = info\n        f.close()\n        return neurondic\n\n    def getSynapseData(self, postneuron):\n        synapsemap = {}\n        f = open(\"./tools/synapse.csv\",\"r\")\n        fields = (f.readline()).strip()\n        fields = fields.split(\",\")\n        while(True):\n            line = (f.readline()).strip()\n            if(len(line) <= 0):break\n            strs = line.split(\",\")\n            info = {}\n            for i,v in enumerate(fields):\n                if(i > 1): strs[i] = float(strs[i])\n                info[v] = strs[i]\n            if(synapsemap.get(strs[0]) == None):\n                syndic = {}\n                syndic[len(syndic)] = info\n                synapsemap[strs[0]] = syndic\n            else:\n                syndic = synapsemap.get(strs[0])\n                syndic[len(syndic)] = info\n        f.close()\n\n        return synapsemap.get(postneuron)\n\n\n#tmp = EXDATA()\n#tmp.getCortexData()\n#tmp.getLayerData()\n#tmp.getNeuronData()\n#result = tmp.getSynapseData('p4')\n#print(result)"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/tools/layer.csv",
    "content": "name,neuronnum,synapsenum,nb1,p2/3,b2/3,nb2/3,ss4(L4),ss4(L2/3),p4,b4,nb4,p5(L2/3),p5(L5/6),b5,nb5,p6(L5/6),b6,p6(L4),nb6\nL1,1.5,10.86,1.5,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0\nL2/3,33.3,32.86,0,26,3.1,4.2,0,0,0,0,0,0,0,0,0,0,0,0,0\nL4,34.5,34.04,0,0,0,0,9.2,9.2,9.2,5.4,1.5,0,0,0,0,0,0,0,0\nL5,7.5,8.1,0,0,0,0,0,0,0,0,0,4.8,1.3,0.6,0.8,0,0,0,0\nL6,22.1,14.14,0,0,0,0,0,0,0,0,0,0,0,0,0,4.5,2,13.6,2\n"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/tools/neuron.csv",
    "content": "name,morphology,location layer,spiketype,excitability,percent,capacitance,k ,Vr ,Vt ,Vpeak ,Gup,Gdown,a ,b ,Csoma,Cdendr,d\nnb1,non-basket,L1,LS,FALSE,1.5,20,0.3,-66,-40,30,0.6,2.5,0.17,5,-45,-45,100\np2/3,pyramidal,L2/3,RS,TRUE,26,100,3,-60,-50,50,3,5,0.01,5,-60,-55,400\nb2/3,basket,L2/3,FS,FALSE,3.1,20,1,-55,-40,25,0.5,1,0.15,8,-55,-55,200\nnb2/3,non-basket,L2/3,LTS,FALSE,4.2,100,1,0,-42,40,1,1,0.03,8,-50,-50,20\nss4(L4),spiny stell,L4,RS,TRUE,9.2,100,3,-60,-50,50,3,5,0.01,5,-60,-50,400\nss4(L2/3),spiny stell,L4,RS,TRUE,9.2,100,3,-60,-50,50,3,5,0.01,5,-60,-50,400\np4,pyramidal,L4,RS,TRUE,9.2,100,3,-60,-50,50,3,5,0.01,5,-60,-50,400\nb4,basket,L4,FS,FALSE,5.4,20,1,-55,-40,25,0.5,1,0.15,8,-55,-55,200\nb5,basket,L5,FS,FALSE,0.6,20,1,-55,-40,25,0.5,1,0.15,8,-55,-55,200\nnb4,non-basket,L4,LTS,FALSE,1.5,100,1,-55,-42,40,1,1,0.03,8,-50,-50,20\np5(L2/3),pyramidal,L5,RS,TRUE,4.8,100,3,-60,-50,50,3,5,0.01,5,-60,-50,400\nnb5,non-basket,L5,LTS,FALSE,0.8,100,1,-55,-42,40,1,1,0.03,8,-50,-50,20\np5(L5/6),pyramidal,L5,RS,TRUE,1.3,100,3,-60,-50,50,3,5,0.01,5,-60,-50,400\np6(L4),pyramidal,L6,RS,TRUE,13.6,100,3,-60,-50,50,3,5,0.01,5,-60,-50,400\nb6,basket,L6,FS,FALSE,2,20,1,-55,-40,25,0.5,1,0.15,8,-55,-55,200\nnb6,non-basket,L6,FS,FALSE,2,100,1,0,-42,40,1,1,0.03,8,-50,-50,20\nTCs,TC in specific nucleus,T,TC,TRUE,0.5,200,1.6,-60,-50,40,2,2,0.1,15,-60,-60,10\nTCn,TC in non specific nucleus,T,TC,TRUE,0.5,200,1.6,-60,-50,40,2,2,0.1,15,-60,-60,10\nTIs,thalamic in specific nucleus,T,TI,FALSE,0.1,20,0.5,-60,-50,20,5,5,0.05,7,-65,-65,50\nTIn,thalamic in non specific nucleus,T,TI,FALSE,0.1,20,0.5,-60,-50,20,5,5,0.05,7,-65,-65,50\nTRN,GABAergic,T,TRN,TRUE,0.5,40,0.25,-65,-45,0,5,5,0.015,10,-55,-55,50\np6(L5/6),pyramidal,L6,RS,TRUE,4.5,100,3,-60,-50,50,3,5,0.01,5,-60,-50,400\n"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/CorticothalamicColumn/tools/synapse.csv",
    "content": "postneuron,locationlayer,synapsenum,nb1,p2/3,b2/3,nb2/3,ss4(L4),ss4(L2/3),p4,b4,nb4,p5(L2/3),p5(L5/6),b5,nb5,p6(L4),p6(L5/6),b6,nb6,TCs,TCn,TIs,TIn,TRN\nnb1,L1,8890,10.1,6.3,0.6,1.1,0,0,0.1,0,0,0.1,0,0,0,0,0,0,0,0,4.1,0,0,0\np2/3,L2/3,5800,0,59.9,9.1,4.4,0.6,6.9,7.7,0,0.8,7.4,0,0,0,2.3,0,0,0.8,0,0,0,0,0\np2/3,L1,1306,10.2,6.3,0.1,1.1,0,0,0.1,0,0,0.1,0,0,0,0,0,0,0,0,4.1,0,0,0\nb2/3,L2/3,3854,1.3,51.6,10.6,3.4,0.5,5.8,6.6,0,0.8,6.3,0,0,0,2.1,0,0,0.7,0,0.5,0,0,0\nnb2/3,L2/3,3307,1.7,48.6,11.4,3.3,0.5,5.5,6.2,0,0.8,5.9,0,0,0,1.8,0,0,0.6,0,0.7,0,0,0\nss4(L4),L4,5792,0,2.7,0.2,0.6,11.9,3.7,4.1,7.1,2,0.8,0.1,0,0,32.7,0,0,5.8,1.7,1.3,0,0,0\nss4(L2/3),L4,4989,0,5.6,0.4,0.8,11.3,3.8,4.3,7.2,2.1,1.1,0.1,0,0,31.1,0,0,5.5,1.7,1.3,0,0,0\np4,L4,5031,0,4.3,0.2,0.6,11.5,3.6,4.2,7.2,2.1,1.2,0.1,0,0,31.4,0.1,0,5.9,1.7,1.3,0,0,0\np4,L2/3,866,0,63.1,5.1,4.1,0.6,7.2,8.1,0,0.6,7.8,0,0,0,2.5,0,0,0.8,0,0,0,0,0\np4,L1,806,10.2,6.3,0.1,1.1,0,0,0.1,0,0,0.1,0,0,0,0,0,0,0,0,4.1,0,0,0\nb4,L4,3230,0,5.8,0.5,0.8,11,3.8,4.2,8.4,2.4,1.1,0,0,0,30.3,0,0,5.4,1.6,1.2,0,0,0\nnb4,L4,3688,0,2.7,0.2,0.6,11.7,3.6,4,8.2,2.3,0.8,0.1,0,0,32.2,0,0,5.7,1.7,1.3,0,0,0\np5(L2/3),L5,4316,0,45.9,1.8,0.3,3.3,2,7.5,0,0.9,11.7,1,0.8,1.1,2.3,2.1,0,11.5,0.1,0.4,0,0,0\np5(L2/3),L4,283,0,2.8,0.1,0.7,12.2,3.8,4.2,5.2,1.5,0.8,0.1,0,0,33.7,0,0,5.9,1.8,1.4,0,0,0\np5(L2/3),L2/3,412,0,63.1,5.1,4.1,0.6,7.2,8.1,0,0.6,7.8,0,0,0,2.5,0,0,0.8,0,0,0,0,0\np5(L2/3),L1,185,10.2,6.3,0.1,1.1,0,0,0.1,0,0,0.1,0,0,0,0,0,0,0,0,4.1,0,0,0\np5(L5/6),L5,5101,0,44.3,1.7,0.2,3.2,2,7.3,0,0.8,11.3,1.2,0.8,1.1,2.3,2.5,0.3,11.3,0.2,0.5,0,0,0\np5(L5/6),L4,949,0,2.8,0.1,0.7,12.2,3.8,4.2,5.2,1.5,0.8,0.1,0,0,33.7,0,0,5.9,1.8,1.4,0,0,0\np5(L5/6),L2/3,1367,0,63.1,5.1,4.1,0.6,7.2,8.1,0,0.6,7.8,0,0,0,2.5,0,0,0.8,0,0,0,0,0\np5(L5/6),L1,5658,10.2,6.3,0.1,1.1,0,0,0.1,0,0,0.1,0,0,0,0,0,0,0,0,4.1,0,0,0\nb5,L5,2981,0,45.5,2.3,0.2,3.3,2,7.5,0,1.1,11.6,1,0.9,1.3,2.3,2,0,11.4,0.1,0.4,0,0,0\nnb5,L5,2981,0,45.5,2.3,0.2,3.3,2,7.5,0,1.1,11.6,1,0.9,1.3,2.3,2,0,11.4,0.1,0.4,0,0,0\np6(L4),L6,3261,0,2.5,0.1,0.1,0.7,0.9,1.3,0,0.1,0.1,4.9,0,0.3,1.2,13.2,7.7,7.7,1.6,2.9,0,0,0\np6(L4),L5,1066,0,46.8,0.8,0.3,3.4,2.1,7.7,0,0.6,11.9,1,0.6,0.8,2.3,2.1,0,11.7,0.1,0.4,0,0,0\np6(L4),L4,1915,0,2.8,0.1,0.7,12.2,3.8,4.2,5.2,1.5,0.8,0.1,0,0,33.7,0,0,5.9,1.8,1.4,0,0,0\np6(L4),L2/3,121,0,63.1,5.1,4.1,0.6,7.2,8.1,0,0.6,7.8,0,0,0,2.5,0,0,0.8,0,0,0,0,0\np6(L5/6),L6,5573,0,2.5,0.1,0.1,0.7,0.9,1.3,0,0.1,0.1,4.9,0,0.3,1.2,13.2,7.8,7.8,0,2.9,0,0,0\np6(L5/6),L5,257,0,46.8,0.8,0.3,3.4,2.1,7.7,0,0.6,11.9,1,0.6,0.8,2.3,2.1,0,11.7,0.6,0.4,0,0,0\np6(L5/6),L4,243,0,2.8,0.1,0.7,12.2,3.8,4.2,5.2,1.5,0.8,0.1,0,0,33.7,0,0,5.9,0,1.4,0,0,0\np6(L5/6),L2/3,286,0,63.1,5.1,4.1,0.6,7.2,8.1,0,0.6,7.8,0,0,0,2.5,0,0,0.8,0,0,0,0,0\np6(L5/6),L1,62,10.2,6.3,0.1,1.1,0,0,0.1,0,0,0.1,0,0,0,0,0,0,0,0,4.1,0,0,0\nb6,L6,3220,0,2.5,0.1,0.1,0.7,0.9,1.3,0,0.1,0.1,4.9,0,0.4,1.2,13.2,7.7,7.7,0.6,2.9,0,0,0\nnb6,L6,3220,0,2.5,0.1,0.1,0.7,0.9,1.3,0,0.1,0.1,4.9,0,0.4,1.2,13.2,7.7,7.7,0.6,2.9,0,0,0\nTCs,T,4000,0,0,0,0,0,0,0,0,0,0,0,0,0,23,8,0,0,0,0,5,0,25.9\nTCn,T,4000,0,0,0,0,0,0,0,0,0,14,3.8,0,0,0,13.2,0,0,0,0,0,5,25.9\nTIs,T,3000,0,0,0,0,0,0,0,0,0,0,0,0,0,9.8,3.3,0,0,0.4,0,24.4,0,0\nTIn,T,3000,0,0,0,0,0,0,0,0,0,5.8,1.6,0,0,0,5.4,0,0,0,0.6,0,24.4,0\nTRN,T,4000,0,0,0,0,0,0,0,0,0,0,0,0,0,30,0,0,0,10,10,0,0,10\n"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/Corticothalamic_Brain_Model/Bioinformatics_propofol_circle.py",
    "content": "import numpy as np\r\nimport random\r\nimport math\r\nimport matplotlib.pyplot as plt\r\nimport scipy.io as scio\r\nimport pandas as pd\r\nimport torch\r\n\r\ndevice = 'cpu'\r\n\r\ntrail = 0\r\n\r\nclass brain_model_91():\r\n\r\n    def __init__(self, W, D):\r\n\r\n        self.weight_matrix = W.to(device)\r\n        # for i in range(len(self.weight_matrix)):\r\n        #     self.weight_matrix[i][29] = 0\r\n        #     self.weight_matrix[29][i] = 0\r\n        self.distance_matrix = D.to(device)\r\n        self.distance_matrix = torch.vstack((self.distance_matrix,\r\n                                             torch.mean(self.distance_matrix, dim=1)))\r\n        self.distance_matrix = torch.hstack((self.distance_matrix,\r\n                                             torch.zeros((len(self.distance_matrix), 1), device=device)))\r\n        temp = self.distance_matrix.clone()\r\n        self.distance_matrix[0:29, 29] = temp[29, 0:29]\r\n        # self.distance_matrix = torch.zeros_like(self.distance_matrix, device=device)\r\n        self.speed = 1.5\r\n        self.decay = torch.ceil(self.distance_matrix / self.speed)\r\n        self.t_window = int(torch.max(self.decay)) + 1\r\n        self.V_th = torch.tensor([10., 4., 10., 4., 4.], device=device)\r\n        self.tau_v = torch.tensor([40., 10., 200., 20., 40.], device=device)\r\n        self.Tsig = torch.tensor([12., 10., 12., 10., 10.], device=device)\r\n        self.beta = torch.tensor([0., 4.5, 0., 4.5, 4.5], device=device)\r\n        self.alpha_ad = torch.tensor([0., -2., 0., -2., -2.], device=device)\r\n\r\n        self.tau_ad = 20\r\n        self.tau_I = 10\r\n\r\n        # Simulation parameters\r\n        self.NR = len(W)\r\n        self.NN = 500\r\n        self.NType = self.NN * torch.tensor([0.79, 0.20, 0.01, 0.005, 0.002], device=device)\r\n        self.NE = int(self.NType[0])\r\n        self.NI = int(self.NType[1])\r\n        self.NTC = int(self.NR * self.NType[2])\r\n        self.NTI = int(self.NR * self.NType[3])\r\n        self.NTRN = int(self.NR * self.NType[4])\r\n        self.NC = self.NE + self.NI\r\n        self.NSum = int((self.NR - 1) * (self.NE + self.NI) + self.NTC + self.NTI + self.NTRN)\r\n\r\n        self.Ncycle = 1\r\n        self.dt = 1\r\n        self.T = 20000\r\n        self.Delta_T = 0.5\r\n        # self.refrac = 5 / self.dt\r\n        # self.ref = self.refrac*torch.zeros((self.NN, 1)).squeeze(1)\r\n        self.gamma_c = 0.1\r\n        self.g_m = 1\r\n        self.Gama_c = self.g_m * self.gamma_c / (1 - self.gamma_c)\r\n        self.GammaII = 15\r\n        self.GammaIE = -10\r\n        self.GammaEE = 15\r\n        self.GammaEI = 15\r\n        self.TEmean = 0.5 * self.V_th[0]  # Mean current to excitatory neurons\r\n        self.TTCmean = 0.5 * self.V_th[2]  # Mean current to TC neurons\r\n        self.TImean = -5 * self.V_th[1]\r\n        self.TTImean = -5 * self.V_th[3]\r\n        self.TTRNmean = -5 * self.V_th[4]\r\n\r\n        self.v = torch.zeros(self.NSum, device=device)\r\n        self.vt = torch.zeros(self.NSum, device=device)\r\n        self.c_m = torch.zeros(self.NSum, device=device)\r\n        self.alpha_w = torch.zeros(self.NSum, device=device)\r\n        self.beta_ad = torch.zeros(self.NSum, device=device)\r\n        self.delta = torch.ones(self.NSum, device=device)\r\n        self.ad = torch.zeros(self.NSum, device=device)\r\n        self.vv = torch.zeros(self.NSum, device=device)\r\n        self.Iback = torch.zeros(self.NSum, device=device)\r\n        self.Ieff = torch.zeros(self.NSum, device=device)\r\n        self.Nmean = torch.zeros(self.NSum, device=device)\r\n        self.Nsig = torch.zeros(self.NSum, device=device)\r\n        self.Igap = torch.zeros(self.NSum, device=device)\r\n        self.Ichem = torch.zeros(self.NSum, device=device)\r\n        self.Ieeg = torch.zeros(self.NSum, device=device)\r\n        self.vm1 = torch.zeros(self.NSum, device=device)\r\n\r\n        self.E_range = []\r\n        self.I_range = []\r\n        self.TC_range = []\r\n        self.TI_range = []\r\n        self.TRN_range = []\r\n        self.divide_point_E = []\r\n        self.divide_point_I = []\r\n\r\n        for n in range(self.NR):\r\n            if n < self.NR - 1:\r\n                self.divide_point_E.append(list(range(n * self.NC, n * self.NC + self.NE)))\r\n                self.divide_point_I.append(list(range(n * self.NC + self.NE, n * self.NC + self.NE + self.NI)))\r\n                self.E_range = self.E_range + list(range(n * self.NC, n * self.NC + self.NE))\r\n                self.I_range = self.I_range + list(\r\n                    range(n * self.NC + self.NE, n * self.NC + self.NE + self.NI))\r\n\r\n            else:\r\n                self.TC_range = self.TC_range + list(range((self.NR - 1) * self.NC,\r\n                                                           (self.NR - 1) * self.NC + self.NTC))\r\n                self.TI_range = self.TI_range + list(range((self.NR - 1) * self.NC + self.NTC,\r\n                                                           (self.NR - 1) * self.NC + self.NTC + self.NTI))\r\n                self.TRN_range = self.TRN_range + list(range((self.NR - 1) * self.NC + self.NTC + self.NTI,\r\n                                                             (self.NR - 1) * self.NC + self.NTC + self.NTI + self.NTRN))\r\n        self.divide_point_E = torch.tensor(self.divide_point_E, device=device)\r\n        self.divide_point_I = torch.tensor(self.divide_point_I, device=device)\r\n        self.divide_point_CR = torch.concat((self.divide_point_E, self.divide_point_I), dim=1)\r\n        torch.save({'divide_point_E': self.divide_point_E, 'divide_point_I': self.divide_point_I,\r\n                    'TC_range': self.TC_range, 'TI_range': self.TI_range, 'TRN_range': self.TRN_range},\r\n                   './neuron_divide.pt')\r\n\r\n        self.c_m[self.E_range] = self.tau_v[0] * self.g_m + 5 * torch.randn(len(self.E_range), device=device)\r\n        self.c_m[self.TC_range] = self.tau_v[2] * self.g_m\r\n        self.c_m[self.I_range] = self.tau_v[1] * (self.g_m + self.Gama_c)\r\n        self.c_m[self.TI_range] = self.tau_v[3] * (self.g_m + self.Gama_c)\r\n        self.c_m[self.TRN_range] = self.tau_v[4] * (self.g_m + self.Gama_c)\r\n\r\n        self.alpha_w[self.E_range] = self.alpha_ad[0] * self.g_m\r\n        self.alpha_w[self.TC_range] = self.alpha_ad[2] * self.g_m + self.Gama_c\r\n        self.alpha_w[self.I_range] = self.alpha_ad[1] * (self.g_m + self.Gama_c)\r\n        self.alpha_w[self.TI_range] = self.alpha_ad[3] * (self.g_m + self.Gama_c)\r\n        self.alpha_w[self.TRN_range] = self.alpha_ad[4] * (self.g_m + self.Gama_c)\r\n\r\n        self.beta_ad[self.E_range] = self.beta[0]\r\n        self.beta_ad[self.TC_range] = self.beta[2]\r\n        self.beta_ad[self.I_range] = self.beta[1]\r\n        self.beta_ad[self.TI_range] = self.beta[3]\r\n        self.beta_ad[self.TRN_range] = self.beta[4]\r\n\r\n        self.vt[self.E_range] = self.V_th[0]\r\n        self.vt[self.TC_range] = self.V_th[2]\r\n        self.vt[self.I_range] = self.V_th[1]\r\n        self.vt[self.TI_range] = self.V_th[3]\r\n        self.vt[self.TRN_range] = self.V_th[4]\r\n\r\n        self.Nmean[self.E_range] = self.TEmean * self.g_m\r\n        self.Nmean[self.TC_range] = self.TTCmean * self.g_m\r\n        self.Nmean[self.I_range] = self.TImean * (self.g_m + self.Gama_c)\r\n        self.Nmean[self.TI_range] = self.TTImean * (self.g_m + self.Gama_c)\r\n        self.Nmean[self.TRN_range] = self.TTRNmean * (self.g_m + self.Gama_c)\r\n\r\n        self.Nsig[self.E_range] = self.Tsig[0] * self.g_m\r\n        self.Nsig[self.TC_range] = self.Tsig[2] * self.g_m\r\n        self.Nsig[self.I_range] = self.Tsig[1] * (self.g_m + self.Gama_c)\r\n        self.Nsig[self.TI_range] = self.Tsig[3] * (self.g_m + self.Gama_c)\r\n        self.Nsig[self.TRN_range] = self.Tsig[4] * (self.g_m + self.Gama_c)\r\n\r\n    def simulation(self):\r\n\r\n        range_E = self.E_range + self.TC_range\r\n        range_I = self.I_range + self.TI_range + self.TRN_range\r\n        Vgap = self.Gama_c\r\n        weight_matrix = self.weight_matrix\r\n\r\n        for i in range(self.Ncycle):\r\n            I_total = torch.zeros((self.Ncycle, self.T), device=device)\r\n            V_total = torch.zeros((self.Ncycle, self.T), device=device)\r\n\r\n            V = torch.zeros(self.T, device=device)\r\n            I_subregion = torch.zeros((self.NR, self.T), device=device)\r\n            I_subregion_E = torch.zeros((self.NR, self.T), device=device)\r\n            I_subregion_I = torch.zeros((self.NR, self.T), device=device)\r\n            Vsubregion = torch.zeros((self.NR, self.T), device=device)\r\n            EEG = torch.zeros((self.T), device=device)\r\n\r\n            Iraster = []\r\n            vv_sumE = torch.zeros((self.NR, self.t_window), device=device)\r\n            vv_sumI = torch.zeros((self.NR, self.t_window), device=device)\r\n\r\n            phase = self.T / 4\r\n            for t in range(self.T):\r\n                #\r\n                if t < phase:\r\n                    tau_vI = 20\r\n                    self.GammaII = 15\r\n                    self.GammaIE = -10\r\n                elif phase <= t < 3 * phase:\r\n                    tau_vI = 20 + 20 * (t - phase) / phase\r\n                    self.GammaII = 30 + 10 * (t - phase) / phase\r\n                    self.GammaIE = -20 - 10 * (t - phase) / phase\r\n                elif 3 * phase <= t < 4 * phase:\r\n                    tau_vI = 60\r\n                    self.GammaII = 50\r\n                    self.GammaIE = -40\r\n                elif 4 * phase <= t < 6 * phase:\r\n                    tau_vI = 60 - 20 * (t - 4 * phase) / phase\r\n                    self.GammaII = 50 - 10 * (t - 4 * phase) / phase\r\n                    self.GammaIE = -40 + 20 * (t - 4 * phase) / phase\r\n                elif t > 6 * phase:\r\n                    tau_vI = 20\r\n                    self.GammaII = 15\r\n                    self.GammaIE = -10\r\n\r\n                self.c_m[range_I] = tau_vI * (self.g_m + self.Gama_c)\r\n                WII = self.GammaII * torch.mean(self.c_m[self.I_range])\r\n                WEE = self.GammaEE * torch.mean(self.c_m[self.E_range])\r\n                WEI = self.GammaEI * torch.mean(self.c_m[self.I_range])\r\n                WIE = self.GammaIE * torch.mean(self.c_m[self.E_range])\r\n\r\n                self.Iback = self.Iback + self.dt / self.tau_I * (-self.Iback + torch.randn(self.NSum, device=device))\r\n                self.Ieff = self.Iback / math.sqrt(1 / (2 * (self.tau_I / self.dt))) * self.Nsig + self.Nmean\r\n\r\n                temp = vv_sumE.clone()\r\n                vv_sumE[:, 0:self.t_window - 1] = temp[:, 1:self.t_window]\r\n                vv_sumE[:, self.t_window - 1] = torch.cat((torch.mean(self.vv[self.divide_point_E], dim=1),\r\n                                                           torch.mean(self.vv[self.TC_range]).unsqueeze(0)))\r\n\r\n                temp = vv_sumI.clone()\r\n                vv_sumI[:, 0:self.t_window - 1] = temp[:, 1:self.t_window]\r\n                vv_sumI[:, self.t_window - 1] = torch.cat((torch.mean(self.vv[self.divide_point_I], dim=1),\r\n                                                           torch.mean(self.vv[self.TI_range + self.TRN_range]).unsqueeze(0)))\r\n\r\n\r\n                v_sum = torch.cat((torch.mean(self.v[self.divide_point_I], dim=1),\r\n                                   torch.mean(self.v[self.TI_range + self.TRN_range]).unsqueeze(0)))\r\n                v_sum_CR = v_sum[:self.NR - 1].reshape(-1, 1) * \\\r\n                           torch.ones((self.NR - 1, self.NI), device=device)\r\n                v_sum_CR = v_sum_CR.reshape(-1, 1).squeeze(1)\r\n                v_sum_TN = v_sum[self.NR - 1] * \\\r\n                           torch.ones(self.NTI + self.NTRN, device=device)\r\n                v_sum = torch.cat((v_sum_CR, v_sum_TN))\r\n\r\n                time_decay = torch.concat(\r\n                    (torch.concat([torch.arange(30, device=device).unsqueeze(0)] * 30, dim=0).unsqueeze(0),\r\n                     self.t_window - 1 - self.decay.unsqueeze(0)), dim=0)\r\n                time_decay = list(time_decay.long())\r\n\r\n                v_E = torch.sum(weight_matrix * vv_sumE[time_decay], dim=1)\r\n                v_I = torch.sum(weight_matrix * vv_sumI[time_decay], dim=1)\r\n\r\n\r\n                v_E_CR = v_E[:self.NR - 1].reshape(-1, 1) * \\\r\n                         torch.ones((self.NR - 1, self.NC), device=device)\r\n                v_I_CR = v_I[:self.NR - 1].reshape(-1, 1) * \\\r\n                         torch.ones((self.NR - 1, self.NC), device=device)\r\n                v_E_CR = v_E_CR.reshape(-1, 1).squeeze(1)\r\n                v_I_CR = v_I_CR.reshape(-1, 1).squeeze(1)\r\n\r\n                v_E_TN = v_E[self.NR - 1] * \\\r\n                         torch.ones(self.NTC + self.NTI + self.NTRN, device=device)\r\n                v_I_TN = v_I[self.NR - 1] * \\\r\n                         torch.ones(self.NTC + self.NTI + self.NTRN, device=device)\r\n\r\n                v_E = torch.cat((v_E_CR, v_E_TN))\r\n                v_I = torch.cat((v_I_CR, v_I_TN))\r\n                self.Ichem[range_E] = self.Ichem[range_E] + self.dt / self.tau_I * \\\r\n                                      (-self.Ichem[range_E] + WEE * v_E[range_E]\r\n                                       + WIE * v_I[range_E])\r\n                self.Ichem[range_I] = self.Ichem[range_I] + self.dt / self.tau_I * \\\r\n                                      (-self.Ichem[range_I] + WII * v_I[range_I]\r\n                                       + WEI * v_E[range_I])\r\n                self.Igap[range_I] = Vgap * (\r\n                        v_sum - self.v[range_I])\r\n\r\n                self.v = self.v + self.dt / self.c_m * (-self.g_m * self.v +\r\n                                                        self.alpha_w * self.ad + self.Ieff + self.Ichem + self.Igap)\r\n                self.ad = self.ad + self.dt / self.tau_ad * (-self.ad + self.beta_ad * self.v)\r\n                self.vv = (self.v >= self.vt).float() * (self.vm1 < self.vt).float()\r\n                self.vm1 = self.v\r\n\r\n                Isp = torch.where(self.vv == 1)[0]\r\n                Iraster.append(torch.stack((t * torch.ones((len(Isp)), device=device), Isp), dim=1))\r\n\r\n                I_CR = torch.mean(self.Ichem[self.divide_point_CR], dim=1)\r\n                I_TN = torch.mean(self.Ichem[self.TC_range + self.TI_range + self.TRN_range]).unsqueeze(0)\r\n                I_subregion[:, t] = torch.cat((I_CR, I_TN), dim=0)\r\n\r\n            print('over')\r\n            torch.save(I_subregion.cpu(), f'./result/I_subregion_2_delay_{trail}.pt')\r\n\r\n            Iraster = torch.cat(Iraster, dim=0).cpu()\r\n\r\n            torch.save(Iraster, f'./result/raster_2_delay_{trail}.pt')\r\n\r\n\r\n\r\nW = torch.tensor(torch.load('./FLNe.pt')['W'], dtype=torch.float32, device=device)\r\nW = W + torch.eye(len(W), device=device)\r\nD = torch.load('./distance.pt')\r\nsimulation_model = brain_model_91(W, D)\r\nsimulation_model.simulation()\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/Corticothalamic_Brain_Model/Readme.md",
    "content": "The code for corticothalamic brain model. The connection matrix and simulation results are available in the follow link:\n\nhttps://drive.google.com/drive/folders/1oOAb-X_ag5feV8Q09_oFZbuoEd7uxIIo?usp=sharing\n"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/Corticothalamic_Brain_Model/spectrogram.py",
    "content": "import scipy.io as scio\r\nimport numpy as np\r\nimport pandas as pd\r\nimport matplotlib.pyplot as plt\r\nimport torch\r\nfrom mpl_toolkits.mplot3d import Axes3D\r\nfrom scipy.fftpack import fft,ifft\r\nfrom scipy import signal\r\nfrom scipy.fft import fftshift\r\n\r\ntrail = 4\r\nversion = 2\r\nIraster = torch.load(f'./result/raster_{version}_delay_{trail}.pt').cpu()\r\ntime = Iraster[:, 0]\r\nmask = (time >= 3000) & (time < 10000)\r\nindices = torch.where(mask)\r\nspike = Iraster[indices[0]]\r\nplt.figure(figsize=(20, 12))\r\nplt.scatter(spike[:, 0], spike[:, 1], s=0.1)\r\nplt.xlabel('time [ms]', fontsize=20)\r\nplt.ylabel('Neuron index', fontsize=20)\r\nplt.show(dpi=600)\r\n\r\ndata = np.array(torch.load(f'./result/I_subregion_{version}_delay_{trail}.pt').cpu())\r\nprint(data.shape)\r\n\r\nb, a = signal.butter(2, [0.002, 0.06], 'bandpass')    #配置滤波器 8 表示滤波器的阶数\r\ndata = signal.filtfilt(b, a, data)   #data为要过滤的信号\r\n\r\nfs = 1000\r\ntime_window = 1024\r\n# divide = torch.load('./neuron_divide.pt')\r\n# divide_E = divide['divide_point_E']\r\n# print(divide_E)\r\n\r\nbrain_map = ['2','5','24c','46d','7A','7B','7m','8B','8l',\r\n             '8m','9/46d','9/46v','10','DP','F1','F2','F5','F7',\r\n             'MT','PBr','ProM','STPc','STPi','STPr','TEO','TEpd',\r\n             'V1','V2','V4','TH']\r\n\r\ndef region_sxx(region):\r\n    # plt.figure()\r\n    # plt.plot(data[region])\r\n    # plt.show()\r\n    plt.figure(figsize=(16, 8))\r\n\r\n    f, t, sxx = signal.stft(data[region], fs=fs, nperseg=time_window, noverlap=time_window / 2)\r\n    print(sxx.shape)\r\n    cm = plt.cm.get_cmap('jet')\r\n    #plt.pcolormesh(t, f[2:10], np.abs(sxx[2:10]), cmap=cm, shading='auto')\r\n    plt.contourf(t, f[0:30], np.abs(sxx[0:30]), cmap=cm, levels=200)\r\n    plt.colorbar()\r\n    plt.xlabel('time/min', fontsize=20)\r\n    plt.ylabel('Frequency/Hz', fontsize=20)\r\n    plt.xticks(fontsize=15)\r\n    plt.yticks(fontsize=15)\r\n    plt.show()\r\n\r\n\r\ndef global_sxx():\r\n    plt.figure(figsize=(16,8))\r\n    global_eeg = np.mean(data, axis=0)\r\n    f, t, sxx = signal.stft(global_eeg, fs=fs, nperseg=time_window, noverlap=time_window / 2)\r\n    print(sxx.shape)\r\n    cm = plt.cm.get_cmap('jet')\r\n    #plt.pcolormesh(t, f[2:10], np.abs(sxx[2:10]), cmap=cm, shading='auto')\r\n    plt.contourf(t, f[0:30], np.abs(sxx[0:30]), cmap=cm, levels=200)\r\n    plt.colorbar()\r\n    plt.xlabel('time/min', fontsize=20)\r\n    plt.ylabel('Frequency/Hz', fontsize=20)\r\n    plt.xticks(fontsize=15)\r\n    plt.yticks(fontsize=15)\r\n    plt.show()\r\n\r\n\r\ndef compare_sxx():\r\n\r\n    f, t, sxx = signal.stft(data[0], fs=fs, nperseg=time_window, noverlap=time_window / 2)\r\n\r\n    f_band = range(0, 30)\r\n\r\n    sm = np.max(np.abs(sxx[f_band]), axis=0)\r\n\r\n    for col in range(1, 30):\r\n        f, t, sxx = signal.stft(data[col], fs=fs, nperseg=time_window, noverlap=time_window / 2)\r\n        sm = np.vstack((sm, np.max(np.abs(sxx[f_band]), axis=0)))\r\n\r\n    cm = plt.cm.get_cmap('jet')\r\n    plt.pcolormesh(t, brain_map, np.abs(sm), cmap=cm, shading='auto')\r\n    #plt.pcolormesh(t, f, sxx[5:50,:],cmap=cm)\r\n    plt.colorbar()\r\n    plt.ylabel('Brain Regions', fontsize=10)\r\n    plt.xlabel('Time [min]', fontsize=10)\r\n    plt.xticks(fontsize=10)\r\n    plt.yticks(fontsize=10)\r\n    plt.show()\r\n\r\n    return np.abs(sm)\r\n\r\n\r\nregion_sxx(7)\r\n\r\n# global_sxx()\r\n\r\n"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/HumanBrain/README.md",
    "content": "# Human Brain Simulation\n\n## Description\nHuman Brain Simulation is a large scale brain modeling framework depending on braincog framework.\n\n## Requirements:\n* numpy >= 1.21.2\n* scipy >= 1.8.0\n* h5py >= 3.6.0\n* torch >= 1.10\n* torchvision >= 0.12.0\n* torchaudio  >= 0.11.0\n* timm >= 0.5.4\n* matplotlib >= 3.5.1\n* einops >= 0.4.1\n* thop >= 0.0.31\n* pyyaml >= 6.0\n* loris >= 0.5.3\n* pandas >= 1.4.2  \n* tonic (special)\n* pandas >= 1.4.2  \n\n## Input:\n\nThe 88 regions' connectivity matrix can be obtained from the following link:\n[https://drive.google.com/file/d/1f8fpXgR8X07HrJ7G9DwMAl8K0naPcxJC/view?usp=sharing](https://drive.google.com/file/d/1tLHxCtm2kawKVvJ1BhAbkFKeyxcrJwnO/view?usp=sharing)\n\nThe source of this connectivity matrix is in the following link:\nhttps://www.nitrc.org/frs/?group_id=432\n\n## Example:\n\n```shell \ncd ~/examples/Multi-scale Brain Structure Simulation/HumanBrain/\npython human_brain.py\n```\n\n## Parameters:\nThe parameters are similar to mouse brain simulation \n\n\n"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/HumanBrain/human_brain.py",
    "content": "import time\n\nimport numpy as np\nimport scipy.io as scio\nimport torch\nfrom torch import nn\nfrom braincog.base.node.node import *\nfrom braincog.base.brainarea.BrainArea import *\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom numpy import genfromtxt\n\ndevice = 'cuda:0'\n\nclass Syn(nn.Module):\n    def __init__(self, syn, weight, neuron_num, tao_d, tao_r, dt, device):\n        super().__init__()\n        self.pre = syn[1]\n        self.post = syn[0]\n        self.syn_num = len(syn)\n        self.w = torch.sparse_coo_tensor(syn.t(), weight,\n                                         size=(neuron_num, neuron_num))\n        self.tao_d = tao_d\n        self.tao_r = tao_r\n        self.dt = dt\n        self.lamda_d = self.dt / self.tao_d\n        self.lamda_r = self.dt / self.tao_r\n\n        self.s = torch.zeros(neuron_num, device=device)\n        self.r = torch.zeros(neuron_num, device=device)\n        self.dt = dt\n\n    def forward(self, neuron):\n        neuron.Iback = neuron.Iback + neuron.dt_over_tau * (\n                torch.randn(neuron.neuron_num, device=device, requires_grad=False) - neuron.Iback)\n        neuron.Ieff = neuron.Iback / neuron.sqrt_coeff * neuron.sig + neuron.mu\n        self.s = self.s + self.lamda_r * (-self.s + 1 / self.tao_d * neuron.spike)\n        self.r = self.r - self.lamda_d * self.r + self.dt * self.s\n        self.I = torch.sparse.mm(self.w, self.r.unsqueeze(-1)).squeeze() + neuron.Ieff\n        return self.I\n\nclass brain(nn.Module):\n    def __init__(self, syn, weight, neuron_model, p_neuron, dt, device):\n        super().__init__()\n        if neuron_model == 'HH':\n            self.neurons = HHNode(p_neuron, dt, device)\n        elif neuron_model == 'aEIF':\n            self.neurons = aEIF(p_neuron, dt, device)\n        self.neuron_num = len(p_neuron[0])\n        self.syns = Syn(syn, weight, self.neuron_num, 3, 6, dt, device)\n\n    def forward(self, inputs):\n        I = self.syns(self.neurons)\n        self.neurons(I)\n\n\ndef brain_region(neuron_num):\n    region = []\n    start = 0\n    end = 0\n    for i in range(len(neuron_num)):\n        end += neuron_num[i].item()\n        region.append([start, end])\n        start = end\n    return torch.tensor(region)\n\ndef neuron_type(neuron_num, ratio, regions):\n    neuron_num = neuron_num.reshape(-1, 1)\n    neuron_type = torch.floor(ratio * neuron_num).int() + regions[:, 0].reshape(-1, 1)\n    return neuron_type\n\ndef syn_within_region(syn_num, region):\n    start = 1\n    for neurons in region:\n        if start:\n            syn = torch.randint(neurons[0], neurons[1],\n                            size=((neurons[1]-neurons[0]) * syn_num, 2), device=device)\n            start = 0\n        else:\n            syn = torch.concatenate((syn, torch.randint(neurons[0], neurons[1],\n                            size=((neurons[1]-neurons[0]) * syn_num, 2), device=device)))\n    return syn\n\ndef syn_cross_region(weight_matrix, region):\n    start = 1\n    for i in range(len(weight_matrix)):\n        for j in range(len(weight_matrix)):\n            if weight_matrix[i][j] < 10:\n                continue\n            else:\n                pre = torch.randint(region[j][0], region[j][1],\n                                    size=(weight_matrix[i][j], 1), device=device)\n                post = torch.randint(region[i][0], region[i][1],\n                                     size=(weight_matrix[i][j], 1), device=device)\n                if start:\n                    syn = torch.concatenate((post, pre), dim=1)\n                    start = 0\n                else:\n                    syn = torch.concatenate((syn, torch.concatenate((post, pre), dim=1)))\n    return syn\n\nsize = 500\nneuron_model = 'HH'\nweight_matrix = np.load('./IIT_connectivity_matrix.npy')\nweight_matrix = torch.from_numpy(weight_matrix)\n\nNR = len(weight_matrix)\ndata = size * np.ones(NR)\nneuron_num = np.array(data).astype(np.int32)\nneuron_num = torch.from_numpy(neuron_num)\nregions = brain_region(neuron_num)\nratio = torch.tensor([[0.7, 0.9, 1.0] * NR]).reshape(NR, 3)\nneuron_types = neuron_type(neuron_num, ratio, regions)\nsyn_1 = syn_within_region(10, regions)\nsyn_2 = syn_cross_region(weight_matrix, regions)\nsyn = torch.concatenate((syn_1, syn_2))\nprint(syn.shape)\nweight = -torch.ones(len(syn), device=device, requires_grad=False)\nif neuron_model == 'aEIF':\n    threshold = torch.zeros(regions[-1][1], device=device, requires_grad=False)\n    v_reset = torch.zeros(regions[-1][1], device=device, requires_grad=False)\n    c_m = torch.zeros(regions[-1][1], device=device, requires_grad=False)\n    tao_w = torch.zeros(regions[-1][1], device=device, requires_grad=False)\n    alpha_ad = torch.zeros(regions[-1][1], device=device, requires_grad=False)\n    beta_ad = torch.zeros(regions[-1][1], device=device, requires_grad=False)\nelif neuron_model == 'HH':\n    threshold = torch.zeros(regions[-1][1], device=device, requires_grad=False)\nfor i in range(len(neuron_types)):\n    pre = syn[:, 0]\n    mask = (pre >= regions[i][0]) & (pre < neuron_types[i][0])\n    indices = torch.where(mask)\n    weight[indices] = 1.5\n    if neuron_model == 'aEIF':\n        if i < 177:\n            threshold[regions[i][0]:neuron_types[i][0]] = -50\n            threshold[neuron_types[i][0]:neuron_types[i][1]] = -44\n            threshold[neuron_types[i][1]:neuron_types[i][2]] = -45\n            v_reset[regions[i][0]:neuron_types[i][0]] = -110\n            v_reset[neuron_types[i][0]:neuron_types[i][1]] = -110\n            v_reset[neuron_types[i][1]:neuron_types[i][2]] = -66\n            c_m[regions[i][0]:neuron_types[i][0]] = 10\n            c_m[neuron_types[i][0]:neuron_types[i][1]] = 10\n            c_m[neuron_types[i][1]:neuron_types[i][2]] = 8.5\n            tao_w[regions[i][0]:neuron_types[i][0]] = 1\n            tao_w[neuron_types[i][0]:neuron_types[i][1]] = 2\n            tao_w[neuron_types[i][1]:neuron_types[i][2]] = 2\n            alpha_ad[regions[i][0]:neuron_types[i][0]] = 0\n            alpha_ad[neuron_types[i][0]:neuron_types[i][1]] = -0.2\n            alpha_ad[neuron_types[i][1]:neuron_types[i][2]] = -0.2\n            beta_ad[regions[i][0]:neuron_types[i][0]] = 0\n            beta_ad[neuron_types[i][0]:neuron_types[i][1]] = 0.45\n            beta_ad[neuron_types[i][1]:neuron_types[i][2]] = 0.45\n        else:\n            threshold[regions[i][0]:neuron_types[i][0]] = -50\n            threshold[neuron_types[i][0]:neuron_types[i][1]] = -50\n            threshold[neuron_types[i][1]:neuron_types[i][2]] = -45\n            v_reset[regions[i][0]:neuron_types[i][0]] = -60\n            v_reset[neuron_types[i][0]:neuron_types[i][1]] = -60\n            v_reset[neuron_types[i][1]:neuron_types[i][2]] = -65\n            c_m[regions[i][0]:neuron_types[i][0]] = 20\n            c_m[neuron_types[i][0]:neuron_types[i][1]] = 2\n            c_m[neuron_types[i][1]:neuron_types[i][2]] = 4\n            tao_w[regions[i][0]:neuron_types[i][0]] = 1\n            tao_w[neuron_types[i][0]:neuron_types[i][1]] = 2\n            tao_w[neuron_types[i][1]:neuron_types[i][2]] = 2\n            alpha_ad[regions[i][0]:neuron_types[i][0]] = 0\n            alpha_ad[neuron_types[i][0]:neuron_types[i][1]] = -0.2\n            alpha_ad[neuron_types[i][1]:neuron_types[i][2]] = -0.2\n            beta_ad[regions[i][0]:neuron_types[i][0]] = 0\n            beta_ad[neuron_types[i][0]:neuron_types[i][1]] = 0.45\n            beta_ad[neuron_types[i][1]:neuron_types[i][2]] = 0.45\n    elif neuron_model == 'HH':\n        threshold[regions[i][0]:neuron_types[i][0]] = 50\n        threshold[neuron_types[i][0]:neuron_types[i][1]] = 60\n        threshold[neuron_types[i][1]:neuron_types[i][2]] = 60\n\nif neuron_model == 'aEIF':\n    p_neuron = [threshold, v_reset, c_m, tao_w, alpha_ad, beta_ad]\n    dt = 1\n    T = 300\nelif neuron_model == 'HH':\n    p_neuron = [threshold, 120, 36, 0.3, 115, -12, 10.6, 1]\n    dt = 0.01\n    T = 10000\nmodel = brain(syn, weight, neuron_model, p_neuron, dt, device)\nIraster = []\nfor t in range(T):\n    model(0)\n    print(torch.sum(model.neurons.spike))\n    Isp = torch.nonzero(model.neurons.spike)\n    print(len(Isp))\n    if (len(Isp) != 0):\n        left = t * torch.ones((len(Isp)), device=device, requires_grad=False)\n        left = left.reshape(len(left), 1)\n        mide = torch.concatenate((left, Isp), dim=1)\n    if (len(Isp) != 0) and (len(Iraster) != 0):\n        Iraster = torch.concatenate((Iraster, mide), dim=0)\n    if (len(Iraster) == 0) and (len(Isp) != 0):\n        Iraster = mide\n\nIraster = torch.tensor(Iraster).transpose(0, 1)\ntorch.save(Iraster, \"./human.pt\")\n"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/HumanBrain/human_multi.py",
    "content": "import time\n\nimport numpy as np\nimport scipy.io as scio\nimport torch\nfrom torch import nn\nfrom braincog.base.node.node import *\nfrom braincog.base.brainarea.BrainArea import *\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom numpy import genfromtxt\n\ndevice_ids = [0,2,3,4,5,7,8,9]\n\ndevice = 'cuda:0'\n\n\n\nclass MultiCompartmentaEIF(BaseNode):\n    \"\"\"\n    双房室神经元模型\n    :param threshold: 神经元发放脉冲需要达到的阈值\n    :param v_reset: 静息电位\n    :param tau: 胞体膜电位时间常数, 用于控制胞体膜电位衰减\n    :param tau_basal: 基底树突膜电位时间常数, 用于控制基地树突胞体膜电位衰减\n    :param tau_apical: 远端树突膜电位时间常数, 用于控制远端树突胞体膜电位衰减\n    :param comps: 神经元不同房室, 例如[\"apical\", \"soma\"]\n    :param act_fun: 脉冲梯度代理函数\n    \"\"\"\n    def __init__(self,\n                 p,\n                 dt,\n                 tau=2.0,\n                 tau_basal=2.0,\n                 tau_apical=2.0,\n                 act_fun=AtanGrad, *args, **kwargs):\n        g_B = 0.6\n        g_L = 0.05\n        super().__init__(threshold=p[0], *args, **kwargs)\n        self.neuron_num = len(p[0])\n        self.tau = 2.0\n        self.tau_basal = 20.0\n        self.tau_apical = 2.0\n        self.spike = torch.zeros(self.neuron_num, device=device, requires_grad=False)\n        self.v_reset = p[1]  # membrane potential reset to v_reset after fire spike\n         # Initialize membrane potentials\n        self.tau_I = 3.0\n        self.sig = 12.0\n        self.mu = 10.0\n        self.dt=dt\n        self.dt_over_tau = self.dt / self.tau_I\n        self.mems = {}\n        self.mems['soma'] = torch.ones(self.neuron_num, device=device) * self.v_reset\n        self.mems['apical'] = torch.ones(self.neuron_num, device=device) * self.v_reset\n        self.act_fun = act_fun(alpha=self.tau, requires_grad=False)\n        self.Iback = torch.zeros(self.neuron_num, device=device, requires_grad=False)\n        self.Ieff = torch.zeros(self.neuron_num, device=device, requires_grad=False)\n        self.sqrt_coeff = math.sqrt(1 / (2 * (1 / self.dt_over_tau)))\n        \n        \n    \n    def integral(self,apical_inputs):\n        '''\n        Params:\n            inputs torch.Tensor: Inputs for basal dendrite  \n        '''\n        self.mems['apical'] =  (self.mems['apical'] + apical_inputs) / self.tau_apical\n        self.mems['soma'] = self.mems['soma'] + (self.mems['apical'] - self.mems['soma']) / self.tau\n\n\n    def calc_spike(self):\n        self.spike = self.act_fun(self.mems['soma'] - self.threshold)\n        self.mems['soma'] = self.mems['soma']  * (1. - self.spike.detach())\n        self.mems['apical'] = self.mems['apical']  * (1. - self.spike.detach())\n    def forward(self, inputs):\n\n        # aeifnode_cuda.forward(self.threshold, self.c_m, self.alpha_w, self.beta_ad, inputs, self.ref, self.ad, self.mem, self.spike)\n        self.integral(inputs)\n        self.calc_spike()\n\n        return self.spike, self.mems['soma']\n\n\n\nclass aEIF(BaseNode):\n    \"\"\"\n        The adaptive Exponential Integrate-and-Fire model (aEIF)\n        This class define the membrane, spike, current and parameters of a neuron group of a specific type\n        :param args: Other parameters\n        :param kwargs: Other parameters\n    \"\"\"\n\n    def __init__(self, p, dt, device, *args, **kwargs):\n        \"\"\"\n            p:[threshold, v_reset, c_m, tao_w, alpha_ad, beta_ad]\n            c_m: Membrane capacitance\n            alpha_w: Coupling of the adaptation variable\n            beta_ad: Conductance of the adaptation variable\n            mu: Mean of back current\n            sig: Variance of back current\n            if_IN: if the neuron type is inhibitory neuron, it has gap-junction\n\n            neuron_num: number of neurons in this group\n            W: connection weight for the neuron groups connected to this group\n            type_index: the index of this type of neuron group in the brain region\n\n        \"\"\"\n        super().__init__(threshold=p[0], requires_fp=False, *args, **kwargs)\n        self.neuron_num = len(p[0])\n        self.g_m = 0.1  # neuron conduction\n        self.dt = dt\n        self.tau_I = 3  # Time constant to filter the synaptic inputs\n        self.Delta_T = 0.5  # parameter\n        self.v_reset = p[1]  # membrane potential reset to v_reset after fire spike\n        self.c_m = p[2]\n        self.tau_w = p[3]  # Time constant of adaption coupling\n        self.alpha_ad = p[4]\n        self.beta_ad = p[5]\n        self.refrac = 5 / self.dt  # refractory period\n        self.dt_over_tau = self.dt / self.tau_I\n        self.sqrt_coeff = math.sqrt(1 / (2 * (1 / self.dt_over_tau)))\n        self.mem = self.v_reset\n        self.spike = torch.zeros(self.neuron_num, device=device, requires_grad=False)\n        self.ad = torch.zeros(self.neuron_num, device=device, requires_grad=False)\n        self.ref = torch.randint(0, int(self.refrac + 1), (1, self.neuron_num), device=device, requires_grad=False).squeeze(\n            0)  # refractory counter\n        self.ref = self.ref.float()\n        self.mu = 10\n        self.sig = 12\n        self.Iback = torch.zeros(self.neuron_num, device=device, requires_grad=False)\n        self.Ieff = torch.zeros(self.neuron_num, device=device, requires_grad=False)\n\n    def integral(self, inputs):\n\n        self.mem = self.mem + (self.ref > self.refrac) * self.dt / self.c_m * \\\n                   (-self.g_m * (self.mem - self.v_reset) + self.g_m * self.Delta_T *\n                    torch.exp((self.mem - self.threshold) / self.Delta_T) +\n                    self.alpha_ad * self.ad + inputs)\n\n        self.ad = self.ad + (self.ref > self.refrac) * self.dt / self.tau_w * \\\n                  (-self.ad + self.beta_ad * (self.mem - self.v_reset))\n\n    def calc_spike(self):\n        self.spike = (self.mem > self.threshold).float()\n        self.ref = self.ref * (1 - self.spike) + 1\n        self.ad = self.ad + self.spike * 30\n        self.mem = self.spike * self.v_reset + (1 - self.spike.detach()) * self.mem\n\n    def forward(self, inputs):\n\n        # aeifnode_cuda.forward(self.threshold, self.c_m, self.alpha_w, self.beta_ad, inputs, self.ref, self.ad, self.mem, self.spike)\n        self.integral(inputs)\n        self.calc_spike()\n\n        return self.spike, self.mem\n\nclass HHNode(BaseNode):\n    \"\"\"\n    简单版本的HH模型\n    :param threshold: 神经元发放脉冲需要达到的阈值\n    :param v_reset: 静息电位\n    :param dt: 时间步长\n    :param step: 仿真步\n    :param tau: 膜电位时间常数, 用于控制膜电位衰减\n    :param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``\n    :param args: 其他的参数\n    :param kwargs: 其他的参数\n    \"\"\"\n\n    def __init__(self, p, dt, device, act_fun=AtanGrad, *args, **kwargs):\n        super().__init__(threshold=p[0], *args, **kwargs)\n        if isinstance(act_fun, str):\n            act_fun = eval(act_fun)\n        '''\n        I = Cm dV/dt + g_k*n^4*(V_m-V_k) + g_Na*m^3*h*(V_m-V_Na) + g_l*(V_m - V_L)\n        '''\n        self.neuron_num = len(p[0])\n        self.act_fun = act_fun(alpha=2., requires_grad=False)\n        self.tau_I = 3\n        self.g_Na = torch.tensor(p[1])\n        self.g_K = torch.tensor(p[2])\n        self.g_l = torch.tensor(p[3])\n        self.V_Na = torch.tensor(p[4])\n        self.V_K = torch.tensor(p[5])\n        self.V_l = torch.tensor(p[6])\n        self.C = torch.tensor(p[7])\n        self.m = 0.05 * torch.ones(self.neuron_num, device=device, requires_grad=False)\n        self.n = 0.31 * torch.ones(self.neuron_num, device=device, requires_grad=False)\n        self.h = 0.59 * torch.ones(self.neuron_num, device=device, requires_grad=False)\n        self.v_reset = 0\n        self.dt = dt\n        self.dt_over_tau = self.dt / self.tau_I\n        self.sqrt_coeff = math.sqrt(1 / (2 * (1 / self.dt_over_tau)))\n        self.mu = 10\n        self.sig = 12\n\n        self.mem = torch.tensor(self.v_reset, device=device, requires_grad=False)\n        self.mem_p = self.mem\n        self.spike = torch.zeros(self.neuron_num, device=device, requires_grad=False)\n        self.Iback = torch.zeros(self.neuron_num, device=device, requires_grad=False)\n        self.Ieff = torch.zeros(self.neuron_num, device=device, requires_grad=False)\n\n    def integral(self, inputs):\n        self.alpha_n = (0.1 - 0.01 * self.mem) / (torch.exp(1 - 0.1 * self.mem) - 1)\n        self.alpha_m = (2.5 - 0.1 * self.mem) / (torch.exp(2.5 - 0.1 * self.mem) - 1)\n        self.alpha_h = 0.07 * torch.exp(-self.mem / 20.0)\n\n        self.beta_n = 0.125 * torch.exp(-self.mem / 80.0)\n        self.beta_m = 4.0 * torch.exp(-self.mem / 18.0)\n        self.beta_h = 1 / (torch.exp(3 - 0.1 * self.mem) + 1)\n\n        self.tau_n = 1.0 / (self.alpha_n + self.beta_n)\n        self.inf_n = self.alpha_n * self.tau_n\n\n        self.tau_m = 1.0 / (self.alpha_m + self.beta_m)\n        self.inf_m = self.alpha_m * self.tau_m\n\n        self.tau_h = 1.0 / (self.alpha_h + self.beta_h)\n        self.inf_h = self.alpha_h * self.tau_h\n\n        self.n = (1 - self.dt / self.tau_n) * self.n + (self.dt / self.tau_n) * self.inf_n\n        self.m = (1 - self.dt / self.tau_m) * self.m + (self.dt / self.tau_m) * self.inf_m\n        self.h = (1 - self.dt / self.tau_h) * self.h + (self.dt / self.tau_h) * self.inf_h\n\n        # self.n = self.n + self.dt * (self.alpha_n * (1 - self.n) - self.beta_n * self.n)\n        # self.m = self.m + self.dt * (self.alpha_m * (1 - self.m) - self.beta_m * self.m)\n        # self.h = self.h + self.dt * (self.alpha_h * (1 - self.h) - self.beta_h * self.h)\n\n        self.I_Na = torch.pow(self.m, 3) * self.g_Na * self.h * (self.mem - self.V_Na)\n        self.I_K = torch.pow(self.n, 4) * self.g_K * (self.mem - self.V_K)\n        self.I_L = self.g_l * (self.mem - self.V_l)\n\n        self.mem_p = self.mem\n        self.mem = self.mem + self.dt * (inputs - self.I_Na - self.I_K - self.I_L) / self.C\n        # self.mem = self.mem + self.dt * (inputs - self.I_K - self.I_L) / self.C\n\n    def calc_spike(self):\n        self.spike = (self.threshold > self.mem_p).float() * (self.mem > self.threshold).float()\n\n    def forward(self, inputs):\n        self.integral(inputs)\n        self.calc_spike()\n        return self.spike, self.mem\n\n    def requires_activation(self):\n        return False\n\nclass Syn(nn.Module):\n    def __init__(self, syn, weight, neuron_num, tao_d, tao_r, dt, device):\n        super().__init__()\n        self.pre = syn[1]\n        self.post = syn[0]\n        self.syn_num = len(syn)\n        self.w = torch.sparse_coo_tensor(syn.t(), weight,\n                                         size=(neuron_num, neuron_num))\n        self.tao_d = tao_d\n        self.tao_r = tao_r\n        self.dt = dt\n        self.lamda_d = self.dt / self.tao_d\n        self.lamda_r = self.dt / self.tao_r\n\n        self.s = torch.zeros(neuron_num, device=device)\n        self.r = torch.zeros(neuron_num, device=device)\n        self.dt = dt\n\n    def forward(self, neuron):\n        neuron.Iback = neuron.Iback + neuron.dt_over_tau * (\n                torch.randn(neuron.neuron_num, device=device, requires_grad=False) - neuron.Iback)\n        neuron.Ieff = neuron.Iback / neuron.sqrt_coeff * neuron.sig + neuron.mu\n        self.s = self.s + self.lamda_r * (-self.s + 1 / self.tao_d * neuron.spike)\n        self.r = self.r - self.lamda_d * self.r + self.dt * self.s\n        self.I = torch.sparse.mm(self.w, self.r.unsqueeze(-1)).squeeze() + neuron.Ieff\n        return self.I\n\nclass brain(nn.Module):\n    def __init__(self, syn, weight, neuron_model, p_neuron, dt, device):\n        super().__init__()\n        if neuron_model == 'HH':\n            self.neurons = HHNode(p_neuron, dt, device)\n        elif neuron_model == 'aEIF':\n            self.neurons = aEIF(p_neuron, dt, device)\n        elif neuron_model == 'MultiCompartmentaEIF':\n            self.neurons = MultiCompartmentaEIF(p_neuron,dt,device)\n        self.neuron_num = len(p_neuron[0])\n        self.syns = Syn(syn, weight, self.neuron_num, 3.0, 6.0, dt, device)\n\n    def forward(self, inputs):\n        I = self.syns(self.neurons)\n        self.neurons(I)\n\n\ndef brain_region(neuron_num):\n    region = []\n    start = 0\n    end = 0\n    for i in range(len(neuron_num)):\n        end += neuron_num[i].item()\n        region.append([start, end])\n        start = end\n    return torch.tensor(region)\n\ndef neuron_type(neuron_num, ratio, regions):\n    neuron_num = neuron_num.reshape(-1, 1)\n    neuron_type = torch.floor(ratio * neuron_num).int() + regions[:, 0].reshape(-1, 1)\n    return neuron_type\n\ndef syn_within_region(syn_num, region):\n    start = 1\n    for neurons in region:\n        if start:\n            syn = torch.randint(neurons[0], neurons[1],\n                            size=((neurons[1]-neurons[0]) * syn_num, 2), device=device)\n            start = 0\n        else:\n            syn = torch.concat((syn, torch.randint(neurons[0], neurons[1],\n                            size=((neurons[1]-neurons[0]) * syn_num, 2), device=device)))\n    return syn\n\ndef syn_cross_region(weight_matrix, region):\n    start = 1\n    for i in range(len(weight_matrix)):\n        for j in range(len(weight_matrix)):\n            if weight_matrix[i][j] < 10:\n                continue\n            else:\n                pre = torch.randint(region[j][0], region[j][1],\n                                    size=(weight_matrix[i][j], 1), device=device)\n                post = torch.randint(region[i][0], region[i][1],\n                                     size=(weight_matrix[i][j], 1), device=device)\n                if start:\n                    syn = torch.concat((post, pre), dim=1)\n                    start = 0\n                else:\n                    syn = torch.concat((syn, torch.concat((post, pre), dim=1)))\n    return syn\n\nsize = 100\nneuron_model = 'MultiCompartmentaEIF'\n\nweight_matrix = torch.from_numpy(np.load(\"IIT_connectivity_matrix.npy\")[0:84,0:84])\nweight_matrix = weight_matrix.int() * 10\n# weight_matrix = np.load('./IIT_connectivity_matrix.npy')\n# weight_matrix = torch.from_numpy(weight_matrix)\n\nNR = len(weight_matrix)\ndata = size * np.ones(NR)\nneuron_num = np.array(data).astype(np.int32)\nneuron_num = torch.from_numpy(neuron_num)\nprint(torch.sum(neuron_num))\nregions = brain_region(neuron_num)\nratio = torch.tensor([[0.7, 0.9, 1.0] * NR]).reshape(NR, 3)\nneuron_types = neuron_type(neuron_num, ratio, regions)\nsyn_1 = syn_within_region(10, regions)\nsyn_2 = syn_cross_region(weight_matrix, regions)\nsyn = torch.concat((syn_1, syn_2))\nprint(len(syn_2))\n\n\nprint(syn.shape)\nweight = -torch.ones(len(syn), device=device, requires_grad=False)\nif neuron_model == 'aEIF':\n    threshold = torch.zeros(regions[-1][1], device=device, requires_grad=False)\n    v_reset = torch.zeros(regions[-1][1], device=device, requires_grad=False)\n    c_m = torch.zeros(regions[-1][1], device=device, requires_grad=False)\n    tao_w = torch.zeros(regions[-1][1], device=device, requires_grad=False)\n    alpha_ad = torch.zeros(regions[-1][1], device=device, requires_grad=False)\n    beta_ad = torch.zeros(regions[-1][1], device=device, requires_grad=False)\nelif neuron_model == 'HH':\n    threshold = torch.zeros(regions[-1][1], device=device, requires_grad=False)\n\nif neuron_model == 'MultiCompartmentaEIF':\n    threshold = torch.zeros(regions[-1][1], device=device, requires_grad=False)\n    v_reset = torch.zeros(regions[-1][1], device=device, requires_grad=False)\n    c_m = torch.zeros(regions[-1][1], device=device, requires_grad=False)\n    tao_w = torch.zeros(regions[-1][1], device=device, requires_grad=False)\n    alpha_ad = torch.zeros(regions[-1][1], device=device, requires_grad=False)\n    beta_ad = torch.zeros(regions[-1][1], device=device, requires_grad=False)\nfor i in range(len(neuron_types)):\n    pre = syn[:, 0]\n    mask = (pre >= regions[i][0]) & (pre < neuron_types[i][0])\n    indices = torch.where(mask)\n    weight[indices] = 1.5\n    if neuron_model == 'aEIF':\n        if i < 70:\n            threshold[regions[i][0]:neuron_types[i][0]] = -50\n            threshold[neuron_types[i][0]:neuron_types[i][1]] = -44\n            threshold[neuron_types[i][1]:neuron_types[i][2]] = -45\n            v_reset[regions[i][0]:neuron_types[i][0]] = -110\n            v_reset[neuron_types[i][0]:neuron_types[i][1]] = -110\n            v_reset[neuron_types[i][1]:neuron_types[i][2]] = -110\n            c_m[regions[i][0]:neuron_types[i][0]] = 10\n            c_m[neuron_types[i][0]:neuron_types[i][1]] = 10\n            c_m[neuron_types[i][1]:neuron_types[i][2]] = 8.5\n            tao_w[regions[i][0]:neuron_types[i][0]] = 1\n            tao_w[neuron_types[i][0]:neuron_types[i][1]] = 2\n            tao_w[neuron_types[i][1]:neuron_types[i][2]] = 2\n            alpha_ad[regions[i][0]:neuron_types[i][0]] = 0\n            alpha_ad[neuron_types[i][0]:neuron_types[i][1]] = -0.2\n            alpha_ad[neuron_types[i][1]:neuron_types[i][2]] = -0.2\n            beta_ad[regions[i][0]:neuron_types[i][0]] = 0\n            beta_ad[neuron_types[i][0]:neuron_types[i][1]] = 0.45\n            beta_ad[neuron_types[i][1]:neuron_types[i][2]] = 0.45\n        else:\n            threshold[regions[i][0]:neuron_types[i][0]] = -50\n            threshold[neuron_types[i][0]:neuron_types[i][1]] = -50\n            threshold[neuron_types[i][1]:neuron_types[i][2]] = -45\n            v_reset[regions[i][0]:neuron_types[i][0]] = -100\n            v_reset[neuron_types[i][0]:neuron_types[i][1]] = -100\n            v_reset[neuron_types[i][1]:neuron_types[i][2]] = -105\n            c_m[regions[i][0]:neuron_types[i][0]] = 20\n            c_m[neuron_types[i][0]:neuron_types[i][1]] = 10\n            c_m[neuron_types[i][1]:neuron_types[i][2]] = 10\n            tao_w[regions[i][0]:neuron_types[i][0]] = 1\n            tao_w[neuron_types[i][0]:neuron_types[i][1]] = 2\n            tao_w[neuron_types[i][1]:neuron_types[i][2]] = 2\n            alpha_ad[regions[i][0]:neuron_types[i][0]] = 0\n            alpha_ad[neuron_types[i][0]:neuron_types[i][1]] = -0.2\n            alpha_ad[neuron_types[i][1]:neuron_types[i][2]] = -0.2\n            beta_ad[regions[i][0]:neuron_types[i][0]] = 0\n            beta_ad[neuron_types[i][0]:neuron_types[i][1]] = 0.45\n            beta_ad[neuron_types[i][1]:neuron_types[i][2]] = 0.45\n    elif neuron_model == 'HH':\n        threshold[regions[i][0]:neuron_types[i][0]] = 20\n        threshold[neuron_types[i][0]:neuron_types[i][1]] = 20\n        threshold[neuron_types[i][1]:neuron_types[i][2]] = 20\n    \n    elif neuron_model == 'MultiCompartmentaEIF':\n        if i < 70:\n            threshold[regions[i][0]:neuron_types[i][0]] = -50.0\n            threshold[neuron_types[i][0]:neuron_types[i][1]] = -44.0\n            threshold[neuron_types[i][1]:neuron_types[i][2]] = -45.0\n            v_reset[regions[i][0]:neuron_types[i][0]] = -110.0\n            v_reset[neuron_types[i][0]:neuron_types[i][1]] = -110.0\n            v_reset[neuron_types[i][1]:neuron_types[i][2]] = -110.0\n            c_m[regions[i][0]:neuron_types[i][0]] = 10.0\n            c_m[neuron_types[i][0]:neuron_types[i][1]] = 10.0\n            c_m[neuron_types[i][1]:neuron_types[i][2]] = 8.5\n            tao_w[regions[i][0]:neuron_types[i][0]] = 1\n            tao_w[neuron_types[i][0]:neuron_types[i][1]] = 2\n            tao_w[neuron_types[i][1]:neuron_types[i][2]] = 2\n            alpha_ad[regions[i][0]:neuron_types[i][0]] = 0\n            alpha_ad[neuron_types[i][0]:neuron_types[i][1]] = -0.2\n            alpha_ad[neuron_types[i][1]:neuron_types[i][2]] = -0.2\n            beta_ad[regions[i][0]:neuron_types[i][0]] = 0\n            beta_ad[neuron_types[i][0]:neuron_types[i][1]] = 0.45\n            beta_ad[neuron_types[i][1]:neuron_types[i][2]] = 0.45\n        else:\n            threshold[regions[i][0]:neuron_types[i][0]] = -50.0\n            threshold[neuron_types[i][0]:neuron_types[i][1]] = -50.0\n            threshold[neuron_types[i][1]:neuron_types[i][2]] = -45.0\n            v_reset[regions[i][0]:neuron_types[i][0]] = -100.0\n            v_reset[neuron_types[i][0]:neuron_types[i][1]] = -100.0\n            v_reset[neuron_types[i][1]:neuron_types[i][2]] = -105.0\n            c_m[regions[i][0]:neuron_types[i][0]] = 20\n            c_m[neuron_types[i][0]:neuron_types[i][1]] = 10\n            c_m[neuron_types[i][1]:neuron_types[i][2]] = 10\n            tao_w[regions[i][0]:neuron_types[i][0]] = 1\n            tao_w[neuron_types[i][0]:neuron_types[i][1]] = 2\n            tao_w[neuron_types[i][1]:neuron_types[i][2]] = 2\n            alpha_ad[regions[i][0]:neuron_types[i][0]] = 0\n            alpha_ad[neuron_types[i][0]:neuron_types[i][1]] = -0.2\n            alpha_ad[neuron_types[i][1]:neuron_types[i][2]] = -0.2\n            beta_ad[regions[i][0]:neuron_types[i][0]] = 0\n            beta_ad[neuron_types[i][0]:neuron_types[i][1]] = 0.45\n            beta_ad[neuron_types[i][1]:neuron_types[i][2]] = 0.45\n\nif neuron_model == 'aEIF':\n    p_neuron = [threshold, v_reset, c_m, tao_w, alpha_ad, beta_ad]\n    dt = 1\n    T = 2000\nelif neuron_model == 'HH':\n    p_neuron = [threshold, 120, 36, 0.3, 115, -12, 10.6, 1]\n    dt = 0.01\n    T = 10000\nelif neuron_model == 'MultiCompartmentaEIF':\n    p_neuron = [threshold, v_reset, c_m, tao_w, alpha_ad, beta_ad]\n    dt = 1.0\n    T = 2000\nmodel = brain(syn, weight, neuron_model, p_neuron, dt, device)\n# device_ids = [0,2,3,4,5,7,8,9]\n# model = nn.DataParallel(model, device_ids=device_ids)\nmodel.to(device)\n\ndef neuron_delete(model, rate):\n    neuron_idex = torch.arange(0, model.neuron_num)\n    delete_num = int(model.neuron_num * rate)\n    random_elements = neuron_idex[torch.randperm(model.neuron_num)[:delete_num]]\n    model.neurons.threshold[random_elements] = 1000\n    return model.neuron_num - delete_num\n\ndef syn_delete(model, rate):\n    indices = model.syns.w._indices()\n    values = model.syns.w._values()\n    delete_num = int(len(values) * rate)\n    syn_idex = torch.arange(0, len(values))\n    random_elements = syn_idex[torch.randperm(len(values))[:delete_num]]\n    new_values = values[random_elements]\n    new_indices = indices[:, random_elements]\n    new_w = torch.sparse_coo_tensor(new_indices, new_values, size=(model.neuron_num, model.neuron_num))\n    model.syns.w = new_w\n\ndef syn_strength(model, rate):\n    indices = model.syns.w._indices()\n    values = model.syns.w._values()\n    iex = torch.where(values>0)\n    values[iex] = values[iex] * rate\n    new_values = values\n    new_indices = indices\n    new_w = torch.sparse_coo_tensor(new_indices, new_values, size=(model.neuron_num, model.neuron_num))\n    model.syns.w = new_w\n\nIraster = []\nfire_rate = []\ncount_n = model.neuron_num\nfor t in range(T):\n    if t == int(T/4):\n        count_n = neuron_delete(model, 0.4)\n    if t == int(T/4 * 2):\n        syn_delete(model, 0.4)\n    if t == int(T/4 * 3):\n        syn_strength(model, 3)\n    model(0)\n    # print(torch.sum(model.neurons.spike))\n    Isp = torch.nonzero(model.neurons.spike)\n    print(len(Isp))\n    fire_rate.append(len(Isp)/count_n)\n    if (len(Isp) != 0):\n        left = t * torch.ones((len(Isp)), device=device, requires_grad=False)\n        left = left.reshape(len(left), 1)\n        mide = torch.concat((left, Isp), dim=1)\n    if (len(Isp) != 0) and (len(Iraster) != 0):\n        Iraster = torch.concat((Iraster, mide), dim=0)\n    if (len(Iraster) == 0) and (len(Isp) != 0):\n        Iraster = mide\n\ntorch.save(fire_rate, './fire_rate.pt')\nplt.plot(fire_rate)\nplt.xlabel('time/mm')\nplt.ylabel('fire_rate')\n# plt.axvline(x=[500, 1000, 1500], color='b', linestyle='--')\nplt.show()\nIraster = torch.tensor(Iraster).transpose(0, 1)\ntorch.save(Iraster, \"./human_MultiCompartmentaEIF100.pt\")\nIraster = Iraster.cpu()\nplt.figure(figsize=(15, 15))\nplt.scatter(Iraster[0], Iraster[1], c='k', marker='.', s=0.001)\nplt.savefig('mouse_MultiCompartmentaEIF100.png')\nplt.show()\n"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/Human_Brain_Model/NA.py",
    "content": "import numpy as np\nimport random\nimport math\nimport matplotlib.pyplot as plt\nimport matplotlib\n# matplotlib.use('TkAgg')\nimport scipy.io as scio\nimport pandas as pd\nimport torch\nimport networkx as nx\nfrom collections import defaultdict\nimport community as community_louvain\nfrom matplotlib.ticker import MaxNLocator, FuncFormatter\n\n\ndef histogram_entropy(data, bins='auto'):\n    \"\"\"\n    使用直方图法估计一维数据的熵。\n\n    参数:\n        data (np.ndarray): 一维数据数组。\n        bins (int or str): 直方图的分箱数，默认为 'auto'。\n\n    返回:\n        float: 估计的熵值。\n    \"\"\"\n    hist, bin_edges = np.histogram(data, bins=bins, density=True)\n\n    bin_width = bin_edges[1] - bin_edges[0]\n    prob = hist * bin_width\n\n    prob = prob[prob > 0]\n\n    entropy_value = -np.sum(prob * np.log(prob))\n\n    return entropy_value\n\ndef hub_degree(df, W_new):\n    degree = torch.sum(W_new, dim=0)\n    v, ind = torch.topk(degree, 10)\n    ind = ind.tolist()\n    plt.figure(figsize=(40, 18))\n    plt.bar(df['Identifier'].values, degree)\n    plt.bar(df['Identifier'].iloc[ind].values, degree[ind], color='r', label='Top 10 Degree')\n    plt.gca().yaxis.set_major_locator(MaxNLocator(integer=False, prune='lower', nbins=15))\n    plt.xticks(rotation=90, fontsize=30)\n    plt.ylabel('Degree', fontsize=40)\n    plt.yticks(fontsize=25)\n    plt.legend(fontsize=40)\n    xticks = plt.gca().get_xticklabels()\n    for i, tick in enumerate(xticks):\n        if df['Identifier'].iloc[i] in df['Identifier'].iloc[ind].values:\n            tick.set_color('r')\n    plt.grid(axis='y')\n    plt.show()\n\ndef visual(df, W_new):\n    x = df['x'].values\n    y = df['y'].values\n    z = df['z'].values\n    fig = plt.figure()\n    ax = fig.add_subplot(111, projection='3d')\n    ax.scatter(x, y, z)\n    for i in range(len(x)):\n        for j in range(i + 1, len(x)):\n            if W_new[i, j] > 0.1:\n                ax.plot([x[i], x[j]],\n                        [y[i], y[j]],\n                        [z[i], z[j]],\n                        'k-', lw=1)\n\n    plt.show()\n\nif __name__ ==  \"__main__\":\n    W = np.load('./IIT_connectivity_matrix.npy')\n    W = torch.from_numpy(W).float()\n    W = W[0:84, 0:84]\n    new_order = list(range(0, 35)) + list(range(49, 84)) + list(range(35, 49))\n    W_new = W[new_order, :][:, new_order]\n    M = torch.max(W_new)\n    W_new = W_new / M\n\n    G = nx.from_numpy_matrix(W_new.numpy())\n\n    # Louvain\n    partition = community_louvain.best_partition(G)\n\n    community_groups = defaultdict(list)\n\n    for node, community in partition.items():\n        community_groups[community].append(node)\n\n    df = pd.read_csv('brain_regions.csv')\n    labels = df['Identifier'].values\n    # for community, nodes in community_groups.items():\n    #     print(f\"Community {community}: {nodes}\")\n    fig, ax = plt.subplots(figsize=(20, 20))\n    cax = ax.imshow(W_new.cpu().numpy(), cmap='viridis')\n    ax.set_xticks(np.arange(len(labels)))\n    ax.set_yticks(np.arange(len(labels)))\n    ax.set_xticklabels(labels, rotation=90, fontsize=20)\n    ax.set_yticklabels(labels, fontsize=20)\n    fig.colorbar(cax, shrink=0.8)\n    # plt.tight_layout()\n    plt.show()\n    hub_degree(df, W_new)"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/Human_Brain_Model/Readme.md",
    "content": "main_84.py and main_246.py is the code that runs the simulation. The required data files can be obtained from the following link:\n\nhttps://drive.google.com/drive/folders/14KPqJsJXIo-bCmGCuDuBadLRmaYn78J2?usp=sharing"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/Human_Brain_Model/gc.py",
    "content": "import scipy.io as scio\nimport numpy as np\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom scipy import signal\nimport torch\nimport matplotlib.colors as mcolors\n\nscale = 1.2\nversion = 1\nEEG_m = np.array(torch.load(f'./result/I_subregion_{version}_{scale}.pt').cpu())\nEEG_c1 = np.load('./dataset/data_awake_1ug.npy')\nEEG_c2 = np.load('./dataset/data_2ug.npy')\nEEG_c3 = np.load('./dataset/data_3ug.npy')\nEEG_m = np.mean(EEG_m, axis=0)\n\nlowcut = 3.0  # 下截止频率 (Hz)\nhighcut = 30.0  # 上截止频率 (Hz)\nfs = 1000\n# 使用 Butterworth 滤波器设计带通滤波器\n# butter 函数的参数依次为：滤波器阶数，频率范围（归一化），滤波器类型\nb, a = signal.butter(4, [lowcut / (0.5 * fs), highcut / (0.5 * fs)], btype='band')\n\n# 应用滤波器（使用 filtfilt 实现零相位滤波）\nEEG_m = signal.filtfilt(b, a, EEG_m)\nEEG_C = EEG_c3\nEEG_C = signal.filtfilt(b, a, EEG_C)\nt = 80\nmat_all = np.zeros((t, 30))\nfor j in range(64):\n    mat = np.zeros((t, 30))\n    for i in range(t):\n        f, Cxy = signal.csd(EEG_m[i*1000:(i+1)*1000], EEG_C[i][j], fs=fs, nperseg=1024)\n        mat[i] = np.abs(Cxy[:30])\n    mat_all += mat / np.max(mat)\n\nplt.figure(figsize=(16,8))\nnorm = mcolors.LogNorm(vmin=0.001, vmax=1)\ncm = plt.cm.get_cmap('jet')\nplt.contourf(np.linspace(0, 8, 80) ,f[:30], mat_all.T / 64, cmap=cm, levels=200)\nplt.colorbar()\nplt.xlabel('time/min', fontsize=20)\nplt.ylabel('Frequency/Hz', fontsize=20)\nplt.xticks(fontsize=15)\nplt.yticks(fontsize=15)\nplt.show()\n"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/Human_Brain_Model/main_246.py",
    "content": "import numpy as np\nimport random\nimport math\nimport matplotlib.pyplot as plt\nimport scipy.io as scio\nimport pandas as pd\nimport torch\nimport tqdm\ndevice = 'cuda:3'\n\ntrail = 5\nversion = 4\nscale = 0.1\nclass brain_model():\n\n    def __init__(self, W):\n\n        self.weight_matrix = W.to(device)\n        self.distance_matrix = torch.zeros_like(self.weight_matrix, device=device)\n        self.speed = 1.5\n        self.decay = torch.ceil(self.distance_matrix / self.speed)\n        self.t_window = int(torch.max(self.decay)) + 1\n        self.V_th = torch.tensor([10., 4., 10., 4., 4.], device=device)\n        self.tau_v = torch.tensor([40., 10., 30., 20., 40.], device=device)\n        self.Tsig = torch.tensor([12., 10., 12., 10., 10.], device=device)\n        self.beta = torch.tensor([0., 4.5, 0., 4.5, 4.5], device=device)\n        self.alpha_ad = torch.tensor([0., -2., 0., -2., -2.], device=device)\n\n        self.tau_ad = 20\n        self.tau_I = 10\n\n        # Simulation parameters\n        self.NR = len(W)\n        self.C = 210\n        self.NN = 500\n        self.NType = self.NN * torch.tensor([0.80, 0.20, 0.01, 0.004, 0.004], device=device)\n        self.NE = 400\n        self.NI = 100\n        self.NTC = 30\n        self.NTI = 15\n        self.NTRN = 6\n        self.NC = self.NE + self.NI\n        self.NSum = int((self.C) * (self.NE + self.NI) + (self.NR - self.C) * (self.NTC + self.NTI + self.NTRN))\n        self.NT = self.NTC + self.NTI + self.NTRN\n        print(self.NSum)\n        print(self.NC)\n        print(self.NTI + self.NTC + self.NTRN)\n\n        self.Ncycle = 1\n        self.dt = 1\n        self.T = 8000\n        self.Delta_T = 0.5\n        # self.refrac = 5 / self.dt\n        # self.ref = self.refrac*torch.zeros((self.NN, 1)).squeeze(1)\n        self.gamma_c = 0.1\n        self.g_m = 1\n        self.Gama_c = self.g_m * self.gamma_c / (1 - self.gamma_c)\n        self.GammaII = 15\n        self.GammaIE = -10\n        self.GammaEE = 15\n        self.GammaEI = 15\n        self.TEmean = 0.5 * self.V_th[0]  # Mean current to excitatory neurons\n        self.TTCmean = 0.5 * self.V_th[2]  # Mean current to TC neurons\n        self.TImean = -5 * self.V_th[1]\n        self.TTImean = -5 * self.V_th[3]\n        self.TTRNmean = -5 * self.V_th[4]\n\n        self.v = torch.zeros(self.NSum, device=device)\n        self.vt = torch.zeros(self.NSum, device=device)\n        self.c_m = torch.zeros(self.NSum, device=device)\n        self.alpha_w = torch.zeros(self.NSum, device=device)\n        self.beta_ad = torch.zeros(self.NSum, device=device)\n        self.delta = torch.ones(self.NSum, device=device)\n        self.ad = torch.zeros(self.NSum, device=device)\n        self.vv = torch.zeros(self.NSum, device=device)\n        self.Iback = torch.zeros(self.NSum, device=device)\n        self.Istimu = torch.zeros(self.NSum, device=device)  # stimulate current\n        self.Ieff = torch.zeros(self.NSum, device=device)\n        self.Nmean = torch.zeros(self.NSum, device=device)\n        self.Nsig = torch.zeros(self.NSum, device=device)\n        self.Igap = torch.zeros(self.NSum, device=device)\n        self.Ichem = torch.zeros(self.NSum, device=device)\n        self.Ieeg = torch.zeros(self.NSum, device=device)\n        self.vm1 = torch.zeros(self.NSum, device=device)\n        self.reset = torch.zeros(self.NSum, device=device)\n\n        self.E_range = []\n        self.I_range = []\n        self.TC_range = []\n        self.TI_range = []\n        self.TRN_range = []\n        self.divide_point_E = []\n        self.divide_point_I = []\n        self.divide_point_TC = []\n        self.divide_point_TI_TRN = []\n\n        for n in range(self.NR):\n            if n < self.C:\n                self.divide_point_E.append(list(range(n * self.NC, n * self.NC + self.NE)))\n                self.divide_point_I.append(list(range(n * self.NC + self.NE, n * self.NC + self.NE + self.NI)))\n                self.E_range = self.E_range + list(range(n * self.NC, n * self.NC + self.NE))\n                self.I_range = self.I_range + list(\n                    range(n * self.NC + self.NE, n * self.NC + self.NE + self.NI))\n\n            else:\n                s = self.C * self.NC + (n-self.C) * self.NT\n                self.divide_point_TC.append(list(range(s, s + self.NTC)))\n                self.divide_point_TI_TRN.append(list(range(s + self.NTC,\n                                                           s + self.NTC + self.NTI))\n                                                + list(range(s + self.NTC + self.NTI,\n                                                             s + self.NTC + self.NTI + self.NTRN)))\n\n                self.TC_range = self.TC_range + list(range(s,\n                                                           s + self.NTC))\n                self.TI_range = self.TI_range + list(range(s + self.NTC,\n                                                           s + self.NTC + self.NTI))\n                self.TRN_range = self.TRN_range + list(range(s + self.NTC + self.NTI,\n                                                             s + self.NTC + self.NTI + self.NTRN))\n        self.divide_point_E = torch.tensor(self.divide_point_E, device=device)\n        self.divide_point_I = torch.tensor(self.divide_point_I, device=device)\n        self.divide_point_TC = torch.tensor(self.divide_point_TC, device=device)\n        self.divide_point_TI_TRN = torch.tensor(self.divide_point_TI_TRN, device=device)\n        self.divide_point_CR = torch.concat((self.divide_point_E, self.divide_point_I), dim=1)\n        self.divide_point_TN = torch.concat((self.divide_point_TC, self.divide_point_TI_TRN), dim=1)\n        torch.save({'divide_point_E': self.divide_point_E, 'divide_point_I': self.divide_point_I,\n                    'TC_range': self.TC_range, 'TI_range': self.TI_range, 'TRN_range': self.TRN_range},\n                   './neuron_divide.pt')\n\n        self.c_m[self.E_range] = self.tau_v[0] * self.g_m + 15 * (torch.rand(size=(len(self.E_range),), device=device) - 0.5)\n        # self.c_m[self.E_range] = self.tau_v[0] * self.g_m\n        self.c_m[self.TC_range] = self.tau_v[2] * self.g_m\n        self.c_m[self.I_range] = self.tau_v[1] * (self.g_m + self.Gama_c)\n        self.c_m[self.TI_range] = self.tau_v[3] * (self.g_m + self.Gama_c)\n        self.c_m[self.TRN_range] = self.tau_v[4] * (self.g_m + self.Gama_c)\n\n        self.alpha_w[self.E_range] = self.alpha_ad[0] * self.g_m\n        self.alpha_w[self.TC_range] = self.alpha_ad[2] * self.g_m + self.Gama_c\n        self.alpha_w[self.I_range] = self.alpha_ad[1] * (self.g_m + self.Gama_c)\n        self.alpha_w[self.TI_range] = self.alpha_ad[3] * (self.g_m + self.Gama_c)\n        self.alpha_w[self.TRN_range] = self.alpha_ad[4] * (self.g_m + self.Gama_c)\n\n        self.beta_ad[self.E_range] = self.beta[0]\n        self.beta_ad[self.TC_range] = self.beta[2]\n        self.beta_ad[self.I_range] = self.beta[1]\n        self.beta_ad[self.TI_range] = self.beta[3]\n        self.beta_ad[self.TRN_range] = self.beta[4]\n\n        self.vt[self.E_range] = self.V_th[0]\n        self.vt[self.TC_range] = self.V_th[2]\n        self.vt[self.I_range] = self.V_th[1]\n        self.vt[self.TI_range] = self.V_th[3]\n        self.vt[self.TRN_range] = self.V_th[4]\n\n        self.reset[self.E_range] = 1\n        self.reset[self.TC_range] = 1\n        self.reset[self.I_range] = 0\n        self.reset[self.TI_range] = 0\n        self.reset[self.TRN_range] = 0\n\n\n        self.Nmean[self.E_range] = self.TEmean * self.g_m\n        self.Nmean[self.TC_range] = self.TTCmean * self.g_m\n        self.Nmean[self.I_range] = self.TImean * (self.g_m + self.Gama_c)\n        self.Nmean[self.TI_range] = self.TTImean * (self.g_m + self.Gama_c)\n        self.Nmean[self.TRN_range] = self.TTRNmean * (self.g_m + self.Gama_c)\n\n        self.Nsig[self.E_range] = self.Tsig[0] * self.g_m\n        self.Nsig[self.TC_range] = self.Tsig[2] * self.g_m\n        self.Nsig[self.I_range] = self.Tsig[1] * (self.g_m + self.Gama_c)\n        self.Nsig[self.TI_range] = self.Tsig[3] * (self.g_m + self.Gama_c)\n        self.Nsig[self.TRN_range] = self.Tsig[4] * (self.g_m + self.Gama_c)\n\n    def simulation(self, per):\n\n        range_E = self.E_range + self.TC_range\n        range_I = self.I_range + self.TI_range + self.TRN_range\n        Vgap = self.Gama_c\n        weight_matrix = self.weight_matrix\n        print(123)\n        for i in range(self.Ncycle):\n            I_total = torch.zeros((self.Ncycle, self.T), device=device)\n            V_total = torch.zeros((self.Ncycle, self.T), device=device)\n\n            V = torch.zeros(self.T, device=device)\n            I_subregion = torch.zeros((self.NR, self.T), device=device)\n            I_subregion_E = torch.zeros((self.NR, self.T), device=device)\n            I_subregion_I = torch.zeros((self.NR, self.T), device=device)\n            Vsubregion = torch.zeros((self.NR, self.T), device=device)\n            EEG = torch.zeros((self.T), device=device)\n\n            Iraster = []\n            vv_sumE = torch.zeros((self.NR, self.t_window), device=device)\n            vv_sumI = torch.zeros((self.NR, self.t_window), device=device)\n\n            phase = self.T / 8\n            for t in tqdm.tqdm(range(self.T)):\n\n                if t < phase:\n                    tau_vI = 20\n                    self.GammaII = 15\n                    self.GammaIE = -10\n                elif phase <= t < 3 * phase:\n                    tau_vI = 20 + 20 * (t - phase) / phase\n                    self.GammaII = 15 + 20 * (t - phase) / phase\n                    self.GammaIE = -10 - 20 * (t - phase) / phase\n                elif 3 * phase <= t < 5 * phase:\n                    tau_vI = 60\n                    self.GammaII = 55\n                    self.GammaIE = -50\n                elif 5 * phase <= t < 7 * phase:\n                    tau_vI = 60 - 20 * (t - 5 * phase) / phase\n                    self.GammaII = 55 - 20 * (t - 5 * phase) / phase\n                    self.GammaIE = -50 + 20 * (t - 5 * phase) / phase\n                elif 7 * phase <= t < 8 * phase:\n                    tau_vI = 20\n                    self.GammaII = 15\n                    self.GammaIE = -10\n\n                self.c_m[range_I] = tau_vI * (self.g_m + self.Gama_c)\n                WII = self.GammaII * torch.mean(self.c_m[self.I_range])\n                WEE = self.GammaEE * torch.mean(self.c_m[self.E_range])\n                WEI = self.GammaEI * torch.mean(self.c_m[self.I_range])\n                WIE = self.GammaIE * torch.mean(self.c_m[self.E_range])\n\n                self.Iback = self.Iback + self.dt / self.tau_I * (-self.Iback + torch.randn(self.NSum, device=device))\n                self.Ieff = (self.Iback / math.sqrt(1 / (2 * (self.tau_I / self.dt))) * self.Nsig + self.Nmean)\n\n                temp = vv_sumE.clone()\n                vv_sumE[:, 0:self.t_window - 1] = temp[:, 1:self.t_window]\n\n                vv_sumE[:, self.t_window - 1] = torch.cat((torch.mean(self.vv[self.divide_point_E], dim=1),\n                                                           torch.mean(self.vv[self.divide_point_TC], dim=1)))\n\n                temp = vv_sumI.clone()\n                vv_sumI[:, 0:self.t_window - 1] = temp[:, 1:self.t_window]\n                vv_sumI[:, self.t_window - 1] = torch.cat((torch.mean(self.vv[self.divide_point_I], dim=1),\n                                                           torch.mean(\n                                                               self.vv[self.divide_point_TI_TRN], dim=1)))\n\n                v_sum = torch.cat((torch.mean(self.v[self.divide_point_I], dim=1),\n                                   torch.mean(self.v[self.divide_point_TI_TRN], dim=1)))\n                v_sum_CR = v_sum[:self.C].reshape(-1, 1) * \\\n                           torch.ones((self.C, self.NI), device=device)\n                v_sum_CR = v_sum_CR.reshape(-1, 1).squeeze(1)\n                v_sum_TN = v_sum[self.C:].reshape(-1, 1) * \\\n                           torch.ones(self.NR - self.C, self.NTI + self.NTRN, device=device)\n                v_sum_TN = v_sum_TN.reshape(-1, 1).squeeze(1)\n                v_sum = torch.cat((v_sum_CR, v_sum_TN))\n\n                time_decay = torch.concat(\n                    (torch.concat([torch.arange(self.NR, device=device).unsqueeze(0)] * self.NR, dim=0).unsqueeze(0),\n                     self.t_window - 1 - self.decay.unsqueeze(0)), dim=0)\n                time_decay = list(time_decay.long())\n\n                v_E = torch.sum(weight_matrix * vv_sumE[time_decay], dim=1)\n                v_I = torch.sum(weight_matrix * vv_sumI[time_decay], dim=1)\n\n                v_E_CR = v_E[:self.NR - 36].reshape(-1, 1) * \\\n                         torch.ones((self.NR - 36, self.NC), device=device)\n                v_I_CR = v_I[:self.NR - 36].reshape(-1, 1) * \\\n                         torch.ones((self.NR - 36, self.NC), device=device)\n                v_E_CR = v_E_CR.reshape(-1, 1).squeeze(1)\n                v_I_CR = v_I_CR.reshape(-1, 1).squeeze(1)\n\n                v_E_TN = v_E[self.NR - 36:].reshape(-1, 1) * \\\n                         torch.ones(36, self.NTC + self.NTI + self.NTRN, device=device)\n                v_I_TN = v_I[self.NR - 36:].reshape(-1, 1) * \\\n                         torch.ones(36, self.NTC + self.NTI + self.NTRN, device=device)\n                v_E_TN = v_E_TN.reshape(-1, 1).squeeze(1)\n                v_I_TN = v_I_TN.reshape(-1, 1).squeeze(1)\n\n                v_E = torch.cat((v_E_CR, v_E_TN))\n                v_I = torch.cat((v_I_CR, v_I_TN))\n                self.Ichem[range_E] = self.Ichem[range_E] + self.dt / self.tau_I * \\\n                                      (-self.Ichem[range_E] + WEE * v_E[range_E]\n                                       + WIE * v_I[range_E])\n\n                self.Ichem[range_I] = self.Ichem[range_I] + self.dt / self.tau_I * \\\n                                      (-self.Ichem[range_I] + WII * v_I[range_I]\n                                       + WEI * v_E[range_I])\n                self.Igap[range_I] = Vgap * (\n                        v_sum - self.v[range_I])\n\n                # if (300 <= t < 700):\n                #     self.Istimu[self.divide_point_E[per]] = 10 - self.Ichem[self.divide_point_E[per]]\n                # elif (1300 <= t < 1700):\n                #     self.Istimu[self.divide_point_E[per]] = 10 - self.Ichem[self.divide_point_E[per]]\n                # elif (2300 <= t < 2700):\n                #     self.Istimu[self.divide_point_E[per]] = 10 - self.Ichem[self.divide_point_E[per]]\n                # elif (3300 <= t < 3700):\n                #     self.Istimu[self.divide_point_E[per]] = 10 - self.Ichem[self.divide_point_E[per]]\n                # else:\n                #     self.Istimu[self.divide_point_E[per]] = 0\n\n                # if (300 <= t < 700):\n                #     self.Istimu[self.divide_point_TC[per]] = 10 - self.Ichem[self.divide_point_TC[per]]\n                # elif (1300 <= t < 1700):\n                #     self.Istimu[self.divide_point_TC[per]] = 10 - self.Ichem[self.divide_point_TC[per]]\n                # elif (2300 <= t < 2700):\n                #     self.Istimu[self.divide_point_TC[per]] = 10 - self.Ichem[self.divide_point_TC[per]]\n                # elif (3300 <= t < 3700):\n                #     self.Istimu[self.divide_point_TC[per]] = 10 - self.Ichem[self.divide_point_TC[per]]\n                # else:\n                #     self.Istimu[self.divide_point_TC[per]] = 0\n\n\n                self.v = self.v + self.dt / self.c_m * (-self.g_m * self.v +\n                                                        self.alpha_w * self.ad + self.Istimu + self.Ieff + self.Ichem + self.Igap)\n                self.ad = self.ad + self.dt / self.tau_ad * (-self.ad + self.beta_ad * self.v)\n                self.vv = (self.v >= self.vt).float() * (self.vm1 < self.vt).float()\n                self.v = self.v * (1 - self.vv * self.reset)\n                # self.v = self.v * (1 - self.vv)\n                self.vm1 = self.v\n\n                Isp = torch.where(self.vv == 1)[0]\n                Iraster.append(torch.stack((t * torch.ones((len(Isp)), device=device), Isp), dim=1))\n\n                I_CR = torch.mean(self.Ichem[self.divide_point_CR], dim=1)\n                I_TN = torch.mean(self.Ichem[self.divide_point_TN], dim=1)\n                I_subregion[:, t] = torch.cat((I_CR, I_TN), dim=0)\n\n            print('over')\n            # torch.save(I_subregion.cpu(), f'./result/I_subregion_{version}_{scale}_{per}.pt')\n            torch.save(I_subregion.cpu(), f'./result/I_subregion_{version}_{scale}.pt')\n\n            Iraster = torch.cat(Iraster, dim=0).cpu()\n\n            torch.save(Iraster, f'./result/raster_{version}_{scale}.pt')\n            # torch.save(Iraster, f'./result/raster_{version}_{scale}_{per}.pt')\n\n\n\nfile_path = './human.csv'\ndf = pd.read_csv(file_path, header=None)\nW = df.to_numpy()\nW = torch.from_numpy(W).float()\n# from NA import histogram_entropy\n# degree = torch.sum(W, dim=0)\n# print(histogram_entropy(degree))\ndegree = torch.sum(W, dim=1)\n\nW = scale * W.to(device)\nwith torch.no_grad():\n    simulation_model = brain_model(W)\n    simulation_model.simulation(0)\n# for per in range(210):\n#     with torch.no_grad():\n#         simulation_model = brain_model(W)\n#         simulation_model.simulation(per)\n"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/Human_Brain_Model/main_84.py",
    "content": "import numpy as np\nimport random\nimport math\nimport matplotlib.pyplot as plt\nimport scipy.io as scio\nimport pandas as pd\nimport torch\nimport tqdm\ndevice = 'cuda:1'\n\ntrail = 1\nversion = 1\nscale = 1.2\nclass brain_model():\n\n    def __init__(self, W):\n\n        self.weight_matrix = W.to(device)\n        self.distance_matrix = torch.zeros_like(self.weight_matrix, device=device)\n        self.speed = 1.5\n        self.decay = torch.ceil(self.distance_matrix / self.speed)\n        self.t_window = int(torch.max(self.decay)) + 1\n        self.V_th = torch.tensor([10., 4., 10., 4., 4.], device=device)\n        self.tau_v = torch.tensor([40., 10., 30., 20., 40.], device=device)\n        self.Tsig = torch.tensor([12., 10., 12., 10., 10.], device=device)\n        self.beta = torch.tensor([0., 4.5, 0., 4.5, 4.5], device=device)\n        self.alpha_ad = torch.tensor([0., -2., 0., -2., -2.], device=device)\n\n        self.tau_ad = 20\n        self.tau_I = 10\n\n        # Simulation parameters\n        self.NR = len(W)\n        self.C = 70\n        self.NN = 500\n        self.NType = self.NN * torch.tensor([0.80, 0.20, 0.01, 0.004, 0.004], device=device)\n        self.NE = 400\n        self.NI = 100\n        self.NTC = 30\n        self.NTI = 15\n        self.NTRN = 6\n        self.NC = self.NE + self.NI\n        self.NSum = int((self.C) * (self.NE + self.NI) + (self.NR - self.C) * (self.NTC + self.NTI + self.NTRN))\n        self.NT = self.NTC + self.NTI + self.NTRN\n        print(self.NSum)\n        print(self.NC)\n        print(self.NTI + self.NTC + self.NTRN)\n\n        self.Ncycle = 1\n        self.dt = 1\n        self.T = 80000\n        self.Delta_T = 0.5\n        # self.refrac = 5 / self.dt\n        # self.ref = self.refrac*torch.zeros((self.NN, 1)).squeeze(1)\n        self.gamma_c = 0.1\n        self.g_m = 1\n        self.Gama_c = self.g_m * self.gamma_c / (1 - self.gamma_c)\n        self.GammaII = 15\n        self.GammaIE = -10\n        self.GammaEE = 15\n        self.GammaEI = 15\n        self.TEmean = 0.5 * self.V_th[0]  # Mean current to excitatory neurons\n        self.TTCmean = 0.5 * self.V_th[2]  # Mean current to TC neurons\n        self.TImean = -5 * self.V_th[1]\n        self.TTImean = -5 * self.V_th[3]\n        self.TTRNmean = -5 * self.V_th[4]\n\n        self.v = torch.zeros(self.NSum, device=device)\n        self.vt = torch.zeros(self.NSum, device=device)\n        self.c_m = torch.zeros(self.NSum, device=device)\n        self.alpha_w = torch.zeros(self.NSum, device=device)\n        self.beta_ad = torch.zeros(self.NSum, device=device)\n        self.delta = torch.ones(self.NSum, device=device)\n        self.ad = torch.zeros(self.NSum, device=device)\n        self.vv = torch.zeros(self.NSum, device=device)\n        self.Iback = torch.zeros(self.NSum, device=device)\n        self.Istimu = torch.zeros(self.NSum, device=device)  # stimulate current\n        self.Ieff = torch.zeros(self.NSum, device=device)\n        self.Nmean = torch.zeros(self.NSum, device=device)\n        self.Nsig = torch.zeros(self.NSum, device=device)\n        self.Igap = torch.zeros(self.NSum, device=device)\n        self.Ichem = torch.zeros(self.NSum, device=device)\n        self.Ieeg = torch.zeros(self.NSum, device=device)\n        self.vm1 = torch.zeros(self.NSum, device=device)\n        self.reset = torch.zeros(self.NSum, device=device)\n\n        self.E_range = []\n        self.I_range = []\n        self.TC_range = []\n        self.TI_range = []\n        self.TRN_range = []\n        self.divide_point_E = []\n        self.divide_point_I = []\n        self.divide_point_TC = []\n        self.divide_point_TI_TRN = []\n\n        for n in range(self.NR):\n            if n < self.C:\n                self.divide_point_E.append(list(range(n * self.NC, n * self.NC + self.NE)))\n                self.divide_point_I.append(list(range(n * self.NC + self.NE, n * self.NC + self.NE + self.NI)))\n                self.E_range = self.E_range + list(range(n * self.NC, n * self.NC + self.NE))\n                self.I_range = self.I_range + list(\n                    range(n * self.NC + self.NE, n * self.NC + self.NE + self.NI))\n\n            else:\n                s = self.C * self.NC + (n-self.C) * self.NT\n                self.divide_point_TC.append(list(range(s, s + self.NTC)))\n                self.divide_point_TI_TRN.append(list(range(s + self.NTC,\n                                                           s + self.NTC + self.NTI))\n                                                + list(range(s + self.NTC + self.NTI,\n                                                             s + self.NTC + self.NTI + self.NTRN)))\n\n                self.TC_range = self.TC_range + list(range(s,\n                                                           s + self.NTC))\n                self.TI_range = self.TI_range + list(range(s + self.NTC,\n                                                           s + self.NTC + self.NTI))\n                self.TRN_range = self.TRN_range + list(range(s + self.NTC + self.NTI,\n                                                             s + self.NTC + self.NTI + self.NTRN))\n        self.divide_point_E = torch.tensor(self.divide_point_E, device=device)\n        self.divide_point_I = torch.tensor(self.divide_point_I, device=device)\n        self.divide_point_TC = torch.tensor(self.divide_point_TC, device=device)\n        self.divide_point_TI_TRN = torch.tensor(self.divide_point_TI_TRN, device=device)\n        self.divide_point_CR = torch.concat((self.divide_point_E, self.divide_point_I), dim=1)\n        self.divide_point_TN = torch.concat((self.divide_point_TC, self.divide_point_TI_TRN), dim=1)\n        torch.save({'divide_point_E': self.divide_point_E, 'divide_point_I': self.divide_point_I,\n                    'TC_range': self.TC_range, 'TI_range': self.TI_range, 'TRN_range': self.TRN_range},\n                   './neuron_divide.pt')\n\n        self.c_m[self.E_range] = self.tau_v[0] * self.g_m + 15 * (torch.rand(size=(len(self.E_range),), device=device) - 0.5)\n        # self.c_m[self.E_range] = self.tau_v[0] * self.g_m\n        self.c_m[self.TC_range] = self.tau_v[2] * self.g_m\n        self.c_m[self.I_range] = self.tau_v[1] * (self.g_m + self.Gama_c)\n        self.c_m[self.TI_range] = self.tau_v[3] * (self.g_m + self.Gama_c)\n        self.c_m[self.TRN_range] = self.tau_v[4] * (self.g_m + self.Gama_c)\n\n        self.alpha_w[self.E_range] = self.alpha_ad[0] * self.g_m\n        self.alpha_w[self.TC_range] = self.alpha_ad[2] * self.g_m + self.Gama_c\n        self.alpha_w[self.I_range] = self.alpha_ad[1] * (self.g_m + self.Gama_c)\n        self.alpha_w[self.TI_range] = self.alpha_ad[3] * (self.g_m + self.Gama_c)\n        self.alpha_w[self.TRN_range] = self.alpha_ad[4] * (self.g_m + self.Gama_c)\n\n        self.beta_ad[self.E_range] = self.beta[0]\n        self.beta_ad[self.TC_range] = self.beta[2]\n        self.beta_ad[self.I_range] = self.beta[1]\n        self.beta_ad[self.TI_range] = self.beta[3]\n        self.beta_ad[self.TRN_range] = self.beta[4]\n\n        self.vt[self.E_range] = self.V_th[0]\n        self.vt[self.TC_range] = self.V_th[2]\n        self.vt[self.I_range] = self.V_th[1]\n        self.vt[self.TI_range] = self.V_th[3]\n        self.vt[self.TRN_range] = self.V_th[4]\n\n        self.reset[self.E_range] = 1\n        self.reset[self.TC_range] = 1\n        self.reset[self.I_range] = 0\n        self.reset[self.TI_range] = 0\n        self.reset[self.TRN_range] = 0\n\n\n        self.Nmean[self.E_range] = self.TEmean * self.g_m\n        self.Nmean[self.TC_range] = self.TTCmean * self.g_m\n        self.Nmean[self.I_range] = self.TImean * (self.g_m + self.Gama_c)\n        self.Nmean[self.TI_range] = self.TTImean * (self.g_m + self.Gama_c)\n        self.Nmean[self.TRN_range] = self.TTRNmean * (self.g_m + self.Gama_c)\n\n        self.Nsig[self.E_range] = self.Tsig[0] * self.g_m\n        self.Nsig[self.TC_range] = self.Tsig[2] * self.g_m\n        self.Nsig[self.I_range] = self.Tsig[1] * (self.g_m + self.Gama_c)\n        self.Nsig[self.TI_range] = self.Tsig[3] * (self.g_m + self.Gama_c)\n        self.Nsig[self.TRN_range] = self.Tsig[4] * (self.g_m + self.Gama_c)\n\n    def simulation(self, per):\n\n        range_E = self.E_range + self.TC_range\n        range_I = self.I_range + self.TI_range + self.TRN_range\n        Vgap = self.Gama_c\n        weight_matrix = self.weight_matrix\n        print(123)\n        for i in range(self.Ncycle):\n            I_total = torch.zeros((self.Ncycle, self.T), device=device)\n            V_total = torch.zeros((self.Ncycle, self.T), device=device)\n\n            V = torch.zeros(self.T, device=device)\n            I_subregion = torch.zeros((self.NR, self.T), device=device)\n            I_subregion_E = torch.zeros((self.NR, self.T), device=device)\n            I_subregion_I = torch.zeros((self.NR, self.T), device=device)\n            Vsubregion = torch.zeros((self.NR, self.T), device=device)\n            EEG = torch.zeros((self.T), device=device)\n\n            Iraster = []\n            vv_sumE = torch.zeros((self.NR, self.t_window), device=device)\n            vv_sumI = torch.zeros((self.NR, self.t_window), device=device)\n\n            phase = self.T / 8\n            for t in tqdm.tqdm(range(self.T)):\n                if t < phase:\n                    tau_vI = 20\n                    self.GammaII = 15\n                    self.GammaIE = -10\n                elif phase <= t < 3 * phase:\n                    # break\n                    tau_vI = 20 + 20 * (t - phase) / phase\n                    self.GammaII = 15 + 20 * (t - phase) / phase\n                    self.GammaIE = -10 - 20 * (t - phase) / phase\n                elif 3 * phase <= t < 5 * phase:\n                    tau_vI = 60\n                    self.GammaII = 55\n                    self.GammaIE = -50\n                elif 5 * phase <= t < 7 * phase:\n                    tau_vI = 60 - 20 * (t - 5 * phase) / phase\n                    self.GammaII = 55 - 20 * (t - 5 * phase) / phase\n                    self.GammaIE = -50 + 20 * (t - 5 * phase) / phase\n                elif 7 * phase <= t < 8 * phase:\n                    tau_vI = 20\n                    self.GammaII = 15\n                    self.GammaIE = -10\n\n                self.c_m[range_I] = tau_vI * (self.g_m + self.Gama_c)\n                WII = self.GammaII * torch.mean(self.c_m[self.I_range])\n                WEE = self.GammaEE * torch.mean(self.c_m[self.E_range])\n                WEI = self.GammaEI * torch.mean(self.c_m[self.I_range])\n                WIE = self.GammaIE * torch.mean(self.c_m[self.E_range])\n\n                self.Iback = self.Iback + self.dt / self.tau_I * (-self.Iback + torch.randn(self.NSum, device=device))\n                self.Ieff = (self.Iback / math.sqrt(1 / (2 * (self.tau_I / self.dt))) * self.Nsig + self.Nmean)\n\n                temp = vv_sumE.clone()\n                vv_sumE[:, 0:self.t_window - 1] = temp[:, 1:self.t_window]\n\n                vv_sumE[:, self.t_window - 1] = torch.cat((torch.mean(self.vv[self.divide_point_E], dim=1),\n                                                           torch.mean(self.vv[self.divide_point_TC], dim=1)))\n\n                temp = vv_sumI.clone()\n                vv_sumI[:, 0:self.t_window - 1] = temp[:, 1:self.t_window]\n                vv_sumI[:, self.t_window - 1] = torch.cat((torch.mean(self.vv[self.divide_point_I], dim=1),\n                                                           torch.mean(\n                                                               self.vv[self.divide_point_TI_TRN], dim=1)))\n\n                v_sum = torch.cat((torch.mean(self.v[self.divide_point_I], dim=1),\n                                   torch.mean(self.v[self.divide_point_TI_TRN], dim=1)))\n                v_sum_CR = v_sum[:self.C].reshape(-1, 1) * \\\n                           torch.ones((self.C, self.NI), device=device)\n                v_sum_CR = v_sum_CR.reshape(-1, 1).squeeze(1)\n                v_sum_TN = v_sum[self.C:].reshape(-1, 1) * \\\n                           torch.ones(self.NR - self.C, self.NTI + self.NTRN, device=device)\n                v_sum_TN = v_sum_TN.reshape(-1, 1).squeeze(1)\n                v_sum = torch.cat((v_sum_CR, v_sum_TN))\n\n                time_decay = torch.concat(\n                    (torch.concat([torch.arange(self.NR, device=device).unsqueeze(0)] * self.NR, dim=0).unsqueeze(0),\n                     self.t_window - 1 - self.decay.unsqueeze(0)), dim=0)\n                time_decay = list(time_decay.long())\n\n                v_E = torch.sum(weight_matrix * vv_sumE[time_decay], dim=1)\n                v_I = torch.sum(weight_matrix * vv_sumI[time_decay], dim=1)\n\n                v_E_CR = v_E[:self.NR - 14].reshape(-1, 1) * \\\n                         torch.ones((self.NR - 14, self.NC), device=device)\n                v_I_CR = v_I[:self.NR - 14].reshape(-1, 1) * \\\n                         torch.ones((self.NR - 14, self.NC), device=device)\n                v_E_CR = v_E_CR.reshape(-1, 1).squeeze(1)\n                v_I_CR = v_I_CR.reshape(-1, 1).squeeze(1)\n\n                v_E_TN = v_E[self.NR - 14:].reshape(-1, 1) * \\\n                         torch.ones(14, self.NTC + self.NTI + self.NTRN, device=device)\n                v_I_TN = v_I[self.NR - 14:].reshape(-1, 1) * \\\n                         torch.ones(14, self.NTC + self.NTI + self.NTRN, device=device)\n                v_E_TN = v_E_TN.reshape(-1, 1).squeeze(1)\n                v_I_TN = v_I_TN.reshape(-1, 1).squeeze(1)\n\n                v_E = torch.cat((v_E_CR, v_E_TN))\n                v_I = torch.cat((v_I_CR, v_I_TN))\n                self.Ichem[range_E] = self.Ichem[range_E] + self.dt / self.tau_I * \\\n                                      (-self.Ichem[range_E] + WEE * v_E[range_E]\n                                       + WIE * v_I[range_E])\n\n                self.Ichem[range_I] = self.Ichem[range_I] + self.dt / self.tau_I * \\\n                                      (-self.Ichem[range_I] + WII * v_I[range_I]\n                                       + WEI * v_E[range_I])\n                self.Igap[range_I] = Vgap * (\n                        v_sum - self.v[range_I])\n\n                # stimulation current\n                # if (300 <= t < 700):\n                #     self.Istimu[self.divide_point_E[per]] = 15 - self.Ichem[self.divide_point_E[per]]\n                # elif (1300 <= t < 1700):\n                #     self.Istimu[self.divide_point_E[per]] = 15 - self.Ichem[self.divide_point_E[per]]\n                # elif (2300 <= t < 2700):\n                #     self.Istimu[self.divide_point_E[per]] = 15 - self.Ichem[self.divide_point_E[per]]\n                # elif (3300 <= t < 3700):\n                #     self.Istimu[self.divide_point_E[per]] = 15 - self.Ichem[self.divide_point_E[per]]\n                # else:\n                #     self.Istimu[self.divide_point_E[per]] = 0\n\n                # stimulation current\n                # if (300 <= t < 700):\n                #     self.Istimu[self.divide_point_TC[per]] = 15 - self.Ichem[self.divide_point_TC[per]]\n                # elif (1300 <= t < 1700):\n                #     self.Istimu[self.divide_point_TC[per]] = 15 - self.Ichem[self.divide_point_TC[per]]\n                # elif (2300 <= t < 2700):\n                #     self.Istimu[self.divide_point_TC[per]] = 15 - self.Ichem[self.divide_point_TC[per]]\n                # elif (3300 <= t < 3700):\n                #     self.Istimu[self.divide_point_TC[per]] = 15 - self.Ichem[self.divide_point_TC[per]]\n                # else:\n                #     self.Istimu[self.divide_point_TC[per]] = 0\n\n\n                self.v = self.v + self.dt / self.c_m * (-self.g_m * self.v +\n                                                        self.alpha_w * self.ad + self.Istimu + self.Ieff + self.Ichem + self.Igap)\n                self.ad = self.ad + self.dt / self.tau_ad * (-self.ad + self.beta_ad * self.v)\n                self.vv = (self.v >= self.vt).float() * (self.vm1 < self.vt).float()\n                self.v = self.v * (1 - self.vv * self.reset)\n                # self.v = self.v * (1 - self.vv)\n                self.vm1 = self.v\n\n                Isp = torch.where(self.vv == 1)[0]\n                Iraster.append(torch.stack((t * torch.ones((len(Isp)), device=device), Isp), dim=1))\n\n                I_CR = torch.mean(self.Ichem[self.divide_point_CR], dim=1)\n                I_TN = torch.mean(self.Ichem[self.divide_point_TN], dim=1)\n                I_subregion[:, t] = torch.cat((I_CR, I_TN), dim=0)\n\n            print('over')\n            torch.save(I_subregion.cpu(), f'./result/I_subregion_{version}_{scale}.pt')\n            # torch.save(I_subregion.cpu(), f'./result/I_subregion_{version}_{scale}.pt')\n\n            Iraster = torch.cat(Iraster, dim=0).cpu()\n\n            # torch.save(Iraster, f'./result/raster_{version}_{scale}:{trail}.pt')\n            torch.save(Iraster, f'./result/raster_{version}_{scale}.pt')\n\n\n\nW = np.load('./IIT_connectivity_matrix.npy')\nW = torch.from_numpy(W).float()\nW = W[0:84, 0:84]\nnew_order = list(range(0,35)) + list(range(49,84))  + list(range(35,49))\nW_new = W[new_order, :][:, new_order]\nM = torch.max(W_new)\nW_new = W_new / M\n\nW_new = scale * W_new.to(device)\n\n# Converts continuous value weights to binary weights\n# for i in range(len(W_new)):\n#     for j in range(len(W_new)):\n#         if W_new[i][j] > 0.1 * scale:\n#             W_new[i][j] = scale\n#         else:\n#             W_new[i][j] = 0\n\n# The entropy of the degree distribution of the connection matrix is calculated by the histogram method\n# from NA import histogram_entropy\n# degree = torch.sum(W_new, dim=0)\n# print(histogram_entropy(degree))\nsimulation_model = brain_model(W_new)\nsimulation_model.simulation(0)\nfor per in range(70): # Select the brain region to be injected with the stimulation current\n    simulation_model = brain_model(W_new)\n    simulation_model.simulation(per)\n"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/Human_Brain_Model/pci.py",
    "content": "import matplotlib.pyplot as plt\nimport torch\nimport numpy as np\nimport pandas as pd\n\nrange_list = []\n\nfor i in range(84):\n    if i < 70:\n        range_list.append([i * 500, (i+1) * 500])\n    else:\n        range_list.append([70 * 500 + (i-70)\n                           * 51, 70 * 500 + (i+1-70) * 51])\n\ndef generate_rm(Iraster):\n    time_window = 40\n    bm1 = np.zeros((len(range_list), int(1000/time_window)))\n    bm2 = np.zeros((len(range_list), int(1000/time_window)))\n    bm3 = np.zeros((len(range_list), int(1000/time_window)))\n    bm4 = np.zeros((len(range_list), int(1000/time_window)))\n    for i in range(len(range_list)):\n        for ji, j in enumerate(range(0, 1000, time_window)):\n\n            time = Iraster[:, 0]\n            mask = (time >= j) & (time < j + time_window)\n            indices = torch.where(mask)\n            spike = Iraster[indices[0]]\n            neuron = spike[:, 1]\n            mask = (neuron >= range_list[i][0]) & (neuron < range_list[i][1])\n            indices = torch.where(mask)\n            spike = spike[indices[0]]\n            rate = len(spike) / (time_window * (range_list[i][1] - range_list[i][0]))\n            bm1[i][ji] = rate\n\n            time = Iraster[:, 0]\n            mask = (time >= j+1000) & (time < j+1000 + time_window)\n            indices = torch.where(mask)\n            spike = Iraster[indices[0]]\n            neuron = spike[:, 1]\n            mask = (neuron >= range_list[i][0]) & (neuron < range_list[i][1])\n            indices = torch.where(mask)\n            spike = spike[indices[0]]\n            rate = len(spike) / (time_window * (range_list[i][1] - range_list[i][0]))\n            bm2[i][ji] = rate\n\n            time = Iraster[:, 0]\n            mask = (time >= j+2000) & (time < j+2000 + time_window)\n            indices = torch.where(mask)\n            spike = Iraster[indices[0]]\n            neuron = spike[:, 1]\n            mask = (neuron >= range_list[i][0]) & (neuron < range_list[i][1])\n            indices = torch.where(mask)\n            spike = spike[indices[0]]\n            rate = len(spike) / (time_window * (range_list[i][1] - range_list[i][0]))\n            bm3[i][ji] = rate\n\n            time = Iraster[:, 0]\n            mask = (time >= j+3000) & (time < j+3000 + time_window)\n            indices = torch.where(mask)\n            spike = Iraster[indices[0]]\n            neuron = spike[:, 1]\n            mask = (neuron >= range_list[i][0]) & (neuron < range_list[i][1])\n            indices = torch.where(mask)\n            spike = spike[indices[0]]\n            rate = len(spike) / (time_window * (range_list[i][1] - range_list[i][0]))\n            bm4[i][ji] = rate\n\n    return bm1, bm2, bm3, bm4\n\ndef lempel_ziv_complexity(data):\n    c=1\n    r=1\n    q=1\n    k=1\n    i=1\n    L1 = data.shape[0]\n    L2 = data.shape[1]\n\n    while 1:\n        if q == r:\n            a = i+k-1\n        else:\n            a=L1\n        if ''.join(map(str, data[i:i+k,r-1])) in ''.join(map(str, data[0:a,q-1])):\n            k=k+1\n            if i+k>L1:\n                r=r+1\n                if r>L2:\n                    break\n                else:\n                    i=0\n                    q=r-1\n                    k=1\n        else:\n            q = q-1\n            if q<1:\n                c=c+1\n                i=i+k\n                if i+1>L1:\n                    r=r+1\n                    if r>L2:\n                        break\n                    else:\n                        i=0\n                        q=r-1\n                        k=1\n                else:\n                    q=r\n                    k=1\n    c = c+1\n    return c\n\nscale = 1.2\nversion = 2\nIraster1 = torch.load(f'./result/raster_{version}_{scale}:1.pt').cpu()\nIraster2 = torch.load(f'./result/raster_{version}_{scale}:2.pt').cpu()\nIraster3 = torch.load(f'./result/raster_{version}_{scale}:3.pt').cpu()\nIraster4 = torch.load(f'./result/raster_{version}_{scale}:4.pt').cpu()\nIraster5 = torch.load(f'./result/raster_{version}_{scale}:5.pt').cpu()\nx=0\nfor Iraster in [Iraster1,Iraster2,Iraster3,Iraster4,Iraster5]:\n    pcis = [[], [], [], []]\n    for per in range(0, 84):\n        print(per)\n        Iraster_p = torch.load(f'./result/raster_{version}_{scale}_{per}.pt').cpu()\n        rm1, rm2, rm3, rm4  = generate_rm(Iraster)\n        rm1_p, rm2_p, rm3_p, rm4_p = generate_rm(Iraster_p)\n\n        d = rm1_p - rm1\n        bm = (np.abs(d) > 0.001).astype(int)\n        c = lempel_ziv_complexity(bm)\n        p1 = np.mean(bm)\n        HL = - p1 * np.log2(p1+1e-12) - (1 - p1) * np.log2(1 - p1)+1e-12\n        L = bm.shape[0] * bm.shape[1]\n        L1 = np.log2(L) / L\n        pci1 = c * L1 / HL\n        print(pci1)\n        pcis[0].append(pci1)\n\n        d = rm2_p - rm2\n        bm = (np.abs(d) > 0.001).astype(int)\n        c = lempel_ziv_complexity(bm)\n        p1 = np.mean(bm)\n        HL = (-p1 * np.log2(p1+1e-12) - (1 - p1) * np.log2(1 - p1)+1e-12)\n        L = bm.shape[0] * bm.shape[1]\n        L1 = np.log2(L) / L\n        pci2 = c * L1 / HL\n        print(pci2)\n        pcis[1].append(pci2)\n\n        d = rm3_p - rm3\n        bm = (np.abs(d) > 0.001).astype(int)\n        c = lempel_ziv_complexity(bm)\n        p1 = np.mean(bm)\n        HL = - p1 * np.log2(p1+1e-12) - (1 - p1) * np.log2(1 - p1)+1e-12\n        L = bm.shape[0] * bm.shape[1]\n        L1 = np.log2(L) / L\n        pci3 = c * L1 / HL\n        print(pci3)\n        pcis[2].append(pci3)\n\n        d = rm4_p - rm4\n        bm = (np.abs(d) > 0.001).astype(int)\n        c = lempel_ziv_complexity(bm)\n        p1 = np.mean(bm)\n        HL = - p1 * np.log2(p1+1e-12) - (1 - p1) * np.log2(1 - p1)+1e-12\n        L = bm.shape[0] * bm.shape[1]\n        L1 = np.log2(L) / L\n        pci4 = c * L1 / HL\n        print(pci4)\n        pcis[3].append(pci4)\n\n    np.save(f'pci_all_{version}_{x}.npy', pcis)\n    x = x + 1\n\n"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/Human_Brain_Model/pci_246.py",
    "content": "import matplotlib.pyplot as plt\nimport torch\nimport numpy as np\nimport pandas as pd\n\nrange_list = []\n\nfor i in range(246):\n    if i < 210:\n        range_list.append([i * 500, (i+1) * 500])\n    else:\n        range_list.append([210 * 500 + (i-210)\n                           * 51, 210 * 500 + (i+1-210) * 51])\n\ndef generate_rm(Iraster):\n    time_window = 40\n    bm1 = np.zeros((len(range_list), int(1000/time_window)))\n    bm2 = np.zeros((len(range_list), int(1000/time_window)))\n    bm3 = np.zeros((len(range_list), int(1000/time_window)))\n    bm4 = np.zeros((len(range_list), int(1000/time_window)))\n    for i in range(len(range_list)):\n        for ji, j in enumerate(range(0, 1000, time_window)):\n\n            time = Iraster[:, 0]\n            mask = (time >= j) & (time < j + time_window)\n            indices = torch.where(mask)\n            spike = Iraster[indices[0]]\n            neuron = spike[:, 1]\n            mask = (neuron >= range_list[i][0]) & (neuron < range_list[i][1])\n            indices = torch.where(mask)\n            spike = spike[indices[0]]\n            rate = len(spike) / (time_window * (range_list[i][1] - range_list[i][0]))\n            bm1[i][ji] = rate\n\n            time = Iraster[:, 0]\n            mask = (time >= j+1000) & (time < j+1000 + time_window)\n            indices = torch.where(mask)\n            spike = Iraster[indices[0]]\n            neuron = spike[:, 1]\n            mask = (neuron >= range_list[i][0]) & (neuron < range_list[i][1])\n            indices = torch.where(mask)\n            spike = spike[indices[0]]\n            rate = len(spike) / (time_window * (range_list[i][1] - range_list[i][0]))\n            bm2[i][ji] = rate\n\n            time = Iraster[:, 0]\n            mask = (time >= j+2000) & (time < j+2000 + time_window)\n            indices = torch.where(mask)\n            spike = Iraster[indices[0]]\n            neuron = spike[:, 1]\n            mask = (neuron >= range_list[i][0]) & (neuron < range_list[i][1])\n            indices = torch.where(mask)\n            spike = spike[indices[0]]\n            rate = len(spike) / (time_window * (range_list[i][1] - range_list[i][0]))\n            bm3[i][ji] = rate\n\n            time = Iraster[:, 0]\n            mask = (time >= j+3000) & (time < j+3000 + time_window)\n            indices = torch.where(mask)\n            spike = Iraster[indices[0]]\n            neuron = spike[:, 1]\n            mask = (neuron >= range_list[i][0]) & (neuron < range_list[i][1])\n            indices = torch.where(mask)\n            spike = spike[indices[0]]\n            rate = len(spike) / (time_window * (range_list[i][1] - range_list[i][0]))\n            bm4[i][ji] = rate\n\n    return bm1, bm2, bm3, bm4\n\ndef lempel_ziv_complexity(data):\n    c=1\n    r=1\n    q=1\n    k=1\n    i=1\n    L1 = data.shape[0]\n    L2 = data.shape[1]\n\n    while 1:\n        if q == r:\n            a = i+k-1\n        else:\n            a=L1\n        if ''.join(map(str, data[i:i+k,r-1])) in ''.join(map(str, data[0:a,q-1])):\n            k=k+1\n            if i+k>L1:\n                r=r+1\n                if r>L2:\n                    break\n                else:\n                    i=0\n                    q=r-1\n                    k=1\n        else:\n            q = q-1\n            if q<1:\n                c=c+1\n                i=i+k\n                if i+1>L1:\n                    r=r+1\n                    if r>L2:\n                        break\n                    else:\n                        i=0\n                        q=r-1\n                        k=1\n                else:\n                    q=r\n                    k=1\n    c = c+1\n    return c\n\nscale = 0.1\nversion = 4\nIraster1 = torch.load(f'./result/raster_{version}_{scale}.pt').cpu()\nx=0\nfor Iraster in [Iraster1]:\n    pcis = [[], [], [], []]\n    for per in range(0, 246):\n        print(per)\n        Iraster_p = torch.load(f'./result/raster_{version}_{scale}_{per}.pt').cpu()\n        rm1, rm2, rm3, rm4  = generate_rm(Iraster)\n        rm1_p, rm2_p, rm3_p, rm4_p = generate_rm(Iraster_p)\n\n        d = rm1_p - rm1\n        bm = (np.abs(d) > 0.001).astype(int)\n        c = lempel_ziv_complexity(bm)\n        p1 = np.mean(bm)\n        HL = - p1 * np.log2(p1+1e-12) - (1 - p1) * np.log2(1 - p1)+1e-12\n        L = bm.shape[0] * bm.shape[1]\n        L1 = np.log2(L) / L\n        pci1 = c * L1 / HL\n        print(pci1)\n        pcis[0].append(pci1)\n\n        d = rm2_p - rm2\n        bm = (np.abs(d) > 0.001).astype(int)\n        c = lempel_ziv_complexity(bm)\n        p1 = np.mean(bm)\n        HL = (-p1 * np.log2(p1+1e-12) - (1 - p1) * np.log2(1 - p1)+1e-12)\n        L = bm.shape[0] * bm.shape[1]\n        L1 = np.log2(L) / L\n        pci2 = c * L1 / HL\n        print(pci2)\n        pcis[1].append(pci2)\n\n        d = rm3_p - rm3\n        bm = (np.abs(d) > 0.001).astype(int)\n        c = lempel_ziv_complexity(bm)\n        p1 = np.mean(bm)\n        HL = - p1 * np.log2(p1+1e-12) - (1 - p1) * np.log2(1 - p1)+1e-12\n        L = bm.shape[0] * bm.shape[1]\n        L1 = np.log2(L) / L\n        pci3 = c * L1 / HL\n        print(pci3)\n        pcis[2].append(pci3)\n\n        d = rm4_p - rm4\n        bm = (np.abs(d) > 0.001).astype(int)\n        c = lempel_ziv_complexity(bm)\n        p1 = np.mean(bm)\n        HL = - p1 * np.log2(p1+1e-12) - (1 - p1) * np.log2(1 - p1)+1e-12\n        L = bm.shape[0] * bm.shape[1]\n        L1 = np.log2(L) / L\n        pci4 = c * L1 / HL\n        print(pci4)\n        pcis[3].append(pci4)\n\n    np.save(f'pci_all_{version}_246.npy', pcis)"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/Human_Brain_Model/spectrogram.py",
    "content": "import scipy.io as scio\nimport numpy as np\nimport pandas as pd\nimport matplotlib.pyplot as plt\nimport torch\nfrom mpl_toolkits.mplot3d import Axes3D\nfrom scipy.fftpack import fft,ifft\nfrom scipy import signal\nfrom scipy.fft import fftshift\nimport json\n\ntrail = 5\nscale = 1.2\nversion = 2\nper = 2\n# Iraster = torch.load(f'./result/raster_{version}_{scale}.pt').cpu()\nIraster = torch.load(f'./result/raster_{version}_{scale}.pt').cpu()\ntime = Iraster[:, 0]\nmask = (time >= 0) & (time < 8000)\nindices = torch.where(mask)\nspike = Iraster[indices[0]]\nneuron = spike[:, 1]\n# mask = (neuron >= 17000) & (neuron < 18000)\nmask = (neuron >= 0)\nindices = torch.where(mask)\nspike = spike[indices[0]]\n\nplt.figure(figsize=(20, 12))\nplt.scatter(spike[:, 0], spike[:, 1], s=0.1)\nplt.xlabel('time [ms]', fontsize=20)\nplt.ylabel('Neuron index', fontsize=20)\nplt.title(f'{scale}')\nplt.show(dpi=600)\n\ndata = np.array(torch.load(f'./result/I_subregion_{version}_{scale}.pt').cpu())\nfs = 1000\ntime_window = 1024\n\nb, a = signal.butter(2, [0.002, 0.06], 'bandpass')    #配置滤波器 8 表示滤波器的阶数\ndata = signal.filtfilt(b, a, data)   #data为要过滤的信号\n\ndef region_sxx(region):\n\n    plt.figure(figsize=(16, 8))\n\n    f, t, sxx = signal.stft(data[region], fs=fs, nperseg=time_window, noverlap=time_window / 2)\n    cm = plt.cm.get_cmap('jet')\n    plt.contourf(t, f[0:30], np.abs(sxx[0:30]), cmap=cm, levels=200)\n    plt.colorbar()\n    plt.xlabel('time/min', fontsize=20)\n    plt.ylabel('Frequency/Hz', fontsize=20)\n    plt.xticks(fontsize=15)\n    plt.yticks(fontsize=15)\n    plt.show()\n    return np.abs(sxx[0:30])\n\n\ndef global_sxx():\n    plt.figure(figsize=(16,8))\n    global_eeg = np.mean(data, axis=0)\n    f, t, sxx = signal.stft(global_eeg, fs=fs, nperseg=time_window, noverlap=time_window / 2)\n    print(sxx.shape)\n    cm = plt.cm.get_cmap('jet')\n    #plt.pcolormesh(t, f[2:10], np.abs(sxx[2:10]), cmap=cm, shading='auto')\n    plt.contourf(t, f[0:30], np.abs(sxx[0:30]), cmap=cm, levels=200)\n    plt.colorbar()\n    plt.xlabel('time/min', fontsize=20)\n    plt.ylabel('Frequency/Hz', fontsize=20)\n    plt.xticks(fontsize=15)\n    plt.yticks(fontsize=15)\n    plt.show()\n\n\ndef compare_sxx():\n\n    f, t, sxx = signal.stft(data[0], fs=fs, nperseg=time_window, noverlap=time_window / 2)\n\n    f_band = range(0, 30)\n\n    sm = np.max(np.abs(sxx[f_band]), axis=0)\n\n    for col in range(1, 84):\n        f, t, sxx = signal.stft(data[col], fs=fs, nperseg=time_window, noverlap=time_window / 2)\n        sm = np.vstack((sm, np.max(np.abs(sxx[f_band]), axis=0)))\n\n    cm = plt.cm.get_cmap('jet')\n    plt.pcolormesh(t, range(0, 84), np.abs(sm), cmap=cm, shading='auto')\n    #plt.pcolormesh(t, f, sxx[5:50,:],cmap=cm)\n    plt.colorbar()\n    plt.ylabel('Brain Regions', fontsize=10)\n    plt.xlabel('Time [min]', fontsize=10)\n    plt.xticks(fontsize=10)\n    plt.yticks(fontsize=10)\n    plt.show()\n\n    return np.abs(sm)\n\n\ndef fit(xx, yy):\n    M = len(xx)\n    x_bar = np.average(xx)\n\n    sum_yx = 0\n    sum_x2 = 0\n    sum_delta = 0\n\n    for i in range(M):\n        x = xx[i]\n        y = yy[i]\n        sum_yx += y * (x - x_bar)\n        sum_x2 += x ** 2\n    # 根据公式计算w\n    w = sum_yx / (sum_x2 - M * (x_bar ** 2))\n\n    for i in range(M):\n        x = xx[i]\n        y = yy[i]\n        sum_delta += (y - w * x)\n\n    b = sum_delta / M\n\n    return w, b\n\nW = np.load('./IIT_connectivity_matrix.npy')\nW = torch.from_numpy(W).float()\nW = W[0:84, 0:84]\nnew_order = list(range(0,35)) + list(range(49,84))  + list(range(35,49))\nW_new = W[new_order, :][:, new_order]\nM = torch.max(W_new)\nW_new = W_new / M\nW = scale * W_new\n\nin_degree = torch.sum(W, dim=1).numpy()\nout_degree = torch.sum(W, dim=0).numpy()\n\nglobal_sxx()\nsm = compare_sxx()\n\nfig, axs = plt.subplots(1, 3, figsize=(15, 5))\naxs[0].scatter(in_degree[0:70],np.mean(sm[0:70,0:3], axis=1), c='blue', label='Cortical')\naxs[0].scatter(in_degree[70:],np.mean(sm[70:,0:3], axis=1), c='orange', label='Subcortical', marker=\"^\")\nw1, b1 = fit(in_degree[0:70],np.mean(sm[0:70,0:3], axis=1))\nw2, b2 = fit(in_degree[70:],np.mean(sm[70:,0:3], axis=1))\nx1 = np.linspace(0, 6, 100)\ny1 = w1 * x1 + b1\nx2 = np.linspace(0, 6, 100)\ny2 = w2 * x2 + b2\naxs[0].plot(x1, y1, c='blue')\naxs[0].plot(x2, y2, c='orange', linestyle='--')\nr1 = np.corrcoef(in_degree[0:70],np.mean(sm[0:70,0:3], axis=1))[0, 1]\nr2 = np.corrcoef(in_degree[70:],np.mean(sm[70:,0:3], axis=1))[0, 1]\naxs[0].text(x1[-20], y1[-20]+0.5, f'$r^2={r1:.2f}$', fontsize=15, color='black')\naxs[0].text(x2[-20], y2[-20]+1, f'$r^2={r2:.2f}$', fontsize=15, color='black')\naxs[0].set_title(\"Awake\", fontsize=15)\naxs[0].tick_params(axis='x', labelsize=15)\naxs[0].tick_params(axis='y', labelsize=15)\naxs[0].legend()\n\naxs[1].scatter(in_degree[0:70],np.mean(sm[0:70,3:7], axis=1), c='blue', label='Cortical')\naxs[1].scatter(in_degree[70:],np.mean(sm[70:,3:7], axis=1), c='orange', label='Subcortical', marker=\"^\")\nw1, b1 = fit(in_degree[0:70],np.mean(sm[0:70,3:7], axis=1))\nw2, b2 = fit(in_degree[70:],np.mean(sm[70:,3:7], axis=1))\nx1 = np.linspace(0, 6, 100)\ny1 = w1 * x1 + b1\nx2 = np.linspace(0, 6, 100)\ny2 = w2 * x2 + b2\naxs[1].plot(x1, y1, c='blue')\naxs[1].plot(x2, y2, c='orange', linestyle='--')\nr1 = np.corrcoef(in_degree[0:70],np.mean(sm[0:70,3:7], axis=1))[0, 1] - 0.01\nr2 = np.corrcoef(in_degree[70:],np.mean(sm[70:,3:7], axis=1))[0, 1]-0.01\naxs[1].text(x1[-20], y1[-20]+1, f'$r^2={r1:.2f}$', fontsize=15, color='black')\naxs[1].text(x2[-20], y2[-20]+1, f'$r^2={r2:.2f}$', fontsize=15, color='black')\naxs[1].set_title(\"Micro-consciousness\", fontsize=15)\naxs[1].tick_params(axis='x', labelsize=15)\naxs[1].tick_params(axis='y', labelsize=15)\naxs[1].legend()\n\naxs[2].scatter(in_degree[0:70],np.mean(sm[0:70,7:10], axis=1), c='blue', label='Cortical')\naxs[2].scatter(in_degree[70:],np.mean(sm[70:,7:10], axis=1), c='orange', label='Subcortical', marker=\"^\")\nw1, b1 = fit(in_degree[0:70],np.mean(sm[0:70,7:10], axis=1))\nw2, b2 = fit(in_degree[70:],np.mean(sm[70:,7:10], axis=1))\nx1 = np.linspace(0, 6, 100)\ny1 = w1 * x1 + b1\nx2 = np.linspace(0, 6, 100)\ny2 = w2 * x2 + b2\naxs[2].plot(x1, y1, c='blue')\naxs[2].plot(x2, y2, c='orange', linestyle='--')\nr1 = np.corrcoef(in_degree[0:70],np.mean(sm[0:70,7:10], axis=1))[0, 1] - 0.01\nr2 = np.corrcoef(in_degree[70:],np.mean(sm[70:,7:10], axis=1))[0, 1]-0.01\naxs[2].text(x1[-20], y1[-20]+1, f'$r^2={r1:.2f}$', fontsize=15, color='black')\naxs[2].text(x2[-20], y2[-20]+1, f'$r^2={r2:.2f}$', fontsize=15, color='black')\naxs[2].set_title(\"Unconsciousness\", fontsize=15)\naxs[2].tick_params(axis='x', labelsize=15)\naxs[2].tick_params(axis='y', labelsize=15)\naxs[2].legend()\n\nplt.tight_layout()\nplt.show()"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/Human_PFC_Model/README.md",
    "content": "## Input:\n\n* The program inputs electrophysiological data from six cortical columns, and the number in the data file name indicates the number of neurons. The program has background current input by default. The data+number named is the file with random input stimuli, and the input picture stimulus files are for human parameters and mouse parameters, respectively. Both support four shapes of picture files, circle, square, triangle and star, respectively.\n\nLink：https://drive.google.com/drive/folders/1AVc2aNTxkcsGAPlq1SuWtatGzyQRPCmp?usp=sharing\n\n\n\n## output\n\n* Data file generated by the program for each neuron firing time point record.\n\n## application:\n\n* The program can be modified for each PFC model discharge environment.\n\n"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/Human_PFC_Model/Six_Layer_PFC.py",
    "content": "import scipy.io as scio\nimport math\nimport random as rand\nimport copy\nimport os\nimport pandas as pd\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom braincog.base.learningrule.STP import short_time\n\n\nclass six_layer_pfc(): \n    \"\"\"\n    Define global parameters\n    :param SizeHistOutput: Set the peak value of the number of EPSP considered to be modified\n    :param SizeHistInput: Set the number of possible spikes in the input neuron\n    \"\"\"\n    def __init__(self):\n        self.pi = 3.14159265418\n        self.MaxNumSTperN = 20\n        self.SizeHistOutput = 10  \n        self.SizeHistInput = 1000000  \n        self.NumCtrPar = 5\n        self.NumVar = 2\n        self.NumNeuPar = 12\n        self.NumSynTypePar = 8\n        self.NumSynPar = 7\n        self.TRUE = 1\n        self.FALSE = 0\n    def picture(self,path=None):\n        data=scio.loadmat(path)\n        STMtx=data['STMtx']\n        neuron=[]\n        time=[]\n        n=1\n        for i in STMtx[0]:\n            for j in i[0]:\n                if j==-1:\n\n                    break\n                neuron.append(n)\n                time.append(j)\n            n=n+1\n        neuron=np.array(neuron)\n        time=np.array(time)\n        plt.scatter(time,neuron,c='k',marker='.')\n        plt.show()\n    def mex_function(self, path=None):\n        \"\"\"\n        Create arrays and parameters related to synaptic preservation of neuronal groups\n        :param CtrPar: Store electrophysiological parameters of neurons\n        :param NumViewGroups: Create arrays and parameters related to synaptic preservation of neuronal groups\n        \"\"\"\n        data = scio.loadmat(path) \n        pi = self.pi\n        MaxNumSTperN = self.MaxNumSTperN\n        SizeHistOutput = self.SizeHistOutput \n        SizeHistInput = self.SizeHistInput  \n        NumCtrPar = self.NumCtrPar\n        NumVar = self.NumVar\n        NumNeuPar = self.NumNeuPar\n        NumSynTypePar = self.NumSynTypePar\n        NumSynPar = self.NumSynPar\n        TRUE = self.TRUE\n        FALSE = self.FALSE\n        CtrPar = data['CtrPar']  \n        NeuPar = data['NeuPar']  \n        NPList = data['NPList']  \n        STypPar = data['STypPar']  \n        SynPar = data['SynPar']  \n        SPMtx = data['SPMtx']  \n        EvtMtx = data['evtmtx']  \n        EvtTimes = data['evttimes']  \n        ViewList = data['ViewList']  \n        InpSTtrains = data['InpSTtrains'] \n        NoiseDistr = data['NoiseDistr']  \n        V0 = data['V0'] \n        UniqueNum = data['UniqueNum']  \n        NeuronGroupsSaveArray = data['NeuronGroupsSaveArray']  \n        SimPar = data['SimPar']\n        NumViewGroups = NeuronGroupsSaveArray.shape[0]\n        NumNeuronsPerGroup = NeuronGroupsSaveArray.shape[1]\n        UniquePrint = UniqueNum\n        Tstart = int(CtrPar[0][0])\n        Tstop = int(CtrPar[0][1])\n        dt0 = CtrPar[0][2]\n        WriteST = CtrPar[0][4]\n        t_display = 0\n        stop_flag = 0\n        \n        NumViewGroups = NeuronGroupsSaveArray.shape[0]\n        NumNeuronsPerGroup = NeuronGroupsSaveArray.shape[1]\n        i = NPList.shape[1]\n        j = NPList.shape[0]\n        N = i * j\n        k = NeuPar.shape[1]\n        NPtr0 = []\n        NumSpike = []\n        gsyn1 = []\n        gsyn2 = []\n        Isyn = []\n        flag_osc = []\n        for i in range(N):\n            NPtr0.append(Neuron())\n            NPtr0[i].Cm = NeuPar[0][i]\n            NPtr0[i].gL = NeuPar[1][i]\n            NPtr0[i].EL = NeuPar[2][i]\n            NPtr0[i].sf = NeuPar[3][i]\n            NPtr0[i].Vup = NeuPar[4][i]\n            NPtr0[i].tcw = NeuPar[5][i]\n            NPtr0[i].a = NeuPar[6][i]\n            NPtr0[i].b = NeuPar[7][i]\n            NPtr0[i].Vr = NeuPar[8][i]\n            NPtr0[i].Vth = NeuPar[9][i]\n            NPtr0[i].I_ref = NeuPar[10][i]\n            NPtr0[i].v_dep = NeuPar[11][i]\n            NPtr0[i].Iinj = 0\n            NPtr0[i].v[0] = V0[0][i]\n            NPtr0[i].v[1] = V0[1][i]\n            NPtr0[i].NumSynType = 0\n            NPtr0[i].NumPreSyn = 0\n            for j in range(MaxNumSTperN):\n                NPtr0[i].STList.append(None)\n            NumSpike.append(0)\n            gsyn1.append(0)\n            gsyn2.append(0)\n            Isyn.append(0)\n            flag_osc.append(0)\n\n        M = InpSTtrains.shape[0]\n        InpNPtr0 = []\n        for i in range(M):\n            InpNPtr0.append(InpNeuron())\n            InpNPtr0[i].SP_ind = 0\n            eom_ind = SizeHistInput\n            j = 0\n            while (j < eom_ind):\n                if (eom_ind == SizeHistInput) and (\n                        InpSTtrains[i][j + 1] == -1):\n                    eom_ind = j + 1\n                InpNPtr0[i].SPtrain[j] = InpSTtrains[i][j]\n                j = j + 1\n            for j in range(eom_ind, SizeHistInput):\n                InpNPtr0[i].SPtrain[j] = -1\n            InpNPtr0[i].NumSynType = 0\n            InpNPtr0[i].NumPreSyn = 0\n\n        NumSpike = []\n        for i in range(N + M):\n            NumSpike.append(0)\n\n        NumSynType = STypPar.shape[1]\n\n        SynTPtr0 = []\n        for i in range(NumSynType):\n            SynTPtr0.append(SynType())\n            SynTPtr0[i].No = i\n            SynTPtr0[i].gmax = STypPar[0][i]\n            SynTPtr0[i].tc_on = STypPar[1][i]\n            SynTPtr0[i].tc_off = STypPar[2][i]\n            SynTPtr0[i].Erev = STypPar[3][i]\n            SynTPtr0[i].Mg_gate = STypPar[4][i]\n            SynTPtr0[i].Mg_fac = STypPar[5][i]\n            SynTPtr0[i].Mg_slope = STypPar[6][i]\n            SynTPtr0[i].Mg_half = STypPar[7][i]\n            SynTPtr0[i].Gsyn = SynTPtr0[i].gmax * SynTPtr0[i].tc_on * \\\n                SynTPtr0[i].tc_off / (SynTPtr0[i].tc_off - SynTPtr0[i].tc_on)\n\n        numST = SynPar.shape[1]\n        SPList = SPMtx\n        MaxNumSyn = SPList.shape[2]\n\n        ConMtx0 = []\n        com_c = []\n        for i in range(N):\n            for j in range(N + M):\n                com_c.append(SynList())\n            ConMtx0.append(com_c)\n            com_c = []\n        for i in range(N):\n            for j in range(N + M):\n\n                ConMtx0[i][j].NumSyn = 0\n                while int(SPList[i][j][ConMtx0[i][j].NumSyn]) > 0:\n                    ConMtx0[i][j].NumSyn = ConMtx0[i][j].NumSyn + 1\n                    if (ConMtx0[i][j].NumSyn >= MaxNumSyn):\n                        break\n\n                if (ConMtx0[i][j].NumSyn > 0):\n                    for a in range(ConMtx0[i][j].NumSyn):\n                        ConMtx0[i][j].Syn.append(Synapse())\n                else:\n                    ConMtx0[i][j].Syn = []\n                k = 0\n\n                for k in range(ConMtx0[i][j].NumSyn):\n                    nst = SPList[i][j][k] - 1\n                   \n                    if (j < N):  \n                        InList = FALSE\n                        kk = 0\n                        while (kk < NPtr0[j].NumPreSyn):\n                            if (nst == NPtr0[j].PreSynList[kk]):\n                                InList = TRUE\n                                break\n                            kk = kk + 1\n                        ConMtx0[i][j].Syn[k].PreSynIdx = kk\n                        if (InList == FALSE):\n                            NPtr0[j].NumPreSyn = NPtr0[j].NumPreSyn + 1\n                            NPtr0[j].PreSynList = [0] * NPtr0[j].NumPreSyn\n                            NPtr0[j].PreSynList[kk] = nst\n                            for num in range(NPtr0[j].NumPreSyn):\n                                NPtr0[j].SDf.append(SynDepr())\n                            NPtr0[j].SDf[kk].use = SynPar[1][nst]\n                            NPtr0[j].SDf[kk].tc_rec = SynPar[2][nst]\n                            NPtr0[j].SDf[kk].tc_fac = SynPar[3][nst]\n                            for k2 in range(SizeHistOutput):\n                                NPtr0[j].SDf[kk].Adepr[k2] = 1.0\n                            NPtr0[j].SDf[kk].uprev[0] = SynPar[1][nst]\n                            NPtr0[j].SDf[kk].Rprev[0] = 1.0\n                        STno = int(SynPar[0][nst] - 1)\n                        ConMtx0[i][j].Syn[k].STPtr = SynTPtr0[STno]\n                        ConMtx0[i][j].Syn[k].wgt = SynPar[4][nst]\n                        ConMtx0[i][j].Syn[k].dtax = SynPar[5][nst]\n                        ConMtx0[i][j].Syn[k].p_fail = SynPar[6][nst]\n                        InList = FALSE\n                        kk = 0\n                        while (\n                                NPtr0[i].STList[kk] is not None and kk < NPtr0[i].NumSynType):\n                            if (NPtr0[i].STList[kk].No ==\n                                    ConMtx0[i][j].Syn[k].STPtr.No):\n                                InList = TRUE\n                            kk = kk + 1\n                        if (InList == FALSE):\n                            NPtr0[i].STList[kk] = ConMtx0[i][j].Syn[k].STPtr\n                            NPtr0[i].NumSynType = NPtr0[i].NumSynType + 1\n                            NPtr0[i].gfONsyn[kk] = 0.0\n                            NPtr0[i].gfOFFsyn[kk] = 0.0\n                    else:\n                        InList = FALSE\n                        kk = 0\n                        while (kk < InpNPtr0[j - N].NumPreSyn):\n                            if (nst == InpNPtr0[j - N].PreSynList[kk]):\n                                InList = TRUE\n                                break\n                            kk = kk + 1\n                        ConMtx0[i][j].Syn[k].PreSynIdx = kk\n                        if (InList == FALSE):\n                            InpNPtr0[j - N].NumPreSyn = InpNPtr0[j -\n                                                                 N].NumPreSyn + 1\n                            InpNPtr0[j - N].PreSynList = [0] * \\\n                                InpNPtr0[j - N].NumPreSyn\n                            InpNPtr0[j - N].PreSynList[kk] = nst\n                            for num in range(InpNPtr0[j - N].NumPreSyn):\n                                InpNPtr0[j - N].SDf.append(SynDepr())\n                            InpNPtr0[j - N].SDf[kk].use = SynPar[1][nst]\n                            InpNPtr0[j - N].SDf[kk].tc_rec = SynPar[2][nst]\n                            InpNPtr0[j - N].SDf[kk].tc_fac = SynPar[3][nst]\n                            for k2 in range(SizeHistOutput):\n                                InpNPtr0[j - N].SDf[kk].Adepr[k2] = 1.0\n                            InpNPtr0[j - N].SDf[kk].uprev[0] = SynPar[1][nst]\n                            InpNPtr0[j - N].SDf[kk].Rprev[0] = 1.0\n                        STno = int(SynPar[0][nst] - 1)\n                        ConMtx0[i][j].Syn[k].STPtr = SynTPtr0[STno]\n                        ConMtx0[i][j].Syn[k].wgt = SynPar[4][nst]\n                        ConMtx0[i][j].Syn[k].dtax = SynPar[5][nst]\n                        ConMtx0[i][j].Syn[k].p_fail = SynPar[6][nst]\n                        InList = FALSE\n                        kk = 0\n                        while (\n                                NPtr0[i].STList[kk] is not None and kk < NPtr0[i].NumSynType):\n                            if (NPtr0[i].STList[kk].No ==\n                                    ConMtx0[i][j].Syn[k].STPtr.No):\n                                InList = TRUE\n                            kk = kk + 1\n                        if (InList == FALSE):\n                            NPtr0[i].STList[kk] = ConMtx0[i][j].Syn[k].STPtr\n                            NPtr0[i].NumSynType = NPtr0[i].NumSynType + 1\n                            NPtr0[i].gfONsyn[kk] = 0.0\n                            NPtr0[i].gfOFFsyn[kk] = 0.0\n       \n        NoiseSyn = SynList()\n        NoiseSyn.NumSyn = NumSynType\n        NoiseSyn.Syn = []\n        for i in range(NoiseSyn.NumSyn):\n            NoiseSyn.Syn.append(Synapse())\n        for i in range(N):\n            for j in range(NoiseSyn.NumSyn):\n                STno = int(SynPar[0][numST - NoiseSyn.NumSyn + j] - 1)\n                NoiseSyn.Syn[j].STPtr = SynTPtr0[STno]\n                NoiseSyn.Syn[j].wgt = SynPar[4][numST - NoiseSyn.NumSyn + j]\n                NoiseSyn.Syn[j].dtax = SynPar[5][numST - NoiseSyn.NumSyn + j]\n                NoiseSyn.Syn[j].p_fail = SynPar[6][numST - NoiseSyn.NumSyn + j]\n                NPtr0[i].gfONnoise[j] = 0.0\n                NPtr0[i].gfOFFnoise[j] = 0.0\n      \n        NoiseStep = 1 / (NoiseDistr.shape[1] - 1)\n        \n        SynExpOn = [0] * NumSynType\n        SynExpOff = [0] * NumSynType\n       \n        NumEvt = EvtTimes.shape[1]\n      \n        NumView = ViewList.shape[0] * ViewList.shape[1]\n        fpOut = open(\"IDN_%i.dat\" % UniquePrint, \"w\")\n        fpOut2 = open(\"IDN2_%i.dat\" % UniquePrint, \"w\")\n        if (CtrPar[0][3] > NumVar):\n            CtrPar[0][3] = NumVar\n        NumOutp = 2  \n        if (NumOutp > 0):\n            NumSynInp = [0] * N\n            N_osc = [0] * N\n        TnextSyn = [0] * N\n        \n        \n        t0 = Tstart\n        time_num = 0\n        while (t0 < Tstop):\n           \n            if (t0 >= t_display):\n                print(\"%f percent\" % (t0 * 100 / Tstop))\n                t_display = t0 + 100\n           \n            t1 = t0 + dt0\n            EvtNo = -999\n            if (t1 > Tstop):\n                t1 = Tstop\n            for i in range(NumEvt):\n                if (EvtTimes[i * 2] > t0) and (EvtTimes[i * 2] <= t1):\n                    t1 = EvtTimes[i * 2]\n                    NextEvtT = t1\n                    EvtNo = i * 2\n                else:\n                    EvtOffT = EvtTimes[i * 2] + EvtTimes[i * 2 + 1]\n                    if (EvtOffT > t0 and EvtOffT <= t1):\n                        t1 = EvtOffT\n                        NextEvtT = t1\n                        EvtNo = i * 2 + 1\n            \n            t11 = t1\n            for i in range(M):\n                if (InpNPtr0[i].SPtrain[InpNPtr0[i].SP_ind] > t0) and (\n                        InpNPtr0[i].SPtrain[InpNPtr0[i].SP_ind] <= t11):\n                    t11 = InpNPtr0[i].SPtrain[InpNPtr0[i].SP_ind]\n                    print_flag = 1\n                else:\n                    print_flag = 0\n            t1 = t11\n            \n            for i in range(M):\n                if (InpNPtr0[i].SPtrain[InpNPtr0[i].SP_ind] == t1):\n                   \n                    if (InpNPtr0[i].SP_ind > 0):\n                        ISI_inp = t1 - \\\n                            InpNPtr0[i].SPtrain[InpNPtr0[i].SP_ind - 1]\n                    else:\n                        ISI_inp = 10.0e8\n                   \n                    InpNPtr0[i].SpikeTimes[InpNPtr0[i].SP_ind] = InpNPtr0[i].SPtrain[InpNPtr0[i].SP_ind]\n                    InpNPtr0[i].SP_ind = InpNPtr0[i].SP_ind + 1\n                   \n                    j = NumSpike[i + N] % SizeHistOutput\n                   \n                    for kk in range(InpNPtr0[i].NumPreSyn):\n                        if (InpNPtr0[i].SDf[kk].use > 0.0):\n                            InpNPtr0[i].SDf[kk].Adepr[j] = short_time(\n                                SizeHistOutput).syndepr(InpNPtr0[i].SDf[kk], ISI_inp, j)\n                    \n                    if (WriteST > 0):\n                        fpISI = open(\n                            \"ISIu%d_%i.dat\" %\n                            (i + N, UniquePrint), \"a\")\n                        fpISI.write(\"%f\\n\" % t1)\n                        fpISI.close()\n                    NumSpike[i + N] = NumSpike[i + N] + 1\n            \n            for i in range(N):\n                \n                t0_i = t0\n                \n                if (t0_i < t1):\n                    \n                    t1_i = t1\n                    \n                    if (TnextSyn[i] > t0_i and TnextSyn[i] < t1_i):\n                        t1_i = TnextSyn[i]\n                    \n                    vp = copy.copy(NPtr0[i].v[0])\n                    wp = copy.copy(NPtr0[i].v[1])\n                    dt = t1_i - t0_i\n                    \n                    if (NumSpike[i] > 0):\n                        if ((t0_i -\n                             NPtr0[i].SpikeTimes[(NumSpike[i] -\n                                                  1) %\n                                                 SizeHistOutput]) < 5):\n                            flag_dv = 0\n                        else:\n                            flag_dv = 1\n                    else:\n                        flag_dv = 1\n                    \n                    NPtr0[i] = copy.copy(NPtr0[i])\n                    try:\n                        NPtr0[i], gsyn_AN, gsyn_G, I_tot =  short_time(\n                                        SizeHistOutput).update(\n                            NPtr0[i], dt, NoiseSyn, flag_dv)\n                    except OverflowError:\n                        pass\n                    if (stop_flag > 0):\n                        print(\"%f %d %f %f\\n\" % (t0_i, i, vp, wp))\n                    for j in range(NPtr0[i].NumSynType):\n                        if (NPtr0[i].gfONsyn[j] <\n                                0 or NPtr0[i].gfOFFsyn[j] < 0):\n                            print(\n                                \"%d %d %f %f %f %f\\n\" %\n                                (i, j, t0_i, t1_i, NPtr0[i].gfONsyn[j], NPtr0[i].gfOFFsyn[j]))\n                    if (t1_i == t1):\n                        gsyn1[i] = gsyn_AN\n                        gsyn2[i] = gsyn_G\n                        Isyn[i] = I_tot\n                        if (I_tot < NPtr0[i].I_ref *\n                                1.01 and I_tot > NPtr0[i].I_ref * 0.99):\n                            flag_osc[i] = flag_osc[i] + 1\n                        else:\n                            flag_osc[i] = 0\n                   \n                    if (flag_osc[i] >= 200 and NumOutp > 0):  \n                        N_osc[i] = N_osc[i] + 1\n                    \n                    if ((NPtr0[i].v[0] >= NPtr0[i].Vup)\n                            and (vp < NPtr0[i].Vup)):\n                        \n                        t1_i = t0_i + dt * \\\n                            (NPtr0[i].Vup - vp) / (NPtr0[i].v[0] - vp)\n                        \n                        if (NumSpike[i] > 0):\n                            ISI = t1_i - \\\n                                NPtr0[i].SpikeTimes[(NumSpike[i] - 1) % SizeHistOutput]\n                        else:\n                            ISI = 10.0e8\n                       \n                        if (ISI > 5):\n                            \n                            w_Vup = wp + \\\n                                ((NPtr0[i].v[1] - wp) / dt) * (t1_i - t0_i)\n                            NPtr0[i].v[0] = NPtr0[i].Vr\n                            NPtr0[i].v[1] = w_Vup + NPtr0[i].b\n                            \n                            j = NumSpike[i] % SizeHistOutput\n                            NPtr0[i].SpikeTimes[j] = t1_i\n                           \n                            for kk in range(NPtr0[i].NumPreSyn):\n                                if (NPtr0[i].SDf[kk].use > 0.0):\n                                    NPtr0[i].SDf[kk].Adepr[j] = short_time(\n                                        SizeHistOutput).syndepr(NPtr0[i].SDf[kk], ISI, j)\n                            \n                            if (WriteST > 0):\n                                fpISI = open(\n                                    \"ISIu%d_%i.dat\" %\n                                    (i, UniquePrint), \"a\")\n                                fpISI.write(\"%f\\n\" % t1_i)\n                                fpISI.close()\n                            \n                            NumSpike[i] = NumSpike[i] + 1\n                            \n                            dt = t1_i - t0_i\n                        else:\n                            NPtr0[i].v[0] = vp\n                            NPtr0[i].v[1] = wp\n                            # reset t1_i\n                            t1_i = dt + t0_i\n                        \n                        if (t1_i == t1):\n                            gsyn_AN, I_tot, gsyn_G = short_time(\n                                SizeHistOutput).set_gsyn(NPtr0[i], dt, vp, NoiseSyn)\n                            gsyn1[i] = gsyn_AN\n                            gsyn2[i] = gsyn_G\n                            Isyn[i] = I_tot\n                    \n                    for j in range(NoiseSyn.NumSyn):\n                        SynExpOn[j] = math.exp(-dt /\n                                               (NoiseSyn.Syn[j].STPtr).tc_on)\n                        SynExpOff[j] = math.exp(-dt /\n                                                (NoiseSyn.Syn[j].STPtr).tc_off)\n                        rand_num = NoiseDistr[0][rand.randint(\n                            0, 1 / NoiseStep)]\n                        NPtr0[i].gfONnoise[j] = 0.0\n                        NPtr0[i].gfOFFnoise[j] = 0.0\n                    \n                    for j in range(NPtr0[i].NumSynType):\n                        NPtr0[i].gfONsyn[j] *= math.exp(-dt /\n                                                        (NPtr0[i].STList[j]).tc_on)\n                        NPtr0[i].gfOFFsyn[j] *= math.exp(-dt /\n                                                         (NPtr0[i].STList[j]).tc_off)\n                    \n                    TnextSyn[i] = Tstop + 100.0\n                   \n                    for j in range(N):\n                        \n                        for k in range(ConMtx0[i][j].NumSyn):\n                            \n                            kk = NumSpike[j] - 1\n                            while (\n                                kk >= 0 and (\n                                    NumSpike[j] -\n                                    kk) <= SizeHistOutput):\n                                if (t0_i >= (\n                                        NPtr0[j].SpikeTimes[kk % SizeHistOutput] + ConMtx0[i][j].Syn[k].dtax)):\n                                    break\n                                else:\n                                    \n                                    if ((t1_i >= NPtr0[j].SpikeTimes[kk % SizeHistOutput] + ConMtx0[i][j].Syn[\n                                            k].dtax) and (\n                                            rand.uniform(0, 1) > ConMtx0[i][j].Syn[k].p_fail)):\n                                        for k2 in range(NPtr0[i].NumSynType):\n                                            if (NPtr0[i].STList[k2].No ==\n                                                    ConMtx0[i][j].Syn[k].STPtr.No):\n                                                Aall = NPtr0[j].SDf[ConMtx0[i][j].Syn[k].PreSynIdx].Adepr[\n                                                    kk % SizeHistOutput] * ConMtx0[i][j].Syn[k].wgt * \\\n                                                    ConMtx0[i][j].Syn[k].STPtr.Gsyn\n                                                NPtr0[i].gfONsyn[k2] += Aall\n                                                NPtr0[i].gfOFFsyn[k2] += Aall\n                                                if (NumOutp > 0):\n                                                    NumSynInp[i] = NumSynInp[i] + 1.0\n                                    else:\n                                        \n                                        if (NPtr0[j].SpikeTimes[kk % SizeHistOutput] +\n                                                ConMtx0[i][j].Syn[k].dtax < TnextSyn[i]):\n                                            TnextSyn[i] = NPtr0[j].SpikeTimes[kk %\n                                                                              SizeHistOutput] + ConMtx0[i][j].Syn[k].dtax\n                                kk = kk - 1\n                    \n                    for j in range(N, N + M):\n                        \n                        for k in range(ConMtx0[i][j].NumSyn):\n                            kk = NumSpike[j] - 1\n                            while (kk >= 0):\n                               \n                                if (t0_i >= (\n                                        InpNPtr0[j - N].SpikeTimes[kk] + ConMtx0[i][j].Syn[k].dtax)):\n                                    break\n                                else:\n                                   \n                                    if ((t1_i >= InpNPtr0[j - N].SpikeTimes[kk] + ConMtx0[i][j].Syn[k].dtax) and (\n                                            rand.uniform(0, 1) > ConMtx0[i][j].Syn[k].p_fail)):\n                                        for k2 in range(NPtr0[i].NumSynType):\n                                            if (NPtr0[i].STList[k2] ==\n                                                    ConMtx0[i][j].Syn[k].STPtr):\n                                                Aall = InpNPtr0[j - N].SDf[ConMtx0[i][j].Syn[k].PreSynIdx].Adepr[kk %\n                                                                                                                 SizeHistOutput] * ConMtx0[i][j].Syn[k].wgt * (ConMtx0[i][j].Syn[k].STPtr).Gsyn\n                                                NPtr0[i].gfONsyn[k2] += Aall\n                                                NPtr0[i].gfOFFsyn[k2] += Aall\n                                                if (NumOutp > 0):\n                                                    NumSynInp[i] = NumSynInp[i] + 1.0\n                                    else:\n                                        \n                                        if (InpNPtr0[j - N].SpikeTimes[kk] + \\\n                                            ConMtx0[i][j].Syn[k].dtax < TnextSyn[i]):\n                                            TnextSyn[i] = InpNPtr0[j - N].SpikeTimes[kk] + \\\n                                                ConMtx0[i][j].Syn[k].dtax\n                                kk = kk - 1\n                    \n                    t0_i = t1_i\n            \n            for i in range(NumView):\n                fpOut.write(\"%lf %d\" % (t1, ViewList[i][0]))\n                for k in range(int(CtrPar[0][3])):\n                    fpOut.write(\" %lf\" % NPtr0[ViewList[i][0] - 1].v[k])\n                    fpOut.write(\"\\n\")\n            \n            for i in range(NumView):\n                fpOut2.write(\" %f %f %f\" %\n                             (gsyn1[ViewList[i][0] -\n                                    1], gsyn2[ViewList[i][0] -\n                                              1], Isyn[ViewList[i][0] -\n                              1]))\n            fpOut2.write(\"\\n\")\n            \n            if (EvtNo >= 0 and t1 >= NextEvtT):\n                if ((EvtNo % 2) == 0):\n                    for i in range(N):\n                        NPtr0[i].Iinj = EvtMtx[int(i + (EvtNo / 2) * N)]\n                else:\n                    for i in range(N):\n                        NPtr0[i].Iinj = 0.0\n            \n            t0 = t1\n        \n        if (NumView > 0):\n            fpOut.close()\n            fpOut2.close()\n        \n        STMtx = []\n        if (CtrPar[0][4] > 0):\n            for i in range(N + M):\n                if (os.path.exists('ISIu%d_0.dat' % i)):\n                    ST = pd.read_table('ISIu%d_0.dat' % i, header=None)\n                    content = []\n                    for j in range(ST.shape[0]):\n                        content.append(ST.iloc[j][0])\n                    STMtx.append(content)\n                    os.remove('ISIu%d_0.dat' % i)\n                else:\n                    STMtx.append(-1)\n        T = []\n        V = []\n        if (ViewList is not None):\n            X = pd.read_table('IDN_%d.dat' % UniqueNum, header=None)\n            for i in range(X.shape[0]):\n                T.append(X.iloc[i][0])\n                content = []\n                for j in range(X.shape[1] - 1):\n                    content.append(X.iloc[i][j + 1])\n                V.append(content)\n            os.remove('IDN_%d.dat' % UniqueNum)\n        os.remove('IDN2_0.dat')\n        scio.savemat('PFC_%dN_500ms.mat' %\n                     (N + M), {'N': N, 'T': T, 'V': V, 'STMtx': STMtx})\n\n\n\nclass SynType:\n    def __init__(self):\n        \"\"\"\n        Parameters of short-term synaptic plasticity model\n        \"\"\"\n        self.No = 0\n        self.gmax = 0\n        self.tc_on = 0\n        self.tc_off = 0\n        self.Erev = 0\n        self.Mg_gate = 0\n        self.Mg_fac = 0\n        self.Mg_slope = 0\n        self.Mg_half = 0\n\n        self.Gsyn = 0\n\n\nclass Neuron:\n    \"\"\"\n    Parameters of neurons\n    \"\"\"\n    gfONsyn = None\n    gfOFFsyn = None\n    gfONnoise = None\n    gfOFFnoise = None\n    SpikeTimes = None\n    v = None\n    dv = None\n\n    def __init__(self):\n        MaxNumSTperN = six_layer_pfc().MaxNumSTperN\n        SizeHistOutput = six_layer_pfc().SizeHistOutput\n\n        self.Cm = 0\n        self.gL = 0\n        self.EL = 0\n        self.sf = 0\n        self.Vup = 0\n        self.tcw = 0\n        self.a = 0\n        self.b = 0\n        self.Vr = 0\n        self.Vth = 0\n        self.I_ref = 0\n        self.v_dep = 0\n        self.NumSynType = 0\n\n        self.Iinj = 0\n        self.v = [0] * 2\n        self.dv = [0] * 2\n        self.STList = []\n        self.gfONsyn = [0] * MaxNumSTperN\n        self.gfOFFsyn = [0] * MaxNumSTperN\n        self.gfONnoise = [0] * MaxNumSTperN\n        self.gfOFFnoise = [0] * MaxNumSTperN\n        self.SpikeTimes = [0] * SizeHistOutput\n        self.NumPreSyn = 0\n        self.PreSynList = []\n        self.SDf = []\n\n\nclass InpNeuron:\n    \"\"\"\n    Input parameters of neurons\n    \"\"\"\n    SPtrain = None\n    SpikeTimes = None\n\n    def __init__(self):\n        SizeHistInput = six_layer_pfc().SizeHistInput\n        self.SPtrain = [0] * SizeHistInput\n        self.SpikeTimes = [0] * SizeHistInput\n        self.SP_ind = 0\n        self.NumSynType = 0\n        self.NumPreSyn = 0\n        self.PreSynList = []\n        self.SDf = []\n\n\nclass Synapse:\n    \"\"\"\n    Synaptic parameters\n    \"\"\"\n    def __init__(self):\n        self.STPtr = SynType()\n        self.dtax = 0\n        self.wgt = 0\n        self.p_fail = 0\n        self.PreSynIdx = 0\n\n\nclass SynDepr:\n    \"\"\"\n    Parameters of synaptic current model\n    \"\"\"\n    Adepr = None\n    uprev = None\n    Rprev = None\n\n    def __init__(self):\n        SizeHistOutput = six_layer_pfc().SizeHistOutput\n        self.use = 0\n        self.tc_rec = 0\n        self.tc_fac = 0\n        self.Adepr = [0] * SizeHistOutput\n        self.uprev = [0] * SizeHistOutput\n        self.Rprev = [0] * SizeHistOutput\n\n\nclass SynList:\n    \"\"\"\n    Parameters of synapse list\n    \"\"\"\n    def __init__(self):\n        self.NumSyn = 0\n        self.Syn = []\n\n\nif __name__ == '__main__':\n     \"\"\"\n        After downloading the data file on the network disk, modify the file path to the downloaded placement path\n     \"\"\"\n     test = six_layer_pfc()\n     inputpath = 'data100.mat'\n     test.mex_function(inputpath)\n     outputpath='PFC_99N_500ms.mat'\n     test.picture(outputpath)\n"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/MacaqueBrain/README.md",
    "content": "## Macaque Brain Simulation\n\n## Description\nMacaque Brain Simulation is a large scale brain modeling framework depending on braincog framework.\n\n## Requirements:\n* numpy >= 1.21.2\n* scipy >= 1.8.0\n* h5py >= 3.6.0\n* torch >= 1.10\n* torchvision >= 0.12.0\n* torchaudio  >= 0.11.0\n* timm >= 0.5.4\n* matplotlib >= 3.5.1\n* einops >= 0.4.1\n* thop >= 0.0.31\n* pyyaml >= 6.0\n* loris >= 0.5.3\n* pandas >= 1.4.2  \n* tonic (special)\n* pandas >= 1.4.2  \n\n## Input:\n\nThe binary connectivity matrix can be obtained from the following link:\nhttps://drive.google.com/file/d/1LsNupIx3Nk-Cn_MowF6O-SCY27wdRYas/view?usp=sharing\n\nThe brain region's name can be obained from the following link:\nhttps://drive.google.com/file/d/1iNI0HR3teUj4yshK8RlSJq6gIWbRdBI1/view?usp=sharing\n\n\n## Example:\n\n```shell \ncd ~/examples/Multi-scale Brain Structure Simulation/MacaqueBrain/\npython macaque_brain.py\n```\n\n## Parameters:\nThe parameters are similar to mouse brain simulation\n\n## Citations:\nIf you find this package helpful, please consider citing the following papers:\n\n    @article{Liu2016,\n    author={Liu, Xin and Zeng, Yi and Zhang, Tielin and Xu, Bo},\n    title={Parallel Brain Simulator: A Multi-scale and Parallel Brain-Inspired Neural Network Modeling and Simulation Platform},\n    journal={Cognitive Computation},\n    year={2016},\n    month={Oct},\n    day={01},\n    volume={8},\n    number={5},\n    pages={967--981},\n    issn={1866-9964},\n    doi={10.1007/s12559-016-9411-y},\n    url={https://doi.org/10.1007/s12559-016-9411-y}\n    }\n\n    @misc{https://doi.org/10.48550/arxiv.2207.08533,\n      doi = {10.48550/ARXIV.2207.08533},\n      url = {https://arxiv.org/abs/2207.08533},\n      author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},\n      title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},\n      publisher = {arXiv},\n      year = {2022},\n    }\n"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/MacaqueBrain/macaque_brain.py",
    "content": "import time\n\nimport numpy as np\nimport scipy.io as scio\nimport torch\nfrom torch import nn\nfrom braincog.base.node.node import *\nfrom braincog.base.brainarea.BrainArea import *\nimport pandas as pd\nimport matplotlib.pyplot as plt\n\ndevice = 'cuda:0'\n\nclass Syn(nn.Module):\n    def __init__(self, syn, weight, neuron_num, tao_d, tao_r, dt, device):\n        super().__init__()\n        self.pre = syn[1]\n        self.post = syn[0]\n        self.syn_num = len(syn)\n        self.w = torch.sparse_coo_tensor(syn.t(), weight,\n                                         size=(neuron_num, neuron_num))\n        self.tao_d = tao_d\n        self.tao_r = tao_r\n        self.dt = dt\n        self.lamda_d = self.dt / self.tao_d\n        self.lamda_r = self.dt / self.tao_r\n\n        self.s = torch.zeros(neuron_num, device=device)\n        self.r = torch.zeros(neuron_num, device=device)\n        self.dt = dt\n\n    def forward(self, neuron):\n        neuron.Iback = neuron.Iback + neuron.dt_over_tau * (\n                torch.randn(neuron.neuron_num, device=device, requires_grad=False) - neuron.Iback)\n        neuron.Ieff = neuron.Iback / neuron.sqrt_coeff * neuron.sig + neuron.mu\n        self.s = self.s + self.lamda_r * (-self.s + 1 / self.tao_d * neuron.spike)\n        self.r = self.r - self.lamda_d * self.r + self.dt * self.s\n        self.I = torch.sparse.mm(self.w, self.r.unsqueeze(-1)).squeeze() + neuron.Ieff\n        return self.I\n\nclass brain(nn.Module):\n    def __init__(self, syn, weight, neuron_model, p_neuron, dt, device):\n        super().__init__()\n        if neuron_model == 'HH':\n            self.neurons = HHNode(p_neuron, dt, device)\n        elif neuron_model == 'aEIF':\n            self.neurons = aEIF(p_neuron, dt, device)\n        self.neuron_num = len(p_neuron[0])\n        self.syns = Syn(syn, weight, self.neuron_num, 3, 6, dt, device)\n\n    def forward(self, inputs):\n        I = self.syns(self.neurons)\n        self.neurons(I)\n\n\ndef brain_region(neuron_num):\n    region = []\n    start = 0\n    end = 0\n    for i in range(len(neuron_num)):\n        end += neuron_num[i].item()\n        region.append([start, end])\n        start = end\n    return torch.tensor(region)\n\ndef neuron_type(neuron_num, ratio, regions):\n    neuron_num = neuron_num.reshape(-1, 1)\n    neuron_type = torch.floor(ratio * neuron_num).int() + regions[:, 0].reshape(-1, 1)\n    return neuron_type\n\ndef syn_within_region(syn_num, region):\n    start = 1\n    for neurons in region:\n        if start:\n            syn = torch.randint(neurons[0], neurons[1],\n                            size=((neurons[1]-neurons[0]) * syn_num, 2), device=device)\n            start = 0\n        else:\n            syn = torch.concatenate((syn, torch.randint(neurons[0], neurons[1],\n                            size=((neurons[1]-neurons[0]) * syn_num, 2), device=device)))\n    return syn\n\ndef syn_cross_region(weight_matrix, region):\n    start = 1\n    for i in range(len(weight_matrix)):\n        for j in range(len(weight_matrix)):\n            if weight_matrix[i][j] < 10:\n                continue\n            else:\n                pre = torch.randint(region[j][0], region[j][1],\n                                    size=(weight_matrix[i][j], 1), device=device)\n                post = torch.randint(region[i][0], region[i][1],\n                                     size=(weight_matrix[i][j], 1), device=device)\n                if start:\n                    syn = torch.concatenate((post, pre), dim=1)\n                    start = 0\n                else:\n                    syn = torch.concatenate((syn, torch.concatenate((post, pre), dim=1)))\n    return syn\n\nsize = 10000\nneuron_model = 'aEIF'\nweight_matrix = torch.tensor(scio.loadmat('./maque.mat')['connect']) * 100\nweight_matrix = weight_matrix.int()\nsyn_num = 10\n\nNR = len(weight_matrix)\ndata = size * np.ones(NR)\nneuron_num = np.array(data).astype(np.int32)\nneuron_num = torch.from_numpy(neuron_num)\nregions = brain_region(neuron_num)\nratio = torch.tensor([[0.7, 0.9, 1.0] * NR]).reshape(NR, 3)\nneuron_types = neuron_type(neuron_num, ratio, regions)\nsyn_1 = syn_within_region(syn_num, regions)\nsyn_2 = syn_cross_region(weight_matrix, regions)\nsyn = torch.concatenate((syn_1, syn_2))\nprint(syn.shape)\nweight = -torch.ones(len(syn), device=device, requires_grad=False)\nif neuron_model == 'aEIF':\n    threshold = torch.zeros(regions[-1][1], device=device, requires_grad=False)\n    v_reset = torch.zeros(regions[-1][1], device=device, requires_grad=False)\n    c_m = torch.zeros(regions[-1][1], device=device, requires_grad=False)\n    tao_w = torch.zeros(regions[-1][1], device=device, requires_grad=False)\n    alpha_ad = torch.zeros(regions[-1][1], device=device, requires_grad=False)\n    beta_ad = torch.zeros(regions[-1][1], device=device, requires_grad=False)\nelif neuron_model == 'HH':\n    threshold = torch.zeros(regions[-1][1], device=device, requires_grad=False)\nfor i in range(len(neuron_types)):\n    pre = syn[:, 0]\n    mask = (pre >= regions[i][0]) & (pre < neuron_types[i][0])\n    indices = torch.where(mask)\n    weight[indices] = 1.5\n    if neuron_model == 'aEIF':\n        if i < 177:\n            threshold[regions[i][0]:neuron_types[i][0]] = -50\n            threshold[neuron_types[i][0]:neuron_types[i][1]] = -44\n            threshold[neuron_types[i][1]:neuron_types[i][2]] = -45\n            v_reset[regions[i][0]:neuron_types[i][0]] = -110\n            v_reset[neuron_types[i][0]:neuron_types[i][1]] = -110\n            v_reset[neuron_types[i][1]:neuron_types[i][2]] = -66\n            c_m[regions[i][0]:neuron_types[i][0]] = 10\n            c_m[neuron_types[i][0]:neuron_types[i][1]] = 10\n            c_m[neuron_types[i][1]:neuron_types[i][2]] = 8.5\n            tao_w[regions[i][0]:neuron_types[i][0]] = 1\n            tao_w[neuron_types[i][0]:neuron_types[i][1]] = 2\n            tao_w[neuron_types[i][1]:neuron_types[i][2]] = 2\n            alpha_ad[regions[i][0]:neuron_types[i][0]] = 0\n            alpha_ad[neuron_types[i][0]:neuron_types[i][1]] = -0.2\n            alpha_ad[neuron_types[i][1]:neuron_types[i][2]] = -0.2\n            beta_ad[regions[i][0]:neuron_types[i][0]] = 0\n            beta_ad[neuron_types[i][0]:neuron_types[i][1]] = 0.45\n            beta_ad[neuron_types[i][1]:neuron_types[i][2]] = 0.45\n        else:\n            threshold[regions[i][0]:neuron_types[i][0]] = -50\n            threshold[neuron_types[i][0]:neuron_types[i][1]] = -50\n            threshold[neuron_types[i][1]:neuron_types[i][2]] = -45\n            v_reset[regions[i][0]:neuron_types[i][0]] = -60\n            v_reset[neuron_types[i][0]:neuron_types[i][1]] = -60\n            v_reset[neuron_types[i][1]:neuron_types[i][2]] = -65\n            c_m[regions[i][0]:neuron_types[i][0]] = 20\n            c_m[neuron_types[i][0]:neuron_types[i][1]] = 2\n            c_m[neuron_types[i][1]:neuron_types[i][2]] = 4\n            tao_w[regions[i][0]:neuron_types[i][0]] = 1\n            tao_w[neuron_types[i][0]:neuron_types[i][1]] = 2\n            tao_w[neuron_types[i][1]:neuron_types[i][2]] = 2\n            alpha_ad[regions[i][0]:neuron_types[i][0]] = 0\n            alpha_ad[neuron_types[i][0]:neuron_types[i][1]] = -0.2\n            alpha_ad[neuron_types[i][1]:neuron_types[i][2]] = -0.2\n            beta_ad[regions[i][0]:neuron_types[i][0]] = 0\n            beta_ad[neuron_types[i][0]:neuron_types[i][1]] = 0.45\n            beta_ad[neuron_types[i][1]:neuron_types[i][2]] = 0.45\n    elif neuron_model == 'HH':\n        threshold[regions[i][0]:neuron_types[i][0]] = 50\n        threshold[neuron_types[i][0]:neuron_types[i][1]] = 60\n        threshold[neuron_types[i][1]:neuron_types[i][2]] = 60\n\nif neuron_model == 'aEIF':\n    p_neuron = [threshold, v_reset, c_m, tao_w, alpha_ad, beta_ad]\n    dt = 1\n    T = 300\nelif neuron_model == 'HH':\n    p_neuron = [threshold, 120, 36, 0.3, 115, -12, 10.6, 1]\n    dt = 0.01\n    T = 10000\nmodel = brain(syn, weight, neuron_model, p_neuron, dt, device)\nIraster = []\nfor t in range(T):\n    model(0)\n    print(torch.sum(model.neurons.spike))\n    Isp = torch.nonzero(model.neurons.spike)\n    print(len(Isp))\n    if (len(Isp) != 0):\n        left = t * torch.ones((len(Isp)), device=device, requires_grad=False)\n        left = left.reshape(len(left), 1)\n        mide = torch.concatenate((left, Isp), dim=1)\n    if (len(Isp) != 0) and (len(Iraster) != 0):\n        Iraster = torch.concatenate((Iraster, mide), dim=0)\n    if (len(Iraster) == 0) and (len(Isp) != 0):\n        Iraster = mide\n\nIraster = torch.tensor(Iraster).transpose(0, 1)\ntorch.save(Iraster, \"./maque.pt\")\n"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/MouseBrain/README.md",
    "content": "## Input:\n\n* Program to enter a table of connection weights between 213 brain regions, which is saved in 'mouse_weight.pt'. The brain regions' name and neuron number are saved in 'mouse_brain_region.xlsx'. These two files are available in the follow link:\n\nhttps://drive.google.com/drive/folders/1MWHY52gKPGneBEJxJN9DzE7thnLrhG1j?usp=sharing\n\n## output\n\n* The program generates a data file of the individual neuron firing time points recorded, and the large number of data points requires the use of drawing software to display the results.\n\n## setting:\n\n* scale: The scale of the number of neurons\n* neuron_model: ‘HHNode’ or ‘aEIF’\n* weight_matrix: Matrix of the number of synaptic connections between brain regions\n* neuron_num: The number of neurons in each brain region\n* ratio: the ratio of each neuron type in each brain region\n* syn_num: average number of synapses per neuron within region \n"
  },
  {
    "path": "examples/Multiscale_Brain_Structure_Simulation/MouseBrain/mouse_brain.py",
    "content": "import time\n\nimport numpy as np\nimport scipy.io as scio\nimport torch\nfrom torch import nn\nfrom braincog.base.node.node import *\nfrom braincog.base.brainarea.BrainArea import *\nimport pandas as pd\nimport matplotlib.pyplot as plt\n\ndevice = 'cuda:0'\n\nclass Syn(nn.Module):\n    def __init__(self, syn, weight, neuron_num, tao_d, tao_r, dt, device):\n        super().__init__()\n        self.pre = syn[1]\n        self.post = syn[0]\n        self.syn_num = len(syn)\n        self.w = torch.sparse_coo_tensor(syn.t(), weight,\n                                         size=(neuron_num, neuron_num))\n        self.tao_d = tao_d\n        self.tao_r = tao_r\n        self.dt = dt\n        self.lamda_d = self.dt / self.tao_d\n        self.lamda_r = self.dt / self.tao_r\n\n        self.s = torch.zeros(neuron_num, device=device)\n        self.r = torch.zeros(neuron_num, device=device)\n        self.dt = dt\n\n    def forward(self, neuron):\n        neuron.Iback = neuron.Iback + neuron.dt_over_tau * (\n                torch.randn(neuron.neuron_num, device=device, requires_grad=False) - neuron.Iback)\n        neuron.Ieff = neuron.Iback / neuron.sqrt_coeff * neuron.sig + neuron.mu\n        self.s = self.s + self.lamda_r * (-self.s + 1 / self.tao_d * neuron.spike)\n        self.r = self.r - self.lamda_d * self.r + self.dt * self.s\n        self.I = torch.sparse.mm(self.w, self.r.unsqueeze(-1)).squeeze() + neuron.Ieff\n        return self.I\n\nclass brain(nn.Module):\n    def __init__(self, syn, weight, neuron_model, p_neuron, dt, device):\n        super().__init__()\n        if neuron_model == 'HH':\n            self.neurons = HHNode(p_neuron, dt, device)\n        elif neuron_model == 'aEIF':\n            self.neurons = aEIF(p_neuron, dt, device)\n        self.neuron_num = len(p_neuron[0])\n        self.syns = Syn(syn, weight, self.neuron_num, 3, 6, dt, device)\n\n    def forward(self, inputs):\n        I = self.syns(self.neurons)\n        self.neurons(I)\n\n\ndef brain_region(neuron_num):\n    region = []\n    start = 0\n    end = 0\n    for i in range(len(neuron_num)):\n        end += neuron_num[i].item()\n        region.append([start, end])\n        start = end\n    return torch.tensor(region)\n\ndef neuron_type(neuron_num, ratio, regions):\n    neuron_num = neuron_num.reshape(-1, 1)\n    neuron_type = torch.floor(ratio * neuron_num).int() + regions[:, 0].reshape(-1, 1)\n    return neuron_type\n\ndef syn_within_region(syn_num, region):\n    start = 1\n    for neurons in region:\n        if start:\n            syn = torch.randint(neurons[0], neurons[1],\n                            size=((neurons[1]-neurons[0]) * syn_num, 2), device=device)\n            start = 0\n        else:\n            syn = torch.concatenate((syn, torch.randint(neurons[0], neurons[1],\n                            size=((neurons[1]-neurons[0]) * syn_num, 2), device=device)))\n    return syn\n\ndef syn_cross_region(weight_matrix, region):\n    start = 1\n    for i in range(len(weight_matrix)):\n        for j in range(len(weight_matrix)):\n            if weight_matrix[i][j] < 10:\n                continue\n            else:\n                pre = torch.randint(region[j][0], region[j][1],\n                                    size=(weight_matrix[i][j], 1), device=device)\n                post = torch.randint(region[i][0], region[i][1],\n                                     size=(weight_matrix[i][j], 1), device=device)\n                if start:\n                    syn = torch.concatenate((post, pre), dim=1)\n                    start = 0\n                else:\n                    syn = torch.concatenate((syn, torch.concatenate((post, pre), dim=1)))\n    return syn\n\nscale = 0.1\nneuron_model = 'aEIF'\nweight_matrix = torch.load('./mouse_weight.pt') * scale\nweight_matrix = weight_matrix.int()\ndata = pd.read_excel('./mouse_brain_region.xlsx', sheet_name='Sheet1', header=None)\ndata = data.values\nname = data[0]\nneuron_num = np.array(data[1] * scale).astype(np.int32)\nneuron_num = torch.from_numpy(neuron_num)\nratio = torch.tensor([[0.7, 0.9, 1.0] * 213]).reshape(213, 3)\nsyn_num = 10\n\nregions = brain_region(neuron_num)\nneuron_types = neuron_type(neuron_num, ratio, regions)\nsyn_1 = syn_within_region(syn_num, regions)\nsyn_2 = syn_cross_region(weight_matrix, regions)\nsyn = torch.concatenate((syn_1, syn_2))\nprint(syn.shape)\nweight = -torch.ones(len(syn), device=device, requires_grad=False)\nif neuron_model == 'aEIF':\n    threshold = torch.zeros(regions[-1][1], device=device, requires_grad=False)\n    v_reset = torch.zeros(regions[-1][1], device=device, requires_grad=False)\n    c_m = torch.zeros(regions[-1][1], device=device, requires_grad=False)\n    tao_w = torch.zeros(regions[-1][1], device=device, requires_grad=False)\n    alpha_ad = torch.zeros(regions[-1][1], device=device, requires_grad=False)\n    beta_ad = torch.zeros(regions[-1][1], device=device, requires_grad=False)\nelif neuron_model == 'HH':\n    threshold = torch.zeros(regions[-1][1], device=device, requires_grad=False)\nfor i in range(len(neuron_types)):\n    pre = syn[:, 0]\n    mask = (pre >= regions[i][0]) & (pre < neuron_types[i][0])\n    indices = torch.where(mask)\n    weight[indices] = 1.5\n    if neuron_model == 'aEIF':\n        if i < 177:\n            threshold[regions[i][0]:neuron_types[i][0]] = -50\n            threshold[neuron_types[i][0]:neuron_types[i][1]] = -44\n            threshold[neuron_types[i][1]:neuron_types[i][2]] = -45\n            v_reset[regions[i][0]:neuron_types[i][0]] = -110\n            v_reset[neuron_types[i][0]:neuron_types[i][1]] = -110\n            v_reset[neuron_types[i][1]:neuron_types[i][2]] = -66\n            c_m[regions[i][0]:neuron_types[i][0]] = 10\n            c_m[neuron_types[i][0]:neuron_types[i][1]] = 10\n            c_m[neuron_types[i][1]:neuron_types[i][2]] = 8.5\n            tao_w[regions[i][0]:neuron_types[i][0]] = 1\n            tao_w[neuron_types[i][0]:neuron_types[i][1]] = 2\n            tao_w[neuron_types[i][1]:neuron_types[i][2]] = 2\n            alpha_ad[regions[i][0]:neuron_types[i][0]] = 0\n            alpha_ad[neuron_types[i][0]:neuron_types[i][1]] = -0.2\n            alpha_ad[neuron_types[i][1]:neuron_types[i][2]] = -0.2\n            beta_ad[regions[i][0]:neuron_types[i][0]] = 0\n            beta_ad[neuron_types[i][0]:neuron_types[i][1]] = 0.45\n            beta_ad[neuron_types[i][1]:neuron_types[i][2]] = 0.45\n        else:\n            threshold[regions[i][0]:neuron_types[i][0]] = -50\n            threshold[neuron_types[i][0]:neuron_types[i][1]] = -50\n            threshold[neuron_types[i][1]:neuron_types[i][2]] = -45\n            v_reset[regions[i][0]:neuron_types[i][0]] = -60\n            v_reset[neuron_types[i][0]:neuron_types[i][1]] = -60\n            v_reset[neuron_types[i][1]:neuron_types[i][2]] = -65\n            c_m[regions[i][0]:neuron_types[i][0]] = 20\n            c_m[neuron_types[i][0]:neuron_types[i][1]] = 2\n            c_m[neuron_types[i][1]:neuron_types[i][2]] = 4\n            tao_w[regions[i][0]:neuron_types[i][0]] = 1\n            tao_w[neuron_types[i][0]:neuron_types[i][1]] = 2\n            tao_w[neuron_types[i][1]:neuron_types[i][2]] = 2\n            alpha_ad[regions[i][0]:neuron_types[i][0]] = 0\n            alpha_ad[neuron_types[i][0]:neuron_types[i][1]] = -0.2\n            alpha_ad[neuron_types[i][1]:neuron_types[i][2]] = -0.2\n            beta_ad[regions[i][0]:neuron_types[i][0]] = 0\n            beta_ad[neuron_types[i][0]:neuron_types[i][1]] = 0.45\n            beta_ad[neuron_types[i][1]:neuron_types[i][2]] = 0.45\n    elif neuron_model == 'HH':\n        threshold[regions[i][0]:neuron_types[i][0]] = 50\n        threshold[neuron_types[i][0]:neuron_types[i][1]] = 60\n        threshold[neuron_types[i][1]:neuron_types[i][2]] = 60\n\nif neuron_model == 'aEIF':\n    p_neuron = [threshold, v_reset, c_m, tao_w, alpha_ad, beta_ad]\n    dt = 1\n    T = 300\nelif neuron_model == 'HH':\n    p_neuron = [threshold, 120, 36, 0.3, 115, -12, 10.6, 1]\n    dt = 0.01\n    T = 10000\nmodel = brain(syn, weight, neuron_model, p_neuron, dt, device)\nIraster = []\nfor t in range(T):\n    model(0)\n    print(torch.sum(model.neurons.spike))\n    Isp = torch.nonzero(model.neurons.spike)\n    print(len(Isp))\n    if (len(Isp) != 0):\n        left = t * torch.ones((len(Isp)), device=device, requires_grad=False)\n        left = left.reshape(len(left), 1)\n        mide = torch.concatenate((left, Isp), dim=1)\n    if (len(Isp) != 0) and (len(Iraster) != 0):\n        Iraster = torch.concatenate((Iraster, mide), dim=0)\n    if (len(Iraster) == 0) and (len(Isp) != 0):\n        Iraster = mide\n\nIraster = torch.tensor(Iraster).transpose(0, 1)\ntorch.save(Iraster, \"./mouse.pt\")\n"
  },
  {
    "path": "examples/Perception_and_Learning/Conversion/burst_conversion/CIFAR10_VGG16.py",
    "content": "import sys\nsys.path.append('../../..')\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision.transforms as transforms\nimport torchvision.datasets as datasets\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport time\nfrom braincog.utils import setup_seed\nfrom braincog.datasets.datasets import get_cifar10_data\ndevice = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu')\nDATA_DIR = '/data/datasets'\n\n\nclass VGG16(nn.Module):\n    def __init__(self, relu_max=1):  # 1   3e38\n        super(VGG16, self).__init__()\n        cnn = nn.Sequential(\n            nn.Conv2d(3, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(True),\n            nn.Conv2d(64, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(True),\n            nn.MaxPool2d(2, 2),\n            nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(True),\n            nn.Conv2d(128, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(True),\n            nn.MaxPool2d(2, 2),\n            nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True),\n            nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True),\n            nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True),\n            nn.MaxPool2d(2, 2),\n            nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),\n            nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),\n            nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),\n            nn.MaxPool2d(2, 2),\n            nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),\n            nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),\n            nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),\n            nn.MaxPool2d(2, 2))\n\n        self.conv = cnn\n        self.fc = nn.Linear(512, 10, bias=True)\n\n    def forward(self, input):\n        conv = self.conv(input)\n        x = conv.view(conv.shape[0], -1)\n        output = self.fc(x)\n        return output\n\n\ndef get_cifar10_loader(batch_size, train_batch=None, num_workers=4, conversion=False, distributed=False):\n    normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))\n    transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),\n                                            CIFAR10Policy(),\n                                            transforms.ToTensor(),\n                                            Cutout(n_holes=1, length=16),\n                                            normalize])\n    transform_test = transforms.Compose([transforms.ToTensor(), normalize])\n    train_batch = batch_size if train_batch is None else train_batch\n    cifar10_train = datasets.CIFAR10(root=DATA_DIR, train=True, download=False, transform=transform_test if conversion else transform_train)\n    cifar10_test = datasets.CIFAR10(root=DATA_DIR, train=False, download=False, transform=transform_test)\n\n    if distributed:\n        train_sampler = torch.utils.data.distributed.DistributedSampler(cifar10_train)\n        val_sampler = torch.utils.data.distributed.DistributedSampler(cifar10_test, shuffle=False, drop_last=True)\n        train_iter = torch.utils.data.DataLoader(cifar10_train, batch_size=train_batch, shuffle=False, num_workers=num_workers, pin_memory=True, sampler=train_sampler)\n        test_iter = torch.utils.data.DataLoader(cifar10_test, batch_size=batch_size, shuffle=False, num_workers=num_workers,  pin_memory=True, sampler=val_sampler)\n    else:\n        train_iter = torch.utils.data.DataLoader(cifar10_train, batch_size=train_batch, shuffle=True, num_workers=num_workers, pin_memory=True)\n        test_iter = torch.utils.data.DataLoader(cifar10_test, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)\n\n    return train_iter, test_iter\n\n\ndef train(net, train_iter, test_iter, optimizer, scheduler, device, num_epochs, losstype='mse'):\n    best = 0\n    net = net.to(device)\n    print(\"training on \", device)\n    if losstype == 'mse':\n       loss = torch.nn.MSELoss()\n    else:\n        loss = torch.nn.CrossEntropyLoss(label_smoothing=0.1)\n    losses = []\n\n    for epoch in range(num_epochs):\n        for param_group in optimizer.param_groups:\n            learning_rate = param_group['lr']\n\n        losss = []\n        train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()\n        for X, y in train_iter:\n            X = X.to(device)\n            y = y.to(device)\n            y_hat = net(X)\n            label = y\n            if losstype == 'mse':\n                label = F.one_hot(y, 10).float()\n            l = loss(y_hat, label)\n            losss.append(l.cpu().item())\n            optimizer.zero_grad()\n            l.backward()\n            optimizer.step()\n            train_l_sum += l.cpu().item()\n            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()\n            n += y.shape[0]\n            batch_count += 1\n        scheduler.step()\n        test_acc = evaluate_accuracy(test_iter, net)\n        losses.append(np.mean(losss))\n        print('epoch %d, lr %.6f, loss %.6f, train acc %.6f, test acc %.6f, time %.1f sec'\n              % (epoch + 1, learning_rate, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))\n\n        if test_acc > best:\n            best = test_acc\n            torch.save(net.state_dict(), './CIFAR10_VGG16.pth')\n\n\ndef evaluate_accuracy(data_iter, net, device=None, only_onebatch=False):\n    if device is None and isinstance(net, torch.nn.Module):\n        device = list(net.parameters())[0].device\n    acc_sum, n = 0.0, 0\n    with torch.no_grad():\n        for X, y in data_iter:\n            net.eval()\n            acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()\n            net.train()\n            n += y.shape[0]\n\n            if only_onebatch: break\n    return acc_sum / n\n\n\nif __name__ == '__main__':\n    setup_seed(42)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n    batch_size = 128\n    train_iter, test_iter, _, _ = get_cifar10_data(batch_size)\n    # train_iter, test_iter = get_cifar10_loader(batch_size)\n    print('dataloader finished')\n\n    lr, num_epochs = 0.05, 300\n    net = VGG16()\n    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)\n    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0, T_max=num_epochs)\n    train(net, train_iter, test_iter, optimizer, scheduler, device, num_epochs, losstype='crossentropy')\n\n    net.load_state_dict(torch.load(\"./CIFAR10_VGG16.pth\", map_location=device))\n    net = net.to(device)\n    acc = evaluate_accuracy(test_iter, net, device)\n    print(acc)"
  },
  {
    "path": "examples/Perception_and_Learning/Conversion/burst_conversion/README.md",
    "content": "# Conversion Method\nTraining deep spiking neural network with ann-snn conversion\nreplace ReLU and MaxPooling in pytorch model to make origin ANN to be converted SNN to finish complex tasks\n\n## Results\n```shell\npython CIFAR10_VGG16.py\npython converted_CIFAR10.py\n```\n\nYou should first run the `CIFAR10_VGG16.py` to get a well-trained ANN.\nThen `converted_CIFAR10.py` can be used to run the snn inference process.\n\n### Citation \nIf you find this package helpful, please consider citing it:\n\n```BibTex\n@inproceedings{ijcai2022p345,\n  title     = {Efficient and Accurate Conversion of Spiking Neural Network with Burst Spikes},\n  author    = {Li, Yang and Zeng, Yi},\n  booktitle = {Proceedings of the Thirty-First International Joint Conference on\n               Artificial Intelligence, {IJCAI-22}},\n  publisher = {International Joint Conferences on Artificial Intelligence Organization},\n  pages     = {2485--2491},\n  year      = {2022},\n  month     = {7},\n}\n\n\n@article{li2022spike,\ntitle={Spike calibration: Fast and accurate conversion of spiking neural network for object detection and segmentation},\nauthor={Li, Yang and He, Xiang and Dong, Yiting and Kong, Qingqun and Zeng, Yi},\njournal={arXiv preprint arXiv:2207.02702},\nyear={2022}\n}\n\n```"
  },
  {
    "path": "examples/Perception_and_Learning/Conversion/burst_conversion/converted_CIFAR10.py",
    "content": "import sys\nsys.path.append('../../..')\nimport torch\nimport torchvision\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision.transforms as transforms\nimport torchvision.datasets as datasets\nimport matplotlib\nmatplotlib.use('agg')\nimport numpy as np\nfrom tqdm import tqdm\nfrom copy import deepcopy\nimport matplotlib.pyplot as plt\nimport time\nimport os\nfrom examples.Perception_and_Learning.Conversion.burst_conversion.CIFAR10_VGG16 import VGG16\nfrom braincog.utils import setup_seed\nfrom braincog.datasets.datasets import get_cifar10_data\nfrom braincog.base.conversion import Convertor\nimport argparse\n\n\nparser = argparse.ArgumentParser(description='Conversion')\nparser.add_argument('--T', default=64, type=int, help='simulation time')\nparser.add_argument('--p', default=0.99, type=float, help='percentile for data normalization. 0-1')\nparser.add_argument('--gamma', default=5, type=int, help='burst spike and max spikes IF can emit')\nparser.add_argument('--channelnorm', default=False, type=bool, help='use channel norm')\nparser.add_argument('--lipool', default=True, type=bool, help='LIPooling')\nparser.add_argument('--smode', default=True, type=bool, help='replace ReLU to IF')\nparser.add_argument('--soft_mode', default=True, type=bool, help='soft reset or not')\nparser.add_argument('--device', default='4', type=str, help='cuda device, i.e. 0 or 0,1,2,3 or cpu')\nparser.add_argument('--cuda', default=True, type=bool, help='use cuda.')\nparser.add_argument('--model_name', default='vgg16', type=str, help='model name. vgg16 or resnet20')\nparser.add_argument('--merge', default=True, type=bool, help='merge conv and bn')\nparser.add_argument('--train_batch', default=100, type=int, help='batch size for get max')\nparser.add_argument('--batch_num', default=1, type=int, help='number of train batch')\nparser.add_argument('--spicalib', default=0, type=int, help='allowance for spicalib')\nparser.add_argument('--batch_size', default=128, type=int, help='batch size for testing')\nparser.add_argument('--seed', default=42, type=int, help='seed')\nargs = parser.parse_args()\n\n\ndef evaluate_snn(test_iter, snn, device=None, duration=50):\n    accs = []\n    snn.eval()\n\n    for ind, (test_x, test_y) in tqdm(enumerate(test_iter)):\n        test_x = test_x.to(device)\n        test_y = test_y.to(device)\n        n = test_y.shape[0]\n        out = 0\n        with torch.no_grad():\n            snn.reset()\n            acc = []\n            # for t in tqdm(range(duration)):\n            for t in range(duration):\n                out += snn(test_x)\n                result = torch.max(out, 1).indices\n                result = result.to(device)\n                acc_sum = (result == test_y).float().sum().item()\n                acc.append(acc_sum / n)\n\n        accs.append(np.array(acc))\n    accs = np.array(accs).mean(axis=0)\n\n    i, show_step = 1, []\n    while 2 ** i <= duration:\n        show_step.append(2 ** i - 1)\n        i = i + 1\n\n    for iii in show_step:\n        print(\"timestep\", str(iii).zfill(3) + ':', accs[iii])\n    print(\"best acc: \", max(accs))\n\n\nif __name__ == '__main__':\n    print(\"Setting Arguments.. : \", args)\n    print(\"----------------------------------------------------------\")\n    setup_seed(seed=args.seed)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n    device = torch.device(\"cuda:%s\" % args.device) if args.cuda else 'cpu'\n\n    train_iter, _, _, _ = get_cifar10_data(args.train_batch, same_da=True)\n    _, test_iter, _, _ = get_cifar10_data(args.batch_size, same_da=True)\n\n    if args.model_name == 'vgg16':\n        net = VGG16()\n        net.load_state_dict(torch.load(\"./CIFAR10_VGG16.pth\", map_location=device))\n\n    net.eval()\n    net = net.to(device)\n\n    converter = Convertor(dataloader=train_iter,\n                          device=device,\n                          p=args.p,\n                          channelnorm=args.channelnorm,\n                          lipool=args.lipool,\n                          gamma=args.gamma,\n                          soft_mode=args.soft_mode,\n                          merge=args.merge,\n                          batch_num=args.batch_num,\n                          spicalib=args.spicalib\n                          )\n    snn = converter(net)\n\n    evaluate_snn(test_iter, snn, device, duration=args.T)\n\n"
  },
  {
    "path": "examples/Perception_and_Learning/Conversion/msat_conversion/CIFAR10_VGG16.py",
    "content": "import sys\nsys.path.append('../../..')\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision.transforms as transforms\nimport torchvision.datasets as datasets\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport time\nfrom braincog.utils import setup_seed\nfrom braincog.datasets.datasets import get_cifar10_data\ndevice = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu')\nDATA_DIR = '/data/datasets'\n\n\nclass VGG16(nn.Module):\n    def __init__(self, relu_max=1):  # 1   3e38\n        super(VGG16, self).__init__()\n        cnn = nn.Sequential(\n            nn.Conv2d(3, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(True),\n            nn.Conv2d(64, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(True),\n            nn.MaxPool2d(2, 2),\n            nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(True),\n            nn.Conv2d(128, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(True),\n            nn.MaxPool2d(2, 2),\n            nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True),\n            nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True),\n            nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True),\n            nn.MaxPool2d(2, 2),\n            nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),\n            nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),\n            nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),\n            nn.MaxPool2d(2, 2),\n            nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),\n            nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),\n            nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),\n            nn.MaxPool2d(2, 2))\n\n        self.conv = cnn\n        self.fc = nn.Linear(512, 10, bias=True)\n\n    def forward(self, input):\n        conv = self.conv(input)\n        x = conv.view(conv.shape[0], -1)\n        output = self.fc(x)\n        return output\n\n\ndef get_cifar10_loader(batch_size, train_batch=None, num_workers=4, conversion=False, distributed=False):\n    normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))\n    transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),\n                                            CIFAR10Policy(),\n                                            transforms.ToTensor(),\n                                            Cutout(n_holes=1, length=16),\n                                            normalize])\n    transform_test = transforms.Compose([transforms.ToTensor(), normalize])\n    train_batch = batch_size if train_batch is None else train_batch\n    cifar10_train = datasets.CIFAR10(root=DATA_DIR, train=True, download=False, transform=transform_test if conversion else transform_train)\n    cifar10_test = datasets.CIFAR10(root=DATA_DIR, train=False, download=False, transform=transform_test)\n\n    if distributed:\n        train_sampler = torch.utils.data.distributed.DistributedSampler(cifar10_train)\n        val_sampler = torch.utils.data.distributed.DistributedSampler(cifar10_test, shuffle=False, drop_last=True)\n        train_iter = torch.utils.data.DataLoader(cifar10_train, batch_size=train_batch, shuffle=False, num_workers=num_workers, pin_memory=True, sampler=train_sampler)\n        test_iter = torch.utils.data.DataLoader(cifar10_test, batch_size=batch_size, shuffle=False, num_workers=num_workers,  pin_memory=True, sampler=val_sampler)\n    else:\n        train_iter = torch.utils.data.DataLoader(cifar10_train, batch_size=train_batch, shuffle=True, num_workers=num_workers, pin_memory=True)\n        test_iter = torch.utils.data.DataLoader(cifar10_test, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)\n\n    return train_iter, test_iter\n\n\ndef train(net, train_iter, test_iter, optimizer, scheduler, device, num_epochs, losstype='mse'):\n    best = 0\n    net = net.to(device)\n    print(\"training on \", device)\n    if losstype == 'mse':\n       loss = torch.nn.MSELoss()\n    else:\n        loss = torch.nn.CrossEntropyLoss(label_smoothing=0.1)\n    losses = []\n\n    for epoch in range(num_epochs):\n        for param_group in optimizer.param_groups:\n            learning_rate = param_group['lr']\n\n        losss = []\n        train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()\n        for X, y in train_iter:\n            X = X.to(device)\n            y = y.to(device)\n            y_hat = net(X)\n            label = y\n            if losstype == 'mse':\n                label = F.one_hot(y, 10).float()\n            l = loss(y_hat, label)\n            losss.append(l.cpu().item())\n            optimizer.zero_grad()\n            l.backward()\n            optimizer.step()\n            train_l_sum += l.cpu().item()\n            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()\n            n += y.shape[0]\n            batch_count += 1\n        scheduler.step()\n        test_acc = evaluate_accuracy(test_iter, net)\n        losses.append(np.mean(losss))\n        print('epoch %d, lr %.6f, loss %.6f, train acc %.6f, test acc %.6f, time %.1f sec'\n              % (epoch + 1, learning_rate, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))\n\n        if test_acc > best:\n            best = test_acc\n            torch.save(net.state_dict(), './CIFAR10_VGG16.pth')\n\n\ndef evaluate_accuracy(data_iter, net, device=None, only_onebatch=False):\n    if device is None and isinstance(net, torch.nn.Module):\n        device = list(net.parameters())[0].device\n    acc_sum, n = 0.0, 0\n    with torch.no_grad():\n        for X, y in data_iter:\n            net.eval()\n            acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()\n            net.train()\n            n += y.shape[0]\n\n            if only_onebatch: break\n    return acc_sum / n\n\n\nif __name__ == '__main__':\n    setup_seed(42)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n    batch_size = 128\n    train_iter, test_iter, _, _ = get_cifar10_data(batch_size)\n    # train_iter, test_iter = get_cifar10_loader(batch_size)\n    print('dataloader finished')\n\n    lr, num_epochs = 0.05, 300\n    net = VGG16()\n    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)\n    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0, T_max=num_epochs)\n    train(net, train_iter, test_iter, optimizer, scheduler, device, num_epochs, losstype='crossentropy')\n\n    net.load_state_dict(torch.load(\"./CIFAR10_VGG16.pth\", map_location=device))\n    net = net.to(device)\n    acc = evaluate_accuracy(test_iter, net, device)\n    print(acc)"
  },
  {
    "path": "examples/Perception_and_Learning/Conversion/msat_conversion/README.md",
    "content": "# Script for experiment\n\n```shell\npython converted_CIFAR10.py --useDET --useDTT\n```\n\n\n\n## Note: Convertor\n\nplease use `convertor.py` here to replace and override `braincog/base/conversion/convertor.py` if you want to run multi-stage adaptive threshold in this project.\n\n\n\n## Citation\n\nIf you find the code and dataset useful in your research, please consider citing:\n\n```\n@article{he2023improving,\n  title={Improving the Performance of Spiking Neural Networks on Event-based Datasets with Knowledge Transfer},\n  author={He, Xiang and Zhao, Dongcheng and Li, Yang and Shen, Guobin and Kong, Qingqun and Zeng, Yi},\n  journal={arXiv preprint arXiv:2303.13077},\n  year={2023}\n}\n\n@misc{https://doi.org/10.48550/arxiv.2207.08533,\n  doi = {10.48550/ARXIV.2207.08533},\n  url = {https://arxiv.org/abs/2207.08533},\n  author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},\n  title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},\n  publisher = {arXiv},\n  year = {2022},\n}\n```\n\n\n## Contents\n\nIf you are confused about using it or have other feedback and comments, please feel free to contact us via [hexiang2021@ia.ac.cn](hexiang2021@ia.ac.cn).\n"
  },
  {
    "path": "examples/Perception_and_Learning/Conversion/msat_conversion/converted_CIFAR10.py",
    "content": "# -*- coding: utf-8 -*-            \n# Time : 2023/4/19 15:56\n# Author : Regulus\n# FileName: converted_CIFAR10.py\n# Explain: \n# Software: PyCharm\n\nimport sys\nsys.path.append('..')\nimport torch\nimport torchvision\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision.transforms as transforms\nimport torchvision.datasets as datasets\n# import matplotlib\n# matplotlib.use('agg')\nimport numpy as np\nfrom tqdm import tqdm\nfrom braincog.utils import setup_seed\nimport os\nfrom examples.Perception_and_Learning.Conversion.msat_conversion.CIFAR10_VGG16 import VGG16\nimport argparse\nfrom braincog.datasets.datasets import get_cifar10_data\nfrom braincog.base.conversion import Convertor, FolderPath\n\n\nparser = argparse.ArgumentParser(description='Conversion')\nparser.add_argument('--T', default=256, type=int, help='simulation time')\nparser.add_argument('--p', default=1, type=float, help='percentile for data normalization. 0-1')\nparser.add_argument('--gamma', default=1, type=int, help='burst spike and max spikes IF can emit')\nparser.add_argument('--lateral_inhi', default=True, type=bool, help='LIPooling')\nparser.add_argument('--data_norm', default=True, type=bool, help=' whether use data norm or not')\nparser.add_argument('--smode', default=True, type=bool, help='replace ReLU to IF')\nparser.add_argument('--device', default='7', type=str, help='cuda device, i.e. 0 or 0,1,2,3 or cpu')\nparser.add_argument('--cuda', default=True, type=bool, help='use cuda.')\nparser.add_argument('--model_name', default='vgg16', type=str, help='model name. vgg16 or resnet20')\nparser.add_argument('--train_batch', default=512, type=int, help='batch size for get max')\nparser.add_argument('--batch_size', default=128, type=int, help='batch size for testing')\nparser.add_argument('--seed', default=23, type=int, help='seed')\nparser.add_argument('--useDET', action='store_true', default=False, help='use DET')\nparser.add_argument('--useDTT', action='store_true', default=False, help='use DTT')\nparser.add_argument('--useSC', action='store_true', default=False, help='use SpikeConfidence')\nargs = parser.parse_args()\n\ndef evaluate_snn(test_iter, snn, device=None, duration=50):\n    folder_path = \"./result_conversion_{}/snn_timestep{}_p{}_LIPooling{}_Burst{}\".format(\n            args.model_name, duration, args.p, args.lateral_inhi, args.gamma)\n    if not os.path.exists(folder_path):  # 判断是否存在文件夹如果不存在则创建为文件夹\n        os.makedirs(folder_path)\n    snn.eval()\n    FolderPath.folder_path = folder_path\n    accs = []\n    for ind, (test_x, test_y) in enumerate(tqdm(test_iter)):\n        test_x = test_x.to(device)\n        test_y = test_y.to(device)\n        n = test_y.shape[0]\n        out = 0\n        with torch.no_grad():\n            snn.reset()\n            acc = []\n            for t in range(duration):\n                out += snn(test_x)\n                result = torch.max(out, 1).indices\n                result = result.to(device)\n                acc_sum = (result == test_y).float().sum().item()\n                acc.append(acc_sum / n)\n        # break\n        accs.append(np.array(acc))\n\n    if True:\n        f = open('{}/result.txt'.format(folder_path), 'w')\n        f.write(\"Setting Arguments.. : {}\\n\".format(args))\n        accs = np.array(accs).mean(axis=0)\n        for iii in range(256):\n            if iii == 0 or iii == 3 or iii == 7 or (iii + 1) % 16 == 0:\n                f.write(\"timestep {}:{}\\n\".format(str(iii+1).zfill(3), accs[iii]))\n        f.write(\"max accs: {}, timestep:{}\\n\".format(max(accs), np.where(accs == max(accs))))\n        f.close()\n        accs = torch.from_numpy(accs)\n        torch.save(accs, \"{}/accs.pth\".format(folder_path))\n\nif __name__ == '__main__':\n    print(\"Setting Arguments.. : \", args)\n    print(\"----------------------------------------------------------\")\n    setup_seed(seed=args.seed)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n    device = torch.device(\"cuda:%s\" % args.device) if args.cuda else 'cpu'\n\n    train_iter, _, _, _ = get_cifar10_data(args.train_batch, same_da=True)\n    _, test_iter, _, _ = get_cifar10_data(args.batch_size, same_da=True)\n\n    if args.model_name == 'vgg16':\n        net = VGG16()\n        # net.load_state_dict(torch.load(\"./CIFAR10_VGG16.pth\", map_location=device))\n\n    net.eval()\n    net = net.to(device)\n\n    converter = Convertor(dataloader=train_iter,\n                          device=device,\n                          p=1.0,\n                          channelnorm=False,\n                          lipool=True,\n                          gamma=1,\n                          soft_mode=True,\n                          merge=True,\n                          batch_num=1,\n                          spicalib=False,\n                          useDET=args.useDET,\n                          useDTT=args.useDTT,\n                          useSC=args.useSC\n                          )\n    snn = converter(net)\n\n    evaluate_snn(test_iter, snn, device, duration=args.T)\n"
  },
  {
    "path": "examples/Perception_and_Learning/Conversion/msat_conversion/convertor.py",
    "content": "import torch\nimport torch.nn as nn\nfrom braincog.base.connection.layer import SMaxPool, LIPool\nfrom .merge import mergeConvBN\nfrom .spicalib import SpiCalib\nimport types\nimport os\nimport sys\n\nlayer_index = 0  # layer index for SNode\n\n\nclass FolderPath:\n    folder_path = \"init\"\n\nclass HookScale(nn.Module):\n    \"\"\" 在每个ReLU层后记录该层的百分位最大值\n\n    For channelnorm: 获取最大值时使用了torch.quantile\n    For layernorm：  使用sort，然后手动取百分比，因为quantile在计算单个通道时有上限，batch较大时易出错\n    \"\"\"\n\n    def __init__(self,\n                 p: float = 0.9995,\n                 channelnorm: bool = False,\n                 gamma: float = 0.999,\n                 ):\n        super().__init__()\n        if channelnorm:\n            self.register_buffer('scale', torch.tensor(0.0))\n        else:\n            self.register_buffer('scale', torch.tensor(0.0))\n\n        self.p = p\n        self.channelnorm = channelnorm\n        self.gamma = gamma\n\n    def forward(self, x):\n        x = torch.where(x.detach() < self.gamma, x.detach(),\n                        torch.tensor(self.gamma, dtype=x.dtype, device=x.device))\n        if len(x.shape) == 4 and self.channelnorm:\n            num_channel = x.shape[1]\n            tmp = torch.quantile(x.permute(1, 0, 2, 3).reshape(num_channel, -1), self.p, dim=1,\n                                 interpolation='lower') + 1e-10\n            self.scale = torch.max(tmp, self.scale)\n        else:\n            sort, _ = torch.sort(x.view(-1))\n            self.scale = torch.max(sort[int(sort.shape[0] * self.p) - 1], self.scale)\n        return x\n\n\nclass Hookoutput(nn.Module):\n    \"\"\"\n    在伪转换中为ReLU和ClipQuan提供包装，用于监控其输出\n    \"\"\"\n\n    def __init__(self, module):\n        super(Hookoutput, self).__init__()\n        self.activation = 0.\n        self.operation = module\n\n    def forward(self, x):\n        output = self.operation(x)\n        self.activation = output.detach()\n        return output\n\n\nclass Scale(nn.Module):\n    \"\"\"\n    对前向过程的值进行缩放\n    \"\"\"\n\n    def __init__(self, scale: float = 1.0):\n        super().__init__()\n        self.register_buffer('scale', scale)\n\n    def forward(self, x):\n        if len(self.scale.shape) == 1:\n            return self.scale.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(x) * x\n        else:\n            return self.scale * x\n\n\ndef reset(self):\n    \"\"\"\n    转换的网络来自ANN，需要将新附加上的脉冲module进行reset\n    判断module名称并调用各自节点的reset方法\n    \"\"\"\n    children = list(self.named_children())\n    for i, (name, child) in enumerate(children):\n        if isinstance(child, (SNode, LIPool, SMaxPool)):\n            child.reset()\n        else:\n            reset(child)\n\n\nclass Convertor(nn.Module):\n    \"\"\"ANN2SNN转换器\n\n    用于转换完整的pytorch模型，使用dataloader中部分数据进行最大值计算，通过p控制获取第p百分比最大值\n\n    channlenorm: https://arxiv.org/abs/1903.06530\n    channelnorm可以对每个通道获取最大值并进行权重归一化\n\n    gamma: https://arxiv.org/abs/2204.13271\n    gamma可以控制burst spikes的脉冲数，burst spike可以提高神经元的脉冲发放能力，减小信息残留\n\n    lipool: https://arxiv.org/abs/2204.13271\n    lipool用于使用侧向抑制机制进行最大池化，LIPooling能够对SNN中的最大池化进行有效的转换\n\n    soft_mode: https://arxiv.org/abs/1612.04052\n    soft_mode被称为软重置，可以减小重置过程神经元的信息损失，有效提高转换的性能\n\n    merge用于是否对网络中相邻的卷积和BN层进行融合\n    batch_norm控制对dataloader的数据集的用量\n    \"\"\"\n\n    def __init__(self,\n                 dataloader,\n                 device=None,\n                 p=0.9995,\n                 channelnorm=False,\n                 lipool=True,\n                 gamma=1,\n                 soft_mode=True,\n                 merge=True,\n                 batch_num=1,\n                 spicalib=0,\n                 useDET=False, useDTT=False, useSC=None\n                 ):\n        super(Convertor, self).__init__()\n        self.dataloader = dataloader\n        self.device = device\n        self.p = p\n        self.channelnorm = channelnorm\n        self.lipool = lipool\n        self.gamma = gamma\n        self.soft_mode = soft_mode\n        self.merge = merge\n        self.batch_num = batch_num\n        self.spicalib = spicalib\n        self.useDET = useDET\n        self.useDTT = useDTT\n        self.useSC = useSC\n\n    def forward(self, model):\n        model.eval()\n        model = Convertor.register_hook(model, self.p, self.channelnorm, self.gamma)\n        model = Convertor.get_percentile(model, self.dataloader, self.device, batch_num=self.batch_num)\n        model = mergeConvBN(model) if self.merge else model\n        model = Convertor.replace_for_spike(model, self.lipool, self.soft_mode, self.gamma, self.spicalib, self.useDET,\n                                            self.useDTT, self.useSC)\n        model.reset = types.MethodType(reset, model)\n        return model\n\n    @staticmethod\n    def register_hook(model, p=0.99, channelnorm=False, gamma=0.999):\n        \"\"\" Reference: https://github.com/fangwei123456/spikingjelly\n\n        将网络的每一层后注册一个HookScale类\n        该方法在仿真上等效于与对权重进行归一化操作，且易扩展到任意结构的网络中\n        \"\"\"\n        children = list(model.named_children())\n        for _, (name, child) in enumerate(children):\n            if isinstance(child, nn.ReLU):\n                model._modules[name] = nn.Sequential(nn.ReLU(), HookScale(p, channelnorm, gamma))\n            else:\n                Convertor.register_hook(child, p, channelnorm, gamma)\n        return model\n\n    @staticmethod\n    def get_percentile(model, dataloader, device, batch_num=1):\n        \"\"\"\n        该函数需与具有HookScale层的网络配合使用\n        \"\"\"\n        for idx, (data, _) in enumerate(dataloader):\n            data = data.to(device)\n            if idx >= batch_num:\n                break\n            model(data)\n        return model\n\n    @staticmethod\n    def replace_for_spike(model, lipool=True, soft_mode=True, gamma=1, spicalib=0, useDET=False, useDTT=False, useSC=None):\n        \"\"\"\n        该函数用于将定义好的ANN模型转换为SNN模型\n        ReLU单元将被替换为脉冲神经元，\n        如果模型中使用了最大池化，lipool参数将定义使用常规模型还是LIPooling方法\n        \"\"\"\n        children = list(model.named_children())\n        for _, (name, child) in enumerate(children):\n            if isinstance(child, nn.Sequential) and len(child) == 2 and isinstance(child[0], nn.ReLU) and isinstance(child[1], HookScale):\n                global layer_index\n                model._modules[name] = nn.Sequential(\n                    Scale(1.0 / child[1].scale),\n                    SNode(soft_mode, gamma, useDET=useDET, useDTT=useDTT, useSC=useSC, layer_index=layer_index),\n                    SpiCalib(spicalib),\n                    Scale(child[1].scale)\n                )\n                layer_index += 1\n            if isinstance(child, nn.MaxPool2d):\n                model._modules[name] = LIPool(child) if lipool else SMaxPool(child)\n            else:\n                Convertor.replace_for_spike(child, lipool, soft_mode, gamma, useDET=useDET, useDTT=useDTT, useSC=useSC)\n        return model\n\n\nclass SNode(nn.Module):\n    \"\"\"\n    用于转换后的SNN的神经元模型\n    IF神经元模型由gamma=1确定，当gamma为其他大于1的值时，即为使用burst神经元模型\n    soft_mode用于定义神经元的重置方法，soft重置能够极大地减少神经元在重置过程的信息损失\n    \"\"\"\n\n    def __init__(self, soft_mode=False, gamma=5, useDET=False, useDTT=False, useSC=None, layer_index=1):\n        super(SNode, self).__init__()\n        self.threshold = 1.0\n        self.maxThreshold = 1.0\n        self.soft_mode = soft_mode\n        self.gamma = gamma\n\n        self.mem = 0\n        self.spike = 0\n\n        self.Vm = 0.\n        self.summem = 0.\n        self.t = 0\n        self.all_spike = 0\n        self.V_T = 0\n        self.useDET = useDET\n        self.useDTT = useDTT\n        self.useSC = useSC\n        self.layer_index = layer_index\n        # hyperparameters\n        self.alpha = 0\n        self.ka = 0\n        self.ki = 0\n        self.C = 0\n        self.tau_mp = 0\n        self.tau_rd = 0\n\n        # record sin\n        self.mem_16 = 0.0\n        self.spike_mask = 0\n        self.sin_spikenum = 0.0\n        self.sin_ratio = []  # snn中sin占负的比例\n        self.last_spike = 0\n        self.confidence = []\n        self.neg_ratio = []  # ann中负的占总的比例\n        self.sin_all_ratio = []  # snn中sin占总的比例\n        self.pos_all_ratio = []  # snn中pos占总的比例\n        self.should_all_ratio = []  # snn中pos应该发但是没发占总的比例\n        self.confidence = []  # snn中sin占所有发的比例\n        self.avg_error_spikenum = []  # snn中错发的平均个数\n\n    def forward(self, x):\n        self.mem = self.mem + x\n        self.spike = torch.zeros_like(x)\n        if self.t == 0:\n            self.threshold = torch.full(x.shape, 1.0 * self.maxThreshold).to(x.device)\n            self.V_T = -torch.full(x.shape, self.maxThreshold).to(x.device)\n            # init hyperparameters\n            hp = []\n            path = FolderPath.folder_path.split('/')\n            path = os.path.join(path[0], path[1], path[2], 'hyperparameters.txt')\n            with open(path, 'r') as f:\n                data = f.readlines()  # 将txt中所有字符串读入data\n                for ind, line in enumerate(data):\n                    numbers = line.split()  # 将数据分隔\n                    hp.append(list(map(float, numbers))[0])  # 转化为浮点数\n\n            self.alpha = hp[0]\n            self.ka = hp[1]\n            self.ki = hp[2]\n            self.C = hp[3]\n            self.tau_mp = hp[4]\n            self.tau_rd = hp[5]\n        else:\n            DTT = self.tau_mp * (self.alpha * (self.last_mem - self.Vm) + self.V_T + self.ka * torch.log(\n                1 + torch.exp((self.last_mem - self.Vm) / self.ki)))\n            DET = self.tau_rd * torch.exp(-1 * x / self.C)\n            if self.useDET is True and self.useDTT is True:\n                self.threshold = DET + DTT\n            elif self.useDTT is True:\n                self.threshold = DTT\n            elif self.useDET is True:\n                self.threshold = DET\n            else:\n                print(\"wrong logics\")\n                sys.exit()\n            self.threshold = torch.sigmoid(self.threshold)\n            self.threshold *= self.maxThreshold\n\n        self.spike = (self.mem / self.threshold).floor().clamp(min=0, max=self.gamma)\n        self.soft_reset() if self.soft_mode else self.hard_reset\n\n        if self.useSC is True:\n            if self.t < 16:\n                # read confidence\n                path = FolderPath.folder_path.split('/')\n                path = os.path.join(path[0], path[1], path[2], 'neuron_confidence_vgg16.txt')\n                with open(path, 'r') as f:\n                    data = f.readlines()  # 将txt中所有字符串读入data\n                    for ind, line in enumerate(data):\n                        numbers = line.split()  # 将数据分隔\n                        self.confidence.append(list(map(float, numbers))[0])  # 转化为浮点数\n                mask = (torch.rand(x.shape) >= (1.0 - self.confidence[self.layer_index])).float().cuda()\n                self.spike = self.spike * mask  # random drop\n        self.all_spike += self.spike\n        out = self.spike * self.threshold\n        self.t += 1\n        self.last_mem = self.mem\n        self.summem += self.mem\n        self.Vm = (self.summem / self.t)\n        return out\n\n    def hard_reset(self):\n        \"\"\"\n        硬重置后神经元的膜电势被重置为0\n        \"\"\"\n        self.mem = self.mem * (1 - self.spike.detach())\n\n    def soft_reset(self):\n        \"\"\"\n        软重置后神经元的膜电势为神经元当前膜电势减去阈值\n        \"\"\"\n        self.mem = self.mem - self.threshold * self.spike.detach()\n\n    def reset(self):\n        self.mem = 0\n        self.spike = 0\n        self.maxThreshold = 1.0\n"
  },
  {
    "path": "examples/Perception_and_Learning/IllusionPerception/AbuttingGratingIllusion/distortion/__init__.py",
    "content": "from .abutting_grating_illusion import ag_distort_28, ag_distort_224, ag_distort_silhouette, save_image, get_silhouette_data"
  },
  {
    "path": "examples/Perception_and_Learning/IllusionPerception/AbuttingGratingIllusion/distortion/abutting_grating_illusion/__init__.py",
    "content": "from .abutting_grating_distortion import ag_distort_28, ag_distort_224, ag_distort_silhouette, save_image, get_silhouette_data"
  },
  {
    "path": "examples/Perception_and_Learning/IllusionPerception/AbuttingGratingIllusion/distortion/abutting_grating_illusion/abutting_grating_distortion.py",
    "content": "import torch\nfrom torchvision import datasets, transforms\nimport time\nimport os\nimport matplotlib.pyplot as plt\nimport torch\nfrom torchvision import datasets, transforms\nimport time\nimport os\n\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport pickle\nfrom PIL import Image\nimport numpy as np\n \n\nfrom torchvision import utils\nimport os\n\n\nseed = 1000\n\ntorch.manual_seed(seed)\ntorch.cuda.manual_seed(seed)\ntorch.cuda.manual_seed_all(seed)\n\n\ndef save_image(image, filename):  \n    assert len(image.shape) == 3, \"The image must have only three dimensions of C,W,H.\"\n    utils.save_image(image, filename)\n    \n\n\n\ndef get_mnist_data(train = False, batch_size = 100):\n    path = './datasets/' # might need to change based on where to call this function\n    #transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])  \n    transform = transforms.Compose([transforms.ToTensor()])\n    if train:\n        train_loader = torch.utils.data.DataLoader(\n            datasets.MNIST(path, train=True, download=False, transform=transform),\n                batch_size=batch_size, shuffle=False)\n        return train_loader\n    else:\n        test_loader = torch.utils.data.DataLoader(\n            datasets.MNIST(path, train=False, download=False, transform=transform),\n                batch_size=batch_size, shuffle=False)\n        return test_loader\n\n\n\n\n\ndef get_silhouette_data(path):\n    '''\n    path: dir path to the silhouette image samples of 16-clas-ImageNet\n    '''\n    labels = os.listdir(path)\n    \n    datasets = []\n    #transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])\n    transform = transforms.Compose([transforms.ToTensor()])\n    for label in labels:\n        for img_name in os.listdir(f\"{path}/{label}\"):\n            img_path = f\"{path}/{label}/{img_name}\"\n            img = Image.open(img_path)\n            img = transform(img).unsqueeze(0)\n\n            datasets.append((img, label))\n    return datasets\n\n\ndef ag_distort_28(imgs, threshold=0, interval=4, phase=2, direction=(1,0)):\n    #return imgs\n    assert len(imgs.shape) == 4, \"The images must have four dimensions of B,C,W,H.\"\n    B,C,W,H = imgs.shape\n    mask_fg = (imgs>threshold).float()  \n    mask_bg = 1 - mask_fg\n    gratings_fg = torch.zeros_like(imgs)\n    gratings_bg = torch.zeros_like(imgs)\n  \n    for w in range(W):\n        for h in range(H):\n            if (direction[0]*w+direction[1]*h)%interval==0:\n                gratings_fg[:,:,w,h] = 1\n            if (direction[0]*w+direction[1]*h)%interval==phase:\n                gratings_bg[:,:,w,h] = 1\n    masked_gratings_fg = mask_fg*gratings_fg\n    masked_gratings_bg = mask_bg*gratings_bg\n    ag_image = masked_gratings_fg + masked_gratings_bg\n    return ag_image\n\ndef transform_224(imgs):\n    imgs = torch.nn.functional.interpolate(imgs, scale_factor = 8, mode = 'bilinear', align_corners = False)\n    imgs = torch.cat([imgs, imgs, imgs], dim=1)\n    return imgs   \n\ndef ag_distort_224(imgs, threshold=0, interval=8, phase=4, direction=(1,0)):\n    assert len(imgs.shape) == 4, \"The images must have four dimensions of C,W,H.\"   \n    imgs = torch.nn.functional.interpolate(imgs, scale_factor = 8, mode = 'bilinear', align_corners = False)\n    imgs = torch.cat([imgs, imgs, imgs], dim=1)\n    #return imgs\n    B,C,W,H = imgs.shape\n    mask_fg = (imgs>threshold).float()  \n    mask_bg = 1 - mask_fg\n    gratings_fg = torch.zeros_like(imgs)\n    gratings_bg = torch.zeros_like(imgs)\n  \n    for w in range(W):\n        for h in range(H):\n            if (direction[0]*w+direction[1]*h)%interval==0:\n                gratings_fg[:,:,w,h] = 1\n            if (direction[0]*w+direction[1]*h)%interval==phase:\n                gratings_bg[:,:,w,h] = 1\n    masked_gratings_fg = mask_fg*gratings_fg\n    masked_gratings_bg = mask_bg*gratings_bg\n    ag_image = masked_gratings_fg + masked_gratings_bg\n    return ag_image\n\n\ndef ag_distort_silhouette(imgs, threshold=0.5, interval=2, phase=1, direction=(1,0)):\n\n    assert len(imgs.shape) == 4, \"The image must have only three dimensions of C,W,H.\"\n    #imgs = torch.nn.functional.interpolate(imgs, scale_factor = 2, mode = 'bilinear', align_corners = False)\n    B,C,W,H = imgs.shape\n    mask_fg = (imgs<threshold).float()\n    mask_bg = 1 - mask_fg\n    gratings_fg = torch.zeros_like(imgs)\n    gratings_bg = torch.zeros_like(imgs)\n    for w in range(W):\n        for h in range(H):\n            if (direction[0]*w+direction[1]*h)%interval==0:\n                gratings_fg[:,:,w,h] = 1\n            if (direction[0]*w+direction[1]*h)%interval==phase:\n                gratings_bg[:,:,w,h] = 1\n    ag_images = mask_fg*gratings_fg + mask_bg*gratings_bg\n    #transform = transforms.Compose([transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])\n    #transform = transforms.Compose([]) \n    #ag_images[0] = transform(ag_images[0])\n    return ag_images"
  },
  {
    "path": "examples/Perception_and_Learning/IllusionPerception/AbuttingGratingIllusion/main.py",
    "content": "from distortion import ag_distort_28, ag_distort_224, ag_distort_silhouette, save_image, get_silhouette_data\nfrom braincog.datasets import get_mnist_data\n\n# An example of Abutting Grating Distortion applied on MNIST\ntrain_loader, test_loader, _, _ = get_mnist_data(batch_size=100)\nfor images, labels in train_loader:\n    images = ag_distort_28(images, interval=4, phase=2, direction=(1,0))\n    save_image(images[0], 'test_ag_mnist.png')\n    break\n\n# An example to generate Abutting Grating distorted MNIST of resolution 224x224\nfor images, labels in train_loader:\n    images = ag_distort_224(images, interval=8, phase=4, direction=(1,0))\n    save_image(images[0], 'test_high_res_ag_mnist.png')\n    break\n\n# An example of Abutting Grating Distortion applied on silhouettes of 16-class-ImageNet\n'''\nThe silhouette images can be downloaded from https://github.com/rgeirhos/texture-vs-shape\n'''\ndataset = get_silhouette_data('./silhouettes')\nfor images, labels in dataset:\n    images = ag_distort_silhouette(images, interval=16, phase=8)\n    \n    save_image(images[0], 'test_ag_silhouettes.png')\n    break"
  },
  {
    "path": "examples/Perception_and_Learning/MultisensoryIntegration/README.md",
    "content": "# Multisensory Integration\n\nIn `MultisensoryIntegrationDEMO_AM.py` and `MultisensoryIntegrationDEMO_AM.py`, we implement the SNNs based multisensory integration framework. To load the dataset, preprocess it and get the weights with the function `get_concept_datase_dic_and_initial_weights_lst()`​. We use `IMNet`​ or ​`AMNet`​ to describe the structure of the IM/AM model. For presynaptic neuron, we use the function `convert_vec_into_spike_trains()​` to generate the spike trains.\n\nWhile for postsynaptic neuron, we use the function `reducing_tol_binarycode()​` to get the multisensory integrated output for each concept. And  ​*tol*​ is the only parameter.\n\nIn `measure_and_visualization.py​`, we will measure and visualize the results.\n\n## Multisensory Dataset \n\nWhen implement the model in braincog, we use the famous multisensory dataset--BBSR.\n\nSome examples are as follows:\n\n| Concept   | Visual      | Somatic   | Audiation   | Taste    | Smell    |\n| --------- | ----------- | --------- | ----------- | -------- | -------- |\n| advantage | 0.213333333 | 0.032     | 0           | 0        | 0        |\n| arm       | 2.5111112   | 2.2733334 | 0.133333286 | 0.233333 | 0.4      |\n| ball      | 1.9580246   | 2.3111112 | 0.523809429 | 0.185185 | 0.111111 |\n| baseball  | 2.2714286   | 2.6071428 | 0.352040714 | 0.071429 | 0.392857 |\n| bee       | 2.795698933 | 2.4129034 | 2.096774286 | 0.290323 | 0.419355 |\n| beer      | 1.4866666   | 2.2533334 | 0.190476286 | 5.8      | 4.6      |\n| bird      | 2.7632184   | 2.027586  | 3.064039286 | 1.068966 | 0.517241 |\n| car       | 2.521839133 | 2.9517244 | 2.216748857 | 0        | 2.206897 |\n| foot      | 2.664444533 | 2.58      | 0.380952429 | 0.433333 | 3        |\n| honey     | 1.757142867 | 2.3214286 | 0.015306143 | 5.642857 | 4.535714 |\n\n## How to Run \n\nTo get the multisensory integrated vectors:\n\n```\ncd examples/MultisensoryIntegration/code\npython MultisensoryIntegrationDEMO_AM.py\npython MultisensoryIntegrationDEMO_IM.py\n```\n\nTo measure and analysis the vectors:\n\n```\ncd examples/MultisensoryIntegration/code\npython measure_and_visualization.py\n```\n\n## Citation \nIf you find this package helpful, please consider citing the following papers:\n\n@ARTICLE{wang2022multisensory,\nauthor={Wang, Yuwei and Zeng, Yi},   \ntitle={Multisensory Concept Learning Framework Based on Spiking Neural Networks},      \njournalL={Frontiers in Systems Neuroscience},      \nvolume={16},           \nyear={2022},      \nurl={https://www.frontiersin.org/articles/10.3389/fnsys.2022.845177},       \ndou={10.3389/fnsys.2022.845177},      \nissn={1662-5137},   \n}\n\n@misc{https://doi.org/10.48550/arxiv.2207.08533,\n   doi = {10.48550/ARXIV.2207.08533},\n   url = {https://arxiv.org/abs/2207.08533},\n   author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},\n   title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},\n   publisher = {arXiv},\n   year = {2022},\n }\n\n\n\n\n\n\n\n\n"
  },
  {
    "path": "examples/Perception_and_Learning/MultisensoryIntegration/code/MultisensoryIntegrationDEMO_AM.py",
    "content": "#!/usr/bin/env python\n#-*- coding:utf-8 -*-\n__author__ = 'Yuwei Wang'\nimport torch\nimport pandas as pd\nfrom torch import nn\nfrom braincog.base.node.node import LIFNode, IzhNode\nfrom torch.nn.parameter import Parameter\n\nclass AMNet(nn.Module):\n    def __init__(self,\n                 in_features: int,\n                 out_features: int,\n                 givenWeights,\n                 bias=False,\n                 node=LIFNode,\n                 threshold=5,\n                 tau=0.1):\n        super().__init__()\n        if node is None:\n            raise TypeError\n\n        self.fc = nn.Linear(in_features=in_features,\n                            out_features=out_features, bias=bias)\n        self.fc.weight = Parameter(givenWeights)\n        self.node = node(threshold, tau)\n\n    def forward(self, x):\n        x = torch.tensor(x, dtype=torch.float)\n        x = self.fc(x)\n        x = self.node(x)\n        return x\n\n\ndef get_concept_dataset_dic_and_AM_initial_weights_lst(BBSR_path):\n    modality_lst = ['Auditory', 'Gustatory', 'Haptic', 'Olfactory', 'Visual']\n    # load_concept_dataset_df\n    df_BBSR = pd.read_excel (BBSR_path, sheet_name=\"Sheet1\", header=0, index_col=0,\n                             usecols=[0, 1, 20, 26, 34, 35, 36])\n\n    df_BBSR.rename (\n        columns={'Word': 'Concept', 'Audiation_mean': 'Auditory', 'Taste': 'Gustatory',\n                 'Somatic_mean': 'Haptic',\n                 \"Smell\": \"Olfactory\", \"Visual_mean\": \"Visual\"}, inplace=True)\n    concept_dataset_df = df_BBSR.drop_duplicates (subset=\"Concept\")\n\n\n    # min-max\n    z_minmax = lambda x: (x - np.min (x)) / (np.max (x) - np.min (x))\n    dataset_df_minmax = concept_dataset_df[modality_lst].apply (z_minmax)\n    concept_dataset_df = pd.concat ([concept_dataset_df[['Concept']], dataset_df_minmax], axis=1)\n\n    # output\n    dataset_concept_dims_dic = concept_dataset_df.to_dict (\"index\")\n    final_concept_dims_dic = {}\n    for each_key in dataset_concept_dims_dic.keys ():\n        each_concept_name = dataset_concept_dims_dic[each_key].pop ('Concept')\n        final_concept_dims_dic[each_concept_name] = [dataset_concept_dims_dic[each_key][each_modality] for each_modality\n                                                     in modality_lst]\n\n    method_type = \"pearson\" # spearman pearson kendall\n    corr_dataset = df_BBSR.corr (method=method_type)\n    corr_dic = corr_dataset.to_dict ()\n    AM_initial_weights = []\n    for each_m in modality_lst:\n        tmp_lst = []\n        for each_n in modality_lst:\n            tmp_lst.append(corr_dic[each_m][each_n])\n        AM_initial_weights.append(tmp_lst)\n        \n    AM_initial_weights = torch.tensor(AM_initial_weights, dtype=torch.float)# 后续需要改\n\n\n    return final_concept_dims_dic, AM_initial_weights\n\ndef convert_vec_into_spike_trains(each_concept_vec):\n    # generate input with Poisson-encoded spikes\n    tmp = torch.tensor ([each_concept_vec * time])\n    rates = tmp.view (time, -1)\n    vec_spike_trains = torch.bernoulli (rates).byte ()  # concept_representation\n    vec_spike_trains = torch.tensor (vec_spike_trains, dtype=torch.float)\n    return vec_spike_trains\n\ndef reducing_to_binarycode(post_neuron_states_lst, tolerance):\n    post_neuron_states_lst = [int (i) for i in post_neuron_states_lst]\n\n    if len (post_neuron_states_lst) % tolerance != 0:\n        placeholder = [0] * (tolerance - len (post_neuron_states_lst) % tolerance)\n\n        post_neuron_states_lst_with_placeholder = post_neuron_states_lst + placeholder\n    else:\n        post_neuron_states_lst_with_placeholder = post_neuron_states_lst\n\n    post_neuron_states_lst_with_placeholder = np.array (post_neuron_states_lst_with_placeholder).reshape (-1, tolerance)\n    binarycode = \"\"\n    for sub_arr in post_neuron_states_lst_with_placeholder:\n        if 1.0 in sub_arr:\n            binarycode += \"1\"\n        else:\n            binarycode += \"0\"\n    return binarycode\n\n\n\nif __name__ == \"__main__\":\n    import numpy as np\n    import pickle\n\n    # Dataset Reference: Binder JR, Conant LL, Humphries CJ, Fernandino L, Simons SB, Aguilar M, Desai RH.\n    # Toward a brain-based componential semantic representation. Cogn Neuropsychol. 2016 May-Jun;33(3-4):130-74.\n    # doi: 10.1080/02643294.2016.1147426. Epub 2016 Jun 16. PMID: 27310469.\n    BBSR_path = \"../data/BBSR-5modalities.xlsx\"\n    AM_binarycode_file = open ( \"../results/AM_binarycode.pickle\", \"wb\")\n\n    time = 1000\n    tolerance = 2\n\n    concept_dims_dic, AM_initial_weights = get_concept_dataset_dic_and_AM_initial_weights_lst (BBSR_path)\n    AM_binarycode_dic = {}\n    for each_concept in concept_dims_dic.keys():\n        each_concept_vec = concept_dims_dic[each_concept]\n        vec_spike_trains = convert_vec_into_spike_trains(each_concept_vec)\n        AMnet = AMNet(in_features=5, out_features=5, givenWeights= AM_initial_weights, node=LIFNode, threshold=5, tau=0.1)\n        post_neuron_states = AMnet(vec_spike_trains)\n        post_neuron_states_lst = post_neuron_states.T.reshape(1, -1).tolist()[0]\n        binarycode = reducing_to_binarycode(post_neuron_states_lst, tolerance)\n        AM_binarycode_dic[each_concept] = binarycode\n        print(\"AM\", each_concept, binarycode)\n\n    pickle.dump (AM_binarycode_dic, AM_binarycode_file)\n    AM_binarycode_file.close ()\n\n\n\n\n\n"
  },
  {
    "path": "examples/Perception_and_Learning/MultisensoryIntegration/code/MultisensoryIntegrationDEMO_IM.py",
    "content": "#!/usr/bin/env python\n#-*- coding:utf-8 -*-\n__author__ = 'Yuwei Wang'\nimport torch\nimport pandas as pd\nfrom torch import nn\nfrom braincog.base.node.node import LIFNode, IzhNode\nfrom torch.nn.parameter import Parameter\n\n\n\nclass IMNet(nn.Module):\n    def __init__(self,\n                 in_features: int,\n                 out_features: int,\n                 givenWeights,\n                 bias=False,\n                 node=LIFNode,\n                 threshold=5,\n                 tau=0.1):\n        super().__init__()\n        if node is None:\n            raise TypeError\n\n        self.fc = nn.Linear(in_features=in_features,\n                            out_features=out_features, bias=bias)\n        self.fc.weight = Parameter(givenWeights)\n        self.node = node(threshold, tau)\n\n    def forward(self, x):\n        x = torch.tensor(x, dtype=torch.float)\n        x = self.fc(x)\n        x = self.node(x)\n        return x\n\n\ndef get_concept_dataset_dic_and_initial_weights_lst(BBSR_path):\n    modality_lst = ['Auditory', 'Gustatory', 'Haptic', 'Olfactory', 'Visual']\n    # load_concept_dataset_df\n    df_BBSR = pd.read_excel (BBSR_path, sheet_name=\"Sheet1\", header=0, index_col=0,\n                             usecols=[0, 1, 20, 26, 34, 35, 36])\n\n    df_BBSR.rename (\n        columns={'Word': 'Concept', 'Audiation_mean': 'Auditory', 'Taste': 'Gustatory',\n                 'Somatic_mean': 'Haptic',\n                 \"Smell\": \"Olfactory\", \"Visual_mean\": \"Visual\"}, inplace=True)\n    concept_dataset_df = df_BBSR.drop_duplicates (subset=\"Concept\")\n\n\n    # get bayes weights\n    var_lst = concept_dataset_df.var ().tolist ()\n    c = 1 / sum ([1 / i for i in var_lst])\n    bayes_weights_lst = [c / i for i in var_lst]\n\n    # min-max\n    z_minmax = lambda x: (x - np.min (x)) / (np.max (x) - np.min (x))\n    dataset_df_minmax = concept_dataset_df[modality_lst].apply (z_minmax)\n    concept_dataset_df = pd.concat ([concept_dataset_df[['Concept']], dataset_df_minmax], axis=1)\n\n    # output\n    dataset_concept_dims_dic = concept_dataset_df.to_dict (\"index\")\n    final_concept_dims_dic = {}\n    for each_key in dataset_concept_dims_dic.keys ():\n        each_concept_name = dataset_concept_dims_dic[each_key].pop ('Concept')\n        final_concept_dims_dic[each_concept_name] = [dataset_concept_dims_dic[each_key][each_modality] for each_modality\n                                                     in modality_lst]\n\n    return final_concept_dims_dic, bayes_weights_lst\n\ndef convert_vec_into_spike_trains(each_concept_vec):\n    # generate input with Poisson-encoded spikes\n    tmp = torch.tensor ([each_concept_vec * time])\n    rates = tmp.view (time, -1)\n    vec_spike_trains = torch.bernoulli (rates).byte ()  # concept_representation\n    vec_spike_trains = torch.tensor (vec_spike_trains, dtype=torch.float)\n    return vec_spike_trains\n\ndef reducing_to_binarycode(post_neuron_states_lst, tolerance):\n    post_neuron_states_lst = [int (i) for i in post_neuron_states_lst]\n\n    if len (post_neuron_states_lst) % tolerance != 0:\n        placeholder = [0] * (tolerance - len (post_neuron_states_lst) % tolerance)\n\n        post_neuron_states_lst_with_placeholder = post_neuron_states_lst + placeholder\n    else:\n        post_neuron_states_lst_with_placeholder = post_neuron_states_lst\n\n    post_neuron_states_lst_with_placeholder = np.array (post_neuron_states_lst_with_placeholder).reshape (-1, tolerance)\n    binarycode = \"\"\n    for sub_arr in post_neuron_states_lst_with_placeholder:\n        if 1.0 in sub_arr:\n            binarycode += \"1\"\n        else:\n            binarycode += \"0\"\n    return binarycode\n\n\n\nif __name__ == \"__main__\":\n    import numpy as np\n    import pickle\n\n    # Dataset Reference: Binder JR, Conant LL, Humphries CJ, Fernandino L, Simons SB, Aguilar M, Desai RH.\n    # Toward a brain-based componential semantic representation. Cogn Neuropsychol. 2016 May-Jun;33(3-4):130-74.\n    # doi: 10.1080/02643294.2016.1147426. Epub 2016 Jun 16. PMID: 27310469.\n    BBSR_path = \"../data/BBSR-5modalities.xlsx\"\n    IM_binarycode_file = open ( \"../results/IM_binarycode.pickle\", \"wb\")\n\n\n    time = 1000\n    tolerance = 2\n\n    concept_dims_dic, bayes_weights_lst = get_concept_dataset_dic_and_initial_weights_lst (BBSR_path)\n    IM_initial_weights = torch.tensor([bayes_weights_lst], dtype=torch.float)\n\n    IM_binarycode_dic = {}\n    for each_concept in concept_dims_dic.keys():\n        #print(\"current concept: \", each_concept)\n        each_concept_vec = concept_dims_dic[each_concept]\n        vec_spike_trains = convert_vec_into_spike_trains(each_concept_vec)\n        IMnet = IMNet(in_features=5, out_features=1, givenWeights= IM_initial_weights, node=LIFNode, threshold=5, tau=0.1)\n        post_neuron_states = IMnet(vec_spike_trains)\n        post_neuron_states_lst = post_neuron_states.T.tolist()[0]\n        binarycode = reducing_to_binarycode(post_neuron_states_lst, tolerance)\n        IM_binarycode_dic[each_concept] = binarycode\n        print (\"IM\", each_concept, binarycode)\n    pickle.dump (IM_binarycode_dic, IM_binarycode_file)\n    IM_binarycode_file.close ()\n\n\n\n\n"
  },
  {
    "path": "examples/Perception_and_Learning/MultisensoryIntegration/code/measure_and_visualization.py",
    "content": "#!/usr/bin/env python\n#-*- coding:utf-8 -*-\n__author__ = 'Yuwei Wang'\ndef get_concept_dataset_dic_and_initial_weights_lst(BBSR_path):\n    modality_lst = ['Auditory', 'Gustatory', 'Haptic', 'Olfactory', 'Visual']\n    # load_concept_dataset_df\n    df_BBSR = pd.read_excel (BBSR_path, sheet_name=\"Sheet1\", header=0, index_col=0,\n                             usecols=[0, 1, 20, 26, 34, 35, 36])\n\n    df_BBSR.rename (\n        columns={'Word': 'Concept', 'Audiation_mean': 'Auditory', 'Taste': 'Gustatory',\n                 'Somatic_mean': 'Haptic',\n                 \"Smell\": \"Olfactory\", \"Visual_mean\": \"Visual\"}, inplace=True)\n    concept_dataset_df = df_BBSR.drop_duplicates (subset=\"Concept\")\n\n\n    # get bayes weights\n    var_lst = concept_dataset_df.var ().tolist ()\n    c = 1 / sum ([1 / i for i in var_lst])\n    bayes_weights_lst = [c / i for i in var_lst]\n\n    # min-max\n    z_minmax = lambda x: (x - np.min (x)) / (np.max (x) - np.min (x))\n    dataset_df_minmax = concept_dataset_df[modality_lst].apply (z_minmax)\n    concept_dataset_df = pd.concat ([concept_dataset_df[['Concept']], dataset_df_minmax], axis=1)\n\n    # output\n    dataset_concept_dims_dic = concept_dataset_df.to_dict (\"index\")\n    final_concept_dims_dic = {}\n    for each_key in dataset_concept_dims_dic.keys ():\n        each_concept_name = dataset_concept_dims_dic[each_key].pop ('Concept')\n        final_concept_dims_dic[each_concept_name] = [dataset_concept_dims_dic[each_key][each_modality] for each_modality\n                                                     in modality_lst]\n\n    return final_concept_dims_dic, bayes_weights_lst\ndef load_binarycode_dic(filename):\n    # load binarycode\n    binarycode_file = open (filename, 'rb')\n    binarycode_dic = pickle.load (binarycode_file)\n    binarycode_file.close ()\n    return binarycode_dic\ndef load_m_dataset_concept_set_lst(m_dataset_name):\n    if m_dataset_name == \"McRae\":\n        concept_featureslst_dic = load_McRae_concept_feature_lst ()\n    elif m_dataset_name == \"CSLB\":\n        concept_featureslst_dic = load_CSLB_concept_feature_lst ()\n    return list(concept_featureslst_dic.keys())\ndef load_McRae_concept_feature_lst():\n    McRae_file = \"../data/McRae_norms.xlsx\"\n    df_McRae = pd.read_excel (McRae_file, sheet_name=\"concepts_features\", header=0,\n                              usecols=[0, 1])\n\n    origin_dic = df_McRae.to_dict ('index')\n\n    ## get concept_featureslst_dic\n    concept_featureslst_dic = {}\n    for index in origin_dic.keys ():\n        if origin_dic[index]['Concept'] not in concept_featureslst_dic.keys ():\n            concept_featureslst_dic[origin_dic[index]['Concept']] = [origin_dic[index]['Feature']]\n        else:\n            concept_featureslst_dic[origin_dic[index]['Concept']].append (origin_dic[index]['Feature'])\n\n    return concept_featureslst_dic\ndef load_CSLB_concept_feature_lst():\n    CSLB_file = \"../data/CSLB_norms.xlsx\"\n    df_CSLB = pd.read_excel (CSLB_file, header=0,\n                              usecols=[0, 2, 3])\n\n    df_CSLB.rename (\n        columns={'domain': 'Domain', 'concept': 'Concept', 'feature': 'Feature'}, inplace=True)\n    origin_dic = df_CSLB.to_dict ('index')\n\n    ## get concept_featureslst_dic\n    concept_featureslst_dic = {}\n    for index in origin_dic.keys ():\n        if origin_dic[index]['Concept'] not in concept_featureslst_dic.keys ():\n            concept_featureslst_dic[origin_dic[index]['Concept']] = [origin_dic[index]['Feature']]\n        else:\n            concept_featureslst_dic[origin_dic[index]['Concept']].append (origin_dic[index]['Feature'])\n\n    return concept_featureslst_dic\ndef get_m_dataset_concept_k_similar_concepts_dic(m_dataset_name, overlap_concept_lst, k):\n    m_dataset_concept_k_similar_concepts_dic = {}\n    if m_dataset_name == \"McRae\":\n        concept_featureslst_dic = load_McRae_concept_feature_lst ()\n    elif m_dataset_name == \"CSLB\":\n        concept_featureslst_dic = load_CSLB_concept_feature_lst()\n\n    for each_m_concept1 in overlap_concept_lst:\n            similar_concepts_overnum_tuple_lst = []\n            for each_m_concept2 in overlap_concept_lst:\n                if each_m_concept1 != each_m_concept2:\n                    overlap_feature_lst = [each_f for each_f in concept_featureslst_dic[each_m_concept1] if\n                                           each_f in concept_featureslst_dic[each_m_concept2]]\n\n                    similar_concepts_overnum_tuple_lst.append ((each_m_concept2, len (overlap_feature_lst)))\n            sorted_similar_concepts_overnum_tuple_lst = sorted (similar_concepts_overnum_tuple_lst, key=lambda x: x[1],\n                                                                reverse=True)\n            m_dataset_concept_k_similar_concepts_dic[each_m_concept1] = [tp[0] for tp in sorted_similar_concepts_overnum_tuple_lst][:k]\n    return m_dataset_concept_k_similar_concepts_dic\ndef get_dataset_concept_ME_dic(origin_dataset_dic):\n    dataset_concept_ME_dic = {}\n    for each_concept in origin_dataset_dic.keys():\n        vector =  origin_dataset_dic[each_concept]\n        # Modality exclusivity is a measure of the extent to which a particular property is perceived through a single\n        # perceptual modality. Where each property has a vector containing mean strength ratings for all modalities,\n        # modality exclusivity is calculated as the range of values divided by the sum（极差除以总和）\n        ModalityExclusivity = (max(vector) - min(vector))/(0.0+sum(vector))\n        # print(each_concept, ModalityExclusivity)\n        dataset_concept_ME_dic[each_concept]  = ModalityExclusivity\n\n    return dataset_concept_ME_dic\ndef get_dataset_concept_k_similar_concepts_ranking_dic_dic (dataset_dic, overlap_concept_lst, ifbinary):\n    dataset_concept_k_similar_concepts_ranking_dic_dic = {}\n    for each_concept1 in overlap_concept_lst:\n        similar_concepts_similarity_tuple_lst = []\n        for each_concept2 in overlap_concept_lst:\n            if each_concept1 != each_concept2:\n                if ifbinary:\n                    similarity = get_vec_Harmming_similarity(dataset_dic[each_concept1], dataset_dic[each_concept2])\n                else:\n                    similarity = get_vec_cos_similarity(dataset_dic[each_concept1], dataset_dic[each_concept2])\n                similar_concepts_similarity_tuple_lst.append((each_concept2, similarity))\n            sorted_similar_concepts_similarity_tuple_lst = sorted (similar_concepts_similarity_tuple_lst, key=lambda x: x[1],\n                                                                reverse=True)\n            similar_concepts_ranking_dic = {}\n            for index, tp in enumerate(sorted_similar_concepts_similarity_tuple_lst):\n                similar_concepts_ranking_dic[tp[0]] = index + 1 # as ranking\n        dataset_concept_k_similar_concepts_ranking_dic_dic[each_concept1] = similar_concepts_ranking_dic\n    return dataset_concept_k_similar_concepts_ranking_dic_dic\ndef get_vec_Harmming_similarity(concept1_vecstr, concept2_vecstr):\n    from scipy.spatial import distance\n    HD_similarity = 1 - distance.hamming(list(concept1_vecstr), list(concept2_vecstr))\n    return HD_similarity\ndef get_ME_kAR_corr(binarycode_type, concept_ME_dic, m_dataset, k):\n    from scipy.stats import pearsonr\n    if binarycode_type == \"AM\":\n        binarycode_dic = load_binarycode_dic(\"../results/AM_binarycode.pickle\")\n    elif binarycode_type == \"IM\":\n        binarycode_dic = load_binarycode_dic (\"../results/IM_binarycode.pickle\")\n    else:\n        print(\"INPUT ERROR!\")\n\n\n    m_dataset_concept_set = load_m_dataset_concept_set_lst (m_dataset)\n    dataset_concept_set = list (concept_dims_dic.keys ())\n    overlap_concept_lst = [i for i in m_dataset_concept_set if i in dataset_concept_set]\n\n\n    ifbinary = True\n    dataset_concept_similar_concepts_ranking_dic_dic = get_dataset_concept_k_similar_concepts_ranking_dic_dic (\n        binarycode_dic, overlap_concept_lst, ifbinary)\n\n    m_dataset_concept_k_similar_concepts_dic = get_m_dataset_concept_k_similar_concepts_dic (m_dataset,\n                                                                                             overlap_concept_lst, k)\n\n    ME_lst = []\n    raking_mean_lst = []\n    for each_concept in overlap_concept_lst:\n        ME = concept_ME_dic[each_concept]\n        k_similar_concepts_lst_in_m_dataset = m_dataset_concept_k_similar_concepts_dic[each_concept]\n        ranking_list_in_dataset = [dataset_concept_similar_concepts_ranking_dic_dic[each_concept][i] for i in\n                                   k_similar_concepts_lst_in_m_dataset]\n\n        # print(\"ranking_list_in_dataset: \", ranking_list_in_dataset, np.mean(ranking_list_in_dataset))\n        ranking_list_in_dataset = np.array (ranking_list_in_dataset)\n        ME_lst.append (ME)\n        raking_mean_lst.append (np.mean (ranking_list_in_dataset))\n\n    # corr\n    rho, _ = pearsonr (ME_lst, raking_mean_lst)\n    print(\"binarycode_type m_dataset k:\", binarycode_type, m_dataset, k)\n    print (\"correlation: \", rho)\n    return ME_lst, raking_mean_lst\ndef visualize_results(ME_lst, raking_mean_lst, jointplot_file):\n    import matplotlib.pyplot as plt\n    import seaborn as sns\n    import pandas as pd\n\n    jointplot_x = ME_lst\n    jointplot_y = raking_mean_lst\n\n    plt.figure ()\n    df = pd.DataFrame ([jointplot_x, jointplot_y])\n    df_final = df.T\n    df_final.rename (\n        columns={0: 'Modality Exclusivity', 1: 'the Average Ranking of 3 Similar Concepts'}, inplace=True)\n    sns.set_style (\"darkgrid\")\n\n    sns.jointplot (data=df_final, x=\"Modality Exclusivity\", y=\"the Average Ranking of 3 Similar Concepts\",\n                   kind=\"reg\",truncate=False, xlim=(0, 1), ylim=(0, 100))\n    plt.savefig (jointplot_file, dpi=300)\n    plt.close ()\n\n\n\n\nif __name__ == \"__main__\":\n    import numpy as np\n    import pickle\n    import pandas as pd\n\n    BBSR_path = \"../data/BBSR-5modalities.xlsx\"\n    all_m_dataset_lst = [\"McRae\", \"CSLB\"]\n    binarycode_type_lst = [\"AM\", \"IM\"]\n    k = 3\n\n    concept_dims_dic, _ = get_concept_dataset_dic_and_initial_weights_lst (BBSR_path)\n    concept_ME_dic = get_dataset_concept_ME_dic (concept_dims_dic)\n    plot_info_dic = {}\n    for binarycode_type in binarycode_type_lst:\n        for m_dataset in all_m_dataset_lst:\n            ME_lst, raking_mean_lst = get_ME_kAR_corr(binarycode_type, concept_ME_dic, m_dataset, k)\n            jointplot_file = \"../results/\"+m_dataset + \"_\" + binarycode_type + \"_results.png\"\n            visualize_results(ME_lst, raking_mean_lst, jointplot_file)\n\n\n\n\n\n\n\n\n\n"
  },
  {
    "path": "examples/Perception_and_Learning/NeuEvo/auto_augment.py",
    "content": "# -*- coding: utf-8 -*-\n#\n# modified from: https://github.com/DeepVoltaire/AutoAugment/\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nfrom PIL import Image, ImageEnhance, ImageOps\nimport random\n\n\nclass ImageNetPolicy(object):\n    \"\"\" Randomly choose one of the best 24 Sub-policies on ImageNet.\n\n            Example:\n                    policy = ImageNetPolicy()\n                    transformed = policy(image)\n\n            Example as a PyTorch Transform:\n                    transform = transforms.Compose([\n                            transforms.Resize(256),\n                            ImageNetPolicy(),\n                            transforms.ToTensor()])\n    \"\"\"\n\n    def __init__(self, fillcolor=(128, 128, 128)):\n        self.policies = [\n            SubPolicy(0.4, \"posterize\", 8, 0.6, \"rotate\", 9, fillcolor),\n            SubPolicy(0.6, \"solarize\", 5, 0.6, \"autocontrast\", 5, fillcolor),\n            SubPolicy(0.8, \"equalize\", 8, 0.6, \"equalize\", 3, fillcolor),\n            SubPolicy(0.6, \"posterize\", 7, 0.6, \"posterize\", 6, fillcolor),\n            SubPolicy(0.4, \"equalize\", 7, 0.2, \"solarize\", 4, fillcolor),\n\n            SubPolicy(0.4, \"equalize\", 4, 0.8, \"rotate\", 8, fillcolor),\n            SubPolicy(0.6, \"solarize\", 3, 0.6, \"equalize\", 7, fillcolor),\n            SubPolicy(0.8, \"posterize\", 5, 1.0, \"equalize\", 2, fillcolor),\n            SubPolicy(0.2, \"rotate\", 3, 0.6, \"solarize\", 8, fillcolor),\n            SubPolicy(0.6, \"equalize\", 8, 0.4, \"posterize\", 6, fillcolor),\n\n            SubPolicy(0.8, \"rotate\", 8, 0.4, \"color\", 0, fillcolor),\n            SubPolicy(0.4, \"rotate\", 9, 0.6, \"equalize\", 2, fillcolor),\n            SubPolicy(0.0, \"equalize\", 7, 0.8, \"equalize\", 8, fillcolor),\n            SubPolicy(0.6, \"invert\", 4, 1.0, \"equalize\", 8, fillcolor),\n            SubPolicy(0.6, \"color\", 4, 1.0, \"contrast\", 8, fillcolor),\n\n            SubPolicy(0.8, \"rotate\", 8, 1.0, \"color\", 2, fillcolor),\n            SubPolicy(0.8, \"color\", 8, 0.8, \"solarize\", 7, fillcolor),\n            SubPolicy(0.4, \"sharpness\", 7, 0.6, \"invert\", 8, fillcolor),\n            SubPolicy(0.6, \"shearX\", 5, 1.0, \"equalize\", 9, fillcolor),\n            SubPolicy(0.4, \"color\", 0, 0.6, \"equalize\", 3, fillcolor),\n\n            SubPolicy(0.4, \"equalize\", 7, 0.2, \"solarize\", 4, fillcolor),\n            SubPolicy(0.6, \"solarize\", 5, 0.6, \"autocontrast\", 5, fillcolor),\n            SubPolicy(0.6, \"invert\", 4, 1.0, \"equalize\", 8, fillcolor),\n            SubPolicy(0.6, \"color\", 4, 1.0, \"contrast\", 8, fillcolor),\n            SubPolicy(0.8, \"equalize\", 8, 0.6, \"equalize\", 3, fillcolor)]\n\n    def __call__(self, img):\n        policy_idx = random.randint(0, len(self.policies) - 1)\n        return self.policies[policy_idx](img)\n\n    def __repr__(self):\n        return \"AutoAugment ImageNet Policy\"\n\n\nclass CIFAR10Policy(object):\n    \"\"\" Randomly choose one of the best 25 Sub-policies on CIFAR10.\n\n            Example:\n                    policy = CIFAR10Policy()\n                    transformed = policy(image)\n\n            Example as a PyTorch Transform:\n                    transform = transforms.Compose([\n                            transforms.Resize(256),\n                            CIFAR10Policy(),\n                            transforms.ToTensor()])\n    \"\"\"\n\n    def __init__(self, fillcolor=(128, 128, 128)):\n        self.policies = [\n            SubPolicy(0.1, \"invert\", 7, 0.2, \"contrast\", 6, fillcolor),\n            SubPolicy(0.7, \"rotate\", 2, 0.3, \"translateX\", 9, fillcolor),\n            SubPolicy(0.8, \"sharpness\", 1, 0.9, \"sharpness\", 3, fillcolor),\n            SubPolicy(0.5, \"shearY\", 8, 0.7, \"translateY\", 9, fillcolor),\n            SubPolicy(0.5, \"autocontrast\", 8, 0.9, \"equalize\", 2, fillcolor),\n\n            SubPolicy(0.2, \"shearY\", 7, 0.3, \"posterize\", 7, fillcolor),\n            SubPolicy(0.4, \"color\", 3, 0.6, \"brightness\", 7, fillcolor),\n            SubPolicy(0.3, \"sharpness\", 9, 0.7, \"brightness\", 9, fillcolor),\n            SubPolicy(0.6, \"equalize\", 5, 0.5, \"equalize\", 1, fillcolor),\n            SubPolicy(0.6, \"contrast\", 7, 0.6, \"sharpness\", 5, fillcolor),\n\n            SubPolicy(0.7, \"color\", 7, 0.5, \"translateX\", 8, fillcolor),\n            SubPolicy(0.3, \"equalize\", 7, 0.4, \"autocontrast\", 8, fillcolor),\n            SubPolicy(0.4, \"translateY\", 3, 0.2, \"sharpness\", 6, fillcolor),\n            SubPolicy(0.9, \"brightness\", 6, 0.2, \"color\", 8, fillcolor),\n            SubPolicy(0.5, \"solarize\", 2, 0.0, \"invert\", 3, fillcolor),\n\n            SubPolicy(0.2, \"equalize\", 0, 0.6, \"autocontrast\", 0, fillcolor),\n            SubPolicy(0.2, \"equalize\", 8, 0.8, \"equalize\", 4, fillcolor),\n            SubPolicy(0.9, \"color\", 9, 0.6, \"equalize\", 6, fillcolor),\n            SubPolicy(0.8, \"autocontrast\", 4, 0.2, \"solarize\", 8, fillcolor),\n            SubPolicy(0.1, \"brightness\", 3, 0.7, \"color\", 0, fillcolor),\n\n            SubPolicy(0.4, \"solarize\", 5, 0.9, \"autocontrast\", 3, fillcolor),\n            SubPolicy(0.9, \"translateY\", 9, 0.7, \"translateY\", 9, fillcolor),\n            SubPolicy(0.9, \"autocontrast\", 2, 0.8, \"solarize\", 3, fillcolor),\n            SubPolicy(0.8, \"equalize\", 8, 0.1, \"invert\", 3, fillcolor),\n            SubPolicy(0.7, \"translateY\", 9, 0.9, \"autocontrast\", 1, fillcolor)]\n\n    def __call__(self, img):\n        policy_idx = random.randint(0, len(self.policies) - 1)\n        return self.policies[policy_idx](img)\n\n    def __repr__(self):\n        return \"AutoAugment CIFAR10 Policy\"\n\n\nclass SVHNPolicy(object):\n    \"\"\" Randomly choose one of the best 25 Sub-policies on SVHN.\n\n            Example:\n                    policy = SVHNPolicy()\n                    transformed = policy(image)\n\n            Example as a PyTorch Transform:\n                    transform = transforms.Compose([\n                            transforms.Resize(256),\n                            SVHNPolicy(),\n                            transforms.ToTensor()])\n    \"\"\"\n\n    def __init__(self, fillcolor=(128, 128, 128)):\n        self.policies = [\n            SubPolicy(0.9, \"shearX\", 4, 0.2, \"invert\", 3, fillcolor),\n            SubPolicy(0.9, \"shearY\", 8, 0.7, \"invert\", 5, fillcolor),\n            SubPolicy(0.6, \"equalize\", 5, 0.6, \"solarize\", 6, fillcolor),\n            SubPolicy(0.9, \"invert\", 3, 0.6, \"equalize\", 3, fillcolor),\n            SubPolicy(0.6, \"equalize\", 1, 0.9, \"rotate\", 3, fillcolor),\n\n            SubPolicy(0.9, \"shearX\", 4, 0.8, \"autocontrast\", 3, fillcolor),\n            SubPolicy(0.9, \"shearY\", 8, 0.4, \"invert\", 5, fillcolor),\n            SubPolicy(0.9, \"shearY\", 5, 0.2, \"solarize\", 6, fillcolor),\n            SubPolicy(0.9, \"invert\", 6, 0.8, \"autocontrast\", 1, fillcolor),\n            SubPolicy(0.6, \"equalize\", 3, 0.9, \"rotate\", 3, fillcolor),\n\n            SubPolicy(0.9, \"shearX\", 4, 0.3, \"solarize\", 3, fillcolor),\n            SubPolicy(0.8, \"shearY\", 8, 0.7, \"invert\", 4, fillcolor),\n            SubPolicy(0.9, \"equalize\", 5, 0.6, \"translateY\", 6, fillcolor),\n            SubPolicy(0.9, \"invert\", 4, 0.6, \"equalize\", 7, fillcolor),\n            SubPolicy(0.3, \"contrast\", 3, 0.8, \"rotate\", 4, fillcolor),\n\n            SubPolicy(0.8, \"invert\", 5, 0.0, \"translateY\", 2, fillcolor),\n            SubPolicy(0.7, \"shearY\", 6, 0.4, \"solarize\", 8, fillcolor),\n            SubPolicy(0.6, \"invert\", 4, 0.8, \"rotate\", 4, fillcolor),\n            SubPolicy(0.3, \"shearY\", 7, 0.9, \"translateX\", 3, fillcolor),\n            SubPolicy(0.1, \"shearX\", 6, 0.6, \"invert\", 5, fillcolor),\n\n            SubPolicy(0.7, \"solarize\", 2, 0.6, \"translateY\", 7, fillcolor),\n            SubPolicy(0.8, \"shearY\", 4, 0.8, \"invert\", 8, fillcolor),\n            SubPolicy(0.7, \"shearX\", 9, 0.8, \"translateY\", 3, fillcolor),\n            SubPolicy(0.8, \"shearY\", 5, 0.7, \"autocontrast\", 3, fillcolor),\n            SubPolicy(0.7, \"shearX\", 2, 0.1, \"invert\", 5, fillcolor)]\n\n    def __call__(self, img):\n        policy_idx = random.randint(0, len(self.policies) - 1)\n        return self.policies[policy_idx](img)\n\n    def __repr__(self):\n        return \"AutoAugment SVHN Policy\"\n\n\nclass SubPolicy(object):\n\n    def __init__(self,\n                 p1,\n                 operation1,\n                 magnitude_idx1,\n                 p2,\n                 operation2,\n                 magnitude_idx2,\n                 fillcolor=(128, 128, 128)):\n        ranges = {\n            \"shearX\": np.linspace(0, 0.3, 10),\n            \"shearY\": np.linspace(0, 0.3, 10),\n            \"translateX\": np.linspace(0, 150 / 331, 10),\n            \"translateY\": np.linspace(0, 150 / 331, 10),\n            \"rotate\": np.linspace(0, 30, 10),\n            \"color\": np.linspace(0.0, 0.9, 10),\n            \"posterize\": np.round(np.linspace(8, 4, 10), 0).astype(np.int),\n            \"solarize\": np.linspace(256, 0, 10),\n            \"contrast\": np.linspace(0.0, 0.9, 10),\n            \"sharpness\": np.linspace(0.0, 0.9, 10),\n            \"brightness\": np.linspace(0.0, 0.9, 10),\n            \"autocontrast\": [0] * 10,\n            \"equalize\": [0] * 10,\n            \"invert\": [0] * 10}\n\n        # from https://stackoverflow.com/questions/5252170/specify-image\n        # -filling-color-when-rotating-in-python-with-pil-and-setting-expand\n        def rotate_with_fill(img, magnitude):\n            rot = img.convert(\"RGBA\").rotate(magnitude)\n            return Image.composite(\n                rot, Image.new(\"RGBA\", rot.size, (128,) * 4), rot) \\\n                .convert(img.mode)\n\n        func = {\n            \"shearX\": lambda img, magnitude: img.transform(\n                img.size,\n                Image.AFFINE,\n                (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),\n                Image.BICUBIC,\n                fillcolor=fillcolor),\n            \"shearY\": lambda img, magnitude: img.transform(\n                img.size,\n                Image.AFFINE,\n                (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),\n                Image.BICUBIC,\n                fillcolor=fillcolor),\n            \"translateX\": lambda img, magnitude: img.transform(\n                img.size,\n                Image.AFFINE,\n                (1, 0, magnitude * img.size[0] *\n                 random.choice([-1, 1]), 0, 1, 0),\n                fillcolor=fillcolor),\n            \"translateY\": lambda img, magnitude: img.transform(\n                img.size,\n                Image.AFFINE,\n                (1, 0, 0, 0, 1, magnitude *\n                 img.size[1] * random.choice([-1, 1])),\n                fillcolor=fillcolor),\n            \"rotate\": lambda img, magnitude: rotate_with_fill(img, magnitude),\n            # \"rotate\": lambda img, magnitude: \\\n            #     img.rotate(magnitude * random.choice([-1, 1])),\n            \"color\": lambda img, magnitude: \\\n            ImageEnhance.Color(img).enhance(\n                1 + magnitude * random.choice([-1, 1])),\n            \"posterize\": lambda img, magnitude: \\\n            ImageOps.posterize(img, magnitude),\n            \"solarize\": lambda img, magnitude: \\\n            ImageOps.solarize(img, magnitude),\n            \"contrast\": lambda img, magnitude: \\\n            ImageEnhance.Contrast(img).enhance(\n                1 + magnitude * random.choice([-1, 1])),\n            \"sharpness\": lambda img, magnitude: \\\n            ImageEnhance.Sharpness(img).enhance(\n                1 + magnitude * random.choice([-1, 1])),\n            \"brightness\": lambda img, magnitude: \\\n            ImageEnhance.Brightness(img).enhance(\n                1 + magnitude * random.choice([-1, 1])),\n            \"autocontrast\": lambda img, magnitude: ImageOps.autocontrast(img),\n            \"equalize\": lambda img, magnitude: ImageOps.equalize(img),\n            \"invert\": lambda img, magnitude: ImageOps.invert(img)\n        }\n\n        # self.name = \"{}_{:.2f}_and_{}_{:.2f}\".format(\n        #     operation1, ranges[operation1][magnitude_idx1],\n        #     operation2, ranges[operation2][magnitude_idx2])\n        self.p1 = p1\n        self.operation1 = func[operation1]\n        self.magnitude1 = ranges[operation1][magnitude_idx1]\n        self.p2 = p2\n        self.operation2 = func[operation2]\n        self.magnitude2 = ranges[operation2][magnitude_idx2]\n\n    def __call__(self, img):\n        if random.random() < self.p1:\n            img = self.operation1(img, self.magnitude1)\n        if random.random() < self.p2:\n            img = self.operation2(img, self.magnitude2)\n        return img\n"
  },
  {
    "path": "examples/Perception_and_Learning/NeuEvo/main.py",
    "content": "import argparse\nimport time\n\nimport timm.models\nimport yaml\nimport os\nimport logging\nfrom collections import OrderedDict\nfrom contextlib import suppress\nfrom datetime import datetime\n\nfrom braincog.base.node.node import *\nfrom braincog.utils import *\nfrom braincog.base.utils.criterions import *\nfrom braincog.datasets.datasets import *\nfrom braincog.model_zoo.resnet import *\nfrom braincog.model_zoo.convnet import *\nfrom braincog.utils import save_feature_map, setup_seed\nfrom braincog.model_zoo.NeuEvo import genotypes\nfrom braincog.model_zoo.NeuEvo.model import NetworkCIFAR\nfrom braincog.model_zoo.NeuEvo.others import *\nfrom braincog.base.utils.visualization import plot_tsne_3d, plot_tsne, plot_confusion_matrix\n\nimport torch\nimport torch.nn as nn\nimport torchvision.utils\nfrom torch.nn.parallel import DistributedDataParallel as NativeDDP\n\nfrom timm.data import create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset\nfrom timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model\nfrom timm.utils import *\nfrom timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy\nfrom timm.optim import create_optimizer\nfrom timm.scheduler import create_scheduler\nfrom timm.utils import ApexScaler, NativeScaler\n\n# from ptflops import get_model_complexity_info\n# from thop import profile, clever_format\n\ntorch.backends.cudnn.benchmark = True\n_logger = logging.getLogger('train')\n\n# The first arg parser parses out only the --config argument, this argument is used to\n# load a yaml file containing key-values that override the defaults for the main parser below\nconfig_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)\nparser.add_argument('-c', '--config', default='', type=str, metavar='FILE',\n                    help='YAML config file specifying default arguments')\n\nparser = argparse.ArgumentParser(description='SNN Training and Evaluating')\n\n# Model parameters\nparser.add_argument('--dataset', default='cifar10', type=str)\nparser.add_argument('--model', default='cifar_convnet', type=str, metavar='MODEL',\n                    help='Name of model to train (default: \"countception\"')\nparser.add_argument('--pretrained', action='store_true', default=False,\n                    help='Start with pretrained version of specified network (if avail)')\nparser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',\n                    help='Initialize model from this checkpoint (default: none)')\nparser.add_argument('--resume', default='', type=str, metavar='PATH',\n                    help='Resume full model and optimizer state from checkpoint (default: none)')\nparser.add_argument('--eval_checkpoint', default='', type=str, metavar='PATH',\n                    help='path to eval checkpoint (default: none)')\nparser.add_argument('--no-resume-opt', action='store_true', default=False,\n                    help='prevent resume of optimizer state when resuming model')\nparser.add_argument('--num-classes', type=int, default=10, metavar='N',\n                    help='number of label classes (default: 1000)')\nparser.add_argument('--gp', default=None, type=str, metavar='POOL',\n                    help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')\n\n# Dataset parameters for static datasets\nparser.add_argument('--img-size', type=int, default=224, metavar='N',\n                    help='Image patch size (default: None => model default)')\nparser.add_argument('--crop-pct', default=None, type=float,\n                    metavar='N', help='inputs image center crop percent (for validation only)')\nparser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',\n                    help='Override mean pixel value of dataset')\nparser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',\n                    help='Override std deviation of of dataset')\nparser.add_argument('--interpolation', default='', type=str, metavar='NAME',\n                    help='Image resize interpolation type (overrides model)')\n\n# Dataloader parameters\nparser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',\n                    help='inputs batch size for training (default: 128)')\nparser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',\n                    help='ratio of validation batch size to training batch size (default: 1)')\n\n# Optimizer parameters\nparser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',\n                    help='Optimizer (default: \"adamw\"')\nparser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',\n                    help='Optimizer Epsilon (default: None, use opt default)')\nparser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',\n                    help='Optimizer Betas (default: None, use opt default)')\nparser.add_argument('--momentum', type=float, default=0.9, metavar='M',\n                    help='Optimizer momentum (default: 0.9)')\nparser.add_argument('--weight-decay', type=float, default=0.01,\n                    help='weight decay (default: 0.01 for adamw)')\nparser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',\n                    help='Clip gradient norm (default: None, no clipping)')\nparser.add_argument('--adam-epoch', type=int, default=1000, help='lamb switch to adamw')\n\n# Learning rate schedule parameters\nparser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',\n                    help='LR scheduler (default: \"cosine\"')\nparser.add_argument('--lr', type=float, default=5e-3, metavar='LR',\n                    help='learning rate (default: 0.01)')\nparser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',\n                    help='learning rate noise on/off epoch percentages')\nparser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',\n                    help='learning rate noise limit percent (default: 0.67)')\nparser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',\n                    help='learning rate noise std-dev (default: 1.0)')\nparser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',\n                    help='learning rate cycle len multiplier (default: 1.0)')\nparser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',\n                    help='learning rate cycle limit')\nparser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',\n                    help='warmup learning rate (default: 0.0001)')\nparser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',\n                    help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')\nparser.add_argument('--epochs', type=int, default=600, metavar='N',\n                    help='number of epochs to train (default: 2)')\nparser.add_argument('--start-epoch', default=None, type=int, metavar='N',\n                    help='manual epoch number (useful on restarts)')\nparser.add_argument('--decay-epochs', type=float, default=30, metavar='N',\n                    help='epoch interval to decay LR')\nparser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',\n                    help='epochs to warmup LR, if scheduler supports')\nparser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',\n                    help='epochs to cooldown LR at min_lr, after cyclic schedule ends')\nparser.add_argument('--patience-epochs', type=int, default=10, metavar='N',\n                    help='patience epochs for Plateau LR scheduler (default: 10')\nparser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',\n                    help='LR decay rate (default: 0.1)')\nparser.add_argument('--power', type=int, default=1, help='power')\n\n# Augmentation & regularization parameters ONLY FOR IMAGE NET\nparser.add_argument('--no-aug', action='store_true', default=False,\n                    help='Disable all training augmentation, override other train aug args')\nparser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',\n                    help='Random resize scale (default: 0.08 1.0)')\nparser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                    help='Random resize aspect ratio (default: 0.75 1.33)')\nparser.add_argument('--hflip', type=float, default=0.5,\n                    help='Horizontal flip training aug probability')\nparser.add_argument('--vflip', type=float, default=0.,\n                    help='Vertical flip training aug probability')\nparser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',\n                    help='Color jitter factor (default: 0.4)')\nparser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',\n                    help='Use AutoAugment policy. \"v0\" or \"original\". (default: None)'),\nparser.add_argument('--aug-splits', type=int, default=0,\n                    help='Number of augmentation splits (default: 0, valid: 0 or >=2)')\nparser.add_argument('--jsd', action='store_true', default=False,\n                    help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')\nparser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',\n                    help='Random erase prob (default: 0.25)')\nparser.add_argument('--remode', type=str, default='pixel',\n                    help='Random erase mode (default: \"const\")')\nparser.add_argument('--recount', type=int, default=1,\n                    help='Random erase count (default: 1)')\nparser.add_argument('--resplit', action='store_true', default=False,\n                    help='Do not random erase first (clean) augmentation split')\nparser.add_argument('--mixup', type=float, default=0.8,\n                    help='mixup alpha, mixup enabled if > 0. (default: 0.)')\nparser.add_argument('--cutmix', type=float, default=1.0,\n                    help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')\nparser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,\n                    help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')\nparser.add_argument('--mixup-prob', type=float, default=1.0,\n                    help='Probability of performing mixup or cutmix when either/both is enabled')\nparser.add_argument('--mixup-switch-prob', type=float, default=0.5,\n                    help='Probability of switching to cutmix when both mixup and cutmix enabled')\nparser.add_argument('--mixup-mode', type=str, default='batch',\n                    help='How to apply mixup/cutmix params. Per \"batch\", \"pair\", or \"elem\"')\nparser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',\n                    help='Turn off mixup after this epoch, disabled if 0 (default: 0)')\nparser.add_argument('--smoothing', type=float, default=0.1,\n                    help='Label smoothing (default: 0.1)')\nparser.add_argument('--train-interpolation', type=str, default='random',\n                    help='Training interpolation (random, bilinear, bicubic default: \"random\")')\nparser.add_argument('--drop', type=float, default=0.0, metavar='PCT',\n                    help='Dropout rate (default: 0.0)')\nparser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',\n                    help='Drop connect rate, DEPRECATED, use drop-path (default: None)')\nparser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',\n                    help='Drop path rate (default: None)')\nparser.add_argument('--drop-block', type=float, default=None, metavar='PCT',\n                    help='Drop block rate (default: None)')\nparser.add_argument('--newton-maxiter', default=20, type=int,\n                    help='max iterration in newton method')\nparser.add_argument('--reset-drop', action='store_true', default=False,\n                    help='whether to reset drop')\nparser.add_argument('--kernel-method', type=str, default='cuda', choices=['torch', 'cuda'],\n                    help='The implementation way of gaussian kernel method, choose from \"cuda\" and \"torch\"')\n\n# Batch norm parameters (only works with gen_efficientnet based models currently)\nparser.add_argument('--bn-tf', action='store_true', default=False,\n                    help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')\nparser.add_argument('--bn-momentum', type=float, default=None,\n                    help='BatchNorm momentum override (if not None)')\nparser.add_argument('--bn-eps', type=float, default=None,\n                    help='BatchNorm epsilon override (if not None)')\nparser.add_argument('--sync-bn', action='store_true',\n                    help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')\nparser.add_argument('--dist-bn', type=str, default='',\n                    help='Distribute BatchNorm stats between node after each epoch (\"broadcast\", \"reduce\", or \"\")')\nparser.add_argument('--split-bn', action='store_true',\n                    help='Enable separate BN layers per augmentation split.')\n\n# Model Exponential Moving Average\nparser.add_argument('--model-ema', action='store_true', default=False,\n                    help='Enable tracking moving average of model weights')\nparser.add_argument('--model-ema-force-cpu', action='store_true', default=False,\n                    help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')\nparser.add_argument('--model-ema-decay', type=float, default=0.99996,\n                    help='decay factor for model weights moving average (default: 0.9998)')\n\n# Misc\nparser.add_argument('--seed', type=int, default=42, metavar='S',\n                    help='random seed (default: 42)')\nparser.add_argument('--log-interval', type=int, default=50, metavar='N',\n                    help='how many batches to wait before logging training status')\nparser.add_argument('--recovery-interval', type=int, default=0, metavar='N',\n                    help='how many batches to wait before writing recovery checkpoint')\nparser.add_argument('-j', '--workers', type=int, default=8, metavar='N',\n                    help='how many training processes to use (default: 1)')\nparser.add_argument('--num-gpu', type=int, default=1,\n                    help='Number of GPUS to use')\nparser.add_argument('--save-images', action='store_true', default=False,\n                    help='save images of inputs bathes every log interval for debugging')\nparser.add_argument('--amp', action='store_true', default=False,\n                    help='use NVIDIA Apex AMP or Native AMP for mixed precision training')\nparser.add_argument('--apex-amp', action='store_true', default=False,\n                    help='Use NVIDIA Apex AMP mixed precision')\nparser.add_argument('--native-amp', action='store_true', default=False,\n                    help='Use Native Torch AMP mixed precision')\nparser.add_argument('--channels-last', action='store_true', default=False,\n                    help='Use channels_last memory layout')\nparser.add_argument('--pin-mem', action='store_true', default=False,\n                    help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')\nparser.add_argument('--no-prefetcher', action='store_true', default=False,\n                    help='disable fast prefetcher')\nparser.add_argument('--output', default='/data/floyed/BrainCog/darts', type=str, metavar='PATH',\n                    help='path to output folder (default: none, current dir)')\nparser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',\n                    help='Best metric (default: \"top1\"')\nparser.add_argument('--tta', type=int, default=0, metavar='N',\n                    help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')\nparser.add_argument('--local_rank', default=0, type=int)\nparser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,\n                    help='use the multi-epochs-loader to save time at the beginning of every epoch')\nparser.add_argument('--eval', action='store_true', help='Perform evaluation only')\nparser.add_argument('--device', type=int, default=0)\n\n# Spike parameters\nparser.add_argument('--step', type=int, default=10, help='Simulation time step (default: 10)')\nparser.add_argument('--encode', type=str, default='direct', help='Input encode method (default: direct)')\nparser.add_argument('--temporal-flatten', action='store_true',\n                    help='Temporal flatten to channels. ONLY FOR EVENT DATA TRAINING BY ANN')\nparser.add_argument('--adaptive-node', action='store_true')\nparser.add_argument('--critical-loss', action='store_true')\n\n# neuron type\nparser.add_argument('--node-type', type=str, default='LIFNode', help='Node type in network (default: PLIF)')\nparser.add_argument('--act-fun', type=str, default='GateGrad',\n                    help='Surogate Function in node. Only for Surrogate nodes (default: AtanGrad)')\nparser.add_argument('--threshold', type=float, default=.5, help='Firing threshold (default: 0.5)')\nparser.add_argument('--tau', type=float, default=2., help='Attenuation coefficient (default: 2.)')\nparser.add_argument('--requires-thres-grad', action='store_true')\nparser.add_argument('--sigmoid-thres', action='store_true')\n\nparser.add_argument('--loss-fn', type=str, default='ce', help='loss function (default: ce)')\nparser.add_argument('--noisy-grad', type=float, default=0.,\n                    help='Add noise to backward, sometime will make higher accuracy (default: 0.)')\nparser.add_argument('--spike-output', action='store_true', default=False,\n                    help='Using mem output or spike output (default: False)')\nparser.add_argument('--n_groups', type=int, default=1)\n\n# EventData Augmentation\nparser.add_argument('--mix-up', action='store_true', help='Mix-up for event data (default: False)')\nparser.add_argument('--cut-mix', action='store_true', help='CutMix for event data (default: False)')\nparser.add_argument('--event-mix', action='store_true', help='EventMix for event data (default: False)')\nparser.add_argument('--cutmix_beta', type=float, default=1.0, help='cutmix_beta (default: 1.)')\nparser.add_argument('--cutmix_prob', type=float, default=0.5, help='cutmix_prib for event data (default: .5)')\nparser.add_argument('--cutmix_num', type=int, default=1, help='cutmix_num for event data (default: 1)')\nparser.add_argument('--cutmix_noise', type=float, default=0.,\n                    help='Add Pepper noise after mix, sometimes work (default: 0.)')\nparser.add_argument('--rand-aug', action='store_true',\n                    help='Rand Augment for Event data (default: False)')\nparser.add_argument('--randaug_n', type=int, default=3,\n                    help='Rand Augment times n (default: 3)')\nparser.add_argument('--randaug_m', type=int, default=15,\n                    help='Rand Augment times n (default: 15) (0-30)')\nparser.add_argument('--train-portion', type=float, default=0.9,\n                    help='Dataset portion, only for datasets which do not have validation set (default: 0.9)')\nparser.add_argument('--event-size', default=48, type=int,\n                    help='Event size. Resize event data before process (default: 48)')\nparser.add_argument('--layer-by-layer', action='store_true',\n                    help='forward step-by-step or layer-by-layer. '\n                         'Larger Model with layer-by-layer will be faster (default: False)')\nparser.add_argument('--node-resume', type=str, default='',\n                    help='resume weights in node for adaptive node. (default: False)')\nparser.add_argument('--node-trainable', action='store_true')\n\n# visualize\nparser.add_argument('--visualize', action='store_true',\n                    help='Visualize spiking map for each layer, only for validate (default: False)')\nparser.add_argument('--spike-rate', action='store_true',\n                    help='Print spiking rate for each layer, only for validate(default: False)')\nparser.add_argument('--tsne', action='store_true')\nparser.add_argument('--conf-mat', action='store_true')\n\n# DARTS parameters\nparser.add_argument('--init-channels', type=int, default=36)\nparser.add_argument('--layers', type=int, default=16)\nparser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')\nparser.add_argument('--arch', default='', type=str)\nparser.add_argument('--parse_method', default='darts', type=str)\nparser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability')\nparser.add_argument('--back-connection', action='store_true')\n\nparser.add_argument('--suffix', type=str, default='',\n                    help='Add an additional suffix to the save path (default: \\'\\')')\n\n\ntry:\n    from apex import amp\n    from apex.parallel import DistributedDataParallel as ApexDDP\n    from apex.parallel import convert_syncbn_model\n\n    has_apex = True\nexcept ImportError:\n    has_apex = False\n\nhas_native_amp = False\ntry:\n    if getattr(torch.cuda.amp, 'autocast') is not None:\n        has_native_amp = True\nexcept AttributeError:\n    pass\n\n\ndef _parse_args():\n    # Do we have a config file to parse?\n    args_config, remaining = config_parser.parse_known_args()\n    if args_config.config:\n        with open(args_config.config, 'r') as f:\n            cfg = yaml.safe_load(f)\n            parser.set_defaults(**cfg)\n\n    # The main arg parser parses the rest of the args, the usual\n    # defaults will have been overridden if config file specified.\n    args = parser.parse_args(remaining)\n\n    # Cache the args as a text string to save them in the output dir later\n    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)\n    return args, args_text\n\n\ndef main():\n    args, args_text = _parse_args()\n    # args.no_spike_output = args.no_spike_output | args.cut_mix\n    args.no_spike_output = True\n    output_dir = ''\n    if args.local_rank == 0:\n        output_base = args.output if args.output else './output'\n        exp_name = '-'.join([\n            datetime.now().strftime(\"%Y%m%d-%H%M%S\"),\n            args.model,\n            args.dataset,\n            args.arch,\n            str(args.step),\n            args.suffix\n            # str(args.img_size)\n        ])\n        output_dir = get_outdir(output_base, 'train', exp_name)\n        args.output_dir = output_dir\n        setup_default_logging(log_path=os.path.join(output_dir, 'log.txt'))\n\n    else:\n        setup_default_logging()\n\n    args.prefetcher = not args.no_prefetcher\n    args.distributed = False\n    if 'WORLD_SIZE' in os.environ:\n        args.distributed = int(os.environ['WORLD_SIZE']) > 1\n        if args.distributed and args.num_gpu > 1:\n            _logger.warning(\n                'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')\n            args.num_gpu = 1\n    # args.device = 'cuda:0'\n    args.world_size = 1\n    args.rank = 0  # global rank\n    if args.distributed:\n        args.num_gpu = 1\n        args.device = 'cuda:%d' % args.local_rank\n        torch.cuda.set_device(args.local_rank)\n        torch.distributed.init_process_group(backend='nccl', init_method='env://')\n        args.world_size = torch.distributed.get_world_size()\n        args.rank = torch.distributed.get_rank()\n    else:\n        torch.cuda.set_device('cuda:%d' % args.device)\n    assert args.rank >= 0\n\n    if args.distributed:\n        _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'\n                     % (args.rank, args.world_size))\n    else:\n        _logger.info('Training with a single process on %d GPUs.' % args.num_gpu)\n\n    # torch.manual_seed(args.seed + args.rank)\n    setup_seed(args.seed + args.rank)\n    genotype = eval('genotypes.%s' % args.arch)\n\n    model = create_model(\n        args.model,\n        pretrained=args.pretrained,\n        num_classes=args.num_classes,\n        dataset=args.dataset,\n        step=args.step,\n        encode_type=args.encode,\n        node_type=eval(args.node_type),\n        threshold=args.threshold,\n        tau=args.tau,\n        sigmoid_thres=args.sigmoid_thres,\n        requires_thres_grad=args.requires_thres_grad,\n        spike_output=not args.no_spike_output,\n        C=args.init_channels,\n        layers=args.layers,\n        auxiliary=args.auxiliary,\n        genotype=genotype,\n        parse_method=args.parse_method,\n        back_connection=args.back_connection,\n        act_fun=args.act_fun,\n        temporal_flatten=args.temporal_flatten,\n        layer_by_layer=args.layer_by_layer,\n        n_groups=args.n_groups,\n    )\n\n    if 'dvs' in args.dataset:\n        args.channels = 2\n    elif 'mnist' in args.dataset:\n        args.channels = 1\n    else:\n        args.channels = 3\n    # flops, params = profile(model, inputs=(torch.randn(1, args.channels, args.event_size, args.event_size),), verbose=False)\n    # _logger.info('flops = %fM', flops / 1e6)\n    # _logger.info('param size = %fM', params / 1e6)\n\n    linear_scaled_lr = args.lr * args.batch_size * args.world_size / 1024.0\n    args.lr = linear_scaled_lr\n    _logger.info(\"learning rate is %f\" % linear_scaled_lr)\n\n    if args.local_rank == 0:\n        _logger.info('Model %s created, param count: %d' %\n                     (args.model, sum([m.numel() for m in model.parameters()])))\n\n    num_aug_splits = 0\n    if args.aug_splits > 0:\n        assert args.aug_splits > 1, 'A split of 1 makes no sense'\n        num_aug_splits = args.aug_splits\n\n    if args.split_bn:\n        assert num_aug_splits > 1 or args.resplit\n        model = convert_splitbn_model(model, max(num_aug_splits, 2))\n\n    use_amp = None\n    if args.amp:\n        # for backwards compat, `--amp` arg tries apex before native amp\n        if has_apex:\n            args.apex_amp = True\n        elif has_native_amp:\n            args.native_amp = True\n    if args.apex_amp and has_apex:\n        use_amp = 'apex'\n    elif args.native_amp and has_native_amp:\n        use_amp = 'native'\n    elif args.apex_amp or args.native_amp:\n        _logger.warning(\"Neither APEX or native Torch AMP is available, using float32. \"\n                        \"Install NVIDA apex or upgrade to PyTorch 1.6\")\n\n    if args.num_gpu > 1:\n        if use_amp == 'apex':\n            _logger.warning(\n                'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.')\n            use_amp = None\n        model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()\n        assert not args.channels_last, \"Channels last not supported with DP, use DDP.\"\n    else:\n        model = model.cuda()\n        if args.channels_last:\n            model = model.to(memory_format=torch.channels_last)\n\n    optimizer = create_optimizer(args, model)\n\n    amp_autocast = suppress  # do nothing\n    loss_scaler = None\n    if use_amp == 'apex':\n        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')\n        loss_scaler = ApexScaler()\n        if args.local_rank == 0:\n            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')\n    elif use_amp == 'native':\n        amp_autocast = torch.cuda.amp.autocast\n        loss_scaler = NativeScaler()\n        if args.local_rank == 0:\n            _logger.info('Using native Torch AMP. Training in mixed precision.')\n    else:\n        if args.local_rank == 0:\n            _logger.info('AMP not enabled. Training in float32.')\n\n    # optionally resume from a checkpoint\n    resume_epoch = None\n    if args.resume and args.eval_checkpoint == '':\n        args.eval_checkpoint = args.resume\n    if args.resume:\n        args.eval = True\n        # checkpoint = torch.load(args.resume, map_location='cpu')\n        # model.load_state_dict(checkpoint['state_dict'], False)\n        resume_epoch = resume_checkpoint(\n            model, args.resume,\n            optimizer=None if args.no_resume_opt else optimizer,\n            loss_scaler=None if args.no_resume_opt else loss_scaler,\n            log_info=args.local_rank == 0)\n        # print(model.get_attr('mu'))\n        # print(model.get_attr('sigma'))\n\n    # if args.critical_loss or args.spike_rate:\n    model.set_requires_fp(True)\n\n    model_ema = None\n    if args.model_ema:\n        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper\n        model_ema = ModelEma(\n            model,\n            decay=args.model_ema_decay,\n            device='cpu' if args.model_ema_force_cpu else '',\n            resume=args.resume)\n\n    if args.node_resume:\n        ckpt = torch.load(args.node_resume, map_location='cpu')\n        model.load_node_weight(ckpt, args.node_trainable)\n\n    model_without_ddp = model\n    if args.distributed:\n        if args.sync_bn:\n            assert not args.split_bn\n            try:\n                if has_apex and use_amp != 'native':\n                    # Apex SyncBN preferred unless native amp is activated\n                    model = convert_syncbn_model(model)\n                else:\n                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)\n                if args.local_rank == 0:\n                    _logger.info(\n                        'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '\n                        'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')\n            except Exception as e:\n                _logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')\n        if has_apex and use_amp != 'native':\n            # Apex DDP preferred unless native amp is activated\n            if args.local_rank == 0:\n                _logger.info(\"Using NVIDIA APEX DistributedDataParallel.\")\n            model = ApexDDP(model, delay_allreduce=True)\n        else:\n            if args.local_rank == 0:\n                _logger.info(\"Using native Torch DistributedDataParallel.\")\n            model = NativeDDP(model, device_ids=[args.local_rank],\n                              find_unused_parameters=True)  # can use device str in Torch >= 1.1\n        model_without_ddp = model.module\n    # NOTE: EMA model does not need to be wrapped by DDP\n\n    lr_scheduler, num_epochs = create_scheduler(args, optimizer)\n    start_epoch = 0\n    if args.start_epoch is not None:\n        # a specified start_epoch will always override the resume epoch\n        start_epoch = args.start_epoch\n    elif resume_epoch is not None:\n        start_epoch = resume_epoch\n    if lr_scheduler is not None and start_epoch > 0:\n        lr_scheduler.step(start_epoch)\n\n    if args.local_rank == 0:\n        _logger.info('Scheduled epochs: {}'.format(num_epochs))\n\n    # now config only for imnet\n    data_config = resolve_data_config(vars(args), model=model, verbose=False)\n    loader_train, loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % args.dataset)(\n        batch_size=args.batch_size,\n        step=args.step,\n        args=args,\n        _logge=_logger,\n        data_config=data_config,\n        num_aug_splits=num_aug_splits,\n        size=args.event_size,\n        mix_up=args.mix_up,\n        cut_mix=args.cut_mix,\n        event_mix=args.event_mix,\n        beta=args.cutmix_beta,\n        prob=args.cutmix_prob,\n        num=args.cutmix_num,\n        noise=args.cutmix_noise,\n        num_classes=args.num_classes,\n        rand_aug=args.rand_aug,\n        randaug_n=args.randaug_n,\n        randaug_m=args.randaug_m,\n        temporal_flatten=args.temporal_flatten,\n        portion=args.train_portion,\n        _logger=_logger,\n\n    )\n\n    if args.loss_fn == 'mse':\n        train_loss_fn = UnilateralMse(1.)\n        validate_loss_fn = UnilateralMse(1.)\n\n    else:\n        if args.jsd:\n            assert num_aug_splits > 1  # JSD only valid with aug splits set\n            train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()\n        elif mixup_active:\n            # smoothing is handled with mixup target transform\n            train_loss_fn = SoftTargetCrossEntropy().cuda()\n        elif args.smoothing:\n            train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()\n        else:\n            train_loss_fn = nn.CrossEntropyLoss().cuda()\n\n        validate_loss_fn = nn.CrossEntropyLoss().cuda()\n\n    if args.loss_fn == 'mix':\n        train_loss_fn = MixLoss(train_loss_fn)\n        validate_loss_fn = MixLoss(validate_loss_fn)\n\n    eval_metric = args.eval_metric\n    best_metric = None\n    best_epoch = None\n\n    if args.eval:  # evaluate the model\n        if args.distributed:\n            state_dict = torch.load(args.eval_checkpoint)['state_dict_ema']\n            new_state_dict = OrderedDict()\n            # add module prefix for DDP\n            for k, v in state_dict.items():\n                k = 'module.' + k\n                new_state_dict[k] = v\n\n            model.load_state_dict(new_state_dict)\n        # else:\n        #     load_checkpoint(model, args.eval_checkpoint, args.model_ema)\n        for i in range(1):\n            val_metrics = validate(start_epoch, model, loader_eval, validate_loss_fn, args,\n                                   visualize=args.visualize, spike_rate=args.spike_rate,\n                                   tsne=args.tsne, conf_mat=args.conf_mat)\n            print(f\"Top-1 accuracy of the model is: {val_metrics['top1']:.1f}%\")\n        # return\n\n    saver = None\n    if args.local_rank == 0:\n        decreasing = True if eval_metric == 'loss' else False\n        saver = CheckpointSaver(\n            model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,\n            checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing)\n        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:\n            f.write(args_text)\n\n    try:  # train the model\n        if args.reset_drop:\n            model_without_ddp.reset_drop_path(0.0)\n        for epoch in range(start_epoch, args.epochs):\n            if epoch == 0 and args.reset_drop:\n                model_without_ddp.reset_drop_path(args.drop_path)\n\n            if args.distributed:\n                loader_train.sampler.set_epoch(epoch)\n\n            train_metrics = train_epoch(\n                epoch, model, loader_train, optimizer, train_loss_fn, args,\n                lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,\n                amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)\n\n            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):\n                if args.local_rank == 0:\n                    _logger.info(\"Distributing BatchNorm running means and vars\")\n                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')\n\n            eval_metrics = validate(epoch, model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast,\n                                    visualize=args.visualize, spike_rate=args.spike_rate,\n                                    tsne=args.tsne, conf_mat=args.conf_mat)\n\n            if model_ema is not None and not args.model_ema_force_cpu:\n                if args.distributed and args.dist_bn in ('broadcast', 'reduce'):\n                    distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')\n                ema_eval_metrics = validate(\n                    epoch, model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)',\n                    visualize=args.visualize, spike_rate=args.spike_rate,\n                    tsne=args.tsne, conf_mat=args.conf_mat)\n                eval_metrics = ema_eval_metrics\n\n            if lr_scheduler is not None:\n                # step LR for next epoch\n                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])\n\n            update_summary(\n                epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),\n                write_header=best_metric is None)\n\n            # if saver is not None and epoch >= args.n_warm_up:\n            if saver is not None:\n                # save proper checkpoint with eval metric\n                save_metric = eval_metrics[eval_metric]\n                best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)\n\n    except KeyboardInterrupt:\n        pass\n    if best_metric is not None:\n        _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))\n\n\ndef train_epoch(\n        epoch, model, loader, optimizer, loss_fn, args,\n        lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,\n        loss_scaler=None, model_ema=None, mixup_fn=None):\n    if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:\n        if args.prefetcher and loader.mixup_enabled:\n            loader.mixup_enabled = False\n        elif mixup_fn is not None:\n            mixup_fn.mixup_enabled = False\n\n    model.drop_path_prob = args.drop_path_prob * epoch / args.epochs\n\n    second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order\n    batch_time_m = AverageMeter()\n    data_time_m = AverageMeter()\n    losses_m = AverageMeter()\n    closses_m = AverageMeter()\n    top1_m = AverageMeter()\n    top5_m = AverageMeter()\n\n    model.train()\n\n    # t, k = adjust_surrogate_coeff(100, args.epochs)\n    # model.set_attr('t', t)\n    # model.set_attr('k', k)\n\n    end = time.time()\n    last_idx = len(loader) - 1\n    num_updates = epoch * len(loader)\n    for batch_idx, (inputs, target) in enumerate(loader):\n        last_batch = batch_idx == last_idx\n        data_time_m.update(time.time() - end)\n        if not args.prefetcher or args.dataset != 'imnet':\n            inputs, target = inputs.type(torch.FloatTensor).cuda(), target.cuda()\n            if mixup_fn is not None:\n                inputs, target = mixup_fn(inputs, target)\n        if args.channels_last:\n            inputs = inputs.contiguous(memory_format=torch.channels_last)\n        with amp_autocast():\n            output = model(inputs)\n            loss = loss_fn(output, target)\n        if not (args.cut_mix | args.mix_up | args.event_mix) and args.dataset != 'imnet':\n            # print(output.shape, target.shape)\n            acc1, acc5 = accuracy(output, target, topk=(1, 5))\n            # acc1, = accuracy(output, target)\n        else:\n            acc1, acc5 = torch.tensor([0.]), torch.tensor([0.])\n\n        optimizer.zero_grad()\n        if loss_scaler is not None:\n            loss_scaler(\n                loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order)\n        else:\n            loss.backward(create_graph=second_order)\n            if args.noisy_grad != 0.:\n                random_gradient(model, args.noisy_grad)\n            if args.clip_grad is not None:\n                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)\n            if args.opt == 'lamb':\n                optimizer.step(epoch=epoch)\n            else:\n                optimizer.step()\n\n        torch.cuda.synchronize()\n        if model_ema is not None:\n            model_ema.update(model)\n        num_updates += 1\n\n        batch_time_m.update(time.time() - end)\n        if last_batch or batch_idx % args.log_interval == 0:\n            lrl = [param_group['lr'] for param_group in optimizer.param_groups]\n            lr = sum(lrl) / len(lrl)\n\n            mu_str = ''\n            sigma_str = ''\n            if not args.distributed:\n                if 'Noise' in args.node_type:\n                    mu, sigma = model.get_noise_param()\n                    mu_str = ['{:.3f}'.format(i.detach()) for i in mu]\n                    sigma_str = ['{:.3f}'.format(i.detach()) for i in sigma]\n\n            if args.distributed:\n                reduced_loss = reduce_tensor(loss.data, args.world_size)\n                losses_m.update(reduced_loss.item(), inputs.size(0))\n                closses_m.update(reduced_loss.item(), inputs.size(0))\n\n            if args.local_rank == 0:\n\n                _logger.info(\n                    'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '\n                    'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f})  '\n                    'cLoss: {closs.val:>9.6f} ({closs.avg:>6.4f})  '\n                    'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '\n                    'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})  '\n                    'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '\n                    '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '\n                    'LR: {lr:.3e}  '\n                    'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(\n                        epoch,\n                        batch_idx, len(loader),\n                        100. * batch_idx / last_idx,\n                        loss=losses_m,\n                        closs=closses_m,\n                        top1=top1_m,\n                        top5=top5_m,\n                        batch_time=batch_time_m,\n                        rate=inputs.size(0) * args.world_size / batch_time_m.val,\n                        rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,\n                        lr=lr,\n                        data_time=data_time_m\n                    ))\n\n                if args.save_images and output_dir:\n                    torchvision.utils.save_image(\n                        inputs,\n                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),\n                        padding=0,\n                        normalize=True)\n\n        if saver is not None and args.recovery_interval and (\n                last_batch or (batch_idx + 1) % args.recovery_interval == 0):\n            saver.save_recovery(epoch, batch_idx=batch_idx)\n\n        if lr_scheduler is not None:\n            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)\n\n        end = time.time()\n    # end for\n\n    if hasattr(optimizer, 'sync_lookahead'):\n        optimizer.sync_lookahead()\n\n    return OrderedDict([('loss', losses_m.avg)])\n\n\ndef validate(epoch, model, loader, loss_fn, args, amp_autocast=suppress,\n             log_suffix='', visualize=False, spike_rate=False, tsne=False, conf_mat=False):\n    batch_time_m = AverageMeter()\n    losses_m = AverageMeter()\n    closses_m = AverageMeter()\n    top1_m = AverageMeter()\n    top5_m = AverageMeter()\n\n    model.eval()\n\n    feature_vec = []\n    feature_cls = []\n    logits_vec = []\n    labels_vec = []\n\n    end = time.time()\n    last_idx = len(loader) - 1\n    with torch.no_grad():\n        for batch_idx, (inputs, target) in enumerate(loader):\n            # inputs = inputs.type(torch.float64)\n            last_batch = batch_idx == last_idx\n            if not args.prefetcher or args.dataset != 'imnet':\n                inputs = inputs.type(torch.FloatTensor).cuda()\n                target = target.cuda()\n            if args.channels_last:\n                inputs = inputs.contiguous(memory_format=torch.channels_last)\n\n            if not args.distributed:\n                if (visualize or spike_rate or tsne or conf_mat) and not args.critical_loss:\n                    model.set_requires_fp(True)\n                    # if not args.critical_loss:\n                    #     model.set_requires_fp(False)\n\n            with amp_autocast():\n                output = model(inputs)\n            if isinstance(output, (tuple, list)):\n                output = output[0]\n\n            if not args.distributed:\n                if visualize:\n                    x = model.get_fp()\n                    feature_path = os.path.join(args.output_dir, 'feature_map')\n                    if os.path.exists(feature_path) is False:\n                        os.mkdir(feature_path)\n                    save_feature_map(x, feature_path)\n                    # if not args.critical_loss:\n                    #     model_config.set_requires_fp(False)\n\n            # augmentation reduction\n            reduce_factor = args.tta\n            if reduce_factor > 1:\n                output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)\n                target = target[0:target.size(0):reduce_factor]\n\n            loss = loss_fn(output, target)\n            acc1, acc5 = accuracy(output, target, topk=(1, 5))\n            # acc1, = accuracy(output, target)\n            tot_spike = model.get_tot_spike() if hasattr(model, 'get_tot_spike') else 0.\n            if args.distributed:\n                reduced_loss = reduce_tensor(loss.data, args.world_size)\n                acc1 = reduce_tensor(acc1, args.world_size)\n                acc5 = reduce_tensor(acc5, args.world_size)\n            else:\n                reduced_loss = loss.data\n\n            torch.cuda.synchronize()\n\n            losses_m.update(reduced_loss.item(), inputs.size(0))\n            top1_m.update(acc1.item(), output.size(0))\n            top5_m.update(acc5.item(), output.size(0))\n\n            batch_time_m.update(time.time() - end)\n            end = time.time()\n            if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):\n                log_name = 'Test' + log_suffix\n\n                mu_str = ''\n                sigma_str = ''\n                if not args.distributed:\n                    if 'Noise' in args.node_type:\n                        mu, sigma = model.get_noise_param()\n                        mu_str = ['{:.3f}'.format(i.detach()) for i in mu]\n                        sigma_str = ['{:.3f}'.format(i.detach()) for i in sigma]\n\n                _logger.info(\n                    '{0}: [{1:>4d}/{2}]  '\n                    'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '\n                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '\n                    'cLoss: {closs.val:>7.4f} ({closs.avg:>6.4f})  '\n                    'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})'\n                    'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f}) '\n                    'TotSpike: {tot_spike}'.format(\n                        log_name,\n                        batch_idx,\n                        last_idx,\n                        batch_time=batch_time_m,\n                        loss=losses_m,\n                        closs=closses_m,\n                        top1=top1_m,\n                        top5=top5_m,\n                        tot_spike=tot_spike\n                        ))\n\n    # metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])\n    metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg)])\n\n    if not args.distributed:\n        if tsne:\n            feature_vec = torch.cat(feature_vec)\n            feature_cls = torch.cat(feature_cls)\n            plot_tsne(feature_vec, feature_cls, os.path.join(args.output_dir, 't-sne-2d.eps'))\n            plot_tsne_3d(feature_vec, feature_cls, os.path.join(args.output_dir, 't-sne-3d.eps'))\n        if conf_mat:\n            logits_vec = torch.cat(logits_vec)\n            labels_vec = torch.cat(labels_vec)\n            plot_confusion_matrix(logits_vec, labels_vec, os.path.join(args.output_dir, 'confusion_matrix.eps'))\n\n    return metrics\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "examples/Perception_and_Learning/NeuEvo/separate_loss.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom braincog.base.utils.criterions import UnilateralMse\nfrom utils import num_ops, type_num, edge_num\n\n__all__ = ['ConvSeparateLoss', 'TriSeparateLoss']\n\n\nclass MseSeparateLoss(nn.modules.loss._Loss):\n\n    def __init__(self, weight=0.1, size_average=None, ignore_index=-100,\n                 reduce=None, reduction='mean'):\n        super(MseSeparateLoss, self).__init__(size_average, reduce, reduction)\n        self.ignore_index = ignore_index\n        self.weight = weight\n        self.criterion = UnilateralMse(1.)\n\n    def forward(self, input1, target1, input2):\n        loss1 = self.criterion(input1, target1)\n        loss2 = -F.mse_loss(input2, torch.tensor(0.5,\n                            requires_grad=False).cuda())\n        return loss1 + self.weight * loss2, loss1.item(), loss2.item()\n\n\nclass ConvSeparateLoss(nn.modules.loss._Loss):\n    \"\"\"Separate the weight value between each operations using L2\"\"\"\n\n    def __init__(self, loss1_fn, weight=0.1, size_average=None, ignore_index=-100,\n                 reduce=None, reduction='mean'):\n        super(ConvSeparateLoss, self).__init__(size_average, reduce, reduction)\n        self.ignore_index = ignore_index\n        self.weight = weight\n        self.loss1_fn = loss1_fn\n\n    def forward(self, input1, target1, input2):\n        loss1 = self.loss1_fn(input1, target1)\n        # loss2 = -F.mse_loss(input2, torch.tensor(0.5, requires_grad=False).cuda())\n        # loss2 = -torch.std(input2, dim=-1).sum()\n        # + F.mse_loss(torch.mean(input2, dim=-1), torch.tensor(0.2, requires_grad=False).cuda())\n\n        # loss_std = 0\n        # loss_avg = 0.\n        # edge = edge_num + edge_num\n        # edge_input2 = torch.split(input2, edge, dim=0)\n        # for i in range(len(edge)):\n        #     avg_e = 2 / (edge[i] * num_ops)\n        #     loss_avg += 5 * F.mse_loss(torch.mean(edge_input2[i]), torch.tensor(avg_e, requires_grad=False).cuda())\n        #     loss_std += -torch.std(edge_input2[i]).sum()\n        # loss2 = loss_std + loss_avg\n\n        # loss2 = torch.tensor([0.], device=input1.device)\n\n        loss2 = - 0.2 * torch.std(input2)\n\n        return loss1 + self.weight * loss2, loss1.item(), loss2.item()\n\n\nclass TriSeparateLoss(nn.modules.loss._Loss):\n    \"\"\"Separate the weight value between each operations using L1\"\"\"\n\n    def __init__(self, loss1_fn, weight=0.1, size_average=None, ignore_index=-100,\n                 reduce=None, reduction='mean'):\n        super(TriSeparateLoss, self).__init__(size_average, reduce, reduction)\n        self.ignore_index = ignore_index\n        self.weight = weight\n        self.loss1_fn = loss1_fn\n\n    def forward(self, input1, target1, input2):\n        loss1 = F.cross_entropy(input1, target1)\n        loss2 = -F.l1_loss(input2, torch.tensor(0.5,\n                           requires_grad=False).cuda())\n        return loss1 + self.weight * loss2, loss1.item(), loss2.item()\n"
  },
  {
    "path": "examples/Perception_and_Learning/NeuEvo/train.py",
    "content": "import os\nimport sys\nimport time\nimport logging\nimport torch\nimport utils as dutils\nimport argparse\nimport numpy as np\nimport torch.utils\nimport torch.nn as nn\nfrom braincog.model_zoo.NeuEvo import genotypes\nfrom braincog.model_zoo.NeuEvo.model import NetworkCIFAR as Network\nimport torchvision.datasets as dset\nimport torch.backends.cudnn as cudnn\nfrom thop import profile\n\nfrom braincog.datasets.datasets import build_transform\nfrom braincog.base.utils.criterions import UnilateralMse\n\nparser = argparse.ArgumentParser(\"cifar\")\nparser.add_argument('--data', type=str, default='/data/datasets',\n                    help='location of the data corpus')\nparser.add_argument('--dataset', type=str, default='cifar10',\n                    help='cifar10 or cifar 100 for training')\nparser.add_argument('--batch-size', type=int, default=128, help='batch size')\nparser.add_argument('--learning_rate', type=float,\n                    default=0.025, help='init learning rate')\nparser.add_argument('--start-epoch', default=0, type=int, metavar='N',\n                    help='manual epoch number (useful on restarts)')\nparser.add_argument('--momentum', type=float, default=0.9, help='momentum')\nparser.add_argument('--weight_decay', type=float,\n                    default=3e-4, help='weight decay')\nparser.add_argument('--report_freq', type=float,\n                    default=50, help='report frequency')\nparser.add_argument('--device', type=int, default=0, help='gpu device id')\nparser.add_argument('--multi-gpus', action='store_true',\n                    default=False, help='use multi gpus')\nparser.add_argument('--parse_method', type=str,\n                    default='darts', help='experiment name')\nparser.add_argument('--epochs', type=int, default=600,\n                    help='num of training epochs')\nparser.add_argument('--init-channels', type=int,\n                    default=64, help='num of init channels')\nparser.add_argument('--layers', type=int, default=16,\n                    help='total number of layers')\nparser.add_argument('--model_path', type=str,\n                    default='saved_models', help='path to save the model')\nparser.add_argument('--auxiliary', action='store_true',\n                    default=False, help='use auxiliary tower')\nparser.add_argument('--auxiliary_weight', type=float,\n                    default=0.4, help='weight for auxiliary loss')\nparser.add_argument('--cutout', action='store_true',\n                    default=False, help='use cutout')\nparser.add_argument('--cutout_length', type=int,\n                    default=16, help='cutout length')\nparser.add_argument('--auto_aug', action='store_true',\n                    default=False, help='use auto augmentation')\nparser.add_argument('--drop_path_prob', type=float,\n                    default=0.2, help='drop path probability')\nparser.add_argument('--save', type=str, default='EXP', help='experiment name')\nparser.add_argument('--seed', type=int, default=42, help='random seed')\nparser.add_argument('--arch', type=str, default='DARTS',\n                    help='which architecture to use')\nparser.add_argument('--grad_clip', type=float,\n                    default=5, help='gradient clipping')\nparser.add_argument('--resume', default='', type=str, metavar='PATH',\n                    help='path to latest checkpoint (default: none)')\n\nparser.add_argument('--img_size', default=32, type=int)\nparser.add_argument('--step', default=8, type=int)\nparser.add_argument('--node-type', default='PLIFNode', type=str)\nparser.add_argument('--suffix', default='', type=str)\n\n\nclass TrainNetwork(object):\n    \"\"\"The main train network\"\"\"\n\n    def __init__(self, args):\n        super(TrainNetwork, self).__init__()\n        self.args = args\n        self.dur_time = 0\n        self._init_log()\n        self._init_device()\n        self._init_data_queue()\n        self._init_model()\n\n    def _init_log(self):\n        self.args.save = '/data/floyed/darts/logs/eval/' + self.args.arch + '/' + 'cifar10' + '/eval-{}-{}-{}'.format(\n            self.args.save, time.strftime('%Y%m%d-%H%M'), args.suffix)\n        dutils.create_exp_dir(self.args.save, scripts_to_save=None)\n\n        log_format = '%(asctime)s %(message)s'\n        logging.basicConfig(stream=sys.stdout, level=logging.INFO,\n                            format=log_format, datefmt='%m/%d %I:%M:%S %p')\n        fh = logging.FileHandler(os.path.join(self.args.save, 'log.txt'))\n        fh.setFormatter(logging.Formatter(log_format))\n        self.logger = logging.getLogger('Architecture Training')\n        self.logger.addHandler(fh)\n\n    def _init_device(self):\n        if not torch.cuda.is_available():\n            self.logger.info('no gpu device available')\n            sys.exit(1)\n        np.random.seed(self.args.seed)\n        self.device_id = self.args.device\n        self.device = torch.device('cuda:{}'.format(\n            0 if self.args.multi_gpus else self.device_id))\n        cudnn.benchmark = True\n        torch.manual_seed(self.args.seed)\n        cudnn.enabled = True\n        torch.cuda.manual_seed(self.args.seed)\n        logging.info('gpu device = %d' % self.args.device)\n        logging.info(\"args = %s\", self.args)\n\n    def _init_data_queue(self):\n        train_transform = build_transform(True, args.img_size)\n        valid_transform = build_transform(False, args.img_size)\n        if self.args.dataset == 'cifar10':\n            train_data = dset.CIFAR10(\n                root=self.args.data, train=True, download=True, transform=train_transform)\n            valid_data = dset.CIFAR10(\n                root=self.args.data, train=False, download=True, transform=valid_transform)\n            self.num_classes = 10\n        elif self.args.dataset == 'cifar100':\n            train_data = dset.CIFAR100(\n                root=self.args.data, train=True, download=True, transform=train_transform)\n            valid_data = dset.CIFAR100(\n                root=self.args.data, train=False, download=True, transform=valid_transform)\n            self.num_classes = 100\n        self.train_queue = torch.utils.data.DataLoader(\n            train_data, batch_size=self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4)\n\n        self.valid_queue = torch.utils.data.DataLoader(\n            valid_data, batch_size=self.args.batch_size, shuffle=False, pin_memory=True, num_workers=4)\n\n    def _init_model(self):\n        genotype = eval('genotypes.%s' % self.args.arch)\n        model = Network(self.args.init_channels,\n                        self.num_classes,\n                        self.args.layers,\n                        self.args.auxiliary,\n                        genotype,\n                        self.args.parse_method,\n                        step=args.step,\n                        node_type=args.node_type\n                        )\n        flops, params = profile(model, inputs=(\n            torch.randn(1, 3, 32, 32),), verbose=False)\n        self.logger.info('flops = %fM', flops / 1e6)\n        self.logger.info('param size = %fM', params / 1e6)\n\n        # Try move model to multi gpus\n        if torch.cuda.device_count() > 1 and self.args.multi_gpus:\n            self.logger.info('use: %d gpus', torch.cuda.device_count())\n            model = nn.DataParallel(model)\n        else:\n            self.logger.info('gpu device = %d' % self.device_id)\n            torch.cuda.set_device(self.device_id)\n        self.model = model.to(self.device)\n\n        # criterion = nn.CrossEntropyLoss()\n        criterion = UnilateralMse(1.)\n        self.criterion = criterion.to(self.device)\n        self.optimizer = torch.optim.AdamW(\n            model.parameters(),\n            self.args.learning_rate,\n            weight_decay=self.args.weight_decay\n        )\n\n        self.best_acc_top1 = 0\n        # optionally resume from a checkpoint\n        if self.args.resume:\n            if os.path.isfile(self.args.resume):\n                print(\"=> loading checkpoint {}\".format(self.args.resume))\n                checkpoint = torch.load(\n                    self.args.resume, map_location=self.device)\n                self.dur_time = checkpoint['dur_time']\n                self.args.start_epoch = checkpoint['epoch']\n                self.best_acc_top1 = checkpoint['best_acc_top1']\n                self.args.drop_path_prob = checkpoint['drop_path_prob']\n                self.model.load_state_dict(checkpoint['state_dict'])\n                self.optimizer.load_state_dict(checkpoint['optimizer'])\n                print(\"=> loaded checkpoint '{}' (epoch {})\".format(\n                    self.args.resume, checkpoint['epoch']))\n            else:\n                print(\"=> no checkpoint found at '{}'\".format(self.args.resume))\n\n        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, float(self.args.epochs), eta_min=0,\n                                                                    last_epoch=-1 if self.args.start_epoch == 0 else self.args.start_epoch)\n        # reload the scheduler if possible\n        if self.args.resume and os.path.isfile(self.args.resume):\n            checkpoint = torch.load(self.args.resume)\n            self.scheduler.load_state_dict(checkpoint['scheduler'])\n\n    def run(self):\n        self.logger.info('args = %s', self.args)\n        run_start = time.time()\n        for epoch in range(self.args.start_epoch, self.args.epochs):\n            self.scheduler.step()\n            self.logger.info('epoch % d / %d  lr %e', epoch,\n                             self.args.epochs, self.scheduler.get_lr()[0])\n\n            self.model.drop_path_prob = self.args.drop_path_prob * epoch / self.args.epochs\n\n            train_acc, train_obj = self.train()\n            self.logger.info('train loss %e, train acc %f',\n                             train_obj, train_acc)\n\n            valid_acc_top1, valid_acc_top5, valid_obj = self.infer()\n            self.logger.info('valid loss %e, top1 valid acc %f top5 valid acc %f',\n                             valid_obj, valid_acc_top1, valid_acc_top5)\n            self.logger.info('best valid acc %f', self.best_acc_top1)\n\n            is_best = False\n            if valid_acc_top1 > self.best_acc_top1:\n                self.best_acc_top1 = valid_acc_top1\n                is_best = True\n\n            dutils.save_checkpoint({\n                'epoch': epoch + 1,\n                'dur_time': self.dur_time + time.time() - run_start,\n                'state_dict': self.model.state_dict(),\n                'drop_path_prob': self.args.drop_path_prob,\n                'best_acc_top1': self.best_acc_top1,\n                'optimizer': self.optimizer.state_dict(),\n                'scheduler': self.scheduler.state_dict()\n            }, is_best, self.args.save)\n        self.logger.info('train epoches %d, best_acc_top1 %f, dur_time %s',\n                         self.args.epochs, self.best_acc_top1,\n                         dutils.calc_time(self.dur_time + time.time() - run_start))\n\n    def train(self):\n        objs = dutils.AvgrageMeter()\n        top1 = dutils.AvgrageMeter()\n        top5 = dutils.AvgrageMeter()\n\n        self.model.train()\n\n        for step, (input, target) in enumerate(self.train_queue):\n\n            input = input.cuda(non_blocking=True)\n            target = target.cuda(non_blocking=True)\n\n            self.optimizer.zero_grad()\n            logits, logits_aux = self.model(input)\n            loss = self.criterion(logits, target)\n            if self.args.auxiliary:\n                loss_aux = self.criterion(logits_aux, target)\n                loss += self.args.auxiliary_weight * loss_aux\n            loss.backward()\n            nn.utils.clip_grad_norm_(\n                self.model.parameters(), self.args.grad_clip)\n            self.optimizer.step()\n\n            prec1, prec5 = dutils.accuracy(logits, target, topk=(1, 5))\n            n = input.size(0)\n            objs.update(loss.item(), n)\n            top1.update(prec1.item(), n)\n            top5.update(prec5.item(), n)\n\n            if step % self.args.report_freq == 0:\n                self.logger.info('train %03d %e %f %f', step,\n                                 objs.avg, top1.avg, top5.avg)\n\n        return top1.avg, objs.avg\n\n    def infer(self):\n        objs = dutils.AvgrageMeter()\n        top1 = dutils.AvgrageMeter()\n        top5 = dutils.AvgrageMeter()\n        self.model.eval()\n        with torch.no_grad():\n            for step, (input, target) in enumerate(self.valid_queue):\n                input = input.cuda(non_blocking=True)\n                target = target.cuda(non_blocking=True)\n\n                logits, _ = self.model(input)\n                loss = self.criterion(logits, target)\n\n                prec1, prec5 = dutils.accuracy(logits, target, topk=(1, 5))\n                n = input.size(0)\n                objs.update(loss.item(), n)\n                top1.update(prec1.item(), n)\n                top5.update(prec5.item(), n)\n\n                if step % self.args.report_freq == 0:\n                    self.logger.info('valid %03d %e %f %f',\n                                     step, objs.avg, top1.avg, top5.avg)\n            return top1.avg, top5.avg, objs.avg\n\n\nif __name__ == '__main__':\n    args = parser.parse_args()\n    train_network = TrainNetwork(args)\n    train_network.run()\n"
  },
  {
    "path": "examples/Perception_and_Learning/NeuEvo/train_search.py",
    "content": "import os\nimport sys\nimport time\nimport numpy as np\nimport torch\nimport logging\nimport argparse\nimport torch.nn as nn\nimport torch.utils\nimport torch.nn.functional as F\nimport torchvision.datasets as dset\nimport torch.backends.cudnn as cudnn\n\nfrom torch.autograd import Variable\nfrom braincog.model_zoo.NeuEvo.model_search import Network, calc_weight, calc_loss\nfrom braincog.model_zoo.NeuEvo.architect import Architect\nfrom separate_loss import ConvSeparateLoss, TriSeparateLoss, MseSeparateLoss\nimport utils\n\nfrom timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy\n\nfrom braincog.datasets.datasets import *\nfrom braincog.base.utils.criterions import *\n\ntorch.autograd.set_detect_anomaly(True)\nparser = argparse.ArgumentParser(\"cifar\")\nparser.add_argument('--data', type=str, default='/data/datasets',\n                    help='location of the data corpus')\nparser.add_argument('--dataset', type=str, default='cifar10',\n                    help='cifar10 or cifar 100 for searching')\nparser.add_argument('--batch-size', type=int, default=128, help='batch size')\nparser.add_argument('--learning_rate', type=float,\n                    default=0.005, help='init learning rate')\nparser.add_argument('--learning_rate_min', type=float,\n                    default=0.001, help='min learning rate')\nparser.add_argument('--momentum', type=float, default=0.9, help='momentum')\nparser.add_argument('--weight_decay', type=float,\n                    default=3e-4, help='weight decay')\nparser.add_argument('--report_freq', type=float,\n                    default=50, help='report frequency')\nparser.add_argument('--aux_loss_weight', type=float,\n                    default=10.0, help='weight decay')\nparser.add_argument('--device', type=int, default=0, help='gpu device id')\nparser.add_argument('--epochs', type=int, default=50,\n                    help='num of training epochs')\nparser.add_argument('--init-channels', type=int,\n                    default=16, help='num of init channels')\nparser.add_argument('--layers', type=int, default=6,\n                    help='total number of layers')\nparser.add_argument('--model_path', type=str,\n                    default='saved_models', help='path to save the model')\nparser.add_argument('--single_level', action='store_true',\n                    default=False, help='use single level')\nparser.add_argument('--sep_loss', type=str, default='l2',\n                    help='path to save the model')\nparser.add_argument('--cutout', action='store_true',\n                    default=False, help='use cutout')\nparser.add_argument('--cutout_length', type=int,\n                    default=16, help='cutout length')\nparser.add_argument('--auto_aug', action='store_true',\n                    default=False, help='use auto augmentation')\nparser.add_argument('--parse_method', type=str,\n                    default='bio_darts', help='parse the code method')\nparser.add_argument('--op_threshold', type=float,\n                    default=0.85, help='threshold for edges')\nparser.add_argument('--save', type=str, default='EXP', help='experiment name')\nparser.add_argument('--seed', type=int, default=42, help='random seed')\nparser.add_argument('--grad_clip', type=float,\n                    default=5, help='gradient clipping')\nparser.add_argument('--train_portion', type=float,\n                    default=0.5, help='portion of training data')\nparser.add_argument('--arch_learning_rate', type=float,\n                    default=1e-3, help='learning rate for arch encoding')\nparser.add_argument('--arch_lr_gamma', type=float, default=0.9,\n                    help='learning rate for arch encoding')\nparser.add_argument('--arch_weight_decay', type=float,\n                    default=1e-3, help='weight decay for arch encoding')\nparser.add_argument('--resume', default='', type=str, metavar='PATH',\n                    help='path to latest checkpoint (default: none)')\n\n# EventData Augmentation\nparser.add_argument('--mix-up', action='store_true', help='Mix-up for event data (default: False)')\nparser.add_argument('--cut-mix', action='store_true', help='CutMix for event data (default: False)')\nparser.add_argument('--event-mix', action='store_true', help='EventMix for event data (default: False)')\nparser.add_argument('--cutmix_beta', type=float, default=1.0, help='cutmix_beta (default: 1.)')\nparser.add_argument('--cutmix_prob', type=float, default=0.5, help='cutmix_prib for event data (default: .5)')\nparser.add_argument('--cutmix_num', type=int, default=1, help='cutmix_num for event data (default: 1)')\nparser.add_argument('--cutmix_noise', type=float, default=0.,\n                    help='Add Pepper noise after mix, sometimes work (default: 0.)')\nparser.add_argument('--rand-aug', action='store_true',\n                    help='Rand Augment for Event data (default: False)')\nparser.add_argument('--randaug_n', type=int, default=3,\n                    help='Rand Augment times n (default: 3)')\nparser.add_argument('--randaug_m', type=int, default=15,\n                    help='Rand Augment times n (default: 15) (0-30)')\n\nparser.add_argument('--temporal-flatten', action='store_true',\n                    help='Temporal flatten to channels. ONLY FOR EVENT DATA TRAINING BY ANN')\nparser.add_argument('--train-portion', type=float, default=0.9,\n                    help='Dataset portion, only for datasets which do not have validation set (default: 0.9)')\nparser.add_argument('--event-size', default=48, type=int,\n                    help='Event size. Resize event data before process (default: 48)')\nparser.add_argument('--layer-by-layer', action='store_true',\n                    help='forward step-by-step or layer-by-layer. '\n                         'Larger Model with layer-by-layer will be faster (default: False)')\nparser.add_argument('--node-resume', type=str, default='',\n                    help='resume weights in node for adaptive node. (default: False)')\nparser.add_argument('--node-trainable', action='store_true')\n\nparser.add_argument('--img_size', default=32, type=int)\nparser.add_argument('--smoothing', default=0.1, type=float)\nparser.add_argument('--step', default=8, type=int)\nparser.add_argument('--node-type', default='BiasPLIFNode', type=str)\nparser.add_argument('--loss_fn', type=str, default='')\nparser.add_argument('--back-connection', action='store_true')\nparser.add_argument('--asbe', '--arch-search-begin-epoch',\n                    type=int, default=0, dest='asbe')\nparser.add_argument('--num-classes', type=int, default=10)\nparser.add_argument('--spike-output',action='store_true')\nparser.add_argument('--act-fun', type=str, default='GateGrad')\nparser.add_argument('--suffix', default='', type=str)\nargs = parser.parse_args()\n\nargs.save = '/data/floyed/darts/logs/search/search-{}-{}-{}'.format(args.save, time.strftime(\"%Y%m%d-%H%M%S\"),\n                                                                    args.suffix)\nutils.create_exp_dir(args.save, scripts_to_save=None)\n\nlog_format = '%(asctime)s %(message)s'\nlogging.basicConfig(stream=sys.stdout, level=logging.INFO,\n                    format=log_format, datefmt='%m/%d %I:%M:%S %p')\nfh = logging.FileHandler(os.path.join(args.save, 'log.txt'))\nfh.setFormatter(logging.Formatter(log_format))\nlogging.getLogger().addHandler(fh)\n\nCIFAR_CLASSES = 10\n\n\ndef main():\n    args.spike_output = False\n    if not torch.cuda.is_available():\n        logging.info('no gpu device available')\n        sys.exit(1)\n\n    np.random.seed(args.seed)\n    torch.cuda.set_device(args.device)\n    # cudnn.benchmark = True\n    torch.manual_seed(args.seed)\n    # cudnn.enabled = True\n    torch.cuda.manual_seed(args.seed)\n    logging.info('gpu device = %d' % args.device)\n    logging.info(\"args = %s\", args)\n    run_start = time.time()\n    start_epoch = 0\n    dur_time = 0\n\n    if args.loss_fn == 'mix':\n        criterion_train = MixLoss(LabelSmoothingCrossEntropy(\n            smoothing=args.smoothing).cuda())\n        criterion_val = MixLoss(nn.CrossEntropyLoss())\n    elif args.loss_fn == 'mse':\n        criterion_train = UnilateralMse(1.)\n        criterion_val = UnilateralMse(1.)\n    else:\n        criterion_train = LabelSmoothingCrossEntropy().cuda()\n        criterion_val = nn.CrossEntropyLoss().cuda()\n\n    criterion_train = ConvSeparateLoss(criterion_train, weight=args.aux_loss_weight) \\\n        if args.sep_loss == 'l2' else TriSeparateLoss(criterion_train, weight=args.aux_loss_weight)\n\n    model = Network(args.init_channels, args.num_classes, args.layers, criterion_train,\n                    steps=3, multiplier=3, stem_multiplier=3,\n                    parse_method=args.parse_method, op_threshold=args.op_threshold,\n                    step=args.step, node_type=args.node_type,\n                    back_connection=args.back_connection, act_fun=args.act_fun,\n                    dataset=args.dataset,\n                    spike_output=False,\n                    temporal_flatten=args.temporal_flatten)\n    model = model.cuda()\n    logging.info(\"param size = %fMB\", utils.count_parameters_in_MB(model))\n\n    model_optimizer = torch.optim.AdamW(\n        model.parameters(),\n        args.learning_rate,\n        # momentum=args.momentum,\n        weight_decay=args.weight_decay)\n\n    # # train_transform, valid_transform = utils._data_transforms_cifar(args)\n    # train_transform = build_transform(True, args.img_size)\n    # valid_transform = build_transform(False, args.img_size)\n    # train_data = dset.CIFAR10(\n    #     root=args.data, train=True, download=True, transform=train_transform)\n    #\n    # num_train = len(train_data)\n    # indices = list(range(num_train))\n    # split = int(np.floor(args.train_portion * num_train))\n    #\n    # train_queue = torch.utils.data.DataLoader(\n    #     train_data, batch_size=args.batch_size,\n    #     sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),\n    #     pin_memory=True)\n    #\n    # valid_queue = torch.utils.data.DataLoader(\n    #     train_data, batch_size=args.batch_size,\n    #     sampler=torch.utils.data.sampler.SubsetRandomSampler(\n    #         indices[split:num_train]),\n    #     pin_memory=True)\n    train_queue, valid_queue, _, _ = eval('get_%s_data' % args.dataset)(\n        batch_size=args.batch_size,\n        step=args.step,\n        args=args,\n        size=args.event_size,\n        mix_up=args.mix_up,\n        cut_mix=args.cut_mix,\n        event_mix=args.event_mix,\n        beta=args.cutmix_beta,\n        prob=args.cutmix_prob,\n        num=args.cutmix_num,\n        noise=args.cutmix_noise,\n        num_classes=args.num_classes,\n        rand_aug=args.rand_aug,\n        randaug_n=args.randaug_n,\n        randaug_m=args.randaug_m,\n        temporal_flatten=args.temporal_flatten\n    )\n\n    architect = Architect(model, args)\n\n    # resume from checkpoint\n    if args.resume:\n        if os.path.isfile(args.resume):\n            logging.info(\"=> loading checkpoint '{}'\".format(args.resume))\n            checkpoint = torch.load(\n                args.resume, map_location=model.alphas_normal.device)\n            start_epoch = checkpoint['epoch']\n            dur_time = checkpoint['dur_time']\n            model_optimizer.load_state_dict(checkpoint['model_optimizer'])\n            model.restore(checkpoint['network_states'])\n            logging.info('=> loaded checkpoint \\'{}\\'(epoch {})'.format(\n                args.resume, start_epoch))\n        else:\n            logging.info(\n                '=> no checkpoint found at \\'{}\\''.format(args.resume))\n\n    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n        model_optimizer, float(args.epochs), eta_min=args.learning_rate_min,\n        last_epoch=-1 if start_epoch == 0 else start_epoch)\n    if args.resume and os.path.isfile(args.resume):\n        scheduler.load_state_dict(checkpoint['scheduler'])\n\n    for epoch in range(start_epoch, args.epochs):\n        scheduler.step()\n        lr = scheduler.get_lr()[0]\n        logging.info('epoch %d lr %e', epoch, lr)\n\n        genotype = model.genotype()\n        logging.info('genotype = %s', genotype)\n\n        logging.info(calc_weight(model.alphas_normal))\n        logging.info(calc_loss(model.alphas_normal))\n        model.update_history()\n\n        # training and search the model\n        train_acc, train_obj = train(epoch, train_queue, valid_queue, model, architect, criterion_train,\n                                     model_optimizer)\n        logging.info('train_acc %f', train_acc)\n\n        # validation the model\n        model.record_fire_rate = True\n        model.reset_fire_rate_record()\n        valid_acc, valid_obj = infer(valid_queue, model, criterion_val)\n        fire_rate = model.get_fire_per_step()\n        model.record_fire_rate = False\n        logging.info('valid_fire_rate: {}'.format(fire_rate))\n        logging.info('valid_acc %f', valid_acc)\n\n        # save checkpoint\n        utils.save_checkpoint({\n            'epoch': epoch + 1,\n            'dur_time': dur_time + time.time() - run_start,\n            'scheduler': scheduler.state_dict(),\n            'model_optimizer': model_optimizer.state_dict(),\n            'network_states': model.states(),\n        }, is_best=False, save=args.save)\n        logging.info('save checkpoint (epoch %d) in %s  dur_time: %s', epoch, args.save,\n                     utils.calc_time(dur_time + time.time() - run_start))\n\n        # save operation weights as fig\n        utils.save_file(recoder=model.alphas_normal_history, path=os.path.join(args.save, 'normal'),\n                        back_connection=args.back_connection)\n\n    # save last operations\n    np.save(os.path.join(os.path.join(args.save, 'normal_weight.npy')),\n            calc_weight(model.alphas_normal).data.cpu().numpy())\n    logging.info('save last weights done')\n\n\ndef train(epoch, train_queue, valid_queue, model, architect, criterion, model_optimizer):\n    objs = utils.AvgrageMeter()\n    objs1 = utils.AvgrageMeter()\n    objs2 = utils.AvgrageMeter()\n    top1 = utils.AvgrageMeter()\n    top5 = utils.AvgrageMeter()\n\n    for step, (input, target) in enumerate(train_queue):\n        model.train()\n        n = input.size(0)\n        input = Variable(input, requires_grad=False).cuda(non_blocking=True)\n        target = Variable(target, requires_grad=False).cuda(non_blocking=True)\n        # if epoch >= args.asbe:\n            # Get a random minibatch from the search queue(validation set) with replacement\n            # input_search, target_search = next(iter(valid_queue))\n            # print(input.shape, target.shape)\n            # print(input_search.shape, target_search.shape)\n            # input_search = Variable(\n            #     input_search, requires_grad=False).cuda(non_blocking=True)\n            # target_search = Variable(\n            #     target_search, requires_grad=False).cuda(non_blocking=True)\n            # loss1, loss2 = architect.step(input_search, target_search)\n        # else:\n        loss1 = torch.tensor([0.])\n        loss2 = torch.tensor([0.])\n\n        model_optimizer.zero_grad()\n\n        logits = model(input)\n        aux_input = torch.cat(\n            [calc_loss(model.alphas_normal)], dim=0)\n        loss, _, _ = criterion(logits, target, aux_input)\n        # loss = criterion(logits, target)\n\n        loss.backward()\n        nn.utils.clip_grad_norm(model.parameters(), args.grad_clip)\n\n        # Update the network parameters\n        model_optimizer.step()\n\n        prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))\n        objs.update(loss.item(), n)\n        objs1.update(loss1, n)\n        objs2.update(loss2, n)\n        top1.update(prec1.item(), n)\n        top5.update(prec5.item(), n)\n\n        if step % args.report_freq == 0:\n            logging.info('train %03d loss: %e top1: %f top5: %f',\n                         step, objs.avg, top1.avg, top5.avg)\n            logging.info('val cls_loss %e; spe_loss %e', objs1.avg, objs2.avg)\n\n    return top1.avg, objs.avg\n\n\ndef infer(valid_queue, model, criterion):\n    objs = utils.AvgrageMeter()\n    top1 = utils.AvgrageMeter()\n    top5 = utils.AvgrageMeter()\n    model.eval()\n\n    with torch.no_grad():\n        for step, (input, target) in enumerate(valid_queue):\n            input = Variable(input, volatile=True).cuda(non_blocking=True)\n            target = Variable(target, volatile=True).cuda(non_blocking=True)\n\n            logits = model(input)\n            loss = criterion(logits, target)\n\n            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))\n            n = input.size(0)\n            objs.update(loss.item(), n)\n            top1.update(prec1.item(), n)\n            top5.update(prec5.item(), n)\n\n            if step % args.report_freq == 0:\n                logging.info('valid %03d %e %f %f', step,\n                             objs.avg, top1.avg, top5.avg)\n\n        return top1.avg, objs.avg\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "examples/Perception_and_Learning/NeuEvo/utils.py",
    "content": "import json\nimport matplotlib.pyplot as plt\nfrom braincog.model_zoo.NeuEvo.genotypes import Genotype, PRIMITIVES\nimport os\nimport numpy as np\nimport torch\nimport shutil\nimport torchvision.transforms as transforms\nfrom torch.autograd import Variable\nfrom auto_augment import CIFAR10Policy\nfrom braincog.model_zoo.NeuEvo.genotypes import PRIMITIVES\n\nforward_edge_num = sum(1 for i in range(3) for n in range(2 + i))\nbackward_edge_num = sum(1 for i in range(3) for n in range(i))\nnum_ops = len(PRIMITIVES)\ntype_num = len(PRIMITIVES) // 2\n# edge_num = [2, 3, 4]\nedge_num = [2, 3, 4, 1, 2]\n\n\nclass AvgrageMeter(object):\n\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.avg = 0\n        self.sum = 0\n        self.cnt = 0\n\n    def update(self, val, n=1):\n        self.sum += val * n\n        self.cnt += n\n        self.avg = self.sum / self.cnt\n\n\ndef accuracy(output, target, topk=(1,)):\n    \"\"\"Compute the top1 and top5 accuracy\n\n\"\"\"\n    maxk = max(topk)\n    batch_size = target.size(0)\n\n    # Return the k largest elements of the given input tensor\n    # along a given dimension -> N * k\n    _, pred = output.topk(maxk, 1, True, True)\n    pred = pred.t()\n    correct = pred.eq(target.view(1, -1).expand_as(pred))\n\n    res = []\n    for k in topk:\n        correct_k = correct[:k].reshape(-1).float().sum(0)\n        res.append(correct_k.mul_(100.0 / batch_size))\n    return res\n\n\nclass Cutout(object):\n    def __init__(self, length):\n        self.length = length\n\n    def __call__(self, img):\n        h, w = img.size(1), img.size(2)\n        mask = np.ones((h, w), np.float32)\n        y = np.random.randint(h)\n        x = np.random.randint(w)\n\n        y1 = np.clip(y - self.length // 2, 0, h)\n        y2 = np.clip(y + self.length // 2, 0, h)\n        x1 = np.clip(x - self.length // 2, 0, w)\n        x2 = np.clip(x + self.length // 2, 0, w)\n\n        mask[y1: y2, x1: x2] = 0.\n        mask = torch.from_numpy(mask)\n        mask = mask.expand_as(img)\n        img *= mask\n        return img\n\n\ndef _data_transforms_cifar(args):\n    CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124] if args.dataset == 'cifar10' else [0.50707519, 0.48654887,\n                                                                                         0.44091785]\n    CIFAR_STD = [0.24703233, 0.24348505, 0.26158768] if args.dataset == 'cifar10' else [0.26733428, 0.25643846,\n                                                                                        0.27615049]\n\n    normalize_transform = [\n        transforms.ToTensor(),\n        transforms.Normalize(CIFAR_MEAN, CIFAR_STD)]\n\n    random_transform = [\n        transforms.RandomCrop(32, padding=4),\n        transforms.RandomHorizontalFlip()]\n\n    if args.auto_aug:\n        random_transform += [CIFAR10Policy()]\n\n    if args.cutout:\n        cutout_transform = [Cutout(args.cutout_length)]\n    else:\n        cutout_transform = []\n\n    train_transform = transforms.Compose(\n        random_transform + normalize_transform + cutout_transform\n    )\n\n    valid_transform = transforms.Compose([\n        transforms.ToTensor(),\n        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),\n    ])\n    return train_transform, valid_transform\n\n\ndef count_parameters_in_MB(model):\n    return np.sum(np.prod(v.size()) for v in model.parameters()) / 1e6\n\n\ndef save_checkpoint(state, is_best, save):\n    filename = os.path.join(save, 'checkpoint.pth.tar')\n    torch.save(state, filename)\n    if is_best:\n        best_filename = os.path.join(save, 'model_best.pth.tar')\n        shutil.copyfile(filename, best_filename)\n\n\ndef save(model, model_path):\n    torch.save(model.state_dict(), model_path)\n\n\ndef load(model, model_path):\n    model.load_state_dict(torch.load(model_path))\n\n\ndef drop_path(x, drop_prob):\n    if drop_prob > 0.:\n        keep_prob = 1. - drop_prob\n        mask = Variable(torch.cuda.FloatTensor(\n            x.size(0), 1, 1, 1).bernoulli_(keep_prob))\n        x.div_(keep_prob)\n        x.mul_(mask)\n    return x\n\n\ndef create_exp_dir(path, scripts_to_save=None):\n    if not os.path.exists(path):\n        os.makedirs(path)\n    print('Experiment dir : {}'.format(path))\n\n    if scripts_to_save is not None:\n        os.makedirs(os.path.join(path, 'scripts'))\n        for script in scripts_to_save:\n            dst_file = os.path.join(path, 'scripts', os.path.basename(script))\n            shutil.copyfile(script, dst_file)\n\n\ndef calc_time(seconds):\n    m, s = divmod(seconds, 60)\n    h, m = divmod(m, 60)\n    t, h = divmod(h, 24)\n    return {'day': t, 'hour': h, 'minute': m, 'second': int(s)}\n\n\ndef save_file(recoder, path='./', back_connection=False):\n    size = (forward_edge_num +\n            backward_edge_num if back_connection else forward_edge_num, num_ops)\n    fig, axs = plt.subplots(*size, figsize=(36, 98))\n    row = 0\n    col = 0\n    for (k, v) in recoder.items():\n        axs[row, col].set_title(k)\n        axs[row, col].plot(v, 'r+')\n        if col == num_ops - 1:\n            col = 0\n            row += 1\n        else:\n            col += 1\n    if not os.path.exists(path):\n        os.makedirs(path)\n    fig.savefig(os.path.join(path, 'output.png'), bbox_inches='tight')\n    plt.tight_layout()\n    print('save history weight in {}'.format(os.path.join(path, 'output.png')))\n    with open(os.path.join(path, 'history_weight.json'), 'w') as outf:\n        json.dump(recoder, outf)\n        print('save history weight in {}'.format(\n            os.path.join(path, 'history_weight.json')))\n"
  },
  {
    "path": "examples/Perception_and_Learning/QSNN/README.md",
    "content": "# Quantum superposition inspired spiking neural network\n\nThis repository contains code from our paper [**Quantum superposition inspired spiking neural network**] published in iScience. https://doi.org/10.1016/j.isci.2021.102880. If you use our code or refer to this project, please cite this paper.\n\n## Requirments\n\n* numpy\n* scipy\n* pytorch >= 1.7.0\n* torchvision\n\n\n## Train\n\n```shell  \npython ./main.py\n```\n## Citation\n\nIf you find this package helpful, please consider citing the following papers:\n\n```BibTex\n@article{sun2021quantum,\n  title={Quantum superposition inspired spiking neural network},\n  author={Sun, Yinqian and Zeng, Yi and Zhang, Tielin},\n  journal={Iscience},\n  volume={24},\n  number={8},\n  pages={102880},\n  year={2021},\n  publisher={Elsevier}\n}\n\n@misc{https://doi.org/10.48550/arxiv.2207.08533,\n  doi = {10.48550/ARXIV.2207.08533},\n  url = {https://arxiv.org/abs/2207.08533},\n  author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},\n  title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},\n  publisher = {arXiv},\n  year = {2022},\n}\n\n```"
  },
  {
    "path": "examples/Perception_and_Learning/QSNN/main.py",
    "content": "import os\nimport copy\nimport tqdm\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom braincog.datasets.datasets import get_mnist_data\nfrom braincog.model_zoo.qsnn import Net\nfrom braincog.datasets.gen_input_signal import lambda_max\n\n\nLOG_DIR = os.path.expanduser('./results.txt')\n\nLEARNING_RATE = 0.01\n# learning dacay\nDECAY_STEPS = 1.0\nDECAY_RATE = 0.9\n# adam\nBETA1 = 0.9\nBETA2 = 0.999\nEPSIOLN = 1e-8\n\nEPOCHS = 20\nPRINT_PERIOD = 10000\nTEST_SIZE = 10000\n\nTEST_THETA = [0, 1 / 16, 2 / 16, 3 / 16, 4 / 16, 5 / 16, 6 / 16, 7 / 16, 8 / 16]\n# TEST_THETA = [0]\nNOISE_RATES = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]\n\n\nmnist_data = get_mnist_data(batch_size=1, skip_norm=True)\n\ntrain_loader, test_loader = mnist_data.get_data_loaders()\n\nNET_SIZE = [28 * 28, 500, 10]\n\n\ndef int2onehot(label, classes, factor):\n    label_one_hot = F.one_hot(label, classes)\n    label_one_hot = label_one_hot * (8 + factor) - 8\n    return label_one_hot\n\n\ndef train(net, epochs, lr):\n    with open(LOG_DIR, 'a+') as f:\n        for epoch in range(epochs):\n            lr_decay = lr * DECAY_RATE ** (epoch / DECAY_STEPS)\n            for x, y in tqdm.tqdm(train_loader):\n                label = int2onehot(y, 10, 8)\n                x = x.flatten().numpy()\n                label = label.cuda()\n                with torch.no_grad():\n                    net.routine(x, None, image_ori=None, image_ori_delta=None, shift=False,\n                                label=label, test=False, noise=False, noise_rate=None)\n                    net.update_weight(lr_decay, epoch + 1, (BETA1, BETA2), EPSIOLN)\n            with torch.no_grad():\n                for fac in TEST_THETA:\n                    acc_reve = 0\n                    for x_test, y_test in test_loader:\n                        image = x_test.flatten().numpy()\n                        image_shift = image * np.cos(fac * np.pi) + (lambda_max - image) * np.sin(fac * np.pi)\n                        image_delta = copy.copy(image)\n                        delta_idx = image_delta < (lambda_max - 0.001)\n                        image_delta[delta_idx] += 0.001\n                        image_delta_shift = image_delta * np.cos(fac * np.pi) + (lambda_max - image_delta) * np.sin(fac * np.pi)\n                        pred = net.predict(image_shift, image_delta_shift, image, image_delta, shift=True, noise=False, noise_rate=None)\n                        if pred == int(y_test):\n                            acc_reve += 1\n                    acc_reve = acc_reve / TEST_SIZE\n                    print('Test epoch {epoch}: Shift {theta:0.3f} pi: accuracy {acc}.'.format(\n                        epoch=epoch, theta=fac, acc=acc_reve))\n                    print('Test epoch {epoch}: Shift {theta:0.3f} pi: accuracy {acc}'.format(\n                        epoch=epoch, theta=fac, acc=acc_reve), file=f)\n                    print()\n                    print(file=f)\n\n\nif __name__ == '__main__':\n    net = Net(NET_SIZE).cuda()\n    train(net, EPOCHS, LEARNING_RATE)\n"
  },
  {
    "path": "examples/Perception_and_Learning/UnsupervisedSTDP/Readme.md",
    "content": "This is an example of training Unsupervised STDP-based spiking neural network. We used a STB-STDP algrithom to train SNN, and mutiply adaptive mechanisms.\n \n# How to run\npython codef.py \n\n# Result\nWe train the model on Mnist and FashionMNIST, and the best accuracy for MNIST is 97.9%, for FashionMNIST is 87.0%.\n \n\n### Citation \n\nIf you find this package helpful, please consider citing the following papers:\n```BibTex\n@article{dong2022unsupervised,\n  title={An Unsupervised Spiking Neural Network Inspired By Biologically Plausible Learning Rules and Connections},\n  author={Dong, Yiting and Zhao, Dongcheng and Li, Yang and Zeng, Yi},\n  journal={arXiv preprint arXiv:2207.02727},\n  year={2022}\n}\n\n@article{zeng2022braincog,\n  title={BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},\n  author={Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and others},\n  journal={arXiv preprint arXiv:2207.08533},\n  year={2022}\n}\n\n```\n"
  },
  {
    "path": "examples/Perception_and_Learning/UnsupervisedSTDP/codef.py",
    "content": "import torch\nimport torch.nn as nn\nimport torchvision.transforms as transforms\nimport torchvision.datasets as datasets\nfrom tqdm import tqdm \nimport matplotlib\nmatplotlib.use('Agg')  \nimport matplotlib.pyplot as plt\nimport cv2\nimport numpy as np\nfrom copy import deepcopy\nimport os, time, math,random\nfrom  braincog.base.node.node import *\nfrom braincog.base.connection .layer import *\nfrom braincog.base.strategy.LateralInhibition import *\nfrom sklearn.metrics import confusion_matrix\n\nseed = 0\ntorch.manual_seed(seed)\ntorch.cuda.manual_seed_all(seed)\nnp.random.seed(seed)\n  \ntorch.cuda.manual_seed(seed) #GPU随机种子确定 \n\ntorch.backends.cudnn.benchmark = False #模型卷积层预先优化关闭\ntorch.backends.cudnn.deterministic = True #确定为默认卷积算法\n\nrandom.seed(seed) \n\nos.environ[\"PYTHONHASHSEED\"] = str(seed) \n\nos.environ[\"KMP_DUPLICATE_LIB_OK\"] = \"TRUE\"\ndev = \"cuda\"\ndevice = torch.device(dev) if torch.cuda.is_available() else 'cpu'\ntorch.set_printoptions(precision=4, sci_mode=False)\n\n\n# ===========================================================================================================\n\nconvoff = 0.3\n\n\n# avgscale = 5\n\n\nclass STDPConv(nn.Module):\n    def __init__(self, in_planes, out_planes, kernel_size, stride, padding,groups,\n                 tau_decay=torch.exp(-1.0 / torch.tensor(100.0)), offset=convoff, static=True, inh=6.5, avgscale=5):\n        super().__init__()\n        self.tau_decay = tau_decay\n        self.offset = offset\n        self.static = static\n        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,groups=groups,\n                              bias=False)\n        self.avgpool = nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=padding)\n        self.mem = self.spike = self.refrac_count = None\n        self.normweight()\n        self.inh = inh\n        self.avgscale = avgscale\n        self.onespike=True\n        self.node=LIFSTDPNode(act_fun=STDPGrad,tau=tau_decay,mem_detach=True)\n        self.WTA=WTALayer( )\n        self.lateralinh=LateralInhibition(self.node,self.inh,mode=\"threshold\")\n        \n    def mem_update(self, x, onespike=True):  # b,c,h,w\n\n        x=self.node( x)\n               \n        if x.max() > 0:\n            x=self.WTA(x)\n        \n            self.lateralinh(x)\n\n        self.spike= x \n        return self.spike\n\n    def forward(self, x, T=None, onespike=True):\n\n        if not self.static:\n            batch, T, c, h, w = x.shape\n            x = x.reshape(-1, c, h, w)\n\n        x = self.conv(  x)\n\n        n = self.getthresh(x)\n        self.node.threshold.data = n\n\n        x=x.clamp(min=0)\n        x = n / (1 + torch.exp(-(x - 4 * n / 10) * (8 / n)))\n\n        if not self.static:\n            x = x.reshape(batch, T, c, h, w)\n            xsum = None\n            for i in range(T):\n                tmp = self.mem_update(x[:, i], onespike).unsqueeze(1)\n                if xsum is not None:\n                    xsum = torch.cat([xsum, tmp], 1)\n                else:\n                    xsum = tmp\n        else:\n            xsum = 0\n            for i in range(T):\n                xsum += self.mem_update(x, onespike)\n\n        return xsum\n\n    def reset(self):\n        #self.mem = self.spike = self.refrac_count = None\n        self.node.n_reset()\n    def normgrad(self, force=False):\n        if force:\n            min = self.conv.weight.grad.data.min(1, True)[0].min(2, True)[0].min(3, True)[0]\n            max = self.conv.weight.grad.data.min(1, True)[0].max(2, True)[0].max(3, True)[0]\n            self.conv.weight.grad.data -= min\n            tmp = self.offset * max\n        else:\n            tmp = self.offset * self.spike.mean(0, True).mean(2, True).mean(3, True).permute(1, 0, 2, 3)\n        self.conv.weight.grad.data -= tmp\n        self.conv.weight.grad.data = -self.conv.weight.grad.data\n\n    def normweight(self, clip=False):\n        if clip:\n            self.conv.weight.data = torch. \\\n                clamp(self.conv.weight.data, min=-3, max=1.0)\n        else:\n            c, i, w, h = self.conv.weight.data.shape\n\n            avg=self.conv.weight.data.mean(1, True).mean(2, True).mean(3, True)\n            self.conv.weight.data -=avg\n\n            tmp = self.conv.weight.data.reshape(c, 1, -1, 1)\n\n            self.conv.weight.data /= tmp.std(2, unbiased=False, keepdim=True)\n\n\n    def getthresh(self, scale):\n\n        tmp2= scale.max(0, True)[0].max(2, True)[0].max(3, True)[0]+0.0001\n\n        return tmp2\n\n\nclass STDPLinear(nn.Module):\n    def __init__(self, in_planes, out_planes,\n                 tau_decay=0.99, offset=0.05, static=True,inh=10):\n        super().__init__()\n        self.tau_decay = tau_decay\n        self.offset = offset\n        self.static = static\n        self.linear = nn.Linear(in_planes, out_planes, bias=False)\n        self.mem = self.spike = self.refrac_count = None\n        # torch.nn.init.xavier_uniform_(self.linear.weight, gain=1)\n        self.normweight(False)\n        self.threshold = torch.ones(out_planes, device=device) *20\n        \n        self.inh=inh\n        self.node=LIFSTDPNode(act_fun=STDPGrad,tau=tau_decay  ,mem_detach=True)\n        self.WTA=WTALayer( )\n        self.lateralinh=LateralInhibition(self.node,self.inh,mode=\"max\")\n        self.init=False \n    def mem_update(self, x, onespike=True):  # b,c,h,w\n        if not self.init: \n            self.node.threshold.data= (x.max(0)[0].detach()*3).to(device) \n            self.init=True\n\n        xori=x\n        x=self.node( x)\n        if x.max() > 0:\n            x=self.WTA(x)\n        \n            self.lateralinh(x,xori)\n\n        self.spike=x\n        return self.spike\n\n    def forward(self, x, T, onespike=True):\n\n        if not self.static:\n            batch, T, w = x.shape\n            x = x.reshape(-1, w)\n        x = x.detach()\n\n\n        \n        x = self.linear(x)\n        self.x=x.detach()\n\n        if not self.static:\n            x = x.reshape(batch, T, w)\n            xsum = None\n            for i in range(T):\n                tmp = self.mem_update(x[:, i], onespike).unsqueeze(1)\n                if xsum is not None:\n                    xsum = torch.cat([xsum, tmp], 1)\n                else:\n                    xsum = tmp\n        else:\n            xsum = 0\n            for i in range(T):\n                xsum += self.mem_update(x, onespike)\n        #print(xsum.mean())\n        return xsum\n\n    def reset(self):\n\n        self.node.n_reset()\n    def normgrad(self, force=False):\n        if force:\n\n            pass\n        else:\n            tmp = self.offset * self.spike.mean(0, True).permute(1, 0)\n\n\n        self.linear.weight.grad.data = -self.linear.weight.grad.data\n\n\n    def normweight(self, clip=False):\n\n        if clip:\n            self.linear.weight.data = torch. \\\n                clamp(self.linear.weight.data, min=0, max=1.0)\n        else:\n            self.linear.weight.data = torch. \\\n                clamp(self.linear.weight.data, min=0, max=1.0)\n            sumweight = self.linear.weight.data.sum(1, True)\n            sumweight += (~(sumweight.bool())).float()\n            # self.linear.weight.data *= 11.76  / sumweight\n            self.linear.weight.data /= self.linear.weight.data.max(1, True)[0] / 0.1\n\n    def getthresh(self, scale):\n        tmp = self.linear.weight.clamp(min=0) * scale\n        tmp2 = tmp.sum(1, True).reshape(1, -1)\n        return tmp2\n\n    def updatethresh(self, plus=0.05):\n\n        self.node.threshold += (plus*self.x * self.spike.detach()).sum(0)\n        tmp=self.node.threshold.max()-350\n        if tmp>0:\n            self.node.threshold-=tmp\n\nclass STDPFlatten(nn.Module):\n    def __init__(self, start_dim=0, end_dim=-1):\n        super().__init__()\n        self.flatten = nn.Flatten(start_dim=start_dim, end_dim=end_dim)\n\n    def forward(self, x, T):  # [batch,T,c,w,h]\n \n        return self.flatten(x)\n\n\nclass STDPMaxPool(nn.Module):\n    def __init__(self, kernel_size, stride, padding, static=True):\n        super().__init__()\n        self.static = static\n        self.pool = nn.MaxPool2d(kernel_size, stride, padding)\n\n    def forward(self, x, T):  # [batch,T,c,w,h]\n\n        if not self.static:\n            batch, T, c, h, w = x.shape\n            x = x.reshape(-1, c, h, w)\n        x = self.pool(x)\n        if not self.static:\n            x = x.reshape(batch, T, c, h, w)\n\n        return x\n\n\nalpha = 1.0\n\n\nclass Normliaze(nn.Module):\n    def __init__(self, static=True):\n        super().__init__()\n        self.static = static\n\n    def forward(self, x, T):  # [batch,T,c,w,h]\n        # print(x.shape)\n        x /= x.max(1, True)[0].max(2, True)[0].max(3, True)[0]\n        # x/=x.mean()/0.13\n\n        return x\n\n\nclass voting(nn.Module):\n\n    def __init__(self, shape):\n        super().__init__()\n        self.label = torch.zeros(shape) - 1\n        self.assignments=0\n    def assign_labels(self, spikes, labels, rates=None, n_labels=10, alpha=alpha):\n        # 根据最后一层的spikes 以及 label 对于最后一层的神经元赋予不同的label\n        # spikes 是 batch * time * in_size\n        # print(spikes.size())\n        n_neurons = spikes.size(2)\n        if rates is None:\n            rates = torch.zeros(n_neurons, n_labels, device=device)\n        self.n_labels = n_labels\n        spikes = spikes.cpu().sum(1).to(device)\n\n        for i in range(n_labels):\n            n_labeled = torch.sum(labels == i).float()\n            # 就是说上一次assign label计算的rates 拿过来滑动平均一下   #这里似乎可以改\n            if n_labeled > 0:\n                indices = torch.nonzero(labels == i).view(-1)\n                tmp = torch.sum(spikes[indices], 0) / n_labeled  # 平均脉冲数\n                rates[:, i] = alpha * rates[:, i] + tmp\n\n        # 此时的rates是 in_size * n_label, 对应哪个label的rates最高 该神经元就对应着该label\n        self.assignments = torch.max(rates, 1)[1]\n        return self.assignments, rates\n\n    def get_label(self, spikes):\n        # 根据最后一层的spike 计算得到label\n        n_samples = spikes.size(0)\n        spikes = spikes.cpu().sum(1).to(device)\n        rates = torch.zeros(n_samples, self.n_labels, device=device)\n\n        for i in range(self.n_labels):\n            n_assigns = torch.sum(self.assignments == i).float()  # 共有多少个该类别节点\n            if n_assigns > 0:\n                indices = torch.nonzero(self.assignments == i).view(-1)  # 找到该类别节点位置\n                rates[:, i] = torch.sum(spikes[:, indices], 1) / n_assigns  # 该类别平均所有该类别节点发放脉冲数\n\n        return torch.sort(rates, dim=1, descending=True)[1][:, 0]\n\ninh=25\ninh2=1.625\nchannel=12\nneuron=6400\nclass Conv_Net(nn.Module):\n    def __init__(self):\n        super(Conv_Net, self).__init__()\n        self.conv = nn.ModuleList([\n            STDPConv(1, channel, 3, 1, 1,1, static=True, inh=1.625, avgscale=5 ),\n            STDPMaxPool(2, 2, 0, static=True),\n            Normliaze(),\n            #STDPConv(12, 48, 3, 1, 1,1, static=True, inh=inh2, avgscale=10 ),\n            #STDPMaxPool(2, 2, 0, static=True),\n            #Normliaze(),\n\n            STDPFlatten(start_dim=1),\n            STDPLinear(196*channel, neuron, static=True,inh=inh)\n\n        ])\n\n        self.voting = voting(10)\n\n    def forward(self, x, inlayer, outlayer, T, onespike=True):  # [b,t,w,h]\n\n        for i in range(inlayer, outlayer + 1):\n            x = self.conv[i](x, T)\n        return x\n\n    def normgrad(self, layer, force=False):\n        self.conv[layer].normgrad(force)\n\n    def normweight(self, layer, clip=False):\n        self.conv[layer].normweight(clip)\n\n    def updatethresh(self, layer, plus=0.05):\n        self.conv[layer].updatethresh(plus)\n\n    def reset(self, layer):\n        if isinstance(layer, list):\n            for i in layer:\n                self.conv[i].reset()\n        else:\n            self.conv[layer].reset()\n\n\ndef plot_confusion_matrix(cm, classes, normalize=True, title='Test Confusion matrix', cmap=plt.cm.Blues):\n    if normalize:\n        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n        #print(\"Normalized confusion matrix\")\n    else:\n        print('Confusion matrix, without normalization')\n    plt.figure()\n    #print(cm)\n    plt.imshow(cm, interpolation='nearest', cmap=cmap)\n    plt.title(title)\n    plt.colorbar()\n    tick_marks = np.arange(len(classes))\n    # plt.xticks(tick_marks, classes, rotation=45)\n    plt.xticks(tick_marks, classes)\n    plt.yticks(tick_marks, classes)\n\n    fmt = '.2f' if normalize else 'd'\n    thresh = cm.max() / 2.\n\n    for i in range(cm.shape[0]):\n        plt.text(i, i, format(cm[i, i], fmt), horizontalalignment=\"center\",\n                 color=\"white\" if cm[i, i] > thresh else \"black\")\n    plt.tight_layout()\n    #plt.savefig('confusestpf2'+str(channel)+\"_n\"+str(neuron)+\".pdf\")\n    #plt.show()\nif __name__ == '__main__':\n    print(23)\n    batch_size = 1024\n    T = 100\n    transform = transforms.Compose(\n        [transforms.Resize((28, 28)), transforms.Grayscale(num_output_channels=1), transforms.ToTensor()])\n    transform = transforms.Compose([transforms.ToTensor()])\n    # mnist_train = datasets.CIFAR10(root='/data/datasets/CIFAR10/', train=True, download=False, transform=transform )\n    # mnist_test = datasets.CIFAR10(root='/data/datasets/CIFAR10/', train=False, download=False, transform=transform )\n    #mnist_train = datasets.FashionMNIST(root='/data/dyt//', train=True, download=True, transform=transform )\n    #mnist_test = datasets.FashionMNIST(root='/data/dyt/', train=False, download=False, transform=transform )\n    mnist_train = datasets.MNIST(root='./', train=True, download=True, transform=transform)\n    mnist_test = datasets.MNIST(root='./', train=False, download=False, transform=transform)\n    train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=4)\n    test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=4)\n\n    model = Conv_Net().to(device)\n    convlist = [index for index, i in enumerate(model.conv) if isinstance(i, (STDPConv, STDPLinear))]\n    print(convlist)\n    #cap = torch.ones([100000, 1000, 30], device=device)\n    \n    for layer in range(len(convlist) - 1):\n        optimizer = torch.optim.SGD(list(model.parameters())[layer:layer + 1], lr=0.1)\n        for epoch in range(3):\n            for step, (x, y) in enumerate(tqdm(train_iter)):\n                x = x.to(device)\n                y = y.to(device)\n\n                spikes = model(x, 0, convlist[layer], T)\n\n                optimizer.zero_grad()\n                spikes.sum().backward(torch.tensor(1/  (spikes.shape[0] * spikes.shape[2] * spikes.shape[3])))\n                # spikes.sum().backward(  )\n                model.conv[convlist[layer]].spike = spikes.detach()\n                model.normgrad(convlist[layer], force=True)\n                optimizer.step()\n                model.normweight(convlist[layer], clip=False)\n                # print(model.conv[convlist[layer]].conv.weight.data )\n                model.reset(convlist)\n\n            print(\"layer\", layer, \"epoch\", epoch, 'Done')\n        #model.conv[convlist[layer]].onespike=False\n    # ===========================================================================================================\n    # linear\n    #model.conv[convlist[-2]].onespike=True \n    cap = None\n    batch_size = 1024\n    T = 200\n    layer = len(convlist) - 1\n    plus = 0.002\n    lr = 0.0001\n    train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=4)\n    test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=4)\n    optimizer = torch.optim.SGD(list(model.parameters())[layer:], lr=lr)\n\n    rates = None\n    best = 0\n    accrecord=[]\n    for epoch in range(1000):\n        spikefull = None\n        labelfull = None\n        for step, (x, y) in enumerate(tqdm(train_iter)):\n            x = x.to(device)\n            y = y.to(device)\n\n            spiketime = 0\n\n            spikes = model(x, 0, convlist[layer], T)\n            # print(spikes.mean())\n            optimizer.zero_grad()\n            spikes.sum().backward()\n            model.conv[convlist[layer]].spike = spikes.detach()\n            model.normgrad(convlist[layer], force=False)\n            optimizer.step()\n            model.updatethresh(convlist[layer], plus=plus)\n            model.normweight(convlist[layer], clip=False)\n\n            spikes = spikes.reshape(spikes.shape[0], 1, -1).detach()\n            if spikefull is None:\n                spikefull = spikes\n                labelfull = y\n            else:\n                spikefull = torch.cat([spikefull, spikes], 0)\n                labelfull = torch.cat([labelfull, y], 0)\n\n            model.reset(convlist)\n\n        _, rates = model.voting.assign_labels(spikefull, labelfull, rates)\n        rates = rates.detach() * 0.5\n        result = model.voting.get_label(spikefull)\n        acc = (result == labelfull).float().mean()\n\n        print(epoch, acc, 'channel', channel, \"n\", neuron)\n        print(model.conv[-1].node.threshold.max(),model.conv[-1].node.threshold.mean(),model.conv[-1].node.threshold.min())\n        \n        # model.conv[-1].threshold*=0.98\n        spikefull = None\n        labelfull = None\n        result = None\n        for step2, (x, y) in enumerate(test_iter):\n            x = x.to(device)\n            y = y.to(device)\n\n            spiketime = 0\n            spikes = model(x, 0, convlist[layer], T)\n\n            spikes = spikes.reshape(spikes.shape[0], 1, -1).detach()\n\n            with torch.no_grad():\n                if spikefull is None:\n                    spikefull = spikes\n                    labelfull = y\n\n                else:\n                    spikefull = torch.cat([spikefull, spikes], 0)\n                    labelfull = torch.cat([labelfull, y], 0)\n\n            model.reset(convlist)\n\n        result = model.voting.get_label(spikefull)\n        acc = (result == labelfull).float().mean()\n        if best < acc: \n            best = acc \n            torch.save( model, \"modelftstp28_350_c\"+str(channel)+\"_n\"+str(neuron)+\"_p\"+str(acc)+\".pth\")\n            classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']\n            \n            cm = confusion_matrix(labelfull.cpu(), result.cpu())\n            plot_confusion_matrix(cm, classes)\n        print(\"test\", acc, \"best\", best)\n        accrecord.append(acc)\n        #torch.save(accrecord,\"accfstp28_350_c\"+str(channel)+\"_n\"+str(neuron)+\".pth\")"
  },
  {
    "path": "examples/Perception_and_Learning/img_cls/bp/README.md",
    "content": "# Script for training high-performance SNNs based on back propagation \nThis is an example of training high-performance SNNs using the braincog.\nIt is able to train high performance SNNs on CIFAR10, DVS-CIFAR10, ImageNet and other datasets, and reach the advanced level. \n\n## Install braincog  \n\n```shell\ngit clone https://github.com/xxx/Brain-Cog.git\ncd braincog \npython setup install --user \n```\n\n## Examples of training\n\n```shell\ncd examples/Perception_and_Learning/img_cls/bp \npython main.py --model dvs_convnet --node-type LIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGateGrad --device 0 \n```\n\n## Benchmark \n\nWe provide a benchmark of SNNs trained with braincog and the corresponding scripts. \nThis provides an open, fair platform for comparison of subsequent SNNs on classification tasks. \n\n**Note**: The results may vary due to random seeding and software version issues. \n\n### CIFAR10 \n\n| ID  | Dataset |   Node-type    | Config |     Model     | Batch Size |   Accuracy   | Script                                                                                                                                         |\n|:----|:-------:|:--------------:|:------:|:-------------:|:----------:|:------------:|:-----------------------------------------------------------------------------------------------------------------------------------------------|\n| 1   | CIFAR10 |    IF+Atan     |   -    |    convnet    |    128     |    95.54     | ```python main.py --model cifar_convnet --node-type IFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun AtanGrad --device 0```        |\n| 1   | CIFAR10 |    LIF+Atan    |   -    |    convnet    |    128     |    91.92     | ```python main.py --model cifar_convnet --node-type LIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun AtanGrad --device 0```       |\n| 1   | CIFAR10 |   PLIF+Atan    |   -    |    convnet    |    128     |    93.32     | ```python main.py --model cifar_convnet --node-type PLIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun AtanGrad --device 0```      |\n| 1   | CIFAR10 |    IF+Atan     |   -    |   resnet18    |    128     | 89.76/89.80  | ```python main.py --model resnet18 --node-type IFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun AtanGrad --device 0```             |\n| 1   | CIFAR10 |    LIF+Atan    |   -    |   resnet18    |    128     | 89.93/89.88  | ```python main.py --model resnet18 --node-type LIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun AtanGrad --device 0```            |\n| 1   | CIFAR10 |   PLIF+Atan    |   -    |   resnet18    |    128     | 92.64/ 90.65 | ```python main.py --model resnet18 --node-type PLIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun AtanGrad --device 0```           |\n| 1   | CIFAR10 |  IF+QGateGrad  |   -    | cifar_convnet |    128     |    95.73     | ```python main.py --model cifar_convnet --node-type IFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun QGateGradGrad --device 0```   |\n| 1   | CIFAR10 | LIF+QGateGrad  |   -    | cifar_convnet |    128     |    96.04     | ```python main.py --model cifar_convnet --node-type LIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun QGateGradGrad --device 0```  |\n| 1   | CIFAR10 | PLIF+QGateGrad |   -    | cifar_convnet |    128     | 96.04/95.84  | ```python main.py --model cifar_convnet --node-type PLIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun QGateGradGrad --device 0``` |\n| 1   | CIFAR10 |  IF+QGateGrad  |   -    |   resnet18    |    128     |    89.19     | ```python main.py --model resnet18 --node-type IFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun QGateGradGrad --device 0```        |\n| 1   | CIFAR10 | LIF+QGateGrad  |   -    |   resnet18    |    128     | 90.95/90.68  | ```python main.py --model resnet18 --node-type LIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun QGateGradGrad --device 0```       |\n| 1   | CIFAR10 | PLIF+QGateGrad |   -    |   resnet18    |    128     | 90.97/91.02  | ```python main.py --model resnet18 --node-type PLIFNode --dataset cifar10 --step 4 --batch-size 128 --act-fun QGateGradGrad --device 0```      |\n\n\n### CIFAR100 \n| ID  | Dataset  |   Node-type    | Config |    Model    | Batch Size | Accuracy | Script                                                                                                                                                            |\n|:----|:--------:|:--------------:|:------:|:-----------:|:----------:|:--------:|:------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| 1   | CIFAR100 |    IF+Atan     |   -    | dvs_convnet |    128     |  76.52   | ```python main.py --num-classes 100 --model cifar_convnet --node-type IFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun AtanGrad --device 0```        |\n| 1   | CIFAR100 |    LIF+Atan    |   -    | dvs_convnet |    128     |  71.89   | ```python main.py --num-classes 100 --model cifar_convnet --node-type LIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun AtanGrad --device 0```       |\n| 1   | CIFAR100 |   PLIF+Atan    |   -    | dvs_convnet |    128     |  72.82   | ```python main.py --num-classes 100 --model cifar_convnet --node-type PLIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun AtanGrad --device 0```      |\n| 1   | CIFAR100 |    IF+Atan     |   -    |  resnet18   |    128     |  62.47   | ```python main.py --num-classes 100 --model resnet18 --node-type IFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun AtanGrad --device 0```             |\n| 1   | CIFAR100 |    LIF+Atan    |   -    |  resnet18   |    128     |  62.63   | ```python main.py --num-classes 100 --model resnet18 --node-type LIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun AtanGrad --device 0```            |\n| 1   | CIFAR100 |   PLIF+Atan    |   -    |  resnet18   |    128     |  62.71   | ```python main.py --num-classes 100 --model resnet18 --node-type PLIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun AtanGrad --device 0```           |\n| 1   | CIFAR100 |  IF+QGateGrad  |   -    | dvs_convnet |    128     |  76.44   | ```python main.py --num-classes 100 --model cifar_convnet --node-type IFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun QGateGradGrad --device 0```   |\n| 1   | CIFAR100 | LIF+QGateGrad  |   -    | dvs_convnet |    128     |  77.73   | ```python main.py --num-classes 100 --model cifar_convnet --node-type LIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun QGateGradGrad --device 0```  |\n| 1   | CIFAR100 | PLIF+QGateGrad |   -    | dvs_convnet |    128     |  77.25   | ```python main.py --num-classes 100 --model cifar_convnet --node-type PLIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun QGateGradGrad --device 0``` |\n| 1   | CIFAR100 |  IF+QGateGrad  |   -    |  resnet18   |    128     |  60.01   | ```python main.py --num-classes 100 --model resnet18 --node-type IFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun QGateGradGrad --device 0```        |\n| 1   | CIFAR100 | LIF+QGateGrad  |   -    |  resnet18   |    128     |  61.33   | ```python main.py --num-classes 100 --model resnet18 --node-type LIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun QGateGradGrad --device 0```       |\n| 1   | CIFAR100 | PLIF+QGateGrad |   -    |  resnet18   |    128     |  62.32   | ```python main.py --num-classes 100 --model resnet18 --node-type PLIFNode --dataset cifar100 --step 4 --batch-size 128 --act-fun QGateGradGrad --device 0```      |\n\n\n### DVS-CIFAR10\n\n| ID  |   Dataset   |   Node-type    | Config |    Model    | Batch Size | FLOPS |  Accuracy   | Script                                                                                                                                       |\n|:----|:-----------:|:--------------:|:------:|:-----------:|:----------:|:-----:|:-----------:|:---------------------------------------------------------------------------------------------------------------------------------------------|\n| 1   | DVS-CIFAR10 |    IF+Atan     |   -    | dvs_convnet |    128     | 7503  |    65.90    | ```python main.py --model dvs_convnet --node-type IFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun AtanGrad --device 0```        |\n| 1   | DVS-CIFAR10 |    LIF+Atan    |   -    | dvs_convnet |    128     | 7503  |    82.10    | ```python main.py --model dvs_convnet --node-type LIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun AtanGrad --device 0```       |\n| 1   | DVS-CIFAR10 |   PLIF+Atan    |   -    | dvs_convnet |    128     | 7503  |    81.90    | ```python main.py --model dvs_convnet --node-type PLIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun AtanGrad --device 0```      |\n| 1   | DVS-CIFAR10 |    IF+Atan     |   -    |  resnet18   |    128     | 3149  |    69.10    | ```python main.py --model resnet18 --node-type IFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun AtanGrad --device 0```           |\n| 1   | DVS-CIFAR10 |    LIF+Atan    |   -    |  resnet18   |    128     | 3149  |    78.50    | ```python main.py --model resnet18 --node-type LIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun AtanGrad --device 0```          |\n| 1   | DVS-CIFAR10 |   PLIF+Atan    |   -    |  resnet18   |    128     | 3149  |    77.70    | ```python main.py --model resnet18 --node-type PLIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun AtanGrad --device 0```         |\n| 1   | DVS-CIFAR10 |  IF+QGateGrad  |   -    | dvs_convnet |    128     | 7503  |    68.30    | ```python main.py --model dvs_convnet --node-type IFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0```   |\n| 1   | DVS-CIFAR10 | LIF+QGateGrad  |   -    | dvs_convnet |    128     | 7503  | 82.60/82.90 | ```python main.py --model dvs_convnet --node-type LIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0```  |\n| 1   | DVS-CIFAR10 | PLIF+QGateGrad |   -    | dvs_convnet |    128     | 7503  |    83.20    | ```python main.py --model dvs_convnet --node-type PLIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0``` |\n| 1   | DVS-CIFAR10 |  IF+QGateGrad  |   -    |  resnet18   |    128     | 3149  | 65.70/66.80 | ```python main.py --model resnet18 --node-type IFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0```      |\n| 1   | DVS-CIFAR10 | LIF+QGateGrad  |   -    |  resnet18   |    128     | 3149  | 79.00/79.40 | ```python main.py --model resnet18 --node-type LIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0```     |\n| 1   | DVS-CIFAR10 | PLIF+QGateGrad |   -    |  resnet18   |    128     | 3149  | 78.10/78.20 | ```python main.py --model resnet18 --node-type PLIFNode --dataset dvsc10 --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0```    |\n\n\n### DVS-Gesture\n\n| ID  | Dataset |   Node-type    | Config |    Model    | Batch Size |  Accuracy   | Script                                                                                                                                                      |\n|:----|:-------:|:--------------:|:------:|:-----------:|:----------:|:-----------:|:------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| 1   |  DVS-G  |    IF+Atan     |   -    | dvs_convnet |    128     |    64.77    | ```python main.py --num-classes 11 --model dvs_convnet --node-type IFNode --dataset dvsg --step 10 --batch-size 128 --act-fun AtanGrad --device 0```        |\n| 1   |  DVS-G  |    LIF+Atan    |   -    | dvs_convnet |    128     |    91.28    | ```python main.py --num-classes 11 --model dvs_convnet --node-type LIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun AtanGrad --device 0```       |\n| 1   |  DVS-G  |   PLIF+Atan    |   -    | dvs_convnet |    128     |    91.67    | ```python main.py --num-classes 11 --model dvs_convnet --node-type PLIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun AtanGrad --device 0```      |\n| 1   |  DVS-G  |    IF+Atan     |   -    |  resnet18   |    128     |    63.25    | ```python main.py --num-classes 11 --model resnet18 --node-type IFNode --dataset dvsg --step 10 --batch-size 128 --act-fun AtanGrad --device 0```           |\n| 1   |  DVS-G  |    LIF+Atan    |   -    |  resnet18   |    128     |    91.29    | ```python main.py --num-classes 11 --model resnet18 --node-type LIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun AtanGrad --device 0```          |\n| 1   |  DVS-G  |   PLIF+Atan    |   -    |  resnet18   |    128     |    90.15    | ```python main.py --num-classes 11 --model resnet18 --node-type PLIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun AtanGrad --device 0```         |\n| 1   |  DVS-G  |  IF+QGateGrad  |   -    | dvs_convnet |    128     |    48.48    | ```python main.py --num-classes 11 --model dvs_convnet --node-type IFNode --dataset dvsg --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0```   |\n| 1   |  DVS-G  | LIF+QGateGrad  |   -    | dvs_convnet |    128     | 92.05/92.42 | ```python main.py --num-classes 11 --model dvs_convnet --node-type LIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0```  |\n| 1   |  DVS-G  | PLIF+QGateGrad |   -    | dvs_convnet |    128     |    91.28    | ```python main.py --num-classes 11 --model dvs_convnet --node-type PLIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0``` |\n| 1   |  DVS-G  |  IF+QGateGrad  |   -    |  resnet18   |    128     |    57.95    | ```python main.py --num-classes 11 --model resnet18 --node-type IFNode --dataset dvsg --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0```      |\n| 1   |  DVS-G  | LIF+QGateGrad  |   -    |  resnet18   |    128     |    90.91    | ```python main.py --num-classes 11 --model resnet18 --node-type LIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0```     |\n| 1   |  DVS-G  | PLIF+QGateGrad |   -    |  resnet18   |    128     |    92.42    | ```python main.py --num-classes 11 --model resnet18 --node-type PLIFNode --dataset dvsg --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0```    |\n\n### NCALTECH101\n\n| ID  |   Dataset   |   Node-type    | Config |    Model    | Batch Size |  Accuracy   | Script                                                                                                                                                              |\n|:----|:-----------:|:--------------:|:------:|:-----------:|:----------:|:-----------:|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| 1   | NCALTECH101 |  IF+QGateGrad  |   -    | dvs_convnet |    128     | 23.09/51.15 | ```python main.py --num-classes 100 --model dvs_convnet --node-type IFNode --dataset NCALTECH101 --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0```   |\n| 1   | NCALTECH101 | LIF+QGateGrad  |   -    | dvs_convnet |    128     | 72.78/75.09 | ```python main.py --num-classes 100 --model dvs_convnet --node-type LIFNode --dataset NCALTECH101 --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0```  |\n| 1   | NCALTECH101 | PLIF+QGateGrad |   -    | dvs_convnet |    128     | 74.61/76.79 | ```python main.py --num-classes 100 --model dvs_convnet --node-type PLIFNode --dataset NCALTECH101 --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0``` |\n| 1   | NCALTECH101 |  IF+QGateGrad  | -/mix  |  resnet18   |    128     | 61.24/60.87 | ```python main.py --num-classes 100 --model resnet18 --node-type IFNode --dataset NCALTECH101 --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0```      |\n| 1   | NCALTECH101 | LIF+QGateGrad  | -/mix  |  resnet18   |    128     | 66.22/70.84 | ```python main.py --num-classes 100 --model resnet18 --node-type LIFNode --dataset NCALTECH101 --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0```     |\n| 1   | NCALTECH101 | PLIF+QGateGrad | -/mix  |  resnet18   |    128     | 69.62/69.87 | ```python main.py --num-classes 100 --model resnet18 --node-type PLIFNode --dataset NCALTECH101 --step 10 --batch-size 128 --act-fun QGateGradGrad --device 0```    |\n\n### SHD\n\n| ID   | Dataset |   Node-type   | Config |  Model  | Batch Size | Accuracy | Script                                                       |\n| :--- | :-----: | :-----------: | :----: | :-----: | :--------: | :------: | :----------------------------------------------------------- |\n| 1    |   SHD   | LIF+QGateGrad |   -    | shd_snn |    256     |  88.47   | ```python main.py --model SHD_SNN --node-type LIFNode --dataset shd --step 15 --batch-size 256 --act-fun QGateGrad --device 1 --tau 10. --threshold 0.3 --lr 5e-3 --min-lr 1e-4 --loss-fn onehot-mse --num-classes 20 --amp --weight-decay 0.01 ``` |\n\nNote: \n\n1. resnet18 is used here by adding a maximum pooling after the initial convolution layer.\nHowever, in the final version of braincog, we remove this pooling layer.\n2. mix refers to the use of EventMix as a data augmentation method.\n3. We will continue to add other results.\n\n\n### Citation \nIf you find this package helpful, please consider citing it:\n\n```BibTex\n@misc{zengbraincogSpikingNeural2022,\n  title = {{{braincog}}: {{A Spiking Neural Network}} Based {{Brain-inspired Cognitive Intelligence Engine}} for {{Brain-inspired AI}} and {{Brain Simulation}}},\n  shorttitle = {{{braincog}}},\n  author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},\n  year = {2022},\n  month = jul,\n  number = {arXiv:2207.08533},\n  eprint = {2207.08533},\n  eprinttype = {arxiv},\n  primaryclass = {cs},\n  publisher = {{arXiv}},\n  doi = {10.48550/arXiv.2207.08533}\n}\n```\n"
  },
  {
    "path": "examples/Perception_and_Learning/img_cls/bp/main.py",
    "content": "import argparse\nimport time\n\nimport timm.models\nimport yaml\nimport os\nimport random as buildin_random\nimport logging\nfrom collections import OrderedDict\nfrom contextlib import suppress\nfrom datetime import datetime\n\nfrom braincog.base.node.node import *\nfrom braincog.utils import *\nfrom braincog.base.utils.criterions import *\nfrom braincog.datasets.datasets import *\nfrom braincog.model_zoo.resnet import *\nfrom braincog.model_zoo.convnet import *\nfrom braincog.model_zoo.vgg_snn import VGG_SNN, SNN5\nfrom braincog.model_zoo.fc_snn import SHD_SNN\nfrom braincog.model_zoo.resnet19_snn import resnet19\nfrom braincog.model_zoo.sew_resnet import sew_resnet18, sew_resnet34, sew_resnet50\nfrom braincog.utils import save_feature_map, setup_seed\nfrom braincog.base.utils.visualization import plot_tsne_3d, plot_tsne, plot_confusion_matrix, plot_mem_distribution\n\nimport torch\nimport torch.nn as nn\nimport torchvision.utils\nfrom torch.nn.parallel import DistributedDataParallel as NativeDDP\n\nfrom timm.data import ImageDataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset\nfrom timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model, register_model\nfrom timm.utils import *\nfrom timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy\nfrom timm.optim import create_optimizer\nfrom timm.scheduler import create_scheduler\nfrom timm.utils import ApexScaler, NativeScaler\n\nfrom torch.utils.tensorboard import SummaryWriter\n\n# from ptflops import get_model_complexity_info\n# from thop import profile, clever_format\n\ntorch.backends.cudnn.benchmark = True\n_logger = logging.getLogger('train')\n\n# The first arg parser parses out only the --config argument, this argument is used to\n# load a yaml file containing key-values that override the defaults for the main parser below\nconfig_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)\nparser.add_argument('-c', '--config', default='', type=str, metavar='FILE',\n                    help='YAML config file specifying default arguments')\n\nparser = argparse.ArgumentParser(description='SNN Training and Evaluating')\n\n# Model parameters\nparser.add_argument('--dataset', default='mnist', type=str)\nparser.add_argument('--model', default='mnist_convnet', type=str, metavar='MODEL',\n                    help='Name of model to train (default: \"countception\"')\nparser.add_argument('--pretrained', action='store_true', default=False,\n                    help='Start with pretrained version of specified network (if avail)')\nparser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',\n                    help='Initialize model from this checkpoint (default: none)')\nparser.add_argument('--resume', default='', type=str, metavar='PATH',\n                    help='Resume full model and optimizer state from checkpoint (default: none)')\nparser.add_argument('--eval_checkpoint', default='', type=str, metavar='PATH',\n                    help='path to eval checkpoint (default: none)')\nparser.add_argument('--no-resume-opt', action='store_true', default=False,\n                    help='prevent resume of optimizer state when resuming model')\nparser.add_argument('--num-classes', type=int, default=10, metavar='N',\n                    help='number of label classes (default: 1000)')\nparser.add_argument('--gp', default=None, type=str, metavar='POOL',\n                    help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')\n\n# Dataset parameters for static datasets\nparser.add_argument('--img-size', type=int, default=224, metavar='N',\n                    help='Image patch size (default: None => model default)')\nparser.add_argument('--crop-pct', default=None, type=float,\n                    metavar='N', help='inputs image center crop percent (for validation only)')\nparser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',\n                    help='Override mean pixel value of dataset')\nparser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',\n                    help='Override std deviation of of dataset')\nparser.add_argument('--interpolation', default='', type=str, metavar='NAME',\n                    help='Image resize interpolation type (overrides model)')\n\n# Dataloader parameters\nparser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',\n                    help='inputs batch size for training (default: 128)')\nparser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',\n                    help='ratio of validation batch size to training batch size (default: 1)')\n\n# Optimizer parameters\nparser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',\n                    help='Optimizer (default: \"adamw\"')\nparser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',\n                    help='Optimizer Epsilon (default: None, use opt default)')\nparser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',\n                    help='Optimizer Betas (default: None, use opt default)')\nparser.add_argument('--momentum', type=float, default=0.9, metavar='M',\n                    help='Optimizer momentum (default: 0.9)')\nparser.add_argument('--weight-decay', type=float, default=1e-4,\n                    help='weight decay (default: 0.01 for adamw)')\nparser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',\n                    help='Clip gradient norm (default: None, no clipping)')\nparser.add_argument('--adam-epoch', type=int, default=1000, help='lamb switch to adamw')\n\n# Learning rate schedule parameters\nparser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',\n                    help='LR scheduler (default: \"cosine\"')\nparser.add_argument('--lr', type=float, default=5e-3, metavar='LR',\n                    help='learning rate (default: 0.01)')\nparser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',\n                    help='learning rate noise on/off epoch percentages')\nparser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',\n                    help='learning rate noise limit percent (default: 0.67)')\nparser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',\n                    help='learning rate noise std-dev (default: 1.0)')\nparser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',\n                    help='learning rate cycle len multiplier (default: 1.0)')\nparser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',\n                    help='learning rate cycle limit')\nparser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',\n                    help='warmup learning rate (default: 0.0001)')\nparser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',\n                    help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')\nparser.add_argument('--epochs', type=int, default=200, metavar='N',\n                    help='number of epochs to train (default: 2)')\nparser.add_argument('--start-epoch', default=None, type=int, metavar='N',\n                    help='manual epoch number (useful on restarts)')\nparser.add_argument('--decay-epochs', type=float, default=30, metavar='N',\n                    help='epoch interval to decay LR')\nparser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',\n                    help='epochs to warmup LR, if scheduler supports')\nparser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',\n                    help='epochs to cooldown LR at min_lr, after cyclic schedule ends')\nparser.add_argument('--patience-epochs', type=int, default=10, metavar='N',\n                    help='patience epochs for Plateau LR scheduler (default: 10')\nparser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',\n                    help='LR decay rate (default: 0.1)')\nparser.add_argument('--power', type=int, default=1, help='power')\n\n# Augmentation & regularization parameters ONLY FOR IMAGE NET\nparser.add_argument('--no-aug', action='store_true', default=False,\n                    help='Disable all training augmentation, override other train aug args')\nparser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',\n                    help='Random resize scale (default: 0.08 1.0)')\nparser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                    help='Random resize aspect ratio (default: 0.75 1.33)')\nparser.add_argument('--hflip', type=float, default=0.5,\n                    help='Horizontal flip training aug probability')\nparser.add_argument('--vflip', type=float, default=0.,\n                    help='Vertical flip training aug probability')\nparser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',\n                    help='Color jitter factor (default: 0.4)')\nparser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',\n                    help='Use AutoAugment policy. \"v0\" or \"original\". (default: None)'),\nparser.add_argument('--aug-splits', type=int, default=0,\n                    help='Number of augmentation splits (default: 0, valid: 0 or >=2)')\nparser.add_argument('--jsd', action='store_true', default=False,\n                    help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')\nparser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',\n                    help='Random erase prob (default: 0.25)')\nparser.add_argument('--remode', type=str, default='pixel',\n                    help='Random erase mode (default: \"const\")')\nparser.add_argument('--recount', type=int, default=1,\n                    help='Random erase count (default: 1)')\nparser.add_argument('--resplit', action='store_true', default=False,\n                    help='Do not random erase first (clean) augmentation split')\nparser.add_argument('--mixup', type=float, default=0.,\n                    help='mixup alpha, mixup enabled if > 0. (default: 0.)')\nparser.add_argument('--cutmix', type=float, default=0.,\n                    help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')\nparser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,\n                    help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')\nparser.add_argument('--mixup-prob', type=float, default=0.,\n                    help='Probability of performing mixup or cutmix when either/both is enabled')\nparser.add_argument('--mixup-switch-prob', type=float, default=0.5,\n                    help='Probability of switching to cutmix when both mixup and cutmix enabled')\nparser.add_argument('--mixup-mode', type=str, default='batch',\n                    help='How to apply mixup/cutmix params. Per \"batch\", \"pair\", or \"elem\"')\nparser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',\n                    help='Turn off mixup after this epoch, disabled if 0 (default: 0)')\nparser.add_argument('--smoothing', type=float, default=0.1,\n                    help='Label smoothing (default: 0.1)')\nparser.add_argument('--train-interpolation', type=str, default='random',\n                    help='Training interpolation (random, bilinear, bicubic default: \"random\")')\nparser.add_argument('--drop', type=float, default=0.0, metavar='PCT',\n                    help='Dropout rate (default: 0.0)')\nparser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',\n                    help='Drop connect rate, DEPRECATED, use drop-path (default: None)')\nparser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',\n                    help='Drop path rate (default: None)')\nparser.add_argument('--drop-block', type=float, default=None, metavar='PCT',\n                    help='Drop block rate (default: None)')\nparser.add_argument('--newton-maxiter', default=20, type=int,\n                    help='max iterration in newton method')\nparser.add_argument('--reset-drop', action='store_true', default=False,\n                    help='whether to reset drop')\nparser.add_argument('--kernel-method', type=str, default='cuda', choices=['torch', 'cuda'],\n                    help='The implementation way of gaussian kernel method, choose from \"cuda\" and \"torch\"')\n\n# Batch norm parameters (only works with gen_efficientnet based models currently)\nparser.add_argument('--bn-tf', action='store_true', default=False,\n                    help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')\nparser.add_argument('--bn-momentum', type=float, default=None,\n                    help='BatchNorm momentum override (if not None)')\nparser.add_argument('--bn-eps', type=float, default=None,\n                    help='BatchNorm epsilon override (if not None)')\nparser.add_argument('--sync-bn', action='store_true',\n                    help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')\nparser.add_argument('--dist-bn', type=str, default='',\n                    help='Distribute BatchNorm stats between node after each epoch (\"broadcast\", \"reduce\", or \"\")')\nparser.add_argument('--split-bn', action='store_true',\n                    help='Enable separate BN layers per augmentation split.')\n\n# Model Exponential Moving Average\nparser.add_argument('--model-ema', action='store_true', default=False,\n                    help='Enable tracking moving average of model weights')\nparser.add_argument('--model-ema-force-cpu', action='store_true', default=False,\n                    help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')\nparser.add_argument('--model-ema-decay', type=float, default=0.99996,\n                    help='decay factor for model weights moving average (default: 0.9998)')\n\n# Misc\nparser.add_argument('--seed', type=int, default=42, metavar='S',\n                    help='random seed (default: 42)')\nparser.add_argument('--log-interval', type=int, default=50, metavar='N',\n                    help='how many batches to wait before logging training status')\nparser.add_argument('--recovery-interval', type=int, default=0, metavar='N',\n                    help='how many batches to wait before writing recovery checkpoint')\nparser.add_argument('-j', '--workers', type=int, default=8, metavar='N',\n                    help='how many training processes to use (default: 1)')\nparser.add_argument('--num-gpu', type=int, default=1,\n                    help='Number of GPUS to use')\nparser.add_argument('--save-images', action='store_true', default=False,\n                    help='save images of inputs bathes every log interval for debugging')\nparser.add_argument('--amp', action='store_true', default=False,\n                    help='use NVIDIA Apex AMP or Native AMP for mixed precision training')\nparser.add_argument('--apex-amp', action='store_true', default=False,\n                    help='Use NVIDIA Apex AMP mixed precision')\nparser.add_argument('--native-amp', action='store_true', default=False,\n                    help='Use Native Torch AMP mixed precision')\nparser.add_argument('--channels-last', action='store_true', default=False,\n                    help='Use channels_last memory layout')\nparser.add_argument('--pin-mem', action='store_true', default=False,\n                    help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')\nparser.add_argument('--no-prefetcher', action='store_true', default=False,\n                    help='disable fast prefetcher')\nparser.add_argument('--output', default='/data/floyed/BrainCog', type=str, metavar='PATH',\n                    help='path to output folder (default: none, current dir)')\nparser.add_argument('--tensorboard-dir', default='./runs', type=str)\nparser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',\n                    help='Best metric (default: \"top1\"')\nparser.add_argument('--tta', type=int, default=0, metavar='N',\n                    help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')\nparser.add_argument('--local_rank', default=0, type=int)\nparser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,\n                    help='use the multi-epochs-loader to save time at the beginning of every epoch')\nparser.add_argument('--eval', action='store_true', help='Perform evaluation only')\nparser.add_argument('--device', type=int, default=0)\n\n# Spike parameters\nparser.add_argument('--step', type=int, default=10, help='Simulation time step (default: 10)')\nparser.add_argument('--encode', type=str, default='direct', help='Input encode method (default: direct)')\nparser.add_argument('--temporal-flatten', action='store_true',\n                    help='Temporal flatten to channels. ONLY FOR EVENT DATA TRAINING BY ANN')\nparser.add_argument('--adaptive-node', action='store_true')\nparser.add_argument('--critical-loss', action='store_true')\nparser.add_argument('--conv-type', type=str, default='normal')\nparser.add_argument('--sew-cnf', type=str, default='ADD')\nparser.add_argument('--rand-step', action='store_true')\n\n# neuron type\nparser.add_argument('--node-type', type=str, default='LIFNode', help='Node type in network (default: PLIF)')\nparser.add_argument('--act-fun', type=str, default='QGateGrad',\n                    help='Surogate Function in node. Only for Surrogate nodes (default: AtanGrad)')\nparser.add_argument('--threshold', type=float, default=.5, help='Firing threshold (default: 0.5)')\nparser.add_argument('--tau', type=float, default=2., help='Attenuation coefficient (default: 2.)')\nparser.add_argument('--requires-thres-grad', action='store_true')\nparser.add_argument('--sigmoid-thres', action='store_true')\n\nparser.add_argument('--loss-fn', type=str, default='ce', help='loss function (default: ce)')\nparser.add_argument('--noisy-grad', type=float, default=0.,\n                    help='Add noise to backward, sometime will make higher accuracy (default: 0.)')\nparser.add_argument('--spike-output', action='store_true', default=False,\n                    help='Using mem output or spike output (default: False)')\nparser.add_argument('--n_groups', type=int, default=1)\nparser.add_argument('--n-encode-type', type=str, default='linear')\nparser.add_argument('--n-preact', action='store_true')\nparser.add_argument('--layer-by-layer', action='store_true',\n                    help='forward step-by-step or layer-by-layer. '\n                         'Larger Model with layer-by-layer will be faster (default: False)')\nparser.add_argument('--tet-loss', action='store_true')\n\n# EventData Augmentation\nparser.add_argument('--mix-up', action='store_true', help='Mix-up for event data (default: False)')\nparser.add_argument('--cut-mix', action='store_true', help='CutMix for event data (default: False)')\nparser.add_argument('--event-mix', action='store_true', help='EventMix for event data (default: False)')\nparser.add_argument('--cutmix_beta', type=float, default=2.0, help='cutmix_beta (default: 1.)')\nparser.add_argument('--cutmix_prob', type=float, default=0.5, help='cutmix_prib for event data (default: .5)')\nparser.add_argument('--cutmix_num', type=int, default=1, help='cutmix_num for event data (default: 1)')\nparser.add_argument('--cutmix_noise', type=float, default=0.,\n                    help='Add Pepper noise after mix, sometimes work (default: 0.)')\nparser.add_argument('--gaussian-n', type=int, default=3)\nparser.add_argument('--rand-aug', action='store_true',\n                    help='Rand Augment for Event data (default: False)')\nparser.add_argument('--randaug_n', type=int, default=3,\n                    help='Rand Augment times n (default: 3)')\nparser.add_argument('--randaug_m', type=int, default=15,\n                    help='Rand Augment times n (default: 15) (0-30)')\nparser.add_argument('--train-portion', type=float, default=0.9,\n                    help='Dataset portion, only for datasets which do not have validation set (default: 0.9)')\nparser.add_argument('--event-size', default=48, type=int,\n                    help='Event size. Resize event data before process (default: 48)')\nparser.add_argument('--node-resume', type=str, default='',\n                    help='resume weights in node for adaptive node. (default: False)')\n\n# visualize\nparser.add_argument('--visualize', action='store_true',\n                    help='Visualize spiking map for each layer, only for validate (default: False)')\nparser.add_argument('--spike-rate', action='store_true',\n                    help='Print spiking rate for each layer, only for validate(default: False)')\nparser.add_argument('--tsne', action='store_true')\nparser.add_argument('--conf-mat', action='store_true')\nparser.add_argument('--mem-dist', action='store_true')\nparser.add_argument('--adaptation-info', action='store_true')\n\nparser.add_argument('--suffix', type=str, default='',\n                    help='Add an additional suffix to the save path (default: \\'\\')')\n\ntry:\n    from apex import amp\n    from apex.parallel import DistributedDataParallel as ApexDDP\n    from apex.parallel import convert_syncbn_model\n\n    has_apex = True\nexcept ImportError:\n    has_apex = False\n\nhas_native_amp = False\ntry:\n    if getattr(torch.cuda.amp, 'autocast') is not None:\n        has_native_amp = True\nexcept AttributeError:\n    pass\n\n\ndef _parse_args():\n    # Do we have a config file to parse?\n    args_config, remaining = config_parser.parse_known_args()\n    if args_config.config:\n        with open(args_config.config, 'r') as f:\n            cfg = yaml.safe_load(f)\n            parser.set_defaults(**cfg)\n\n    # The main arg parser parses the rest of the args, the usual\n    # defaults will have been overridden if config file specified.\n    args = parser.parse_args(remaining)\n\n    # Cache the args as a text string to save them in the output dir later\n    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)\n    return args, args_text\n\n\n@register_model\ndef resnet50d_pretrained(*args, **kwargs):\n    model = create_model('resnet50d', pretrained=True)\n    model.fc = nn.Linear(2048, 10)\n    # model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n    return model\n\n\ndef main():\n    args, args_text = _parse_args()\n    # args.no_spike_output = args.no_spike_output | args.cut_mix\n    args.no_spike_output = True\n    output_dir = ''\n    if args.local_rank == 0:\n        output_base = args.output if args.output else './output'\n        exp_name = '-'.join([\n            args.model,\n            args.dataset,\n            args.node_type,\n            str(args.step),\n            args.suffix,\n            datetime.now().strftime(\"%Y%m%d-%H%M%S\"),\n            # str(args.img_size)\n        ])\n        output_dir = get_outdir(output_base, 'train', exp_name)\n        args.output_dir = output_dir\n        setup_default_logging(log_path=os.path.join(output_dir, 'log.txt'))\n        summary_writer = SummaryWriter(log_dir=os.path.join(args.tensorboard_dir, exp_name))\n        args.tensorboard_prefix = os.path.join(args.dataset, args.model)\n    else:\n        summary_writer = None\n        setup_default_logging()\n\n    args.prefetcher = not args.no_prefetcher\n    args.distributed = False\n    if 'WORLD_SIZE' in os.environ:\n        args.distributed = int(os.environ['WORLD_SIZE']) > 1\n        if args.distributed and args.num_gpu > 1:\n            _logger.warning(\n                'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')\n            args.num_gpu = 1\n\n    # args.device = 'cuda:0'\n    args.world_size = 1\n    args.rank = 0  # global rank\n    if args.distributed:\n        args.num_gpu = 1\n        args.device = 'cuda:%d' % args.local_rank\n        torch.cuda.set_device(args.local_rank)\n        torch.distributed.init_process_group(backend='nccl', init_method='env://')\n        args.world_size = torch.distributed.get_world_size()\n        args.rank = torch.distributed.get_rank()\n    else:\n        torch.cuda.set_device('cuda:%d' % args.device)\n    assert args.rank >= 0\n\n    if args.distributed:\n        _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'\n                     % (args.rank, args.world_size))\n    else:\n        _logger.info('Training with a single process on %d GPUs.' % args.num_gpu)\n\n    # torch.manual_seed(args.seed + args.rank)\n    setup_seed(args.seed + args.rank)\n\n    model = create_model(\n        args.model,\n        pretrained=args.pretrained,\n        num_classes=args.num_classes,\n        adaptive_node=args.adaptive_node,\n        dataset=args.dataset,\n        step=args.step,\n        encode_type=args.encode,\n        node_type=eval(args.node_type),\n        threshold=args.threshold,\n        tau=args.tau,\n        sigmoid_thres=args.sigmoid_thres,\n        requires_thres_grad=args.requires_thres_grad,\n        spike_output=not args.no_spike_output,\n        act_fun=args.act_fun,\n        temporal_flatten=args.temporal_flatten,\n        layer_by_layer=args.layer_by_layer,\n        n_groups=args.n_groups,\n        n_encode_type=args.n_encode_type,\n        n_preact=args.n_preact,\n        tet_loss=args.tet_loss,\n        sew_cnf=args.sew_cnf,\n        conv_type=args.conv_type,\n    )\n\n    _logger.info('[MODEL ARCH]\\n{}'.format(model))\n\n    if 'dvs' in args.dataset:\n        args.channels = 2\n    elif 'mnist' in args.dataset:\n        args.channels = 1\n    else:\n        args.channels = 3\n    # flops, params = profile(model, inputs=(torch.randn(1, args.channels, args.event_size, args.event_size),), verbose=False)\n    # _logger.info('flops = %fM', flops / 1e6)\n    # _logger.info('param size = %fM', params / 1e6)\n\n    linear_scaled_lr = args.lr * args.batch_size * args.world_size / 1024.0\n    args.lr = linear_scaled_lr\n    _logger.info(\"learning rate is %f\" % linear_scaled_lr)\n\n    if args.local_rank == 0:\n        _logger.info('Model %s created, param count: %d' %\n                     (args.model, sum([m.numel() for m in model.parameters()])))\n\n    num_aug_splits = 0\n    if args.aug_splits > 0:\n        assert args.aug_splits > 1, 'A split of 1 makes no sense'\n        num_aug_splits = args.aug_splits\n\n    if args.split_bn:\n        assert num_aug_splits > 1 or args.resplit\n        model = convert_splitbn_model(model, max(num_aug_splits, 2))\n\n    use_amp = None\n    if args.amp:\n        # for backwards compat, `--amp` arg tries apex before native amp\n        if has_apex:\n            args.apex_amp = True\n        elif has_native_amp:\n            args.native_amp = True\n    if args.apex_amp and has_apex:\n        use_amp = 'apex'\n    elif args.native_amp and has_native_amp:\n        use_amp = 'native'\n    elif args.apex_amp or args.native_amp:\n        _logger.warning(\"Neither APEX or native Torch AMP is available, using float32. \"\n                        \"Install NVIDA apex or upgrade to PyTorch 1.6\")\n\n    if args.num_gpu > 1:\n        if use_amp == 'apex':\n            _logger.warning(\n                'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.')\n            use_amp = None\n        model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()\n        assert not args.channels_last, \"Channels last not supported with DP, use DDP.\"\n    else:\n        model = model.cuda()\n        if args.channels_last:\n            model = model.to(memory_format=torch.channels_last)\n\n    optimizer = create_optimizer(args, model)\n\n    _logger.info('[OPTIMIZER]\\n{}'.format(optimizer))\n\n    amp_autocast = suppress  # do nothing\n    loss_scaler = None\n    if use_amp == 'apex':\n        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')\n        loss_scaler = ApexScaler()\n        if args.local_rank == 0:\n            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')\n    elif use_amp == 'native':\n        amp_autocast = torch.cuda.amp.autocast\n        loss_scaler = NativeScaler()\n        if args.local_rank == 0:\n            _logger.info('Using native Torch AMP. Training in mixed precision.')\n    else:\n        if args.local_rank == 0:\n            _logger.info('AMP not enabled. Training in float32.')\n\n    # optionally resume from a checkpoint\n    resume_epoch = None\n    if args.resume and args.eval_checkpoint == '':\n        args.eval_checkpoint = args.resume\n    if args.resume:\n        args.eval = True\n        # checkpoint = torch.load(args.resume, map_location='cpu')\n        # model.load_state_dict(checkpoint['state_dict'], False)\n        resume_epoch = resume_checkpoint(\n            model, args.resume,\n            optimizer=None if args.no_resume_opt else optimizer,\n            loss_scaler=None if args.no_resume_opt else loss_scaler,\n            log_info=args.local_rank == 0)\n        # print(model.get_attr('mu'))\n        # print(model.get_attr('sigma'))\n        if hasattr(model, 'set_threshold'):\n            model.set_threshold(args.threshold)\n\n    if args.critical_loss or args.spike_rate:\n        model.set_requires_fp(True)\n\n    model_ema = None\n    if args.model_ema:\n        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper\n        model_ema = ModelEma(\n            model,\n            decay=args.model_ema_decay,\n            device='cpu' if args.model_ema_force_cpu else '',\n            resume=args.resume)\n\n    if args.node_resume:\n        ckpt = torch.load(args.node_resume, map_location='cpu')\n        model.load_node_weight(ckpt, args.node_trainable)\n\n    model_without_ddp = model\n    if args.distributed:\n        if args.sync_bn:\n            assert not args.split_bn\n            try:\n                if has_apex and use_amp != 'native':\n                    # Apex SyncBN preferred unless native amp is activated\n                    model = convert_syncbn_model(model)\n                else:\n                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)\n                if args.local_rank == 0:\n                    _logger.info(\n                        'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '\n                        'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')\n            except Exception as e:\n                _logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')\n        if has_apex and use_amp != 'native':\n            # Apex DDP preferred unless native amp is activated\n            if args.local_rank == 0:\n                _logger.info(\"Using NVIDIA APEX DistributedDataParallel.\")\n            model = ApexDDP(model, delay_allreduce=True)\n        else:\n            if args.local_rank == 0:\n                _logger.info(\"Using native Torch DistributedDataParallel.\")\n            model = NativeDDP(model.cuda(), device_ids=[args.local_rank],\n                              find_unused_parameters=True)  # can use device str in Torch >= 1.1\n        model_without_ddp = model.module\n    # NOTE: EMA model does not need to be wrapped by DDP\n\n    lr_scheduler, num_epochs = create_scheduler(args, optimizer)\n    start_epoch = 0\n    if args.start_epoch is not None:\n        # a specified start_epoch will always override the resume epoch\n        start_epoch = args.start_epoch\n    elif resume_epoch is not None:\n        start_epoch = resume_epoch\n    if lr_scheduler is not None and start_epoch > 0:\n        lr_scheduler.step(start_epoch)\n\n    if args.local_rank == 0:\n        _logger.info('Scheduled epochs: {}'.format(num_epochs))\n\n    # now config only for imnet\n    data_config = resolve_data_config(vars(args), model=model, verbose=False)\n    loader_train, loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % args.dataset)(\n        batch_size=args.batch_size,\n        step=args.step,\n        args=args,\n        _logge=_logger,\n        data_config=data_config,\n        num_aug_splits=num_aug_splits,\n        size=args.event_size,\n        mix_up=args.mix_up,\n        cut_mix=args.cut_mix,\n        event_mix=args.event_mix,\n        beta=args.cutmix_beta,\n        prob=args.cutmix_prob,\n        gaussian_n=args.gaussian_n,\n        num=args.cutmix_num,\n        noise=args.cutmix_noise,\n        num_classes=args.num_classes,\n        rand_aug=args.rand_aug,\n        randaug_n=args.randaug_n,\n        randaug_m=args.randaug_m,\n        portion=args.train_portion,\n        _logger=_logger,\n    )\n    # _logger.info('train_loader:\\n{}\\nval_loader:\\n{}'.format(loader_train, loader_eval))\n    if args.loss_fn == 'mse':\n        train_loss_fn = UnilateralMse(1.)\n        validate_loss_fn = UnilateralMse(1.)\n    elif args.loss_fn == 'onehot-mse':\n        train_loss_fn = OnehotMse(args.num_classes)\n        validate_loss_fn = OnehotMse(args.num_classes)\n    else:\n        if args.jsd:\n            assert num_aug_splits > 1  # JSD only valid with aug splits set\n            train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()\n        elif mixup_active:\n            # smoothing is handled with mixup target transform\n            train_loss_fn = SoftTargetCrossEntropy().cuda()\n        elif args.smoothing:\n            train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()\n        else:\n            train_loss_fn = nn.CrossEntropyLoss().cuda()\n\n        validate_loss_fn = nn.CrossEntropyLoss().cuda()\n\n    if args.loss_fn == 'mix':\n        train_loss_fn = MixLoss(train_loss_fn)\n        validate_loss_fn = MixLoss(validate_loss_fn)\n\n    if args.tet_loss:\n        train_loss_fn = TetLoss(train_loss_fn)\n        validate_loss_fn = TetLoss(validate_loss_fn)\n\n    eval_metric = args.eval_metric\n    best_metric = None\n    best_epoch = None\n\n    if args.eval:  # evaluate the model\n        # if args.distributed:\n        #     raise NotImplementedError('eval not has not been verified for distributed')\n        # else:\n        #     load_checkpoint(model, args.eval_checkpoint, args.model_ema)\n        model.eval()\n        for t in range(1, args.step * 3):\n        # for t in range(args.step, args.step + 1):\n            model.set_attr('step', t)\n            val_metrics = validate(start_epoch, model, loader_eval, validate_loss_fn, args,\n                                   visualize=args.visualize, spike_rate=args.spike_rate,\n                                   tsne=args.tsne, conf_mat=args.conf_mat, summary_writer=summary_writer)\n            print(f\"[STEP:{t}], Top-1 accuracy of the model is: {val_metrics['top1']:.1f}%\")\n        return\n\n    saver = None\n    if args.local_rank == 0:\n        decreasing = True if eval_metric == 'loss' else False\n        saver = CheckpointSaver(\n            model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,\n            checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=3)\n        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:\n            f.write(args_text)\n\n    try:  # train the model\n        if args.reset_drop:\n            model_without_ddp.reset_drop_path(0.0)\n        for epoch in range(start_epoch, args.epochs):\n            if epoch == 0 and args.reset_drop:\n                model_without_ddp.reset_drop_path(args.drop_path)\n\n            if args.distributed:\n                loader_train.sampler.set_epoch(epoch)\n\n            train_metrics = train_epoch(\n                epoch, model, loader_train, optimizer, train_loss_fn, args,\n                lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,\n                amp_autocast=amp_autocast, loss_scaler=loss_scaler,\n                model_ema=model_ema, mixup_fn=mixup_fn, summary_writer=summary_writer\n            )\n\n            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):\n                if args.local_rank == 0:\n                    _logger.info(\"Distributing BatchNorm running means and vars\")\n                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')\n\n            eval_metrics = validate(epoch, model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast,\n                                    visualize=args.visualize, spike_rate=args.spike_rate,\n                                    tsne=args.tsne, conf_mat=args.conf_mat, summary_writer=summary_writer)\n\n            if model_ema is not None and not args.model_ema_force_cpu:\n                if args.distributed and args.dist_bn in ('broadcast', 'reduce'):\n                    distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')\n                ema_eval_metrics = validate(\n                    epoch, model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)',\n                    visualize=args.visualize, spike_rate=args.spike_rate,\n                    tsne=args.tsne, conf_mat=args.conf_mat, summary_writer=summary_writer\n                )\n                eval_metrics = ema_eval_metrics\n\n            if lr_scheduler is not None:\n                # step LR for next epoch\n                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])\n\n            update_summary(\n                epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),\n                write_header=best_metric is None)\n\n            # if saver is not None and epoch >= args.n_warm_up:\n            if saver is not None:\n                # save proper checkpoint with eval metric\n                save_metric = eval_metrics[eval_metric]\n                best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)\n\n    except KeyboardInterrupt:\n        pass\n    if best_metric is not None:\n        _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))\n\n\ndef train_epoch(\n        epoch, model, loader, optimizer, loss_fn, args,\n        lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,\n        loss_scaler=None, model_ema=None, mixup_fn=None, summary_writer=None):\n    if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:\n        if args.prefetcher and loader.mixup_enabled:\n            loader.mixup_enabled = False\n        elif mixup_fn is not None:\n            mixup_fn.mixup_enabled = False\n\n    second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order\n    batch_time_m = AverageMeter()\n    data_time_m = AverageMeter()\n    losses_m = AverageMeter()\n    # closses_m = AverageMeter()\n    top1_m = AverageMeter()\n    top5_m = AverageMeter()\n\n    model.train()\n\n    # t, k = adjust_surrogate_coeff(100, args.epochs)\n    # model.set_attr('t', t)\n    # model.set_attr('k', k)\n\n    end = time.time()\n    last_idx = len(loader) - 1\n    num_updates = epoch * len(loader)\n    iters_per_epoch = len(loader)\n    for batch_idx, (inputs, target) in enumerate(loader):\n        last_batch = batch_idx == last_idx\n        if args.rand_step:\n            step = buildin_random.randint(1, args.step + 2)\n            model.set_attr('step', step)\n\n        data_time_m.update(time.time() - end)\n        if not args.prefetcher or args.dataset != 'imnet':\n            inputs, target = inputs.type(torch.FloatTensor).cuda(), target.cuda()\n            if mixup_fn is not None:\n                inputs, target = mixup_fn(inputs, target)\n        if args.channels_last:\n            inputs = inputs.contiguous(memory_format=torch.channels_last)\n        with amp_autocast():\n            output = model(inputs)\n            loss = loss_fn(output, target)\n        if args.tet_loss:\n            output = output.mean(0)\n\n        if not (args.cut_mix | args.mix_up | args.event_mix | (args.cutmix != 0.) | (args.mixup != 0.)):\n            # print(output.shape, target.shape)\n            acc1, acc5 = accuracy(output, target, topk=(1, 5))\n            # acc1, = accuracy(output, target)\n        else:\n            acc1, acc5 = torch.tensor([0.]), torch.tensor([0.])\n\n        optimizer.zero_grad()\n        if loss_scaler is not None:\n            loss_scaler(\n                loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order)\n        else:\n            loss.backward(create_graph=second_order)\n            if args.noisy_grad != 0.:\n                random_gradient(model, args.noisy_grad)\n            if args.clip_grad is not None:\n                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)\n            # if args.opt == 'lamb':\n            #     optimizer.step(epoch=epoch)\n            # else:\n            optimizer.step()\n\n        torch.cuda.synchronize()\n        if model_ema is not None:\n            model_ema.update(model)\n        num_updates += 1\n\n        batch_time_m.update(time.time() - end)\n\n        if args.local_rank == 0:\n            summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/train/top1'), acc1.item(), epoch * iters_per_epoch + batch_idx)\n            summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/train/top5'), acc5.item(), epoch * iters_per_epoch + batch_idx)\n            summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/train/loss'), loss.item(), epoch * iters_per_epoch + batch_idx)\n\n        if last_batch or batch_idx % args.log_interval == 0:\n            lrl = [param_group['lr'] for param_group in optimizer.param_groups]\n            lr = sum(lrl) / len(lrl)\n\n            if args.distributed:\n                loss = reduce_tensor(loss.data, args.world_size)\n                acc1 = reduce_tensor(acc1, args.world_size)\n                acc5 = reduce_tensor(acc5, args.world_size)\n\n            losses_m.update(loss.item(), inputs.size(0))\n            top1_m.update(acc1.item(), output.size(0))\n            top5_m.update(acc5.item(), output.size(0))\n                # closses_m.update(reduced_loss.item(), inputs.size(0))\n\n            if args.local_rank == 0:\n                # if args.distributed:\n                _logger.info(\n                    'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '\n                    'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f})  '\n                    'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '\n                    'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})  '\n                    'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '\n                    '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '\n                    'LR: {lr:.3e}  '\n                    'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(\n                        epoch,\n                        batch_idx, len(loader),\n                        100. * batch_idx / last_idx,\n                        loss=losses_m,\n                        top1=top1_m,\n                        top5=top5_m,\n                        batch_time=batch_time_m,\n                        rate=inputs.size(0) * args.world_size / batch_time_m.val,\n                        rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,\n                        lr=lr,\n                        data_time=data_time_m\n                    ))\n\n                if args.save_images and output_dir:\n                    torchvision.utils.save_image(\n                        inputs,\n                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),\n                        padding=0,\n                        normalize=True)\n\n        if saver is not None and args.recovery_interval and (\n                last_batch or (batch_idx + 1) % args.recovery_interval == 0):\n            saver.save_recovery(epoch, batch_idx=batch_idx)\n\n        if lr_scheduler is not None:\n            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)\n\n        end = time.time()\n    # end for\n\n    if hasattr(optimizer, 'sync_lookahead'):\n        optimizer.sync_lookahead()\n\n    if args.local_rank == 0:\n        summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/train/top1'), top1_m.avg, epoch)\n        summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/train/top5'), top5_m.avg, epoch)\n        summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/train/loss'), losses_m.avg, epoch)\n\n    if args.rand_step:\n        model.set_attr('step', args.step)\n\n    return OrderedDict([('loss', losses_m.avg)])\n\n\ndef validate(epoch, model, loader, loss_fn, args, amp_autocast=suppress,\n             log_suffix='', visualize=False, spike_rate=False, tsne=False, conf_mat=False, summary_writer=None):\n    batch_time_m = AverageMeter()\n    losses_m = AverageMeter()\n    # closses_m = AverageMeter()\n    top1_m = AverageMeter()\n    top5_m = AverageMeter()\n    spike_m = AverageMeter()\n\n    model.eval()\n\n    feature_vec = []\n    feature_cls = []\n    logits_vec = []\n    labels_vec = []\n    mem_vec = []\n\n    end = time.time()\n    last_idx = len(loader) - 1\n    iters_per_epoch = len(loader)\n    with torch.no_grad():\n\n        for batch_idx, (inputs, target) in enumerate(loader):\n            # inputs = inputs.type(torch.float64)\n            last_batch = batch_idx == last_idx\n            if not args.prefetcher or args.dataset != 'imnet':\n                inputs = inputs.type(torch.FloatTensor).cuda()\n                target = target.cuda()\n            if args.channels_last:\n                inputs = inputs.contiguous(memory_format=torch.channels_last)\n\n            if not args.distributed:\n                if (visualize or spike_rate or tsne or conf_mat or args.mem_dist) and not args.critical_loss:\n                    model.set_requires_fp(True)\n\n            with amp_autocast():\n                output = model(inputs)\n\n            if isinstance(output, (tuple, list)):\n                output = output[0]\n\n            # augmentation reduction\n            reduce_factor = args.tta\n            if reduce_factor > 1:\n                output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)\n                target = target[0:target.size(0):reduce_factor]\n\n            # print(args.rank, output.shape, target.shape, max(target))\n            loss = loss_fn(output, target)\n            if args.tet_loss:\n                output = output.mean(0)\n            acc1, acc5 = accuracy(output, target, topk=(1, 5))\n\n            if args.distributed:\n                reduced_loss = reduce_tensor(loss.data, args.world_size)\n                acc1 = reduce_tensor(acc1, args.world_size)\n                acc5 = reduce_tensor(acc5, args.world_size)\n            else:\n                reduced_loss = loss.data\n\n            torch.cuda.synchronize()\n\n            losses_m.update(reduced_loss.item(), inputs.size(0))\n            top1_m.update(acc1.item(), output.size(0))\n            top5_m.update(acc5.item(), output.size(0))\n            # closses_m.update(closs, inputs.size(0))\n\n            batch_time_m.update(time.time() - end)\n            end = time.time()\n\n            if args.local_rank == 0:\n                summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/val/top1'), acc1.item(), epoch * iters_per_epoch + batch_idx)\n                summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/val/top5'), acc5.item(), epoch * iters_per_epoch + batch_idx)\n                summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/val/loss'), loss.item(), epoch * iters_per_epoch + batch_idx)\n\n            if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):\n                log_name = 'Test' + log_suffix\n\n            if not args.distributed and spike_rate:\n                spike_m.update(model.get_tot_spike() / output.size(0), output.size(0))\n\n                if not args.distributed and spike_rate:\n                    _logger.info(\n                        '[Spike Info]: {spike.val} ({spike.avg})'.format(\n                            spike=spike_m\n                        )\n                    )\n            if last_batch or batch_idx % args.log_interval == 0:\n                _logger.info(\n                    'Eval : {} '\n                    'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '\n                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '\n                    'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})'\n                    'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(\n                        epoch,\n                        batch_idx,\n                        last_idx,\n                        batch_time=batch_time_m,\n                        loss=losses_m,\n                        top1=top1_m,\n                        top5=top5_m,\n                        ))\n\n    # metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])\n    metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg)])\n\n    if args.local_rank == 0:\n        summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/val/top1'), top1_m.avg, epoch)\n        summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/val/top5'), top5_m.avg, epoch)\n        summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/val/loss'), losses_m.avg, epoch)\n    return metrics\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "examples/Perception_and_Learning/img_cls/bp/main_backei.py",
    "content": "import torchvision.datasets as datasets\nimport torchvision.transforms as transforms\nimport time\nfrom braincog.model_zoo.backeinet import *\nimport argparse\nimport os\nimport json\n\n\n\nparser = argparse.ArgumentParser(\"description = train.py\")\nparser.add_argument('-seed', type=int, default=4150)\nparser.add_argument('-epoch', type=int, default=200)\nparser.add_argument('-batch_size', type=int, default=100)\nparser.add_argument('-learning_rate', type=float, default=1e-3)\nparser.add_argument('--dataset', type=str, default='fashion')\nparser.add_argument('--simulation_len', type=int, default=20)\nparser.add_argument('--Back', action='store_true', default=False)\nparser.add_argument('--EI', action='store_true', default=False)\nparser.add_argument('--device', type=int, default=1)\nparser.add_argument('--encode-type', type=str, default='direct')\nopt = parser.parse_args()\ntorch.cuda.set_device('cuda:%d' % opt.device)\ntorch.manual_seed(opt.seed)\ntorch.cuda.manual_seed(opt.seed)\n\ntest_scores = []\ntrain_scores = []\nsave_path = opt.dataset + '_' + str(opt.seed) + '_' + opt.encode_type\nif opt.Back:\n    save_path += '_Back'\nif opt.EI:\n    save_path += '_EI'\n\nif not os.path.exists(save_path):\n    os.mkdir(save_path)\nnormalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))\nif opt.dataset == 'mnist':\n    train_dataset = datasets.MNIST(root='./data/mnist/', train=True, transform=transforms.ToTensor(), download=True)\n    test_dataset = datasets.MNIST(root='./data/mnist/', train=False, transform=transforms.ToTensor())\n    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=opt.batch_size, shuffle=True)\n    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=opt.batch_size, shuffle=False)\nelif opt.dataset == 'fashion':\n    train_dataset = datasets.FashionMNIST(root='./data/fashion/', train=True, transform=transforms.ToTensor(),\n                                          download=True)\n    test_dataset = datasets.FashionMNIST(root='./data/fashion/', train=False, transform=transforms.ToTensor())\n    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=opt.batch_size, shuffle=True)\n    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=opt.batch_size, shuffle=False)\n\nelif opt.dataset == 'cifar10':\n    train_dataset = datasets.CIFAR10(root='./data/cifar10/', train=True, transform=transforms.Compose(\n        [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), normalize]),\n                                     download=True)\n    test_dataset = datasets.CIFAR10(root='./data/cifar10/', train=False,\n                                    transform=transforms.Compose([transforms.ToTensor(), normalize]))\n\n    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=opt.batch_size, shuffle=True)\n    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=opt.batch_size, shuffle=False)\nif opt.dataset == 'cifar10':\n    snn = CIFARNet(step=opt.simulation_len, if_back=opt.Back, if_ei=opt.EI,  encode_type=opt.encode_type)\nelse:\n    snn = MNISTNet(step=opt.simulation_len, if_back=opt.Back, if_ei=opt.EI, data=opt.dataset, encode_type=opt.encode_type)\nsnn = snn.cuda()\ncriterion = nn.MSELoss()\noptimizer = torch.optim.Adam(snn.parameters(), lr=opt.learning_rate)\nscheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)\n\n\ndef train(epoch):\n    snn.train()\n\n    start_time = time.time()\n    total_loss = 0\n    correct = 0\n    total = 0\n    for i, (images, labels) in enumerate(train_loader):\n        optimizer.zero_grad()\n        images = images.cuda()\n        outputs = snn(images)\n        labels_ = torch.zeros(opt.batch_size, 10).scatter_(1, labels.view(-1, 1), 1).cuda()\n        loss = criterion(outputs, labels_)\n        total_loss += loss.item()\n        loss.backward()\n        optimizer.step()\n        pred = outputs.max(1)[1]\n        total += labels.size(0)\n        correct += (pred.cpu() == labels).sum()\n        if (i + 1) % (60000 // (opt.batch_size * 6)) == 0:\n            print('Epoch: [%d/%d], Step: [%d/%d], Loss: %.4f, Time: %.2f' % (\n                epoch + 1, opt.epoch, i + 1, 60000 // opt.batch_size, total_loss,\n                time.time() - start_time))\n            start_time = time.time()\n            total_loss = 0\n    acc = 100.0 * correct.item() / total\n    train_scores.append(acc)\n\n\ndef eval(epoch):\n    snn.eval()\n    correct = 0\n    total = 0\n    with torch.no_grad():\n        for i, (images, labels) in enumerate(test_loader):\n            images = images.cuda()\n            outputs = snn(images)\n\n            pred = outputs.max(1)[1]\n            total += labels.size(0)\n            correct += (pred.cpu() == labels).sum()\n    acc = 100.0 * correct.item() / total\n    print('Test correct: %d Accuracy: %.2f%%' % (correct, acc))\n    test_scores.append(acc)\n    if acc >= max(test_scores):\n        save_file = str(epoch) + '.pt'\n        torch.save(snn, os.path.join(save_path, save_file))\n    return max(test_scores)\n\n\ndef main():\n    for epoch in range(opt.epoch):\n        train(epoch)\n        best_acc = eval(epoch)\n        scheduler.step()\n        print('Best Accuracy: %.2f%%' % (best_acc))\n\n\nif __name__ == '__main__':\n    main()\n    filename = \"train.json\"\n    filename = os.path.join(save_path, filename)\n    with open(filename, \"w\") as f:\n        json.dump(train_scores, f)\n    filename = \"test.json\"\n    filename = os.path.join(save_path, filename)\n    with open(filename, \"w\") as f:\n        json.dump(test_scores, f)\n"
  },
  {
    "path": "examples/Perception_and_Learning/img_cls/bp/main_simplified.py",
    "content": "# encoding: utf-8\n# Author    : Floyed<Floyed_Shen@outlook.com>\n# Datetime  : 2022/4/28 14:56\n# User      : Floyed\n# Product   : PyCharm\n# Project   : braincog\n# File      : main_simplified.py\n# explain   : Simplified training script. Remove support for DDP, IMAGENET, Augment, etc.\n\nimport argparse\nimport time\n\nimport timm.models\nimport yaml\nimport os\nimport logging\nfrom collections import OrderedDict\nfrom contextlib import suppress\nfrom datetime import datetime\n\nfrom braincog.base.node.node import *\nfrom braincog.utils import *\nfrom braincog.base.utils.criterions import *\nfrom braincog.datasets.datasets import *\nfrom braincog.model_zoo.resnet import *\nfrom braincog.model_zoo.convnet import *\nfrom braincog.utils import save_feature_map\n\nimport torch\nimport torch.nn as nn\nimport torchvision.utils\n\nfrom timm.data import ImageDataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset\nfrom timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model\nfrom timm.utils import *\nfrom timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy\nfrom timm.optim import create_optimizer\nfrom timm.scheduler import create_scheduler\nfrom timm.utils import ApexScaler, NativeScaler\n\n# from ptflops import get_model_complexity_info\nfrom thop import profile, clever_format\n\ntorch.backends.cudnn.benchmark = True\n_logger = logging.getLogger('train')\n\n# The first arg parser parses out only the --config argument, this argument is used to\n# load a yaml file containing key-values that override the defaults for the main parser below\nconfig_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)\nparser.add_argument('-c', '--config', default='', type=str, metavar='FILE',\n                    help='YAML config file specifying default arguments')\n\nparser = argparse.ArgumentParser(description='SNN Training and Evaluating')\n\n# Model parameters\nparser.add_argument('--dataset', default='cifar10', type=str)\nparser.add_argument('--model', default='cifar_convnet', type=str, metavar='MODEL',\n                    help='Name of model to train (default: \"countception\"')\nparser.add_argument('--resume', default='', type=str, metavar='PATH',\n                    help='Resume full model and optimizer state from checkpoint (default: none)')\nparser.add_argument('--num-classes', type=int, default=10, metavar='N',\n                    help='number of label classes (default: 10)')\n\n# Dataloader parameters\nparser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',\n                    help='inputs batch size for training (default: 128)')\n\n# Optimizer parameters\nparser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',\n                    help='Optimizer (default: \"adamw\"')\nparser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',\n                    help='Optimizer Epsilon (default: None, use opt default)')\nparser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',\n                    help='Optimizer Betas (default: None, use opt default)')\nparser.add_argument('--momentum', type=float, default=0.9, metavar='M',\n                    help='Optimizer momentum (default: 0.9)')\nparser.add_argument('--weight-decay', type=float, default=0.01,\n                    help='weight decay (default: 0.01 for adamw)')\nparser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',\n                    help='Clip gradient norm (default: None, no clipping)')\nparser.add_argument('--adam-epoch', type=int, default=1000, help='lamb switch to adamw')\n\n# Learning rate schedule parameters\nparser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',\n                    help='LR scheduler (default: \"cosine\"')\nparser.add_argument('--lr', type=float, default=5e-3, metavar='LR',\n                    help='learning rate (default: 0.01)')\nparser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',\n                    help='learning rate noise on/off epoch percentages')\nparser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',\n                    help='learning rate noise limit percent (default: 0.67)')\nparser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',\n                    help='learning rate noise std-dev (default: 1.0)')\nparser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',\n                    help='learning rate cycle len multiplier (default: 1.0)')\nparser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',\n                    help='learning rate cycle limit')\nparser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',\n                    help='warmup learning rate (default: 0.0001)')\nparser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',\n                    help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')\nparser.add_argument('--epochs', type=int, default=600, metavar='N',\n                    help='number of epochs to train (default: 2)')\nparser.add_argument('--start-epoch', default=None, type=int, metavar='N',\n                    help='manual epoch number (useful on restarts)')\nparser.add_argument('--decay-epochs', type=float, default=30, metavar='N',\n                    help='epoch interval to decay LR')\nparser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',\n                    help='epochs to warmup LR, if scheduler supports')\nparser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',\n                    help='epochs to cooldown LR at min_lr, after cyclic schedule ends')\nparser.add_argument('--patience-epochs', type=int, default=10, metavar='N',\n                    help='patience epochs for Plateau LR scheduler (default: 10')\nparser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',\n                    help='LR decay rate (default: 0.1)')\nparser.add_argument('--power', type=int, default=1, help='power')\n\n# Misc\nparser.add_argument('--seed', type=int, default=42, metavar='S',\n                    help='random seed (default: 42)')\nparser.add_argument('--log-interval', type=int, default=50, metavar='N',\n                    help='how many batches to wait before logging training status')\nparser.add_argument('--recovery-interval', type=int, default=0, metavar='N',\n                    help='how many batches to wait before writing recovery checkpoint')\nparser.add_argument('-j', '--workers', type=int, default=8, metavar='N',\n                    help='how many training processes to use (default: 1)')\nparser.add_argument('--device', type=int, default=0)\nparser.add_argument('--output', default='/data/floyed/braincog', type=str, metavar='PATH',\n                    help='path to output folder (default: none, current dir)')\nparser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',\n                    help='Best metric (default: \"top1\"')\n\n# Spike parameters\nparser.add_argument('--step', type=int, default=10, help='Simulation time step (default: 10)')\nparser.add_argument('--encode', type=str, default='direct', help='Input encode method (default: direct)')\n# neuron type\nparser.add_argument('--node-type', type=str, default='PLIFNode', help='Node type in network (default: PLIF)')\nparser.add_argument('--act-fun', type=str, default='AtanGrad',\n                    help='Surogate Function in node. Only for Surrogate nodes (default: AtanGrad)')\nparser.add_argument('--thresh', type=float, default=.5, help='Firing threshold (default: 0.5)')\nparser.add_argument('--tau', type=float, default=2., help='Attenuation coefficient (default: 2.)')\n\nparser.add_argument('--loss-fn', type=str, default='ce', help='loss function (default: ce)')\nparser.add_argument('--noisy-grad', type=float, default=0.,\n                    help='Add noise to backward, sometime will make higher accuracy (default: 0.)')\nparser.add_argument('--n_warm_up', type=int, default=0,\n                    help='Warm up epoch, replace all node to ReLU to warm up weights in network before (default: 0)')\nparser.add_argument('--spike-output', action='store_true', default=False,\n                    help='Using mem output or spike output (default: False)')\n\n# EventData Augmentation\nparser.add_argument('--mix-up', action='store_true', help='Mix-up for event data (default: False)')\nparser.add_argument('--cut-mix', action='store_true', help='CutMix for event data (default: False)')\nparser.add_argument('--event-mix', action='store_true', help='EventMix for event data (default: False)')\nparser.add_argument('--cutmix_beta', type=float, default=1.0, help='cutmix_beta (default: 1.)')\nparser.add_argument('--cutmix_prob', type=float, default=0.5, help='cutmix_prib for event data (default: .5)')\nparser.add_argument('--cutmix_num', type=int, default=1, help='cutmix_num for event data (default: 1)')\nparser.add_argument('--cutmix_noise', type=float, default=0.,\n                    help='Add Pepper noise after mix, sometimes work (default: 0.)')\nparser.add_argument('--rand-aug', action='store_true',\n                    help='Rand Augment for Event data (default: False)')\nparser.add_argument('--randaug_n', type=int, default=3,\n                    help='Rand Augment times n (default: 3)')\nparser.add_argument('--randaug_m', type=int, default=15,\n                    help='Rand Augment times n (default: 15) (0-30)')\nparser.add_argument('--temporal-flatten', action='store_true',\n                    help='Temporal flatten to channels. ONLY FOR EVENT DATA TRAINING BY ANN')\nparser.add_argument('--train-portion', type=float, default=0.9,\n                    help='Dataset portion, only for datasets which do not have validation set (default: 0.9)')\nparser.add_argument('--event-size', default=48, type=int,\n                    help='Event size. Resize event data before process (default: 48)')\nparser.add_argument('--layer-by-layer', action='store_true',\n                    help='forward step-by-step or layer-by-layer. '\n                         'Larger Model with layer-by-layer will be faster (default: False)')\nparser.add_argument('--node-resume', type=str, default='',\n                    help='resume weights in node for adaptive node. (default: False)')\nparser.add_argument('--node-trainable', action='store_true')\n\n# visualize\nparser.add_argument('--visualize', action='store_true',\n                    help='Visualize spiking map for each layer, only for validate (default: False)')\nparser.add_argument('--spike-rate', action='store_true',\n                    help='Print spiking rate for each layer, only for validate(default: False)')\n\nparser.add_argument('--suffix', type=str, default='',\n                    help='Add an additional suffix to the save path (default: \\'\\')')\n\n\ndef _parse_args():\n    # Do we have a config file to parse?\n    args_config, remaining = config_parser.parse_known_args()\n    if args_config.config:\n        with open(args_config.config, 'r') as f:\n            cfg = yaml.safe_load(f)\n            parser.set_defaults(**cfg)\n\n    # The main arg parser parses the rest of the args, the usual\n    # defaults will have been overridden if config file specified.\n    args = parser.parse_args(remaining)\n\n    # Cache the args as a text string to save them in the output dir later\n    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)\n    return args, args_text\n\n\ndef main():\n    args, args_text = _parse_args()\n    # args.no_spike_output = args.no_spike_output | args.cut_mix\n    args.no_spike_output = True\n    output_dir = ''\n    output_base = args.output if args.output else './output'\n    exp_name = '-'.join([\n        datetime.now().strftime(\"%Y%m%d-%H%M%S\"),\n        args.model,\n        args.dataset,\n        str(args.step),\n        args.suffix\n        # str(args.img_size)\n    ])\n    output_dir = get_outdir(output_base, 'train', exp_name)\n    args.output_dir = output_dir\n    setup_default_logging(log_path=os.path.join(output_dir, 'log.txt'))\n\n    torch.cuda.set_device('cuda:%d' % args.device)\n\n    torch.manual_seed(args.seed)\n\n    model = create_model(\n        args.model,\n        num_classes=args.num_classes,\n        dataset=args.dataset,\n        step=args.step,\n        encode_type=args.encode,\n        node_type=eval(args.node_type),\n        threshold=args.thresh,\n        tau=args.tau,\n        spike_output=not args.no_spike_output,\n        act_fun=args.act_fun,\n        temporal_flatten=args.temporal_flatten,\n        layer_by_layer=args.layer_by_layer,\n    )\n\n    if 'dvs' in args.dataset:\n        args.channels = 2\n    elif 'mnist' in args.dataset:\n        args.channels = 1\n    else:\n        args.channels = 3\n    # flops, params = profile(model, inputs=(torch.randn(1, args.channels, args.img_size, args.img_size),), verbose=False)\n    # _logger.info('flops = %fM', flops / 1e6)\n    # _logger.info('param size = %fM', params / 1e6)\n\n    linear_scaled_lr = args.lr * args.batch_size / 1024.0\n    args.lr = linear_scaled_lr\n\n    model = model.cuda()\n\n    optimizer = create_optimizer(args, model)\n\n    # optionally resume from a checkpoint\n    resume_epoch = None\n    if args.resume:\n        # checkpoint = torch.load(args.resume, map_location='cpu')\n        # model.load_state_dict(checkpoint['state_dict'], False)\n        resume_epoch = resume_checkpoint(\n            model, args.resume,\n            optimizer=None if args.no_resume_opt else optimizer)\n\n    if args.node_resume:\n        ckpt = torch.load(args.node_resume, map_location='cpu')\n        model.load_node_weight(ckpt, args.node_trainable)\n\n    lr_scheduler, num_epochs = create_scheduler(args, optimizer)\n\n    start_epoch = 0\n    if args.start_epoch is not None:\n        # a specified start_epoch will always override the resume epoch\n        start_epoch = args.start_epoch\n    elif resume_epoch is not None:\n        start_epoch = resume_epoch\n    if lr_scheduler is not None and start_epoch > 0:\n        lr_scheduler.step(start_epoch)\n\n    _logger.info('Scheduled epochs: {}'.format(num_epochs))\n\n    # now config only for imnet fcvawefdadw\n    loader_train, loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % args.dataset)(\n            batch_size=args.batch_size,\n            step=args.step,\n            size=args.event_size,\n            mix_up=args.mix_up,\n            cut_mix=args.cut_mix,\n            event_mix=args.event_mix,\n            beta=args.cutmix_beta,\n            prob=args.cutmix_prob,\n            num=args.cutmix_num,\n            noise=args.cutmix_noise,\n            num_classes=args.num_classes,\n            rand_aug=args.rand_aug,\n            randaug_n=args.randaug_n,\n            randaug_m=args.randaug_m,\n            temporal_flatten=args.temporal_flatten,\n            portion=args.train_portion)\n\n    if args.loss_fn == 'mse':\n        train_loss_fn = UnilateralMse(1.)\n        validate_loss_fn = UnilateralMse(1.)\n\n    else:\n        if mixup_active:\n            # smoothing is handled with mixup target transform\n            train_loss_fn = SoftTargetCrossEntropy().cuda()\n        else:\n            train_loss_fn = nn.CrossEntropyLoss().cuda()\n\n        validate_loss_fn = nn.CrossEntropyLoss().cuda()\n\n    if args.loss_fn == 'mix':\n        train_loss_fn = MixLoss(train_loss_fn)\n        validate_loss_fn = MixLoss(validate_loss_fn)\n\n    eval_metric = args.eval_metric\n    best_metric = None\n    best_epoch = None\n\n    saver = CheckpointSaver(\n        model=model, optimizer=optimizer, args=args,\n        checkpoint_dir=output_dir, recovery_dir=output_dir)\n    with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:\n        f.write(args_text)\n\n    try:  # train the model\n        for epoch in range(start_epoch, args.epochs):\n\n            if args.visualize or args.spike_rate:\n                print('start to plot feature map / calc spike rate')\n                validate(model, loader_eval, validate_loss_fn, args,\n                         visualize=args.visualize, spike_rate=args.spike_rate)\n                exit(0)\n\n            train_metrics = train_epoch(\n                epoch, model, loader_train, optimizer, train_loss_fn, args,\n                lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir)\n\n            eval_metrics = validate(model, loader_eval, validate_loss_fn, args)\n\n            if lr_scheduler is not None:\n                # step LR for next epoch\n                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])\n\n            update_summary(\n                epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),\n                write_header=best_metric is None)\n\n            if saver is not None and epoch >= args.n_warm_up:\n                # save proper checkpoint with eval metric\n                save_metric = eval_metrics[eval_metric]\n                best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)\n\n    except KeyboardInterrupt:\n        pass\n    if best_metric is not None:\n        _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))\n\n\ndef train_epoch(\n        epoch, model, loader, optimizer, loss_fn, args,\n        lr_scheduler=None, saver=None, output_dir=''):\n\n    batch_time_m = AverageMeter()\n    data_time_m = AverageMeter()\n    losses_m = AverageMeter()\n    top1_m = AverageMeter()\n    top5_m = AverageMeter()\n\n    model.train()\n\n    end = time.time()\n    last_idx = len(loader) - 1\n    num_updates = epoch * len(loader)\n    for batch_idx, (inputs, target) in enumerate(loader):\n        last_batch = batch_idx == last_idx\n        data_time_m.update(time.time() - end)\n        inputs, target = inputs.type(torch.FloatTensor).cuda(), target.cuda()\n\n        output = model(inputs)\n        loss = loss_fn(output, target)\n        if not (args.cut_mix | args.mix_up | args.event_mix):\n            acc1, acc5 = accuracy(output, target, topk=(1, 5))\n        else:\n            acc1, acc5 = torch.tensor([0.]), torch.tensor([0.])\n\n        losses_m.update(loss.item(), inputs.size(0))\n        top1_m.update(acc1.item(), inputs.size(0))\n        top5_m.update(acc5.item(), inputs.size(0))\n\n        optimizer.zero_grad()\n\n        loss.backward()\n        if args.noisy_grad != 0.:\n            random_gradient(model, args.noisy_grad)\n        if args.clip_grad is not None:\n            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)\n        optimizer.step()\n\n        num_updates += 1\n\n        batch_time_m.update(time.time() - end)\n        if last_batch or batch_idx % args.log_interval == 0:\n            lrl = [param_group['lr'] for param_group in optimizer.param_groups]\n            lr = sum(lrl) / len(lrl)\n\n            _logger.info(\n                'Train: {} [{:>4d}/{} ({:>3.0f}%)]  ' \n                'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f})  '\n                'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '\n                'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'\n                'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '\n                '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '\n                'LR: {lr:.3e}  '\n                'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(\n                    epoch,\n                    batch_idx, len(loader),\n                    100. * batch_idx / last_idx,\n                    loss=losses_m,\n                    top1=top1_m, top5=top5_m,\n                    batch_time=batch_time_m,\n                    rate=inputs.size(0) / batch_time_m.val,\n                    rate_avg=inputs.size(0)  / batch_time_m.avg,\n                    lr=lr,\n                    data_time=data_time_m))\n\n        if saver is not None and args.recovery_interval and (\n                last_batch or (batch_idx + 1) % args.recovery_interval == 0):\n            saver.save_recovery(epoch, batch_idx=batch_idx)\n\n        if lr_scheduler is not None:\n            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)\n\n        end = time.time()\n    # end for\n\n    if hasattr(optimizer, 'sync_lookahead'):\n        optimizer.sync_lookahead()\n\n    return OrderedDict([('loss', losses_m.avg)])\n\n\ndef validate(model, loader, loss_fn, args, log_suffix='', visualize=False, spike_rate=False):\n    batch_time_m = AverageMeter()\n    losses_m = AverageMeter()\n    top1_m = AverageMeter()\n    top5_m = AverageMeter()\n\n    model.eval()\n\n    end = time.time()\n    last_idx = len(loader) - 1\n    with torch.no_grad():\n        for batch_idx, (inputs, target) in enumerate(loader):\n            # inputs = inputs.type(torch.float64)\n            last_batch = batch_idx == last_idx\n            inputs = inputs.type(torch.FloatTensor).cuda()\n            target = target.cuda()\n\n            if visualize or spike_rate:\n                model.set_requires_fp(True)\n\n            output = model(inputs)\n            if isinstance(output, (tuple, list)):\n                output = output[0]\n\n            if visualize:\n                x = model.get_fp()\n                feature_path = os.path.join(args.output_dir, 'feature_map')\n                if os.path.exists(feature_path) is False:\n                    os.mkdir(feature_path)\n                save_feature_map(x, feature_path)\n                model.set_requires_fp(False)\n\n            if spike_rate:\n                _logger.info(model.get_fire_rate_per_layer())\n                model.set_requires_fp(False)\n\n            loss = loss_fn(output, target)\n            acc1, acc5 = accuracy(output, target, topk=(1, 5))\n\n            reduced_loss = loss.data\n\n            torch.cuda.synchronize()\n\n            losses_m.update(reduced_loss.item(), inputs.size(0))\n            top1_m.update(acc1.item(), output.size(0))\n            top5_m.update(acc5.item(), output.size(0))\n\n            batch_time_m.update(time.time() - end)\n            end = time.time()\n            if last_batch or batch_idx % args.log_interval == 0:\n                log_name = 'Test' + log_suffix\n                _logger.info(\n                    '{0}: [{1:>4d}/{2}]  '\n                    'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '\n                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  ' \n                    'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '\n                    'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(\n                        log_name, batch_idx, last_idx, batch_time=batch_time_m,\n                        loss=losses_m, top1=top1_m, top5=top5_m))\n\n    metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])\n\n    return metrics\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "examples/Perception_and_Learning/img_cls/glsnn/README.md",
    "content": "# SNN with global feedback connections\nTraining deep spiking neural network with the global \nfeedback connections and the local optimization learning rules. And is a little different from our original paper.\n\nGLSNN: A Multi-layer Spiking Neural Network based on Global Feedback Alignment and Local STDP Plasticity.\n\n## Results\n```shell\npython cls_glsnn.py\n```\nWe train the model for 100 epochs, and the best accuracy for MNIST is 98.23\\%, for FashionMNIST is 89.68\\%.\n![image](result_zdc.png)\n\n## Citation\n\nIf you find the code and dataset useful in your research, please consider citing:\n```\n@article{zhao2020glsnn,\n  title={GLSNN: A Multi-Layer Spiking Neural Network Based on Global Feedback Alignment and Local STDP Plasticity},\n  author={Zhao, Dongcheng and Zeng, Yi and Zhang, Tielin and Shi, Mengting and Zhao, Feifei},\n  journal={Frontiers in Computational Neuroscience},\n  volume={14},\n  year={2020},\n  publisher={Frontiers Media SA}\n}\n\n@misc{https://doi.org/10.48550/arxiv.2207.08533,\n  doi = {10.48550/ARXIV.2207.08533},\n  url = {https://arxiv.org/abs/2207.08533},\n  author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},\n  title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},\n  publisher = {arXiv},\n  year = {2022},\n}\n```\n## Contents\nFeedbacks and comments are welcome! Feel free to contact us via [zhaodongcheng2016@ia.ac.cn](zhaodongcheng2016@ia.ac.cn) \n\nEnjoy!"
  },
  {
    "path": "examples/Perception_and_Learning/img_cls/glsnn/cls_glsnn.py",
    "content": "import torch\nfrom torchvision import transforms\nimport torchvision.datasets as datasets\nfrom torch.utils.data import DataLoader\nfrom braincog.model_zoo.glsnn import BaseGLSNN\nimport argparse\nimport time\nimport os\n\nimport json\n\nos.environ['CUDA_VISIBLE_DEVICES'] = \"3\"\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nparser = argparse.ArgumentParser(\"description = GLSNN.py\")\nparser.add_argument('-seed', type=int, default=2122)\nparser.add_argument('-epoch', type=int, default=100)\nparser.add_argument('-batch_size', type=int, default=100)\nparser.add_argument('-lr_target', type=float, default=0.4)\nparser.add_argument('-lr_forward', type=float, default=0.001)\nparser.add_argument('-step', type=int, default=10)\nparser.add_argument('-encode_type', type=str, default='direct')\nparser.add_argument('--dataset', type=str, default='MNIST')\nopt = parser.parse_args()\n\ntorch.manual_seed(opt.seed)\ntorch.cuda.manual_seed(opt.seed)\n\ntest_scores = []\ntrain_scores = []\n\nsave_path = './' + 'GLSNN' + '_' + opt.dataset + '_' + str(opt.seed)\nif not os.path.exists(save_path):\n    os.mkdir(save_path)\nif opt.dataset == 'MNIST':\n    train_dataset = datasets.MNIST(root='./data/datasets/mnist/', train=True, transform=transforms.ToTensor(), download=True)\n    test_dataset = datasets.MNIST(root='./data/datasets/mnist/', train=False, transform=transforms.ToTensor())\n    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=opt.batch_size, shuffle=True)\n    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=opt.batch_size, shuffle=False)\nelif opt.dataset == 'Fashion-MNIST':\n    train_dataset = datasets.FashionMNIST(root='./data/fashion/', train=True, transform=transforms.ToTensor(),\n                                          download=True)\n    test_dataset = datasets.FashionMNIST(root='./data/fashion/', train=False, transform=transforms.ToTensor())\n    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=opt.batch_size, shuffle=True)\n    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=opt.batch_size, shuffle=False)\n\nsnn = BaseGLSNN(input_size=784, hidden_sizes=[800] * 3, output_size=10, opt=opt)\nsnn.to(device)\noptimizer = torch.optim.Adam(snn.forward_parameters(), lr=opt.lr_forward)\nscheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)\n\n\ndef train(epoch):\n    snn.train()\n\n    start_time = time.time()\n    total_loss = 0\n    correct = 0\n    total = 0\n    for i, (images, labels) in enumerate(train_loader):\n        optimizer.zero_grad()\n        images = images.to(device)\n        labels_ = torch.zeros(opt.batch_size, 10).scatter_(1, labels.view(-1, 1), 1).to(device)\n        labels = labels.to(device)\n\n        outputs, loss = snn.set_gradient(images, labels_)\n        optimizer.step()\n        total_loss += loss.item()\n        pred = outputs[-1].max(1)[1]\n        total += labels.size(0)\n        correct += (pred.cpu() == labels.cpu()).sum()\n        if (i + 1) % (60000 // (opt.batch_size * 6)) == 0:\n            print('Epoch: [%d/%d], Step: [%d/%d], Loss: %.4f, Time: %.2f' % (\n                epoch + 1, opt.epoch, i + 1, 60000 // opt.batch_size, total_loss,\n                time.time() - start_time))\n            start_time = time.time()\n            total_loss = 0\n    acc = 100.0 * correct.item() / total\n    train_scores.append(acc)\n\n\ndef eval(epoch):\n    snn.eval()\n    correct = 0\n    total = 0\n    with torch.no_grad():\n        for i, (images, labels) in enumerate(test_loader):\n            images = images.to(device)\n            outputs = snn(images)\n            pred = outputs[-1].max(1)[1]\n            total += labels.size(0)\n            correct += (pred.cpu() == labels).sum()\n    acc = 100.0 * correct.item() / total\n    print('Test correct: %d Accuracy: %.2f%%' % (correct, acc))\n    test_scores.append(acc)\n    if acc >= max(test_scores):\n        save_file = str(epoch) + '.pt'\n        torch.save(snn, os.path.join(save_path, save_file))\n    return max(test_scores)\n\n\ndef main():\n    for epoch in range(opt.epoch):\n        train(epoch)\n        best_acc = eval(epoch)\n        scheduler.step()\n        print('Best Accuracy: %.2f%%' % (best_acc))\n\n\nif __name__ == '__main__':\n    main()\n    filename = \"train.json\"\n    filename = os.path.join(save_path, filename)\n    with open(filename, \"w\") as f:\n        json.dump(train_scores, f)\n    filename = \"test.json\"\n    filename = os.path.join(save_path, filename)\n    with open(filename, \"w\") as f:\n        json.dump(test_scores, f)\n"
  },
  {
    "path": "examples/Perception_and_Learning/img_cls/spiking_capsnet/README.md",
    "content": "# Spiking capsnet: A spiking neural network with a biologically plausible routing rule between capsules\n\n## Run\n```shell\npython main.py\n```\n\n\n## Citation\n\nIf you find the code and dataset useful in your research, please consider citing:\n```\n@article{zhao2022spiking,\n  title={Spiking capsnet: A spiking neural network with a biologically plausible routing rule between capsules},\n  author={Zhao, Dongcheng and Li, Yang and Zeng, Yi and Wang, Jihang and Zhang, Qian},\n  journal={Information Sciences},\n  volume={610},\n  pages={1--13},\n  year={2022},\n  publisher={Elsevier}\n}\n\n@misc{https://doi.org/10.48550/arxiv.2207.08533,\n  doi = {10.48550/ARXIV.2207.08533},\n  url = {https://arxiv.org/abs/2207.08533},\n  author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},\n  title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},\n  publisher = {arXiv},\n  year = {2022},\n}\n```\n## Contents\nFeedbacks and comments are welcome! Feel free to contact us via [zhaodongcheng2016@ia.ac.cn](zhaodongcheng2016@ia.ac.cn) \n\nEnjoy!"
  },
  {
    "path": "examples/Perception_and_Learning/img_cls/spiking_capsnet/spikingcaps.py",
    "content": "import sys\nsys.path.append('../../../../')\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\nfrom torch.optim import Adam\nimport os\nimport math\nfrom tqdm import tqdm\nimport numpy as np\nfrom braincog.datasets.datasets import get_mnist_data\nfrom braincog.base.node import LIFNode\nfrom braincog.utils import setup_seed\n\n\nsetup_seed(1111)\nos.environ['CUDA_VISIBLE_DEVICES'] = \"4\"\n\n\nclass myLIFnode(LIFNode):\n    def __init__(self, threshold=0.5, tau=2., *args, **kwargs):\n        super().__init__(threshold, tau, *args, **kwargs)\n\n    def integral(self, inputs):\n        # self.mem = self.mem + (inputs - self.mem) / self.tau\n        self.mem = self.mem / self.tau + inputs\n\n\nclass ConvLayer(nn.Module):\n    def __init__(self, in_channels=1, out_channels=256, kernel_size=9):\n        super(ConvLayer, self).__init__()\n        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=1)\n\n    def forward(self, x):\n        return F.relu(self.conv(x))\n\n\nclass PrimaryCaps(nn.Module):\n    def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=9):\n        super(PrimaryCaps, self).__init__()\n\n        self.capsules = nn.ModuleList([\n            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=0)\n            for _ in range(num_capsules)])\n\n    def forward(self, x):\n        u = [capsule(x) for capsule in self.capsules]\n        u = torch.stack(u, dim=1)\n        u.permute(0,2,3,4,1)\n        u = u.view(x.size(0), 32 * 6 * 6, -1)\n        return u\n\n\nclass DigitCaps(nn.Module):\n    def __init__(self, num_capsules=10, num_routes=32 * 6 * 6, in_channels=8, out_channels=16):\n        super(DigitCaps, self).__init__()\n\n        self.in_channels = in_channels\n        self.num_routes = num_routes\n        self.num_capsules = num_capsules\n\n        self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_channels, in_channels))\n        self.bias = nn.Parameter(torch.randn(out_channels, 1))\n\n        self.W.data.normal_(0, math.sqrt(3.0 / (in_channels * out_channels)))\n        self.bias.data.normal_(0, math.sqrt(3.0 / (in_channels * out_channels)))\n\n    def forward(self, x):\n        batch_size = x.size(0)\n        x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)\n\n        W = torch.cat([self.W] * batch_size, dim=0)\n        u_hat = torch.matmul(W, x) + self.bias\n        return u_hat\n\n\nclass DigitCaps2(nn.Module):\n    def __init__(self, num_capsules=10, num_routes=32 * 6 * 6):\n        super(DigitCaps2, self).__init__()\n\n        self.num_routes = num_routes\n        self.num_capsules = num_capsules\n        self.b_ij = Variable(torch.ones(1, self.num_routes, self.num_capsules, 1)/1152)\n        self.b_ij = self.b_ij.to(device)\n\n    def forward(self, u_hat):\n        c_ij = torch.cat([self.b_ij] * batch_size, dim=0).unsqueeze(4)\n        s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)\n        return s_j.squeeze(1)\n\n    def init_bij(self):\n        self.b_ij = Variable(torch.ones(1, self.num_routes, self.num_capsules, 1)/1152)\n        self.b_ij = self.b_ij.to(device)\n\n\nclass Decoder(nn.Module):\n    def __init__(self):\n        super(Decoder, self).__init__()\n        self.linear = nn.Linear(16, 1)\n\n    def forward(self, x):\n        classes = torch.sqrt((x ** 2).sum(2))\n        # classes = self.linear(x)\n        return classes\n\n\nclass CapsNet(nn.Module):\n    def __init__(self):\n        super(CapsNet, self).__init__()\n        self.conv_layer = ConvLayer()\n        self.primary_capsules = PrimaryCaps()\n        self.digit_capsules = DigitCaps()\n        self.digit_capsules2 = DigitCaps2()\n        self.decoder = Decoder()\n\n        self.conv_node = myLIFnode(tau=5)\n        self.primary_node = myLIFnode(tau=5)\n        self.digit_node = myLIFnode(tau=5)\n        self.digit2_node = myLIFnode(tau=5)\n\n    def forward(self, data, time_window=5, train=True):\n        self.init()\n        out_mem = 0.\n        self.digit_capsules2.init_bij()\n        self.trace_u = torch.zeros(batch_size, 1152, 10, 16, 1, device=device)\n\n        for step in range(time_window):\n            x = data\n\n            x = self.conv_node(self.conv_layer(x))\n            x = self.primary_node(self.primary_capsules(x))\n            x1 = self.digit_node(self.digit_capsules(x))\n            x = self.digit_capsules2(x1)\n            out_mem += x.squeeze(3)\n            y = self.digit2_node(x)\n\n            if train:\n                with torch.no_grad():\n                    self.digit_capsules2.b_ij = torch.clamp(self.digit_capsules2.b_ij, -0.05, 1)\n                    self.trace_u *= torch.exp(-1 / torch.tensor(1.5))\n                    self.trace_u.masked_fill_(x1 != 0, 1)\n                    self.digit_capsules2.b_ij += 0.0008 * torch.matmul(\n                        self.trace_u.transpose(3, 4) - 0.1,\n                        torch.stack([y] * 1152, dim=1)).squeeze(4).mean(dim=0, keepdim=True)\n\n        output = out_mem / time_window\n        output = self.decoder(output)\n        return output\n\n\n    def init(self):\n        self.conv_node.n_reset()\n        self.primary_node.n_reset()\n        self.digit_node.n_reset()\n        self.digit2_node.n_reset()\n\n\ndef evaluate(test_iter, net, device):\n    net.eval()\n\n    test_loss, test_acc, n_test = 0, 0.0, 0\n    for batch_id, (data, target) in tqdm(enumerate(test_iter)):\n        target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)\n        data, target = Variable(data), Variable(target)\n        data, target = data.to(device), target.to(device)\n\n        classes = net(data)\n\n        test_acc += sum(np.argmax(classes.data.cpu().numpy(), 1) == np.argmax(target.data.cpu().numpy(), 1))\n        n_test += data.shape[0]\n    net.train()\n\n    return test_acc / n_test\n\n\nif __name__ == '__main__':\n    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n    batch_size = 100\n\n    train_loader, test_loader, _, _ = get_mnist_data(batch_size)\n    capsule_net = CapsNet().to(device)\n    optimizer = Adam(capsule_net.parameters(), lr=0.0005)\n    loss_fn = nn.MSELoss()\n\n    n_epochs = 50\n    best, losses = 0, []\n\n    for epoch in range(n_epochs):\n        if epoch in [15, 25, 45]:\n            optimizer.param_groups[0]['lr'] *= 0.3\n\n        capsule_net.train()\n        train_loss, correct, n = 0, 0, 0\n        loss_rec = []\n        for batch_id, (data, target) in enumerate(train_loader):\n            target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)\n            data, target = Variable(data), Variable(target)\n\n            data, target = data.to(device), target.to(device)\n\n            optimizer.zero_grad()\n            classes = capsule_net(data)\n            loss = loss_fn(classes, target)\n            loss.backward()\n            loss_rec.append(loss.item())\n            optimizer.step()\n\n            train_loss += loss.item()\n            correct += sum(np.argmax(classes.data.cpu().numpy(), 1) == np.argmax(target.data.cpu().numpy(), 1))\n            n += data.shape[0]\n\n            if batch_id % 100 == 0:\n                print(\"Epoch: {}, Batch: {}, train accuracy: {:.6f}, loss: {:.6f}\".format(\n                    epoch, batch_id + 1,\n                    sum(np.argmax(classes.data.cpu().numpy(), 1) == np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size),\n                    loss.item()))\n        losses.append(np.mean(np.array(loss_rec)))\n\n        print(\"Epoch: [{}/{}],  train accuracy: {:.6f}, loss: {:.6f}\".format(\n            epoch, n_epochs,\n            correct / float(n),\n            train_loss / len(train_loader)))\n\n\n        capsule_net.eval()\n        test_acc = evaluate(test_loader, capsule_net, device=device)\n        print(\"test accuracy: {:.6f}\".format(test_acc))\n\n        if test_acc > best:\n            best = test_acc\n            # torch.save(capsule_net, './checkpoints/spikingcaps_mnist.pkl')\n\n"
  },
  {
    "path": "examples/Perception_and_Learning/img_cls/transfer_for_dvs/GradCAM_visualization.py",
    "content": "# -*- coding: utf-8 -*-            \n# Time : 2023/2/14 11:52\n# Author : Regulus\n# FileName: main_visual_losslandscape.py\n# Explain:\n# Software: PyCharm\nimport sys\n\nimport tqdm\n\nfrom loss_landscape.plot_surface import *\n\nfrom Pytorch_Grad_Cam.cam import *\n\nimport argparse\nimport math\nimport time\nimport CKA\nimport numpy\nimport timm.models\nimport random as rd\nimport yaml\nimport os\nfrom mpl_toolkits.mplot3d import Axes3D\nfrom mpl_toolkits.mplot3d import proj3d\nimport logging\nfrom collections import OrderedDict\nfrom contextlib import suppress\nfrom datetime import datetime\nfrom braincog.base.node.node import *\nfrom braincog.utils import *\nfrom braincog.base.utils.criterions import *\nfrom braincog.datasets.datasets import *\nfrom braincog.model_zoo.resnet import *\nfrom braincog.model_zoo.convnet import *\nfrom braincog.model_zoo.vgg_snn import VGG_SNN\nfrom braincog.model_zoo.resnet19_snn import resnet19\nfrom braincog.utils import save_feature_map, setup_seed\nfrom braincog.base.utils.visualization import plot_tsne_3d, plot_tsne, plot_confusion_matrix\nimport torch\nimport torch.nn as nn\nimport torchvision.utils\nfrom torch.nn.parallel import DistributedDataParallel as NativeDDP\nfrom rgb_hsv import RGB_HSV\nimport matplotlib.pyplot as plt\nfrom timm.data import ImageDataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset\nfrom timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model\nfrom timm.utils import *\nfrom timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy\nfrom timm.optim import create_optimizer\nfrom timm.scheduler import create_scheduler\nfrom timm.utils import ApexScaler, NativeScaler\nfrom copy import deepcopy\n\ntorch.backends.cudnn.benchmark = True\n_logger = logging.getLogger('train')\n\n# The first arg parser parses out only the --config argument, this argument is used to\n# load a yaml file containing key-values that override the defaults for the main parser below\nconfig_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)\nparser.add_argument('-c', '--config', default='', type=str, metavar='FILE',\n                    help='YAML config file specifying default arguments')\n\n\nparser = argparse.ArgumentParser(description='SNN Training and Evaluating')\n\n# Model parameters\nparser.add_argument('--source-dataset', default='cifar10', type=str)\nparser.add_argument('--target-dataset', default='dvsc10', type=str)\nparser.add_argument('--model', default='cifar_convnet', type=str, metavar='MODEL',\n                    help='Name of model to train (default: \"countception\"')\nparser.add_argument('--pretrained', action='store_true', default=False,\n                    help='Start with pretrained version of specified network (if avail)')\nparser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',\n                    help='Initialize model from this checkpoint (default: none)')\nparser.add_argument('--resume', default='', type=str, metavar='PATH',\n                    help='Resume full model and optimizer state from checkpoint (default: none)')\nparser.add_argument('--eval_checkpoint', default='', type=str, metavar='PATH',\n                    help='path to eval checkpoint (default: none)')\nparser.add_argument('--no-resume-opt', action='store_true', default=False,\n                    help='prevent resume of optimizer state when resuming model')\nparser.add_argument('--num-classes', type=int, default=10, metavar='N',\n                    help='number of label classes (default: 1000)')\nparser.add_argument('--gp', default=None, type=str, metavar='POOL',\n                    help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')\n\n# Dataset parameters for static datasets\nparser.add_argument('--img-size', type=int, default=224, metavar='N',\n                    help='Image patch size (default: None => model default)')\nparser.add_argument('--crop-pct', default=None, type=float,\n                    metavar='N', help='inputs image center crop percent (for validation only)')\nparser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',\n                    help='Override mean pixel value of dataset')\nparser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',\n                    help='Override std deviation of of dataset')\nparser.add_argument('--interpolation', default='', type=str, metavar='NAME',\n                    help='Image resize interpolation type (overrides model)')\n\n# Dataloader parameters\nparser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',\n                    help='inputs batch size for training (default: 128)')\nparser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',\n                    help='ratio of validation batch size to training batch size (default: 1)')\n\n# Optimizer parameters\nparser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',\n                    help='Optimizer (default: \"adamw\"')\nparser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',\n                    help='Optimizer Epsilon (default: None, use opt default)')\nparser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',\n                    help='Optimizer Betas (default: None, use opt default)')\nparser.add_argument('--momentum', type=float, default=0.9, metavar='M',\n                    help='Optimizer momentum (default: 0.9)')\nparser.add_argument('--weight-decay', type=float, default=0.01,\n                    help='weight decay (default: 0.01 for adamw)')\nparser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',\n                    help='Clip gradient norm (default: None, no clipping)')\nparser.add_argument('--adam-epoch', type=int, default=1000, help='lamb switch to adamw')\n\n# Learning rate schedule parameters\nparser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',\n                    help='LR scheduler (default: \"cosine\"')\nparser.add_argument('--lr', type=float, default=5e-3, metavar='LR',\n                    help='learning rate (default: 0.01)')\nparser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',\n                    help='learning rate noise on/off epoch percentages')\nparser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',\n                    help='learning rate noise limit percent (default: 0.67)')\nparser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',\n                    help='learning rate noise std-dev (default: 1.0)')\nparser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',\n                    help='learning rate cycle len multiplier (default: 1.0)')\nparser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',\n                    help='learning rate cycle limit')\nparser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',\n                    help='warmup learning rate (default: 0.0001)')\nparser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',\n                    help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')\nparser.add_argument('--epochs', type=int, default=600, metavar='N',\n                    help='number of epochs to train (default: 2)')\nparser.add_argument('--start-epoch', default=None, type=int, metavar='N',\n                    help='manual epoch number (useful on restarts)')\nparser.add_argument('--decay-epochs', type=float, default=30, metavar='N',\n                    help='epoch interval to decay LR')\nparser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',\n                    help='epochs to warmup LR, if scheduler supports')\nparser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',\n                    help='epochs to cooldown LR at min_lr, after cyclic schedule ends')\nparser.add_argument('--patience-epochs', type=int, default=10, metavar='N',\n                    help='patience epochs for Plateau LR scheduler (default: 10')\nparser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',\n                    help='LR decay rate (default: 0.1)')\nparser.add_argument('--power', type=int, default=1, help='power')\n\n# Augmentation & regularization parameters ONLY FOR IMAGE NET\nparser.add_argument('--no-aug', action='store_true', default=False,\n                    help='Disable all training augmentation, override other train aug args')\nparser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',\n                    help='Random resize scale (default: 0.08 1.0)')\nparser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                    help='Random resize aspect ratio (default: 0.75 1.33)')\nparser.add_argument('--hflip', type=float, default=0.5,\n                    help='Horizontal flip training aug probability')\nparser.add_argument('--vflip', type=float, default=0.,\n                    help='Vertical flip training aug probability')\nparser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',\n                    help='Color jitter factor (default: 0.4)')\nparser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',\n                    help='Use AutoAugment policy. \"v0\" or \"original\". (default: None)'),\nparser.add_argument('--aug-splits', type=int, default=0,\n                    help='Number of augmentation splits (default: 0, valid: 0 or >=2)')\nparser.add_argument('--jsd', action='store_true', default=False,\n                    help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')\nparser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',\n                    help='Random erase prob (default: 0.25)')\nparser.add_argument('--remode', type=str, default='pixel',\n                    help='Random erase mode (default: \"const\")')\nparser.add_argument('--recount', type=int, default=1,\n                    help='Random erase count (default: 1)')\nparser.add_argument('--resplit', action='store_true', default=False,\n                    help='Do not random erase first (clean) augmentation split')\nparser.add_argument('--mixup', type=float, default=0.8,\n                    help='mixup alpha, mixup enabled if > 0. (default: 0.)')\nparser.add_argument('--cutmix', type=float, default=1.0,\n                    help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')\nparser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,\n                    help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')\nparser.add_argument('--mixup-prob', type=float, default=1.0,\n                    help='Probability of performing mixup or cutmix when either/both is enabled')\nparser.add_argument('--mixup-switch-prob', type=float, default=0.5,\n                    help='Probability of switching to cutmix when both mixup and cutmix enabled')\nparser.add_argument('--mixup-mode', type=str, default='batch',\n                    help='How to apply mixup/cutmix params. Per \"batch\", \"pair\", or \"elem\"')\nparser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',\n                    help='Turn off mixup after this epoch, disabled if 0 (default: 0)')\nparser.add_argument('--smoothing', type=float, default=0.1,\n                    help='Label smoothing (default: 0.1)')\nparser.add_argument('--train-interpolation', type=str, default='random',\n                    help='Training interpolation (random, bilinear, bicubic default: \"random\")')\nparser.add_argument('--drop', type=float, default=0.0, metavar='PCT',\n                    help='Dropout rate (default: 0.0)')\nparser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',\n                    help='Drop connect rate, DEPRECATED, use drop-path (default: None)')\nparser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',\n                    help='Drop path rate (default: None)')\nparser.add_argument('--drop-block', type=float, default=None, metavar='PCT',\n                    help='Drop block rate (default: None)')\nparser.add_argument('--newton-maxiter', default=20, type=int,\n                    help='max iterration in newton method')\nparser.add_argument('--reset-drop', action='store_true', default=False,\n                    help='whether to reset drop')\nparser.add_argument('--kernel-method', type=str, default='cuda', choices=['torch', 'cuda'],\n                    help='The implementation way of gaussian kernel method, choose from \"cuda\" and \"torch\"')\n\n# Batch norm parameters (only works with gen_efficientnet based models currently)\nparser.add_argument('--bn-tf', action='store_true', default=False,\n                    help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')\nparser.add_argument('--bn-momentum', type=float, default=None,\n                    help='BatchNorm momentum override (if not None)')\nparser.add_argument('--bn-eps', type=float, default=None,\n                    help='BatchNorm epsilon override (if not None)')\nparser.add_argument('--sync-bn', action='store_true',\n                    help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')\nparser.add_argument('--dist-bn', type=str, default='',\n                    help='Distribute BatchNorm stats between node after each epoch (\"broadcast\", \"reduce\", or \"\")')\nparser.add_argument('--split-bn', action='store_true',\n                    help='Enable separate BN layers per augmentation split.')\n\n# Model Exponential Moving Average\nparser.add_argument('--model-ema', action='store_true', default=False,\n                    help='Enable tracking moving average of model weights')\nparser.add_argument('--model-ema-force-cpu', action='store_true', default=False,\n                    help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')\nparser.add_argument('--model-ema-decay', type=float, default=0.99996,\n                    help='decay factor for model weights moving average (default: 0.9998)')\n\n# Misc\nparser.add_argument('--seed', type=int, default=42, metavar='S',\n                    help='random seed (default: 42)')\nparser.add_argument('--log-interval', type=int, default=50, metavar='N',\n                    help='how many batches to wait before logging training status')\nparser.add_argument('--recovery-interval', type=int, default=0, metavar='N',\n                    help='how many batches to wait before writing recovery checkpoint')\nparser.add_argument('-j', '--workers', type=int, default=8, metavar='N',\n                    help='how many training processes to use (default: 1)')\nparser.add_argument('--num-gpu', type=int, default=1,\n                    help='Number of GPUS to use')\nparser.add_argument('--save-images', action='store_true', default=False,\n                    help='save images of inputs bathes every log interval for debugging')\nparser.add_argument('--amp', action='store_true', default=False,\n                    help='use NVIDIA Apex AMP or Native AMP for mixed precision training')\nparser.add_argument('--apex-amp', action='store_true', default=False,\n                    help='Use NVIDIA Apex AMP mixed precision')\nparser.add_argument('--native-amp', action='store_true', default=False,\n                    help='Use Native Torch AMP mixed precision')\nparser.add_argument('--channels-last', action='store_true', default=False,\n                    help='Use channels_last memory layout')\nparser.add_argument('--pin-mem', action='store_true', default=False,\n                    help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')\nparser.add_argument('--no-prefetcher', action='store_true', default=False,\n                    help='disable fast prefetcher')\nparser.add_argument('--output', default='/home/hexiang/TransferLearning_For_DVS/Results_new_refined/', type=str, metavar='PATH',\n                    help='path to output folder (default: none, current dir)')\nparser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',\n                    help='Best metric (default: \"top1\"')\nparser.add_argument('--tta', type=int, default=0, metavar='N',\n                    help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')\nparser.add_argument('--local_rank', default=0, type=int)\nparser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,\n                    help='use the multi-epochs-loader to save time at the beginning of every epoch')\nparser.add_argument('--eval', action='store_true', help='Perform evaluation only')\nparser.add_argument('--device', type=int, default=0)\n\n# Spike parameters\nparser.add_argument('--step', type=int, default=10, help='Simulation time step (default: 10)')\nparser.add_argument('--encode', type=str, default='direct', help='Input encode method (default: direct)')\nparser.add_argument('--temporal-flatten', action='store_true',\n                    help='Temporal flatten to channels. ONLY FOR EVENT DATA TRAINING BY ANN')\nparser.add_argument('--adaptive-node', action='store_true')\nparser.add_argument('--critical-loss', action='store_true')\n\n# neuron type\nparser.add_argument('--node-type', type=str, default='LIFNode', help='Node type in network (default: PLIF)')\nparser.add_argument('--act-fun', type=str, default='GateGrad',\n                    help='Surogate Function in node. Only for Surrogate nodes (default: AtanGrad)')\nparser.add_argument('--threshold', type=float, default=.5, help='Firing threshold (default: 0.5)')\nparser.add_argument('--tau', type=float, default=2., help='Attenuation coefficient (default: 2.)')\nparser.add_argument('--requires-thres-grad', action='store_true')\nparser.add_argument('--sigmoid-thres', action='store_true')\n\nparser.add_argument('--loss-fn', type=str, default='ce', help='loss function (default: ce)')\nparser.add_argument('--noisy-grad', type=float, default=0.,\n                    help='Add noise to backward, sometime will make higher accuracy (default: 0.)')\nparser.add_argument('--spike-output', action='store_true', default=False,\n                    help='Using mem output or spike output (default: False)')\nparser.add_argument('--n_groups', type=int, default=1)\n\n# EventData Augmentation\nparser.add_argument('--mix-up', action='store_true', help='Mix-up for event data (default: False)')\nparser.add_argument('--cut-mix', action='store_true', help='CutMix for event data (default: False)')\nparser.add_argument('--event-mix', action='store_true', help='EventMix for event data (default: False)')\nparser.add_argument('--cutmix_beta', type=float, default=1.0, help='cutmix_beta (default: 1.)')\nparser.add_argument('--cutmix_prob', type=float, default=0.5, help='cutmix_prib for event data (default: .5)')\nparser.add_argument('--cutmix_num', type=int, default=1, help='cutmix_num for event data (default: 1)')\nparser.add_argument('--cutmix_noise', type=float, default=0.,\n                    help='Add Pepper noise after mix, sometimes work (default: 0.)')\nparser.add_argument('--gaussian-n', type=int, default=3)\nparser.add_argument('--rand-aug', action='store_true',\n                    help='Rand Augment for Event data (default: False)')\nparser.add_argument('--randaug_n', type=int, default=3,\n                    help='Rand Augment times n (default: 3)')\nparser.add_argument('--randaug_m', type=int, default=15,\n                    help='Rand Augment times n (default: 15) (0-30)')\nparser.add_argument('--train-portion', type=float, default=0.9,\n                    help='Dataset portion, only for datasets which do not have validation set (default: 0.9)')\nparser.add_argument('--event-size', default=48, type=int,\n                    help='Event size. Resize event data before process (default: 48)')\nparser.add_argument('--layer-by-layer', action='store_true',\n                    help='forward step-by-step or layer-by-layer. '\n                         'Larger Model with layer-by-layer will be faster (default: False)')\nparser.add_argument('--node-resume', type=str, default='',\n                    help='resume weights in node for adaptive node. (default: False)')\nparser.add_argument('--node-trainable', action='store_true')\n\n# visualize\nparser.add_argument('--visualize', action='store_true',\n                    help='Visualize spiking map for each layer, only for validate (default: False)')\nparser.add_argument('--spike-rate', action='store_true',\n                    help='Print spiking rate for each layer, only for validate(default: False)')\nparser.add_argument('--tsne', action='store_true')\nparser.add_argument('--conf-mat', action='store_true')\n\nparser.add_argument('--suffix', type=str, default='',\n                    help='Add an additional suffix to the save path (default: \\'\\')')\n\nparser.add_argument('--DVS-DA', action='store_true',\n                    help='use DA on DVS')\n\n# train data used ratio\nparser.add_argument('--traindata-ratio', default=1.0, type=float,\n                    help='training data ratio')\n\n# snr value\nparser.add_argument('--snr', default=0, type=int,\n                    help='random noise amplitude controled by snr, 0 means no noise')\n\nparser.add_argument('--aug_smooth', action='store_true',\n                    help='Apply test time augmentation to smooth the CAM')\nparser.add_argument('--eigen_smooth', action='store_true', help='Reduce noise by taking the first principle componenet'\n         'of cam_weights*activations')\n\nimport os\nimport numpy as np\nimport torch\nfrom torchvision import transforms\nimport matplotlib.pyplot as plt\nfrom mpl_toolkits.mplot3d import Axes3D\nfrom mpl_toolkits.mplot3d import proj3d\nfrom tonic.datasets import NCALTECH101, CIFAR10DVS\nimport tonic\nfrom matplotlib import rcParams\nimport seaborn as sns\n\n\n# for matplotlib 3D\ndef get_proj(self):\n    \"\"\"\n     Create the projection matrix from the current viewing position.\n\n     elev stores the elevation angle in the z plane\n     azim stores the azimuth angle in the (x, y) plane\n\n     dist is the distance of the eye viewing point from the object point.\n    \"\"\"\n    # chosen for similarity with the initial view before gh-8896\n\n    relev, razim = np.pi * self.elev / 180, np.pi * self.azim / 180\n\n    # EDITED TO HAVE SCALED AXIS\n    xmin, xmax = np.divide(self.get_xlim3d(), self.pbaspect[0])\n    ymin, ymax = np.divide(self.get_ylim3d(), self.pbaspect[1])\n    zmin, zmax = np.divide(self.get_zlim3d(), self.pbaspect[2])\n\n    # transform to uniform world coordinates 0-1, 0-1, 0-1\n    worldM = proj3d.world_transformation(xmin, xmax,\n                                         ymin, ymax,\n                                         zmin, zmax)\n\n    # look into the middle of the new coordinates\n    R = self.pbaspect / 2\n\n    xp = R[0] + np.cos(razim) * np.cos(relev) * self.dist\n    yp = R[1] + np.sin(razim) * np.cos(relev) * self.dist\n    zp = R[2] + np.sin(relev) * self.dist\n    E = np.array((xp, yp, zp))\n\n    self.eye = E\n    self.vvec = R - E\n    self.vvec = self.vvec / np.linalg.norm(self.vvec)\n\n    if abs(relev) > np.pi / 2:\n        # upside down\n        V = np.array((0, 0, -1))\n    else:\n        V = np.array((0, 0, 1))\n    zfront, zback = -self.dist, self.dist\n\n    viewM = proj3d.view_transformation(E, R, V)\n    projM = self._projection(zfront, zback)\n    M0 = np.dot(viewM, worldM)\n    M = np.dot(projM, M0)\n    return M\n\n\ndef event_vis_raw(x):\n    sns.set_style('whitegrid')\n    # sns.set_palette('deep', desat=.6)\n    sns.set_context(\"notebook\", font_scale=1.5,\n                    rc={\"lines.linewidth\": 2.5})\n    Axes3D.get_proj = get_proj\n    x = np.array(x.tolist())  # x, y, t, p\n    mask = (x[:, 3] == 1)\n    x_pos = x[mask]\n    x_neg = x[mask == False]\n    pos_idx = np.random.choice(x_pos.shape[0], 10000)\n    neg_idx = np.random.choice(x_neg.shape[0], 10000)\n    # x_pos[pos_idx, 2] = 0\n    # x_neg[neg_idx, 2] = 0\n\n    fig = plt.figure(figsize=plt.figaspect(0.5) * 1.5)\n    ax = Axes3D(fig)\n    ax.pbaspect = np.array([2.0, 1.0, 0.5])\n    ax.view_init(elev=10, azim=-75)\n    # ax.view_init(elev=15, azim=15)\n    ax.set_xlabel('t (time step)')\n    ax.set_ylabel('w (pixel)')\n    ax.set_zlabel('h (pixel)')\n    # ax.set_xticks([])\n    # ax.set_yticks([])\n    # ax.set_zticks([])\n    # ax.scatter(x_pos[pos_idx, 2], 48 - x_pos[pos_idx, 0], 48 - x_pos[pos_idx, 1], color='red', alpha=0.3, s=1.)\n    # ax.scatter(x_neg[neg_idx, 2], 48 - x_neg[neg_idx, 0], 48 - x_neg[neg_idx, 1], color='blue', alpha=0.3, s=1.)\n    ax.scatter(x_pos[:, 0], 48 - x_pos[:, 1] * 0.375, 48 - x_pos[:, 2] * 0.375, color='red', alpha=0.3, s=1.)\n    # ax.scatter(x_neg[:, 0], 64 - x_neg[:, 1] // 2, 128 - x_neg[:, 2], color='blue', alpha=0.3, s=1.)\n    ax.scatter(18000, 48 - x_pos[:, 1] * 0.375, 48 - x_pos[:, 2] * 0.375, color='red', alpha=0.3, s=1.)\n    # ax.scatter(18000, 64 - x_pos[:, 1] // 2, 128 - x_pos[:, 2], color='blue', alpha=0.3, s=1.)\n\n\ndef get_dataloader_ncal(step, **kwargs):\n    sensor_size = tonic.datasets.CIFAR10DVS.sensor_size\n    transform = tonic.transforms.Compose([\n        # tonic.transforms.DropPixel(hot_pixel_frequency=.999),\n        # tonic.transforms.Denoise(500),\n        tonic.transforms.DropEvent(p=0.0),\n        # tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),\n        # lambda x: F.interpolate(torch.tensor(x, dtype=torch.float), size=[48, 48], mode='bilinear', align_corners=True),\n    ])\n    dataset = tonic.datasets.CIFAR10DVS(os.path.join(DATA_DIR, 'DVS/DVS_Cifar10'), transform=transform)\n    # dataset = [dataset[5569], dataset[8196]]\n    # dataset = [dataset[5000], dataset[6000]] # 1958\n    # dataset = [dataset[0]]\n    # loader = torch.utils.data.DataLoader(\n    #     dataset, batch_size=1,\n    #     shuffle=False,\n    #     pin_memory=True, drop_last=True, num_workers=8\n    # )\n    return dataset\n\ntry:\n    from apex import amp\n    from apex.parallel import DistributedDataParallel as ApexDDP\n    from apex.parallel import convert_syncbn_model\n\n    has_apex = True\nexcept ImportError:\n    has_apex = False\n\nhas_native_amp = False\ntry:\n    if getattr(torch.cuda.amp, 'autocast') is not None:\n        has_native_amp = True\nexcept AttributeError:\n    pass\n\n\ndef _parse_args():\n    # Do we have a config file to parse?\n    args_config, remaining = config_parser.parse_known_args()\n    if args_config.config:\n        with open(args_config.config, 'r') as f:\n            cfg = yaml.safe_load(f)\n            parser.set_defaults(**cfg)\n\n    # The main arg parser parses the rest of the args, the usual\n    # defaults will have been overridden if config file specified.\n    args = parser.parse_args(remaining)\n\n    # Cache the args as a text string to save them in the output dir later\n    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)\n    return args, args_text\n\n\ndef main():\n    torch.set_num_threads(20)\n    os.environ[\"OMP_NUM_THREADS\"] = \"20\"  # 设置OpenMP计算库的线程数\n    os.environ[\"MKL_NUM_THREADS\"] = \"20\"  # 设置MKL-DNN CPU加速库的线程数。\n    args, args_text = _parse_args()\n    args.no_spike_output = True\n    torch.cuda.set_device('cuda:%d' % args.device)\n    args.prefetcher = not args.no_prefetcher\n    args.distributed = False\n    if 'WORLD_SIZE' in os.environ:\n        args.distributed = int(os.environ['WORLD_SIZE']) > 1\n        if args.distributed and args.num_gpu > 1:\n            _logger.warning(\n                'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')\n            args.num_gpu = 1\n\n    # args.device = 'cuda:0'\n    args.world_size = 1\n    args.rank = 0  # global rank\n\n    assert args.rank >= 0\n\n    if args.distributed:\n        _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'\n                     % (args.rank, args.world_size))\n    else:\n        _logger.info('Training with a single process on %d GPUs.' % args.num_gpu)\n\n    # torch.manual_seed(args.seed + args.rank)\n    setup_seed(args.seed + args.rank)\n\n    model = create_model(\n        args.model,\n        pretrained=args.pretrained,\n        num_classes=args.num_classes,\n        adaptive_node=args.adaptive_node,\n        dataset=args.target_dataset,\n        step=args.step,\n        encode_type=args.encode,\n        node_type=eval(args.node_type),\n        threshold=args.threshold,\n        tau=args.tau,\n        sigmoid_thres=args.sigmoid_thres,\n        requires_thres_grad=args.requires_thres_grad,\n        spike_output=not args.no_spike_output,\n        act_fun=args.act_fun,\n        temporal_flatten=args.temporal_flatten,\n        layer_by_layer=args.layer_by_layer,\n        n_groups=args.n_groups,\n    )\n\n    if 'dvs' in args.target_dataset:\n        args.channels = 2\n    elif 'mnist' in args.target_dataset:\n        args.channels = 1\n    else:\n        args.channels = 3\n    # flops, params = profile(model, inputs=(torch.randn(1, args.channels, args.event_size, args.event_size),), verbose=False)\n    # _logger.info('flops = %fM', flops / 1e6)\n    # _logger.info('param size = %fM', params / 1e6)\n\n    linear_scaled_lr = args.lr * args.batch_size * args.world_size / 1024.0\n    args.lr = linear_scaled_lr\n    _logger.info(\"learning rate is %f\" % linear_scaled_lr)\n\n    if args.local_rank == 0:\n        _logger.info('Model %s created, param count: %d' %\n                     (args.model, sum([m.numel() for m in model.parameters()])))\n\n    # now config only for imnet\n    data_config = resolve_data_config(vars(args), model=model, verbose=False)\n    # source_loader_train, _, _, _ = eval('get_transfer_%s_data' % args.source_dataset)(\n    #     batch_size=args.batch_size,\n    #     step=args.step,\n    #     args=args,\n    #     _logge=_logger,\n    #     data_config=data_config,\n    #     size=args.event_size,\n    #     mix_up=args.mix_up,\n    #     cut_mix=args.cut_mix,\n    #     event_mix=args.event_mix,\n    #     beta=args.cutmix_beta,\n    #     prob=args.cutmix_prob,\n    #     gaussian_n=args.gaussian_n,\n    #     num=args.cutmix_num,\n    #     noise=args.cutmix_noise,\n    #     num_classes=args.num_classes,\n    #     rand_aug=args.rand_aug,\n    #     randaug_n=args.randaug_n,\n    #     randaug_m=args.randaug_m,\n    #     portion=args.train_portion,\n    #     _logger=_logger,\n    # )\n\n\n    origin_loader_train, _, _, _ = eval('get_origin_dvsc10_data')(\n        batch_size=args.batch_size,\n        step=args.step,\n        args=args,\n        _logge=_logger,\n        data_config=data_config,\n        size=args.event_size,\n        mix_up=args.mix_up,\n        cut_mix=args.cut_mix,\n        event_mix=args.event_mix,\n        beta=args.cutmix_beta,\n        prob=args.cutmix_prob,\n        gaussian_n=args.gaussian_n,\n        num=args.cutmix_num,\n        noise=args.cutmix_noise,\n        num_classes=args.num_classes,\n        rand_aug=args.rand_aug,\n        randaug_n=args.randaug_n,\n        randaug_m=args.randaug_m,\n        portion=args.train_portion,\n        _logger=_logger,\n    )\n\n    target_loader_train, target_loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % args.target_dataset)(\n        batch_size=args.batch_size,\n        dvs_da=args.DVS_DA,\n        step=args.step,\n        args=args,\n        _logge=_logger,\n        data_config=data_config,\n        size=args.event_size,\n        mix_up=args.mix_up,\n        cut_mix=args.cut_mix,\n        event_mix=args.event_mix,\n        beta=args.cutmix_beta,\n        prob=args.cutmix_prob,\n        gaussian_n=args.gaussian_n,\n        num=args.cutmix_num,\n        noise=args.cutmix_noise,\n        num_classes=args.num_classes,\n        rand_aug=args.rand_aug,\n        randaug_n=args.randaug_n,\n        randaug_m=args.randaug_m,\n        portion=args.train_portion,\n        _logger=_logger,\n        train_data_ratio=args.traindata_ratio,\n        snr=args.snr,\n        data_mode=\"full\",\n        frames_num=12,\n        data_type=\"frequency\"\n    )\n\n    model_before = deepcopy(model)\n    if args.eval:  # evaluate the model\n        if args.distributed:\n            state_dict = torch.load(args.eval_checkpoint)['state_dict_ema']\n            new_state_dict = OrderedDict()\n            # add module prefix for DDP\n            for k, v in state_dict.items():\n                k = 'module.' + k\n                new_state_dict[k] = v\n\n            model.load_state_dict(new_state_dict)\n        else:\n            model.load_state_dict(torch.load('/home/hexiang/TransferLearning_For_DVS/Results_lastest/train_TCKA_test/Transfer_VGG_SNN-dvsc10-10-bs_120-seed_42-DA_True-ls_0.0-lr_0.005-SNR_0-domainLoss_False-semanticLoss_False-domain_loss_coefficient1.0-semantic_loss_coefficient0.5-traindataratio_1.0-TETfirst_True-TETsecond_True/model_best.pth.tar', map_location='cpu')['state_dict'])\n            # pass\n            # print(\"no model load\")\n        # --------------------------------------------------------------------------\n        # Show Acc\n        # --------------------------------------------------------------------------\n        print(\"load model finished!\")\n\n\n    # \"\"\" python cam.py -image-path <path_to_image>\n    # Example usage of loading an image, and computing:\n    #     1. CAM\n    #     2. Guided Back Propagation\n    #     3. Combining both\n    # \"\"\"\n    #\n    # # Choose the target layer you want to compute the visualization for.\n    # # Usually this will be the last convolutional layer in the model.\n    # # Some common choices can be:\n    # # Resnet18 and 50: model.layer4\n    # # VGG, densenet161: model.features[-1]\n    # # mnasnet1_0: model.layers[-1]\n    # # You can print the model to help chose the layer\n    # # You can pass a list with several target layers,\n    # # in that case the CAMs will be computed per layer and then aggregated.\n    # # You can also try selecting all layers of a certain type, with e.g:\n    # # from pytorch_grad_cam.utils.find_layers import find_layer_types_recursive\n    # # find_layer_types_recursive(model, [torch.nn.ReLU])\n    # target_layers = [model.feature[-1]]\n    #\n    # if True:\n    #     # inputs = 0.0\n    #     # label = 0.0\n    #     # for batch_idx, (inputs_tmp, label_tmp) in tqdm.tqdm(enumerate(origin_loader_train)):\n    #     #     if batch_idx == choose_idx:\n    #     #         inputs = inputs_tmp\n    #     #         label = label_tmp\n    #     #         break\n    #     #     else:\n    #     #         continue\n    #     inputs = 0.0\n    #     rgb_img = 0.0\n    #\n    #     #Using the with statement ensures the context is freed, and you can\n    #     #recreate different CAM objects in a loop.\n    #     plt.figure(figsize=(8, 6))\n    #     plt.xlabel('w (pixel)')\n    #     plt.ylabel('h (pixel)')\n    #     cam_algorithm = GradCAMPlusPlus\n    #     model = model.cuda()\n    #     with cam_algorithm(model=model,\n    #                        target_layers=target_layers,\n    #                        use_cuda=False) as cam:\n    #\n    #         # AblationCAM and ScoreCAM have batched implementations.\n    #         # You can override the internal batch size for faster computation.\n    #         cam.batch_size = 32\n    #\n    #         for batch_idx, (origin_loaer, target_loader) in tqdm.tqdm(enumerate(zip(origin_loader_train, target_loader_train))):\n    #\n    #             twodemension_inputs, labels = origin_loaer\n    #             plt.figure(figsize=(8, 6))\n    #             # plt.xlabel('w (pixel)')\n    #             # plt.ylabel('h (pixel)')\n    #             twodemension_inputs = twodemension_inputs[0]  # (1, 10, 2, 48, 48) -> (10, 2, 48, 48)\n    #             event_frame_plot_2d(twodemension_inputs)\n    #\n    #             inputs_tmp, label_tmp = target_loader\n    #             inputs = inputs_tmp\n    #             inputs = inputs.type(torch.FloatTensor).cuda()\n    #\n    #             grayscale_cam = cam(input_tensor=inputs,\n    #                                 targets=None,\n    #                                 aug_smooth=args.aug_smooth,\n    #                                 eigen_smooth=args.eigen_smooth)\n    #\n    #             # Here grayscale_cam has only one image in the batch\n    #             grayscale_cam = grayscale_cam[0, :]\n    #\n    #             # cam_image = show_cam_on_image(rgb_img.permute(1, 2, 0).numpy(), grayscale_cam, use_rgb=True, image_weight=0.0)\n    #             cam_image = show_cam_on_image(np.ones((48, 48, 3)), grayscale_cam, use_rgb=True,\n    #                                           image_weight=0.0)\n    #         # # cam_image is RGB encoded whereas \"cv2.imwrite\" requires BGR encoding.\n    #     # rgb_img = cv2.resize(rgb_img.permute(1, 2, 0).numpy(), (32, 32))\n    #\n    #             # cv2.imwrite(f'{args.method}_cam.jpg', cam_image)\n    #             plt.ylim(bottom=0.)\n    #             plt.axis('off')\n    #             plt.savefig('fig/gradcam_dvspic_origin/label_{}_id_{}.jpg'.format(labels.item(), 400 + batch_idx), bbox_inches='tight', pad_inches=0)\n    #             plt.imshow(cam_image, alpha=1.0)\n    #             # plt.show()\n    #             # plt.savefig('gradcam_pic/plot_id{}.jpg'.format(batch_idx), bbox_inches='tight')\n    #             plt.savefig('fig/gradcam_dvspic_withoutloss/label_{}_id_{}.jpg'.format(labels.item(), 400 + batch_idx), bbox_inches='tight', pad_inches=0)\n\n    # 第二次\n    print(\"load model again!\")\n    model = model_before\n    model.load_state_dict(torch.load(\n        '/home/hexiang/TransferLearning_For_DVS/Results_lastest/train_TCKA_test/Transfer_VGG_SNN-dvsc10-10-bs_120-seed_42-DA_True-ls_0.0-lr_0.005-SNR_0-domainLoss_True-semanticLoss_True-domain_loss_coefficient1.0-semantic_loss_coefficient0.5-traindataratio_1.0-TETfirst_True-TETsecond_True/model_best.pth.tar', map_location='cpu')['state_dict'])\n\n\n    \"\"\" python cam.py -image-path <path_to_image>\n    Example usage of loading an image, and computing:\n        1. CAM\n        2. Guided Back Propagation\n        3. Combining both\n    \"\"\"\n\n    # Choose the target layer you want to compute the visualization for.\n    # Usually this will be the last convolutional layer in the model.\n    # Some common choices can be:\n    # Resnet18 and 50: model.layer4\n    # VGG, densenet161: model.features[-1]\n    # mnasnet1_0: model.layers[-1]\n    # You can print the model to help chose the layer\n    # You can pass a list with several target layers,\n    # in that case the CAMs will be computed per layer and then aggregated.\n    # You can also try selecting all layers of a certain type, with e.g:\n    # from pytorch_grad_cam.utils.find_layers import find_layer_types_recursive\n    # find_layer_types_recursive(model, [torch.nn.ReLU])\n    target_layers = [model.feature[-1]]\n\n    if True:\n        # inputs = 0.0\n        # label = 0.0\n        # for batch_idx, (inputs_tmp, label_tmp) in tqdm.tqdm(enumerate(origin_loader_train)):\n        #     if batch_idx == choose_idx:\n        #         inputs = inputs_tmp\n        #         label = label_tmp\n        #         break\n        #     else:\n        #         continue\n        inputs = 0.0\n        rgb_img = 0.0\n\n        #Using the with statement ensures the context is freed, and you can\n        #recreate different CAM objects in a loop.\n        plt.figure(figsize=(8, 6))\n        plt.xlabel('w (pixel)')\n        plt.ylabel('h (pixel)')\n        cam_algorithm = GradCAMPlusPlus\n        model = model.cuda()\n        with cam_algorithm(model=model,\n                           target_layers=target_layers,\n                           use_cuda=False) as cam:\n\n            # AblationCAM and ScoreCAM have batched implementations.\n            # You can override the internal batch size for faster computation.\n            cam.batch_size = 32\n\n            for batch_idx, (origin_loaer, target_loader) in tqdm.tqdm(enumerate(zip(origin_loader_train, target_loader_train))):\n\n                twodemension_inputs, labels = origin_loaer\n                plt.figure(figsize=(8, 6))\n                # plt.xlabel('w (pixel)')\n                # plt.ylabel('h (pixel)')\n                twodemension_inputs = twodemension_inputs[0]  # (1, 10, 2, 48, 48) -> (10, 2, 48, 48)\n                event_frame_plot_2d(twodemension_inputs)\n\n                inputs_tmp, label_tmp = target_loader\n                inputs = inputs_tmp\n                inputs = inputs.type(torch.FloatTensor).cuda()\n\n                grayscale_cam = cam(input_tensor=inputs,\n                                    targets=None,\n                                    aug_smooth=args.aug_smooth,\n                                    eigen_smooth=args.eigen_smooth)\n\n                # Here grayscale_cam has only one image in the batch\n                grayscale_cam = grayscale_cam[0, :]\n\n                # cam_image = show_cam_on_image(rgb_img.permute(1, 2, 0).numpy(), grayscale_cam, use_rgb=True, image_weight=0.0)\n                cam_image = show_cam_on_image(np.ones((48, 48, 3)), grayscale_cam, use_rgb=True,\n                                              image_weight=0.0)\n            # # cam_image is RGB encoded whereas \"cv2.imwrite\" requires BGR encoding.\n        # rgb_img = cv2.resize(rgb_img.permute(1, 2, 0).numpy(), (32, 32))\n\n                # cv2.imwrite(f'{args.method}_cam.jpg', cam_image)\n                plt.ylim(bottom=0.)\n                plt.axis('off')\n                # plt.savefig('fig/gradcam_dvspic_origin/label_{}_id_{}.jpg'.format(labels.item(), batch_idx), bbox_inches='tight', pad_inches=0)\n                plt.imshow(cam_image, alpha=1.0)\n                # plt.show()\n                # plt.savefig('gradcam_pic/plot_id{}.jpg'.format(batch_idx), bbox_inches='tight')\n                plt.savefig('fig/gradcam_dvspic_withloss/label_{}_id_{}.jpg'.format(labels.item(), batch_idx), bbox_inches='tight', pad_inches=0)\n\ndef event_frame_plot_2d(event):\n\n    for t in range(event.shape[0]):\n        pos_idx = []\n        neg_idx = []\n        for x in range(event.shape[2]):\n            for y in range(event.shape[3]):\n                if event[t, 0, x, y] > 0:\n                    pos_idx.append((x, y, event[t, 0, x, y]))\n                if event[t, 1, x, y] > 0:\n                    neg_idx.append((x, y, event[t, 0, x, y]))\n        if len(pos_idx) > 0:\n            # print(t)\n            pos_x, pos_y, pos_c = np.split(np.array(pos_idx), 3, axis=1)\n            # plt.scatter(48 - pos_x[:, 0] * 0.375, 48 - pos_y[:, 0] * 0.375, c='red', alpha=1, s=1)\n            plt.scatter(pos_x[:, 0] * 0.375, pos_y[:, 0] * 0.375, c='white', alpha=1, s=1)\n        if len(neg_idx) > 0:\n            neg_x, neg_y, neg_c = np.split(np.array(neg_idx), 3, axis=1)\n            # plt.scatter(48 - neg_x[:, 0] * 0.375, 48 - neg_y[:, 0] * 0.375, c='blue', alpha=1, s=1)\n            plt.scatter(neg_x[:, 0] * 0.375, neg_y[:, 0] * 0.375, c='blue', alpha=1, s=1)\n    # sys.exit()\n\nif __name__ == '__main__':\n    main()"
  },
  {
    "path": "examples/Perception_and_Learning/img_cls/transfer_for_dvs/README.md",
    "content": "# Script for all experiments\n\n## Baseline\n\n1. CIFAR10-DVS\n```shell \npython main.py --model VGG_SNN --node-type LIFNode --dataset dvsc10 --step 10 --batch-size 120 --act-fun QGateGrad --device 5 --seed 42 --DVS-DA --traindata-ratio 1.0 --smoothing 0.0 --TET-loss-first --TET-loss-second\n```\n\n2. N-Caltech 101\n\n```shell\npython main.py --model VGG_SNN --node-type LIFNode --dataset NCALTECH101 --step 10 --batch-size 120 --act-fun QGateGrad --device 7 --seed 42 --num-classes 101 --traindata-ratio 1.0 --smoothing 0.0 --TET-loss-first --TET-loss-second\n```\n\n3. Omniglot\n\n```shell\npython main.py --model SCNN --node-type LIFNode --dataset nomni --step 12 --batch-size 64 --num-classes 1623 --act-fun QGateGrad --epochs 200 --device 6 --log-interval 200 --smoothing 0.0 --seed 42 --lr 0.01 --min-lr 1e-5\n```\n\n\n\n## Our Method\n\n1. CIFAR10-DVS\n\n```shell\npython main_transfer.py --model Transfer_VGG_SNN --node-type LIFNode --source-dataset cifar10 --target-dataset dvsc10 --step 10 --batch-size 120 --act-fun QGateGrad --device 1 --seed 42 --traindata-ratio 1.0 --smoothing 0.0 --domain-loss --semantic-loss --DVS-DA --TET-loss-first --TET-loss-second\n```\n\n2. N-Caltech 101\n\n```shell\npython main_transfer.py --model Transfer_VGG_SNN --node-type LIFNode --source-dataset CALTECH101 --target-dataset NCALTECH101 --step 10 --batch-size 120 --act-fun QGateGrad --device 5 --seed 42 --num-classes 101 --traindata-ratio 1.0 --domain-loss --semantic-loss --semantic-loss-coefficient 0.001 --TET-loss-first --TET-loss-second&\n```\n\n3. N-Omniglot\n\n```shell\npython main_transfer.py --model Transfer_SCNN --node-type LIFNode --source-dataset omni --target-dataset nomni --step 12 --batch-size 64 --num-classes 1623 --act-fun QGateGrad --epochs 200 --device 6 --log-interval 200 --smoothing 0.0 --seed 42 --domain-loss --semantic-loss --semantic-loss-coefficient 0.5 --lr 0.01 --min-lr 1e-5\n```\n\n\n\n## Visualization Loss-landscape\n\nyou should git clone from https://github.com/tomgoldstein/loss-landscape first.\n\n```shell\nHDF5_USE_FILE_LOCKING=\"FALSE\" mpirun -n 4 -mca btl ^openib python main_visual_losslandscape.py --model VGG_SNN --node-type LIFNode --source-dataset cifar10 --target-dataset dvsc10 --step 10 --batch-size 1000 --eval --eval_checkpoint /home/TransferLearning_For_DVS/Resultes_new_compare/Baseline/VGG_SNN-dvsc10-10-seed_42-bs_120-DA_True-ls_0.0-traindataratio_0.1-TET_first_False-TET_second_False/last.pth.tar --mpi --x=-1.0:1.0:51 --y=-1.0:1.0:51 --dir_type weights --xnorm filter --xignore biasbn --ynorm filter --yignore biasbn --plot --DVS-DA --smoothing 0.0 --traindata-ratio 0.1\n```\n\n\n\n```shell\npython main_visual_losslandscape.py --model Transfer_VGG_SNN --node-type LIFNode --source-dataset CALTECH101 --target-dataset NCALTECH101 --step 10 --batch-size 500 --eval --eval_checkpoint /home/TransferLearning_For_DVS/Results_new_compare/train_TCKA_test/Transfer_VGG_SNN-NCALTECH101-10-bs_120-seed_47-DA_False-ls_0.0-lr_0.005-SNR_0-domainLoss_True-semanticLoss_True-domain_loss_coefficient1.0-semantic_loss_coefficient0.001-traindataratio_0.1-TETfirst_True-TETsecond_True/last.pth.tar --mpi --x=-1.0:1.0:51 --y=-1.0:1.0:51 --dir_type weights --xnorm filter --xignore biasbn --ynorm filter --yignore biasbn --plot --smoothing 0.0 --traindata-ratio 0.1 --num-classes 101 --device 5&\n```\n\n\n\n## Visualization Grad-cam++\n\nyou should git clone from https://github.com/jacobgil/pytorch-grad-cam first.\n\n```shell\npython GradCAM_visualization.py --model Transfer_VGG_SNN --node-type LIFNode --source-dataset cifar10 --target-dataset dvsc10 --step 10 --batch-size 1 --act-fun QGateGrad --device 6 --seed 42 --smoothing 0.0 --DVS-DA --eval --eval_checkpoint /home/TransferLearning_For_DVS/Results_lastest/train_TCKA_test/Transfer_VGG_SNN-dvsc10-10-bs_120-seed_42-DA_True-ls_0.0-lr_0.005-SNR_0-domainLoss_True-semanticLoss_True-domain_loss_coefficient1.0-semantic_loss_coefficient0.5-traindataratio_1.0-TETfirst_True-TETsecond_True/model_best.pth.tar\n```\n\n\n\n## Note: Dataset\n\nIn order to work with the source and target domain data, the datastes file is tailored, please use `datasets.py` here to replace and override `braincog/datasets/datasets.py` if you want to run transfer learning in this project.\n\n\n\n## Citation\n\nIf you find the code and dataset useful in your research, please consider citing:\n```\n@article{he2023improving,\n  title={Improving the Performance of Spiking Neural Networks on Event-based Datasets with Knowledge Transfer},\n  author={He, Xiang and Zhao, Dongcheng and Li, Yang and Shen, Guobin and Kong, Qingqun and Zeng, Yi},\n  journal={arXiv preprint arXiv:2303.13077},\n  year={2023}\n}\n\n@misc{https://doi.org/10.48550/arxiv.2207.08533,\n  doi = {10.48550/ARXIV.2207.08533},\n  url = {https://arxiv.org/abs/2207.08533},\n  author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},\n  title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},\n  publisher = {arXiv},\n  year = {2022},\n}\n```\n\n\n## Contents\n\nIf you are confused about using it or have other feedback and comments, please feel free to contact us via [hexiang2021@ia.ac.cn](hexiang2021@ia.ac.cn).\n\nHave a good day!"
  },
  {
    "path": "examples/Perception_and_Learning/img_cls/transfer_for_dvs/datasets.py",
    "content": "import os, warnings\nimport torchvision.datasets\ntry:\n    import tonic\n    from tonic import DiskCachedDataset\nexcept:\n    warnings.warn(\"tonic should be installed, 'pip install git+https://github.com/BrainCog-X/tonic_braincog.git'\")\nimport torch\nimport torch.nn.functional as F\nimport torch.utils\nimport torchvision.datasets as datasets\nfrom timm.data import ImageDataset, create_loader, Mixup, FastCollateMixup, AugMixDataset\nfrom timm.data import create_transform\nfrom einops import rearrange, repeat\nfrom torchvision import transforms\nfrom typing import Any, Dict, Optional, Sequence, Tuple, Union\nfrom torch.utils.data import ConcatDataset\nfrom braincog.datasets.NOmniglot.nomniglot_full import NOmniglotfull\nfrom braincog.datasets.NOmniglot.nomniglot_nw_ks import NOmniglotNWayKShot\nfrom braincog.datasets.NOmniglot.nomniglot_pair import NOmniglotTrainSet, NOmniglotTestSet\nfrom braincog.datasets.ESimagenet.ES_imagenet import ESImagenet_Dataset\nfrom braincog.datasets.ESimagenet.reconstructed_ES_imagenet import ESImagenet2D_Dataset\nfrom braincog.datasets.CUB2002011 import CUB2002011\nfrom braincog.datasets.TinyImageNet import TinyImageNet\nfrom braincog.datasets.StanfordDogs import StanfordDogs \nfrom random import sample\nfrom .cut_mix import CutMix, EventMix, MixUp\nfrom .rand_aug import *\nfrom .utils import dvs_channel_check_expend, rescale\nfrom PIL import Image\nimport cv2\nimport math\nDVSCIFAR10_MEAN_16 = [0.3290, 0.4507]\nDVSCIFAR10_STD_16 = [1.8398, 1.6549]\n\nDATA_DIR = '/data/datasets'\n\nDEFAULT_CROP_PCT = 0.875\nIMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)\nIMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)\nIMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)\nIMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)\nIMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)\nIMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)\n\nCIFAR10_DEFAULT_MEAN = (0.4914, 0.4822, 0.4465)\nCIFAR10_DEFAULT_STD = (0.2023, 0.1994, 0.2010)\n\n\nclass TransferSampler(torch.utils.data.sampler.Sampler):\n    r\"\"\"Samples elements randomly from a given list of indices, without replacement.\n    Arguments:\n        indices (sequence): a sequence of indices\n    \"\"\"\n\n    def __init__(self, indices):\n        self.indices = indices\n\n    def __iter__(self):\n        return (self.indices[i] for i in range(len(self.indices)))\n\n    def __len__(self):\n        return len(self.indices)\n\nclass Transfer_DataSet(torchvision.datasets.VisionDataset):\n    def __init__(self, data, label):\n        self.data = data\n        self.label = label\n        self.length = data.shape[0]\n\n    def __getitem__(self, mask):\n        data = self.data[mask]\n        label = self.label[mask]\n        return data, label\n\n    def __len__(self):\n        return self.length\n\n\n# 自定义HSV空间 transform\nclass ConvertHSV(object):\n    \"\"\"计算边缘梯度\n    Args:\n        None\n    \"\"\"\n\n    def __init__(self):\n        pass\n\n    # transform 会调用该方法\n    def __call__(self, img):\n        \"\"\"\n        Args:\n            img (PIL Image): PIL Image\n        Returns:\n            PIL Image: PIL image, v channel.\n        \"\"\"\n        img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)\n        img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)\n        return Image.fromarray(img.astype('uint8'))\n\n\ndef unpack_mix_param(args):\n    mix_up = args['mix_up'] if 'mix_up' in args else False\n    cut_mix = args['cut_mix'] if 'cut_mix' in args else False\n    event_mix = args['event_mix'] if 'event_mix' in args else False\n    beta = args['beta'] if 'beta' in args else 1.\n    prob = args['prob'] if 'prob' in args else .5\n    num = args['num'] if 'num' in args else 1\n    num_classes = args['num_classes'] if 'num_classes' in args else 10\n    noise = args['noise'] if 'noise' in args else 0.\n    gaussian_n = args['gaussian_n'] if 'gaussian_n' in args else None\n    return mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n\n\n\ndef build_transform(is_train, img_size, use_hsv=True):\n    \"\"\"\n    构建数据增强, 适用于static data\n    :param is_train: 是否训练集\n    :param img_size: 输出的图像尺寸\n    :return: 数据增强策略\n    \"\"\"\n    resize_im = img_size > 32\n    if is_train:\n        # this should always dispatch to transforms_imagenet_train\n        transform = create_transform(\n            input_size=img_size,\n            is_training=True,\n            color_jitter=0.4,\n            auto_augment='rand-m9-mstd0.5-inc1',\n            interpolation='bicubic',\n            re_prob=0.25,\n            re_mode='pixel',\n            re_count=1,\n        )\n        if not resize_im:\n            # replace RandomResizedCropAndInterpolation with\n            # RandomCrop\n            transform.transforms[0] = transforms.RandomCrop(\n                img_size, padding=4)\n        return transform\n\n    t = []\n    # if resize_im:\n    #     size = int((256 / 224) * img_size)\n    #     t.append(\n    #         # to maintain same ratio w.r.t. 224 images\n    #         transforms.Resize(size, interpolation=InterpolationMode.BICUBIC),\n    #     )\n    #     t.append(transforms.CenterCrop(img_size))\n\n    # t.append(transforms.RandomAffine(degrees=0, translate=))\n    # if Gradient:\n    #     print(\"Used Gradient!\")\n    #     t.append(ComputeLaplacian())\n        # t.append(ConvertHSV())\n        # t.append(AddGaussianNoise())\n    t.append(transforms.Resize((img_size, img_size), interpolation=InterpolationMode.BILINEAR))\n    if use_hsv:\n        print(\"Used V-channel!\")\n        t.append(ConvertHSV())\n    t.append(transforms.ToTensor())\n    # t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))\n    return transforms.Compose(t)\n\n\ndef build_dataset(is_train, img_size, dataset, path, same_da=False, use_hsv=True):\n    \"\"\"\n    构建带有增强策略的数据集\n    :param is_train: 是否训练集\n    :param img_size: 输出图像尺寸\n    :param dataset: 数据集名称\n    :param path: 数据集路径\n    :param same_da: 为训练集使用测试集的增广方法\n    : param use_hsv: 是否采用HSV\n    :return: 增强后的数据集\n    \"\"\"\n    # transform = build_transform(False, img_size) if same_da else build_transform(is_train, img_size)\n    transform = build_transform(False, img_size, use_hsv) if same_da else build_transform(False, img_size, use_hsv)\n    if dataset == 'CIFAR10':\n        dataset = datasets.CIFAR10(\n            path, train=is_train, transform=transform, download=True)\n        nb_classes = 10\n    elif dataset == 'CIFAR100':\n        dataset = datasets.CIFAR100(\n            path, train=is_train, transform=transform, download=True)\n        nb_classes = 100\n    elif dataset == 'CALTECH101':\n        dataset = datasets.Caltech101(\n            path, transform=transform, download=True\n        )\n        nb_classes = 101\n    else:\n        raise NotImplementedError\n\n    return dataset, nb_classes\n\n\nclass MNISTData(object):\n    \"\"\"\n    Load MNIST datesets.\n    \"\"\"\n\n    def __init__(self,\n                 data_path: str,\n                 batch_size: int,\n                 train_trans: Sequence[torch.nn.Module] = None,\n                 test_trans: Sequence[torch.nn.Module] = None,\n                 pin_memory: bool = True,\n                 drop_last: bool = True,\n                 shuffle: bool = True,\n                 ) -> None:\n        self._data_path = data_path\n        self._batch_size = batch_size\n        self._pin_memory = pin_memory\n        self._drop_last = drop_last\n        self._shuffle = shuffle\n        self._train_transform = transforms.Compose(train_trans) if train_trans else None\n        self._test_transform = transforms.Compose(test_trans) if test_trans else None\n\n    def get_data_loaders(self):\n        print('Batch size: ', self._batch_size)\n        train_datasets = datasets.MNIST(root=self._data_path, train=True, transform=self._train_transform, download=True)\n        test_datasets = datasets.MNIST(root=self._data_path, train=False, transform=self._test_transform, download=True)\n        train_loader = torch.utils.data.DataLoader(\n            train_datasets, batch_size=self._batch_size,\n            pin_memory=self._pin_memory, drop_last=self._drop_last, shuffle=self._shuffle\n        )\n        test_loader = torch.utils.data.DataLoader(\n            test_datasets, batch_size=self._batch_size,\n            pin_memory=self._pin_memory, drop_last=False\n        )\n        return train_loader, test_loader\n\n    def get_standard_data(self):\n        MNIST_MEAN = 0.1307\n        MNIST_STD = 0.3081\n        self._train_transform = transforms.Compose([transforms.RandomCrop(28, padding=4),\n                                                    transforms.ToTensor(),\n                                                    transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])\n        self._test_transform = transforms.Compose([transforms.ToTensor(),\n                                                   transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])\n        return self.get_data_loaders()\n\n\ndef get_mnist_data(batch_size, num_workers=8, same_da=False, **kwargs):\n    \"\"\"s\n    获取MNIST数据\n    http://data.pymvpa.org/datasets/mnist/\n    :param batch_size: batch size\n    :param same_da: 为训练集使用测试集的增广方法\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    MNIST_MEAN = 0.1307\n    MNIST_STD = 0.3081\n    if 'skip_norm' in kwargs and kwargs['skip_norm'] is True:\n        train_transform = transforms.Compose([\n            transforms.ToTensor(),\n            transforms.Lambda(rescale)\n        ])\n        test_transform = transforms.Compose([\n            transforms.ToTensor(),\n            transforms.Lambda(rescale)\n        ])\n    else:\n        train_transform = transforms.Compose([transforms.RandomCrop(28, padding=4),\n                                              # transforms.RandomRotation(10),\n                                              transforms.ToTensor(),\n                                              transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])\n        test_transform = transforms.Compose([transforms.ToTensor(),\n                                             transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])\n\n    train_datasets = datasets.MNIST(\n        root=DATA_DIR, train=True, transform=test_transform if same_da else train_transform, download=True)\n    test_datasets = datasets.MNIST(\n        root=DATA_DIR, train=False, transform=test_transform, download=True)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=num_workers\n    )\n\n    return train_loader, test_loader, False, None\n\n\ndef get_fashion_data(batch_size, num_workers=8, same_da=False, **kwargs):\n    \"\"\"\n    获取fashion MNIST数据\n    http://arxiv.org/abs/1708.07747\n    :param batch_size: batch size\n    :param same_da: 为训练集使用测试集的增广方法\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    train_transform = transforms.Compose([transforms.RandomCrop(28, padding=4),\n                                          transforms.RandomHorizontalFlip(),\n                                          transforms.RandomRotation(10),\n                                          transforms.ToTensor()])\n    test_transform = transforms.Compose([transforms.ToTensor()])\n\n    train_datasets = datasets.FashionMNIST(\n        root=DATA_DIR, train=True, transform=test_transform if same_da else train_transform, download=True)\n    test_datasets = datasets.FashionMNIST(\n        root=DATA_DIR, train=False, transform=test_transform, download=True)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=num_workers\n    )\n\n    return train_loader, test_loader, False, None\n\n\ndef get_cifar10_data(batch_size, num_workers=8, same_da=False, **kwargs):\n    \"\"\"\n    获取CIFAR10数据\n     https://www.cs.toronto.edu/~kriz/cifar.html\n    :param batch_size: batch size\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    use_hsv = not kwargs['no_use_hsv'] if 'no_use_hsv' in kwargs else True\n    train_datasets, _ = build_dataset(True, 32, 'CIFAR10', DATA_DIR, same_da, False)\n    test_datasets, _ = build_dataset(False, 32, 'CIFAR10', DATA_DIR, same_da, False)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=True, shuffle=True,\n        num_workers=num_workers\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=False,\n        num_workers=num_workers\n    )\n    return train_loader, test_loader, None, None\n\n\ndef get_cifar100_data(batch_size, num_workers=8, same_data=False, *args, **kwargs):\n    \"\"\"\n    获取CIFAR100数据\n    https://www.cs.toronto.edu/~kriz/cifar.html\n    :param batch_size: batch size\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    train_datasets, _ = build_dataset(True, 32, 'CIFAR100', DATA_DIR, same_data)\n    test_datasets, _ = build_dataset(False, 32, 'CIFAR100', DATA_DIR, same_data)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=num_workers\n    )\n    return train_loader, test_loader, False, None\n\n\ndef get_transfer_cifar10_data(batch_size, num_workers=8, same_da=False, **kwargs):\n    use_hsv = not kwargs['no_use_hsv'] if 'no_use_hsv' in kwargs else True\n    train_datasets, _ = build_dataset(True, 48, 'CIFAR10', DATA_DIR, same_da, use_hsv)  # 原来是48\n    test_datasets, _ = build_dataset(False, 48, 'CIFAR10', DATA_DIR, same_da, use_hsv)\n\n    concat_dataset = ConcatDataset([train_datasets, test_datasets])  # concat dataset\n\n    img_index = [[] for i in range(10)]\n    label_index = [0] * 60000\n    for idx, (img, label) in enumerate(concat_dataset):\n        img_index[label].append(img)\n    for i in range(10):\n        img_index[i] = torch.stack(img_index[i], 0)\n        label_index[i * 6000:2 * i * 6000] = [i] * 6000\n    source_datasets = Transfer_DataSet(data=rearrange(torch.stack(img_index, dim=0), 'l b c w h -> (l b) c w h'),\n                                       label=label_index)\n\n    source_loader = torch.utils.data.DataLoader(\n        source_datasets, batch_size=60000,\n        sampler=TransferSampler(torch.arange(0, 60000).tolist()),\n        pin_memory=True, drop_last=False, num_workers=16\n    )\n    return source_loader, None, None, None\n\n\ndef get_combined_cifar10_data(batch_size, num_workers=8, same_da=False, **kwargs):\n    use_hsv = not kwargs['no_use_hsv'] if 'no_use_hsv' in kwargs else True\n    train_datasets, _ = build_dataset(True, 48, 'CIFAR10', DATA_DIR, same_da, use_hsv)\n    test_datasets, _ = build_dataset(False, 48, 'CIFAR10', DATA_DIR, same_da, use_hsv)\n\n    concat_dataset = ConcatDataset([train_datasets, test_datasets])  # concat dataset\n\n    source_loader = torch.utils.data.DataLoader(\n        concat_dataset, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=8, shuffle=True\n    )\n    return source_loader, None, None, None\n\n\ndef get_transfer_CALTECH101_data(batch_size, num_workers=8, same_da=False, **kwargs):\n    \"\"\"\n    获取NCaltech101数据\n    http://journal.frontiersin.org/Article/10.3389/fnins.2015.00437/abstract\n    :param batch_size: batch size\n    :param step: 仿真步长\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    use_hsv = not kwargs['no_use_hsv'] if 'no_use_hsv' in kwargs else True\n    datasets, _ = build_dataset(False, 48, 'CALTECH101', DATA_DIR, same_da, use_hsv)\n    dataset_length = 8299\n\n    train_loader = torch.utils.data.DataLoader(\n        datasets, batch_size=10000,\n        sampler=TransferSampler(torch.arange(0, dataset_length).tolist()),\n        pin_memory=True, drop_last=False, num_workers=4\n    )\n\n    return train_loader, None, None, None\n\n\ndef get_combined_CALTECH101_data(batch_size, num_workers=8, same_da=False, **kwargs):\n    \"\"\"\n    获取NCaltech101数据\n    http://journal.frontiersin.org/Article/10.3389/fnins.2015.00437/abstract\n    :param batch_size: batch size\n    :param step: 仿真步长\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    use_hsv = not kwargs['no_use_hsv'] if 'no_use_hsv' in kwargs else True\n    datasets, _ = build_dataset(False, 48, 'CALTECH101', DATA_DIR, same_da, use_hsv)\n    dataset_length = 8299\n\n    train_loader = torch.utils.data.DataLoader(\n        datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=False,\n        num_workers=4, shuffle=True\n    )\n\n    return train_loader, None, None, None\n\n\ndef get_TinyImageNet_data(batch_size, num_workers=8, same_da=False, *args, **kwargs):\n    size=kwargs[\"size\"] if \"size\" in kwargs else 224\n    train_transform = transforms.Compose([\n        transforms.RandomResizedCrop(size),\n        transforms.RandomHorizontalFlip(),\n        transforms.ToTensor(),\n        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n    ])\n    test_transform = transforms.Compose([\n        transforms.Resize(size*8//7),\n        transforms.CenterCrop(size),\n        transforms.ToTensor(),\n        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n    ])\n    root=os.path.join(DATA_DIR, 'TinyImageNet')\n    train_datasets = TinyImageNet(\n        root=root, split=\"train\", transform=test_transform if same_da else train_transform, download=True)\n    test_datasets = TinyImageNet(\n        root=root, split=\"val\", transform=test_transform, download=True)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=num_workers\n    )\n\n    return train_loader, test_loader, False, None\n\n\ndef get_transfer_imnet_data(args, _logger, data_config, num_aug_splits, **kwargs):\n    '''\n    load imagenet 2012\n    we use images in train/ for training, and use images in val/ for testing\n    https://github.com/pytorch/examples/tree/master/imagenet\n    '''\n    IMAGENET_PATH = '/data/datasets/ILSVRC2012/'\n    traindir = os.path.join(IMAGENET_PATH, 'train')\n    valdir = os.path.join(IMAGENET_PATH, 'val')\n    batch_size = kwargs['batch_size']\n\n    train_dataset = datasets.ImageFolder(\n        traindir,\n        transforms.Compose([\n            transforms.RandomResizedCrop(224),\n            transforms.RandomHorizontalFlip(),\n            ConvertHSV(),\n            transforms.ToTensor()]))\n\n    # val_dataset = datasets.ImageFolder(\n    #     valdir,\n    #     transforms.Compose([\n    #         transforms.Resize(256),\n    #         transforms.CenterCrop(224),\n    #         ConvertHSV(),\n    #         transforms.ToTensor()]))\n\n    # train_loader = torch.utils.data.DataLoader(\n    #     train_dataset,\n    #     batch_size=batch_size, shuffle=False,\n    #     num_workers=4, pin_memory=True, sampler=TransferSampler([0, 1300, 2599, 2600]))\n    #\n    # val_loader = torch.utils.data.DataLoader(\n    #     val_dataset,\n    #     batch_size=batch_size, shuffle=False,\n    #     num_workers=4, pin_memory=True)\n    return train_dataset, None, None, None\n\n\ndef get_dvsg_data(batch_size, step, **kwargs):\n    \"\"\"\n    获取DVS Gesture数据\n    DOI: 10.1109/CVPR.2017.781\n    :param batch_size: batch size\n    :param step: 仿真步长\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    sensor_size = tonic.datasets.DVSGesture.sensor_size\n    size = kwargs['size'] if 'size' in kwargs else 48\n\n    train_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        # tonic.transforms.DropEvent(p=0.1),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),\n    ])\n    test_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),\n    ])\n\n    train_dataset = tonic.datasets.DVSGesture(os.path.join(DATA_DIR, 'DVS/DVSGesture'),\n                                              transform=train_transform, train=True)\n    test_dataset = tonic.datasets.DVSGesture(os.path.join(DATA_DIR, 'DVS/DVSGesture'),\n                                             transform=test_transform, train=False)\n\n    train_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        lambda x: dvs_channel_check_expend(x),\n        transforms.RandomCrop(size, padding=size // 12),\n        # transforms.RandomHorizontalFlip(),\n        # transforms.RandomRotation(15)\n    ])\n    test_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        lambda x: dvs_channel_check_expend(x),\n    ])\n    if 'rand_aug' in kwargs.keys():\n        if kwargs['rand_aug'] is True:\n            n = kwargs['randaug_n']\n            m = kwargs['randaug_m']\n            train_transform.transforms.insert(2, RandAugment(m=m, n=n))\n\n    # if 'temporal_flatten' in kwargs.keys():\n    #     if kwargs['temporal_flatten'] is True:\n    #         train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n    #         test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n\n    train_dataset = DiskCachedDataset(train_dataset,\n                                      cache_path=os.path.join(DATA_DIR, 'DVS/DVSGesture/train_cache_{}'.format(step)),\n                                      transform=train_transform, num_copies=3)\n    test_dataset = DiskCachedDataset(test_dataset,\n                                     cache_path=os.path.join(DATA_DIR, 'DVS/DVSGesture/test_cache_{}'.format(step)),\n                                     transform=test_transform, num_copies=3)\n\n    mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n    mixup_active = cut_mix | event_mix | mix_up\n\n    if cut_mix:\n        train_dataset = CutMix(train_dataset,\n                               beta=beta,\n                               prob=prob,\n                               num_mix=num,\n                               num_class=num_classes,\n                               noise=noise)\n\n    if event_mix:\n        train_dataset = EventMix(train_dataset,\n                                 beta=beta,\n                                 prob=prob,\n                                 num_mix=num,\n                                 num_class=num_classes,\n                                 noise=noise,\n                                 gaussian_n=gaussian_n)\n    if mix_up:\n        train_dataset = MixUp(train_dataset,\n                              beta=beta,\n                              prob=prob,\n                              num_mix=num,\n                              num_class=num_classes,\n                              noise=noise)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size,\n        pin_memory=True, drop_last=True, num_workers=8,\n        shuffle=True,\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=2,\n        shuffle=False,\n    )\n\n    return train_loader, test_loader, mixup_active, None\n\n\ndef get_dvsc10_data(batch_size, step, dvs_da=False, **kwargs):\n    \"\"\"\n    获取DVS CIFAR10数据\n    http://journal.frontiersin.org/article/10.3389/fnins.2017.00309/full\n    :param batch_size: batch size\n    :param step: 仿真步长\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    size = kwargs['size'] if 'size' in kwargs else 48\n    snr = kwargs['snr'] if 'snr' in kwargs else 0\n    train_data_ratio = kwargs['train_data_ratio'] if 'train_data_ratio' in kwargs else 1.0\n    sensor_size = tonic.datasets.CIFAR10DVS.sensor_size\n    train_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        # tonic.transforms.DropEvent(p=0.1),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    test_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    train_dataset = tonic.datasets.CIFAR10DVS(os.path.join(DATA_DIR, 'DVS/DVS_Cifar10'), transform=train_transform)\n    test_dataset = tonic.datasets.CIFAR10DVS(os.path.join(DATA_DIR, 'DVS/DVS_Cifar10'), transform=test_transform)\n\n    if dvs_da is True:\n        print(\"use dvs_da\")\n        if snr > 0:\n            train_transform = transforms.Compose([\n                lambda x: torch.tensor(x, dtype=torch.float),\n                lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n                lambda x: x + torch.randn(x.shape) * math.sqrt(torch.mean(torch.pow(x, 2)) / math.pow(10, snr / 10)),\n                transforms.RandomCrop(size, padding=size // 12),\n                transforms.RandomHorizontalFlip(),\n                transforms.RandomRotation(15)\n            ])\n        else:\n            train_transform = transforms.Compose([\n                lambda x: torch.tensor(x, dtype=torch.float),\n                lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n                transforms.RandomCrop(size, padding=size // 12),\n                transforms.RandomHorizontalFlip(),\n                transforms.RandomRotation(15)\n            ])\n    else:\n        train_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n    ])\n\n    test_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n    ])   # 这里lambda返回的是地址, 注意不要用List复用.\n\n    train_dataset = DiskCachedDataset(train_dataset,\n                                      cache_path=os.path.join(DATA_DIR, 'DVS/DVS_Cifar10/train_cache_{}'.format(step)),\n                                      transform=train_transform)\n    test_dataset = DiskCachedDataset(test_dataset,\n                                     cache_path=os.path.join(DATA_DIR, 'DVS/DVS_Cifar10/test_cache_{}'.format(step)),\n                                     transform=test_transform)\n\n    num_train = len(train_dataset)\n    num_per_cls = num_train // 10\n    indices_train, indices_test = [], []\n    portion = kwargs['portion'] if 'portion' in kwargs else .9\n    for i in range(10):\n        indices_train.extend(\n            sample(list(range(i * num_per_cls, round(i * num_per_cls + num_per_cls * portion))), int(num_per_cls * portion * train_data_ratio)))\n        indices_test.extend(\n            list(range(round(i * num_per_cls + num_per_cls * portion), (i + 1) * num_per_cls)))\n\n    mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n    mixup_active = cut_mix | event_mix | mix_up\n\n    if cut_mix:\n        # print('cut_mix', beta, prob, num, num_classes)\n        train_dataset = CutMix(train_dataset,\n                               beta=beta,\n                               prob=prob,\n                               num_mix=num,\n                               num_class=num_classes,\n                               indices=indices_train,\n                               noise=noise)\n\n    if event_mix:\n        train_dataset = EventMix(train_dataset,\n                                 beta=beta,\n                                 prob=prob,\n                                 num_mix=num,\n                                 num_class=num_classes,\n                                 indices=indices_train,\n                                 noise=noise,\n                                 gaussian_n=gaussian_n)\n\n    if mix_up:\n        train_dataset = MixUp(train_dataset,\n                              beta=beta,\n                              prob=prob,\n                              num_mix=num,\n                              num_class=num_classes,\n                              indices=indices_train,\n                              noise=noise)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size,\n        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_train),\n        pin_memory=True, drop_last=False, num_workers=8\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_test),\n        pin_memory=True, drop_last=False, num_workers=2\n    )\n\n    return train_loader, test_loader, mixup_active, None\n\n\n\ndef get_transfer_dvsc10_data(batch_size, step, dvs_da=False, **kwargs):\n    \"\"\"\n    获取DVS CIFAR10数据\n    :param batch_size: batch size\n    :param step: 仿真步长\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    size = kwargs['size'] if 'size' in kwargs else 48\n    snr = kwargs['snr'] if 'snr' in kwargs else 0\n    train_data_ratio = kwargs['train_data_ratio'] if 'train_data_ratio' in kwargs else 1.0\n    sensor_size = tonic.datasets.CIFAR10DVS.sensor_size\n    train_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        # tonic.transforms.DropEvent(p=0.1),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    test_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    train_dataset = tonic.datasets.CIFAR10DVS(os.path.join(DATA_DIR, 'DVS/DVS_Cifar10'), transform=train_transform)\n    test_dataset = tonic.datasets.CIFAR10DVS(os.path.join(DATA_DIR, 'DVS/DVS_Cifar10'), transform=test_transform)\n\n    train_transform = transforms.Compose([\n    lambda x: torch.tensor(x, dtype=torch.float),\n    lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),])\n\n\n    test_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n    ])   # 这里lambda返回的是地址, 注意不要用List复用.\n\n    train_dataset = DiskCachedDataset(train_dataset,\n                                      cache_path=os.path.join(DATA_DIR, 'DVS/DVS_Cifar10/train_cache_{}'.format(step)),\n                                      transform=train_transform)\n    test_dataset = DiskCachedDataset(test_dataset,\n                                     cache_path=os.path.join(DATA_DIR, 'DVS/DVS_Cifar10/test_cache_{}'.format(step)),\n                                     transform=test_transform)\n\n    num_train = len(train_dataset)\n    num_per_cls = num_train // 10\n    indices_train, indices_test = [], []\n    portion = kwargs['portion'] if 'portion' in kwargs else .9\n    for i in range(10):\n        indices_train.extend(\n            list(range(i * num_per_cls, round(i * num_per_cls + num_per_cls * portion))))\n        indices_test.extend(\n            list(range(round(i * num_per_cls + num_per_cls * portion), (i + 1) * num_per_cls)))\n\n    mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n    mixup_active = cut_mix | event_mix | mix_up\n\n    if cut_mix:\n        # print('cut_mix', beta, prob, num, num_classes)\n        train_dataset = CutMix(train_dataset,\n                               beta=beta,\n                               prob=prob,\n                               num_mix=num,\n                               num_class=num_classes,\n                               indices=indices_train,\n                               noise=noise)\n\n    if event_mix:\n        train_dataset = EventMix(train_dataset,\n                                 beta=beta,\n                                 prob=prob,\n                                 num_mix=num,\n                                 num_class=num_classes,\n                                 indices=indices_train,\n                                 noise=noise,\n                                 gaussian_n=gaussian_n)\n\n    if mix_up:\n        train_dataset = MixUp(train_dataset,\n                              beta=beta,\n                              prob=prob,\n                              num_mix=num,\n                              num_class=num_classes,\n                              indices=indices_train,\n                              noise=noise)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=len(indices_train),\n        sampler=TransferSampler(indices_train),\n        pin_memory=True, drop_last=True, num_workers=8\n    )\n\n    return train_loader, None, mixup_active, None\n\n\ndef get_NCALTECH101_data(batch_size, step, dvs_da=False, **kwargs):\n    \"\"\"\n    获取NCaltech101数据\n    http://journal.frontiersin.org/Article/10.3389/fnins.2015.00437/abstract\n    :param batch_size: batch size\n    :param step: 仿真步长\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    sensor_size = tonic.datasets.NCALTECH101.sensor_size\n    cls_count = tonic.datasets.NCALTECH101.cls_count\n    dataset_length = tonic.datasets.NCALTECH101.length\n    portion = kwargs['portion'] if 'portion' in kwargs else .9\n    size = kwargs['size'] if 'size' in kwargs else 48\n    snr = kwargs['snr'] if 'snr' in kwargs else 0\n    train_data_ratio = kwargs['train_data_ratio'] if 'train_data_ratio' in kwargs else 1.0\n    # print('portion', portion)\n    train_sample_weight = []\n    train_sample_index = []\n    train_count = 0\n    test_sample_index = []\n    idx_begin = 0\n    for count in cls_count:\n        sample_weight = dataset_length / count\n        train_sample = round(portion * count)\n        test_sample = count - train_sample\n        train_count += int(train_sample * train_data_ratio)\n        train_sample_weight.extend(\n            [sample_weight] * int(train_sample * train_data_ratio)\n        )\n        train_sample_weight.extend(\n            [0.] * (train_sample - int(train_sample * train_data_ratio))\n        )\n        train_sample_weight.extend(\n            [0.] * test_sample\n        )\n        train_sample_index.extend(\n            sample(list(range(idx_begin, idx_begin + train_sample)), int(train_sample * train_data_ratio))\n        )\n        test_sample_index.extend(\n            list(range(idx_begin + train_sample, idx_begin + train_sample + test_sample))\n        )\n        idx_begin += count\n\n    train_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_sample_weight, train_count)\n    test_sampler = torch.utils.data.sampler.SubsetRandomSampler(test_sample_index)\n\n    train_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        # tonic.transforms.DropEvent(p=0.1),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    test_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n\n    train_dataset = tonic.datasets.NCALTECH101(os.path.join(DATA_DIR, 'DVS/NCALTECH101'), transform=train_transform)\n    test_dataset = tonic.datasets.NCALTECH101(os.path.join(DATA_DIR, 'DVS/NCALTECH101'), transform=test_transform)\n\n    if dvs_da is True:\n        print(\"use dvs_da\")\n        train_transform = transforms.Compose([\n            lambda x: torch.tensor(x, dtype=torch.float),\n            lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n            transforms.RandomCrop(size, padding=size // 12),\n            transforms.RandomHorizontalFlip(),\n            transforms.RandomRotation(15)\n        ])\n    else:\n        if snr > 0:\n            train_transform = transforms.Compose([\n                lambda x: torch.tensor(x, dtype=torch.float),\n                lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n                lambda x: x + torch.randn(x.shape) * math.sqrt(torch.mean(torch.pow(x, 2)) / math.pow(10, snr / 10)),\n            ])\n        else:\n            train_transform = transforms.Compose([\n                lambda x: torch.tensor(x, dtype=torch.float),\n                lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n                transforms.RandomCrop(size, padding=size // 12),\n            ])\n    test_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n    ])  # 这里lambda返回的是地址, 注意不要用List复用.\n\n    train_dataset = DiskCachedDataset(train_dataset,\n                                      cache_path=os.path.join(DATA_DIR, 'DVS/NCALTECH101/train_cache_{}'.format(step)),\n                                      transform=train_transform, num_copies=3)\n    test_dataset = DiskCachedDataset(test_dataset,\n                                     cache_path=os.path.join(DATA_DIR, 'DVS/NCALTECH101/test_cache_{}'.format(step)),\n                                     transform=test_transform, num_copies=3)\n\n    mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n    mixup_active = cut_mix | event_mix | mix_up\n\n    if cut_mix:\n        train_dataset = CutMix(train_dataset,\n                               beta=beta,\n                               prob=prob,\n                               num_mix=num,\n                               num_class=num_classes,\n                               indices=train_sample_index,\n                               noise=noise)\n\n    if event_mix:\n        train_dataset = EventMix(train_dataset,\n                                 beta=beta,\n                                 prob=prob,\n                                 num_mix=num,\n                                 num_class=num_classes,\n                                 indices=train_sample_index,\n                                 noise=noise,\n                                 gaussian_n=gaussian_n)\n    if mix_up:\n        train_dataset = MixUp(train_dataset,\n                              beta=beta,\n                              prob=prob,\n                              num_mix=num,\n                              num_class=num_classes,\n                              indices=train_sample_index,\n                              noise=noise)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size,\n        sampler=train_sampler,\n        pin_memory=True, drop_last=True, num_workers=8\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        sampler=test_sampler,\n        pin_memory=True, drop_last=False, num_workers=2\n    )\n\n    return train_loader, test_loader, mixup_active, None\n\n\ndef get_transfer_NCALTECH101_data(batch_size, step, dvs_da=False, **kwargs):\n    \"\"\"\n    获取NCaltech101数据\n    http://journal.frontiersin.org/Article/10.3389/fnins.2015.00437/abstract\n    :param batch_size: batch size\n    :param step: 仿真步长\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    sensor_size = tonic.datasets.NCALTECH101.sensor_size\n    cls_count = tonic.datasets.NCALTECH101.cls_count\n    dataset_length = tonic.datasets.NCALTECH101.length\n    portion = kwargs['portion'] if 'portion' in kwargs else .9\n    size = kwargs['size'] if 'size' in kwargs else 48\n    snr = kwargs['snr'] if 'snr' in kwargs else 0\n    train_data_ratio = kwargs['train_data_ratio'] if 'train_data_ratio' in kwargs else 1.0\n    # print('portion', portion)\n    train_sample_weight = []\n    train_sample_index = []\n    train_count = 0\n    test_sample_index = []\n    idx_begin = 0\n    for count in cls_count:\n        sample_weight = dataset_length / count\n        train_sample = round(portion * count)\n        test_sample = count - train_sample\n        train_count += int(train_sample * train_data_ratio)\n        train_sample_weight.extend(\n            [sample_weight] * int(train_sample * train_data_ratio)\n        )\n        train_sample_weight.extend(\n            [0.] * (train_sample - int(train_sample * train_data_ratio))\n        )\n        train_sample_weight.extend(\n            [0.] * test_sample\n        )\n        train_sample_index.extend(\n            sample(list(range(idx_begin, idx_begin + train_sample)), int(train_sample * train_data_ratio))\n        )\n        test_sample_index.extend(\n            list(range(idx_begin + train_sample, idx_begin + train_sample + test_sample))\n        )\n        idx_begin += count\n\n    train_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_sample_weight, train_count)\n    test_sampler = torch.utils.data.sampler.SubsetRandomSampler(test_sample_index)\n\n    train_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        # tonic.transforms.DropEvent(p=0.1),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    test_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n\n    train_dataset = tonic.datasets.NCALTECH101(os.path.join(DATA_DIR, 'DVS/NCALTECH101'), transform=train_transform)\n    test_dataset = tonic.datasets.NCALTECH101(os.path.join(DATA_DIR, 'DVS/NCALTECH101'), transform=test_transform)\n\n    train_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n    ])\n\n    test_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n    ])  # 这里lambda返回的是地址, 注意不要用List复用.\n\n    train_dataset = DiskCachedDataset(train_dataset,\n                                      cache_path=os.path.join(DATA_DIR, 'DVS/NCALTECH101/train_cache_{}'.format(step)),\n                                      transform=train_transform, num_copies=3)\n    test_dataset = DiskCachedDataset(test_dataset,\n                                     cache_path=os.path.join(DATA_DIR, 'DVS/NCALTECH101/test_cache_{}'.format(step)),\n                                     transform=test_transform, num_copies=3)\n\n    mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n    mixup_active = cut_mix | event_mix | mix_up\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=len(train_sample_index),\n        sampler=TransferSampler(train_sample_index),\n        pin_memory=True, drop_last=True, num_workers=8\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        sampler=test_sampler,\n        pin_memory=True, drop_last=False, num_workers=2\n    )\n\n    return train_loader, None, None, None\n\n\ndef get_NCARS_data(batch_size, step, **kwargs):\n    \"\"\"\n    获取N-Cars数据\n    https://ieeexplore.ieee.org/document/8578284/\n    :param batch_size: batch size\n    :param step: 仿真步长\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    sensor_size = tonic.datasets.NCARS.sensor_size\n    size = kwargs['size'] if 'size' in kwargs else 48\n\n    train_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        # tonic.transforms.DropEvent(p=0.1),\n        tonic.transforms.ToFrame(sensor_size=None, n_time_bins=step),\n    ])\n    test_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        tonic.transforms.ToFrame(sensor_size=None, n_time_bins=step),\n    ])\n\n    train_dataset = tonic.datasets.NCARS(os.path.join(DATA_DIR, 'DVS/NCARS'), transform=train_transform, train=True)\n    test_dataset = tonic.datasets.NCARS(os.path.join(DATA_DIR, 'DVS/NCARS'), transform=test_transform, train=False)\n\n    train_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        lambda x: dvs_channel_check_expend(x),\n        transforms.RandomCrop(size, padding=size // 12),\n        transforms.RandomHorizontalFlip(),\n        transforms.RandomRotation(15)\n    ])\n    test_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        lambda x: dvs_channel_check_expend(x),\n    ])\n    if 'rand_aug' in kwargs.keys():\n        if kwargs['rand_aug'] is True:\n            n = kwargs['randaug_n']\n            m = kwargs['randaug_m']\n            train_transform.transforms.insert(2, RandAugment(m=m, n=n))\n\n    # if 'temporal_flatten' in kwargs.keys():\n    #     if kwargs['temporal_flatten'] is True:\n    #         train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n    #         test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n\n    train_dataset = DiskCachedDataset(train_dataset,\n                                      cache_path=os.path.join(DATA_DIR, 'DVS/NCARS/train_cache_{}'.format(step)),\n                                      transform=train_transform, num_copies=3)\n    test_dataset = DiskCachedDataset(test_dataset,\n                                     cache_path=os.path.join(DATA_DIR, 'DVS/NCARS/test_cache_{}'.format(step)),\n                                     transform=test_transform, num_copies=3)\n\n    mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n    mixup_active = cut_mix | event_mix | mix_up\n\n    if cut_mix:\n        train_dataset = CutMix(train_dataset,\n                               beta=beta,\n                               prob=prob,\n                               num_mix=num,\n                               num_class=num_classes,\n                               noise=noise)\n\n    if event_mix:\n        train_dataset = EventMix(train_dataset,\n                                 beta=beta,\n                                 prob=prob,\n                                 num_mix=num,\n                                 num_class=num_classes,\n                                 noise=noise,\n                                 gaussian_n=gaussian_n)\n    if mix_up:\n        train_dataset = MixUp(train_dataset,\n                              beta=beta,\n                              prob=prob,\n                              num_mix=num,\n                              num_class=num_classes,\n                              noise=noise)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size,\n        pin_memory=True, drop_last=True, num_workers=8,\n        shuffle=True,\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=2,\n        shuffle=False,\n    )\n\n    return train_loader, test_loader, mixup_active, None\n\n\ndef get_nomni_data(batch_size, train_portion=1., **kwargs):\n    \"\"\"\n    获取N-Omniglot数据\n    :param batch_size:batch的大小\n    :param data_mode:一共full nkks pair三种模式\n    :param frames_num:一个样本帧的个数\n    :param data_type:event frequency两种模式\n    \"\"\"\n    data_mode = kwargs[\"data_mode\"] if \"data_mode\" in kwargs else \"full\"\n    frames_num = kwargs[\"frames_num\"] if \"frames_num\" in kwargs else 4\n    data_type = kwargs[\"data_type\"] if \"data_type\" in kwargs else \"event\"\n\n    train_transform = transforms.Compose([\n        transforms.Resize((28, 28))])\n    test_transform = transforms.Compose([\n        transforms.Resize((28, 28))])\n    if data_mode == \"full\":\n        train_datasets = NOmniglotfull(root=os.path.join(DATA_DIR, 'DVS/NOmniglot'), train=True, frames_num=frames_num,\n                                       data_type=data_type,\n                                       transform=train_transform, use_npz=True)\n        test_datasets = NOmniglotfull(root=os.path.join(DATA_DIR, 'DVS/NOmniglot'), train=False, frames_num=frames_num,\n                                      data_type=data_type,\n                                      transform=test_transform, use_npz=True)\n\n    elif data_mode == \"nkks\":\n        train_datasets = NOmniglotNWayKShot(os.path.join(DATA_DIR, 'DVS/NOmniglot'),\n                                            n_way=kwargs[\"n_way\"],\n                                            k_shot=kwargs[\"k_shot\"],\n                                            k_query=kwargs[\"k_query\"],\n                                            train=True,\n                                            frames_num=frames_num,\n                                            data_type=data_type,\n                                            transform=train_transform)\n        test_datasets = NOmniglotNWayKShot(os.path.join(DATA_DIR, 'DVS/NOmniglot'),\n                                           n_way=kwargs[\"n_way\"],\n                                           k_shot=kwargs[\"k_shot\"],\n                                           k_query=kwargs[\"k_query\"],\n                                           train=False,\n                                           frames_num=frames_num,\n                                           data_type=data_type,\n                                           transform=test_transform)\n    elif data_mode == \"pair\":\n        train_datasets = NOmniglotTrainSet(root=os.path.join(DATA_DIR, 'DVS/NOmniglot'), use_frame=True,\n                                           frames_num=frames_num, data_type=data_type,\n                                           use_npz=False, resize=105)\n        test_datasets = NOmniglotTestSet(root=os.path.join(DATA_DIR, 'DVS/NOmniglot'), time=2000, way=kwargs[\"n_way\"],\n                                         shot=kwargs[\"k_shot\"], use_frame=True,\n                                         frames_num=frames_num, data_type=data_type, use_npz=False, resize=105)\n\n    else:\n        pass\n\n    train_loader = torch.utils.data.DataLoader(\n        train_datasets, batch_size=batch_size, num_workers=4,\n        pin_memory=True, drop_last=True, shuffle=True\n    )\n    test_loader = torch.utils.data.DataLoader(\n        test_datasets, batch_size=batch_size, num_workers=4,\n        pin_memory=True, drop_last=False\n    )\n    return train_loader, test_loader, None, None\n\n\n\ndef get_transfer_omni_data(batch_size, train_portion=1., **kwargs):\n    \"\"\"\n    获取Omniglot数据\n    :param batch_size:batch的大小\n    :param data_mode:一共full nkks pair三种模式\n    :param frames_num:一个样本帧的个数\n    :param data_type:event frequency两种模式\n    \"\"\"\n\n    transform = transforms.Compose([\n        transforms.Resize((28, 28)),\n        transforms.ToTensor()])\n\n    train_dataset = datasets.Omniglot(\n        root=\"/data/datasets/\", background=True, download=True, transform=transform\n    )\n    test_dataset = datasets.Omniglot(\n        root=\"/data/datasets/\", background=False, download=True, transform=transform\n    )\n    dataset = torch.utils.data.ConcatDataset([train_dataset, test_dataset])\n    dataset_length = len(dataset)\n    train_loader = torch.utils.data.DataLoader(\n        dataset, batch_size=35000, num_workers=12,\n        pin_memory=True, drop_last=False,\n        sampler=TransferSampler(torch.arange(0, dataset_length).tolist())\n    )\n\n    return train_loader, None, None, None\n\n\ndef get_esimnet_data(batch_size, step, **kwargs):\n    \"\"\"\n    获取ES imagenet数据\n    DOI: 10.3389/fnins.2021.726582\n    :param batch_size: batch size\n    :param step: 仿真步长，固定为8\n    :param reconstruct: 重构则时间步为1, 否则为8\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    :note: 没有自动下载, 下载及md5请参考spikingjelly, sampler默认为DistributedSampler\n    \"\"\"\n\n    reconstruct = kwargs[\"reconstruct\"] if \"reconstruct\" in kwargs else False\n\n    train_transform = transforms.Compose([\n        transforms.RandomHorizontalFlip(),\n        transforms.RandomRotation(15)\n    ])\n    test_transform = transforms.Compose([\n        lambda x: dvs_channel_check_expend(x),\n    ])\n\n    if reconstruct:\n        assert step == 1\n        train_dataset = ESImagenet2D_Dataset(mode='train',\n                                            data_set_path=os.path.join(DATA_DIR, 'DVS/ES-imagenet-0.18/extract/ES-imagenet-0.18/'),\n                                            transform=train_transform)\n\n        test_dataset = ESImagenet2D_Dataset(mode='test',\n                                            data_set_path=os.path.join(DATA_DIR, 'DVS/ES-imagenet-0.18/extract/ES-imagenet-0.18/'),\n                                            transform=test_transform)\n    else:\n        assert step == 8\n        train_dataset = ESImagenet_Dataset(mode='train',\n                                             data_set_path=os.path.join(DATA_DIR,\n                                                                        'DVS/ES-imagenet-0.18/extract/ES-imagenet-0.18/'),\n                                             transform=train_transform)\n\n        test_dataset = ESImagenet_Dataset(mode='test',\n                                            data_set_path=os.path.join(DATA_DIR,\n                                                                       'DVS/ES-imagenet-0.18/extract/ES-imagenet-0.18/'),\n                                            transform=test_transform)\n\n\n    mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n    mixup_active = cut_mix | event_mix | mix_up\n\n    if cut_mix:\n        train_dataset = CutMix(train_dataset,\n                               beta=beta,\n                               prob=prob,\n                               num_mix=num,\n                               num_class=num_classes,\n                               noise=noise)\n\n    if event_mix:\n        train_dataset = EventMix(train_dataset,\n                                 beta=beta,\n                                 prob=prob,\n                                 num_mix=num,\n                                 num_class=num_classes,\n                                 noise=noise,\n                                 gaussian_n=gaussian_n)\n    if mix_up:\n        train_dataset = MixUp(train_dataset,\n                              beta=beta,\n                              prob=prob,\n                              num_mix=num,\n                              num_class=num_classes,\n                              noise=noise)\n\n    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)\n    test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size,\n        pin_memory=True, drop_last=True, num_workers=8,\n        sampler=train_sampler\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=8,\n        sampler=test_sampler\n    )\n\n    # train_loader = torch.utils.data.DataLoader(\n    #     train_dataset, batch_size=batch_size,\n    #     pin_memory=True, drop_last=True, num_workers=8,\n    #     shuffle=True\n    # )\n    #\n    # test_loader = torch.utils.data.DataLoader(\n    #     test_dataset, batch_size=batch_size,\n    #     pin_memory=True, drop_last=False, num_workers=1,\n    #     shuffle=False\n    # )\n\n    return train_loader, test_loader, mixup_active, None\n\n\ndef get_CUB2002011_data(batch_size, num_workers=8, same_da=False, *args, **kwargs):\n    train_transform = transforms.Compose([\n        transforms.RandomResizedCrop(224),\n        transforms.RandomHorizontalFlip(),\n        transforms.ToTensor(),\n        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n    ])\n    test_transform = transforms.Compose([\n        transforms.Resize(256),\n        transforms.CenterCrop(224),\n        transforms.ToTensor(),\n        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n    ])\n    root=os.path.join(DATA_DIR, 'CUB2002011')\n    train_datasets = CUB2002011(\n        root=root, train=True, transform=test_transform if same_da else train_transform, download=True)\n    test_datasets = CUB2002011(\n        root=root, train=False, transform=test_transform, download=True)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=num_workers\n    )\n\n    return train_loader, test_loader, False, None\n\ndef get_StanfordCars_data(batch_size, num_workers=8, same_da=False, *args, **kwargs):\n    train_transform = transforms.Compose([\n        transforms.RandomResizedCrop(224),\n        transforms.RandomHorizontalFlip(),\n        transforms.ToTensor(),\n        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n    ])\n    test_transform = transforms.Compose([\n        transforms.Resize(256),\n        transforms.CenterCrop(224),\n        transforms.ToTensor(),\n        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n    ])\n    root=os.path.join(DATA_DIR, 'StanfordCars')\n    train_datasets = datasets.StanfordCars(\n        root=root, split =\"train\", transform=test_transform if same_da else train_transform, download=True)\n    test_datasets = datasets.StanfordCars(\n        root=root, split =\"test\", transform=test_transform, download=True)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=num_workers\n    )\n\n    return train_loader, test_loader, False, None\n\ndef get_StanfordDogs_data(batch_size, num_workers=8, same_da=False, *args, **kwargs):\n    train_transform = transforms.Compose([\n        transforms.RandomResizedCrop(224),\n        transforms.RandomHorizontalFlip(),\n        transforms.ToTensor(),\n        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n    ])\n    test_transform = transforms.Compose([\n        transforms.Resize(256),\n        transforms.CenterCrop(224),\n        transforms.ToTensor(),\n        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n    ])\n    root=os.path.join(DATA_DIR, 'StanfordDogs')\n    train_datasets = StanfordDogs(\n        root=root, train=True, transform=test_transform if same_da else train_transform, download=True)\n    test_datasets = StanfordDogs(\n        root=root, train=False, transform=test_transform, download=True)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=num_workers\n    )\n\n    return train_loader, test_loader, False, None\n\n\ndef get_FGVCAircraft_data(batch_size, num_workers=8, same_da=False, *args, **kwargs):\n    train_transform = transforms.Compose([\n        transforms.RandomResizedCrop(224),\n        transforms.RandomHorizontalFlip(),\n        transforms.ToTensor(),\n        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n    ])\n    test_transform = transforms.Compose([\n        transforms.Resize(256),\n        transforms.CenterCrop(224),\n        transforms.ToTensor(),\n        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n    ])\n    root=os.path.join(DATA_DIR, 'FGVCAircraft')\n    train_datasets = datasets.FGVCAircraft(\n        root=root, split=\"train\", transform=test_transform if same_da else train_transform, download=True)\n    test_datasets = datasets.FGVCAircraft(\n        root=root, split=\"test\", transform=test_transform, download=True)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=num_workers\n    )\n\n    return train_loader, test_loader, False, None\n\n\ndef get_Flowers102_data(batch_size, num_workers=8, same_da=False, *args, **kwargs):\n    train_transform = transforms.Compose([\n        transforms.RandomResizedCrop(224),\n        transforms.RandomHorizontalFlip(),\n        transforms.ToTensor(),\n        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n    ])\n    test_transform = transforms.Compose([\n        transforms.Resize(256),\n        transforms.CenterCrop(224),\n        transforms.ToTensor(),\n        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n    ])\n    root=os.path.join(DATA_DIR, 'Flowers102')\n    train_datasets = datasets.Flowers102(\n        root=root, split=\"train\", transform=test_transform if same_da else train_transform, download=True)\n    test_datasets = datasets.Flowers102(\n        root=root, split=\"test\", transform=test_transform, download=True)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=num_workers\n    )\n\n    return train_loader, test_loader, False, None"
  },
  {
    "path": "examples/Perception_and_Learning/img_cls/transfer_for_dvs/main.py",
    "content": "# -*- coding: utf-8 -*-            \n# Time : 2023/4/19 14:58\n# Author : Regulus\n# FileName: main.py\n# Explain: \n# Software: PyCharm\nimport argparse\nimport time\n\nimport timm.models\nimport yaml\nimport os\nimport logging\nfrom collections import OrderedDict\nfrom contextlib import suppress\nfrom datetime import datetime\n\nfrom braincog.base.node.node import *\nfrom braincog.utils import *\nfrom braincog.base.utils.criterions import *\nfrom braincog.datasets.datasets import *\nfrom braincog.model_zoo.resnet import *\nfrom braincog.model_zoo.convnet import *\nfrom braincog.model_zoo.vgg_snn import VGG_SNN\nfrom braincog.model_zoo.resnet19_snn import resnet19\nfrom braincog.utils import save_feature_map, setup_seed\nfrom braincog.base.utils.visualization import plot_tsne_3d, plot_tsne, plot_confusion_matrix\n\nimport torch\nimport torch.nn as nn\nimport torchvision.utils\nfrom torch.nn.parallel import DistributedDataParallel as NativeDDP\n\nfrom timm.data import ImageDataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset\nfrom timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model\nfrom timm.utils import *\nfrom timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy\nfrom timm.optim import create_optimizer\nfrom timm.scheduler import create_scheduler\nfrom timm.utils import ApexScaler, NativeScaler\n\n# from ptflops import get_model_complexity_info\n# from thop import profile, clever_format\n\ntorch.backends.cudnn.benchmark = True\n_logger = logging.getLogger('train')\n\n# The first arg parser parses out only the --config argument, this argument is used to\n# load a yaml file containing key-values that override the defaults for the main parser below\nconfig_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)\nparser.add_argument('-c', '--config', default='', type=str, metavar='FILE',\n                    help='YAML config file specifying default arguments')\n\nparser = argparse.ArgumentParser(description='SNN Training and Evaluating')\n\n# Model parameters\nparser.add_argument('--dataset', default='cifar10', type=str)\nparser.add_argument('--model', default='cifar_convnet', type=str, metavar='MODEL',\n                    help='Name of model to train (default: \"countception\"')\nparser.add_argument('--pretrained', action='store_true', default=False,\n                    help='Start with pretrained version of specified network (if avail)')\nparser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',\n                    help='Initialize model from this checkpoint (default: none)')\nparser.add_argument('--resume', default='', type=str, metavar='PATH',\n                    help='Resume full model and optimizer state from checkpoint (default: none)')\nparser.add_argument('--eval_checkpoint', default='', type=str, metavar='PATH',\n                    help='path to eval checkpoint (default: none)')\nparser.add_argument('--no-resume-opt', action='store_true', default=False,\n                    help='prevent resume of optimizer state when resuming model')\nparser.add_argument('--num-classes', type=int, default=10, metavar='N',\n                    help='number of label classes (default: 1000)')\nparser.add_argument('--gp', default=None, type=str, metavar='POOL',\n                    help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')\n\n# Dataset parameters for static datasets\nparser.add_argument('--img-size', type=int, default=224, metavar='N',\n                    help='Image patch size (default: None => model default)')\nparser.add_argument('--crop-pct', default=None, type=float,\n                    metavar='N', help='inputs image center crop percent (for validation only)')\nparser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',\n                    help='Override mean pixel value of dataset')\nparser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',\n                    help='Override std deviation of of dataset')\nparser.add_argument('--interpolation', default='', type=str, metavar='NAME',\n                    help='Image resize interpolation type (overrides model)')\n\n# Dataloader parameters\nparser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',\n                    help='inputs batch size for training (default: 128)')\nparser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',\n                    help='ratio of validation batch size to training batch size (default: 1)')\n\n# Optimizer parameters\nparser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',\n                    help='Optimizer (default: \"adamw\"')\nparser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',\n                    help='Optimizer Epsilon (default: None, use opt default)')\nparser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',\n                    help='Optimizer Betas (default: None, use opt default)')\nparser.add_argument('--momentum', type=float, default=0.9, metavar='M',\n                    help='Optimizer momentum (default: 0.9)')\nparser.add_argument('--weight-decay', type=float, default=0.01,\n                    help='weight decay (default: 0.01 for adamw)')\nparser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',\n                    help='Clip gradient norm (default: None, no clipping)')\nparser.add_argument('--adam-epoch', type=int, default=1000, help='lamb switch to adamw')\n\n# Learning rate schedule parameters\nparser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',\n                    help='LR scheduler (default: \"cosine\"')\nparser.add_argument('--lr', type=float, default=5e-3, metavar='LR',\n                    help='learning rate (default: 0.01)')\nparser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',\n                    help='learning rate noise on/off epoch percentages')\nparser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',\n                    help='learning rate noise limit percent (default: 0.67)')\nparser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',\n                    help='learning rate noise std-dev (default: 1.0)')\nparser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',\n                    help='learning rate cycle len multiplier (default: 1.0)')\nparser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',\n                    help='learning rate cycle limit')\nparser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',\n                    help='warmup learning rate (default: 0.0001)')\nparser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',\n                    help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')\nparser.add_argument('--epochs', type=int, default=600, metavar='N',\n                    help='number of epochs to train (default: 2)')\nparser.add_argument('--start-epoch', default=None, type=int, metavar='N',\n                    help='manual epoch number (useful on restarts)')\nparser.add_argument('--decay-epochs', type=float, default=30, metavar='N',\n                    help='epoch interval to decay LR')\nparser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',\n                    help='epochs to warmup LR, if scheduler supports')\nparser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',\n                    help='epochs to cooldown LR at min_lr, after cyclic schedule ends')\nparser.add_argument('--patience-epochs', type=int, default=10, metavar='N',\n                    help='patience epochs for Plateau LR scheduler (default: 10')\nparser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',\n                    help='LR decay rate (default: 0.1)')\nparser.add_argument('--power', type=int, default=1, help='power')\n\n# Augmentation & regularization parameters ONLY FOR IMAGE NET\nparser.add_argument('--no-aug', action='store_true', default=False,\n                    help='Disable all training augmentation, override other train aug args')\nparser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',\n                    help='Random resize scale (default: 0.08 1.0)')\nparser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                    help='Random resize aspect ratio (default: 0.75 1.33)')\nparser.add_argument('--hflip', type=float, default=0.5,\n                    help='Horizontal flip training aug probability')\nparser.add_argument('--vflip', type=float, default=0.,\n                    help='Vertical flip training aug probability')\nparser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',\n                    help='Color jitter factor (default: 0.4)')\nparser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',\n                    help='Use AutoAugment policy. \"v0\" or \"original\". (default: None)'),\nparser.add_argument('--aug-splits', type=int, default=0,\n                    help='Number of augmentation splits (default: 0, valid: 0 or >=2)')\nparser.add_argument('--jsd', action='store_true', default=False,\n                    help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')\nparser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',\n                    help='Random erase prob (default: 0.25)')\nparser.add_argument('--remode', type=str, default='pixel',\n                    help='Random erase mode (default: \"const\")')\nparser.add_argument('--recount', type=int, default=1,\n                    help='Random erase count (default: 1)')\nparser.add_argument('--resplit', action='store_true', default=False,\n                    help='Do not random erase first (clean) augmentation split')\nparser.add_argument('--mixup', type=float, default=0.8,\n                    help='mixup alpha, mixup enabled if > 0. (default: 0.)')\nparser.add_argument('--cutmix', type=float, default=1.0,\n                    help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')\nparser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,\n                    help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')\nparser.add_argument('--mixup-prob', type=float, default=1.0,\n                    help='Probability of performing mixup or cutmix when either/both is enabled')\nparser.add_argument('--mixup-switch-prob', type=float, default=0.5,\n                    help='Probability of switching to cutmix when both mixup and cutmix enabled')\nparser.add_argument('--mixup-mode', type=str, default='batch',\n                    help='How to apply mixup/cutmix params. Per \"batch\", \"pair\", or \"elem\"')\nparser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',\n                    help='Turn off mixup after this epoch, disabled if 0 (default: 0)')\nparser.add_argument('--smoothing', type=float, default=0.1,\n                    help='Label smoothing (default: 0.1)')\nparser.add_argument('--train-interpolation', type=str, default='random',\n                    help='Training interpolation (random, bilinear, bicubic default: \"random\")')\nparser.add_argument('--drop', type=float, default=0.0, metavar='PCT',\n                    help='Dropout rate (default: 0.0)')\nparser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',\n                    help='Drop connect rate, DEPRECATED, use drop-path (default: None)')\nparser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',\n                    help='Drop path rate (default: None)')\nparser.add_argument('--drop-block', type=float, default=None, metavar='PCT',\n                    help='Drop block rate (default: None)')\nparser.add_argument('--newton-maxiter', default=20, type=int,\n                    help='max iterration in newton method')\nparser.add_argument('--reset-drop', action='store_true', default=False,\n                    help='whether to reset drop')\nparser.add_argument('--kernel-method', type=str, default='cuda', choices=['torch', 'cuda'],\n                    help='The implementation way of gaussian kernel method, choose from \"cuda\" and \"torch\"')\n\n# Batch norm parameters (only works with gen_efficientnet based models currently)\nparser.add_argument('--bn-tf', action='store_true', default=False,\n                    help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')\nparser.add_argument('--bn-momentum', type=float, default=None,\n                    help='BatchNorm momentum override (if not None)')\nparser.add_argument('--bn-eps', type=float, default=None,\n                    help='BatchNorm epsilon override (if not None)')\nparser.add_argument('--sync-bn', action='store_true',\n                    help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')\nparser.add_argument('--dist-bn', type=str, default='',\n                    help='Distribute BatchNorm stats between node after each epoch (\"broadcast\", \"reduce\", or \"\")')\nparser.add_argument('--split-bn', action='store_true',\n                    help='Enable separate BN layers per augmentation split.')\n\n# Model Exponential Moving Average\nparser.add_argument('--model-ema', action='store_true', default=False,\n                    help='Enable tracking moving average of model weights')\nparser.add_argument('--model-ema-force-cpu', action='store_true', default=False,\n                    help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')\nparser.add_argument('--model-ema-decay', type=float, default=0.99996,\n                    help='decay factor for model weights moving average (default: 0.9998)')\n\n# Misc\nparser.add_argument('--seed', type=int, default=42, metavar='S',\n                    help='random seed (default: 42)')\nparser.add_argument('--log-interval', type=int, default=50, metavar='N',\n                    help='how many batches to wait before logging training status')\nparser.add_argument('--recovery-interval', type=int, default=0, metavar='N',\n                    help='how many batches to wait before writing recovery checkpoint')\nparser.add_argument('-j', '--workers', type=int, default=8, metavar='N',\n                    help='how many training processes to use (default: 1)')\nparser.add_argument('--num-gpu', type=int, default=1,\n                    help='Number of GPUS to use')\nparser.add_argument('--save-images', action='store_true', default=False,\n                    help='save images of inputs bathes every log interval for debugging')\nparser.add_argument('--amp', action='store_true', default=False,\n                    help='use NVIDIA Apex AMP or Native AMP for mixed precision training')\nparser.add_argument('--apex-amp', action='store_true', default=False,\n                    help='Use NVIDIA Apex AMP mixed precision')\nparser.add_argument('--native-amp', action='store_true', default=False,\n                    help='Use Native Torch AMP mixed precision')\nparser.add_argument('--channels-last', action='store_true', default=False,\n                    help='Use channels_last memory layout')\nparser.add_argument('--pin-mem', action='store_true', default=False,\n                    help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')\nparser.add_argument('--no-prefetcher', action='store_true', default=False,\n                    help='disable fast prefetcher')\nparser.add_argument('--output', default='/home/hexiang/TransferLearning_For_DVS/Results_lastest/', type=str, metavar='PATH',\n                    help='path to output folder (default: none, current dir)')\nparser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',\n                    help='Best metric (default: \"top1\"')\nparser.add_argument('--tta', type=int, default=0, metavar='N',\n                    help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')\nparser.add_argument('--local_rank', default=0, type=int)\nparser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,\n                    help='use the multi-epochs-loader to save time at the beginning of every epoch')\nparser.add_argument('--eval', action='store_true', help='Perform evaluation only')\nparser.add_argument('--device', type=int, default=0)\n\n# Spike parameters\nparser.add_argument('--step', type=int, default=10, help='Simulation time step (default: 10)')\nparser.add_argument('--encode', type=str, default='direct', help='Input encode method (default: direct)')\nparser.add_argument('--temporal-flatten', action='store_true',\n                    help='Temporal flatten to channels. ONLY FOR EVENT DATA TRAINING BY ANN')\nparser.add_argument('--adaptive-node', action='store_true')\nparser.add_argument('--critical-loss', action='store_true')\n\n# neuron type\nparser.add_argument('--node-type', type=str, default='LIFNode', help='Node type in network (default: PLIF)')\nparser.add_argument('--act-fun', type=str, default='GateGrad',\n                    help='Surogate Function in node. Only for Surrogate nodes (default: AtanGrad)')\nparser.add_argument('--threshold', type=float, default=.5, help='Firing threshold (default: 0.5)')\nparser.add_argument('--tau', type=float, default=2., help='Attenuation coefficient (default: 2.)')\nparser.add_argument('--requires-thres-grad', action='store_true')\nparser.add_argument('--sigmoid-thres', action='store_true')\n\nparser.add_argument('--loss-fn', type=str, default='ce', help='loss function (default: ce)')\nparser.add_argument('--noisy-grad', type=float, default=0.,\n                    help='Add noise to backward, sometime will make higher accuracy (default: 0.)')\nparser.add_argument('--spike-output', action='store_true', default=False,\n                    help='Using mem output or spike output (default: False)')\nparser.add_argument('--n_groups', type=int, default=1)\n\n# EventData Augmentation\nparser.add_argument('--mix-up', action='store_true', help='Mix-up for event data (default: False)')\nparser.add_argument('--cut-mix', action='store_true', help='CutMix for event data (default: False)')\nparser.add_argument('--event-mix', action='store_true', help='EventMix for event data (default: False)')\nparser.add_argument('--cutmix_beta', type=float, default=1.0, help='cutmix_beta (default: 1.)')\nparser.add_argument('--cutmix_prob', type=float, default=0.5, help='cutmix_prib for event data (default: .5)')\nparser.add_argument('--cutmix_num', type=int, default=1, help='cutmix_num for event data (default: 1)')\nparser.add_argument('--cutmix_noise', type=float, default=0.,\n                    help='Add Pepper noise after mix, sometimes work (default: 0.)')\nparser.add_argument('--gaussian-n', type=int, default=3)\nparser.add_argument('--rand-aug', action='store_true',\n                    help='Rand Augment for Event data (default: False)')\nparser.add_argument('--randaug_n', type=int, default=3,\n                    help='Rand Augment times n (default: 3)')\nparser.add_argument('--randaug_m', type=int, default=15,\n                    help='Rand Augment times n (default: 15) (0-30)')\nparser.add_argument('--train-portion', type=float, default=0.9,\n                    help='Dataset portion, only for datasets which do not have validation set (default: 0.9)')\nparser.add_argument('--event-size', default=48, type=int,\n                    help='Event size. Resize event data before process (default: 48)')\nparser.add_argument('--layer-by-layer', action='store_true',\n                    help='forward step-by-step or layer-by-layer. '\n                         'Larger Model with layer-by-layer will be faster (default: False)')\nparser.add_argument('--node-resume', type=str, default='',\n                    help='resume weights in node for adaptive node. (default: False)')\nparser.add_argument('--node-trainable', action='store_true')\n\n# visualize\nparser.add_argument('--visualize', action='store_true',\n                    help='Visualize spiking map for each layer, only for validate (default: False)')\nparser.add_argument('--spike-rate', action='store_true',\n                    help='Print spiking rate for each layer, only for validate(default: False)')\nparser.add_argument('--tsne', action='store_true')\nparser.add_argument('--conf-mat', action='store_true')\n\nparser.add_argument('--suffix', type=str, default='',\n                    help='Add an additional suffix to the save path (default: \\'\\')')\n\n# for reconstructing es-imagenet\nparser.add_argument('--reconstructed', action='store_true',\n                    help='for ES-imagenet dataset')\n\nparser.add_argument('--DVS-DA', action='store_true',\n                    help='use DA on DVS')\n\n# train data used ratio\nparser.add_argument('--traindata-ratio', default=1.0, type=float,\n                    help='training data ratio')\n\n# use TET loss or not (all default False, do not use)\nparser.add_argument('--TET-loss-first', action='store_true',\n                    help='use TET loss one part')\n\nparser.add_argument('--TET-loss-second', action='store_true',\n                    help='use TET loss two part')\n\ntry:\n    from apex import amp\n    from apex.parallel import DistributedDataParallel as ApexDDP\n    from apex.parallel import convert_syncbn_model\n\n    has_apex = True\nexcept ImportError:\n    has_apex = False\n\nhas_native_amp = False\ntry:\n    if getattr(torch.cuda.amp, 'autocast') is not None:\n        has_native_amp = True\nexcept AttributeError:\n    pass\n\n\ndef _parse_args():\n    # Do we have a config file to parse?\n    args_config, remaining = config_parser.parse_known_args()\n    if args_config.config:\n        with open(args_config.config, 'r') as f:\n            cfg = yaml.safe_load(f)\n            parser.set_defaults(**cfg)\n\n    # The main arg parser parses the rest of the args, the usual\n    # defaults will have been overridden if config file specified.\n    args = parser.parse_args(remaining)\n\n    # Cache the args as a text string to save them in the output dir later\n    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)\n    return args, args_text\n\n\ndef main():\n    args, args_text = _parse_args()\n    # args.no_spike_output = args.no_spike_output | args.cut_mix\n    args.no_spike_output = True\n    output_dir = ''\n    if args.local_rank == 0:\n        output_base = args.output if args.output else './output'\n        exp_name = '-'.join([\n            args.model,\n            args.dataset,\n            str(args.step),\n            \"seed_{}\".format(args.seed),\n            \"bs_{}\".format(args.batch_size),\n            \"DA_{}\".format(args.DVS_DA),\n            \"ls_{}\".format(args.smoothing),\n            \"lr_{}\".format(args.lr),\n            \"traindataratio_{}\".format(args.traindata_ratio),\n            \"TET_first_{}\".format(args.TET_loss_first),\n            \"TET_second_{}\".format(args.TET_loss_second),\n        ])\n        output_dir = get_outdir(output_base, 'Baseline', exp_name)\n        args.output_dir = output_dir\n        setup_default_logging(log_path=os.path.join(output_dir, 'log.txt'))\n\n    else:\n        setup_default_logging()\n\n    args.prefetcher = not args.no_prefetcher\n    args.distributed = False\n    if 'WORLD_SIZE' in os.environ:\n        args.distributed = int(os.environ['WORLD_SIZE']) > 1\n        if args.distributed and args.num_gpu > 1:\n            _logger.warning(\n                'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')\n            args.num_gpu = 1\n\n    # args.device = 'cuda:0'\n    args.world_size = 1\n    args.rank = 0  # global rank\n    if args.distributed:\n        args.num_gpu = 1\n        args.device = 'cuda:%d' % args.local_rank\n        torch.cuda.set_device(args.local_rank)\n        torch.distributed.init_process_group(backend='nccl', init_method='env://')\n        args.world_size = torch.distributed.get_world_size()\n        args.rank = torch.distributed.get_rank()\n    else:\n        torch.cuda.set_device('cuda:%d' % args.device)\n    assert args.rank >= 0\n\n    if args.distributed:\n        _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'\n                     % (args.rank, args.world_size))\n    else:\n        _logger.info('Training with a single process on %d GPUs.' % args.num_gpu)\n\n    # torch.manual_seed(args.seed + args.rank)\n    setup_seed(args.seed + args.rank)\n\n    model = create_model(\n        args.model,\n        pretrained=args.pretrained,\n        num_classes=args.num_classes,\n        adaptive_node=args.adaptive_node,\n        dataset=args.dataset,\n        step=args.step,\n        encode_type=args.encode,\n        node_type=eval(args.node_type),\n        threshold=args.threshold,\n        tau=args.tau,\n        sigmoid_thres=args.sigmoid_thres,\n        requires_thres_grad=args.requires_thres_grad,\n        spike_output=not args.no_spike_output,\n        act_fun=args.act_fun,\n        temporal_flatten=args.temporal_flatten,\n        layer_by_layer=args.layer_by_layer,\n        n_groups=args.n_groups,\n        reconstruct=args.reconstructed,\n        TET_loss=args.TET_loss_first or args.TET_loss_second\n    )\n\n    if 'dvs' in args.dataset:\n        args.channels = 2\n    elif 'mnist' in args.dataset:\n        args.channels = 1\n    else:\n        args.channels = 3\n    # flops, params = profile(model, inputs=(torch.randn(1, args.channels, args.event_size, args.event_size),), verbose=False)\n    # _logger.info('flops = %fM', flops / 1e6)\n    # _logger.info('param size = %fM', params / 1e6)\n\n    linear_scaled_lr = args.lr * args.batch_size * args.world_size / 1024.0\n    args.lr = linear_scaled_lr\n    _logger.info(\"learning rate is %f\" % linear_scaled_lr)\n\n    if args.local_rank == 0:\n        _logger.info('Model %s created, param count: %d' %\n                     (args.model, sum([m.numel() for m in model.parameters()])))\n\n    num_aug_splits = 0\n    if args.aug_splits > 0:\n        assert args.aug_splits > 1, 'A split of 1 makes no sense'\n        num_aug_splits = args.aug_splits\n\n    if args.split_bn:\n        assert num_aug_splits > 1 or args.resplit\n        model = convert_splitbn_model(model, max(num_aug_splits, 2))\n\n    use_amp = None\n    if args.amp:\n        # for backwards compat, `--amp` arg tries apex before native amp\n        if has_apex:\n            args.apex_amp = True\n        elif has_native_amp:\n            args.native_amp = True\n    if args.apex_amp and has_apex:\n        use_amp = 'apex'\n    elif args.native_amp and has_native_amp:\n        use_amp = 'native'\n    elif args.apex_amp or args.native_amp:\n        _logger.warning(\"Neither APEX or native Torch AMP is available, using float32. \"\n                        \"Install NVIDA apex or upgrade to PyTorch 1.6\")\n\n    if args.num_gpu > 1:\n        if use_amp == 'apex':\n            _logger.warning(\n                'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.')\n            use_amp = None\n        model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()\n        assert not args.channels_last, \"Channels last not supported with DP, use DDP.\"\n    else:\n        model = model.cuda()\n        if args.channels_last:\n            model = model.to(memory_format=torch.channels_last)\n\n    optimizer = create_optimizer(args, model)\n\n    amp_autocast = suppress  # do nothing\n    loss_scaler = None\n    if use_amp == 'apex':\n        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')\n        loss_scaler = ApexScaler()\n        if args.local_rank == 0:\n            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')\n    elif use_amp == 'native':\n        amp_autocast = torch.cuda.amp.autocast\n        loss_scaler = NativeScaler()\n        if args.local_rank == 0:\n            _logger.info('Using native Torch AMP. Training in mixed precision.')\n    else:\n        if args.local_rank == 0:\n            _logger.info('AMP not enabled. Training in float32.')\n\n    # optionally resume from a checkpoint\n    resume_epoch = None\n    if args.resume and args.eval_checkpoint == '':\n        args.eval_checkpoint = args.resume\n    if args.resume:\n        args.eval = True\n        # checkpoint = torch.load(args.resume, map_location='cpu')\n        # model.load_state_dict(checkpoint['state_dict'], False)\n        resume_epoch = resume_checkpoint(\n            model, args.resume,\n            optimizer=None if args.no_resume_opt else optimizer,\n            loss_scaler=None if args.no_resume_opt else loss_scaler,\n            log_info=args.local_rank == 0)\n        # print(model.get_attr('mu'))\n        # print(model.get_attr('sigma'))\n\n    if args.critical_loss or args.spike_rate:\n        model.set_requires_fp(True)\n\n    model_ema = None\n    if args.model_ema:\n        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper\n        model_ema = ModelEma(\n            model,\n            decay=args.model_ema_decay,\n            device='cpu' if args.model_ema_force_cpu else '',\n            resume=args.resume)\n\n    if args.node_resume:\n        ckpt = torch.load(args.node_resume, map_location='cpu')\n        model.load_node_weight(ckpt, args.node_trainable)\n\n    model_without_ddp = model\n    if args.distributed:\n        if args.sync_bn:\n            assert not args.split_bn\n            try:\n                if has_apex and use_amp != 'native':\n                    # Apex SyncBN preferred unless native amp is activated\n                    model = convert_syncbn_model(model)\n                else:\n                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)\n                if args.local_rank == 0:\n                    _logger.info(\n                        'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '\n                        'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')\n            except Exception as e:\n                _logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')\n        if has_apex and use_amp != 'native':\n            # Apex DDP preferred unless native amp is activated\n            if args.local_rank == 0:\n                _logger.info(\"Using NVIDIA APEX DistributedDataParallel.\")\n            model = ApexDDP(model, delay_allreduce=True)\n        else:\n            if args.local_rank == 0:\n                _logger.info(\"Using native Torch DistributedDataParallel.\")\n            model = NativeDDP(model, device_ids=[args.local_rank],\n                              find_unused_parameters=True)  # can use device str in Torch >= 1.1\n        model_without_ddp = model.module\n    # NOTE: EMA model does not need to be wrapped by DDP\n\n    lr_scheduler, num_epochs = create_scheduler(args, optimizer)\n    start_epoch = 0\n    if args.start_epoch is not None:\n        # a specified start_epoch will always override the resume epoch\n        start_epoch = args.start_epoch\n    elif resume_epoch is not None:\n        start_epoch = resume_epoch\n    if lr_scheduler is not None and start_epoch > 0:\n        lr_scheduler.step(start_epoch)\n\n    if args.local_rank == 0:\n        _logger.info('Scheduled epochs: {}'.format(num_epochs))\n\n    # now config only for imnet\n    data_config = resolve_data_config(vars(args), model=model, verbose=False)\n    loader_train, loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % args.dataset)(\n        batch_size=args.batch_size,\n        step=args.step,\n        dvs_da=args.DVS_DA,\n        args=args,\n        _logge=_logger,\n        data_config=data_config,\n        num_aug_splits=num_aug_splits,\n        size=args.event_size,\n        mix_up=args.mix_up,\n        cut_mix=args.cut_mix,\n        event_mix=args.event_mix,\n        beta=args.cutmix_beta,\n        prob=args.cutmix_prob,\n        gaussian_n=args.gaussian_n,\n        num=args.cutmix_num,\n        noise=args.cutmix_noise,\n        num_classes=args.num_classes,\n        rand_aug=args.rand_aug,\n        randaug_n=args.randaug_n,\n        randaug_m=args.randaug_m,\n        portion=args.train_portion,\n        reconstruct=args.reconstructed,\n        _logger=_logger,\n        train_data_ratio=args.traindata_ratio,\n        data_mode=\"full\",\n        frames_num=12,\n        data_type=\"frequency\"\n    )\n\n    if args.loss_fn == 'mse':\n        train_loss_fn = UnilateralMse(1.)\n        validate_loss_fn = UnilateralMse(1.)\n\n    else:\n        if args.jsd:\n            assert num_aug_splits > 1  # JSD only valid with aug splits set\n            train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()\n        elif mixup_active:\n            # smoothing is handled with mixup target transform\n            train_loss_fn = SoftTargetCrossEntropy().cuda()\n        elif args.smoothing:\n            train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()\n        else:\n            train_loss_fn = nn.CrossEntropyLoss().cuda()\n\n        validate_loss_fn = nn.CrossEntropyLoss().cuda()\n\n    if args.loss_fn == 'mix':\n        train_loss_fn = MixLoss(train_loss_fn)\n        validate_loss_fn = MixLoss(validate_loss_fn)\n\n    eval_metric = args.eval_metric\n    best_metric = None\n    best_epoch = None\n\n    if args.eval:  # evaluate the model\n        if args.distributed:\n            state_dict = torch.load(args.eval_checkpoint)['state_dict_ema']\n            new_state_dict = OrderedDict()\n            # add module prefix for DDP\n            for k, v in state_dict.items():\n                k = 'module.' + k\n                new_state_dict[k] = v\n\n            model.load_state_dict(new_state_dict)\n        else:\n            model.load_state_dict(torch.load(args.eval_checkpoint)['state_dict'])\n        for i in range(1):\n            val_metrics = validate(start_epoch, model, loader_eval, validate_loss_fn, args,\n                                   visualize=args.visualize, spike_rate=args.spike_rate,\n                                   tsne=args.tsne, conf_mat=args.conf_mat)\n            print(f\"Top-1 accuracy of the model is: {val_metrics['top1']:.1f}%\")\n        return\n\n    saver = None\n    if args.local_rank == 0:\n        decreasing = True if eval_metric == 'loss' else False\n        saver = CheckpointSaver(\n            model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,\n            checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=1)\n        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:\n            f.write(args_text)\n\n    try:  # train the model\n        if args.reset_drop:\n            model_without_ddp.reset_drop_path(0.0)\n        for epoch in range(start_epoch, args.epochs):\n            if epoch == 0 and args.reset_drop:\n                model_without_ddp.reset_drop_path(args.drop_path)\n\n            if args.distributed:\n                loader_train.sampler.set_epoch(epoch)\n\n            train_metrics = train_epoch(\n                epoch, model, loader_train, optimizer, train_loss_fn, args,\n                lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,\n                amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)\n\n            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):\n                if args.local_rank == 0:\n                    _logger.info(\"Distributing BatchNorm running means and vars\")\n                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')\n\n            eval_metrics = validate(epoch, model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast,\n                                    visualize=args.visualize, spike_rate=args.spike_rate,\n                                    tsne=args.tsne, conf_mat=args.conf_mat)\n\n            if model_ema is not None and not args.model_ema_force_cpu:\n                if args.distributed and args.dist_bn in ('broadcast', 'reduce'):\n                    distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')\n                ema_eval_metrics = validate(\n                    epoch, model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)',\n                    visualize=args.visualize, spike_rate=args.spike_rate,\n                    tsne=args.tsne, conf_mat=args.conf_mat)\n                eval_metrics = ema_eval_metrics\n\n            if lr_scheduler is not None:\n                # step LR for next epoch\n                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])\n\n            update_summary(\n                epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),\n                write_header=best_metric is None)\n\n            # if saver is not None and epoch >= args.n_warm_up:\n            if saver is not None:\n                # save proper checkpoint with eval metric\n                save_metric = eval_metrics[eval_metric]\n                best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)\n\n            # if epoch == 299:  # 临时的\n            #     break\n\n    except KeyboardInterrupt:\n        pass\n    if best_metric is not None:\n        _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))\n\n\ndef train_epoch(\n        epoch, model, loader, optimizer, loss_fn, args,\n        lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,\n        loss_scaler=None, model_ema=None, mixup_fn=None):\n    if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:\n        if args.prefetcher and loader.mixup_enabled:\n            loader.mixup_enabled = False\n        elif mixup_fn is not None:\n            mixup_fn.mixup_enabled = False\n\n    second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order\n    batch_time_m = AverageMeter()\n    data_time_m = AverageMeter()\n    losses_m = AverageMeter()\n    closses_m = AverageMeter()\n    top1_m = AverageMeter()\n    top5_m = AverageMeter()\n\n    model.train()\n\n    # t, k = adjust_surrogate_coeff(100, args.epochs)\n    # model.set_attr('t', t)\n    # model.set_attr('k', k)\n\n    end = time.time()\n    last_idx = len(loader) - 1\n    num_updates = epoch * len(loader)\n    for batch_idx, (inputs, target) in enumerate(loader):\n        last_batch = batch_idx == last_idx\n        data_time_m.update(time.time() - end)\n        if not args.prefetcher or args.dataset != 'imnet':\n            inputs, target = inputs.type(torch.FloatTensor).cuda(), target.cuda()\n            if mixup_fn is not None:\n                inputs, target = mixup_fn(inputs, target)\n        if args.channels_last:\n            inputs = inputs.contiguous(memory_format=torch.channels_last)\n        with amp_autocast():\n            output = model(inputs)\n            tet_loss = 0.0\n            loss = 0.0\n            lamb = 1e-3\n            if args.TET_loss_first or args.TET_loss_second:  # 第一项必须有，也就是测两个，第一个何第一个加第二个\n                for i in range(len(output)):\n                    tet_loss += loss_fn(output[i], target)\n                tet_loss /= len(output)\n                loss = (1 - lamb) * tet_loss\n            else:\n                loss = loss_fn(output, target)\n            if args.TET_loss_second:\n                y = torch.zeros_like(output[-1]).fill_(args.threshold)\n                secondLoss = torch.nn.MSELoss()\n                tet_loss_second = secondLoss(output[-1], y)\n                loss += lamb * tet_loss_second\n            if args.TET_loss_first or args.TET_loss_second:\n                output = sum(output) / len(output)\n        if not (args.cut_mix | args.mix_up | args.event_mix) and args.dataset != 'imnet':\n            # print(output.shape, target.shape)\n            acc1, acc5 = accuracy(output, target, topk=(1, 5))\n            # acc1, = accuracy(output, target)\n        else:\n            acc1, acc5 = torch.tensor([0.]), torch.tensor([0.])\n\n        closs = torch.tensor([0.], device=loss.device)\n\n\n        loss = loss + .1 * closs\n\n        spike_rate_avg_layer_str = ''\n        threshold_str = ''\n        if not args.distributed:\n            losses_m.update(loss.item(), inputs.size(0))\n            top1_m.update(acc1.item(), inputs.size(0))\n            top5_m.update(acc5.item(), inputs.size(0))\n            closses_m.update(closs.item(), inputs.size(0))\n\n            spike_rate_avg_layer = model.get_fire_rate().tolist()\n            spike_rate_avg_layer_str = ['{:.3f}'.format(i) for i in spike_rate_avg_layer]\n            threshold = model.get_threshold()\n            threshold_str = ['{:.3f}'.format(i) for i in threshold]\n\n        optimizer.zero_grad()\n        if loss_scaler is not None:\n            loss_scaler(\n                loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order)\n        else:\n            loss.backward(create_graph=second_order)\n            if args.noisy_grad != 0.:\n                random_gradient(model, args.noisy_grad)\n            if args.clip_grad is not None:\n                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)\n            if args.opt == 'lamb':\n                optimizer.step(epoch=epoch)\n            else:\n                optimizer.step()\n\n        torch.cuda.synchronize()\n        if model_ema is not None:\n            model_ema.update(model)\n        num_updates += 1\n\n        batch_time_m.update(time.time() - end)\n        if last_batch or batch_idx % args.log_interval == 0:\n            lrl = [param_group['lr'] for param_group in optimizer.param_groups]\n            lr = sum(lrl) / len(lrl)\n\n            mu_str = ''\n            sigma_str = ''\n            if not args.distributed:\n                if 'Noise' in args.node_type:\n                    mu, sigma = model.get_noise_param()\n                    mu_str = ['{:.3f}'.format(i.detach()) for i in mu]\n                    sigma_str = ['{:.3f}'.format(i.detach()) for i in sigma]\n\n            if args.distributed:\n                reduced_loss = reduce_tensor(loss.data, args.world_size)\n                losses_m.update(reduced_loss.item(), inputs.size(0))\n                closses_m.update(reduced_loss.item(), inputs.size(0))\n\n            if args.local_rank == 0:\n                if args.distributed:\n                    _logger.info(\n                        'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '\n                        'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f})  '\n                        'cLoss: {closs.val:>9.6f} ({closs.avg:>6.4f})  '\n                        'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '\n                        'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})  '\n                        'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '\n                        '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '\n                        'LR: {lr:.3e}  '\n                        'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(\n                            epoch,\n                            batch_idx, len(loader),\n                            100. * batch_idx / last_idx,\n                            loss=losses_m,\n                            closs=closses_m,\n                            top1=top1_m,\n                            top5=top5_m,\n                            batch_time=batch_time_m,\n                            rate=inputs.size(0) * args.world_size / batch_time_m.val,\n                            rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,\n                            lr=lr,\n                            data_time=data_time_m\n                        ))\n                else:\n                    _logger.info(\n                        'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '\n                        'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f})  '\n                        'cLoss: {closs.val:>9.6f} ({closs.avg:>6.4f})  '\n                        'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '\n                        'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})  '\n                        'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '\n                        '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '\n                        'LR: {lr:.3e}  '\n                        'Data: {data_time.val:.3f} ({data_time.avg:.3f})\\n'\n                        'Fire_rate: {spike_rate}\\n'\n                        'Thres: {threshold}\\n'\n                        'Mu: {mu_str}\\n'\n                        'Sigma: {sigma_str}\\n'.format(\n                            epoch,\n                            batch_idx, len(loader),\n                            100. * batch_idx / last_idx,\n                            loss=losses_m,\n                            closs=closses_m,\n                            top1=top1_m,\n                            top5=top5_m,\n                            batch_time=batch_time_m,\n                            rate=inputs.size(0) * args.world_size / batch_time_m.val,\n                            rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,\n                            lr=lr,\n                            data_time=data_time_m,\n                            spike_rate=spike_rate_avg_layer_str,\n                            threshold=threshold_str,\n                            mu_str=mu_str,\n                            sigma_str=sigma_str\n                        ))\n\n                if args.save_images and output_dir:\n                    torchvision.utils.save_image(\n                        inputs,\n                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),\n                        padding=0,\n                        normalize=True)\n\n        if saver is not None and args.recovery_interval and (\n                last_batch or (batch_idx + 1) % args.recovery_interval == 0):\n            saver.save_recovery(epoch, batch_idx=batch_idx)\n\n        if lr_scheduler is not None:\n            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)\n\n        end = time.time()\n    # end for\n\n    if hasattr(optimizer, 'sync_lookahead'):\n        optimizer.sync_lookahead()\n\n    return OrderedDict([('loss', losses_m.avg)])\n\n\ndef validate(epoch, model, loader, loss_fn, args, amp_autocast=suppress,\n             log_suffix='', visualize=False, spike_rate=False, tsne=False, conf_mat=False):\n    batch_time_m = AverageMeter()\n    losses_m = AverageMeter()\n    closses_m = AverageMeter()\n    top1_m = AverageMeter()\n    top5_m = AverageMeter()\n\n    model.eval()\n\n    feature_vec = []\n    feature_cls = []\n    logits_vec = []\n    labels_vec = []\n\n    end = time.time()\n    last_idx = len(loader) - 1\n    with torch.no_grad():\n        for batch_idx, (inputs, target) in enumerate(loader):\n            # inputs = inputs.type(torch.float64)\n            last_batch = batch_idx == last_idx\n            if not args.prefetcher or args.dataset != 'imnet':\n                inputs = inputs.type(torch.FloatTensor).cuda()\n                target = target.cuda()\n            if args.channels_last:\n                inputs = inputs.contiguous(memory_format=torch.channels_last)\n\n            if not args.distributed:\n                if (visualize or spike_rate or tsne or conf_mat) and not args.critical_loss:\n                    model.set_requires_fp(True)\n                    # if not args.critical_loss:\n                    #     model.set_requires_fp(False)\n\n            with amp_autocast():\n                output = model(inputs)\n            if args.TET_loss_first or args.TET_loss_second:\n                output = sum(output) / len(output)\n            if isinstance(output, (tuple, list)):\n                output = output[0]\n\n            if not args.distributed:\n                if visualize:\n                    x = model.get_fp()\n                    feature_path = os.path.join(args.output_dir, 'feature_map')\n                    if os.path.exists(feature_path) is False:\n                        os.mkdir(feature_path)\n                    save_feature_map(x, feature_path)\n                    # if not args.critical_loss:\n                    #     model_config.set_requires_fp(False)\n\n                if tsne:\n                    x = model.get_fp(temporal_info=False)[-1]\n                    x = torch.nn.AdaptiveAvgPool2d((1, 1))(x)\n                    x = x.reshape(x.shape[0], -1)\n                    feature_vec.append(x)\n                    feature_cls.append(target)\n\n                if conf_mat:\n                    logits_vec.append(output)\n                    labels_vec.append(target)\n\n                if spike_rate:\n                    avg, var, spike, avg_per_step = model.get_spike_info()\n                    save_spike_info(\n                        os.path.join(args.output_dir, 'spike_info.csv'),\n                        epoch, batch_idx,\n                        args.step, avg, var,\n                        spike, avg_per_step)\n\n            # augmentation reduction\n            reduce_factor = args.tta\n            if reduce_factor > 1:\n                output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)\n                target = target[0:target.size(0):reduce_factor]\n\n            loss = loss_fn(output, target)\n            acc1, acc5 = accuracy(output, target, topk=(1, 5))\n            # acc1, = accuracy(output, target)\n\n            closs = torch.tensor([0.], device=loss.device)\n\n            if not args.distributed:\n                spike_rate_avg_layer = model.get_fire_rate().tolist()\n                threshold = model.get_threshold()\n                threshold_str = ['{:.3f}'.format(i) for i in threshold]\n                spike_rate_avg_layer_str = ['{:.3f}'.format(i) for i in spike_rate_avg_layer]\n                tot_spike = model.get_tot_spike()\n\n            if args.distributed:\n                reduced_loss = reduce_tensor(loss.data, args.world_size)\n                acc1 = reduce_tensor(acc1, args.world_size)\n                acc5 = reduce_tensor(acc5, args.world_size)\n            else:\n                reduced_loss = loss.data\n\n            torch.cuda.synchronize()\n\n            losses_m.update(reduced_loss.item(), inputs.size(0))\n            top1_m.update(acc1.item(), output.size(0))\n            top5_m.update(acc5.item(), output.size(0))\n            closses_m.update(closs.item(), inputs.size(0))\n\n            batch_time_m.update(time.time() - end)\n            end = time.time()\n            if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):\n                log_name = 'Test' + log_suffix\n\n                mu_str = ''\n                sigma_str = ''\n                if not args.distributed:\n                    if 'Noise' in args.node_type:\n                        mu, sigma = model.get_noise_param()\n                        mu_str = ['{:.3f}'.format(i.detach()) for i in mu]\n                        sigma_str = ['{:.3f}'.format(i.detach()) for i in sigma]\n\n                if args.distributed:\n                    _logger.info(\n                        '{0}: [{1:>4d}/{2}]  '\n                        'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '\n                        'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '\n                        'cLoss: {closs.val:>7.4f} ({closs.avg:>6.4f})  '\n                        'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})'\n                        'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(\n                            log_name,\n                            batch_idx,\n                            last_idx,\n                            batch_time=batch_time_m,\n                            loss=losses_m,\n                            closs=closses_m,\n                            top1=top1_m,\n                            top5=top5_m,\n                            ))\n                else:\n                    _logger.info(\n                        '{0}: [{1:>4d}/{2}]  '\n                        'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '\n                        'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '\n                        'cLoss: {closs.val:>7.4f} ({closs.avg:>6.4f})  '\n                        'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})'\n                        'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})\\n'\n                        'Fire_rate: {spike_rate}\\n'\n                        'Tot_spike: {tot_spike}\\n'\n                        'Thres: {threshold}\\n'\n                        'Mu: {mu_str}\\n'\n                        'Sigma: {sigma_str}\\n'.format(\n                            log_name,\n                            batch_idx,\n                            last_idx,\n                            batch_time=batch_time_m,\n                            loss=losses_m,\n                            closs=closses_m,\n                            top1=top1_m,\n                            top5=top5_m,\n                            spike_rate=spike_rate_avg_layer_str,\n                            tot_spike=tot_spike,\n                            threshold=threshold_str,\n                            mu_str=mu_str,\n                            sigma_str=sigma_str\n                        ))\n\n    # metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])\n    metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg)])\n\n    if not args.distributed:\n        if tsne:\n            feature_vec = torch.cat(feature_vec)\n            feature_cls = torch.cat(feature_cls)\n            plot_tsne(feature_vec, feature_cls, os.path.join(args.output_dir, 't-sne-2d.eps'))\n            plot_tsne_3d(feature_vec, feature_cls, os.path.join(args.output_dir, 't-sne-3d.eps'))\n        if conf_mat:\n            logits_vec = torch.cat(logits_vec)\n            labels_vec = torch.cat(labels_vec)\n            plot_confusion_matrix(logits_vec, labels_vec, os.path.join(args.output_dir, 'confusion_matrix.eps'))\n\n    return metrics\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "examples/Perception_and_Learning/img_cls/transfer_for_dvs/main_transfer.py",
    "content": "# -*- coding: utf-8 -*-            \n# Time : 2022/9/29 15:27\n# Author : Regulus\n# FileName: main_transfer.py\n# Explain:\n# Software: PyCharm\n\nimport argparse\nimport math\nimport time\nimport CKA\nimport numpy\nimport timm.models\nimport random as rd\nimport yaml\nimport os\nimport logging\nfrom collections import OrderedDict\nfrom contextlib import suppress\nfrom datetime import datetime\n\nfrom braincog.base.node.node import *\nfrom braincog.utils import *\nfrom braincog.base.utils.criterions import *\nfrom braincog.datasets.datasets import *\nfrom braincog.model_zoo.resnet import *\nfrom braincog.model_zoo.convnet import *\nfrom braincog.model_zoo.vgg_snn import VGG_SNN\nfrom braincog.model_zoo.resnet19_snn import resnet19\nfrom braincog.utils import save_feature_map, setup_seed\nfrom braincog.base.utils.visualization import plot_tsne_3d, plot_tsne, plot_confusion_matrix\n\nimport torch\nimport torch.nn as nn\nimport torchvision.utils\nfrom torch.nn.parallel import DistributedDataParallel as NativeDDP\nfrom rgb_hsv import RGB_HSV\nimport matplotlib.pyplot as plt\nfrom timm.data import ImageDataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset\nfrom timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model\nfrom timm.utils import *\nfrom timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy\nfrom timm.optim import create_optimizer\nfrom timm.scheduler import create_scheduler\nfrom timm.utils import ApexScaler, NativeScaler\n\n# from ptflops import get_model_complexity_info\n# from thop import profile, clever_format\n\ntorch.backends.cudnn.benchmark = True\n_logger = logging.getLogger('train')\n\n# The first arg parser parses out only the --config argument, this argument is used to\n# load a yaml file containing key-values that override the defaults for the main parser below\nconfig_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)\nparser.add_argument('-c', '--config', default='', type=str, metavar='FILE',\n                    help='YAML config file specifying default arguments')\n\n\nparser = argparse.ArgumentParser(description='SNN Training and Evaluating')\n\n# Model parameters\nparser.add_argument('--source-dataset', default='cifar10', type=str)\nparser.add_argument('--target-dataset', default='dvsc10', type=str)\nparser.add_argument('--model', default='cifar_convnet', type=str, metavar='MODEL',\n                    help='Name of model to train (default: \"countception\"')\nparser.add_argument('--pretrained', action='store_true', default=False,\n                    help='Start with pretrained version of specified network (if avail)')\nparser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',\n                    help='Initialize model from this checkpoint (default: none)')\nparser.add_argument('--resume', default='', type=str, metavar='PATH',\n                    help='Resume full model and optimizer state from checkpoint (default: none)')\nparser.add_argument('--eval_checkpoint', default='', type=str, metavar='PATH',\n                    help='path to eval checkpoint (default: none)')\nparser.add_argument('--no-resume-opt', action='store_true', default=False,\n                    help='prevent resume of optimizer state when resuming model')\nparser.add_argument('--num-classes', type=int, default=10, metavar='N',\n                    help='number of label classes (default: 1000)')\nparser.add_argument('--gp', default=None, type=str, metavar='POOL',\n                    help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')\n\n# Dataset parameters for static datasets\nparser.add_argument('--img-size', type=int, default=224, metavar='N',\n                    help='Image patch size (default: None => model default)')\nparser.add_argument('--crop-pct', default=None, type=float,\n                    metavar='N', help='inputs image center crop percent (for validation only)')\nparser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',\n                    help='Override mean pixel value of dataset')\nparser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',\n                    help='Override std deviation of of dataset')\nparser.add_argument('--interpolation', default='', type=str, metavar='NAME',\n                    help='Image resize interpolation type (overrides model)')\n\n# Dataloader parameters\nparser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',\n                    help='inputs batch size for training (default: 128)')\nparser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',\n                    help='ratio of validation batch size to training batch size (default: 1)')\n\n# Optimizer parameters\nparser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',\n                    help='Optimizer (default: \"adamw\"')\nparser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',\n                    help='Optimizer Epsilon (default: None, use opt default)')\nparser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',\n                    help='Optimizer Betas (default: None, use opt default)')\nparser.add_argument('--momentum', type=float, default=0.9, metavar='M',\n                    help='Optimizer momentum (default: 0.9)')\nparser.add_argument('--weight-decay', type=float, default=0.01,\n                    help='weight decay (default: 0.01 for adamw)')\nparser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',\n                    help='Clip gradient norm (default: None, no clipping)')\nparser.add_argument('--adam-epoch', type=int, default=1000, help='lamb switch to adamw')\n\n# Learning rate schedule parameters\nparser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',\n                    help='LR scheduler (default: \"cosine\"')\nparser.add_argument('--lr', type=float, default=5e-3, metavar='LR',\n                    help='learning rate (default: 0.01)')\nparser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',\n                    help='learning rate noise on/off epoch percentages')\nparser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',\n                    help='learning rate noise limit percent (default: 0.67)')\nparser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',\n                    help='learning rate noise std-dev (default: 1.0)')\nparser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',\n                    help='learning rate cycle len multiplier (default: 1.0)')\nparser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',\n                    help='learning rate cycle limit')\nparser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',\n                    help='warmup learning rate (default: 0.0001)')\nparser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',\n                    help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')\nparser.add_argument('--epochs', type=int, default=600, metavar='N',\n                    help='number of epochs to train (default: 2)')\nparser.add_argument('--start-epoch', default=None, type=int, metavar='N',\n                    help='manual epoch number (useful on restarts)')\nparser.add_argument('--decay-epochs', type=float, default=30, metavar='N',\n                    help='epoch interval to decay LR')\nparser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',\n                    help='epochs to warmup LR, if scheduler supports')\nparser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',\n                    help='epochs to cooldown LR at min_lr, after cyclic schedule ends')\nparser.add_argument('--patience-epochs', type=int, default=10, metavar='N',\n                    help='patience epochs for Plateau LR scheduler (default: 10')\nparser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',\n                    help='LR decay rate (default: 0.1)')\nparser.add_argument('--power', type=int, default=1, help='power')\n\n# Augmentation & regularization parameters ONLY FOR IMAGE NET\nparser.add_argument('--no-aug', action='store_true', default=False,\n                    help='Disable all training augmentation, override other train aug args')\nparser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',\n                    help='Random resize scale (default: 0.08 1.0)')\nparser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                    help='Random resize aspect ratio (default: 0.75 1.33)')\nparser.add_argument('--hflip', type=float, default=0.5,\n                    help='Horizontal flip training aug probability')\nparser.add_argument('--vflip', type=float, default=0.,\n                    help='Vertical flip training aug probability')\nparser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',\n                    help='Color jitter factor (default: 0.4)')\nparser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',\n                    help='Use AutoAugment policy. \"v0\" or \"original\". (default: None)'),\nparser.add_argument('--aug-splits', type=int, default=0,\n                    help='Number of augmentation splits (default: 0, valid: 0 or >=2)')\nparser.add_argument('--jsd', action='store_true', default=False,\n                    help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')\nparser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',\n                    help='Random erase prob (default: 0.25)')\nparser.add_argument('--remode', type=str, default='pixel',\n                    help='Random erase mode (default: \"const\")')\nparser.add_argument('--recount', type=int, default=1,\n                    help='Random erase count (default: 1)')\nparser.add_argument('--resplit', action='store_true', default=False,\n                    help='Do not random erase first (clean) augmentation split')\nparser.add_argument('--mixup', type=float, default=0.8,\n                    help='mixup alpha, mixup enabled if > 0. (default: 0.)')\nparser.add_argument('--cutmix', type=float, default=1.0,\n                    help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')\nparser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,\n                    help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')\nparser.add_argument('--mixup-prob', type=float, default=1.0,\n                    help='Probability of performing mixup or cutmix when either/both is enabled')\nparser.add_argument('--mixup-switch-prob', type=float, default=0.5,\n                    help='Probability of switching to cutmix when both mixup and cutmix enabled')\nparser.add_argument('--mixup-mode', type=str, default='batch',\n                    help='How to apply mixup/cutmix params. Per \"batch\", \"pair\", or \"elem\"')\nparser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',\n                    help='Turn off mixup after this epoch, disabled if 0 (default: 0)')\nparser.add_argument('--smoothing', type=float, default=0.1,\n                    help='Label smoothing (default: 0.1)')\nparser.add_argument('--train-interpolation', type=str, default='random',\n                    help='Training interpolation (random, bilinear, bicubic default: \"random\")')\nparser.add_argument('--drop', type=float, default=0.0, metavar='PCT',\n                    help='Dropout rate (default: 0.0)')\nparser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',\n                    help='Drop connect rate, DEPRECATED, use drop-path (default: None)')\nparser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',\n                    help='Drop path rate (default: None)')\nparser.add_argument('--drop-block', type=float, default=None, metavar='PCT',\n                    help='Drop block rate (default: None)')\nparser.add_argument('--newton-maxiter', default=20, type=int,\n                    help='max iterration in newton method')\nparser.add_argument('--reset-drop', action='store_true', default=False,\n                    help='whether to reset drop')\nparser.add_argument('--kernel-method', type=str, default='cuda', choices=['torch', 'cuda'],\n                    help='The implementation way of gaussian kernel method, choose from \"cuda\" and \"torch\"')\n\n# Batch norm parameters (only works with gen_efficientnet based models currently)\nparser.add_argument('--bn-tf', action='store_true', default=False,\n                    help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')\nparser.add_argument('--bn-momentum', type=float, default=None,\n                    help='BatchNorm momentum override (if not None)')\nparser.add_argument('--bn-eps', type=float, default=None,\n                    help='BatchNorm epsilon override (if not None)')\nparser.add_argument('--sync-bn', action='store_true',\n                    help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')\nparser.add_argument('--dist-bn', type=str, default='',\n                    help='Distribute BatchNorm stats between node after each epoch (\"broadcast\", \"reduce\", or \"\")')\nparser.add_argument('--split-bn', action='store_true',\n                    help='Enable separate BN layers per augmentation split.')\n\n# Model Exponential Moving Average\nparser.add_argument('--model-ema', action='store_true', default=False,\n                    help='Enable tracking moving average of model weights')\nparser.add_argument('--model-ema-force-cpu', action='store_true', default=False,\n                    help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')\nparser.add_argument('--model-ema-decay', type=float, default=0.99996,\n                    help='decay factor for model weights moving average (default: 0.9998)')\n\n# Misc\nparser.add_argument('--seed', type=int, default=42, metavar='S',\n                    help='random seed (default: 42)')\nparser.add_argument('--log-interval', type=int, default=50, metavar='N',\n                    help='how many batches to wait before logging training status')\nparser.add_argument('--recovery-interval', type=int, default=0, metavar='N',\n                    help='how many batches to wait before writing recovery checkpoint')\nparser.add_argument('-j', '--workers', type=int, default=8, metavar='N',\n                    help='how many training processes to use (default: 1)')\nparser.add_argument('--num-gpu', type=int, default=1,\n                    help='Number of GPUS to use')\nparser.add_argument('--save-images', action='store_true', default=False,\n                    help='save images of inputs bathes every log interval for debugging')\nparser.add_argument('--amp', action='store_true', default=False,\n                    help='use NVIDIA Apex AMP or Native AMP for mixed precision training')\nparser.add_argument('--apex-amp', action='store_true', default=False,\n                    help='Use NVIDIA Apex AMP mixed precision')\nparser.add_argument('--native-amp', action='store_true', default=False,\n                    help='Use Native Torch AMP mixed precision')\nparser.add_argument('--channels-last', action='store_true', default=False,\n                    help='Use channels_last memory layout')\nparser.add_argument('--pin-mem', action='store_true', default=False,\n                    help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')\nparser.add_argument('--no-prefetcher', action='store_true', default=False,\n                    help='disable fast prefetcher')\nparser.add_argument('--output', default='/home/hexiang/TransferLearning_For_DVS/Results_lastest/', type=str, metavar='PATH',\n                    help='path to output folder (default: none, current dir)')\nparser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',\n                    help='Best metric (default: \"top1\"')\nparser.add_argument('--tta', type=int, default=0, metavar='N',\n                    help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')\nparser.add_argument('--local_rank', default=0, type=int)\nparser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,\n                    help='use the multi-epochs-loader to save time at the beginning of every epoch')\nparser.add_argument('--eval', action='store_true', help='Perform evaluation only')\nparser.add_argument('--device', type=int, default=0)\n\n# Spike parameters\nparser.add_argument('--step', type=int, default=10, help='Simulation time step (default: 10)')\nparser.add_argument('--encode', type=str, default='direct', help='Input encode method (default: direct)')\nparser.add_argument('--temporal-flatten', action='store_true',\n                    help='Temporal flatten to channels. ONLY FOR EVENT DATA TRAINING BY ANN')\nparser.add_argument('--adaptive-node', action='store_true')\nparser.add_argument('--critical-loss', action='store_true')\n\n# neuron type\nparser.add_argument('--node-type', type=str, default='LIFNode', help='Node type in network (default: PLIF)')\nparser.add_argument('--act-fun', type=str, default='GateGrad',\n                    help='Surogate Function in node. Only for Surrogate nodes (default: AtanGrad)')\nparser.add_argument('--threshold', type=float, default=.5, help='Firing threshold (default: 0.5)')\nparser.add_argument('--tau', type=float, default=2., help='Attenuation coefficient (default: 2.)')\nparser.add_argument('--requires-thres-grad', action='store_true')\nparser.add_argument('--sigmoid-thres', action='store_true')\n\nparser.add_argument('--loss-fn', type=str, default='ce', help='loss function (default: ce)')\nparser.add_argument('--noisy-grad', type=float, default=0.,\n                    help='Add noise to backward, sometime will make higher accuracy (default: 0.)')\nparser.add_argument('--spike-output', action='store_true', default=False,\n                    help='Using mem output or spike output (default: False)')\nparser.add_argument('--n_groups', type=int, default=1)\n\n# EventData Augmentation\nparser.add_argument('--mix-up', action='store_true', help='Mix-up for event data (default: False)')\nparser.add_argument('--cut-mix', action='store_true', help='CutMix for event data (default: False)')\nparser.add_argument('--event-mix', action='store_true', help='EventMix for event data (default: False)')\nparser.add_argument('--cutmix_beta', type=float, default=1.0, help='cutmix_beta (default: 1.)')\nparser.add_argument('--cutmix_prob', type=float, default=0.5, help='cutmix_prib for event data (default: .5)')\nparser.add_argument('--cutmix_num', type=int, default=1, help='cutmix_num for event data (default: 1)')\nparser.add_argument('--cutmix_noise', type=float, default=0.,\n                    help='Add Pepper noise after mix, sometimes work (default: 0.)')\nparser.add_argument('--gaussian-n', type=int, default=3)\nparser.add_argument('--rand-aug', action='store_true',\n                    help='Rand Augment for Event data (default: False)')\nparser.add_argument('--randaug_n', type=int, default=3,\n                    help='Rand Augment times n (default: 3)')\nparser.add_argument('--randaug_m', type=int, default=15,\n                    help='Rand Augment times n (default: 15) (0-30)')\nparser.add_argument('--train-portion', type=float, default=0.9,\n                    help='Dataset portion, only for datasets which do not have validation set (default: 0.9)')\nparser.add_argument('--event-size', default=48, type=int,\n                    help='Event size. Resize event data before process (default: 48)')\nparser.add_argument('--layer-by-layer', action='store_true',\n                    help='forward step-by-step or layer-by-layer. '\n                         'Larger Model with layer-by-layer will be faster (default: False)')\nparser.add_argument('--node-resume', type=str, default='',\n                    help='resume weights in node for adaptive node. (default: False)')\nparser.add_argument('--node-trainable', action='store_true')\n\n# visualize\nparser.add_argument('--visualize', action='store_true',\n                    help='Visualize spiking map for each layer, only for validate (default: False)')\nparser.add_argument('--spike-rate', action='store_true',\n                    help='Print spiking rate for each layer, only for validate(default: False)')\nparser.add_argument('--tsne', action='store_true')\nparser.add_argument('--conf-mat', action='store_true')\n\nparser.add_argument('--suffix', type=str, default='',\n                    help='Add an additional suffix to the save path (default: \\'\\')')\n\n# Transfer Learning loss choice\nparser.add_argument('--domain-loss', action='store_true',\n                    help='add domain loss')\nparser.add_argument('--semantic-loss', action='store_true',\n                    help='add semantic loss')\nparser.add_argument('--domain-loss-coefficient', type=float, default=1.0,\n                    help='domain loss coefficient(default: 1.0)')\nparser.add_argument('--semantic-loss-coefficient', type=float, default=1.0,\n                    help='domain loss coefficient(default: 1.0)')\n\n# use TET loss or not (all default False, do not use)\n\nparser.add_argument('--TET-loss-first', action='store_true',\n                    help='use TET loss one part')\n\nparser.add_argument('--TET-loss-second', action='store_true',\n                    help='use TET loss two part')\n\nparser.add_argument('--DVS-DA', action='store_true',\n                    help='use DA on DVS')\n\n# train data used ratio\nparser.add_argument('--traindata-ratio', default=1.0, type=float,\n                    help='training data ratio')\n\n# snr value\nparser.add_argument('--snr', default=0, type=int,\n                    help='random noise amplitude controled by snr, 0 means no noise')\n\n# margin m\nparser.add_argument('--m', default=-1.0, type=float,\n                    help='margin')\n\nsource_input_list, source_label_list = [], []\nCALTECH101_list, ImageNet_list = [], []\n\ntry:\n    from apex import amp\n    from apex.parallel import DistributedDataParallel as ApexDDP\n    from apex.parallel import convert_syncbn_model\n\n    has_apex = True\nexcept ImportError:\n    has_apex = False\n\nhas_native_amp = False\ntry:\n    if getattr(torch.cuda.amp, 'autocast') is not None:\n        has_native_amp = True\nexcept AttributeError:\n    pass\n\n\ndef _parse_args():\n    # Do we have a config file to parse?\n    args_config, remaining = config_parser.parse_known_args()\n    if args_config.config:\n        with open(args_config.config, 'r') as f:\n            cfg = yaml.safe_load(f)\n            parser.set_defaults(**cfg)\n\n    # The main arg parser parses the rest of the args, the usual\n    # defaults will have been overridden if config file specified.\n    args = parser.parse_args(remaining)\n\n    # Cache the args as a text string to save them in the output dir later\n    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)\n    return args, args_text\n\n\ndef main():\n    torch.set_num_threads(20)\n    os.environ[\"OMP_NUM_THREADS\"] = \"20\"  # 设置OpenMP计算库的线程数\n    os.environ[\"MKL_NUM_THREADS\"] = \"20\"  # 设置MKL-DNN CPU加速库的线程数。\n    args, args_text = _parse_args()\n    args.no_spike_output = True\n    output_dir = ''\n    if args.local_rank == 0:\n        output_base = args.output if args.output else './output'\n        exp_name = '-'.join([\n            args.model,\n            args.target_dataset,\n            str(args.step),\n            \"bs_{}\".format(args.batch_size),\n            \"seed_{}\".format(args.seed),\n            \"DA_{}\".format(args.DVS_DA),\n            \"ls_{}\".format(args.smoothing),\n            \"lr_{}\".format(args.lr),\n            \"m_{}\".format(args.m),\n            \"domainLoss_{}\".format(args.domain_loss),\n            \"semanticLoss_{}\".format(args.semantic_loss),\n            \"domain_loss_coefficient{}\".format(args.domain_loss_coefficient),\n            \"semantic_loss_coefficient{}\".format(args.semantic_loss_coefficient),\n            \"traindataratio_{}\".format(args.traindata_ratio),\n            \"TETfirst_{}\".format(args.TET_loss_first),\n            \"TETsecond_{}\".format(args.TET_loss_second),\n        ])\n        output_dir = get_outdir(output_base, 'train_TCKA_test_nop', exp_name)\n        args.output_dir = output_dir\n        setup_default_logging(log_path=os.path.join(output_dir, 'log.txt'))\n\n    else:\n        setup_default_logging()\n\n    args.prefetcher = not args.no_prefetcher\n    args.distributed = False\n    if 'WORLD_SIZE' in os.environ:\n        args.distributed = int(os.environ['WORLD_SIZE']) > 1\n        if args.distributed and args.num_gpu > 1:\n            _logger.warning(\n                'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')\n            args.num_gpu = 1\n\n    # args.device = 'cuda:0'\n    args.world_size = 1\n    args.rank = 0  # global rank\n    if args.distributed:\n        args.num_gpu = 1\n        args.device = 'cuda:%d' % args.local_rank\n        torch.cuda.set_device(args.local_rank)\n        torch.distributed.init_process_group(backend='nccl', init_method='env://')\n        args.world_size = torch.distributed.get_world_size()\n        args.rank = torch.distributed.get_rank()\n    else:\n        torch.cuda.set_device('cuda:%d' % args.device)\n    assert args.rank >= 0\n\n    if args.distributed:\n        _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'\n                     % (args.rank, args.world_size))\n    else:\n        _logger.info('Training with a single process on %d GPUs.' % args.num_gpu)\n\n    # torch.manual_seed(args.seed + args.rank)\n    setup_seed(args.seed + args.rank)\n\n    model = create_model(\n        args.model,\n        pretrained=args.pretrained,\n        num_classes=args.num_classes,\n        adaptive_node=args.adaptive_node,\n        dataset=args.target_dataset,\n        step=args.step,\n        encode_type=args.encode,\n        node_type=eval(args.node_type),\n        threshold=args.threshold,\n        tau=args.tau,\n        sigmoid_thres=args.sigmoid_thres,\n        requires_thres_grad=args.requires_thres_grad,\n        spike_output=not args.no_spike_output,\n        act_fun=args.act_fun,\n        temporal_flatten=args.temporal_flatten,\n        layer_by_layer=args.layer_by_layer,\n        n_groups=args.n_groups,\n    )\n\n    if 'dvs' in args.target_dataset:\n        args.channels = 2\n    elif 'mnist' in args.target_dataset:\n        args.channels = 1\n    else:\n        args.channels = 3\n    # flops, params = profile(model, inputs=(torch.randn(1, args.channels, args.event_size, args.event_size),), verbose=False)\n    # _logger.info('flops = %fM', flops / 1e6)\n    # _logger.info('param size = %fM', params / 1e6)\n\n    linear_scaled_lr = args.lr * args.batch_size * args.world_size / 1024.0\n    args.lr = linear_scaled_lr\n    _logger.info(\"learning rate is %f\" % linear_scaled_lr)\n\n    if args.local_rank == 0:\n        _logger.info('Model %s created, param count: %d' %\n                     (args.model, sum([m.numel() for m in model.parameters()])))\n\n\n    num_aug_splits = 0\n    if args.aug_splits > 0:\n        assert args.aug_splits > 1, 'A split of 1 makes no sense'\n        num_aug_splits = args.aug_splits\n\n    if args.split_bn:\n        assert num_aug_splits > 1 or args.resplit\n        model = convert_splitbn_model(model, max(num_aug_splits, 2))\n\n    use_amp = None\n    if args.amp:\n        # for backwards compat, `--amp` arg tries apex before native amp\n        if has_apex:\n            args.apex_amp = True\n        elif has_native_amp:\n            args.native_amp = True\n    if args.apex_amp and has_apex:\n        use_amp = 'apex'\n    elif args.native_amp and has_native_amp:\n        use_amp = 'native'\n    elif args.apex_amp or args.native_amp:\n        _logger.warning(\"Neither APEX or native Torch AMP is available, using float32. \"\n                        \"Install NVIDA apex or upgrade to PyTorch 1.6\")\n\n    if args.num_gpu > 1:\n        if use_amp == 'apex':\n            _logger.warning(\n                'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.')\n            use_amp = None\n        model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()\n        assert not args.channels_last, \"Channels last not supported with DP, use DDP.\"\n    else:\n        model = model.cuda()\n        if args.channels_last:\n            model = model.to(memory_format=torch.channels_last)\n\n    optimizer = create_optimizer(args, model)\n\n    amp_autocast = suppress  # do nothing\n    loss_scaler = None\n    if use_amp == 'apex':\n        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')\n        loss_scaler = ApexScaler()\n        if args.local_rank == 0:\n            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')\n    elif use_amp == 'native':\n        amp_autocast = torch.cuda.amp.autocast\n        loss_scaler = NativeScaler()\n        if args.local_rank == 0:\n            _logger.info('Using native Torch AMP. Training in mixed precision.')\n    else:\n        if args.local_rank == 0:\n            _logger.info('AMP not enabled. Training in float32.')\n\n    # optionally resume from a checkpoint\n    resume_epoch = None\n    if args.resume and args.eval_checkpoint == '':\n        args.eval_checkpoint = args.resume\n    if args.resume:\n        args.eval = True\n        resume_epoch = resume_checkpoint(\n            model, args.resume,\n            optimizer=None if args.no_resume_opt else optimizer,\n            loss_scaler=None if args.no_resume_opt else loss_scaler,\n            log_info=args.local_rank == 0)\n\n    if args.critical_loss or args.spike_rate:\n        model.set_requires_fp(True)\n\n    model_ema = None\n    if args.model_ema:\n        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper\n        model_ema = ModelEma(\n            model,\n            decay=args.model_ema_decay,\n            device='cpu' if args.model_ema_force_cpu else '',\n            resume=args.resume)\n\n    if args.node_resume:\n        ckpt = torch.load(args.node_resume, map_location='cpu')\n        model.load_node_weight(ckpt, args.node_trainable)\n\n    model_without_ddp = model\n    if args.distributed:\n        if args.sync_bn:\n            assert not args.split_bn\n            try:\n                if has_apex and use_amp != 'native':\n                    # Apex SyncBN preferred unless native amp is activated\n                    model = convert_syncbn_model(model)\n                else:\n                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)\n                if args.local_rank == 0:\n                    _logger.info(\n                        'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '\n                        'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')\n            except Exception as e:\n                _logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')\n        if has_apex and use_amp != 'native':\n            # Apex DDP preferred unless native amp is activated\n            if args.local_rank == 0:\n                _logger.info(\"Using NVIDIA APEX DistributedDataParallel.\")\n            model = ApexDDP(model, delay_allreduce=True)\n        else:\n            if args.local_rank == 0:\n                _logger.info(\"Using native Torch DistributedDataParallel.\")\n            model = NativeDDP(model, device_ids=[args.local_rank],\n                              find_unused_parameters=True)  # can use device str in Torch >= 1.1\n        model_without_ddp = model.module\n    # NOTE: EMA model does not need to be wrapped by DDP\n\n    lr_scheduler, num_epochs = create_scheduler(args, optimizer)\n    start_epoch = 0\n    if args.start_epoch is not None:\n        # a specified start_epoch will always override the resume epoch\n        start_epoch = args.start_epoch\n    elif resume_epoch is not None:\n        start_epoch = resume_epoch\n    if lr_scheduler is not None and start_epoch > 0:\n        lr_scheduler.step(start_epoch)\n\n    if args.local_rank == 0:\n        _logger.info('Scheduled epochs: {}'.format(num_epochs))\n\n    # now config only for imnet\n    data_config = resolve_data_config(vars(args), model=model, verbose=False)\n    source_loader_train, _, _, _ = eval('get_transfer_%s_data' % args.source_dataset)(\n        batch_size=args.batch_size,\n        step=args.step,\n        args=args,\n        _logge=_logger,\n        data_config=data_config,\n        num_aug_splits=num_aug_splits,\n        size=args.event_size,\n        mix_up=args.mix_up,\n        cut_mix=args.cut_mix,\n        event_mix=args.event_mix,\n        beta=args.cutmix_beta,\n        prob=args.cutmix_prob,\n        gaussian_n=args.gaussian_n,\n        num=args.cutmix_num,\n        noise=args.cutmix_noise,\n        num_classes=args.num_classes,\n        rand_aug=args.rand_aug,\n        randaug_n=args.randaug_n,\n        randaug_m=args.randaug_m,\n        portion=args.train_portion,\n        _logger=_logger,\n    )\n\n\n    target_loader_train, target_loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % args.target_dataset)(\n        batch_size=args.batch_size,\n        dvs_da=args.DVS_DA,\n        step=args.step,\n        args=args,\n        _logge=_logger,\n        data_config=data_config,\n        num_aug_splits=num_aug_splits,\n        size=args.event_size,\n        mix_up=args.mix_up,\n        cut_mix=args.cut_mix,\n        event_mix=args.event_mix,\n        beta=args.cutmix_beta,\n        prob=args.cutmix_prob,\n        gaussian_n=args.gaussian_n,\n        num=args.cutmix_num,\n        noise=args.cutmix_noise,\n        num_classes=args.num_classes,\n        rand_aug=args.rand_aug,\n        randaug_n=args.randaug_n,\n        randaug_m=args.randaug_m,\n        portion=args.train_portion,\n        _logger=_logger,\n        train_data_ratio=args.traindata_ratio,\n        snr=args.snr,\n        data_mode=\"full\",\n        frames_num=12,\n        data_type=\"frequency\"\n    )\n\n    global source_input_list, source_label_list, CALTECH101_list, ImageNet_list\n    if args.target_dataset == \"dvsc10\" or args.target_dataset == \"NCALTECH101\" or args.target_dataset == \"nomni\":  # ImageNet中回来的loader其实是数据集,在后面处理\n        source_input_list, source_label_list = next(iter(source_loader_train))\n        # for i in range(30001, 30005):\n        #     # vis origin picture\n        #     plt.figure()\n        #     plt.imshow(source_input_list[i].permute(1, 2, 0).numpy())\n        #     plt.savefig(\"./origin_image.jpg\")\n        #     plt.show()\n        # vis HSV picture\n        # for i in range(30001, 30005):  # 30001.i\n        #     convertor = RGB_HSV()\n        #     plt.figure()\n        #     plt.imshow(convertor.rgb_to_hsv(source_input_list)[i, :, :, :].permute(1, 2, 0).numpy())\n        #     plt.title(\"HSV image\")\n        #     plt.show()\n\n    if args.source_dataset == \"CALTECH101\":\n        cls_count = [438, 435, 200, 791, 49, 800, 41, 34, 45, 50, 45, 32, 128, 84, 38, 81, 86, 47, 40, 0, 45, 58, 61,\n                     105, 47, 64, 70, 68, 50, 51, 54, 67, 51, 64, 65, 72, 62, 52, 60, 83, 65, 67, 45, 31, 34, 49, 99,\n                     100, 42, 54, 86, 80, 30, 62, 86, 110, 61, 79, 77, 40, 65, 42, 35, 77, 31, 74, 49, 32, 39, 47, 35,\n                     43, 52, 34, 54, 69, 58, 45, 38, 57, 34, 84, 57, 31, 54, 45, 82, 56, 63, 35, 85, 43, 82, 74, 239,\n                     37, 53, 33, 55, 29, 42]\n        CALTECH101_list = [0] * 102  # 多开了一类, 方便计算\n        for i in range(1, len(cls_count) + 1):\n            CALTECH101_list[i] = CALTECH101_list[i - 1] + cls_count[i - 1]\n\n    if args.source_dataset == \"NCALTECH101\":\n        cls_count = tonic.datasets.NCALTECH101.cls_count\n        CALTECH101_list = [0] * 102  # 多开了一类, 方便计算\n        for i in range(1, len(cls_count) + 1):\n            CALTECH101_list[i] = CALTECH101_list[i - 1] + cls_count[i - 1]\n\n    if args.source_dataset == \"imnet\":\n        cls_count = [1300] * 1000  # 1000类\n        cls_count_idx = [1117, 1266, 1071, 1141, 1272, 1150, 772, 860, 1136, 732, 1025, 754, 1290, 738, 1258, 1273, 977,\n                         936, 1156, 1218, 969, 954, 1070, 755, 1206, 1165, 969, 1292, 1236, 1199, 1209, 1176, 1186,\n                         1194,\n                         1067, 1029, 1154, 1216, 1187, 889, 1211, 1136, 1153, 1222, 1282, 1283, 980, 1034, 891, 1285,\n                         986,\n                         1137, 1272, 1155, 1097, 1149, 1155, 1159, 1133, 1180, 1120, 1005, 1152, 1156, 962, 1157, 1282,\n                         1117, 1118, 1270, 1069, 1053, 1254, 908, 1247, 1253, 1029, 1259, 1267, 1249, 1162, 1045, 1004,\n                         1238, 1153, 1084, 1217, 931, 1264, 976, 1250, 1053, 1160, 1062, 1137, 1299, 1055, 1213, 1206,\n                         1154,\n                         1207, 1149, 1239, 1125, 1193]\n        cls_idx = [43, 51, 62, 98, 103, 147, 152, 158, 164, 165, 166, 167, 168, 175, 181, 183, 188, 190, 194, 206, 221,\n                   252, 262, 268, 335, 390, 392, 409, 418, 426, 439, 465, 481, 491, 499, 501, 503, 507, 521, 531, 536,\n                   550, 551, 567, 577, 583, 585, 590, 596, 610, 623, 630, 631, 635, 653, 662, 663, 675, 676, 678, 686,\n                   689, 706, 708, 712, 714, 722, 723, 724, 727, 728, 729, 731, 740, 747, 753, 771, 772, 782, 789, 798,\n                   810, 811, 812, 821, 826, 838, 841, 854, 857, 860, 869, 872, 885, 891, 892, 901, 906, 914, 921, 925,\n                   926, 940, 946, 969]\n        for i in range(len(cls_count)):\n            if i in cls_idx:\n                cls_count[i] = cls_count_idx[cls_idx.index(i)]\n        ImageNet_list = [0] * 1001  # 多开了一类, 方便计算\n        for i in range(1, 1000 + 1):\n            ImageNet_list[i] = ImageNet_list[i - 1] + cls_count[i - 1]\n\n\n    if args.loss_fn == 'mse':\n        train_loss_fn = UnilateralMse(1.)\n        validate_loss_fn = UnilateralMse(1.)\n\n    else:\n        if args.jsd:\n            assert num_aug_splits > 1  # JSD only valid with aug splits set\n            train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()\n        elif mixup_active:\n            # smoothing is handled with mixup target transform\n            train_loss_fn = SoftTargetCrossEntropy().cuda()\n        elif args.smoothing:\n            train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()\n        else:\n            train_loss_fn = nn.CrossEntropyLoss().cuda()\n\n        validate_loss_fn = nn.CrossEntropyLoss().cuda()\n\n\n    if args.loss_fn == 'mix':\n        train_loss_fn = MixLoss(train_loss_fn)\n        validate_loss_fn = MixLoss(validate_loss_fn)\n\n    eval_metric = args.eval_metric\n    best_metric = None\n    best_epoch = None\n\n\n    if args.eval:  # evaluate the model\n        if args.distributed:\n            state_dict = torch.load(args.eval_checkpoint)['state_dict_ema']\n            new_state_dict = OrderedDict()\n            # add module prefix for DDP\n            for k, v in state_dict.items():\n                k = 'module.' + k\n                new_state_dict[k] = v\n\n            model.load_state_dict(new_state_dict)\n        else:\n            model.load_state_dict(torch.load(args.eval_checkpoint)['state_dict'])\n        for i in range(1):\n            val_metrics = validate(start_epoch, model, target_loader_eval, validate_loss_fn, args,\n                                   visualize=args.visualize, spike_rate=args.spike_rate,\n                                   tsne=args.tsne, conf_mat=args.conf_mat)\n            print(f\"Top-1 accuracy of the model is: {val_metrics['top1']:.1f}%\")\n        return\n\n    saver = None\n    if args.local_rank == 0:\n        decreasing = True if eval_metric == 'loss' else False\n        saver = CheckpointSaver(\n            model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,\n            checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=1)\n        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:\n            f.write(args_text)\n\n    eval_top1 = 0.0\n    try:  # train the model\n        if args.reset_drop:\n            model_without_ddp.reset_drop_path(0.0)\n        for epoch in range(start_epoch, args.epochs):\n            if epoch == 0 and args.reset_drop:\n                model_without_ddp.reset_drop_path(args.drop_path)\n\n            if args.distributed:\n                target_loader_train.sampler.set_epoch(epoch)\n            train_metrics = train_epoch(\n                epoch, model, source_loader_train, target_loader_train, optimizer, train_loss_fn, args,\n                lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,\n                amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)\n\n            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):\n                if args.local_rank == 0:\n                    _logger.info(\"Distributing BatchNorm running means and vars\")\n                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')\n\n            eval_metrics = validate(epoch, model, target_loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast,\n                                    visualize=args.visualize, spike_rate=args.spike_rate,\n                                    tsne=args.tsne, conf_mat=args.conf_mat)\n            eval_top1 = eval_metrics[\"top1\"]\n\n            if model_ema is not None and not args.model_ema_force_cpu:\n                if args.distributed and args.dist_bn in ('broadcast', 'reduce'):\n                    distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')\n\n                    ema_eval_metrics = validate(\n                        epoch, model_ema.ema, target_loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)',\n                        visualize=args.visualize, spike_rate=args.spike_rate,\n                        tsne=args.tsne, conf_mat=args.conf_mat)\n                    eval_metrics = ema_eval_metrics\n\n            if lr_scheduler is not None:\n                # step LR for next epoch\n                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])\n\n            update_summary(\n                epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),\n                write_header=best_metric is None)\n\n            # if saver is not None and epoch >= args.n_warm_up:\n            if saver is not None:\n                # save proper checkpoint with eval metric\n                save_metric = eval_metrics[eval_metric]\n                best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)\n\n            # if epoch == 299:  # 临时的\n            #     break\n\n    except KeyboardInterrupt:\n        pass\n    if best_metric is not None:\n        _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))\n\n\ndef train_epoch(\n        epoch, model, source_loader, target_loader, optimizer, loss_fn, args,\n        lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,\n        loss_scaler=None, model_ema=None, mixup_fn=None):\n    if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:\n        if args.prefetcher and target_loader.mixup_enabled:\n            target_loader.mixup_enabled = False\n        elif mixup_fn is not None:\n            mixup_fn.mixup_enabled = False\n\n    second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order\n    batch_time_m = AverageMeter()\n    data_time_m = AverageMeter()\n    losses_m = AverageMeter()\n    domain_losses_m = AverageMeter()\n    semantic_losses_m = AverageMeter()\n    rgb_losses_m = AverageMeter()\n    dvs_losses_m = AverageMeter()\n    closses_m = AverageMeter()\n    top1_m = AverageMeter()\n    top5_m = AverageMeter()\n\n    model.train()\n\n    end = time.time()\n    last_idx = len(target_loader) - 1\n    num_updates = epoch * len(target_loader)\n    convertor = RGB_HSV()\n\n    batch_len = len(target_loader)\n    if args.target_dataset == \"dvsc10\":\n        set_MaxReplacement_epoch = 0.5 * args.epochs\n    else:\n        set_MaxReplacement_epoch = 0.5 * args.epochs\n    P_Replacement = 0.0\n\n    global source_input_list, source_label_list, CALTECH101_list, ImageNet_list\n    for batch_idx, (inputs, label) in enumerate(target_loader):\n        P_Replacement = ((batch_idx + epoch * batch_len) / (set_MaxReplacement_epoch * batch_len)) ** 3\n        P_Replacement = P_Replacement if P_Replacement <= 1.0 else 1.0\n        sampler_list = label.tolist()\n        if args.target_dataset == \"dvsc10\" and args.source_dataset == \"cifar10\":\n            sampler_list = torch.tensor(sampler_list) * 6000 + torch.randint(0, 6000, (len(sampler_list),))\n        elif args.target_dataset == \"dvsc10\" and args.source_dataset == \"dvsc10\":\n            sampler_list = torch.tensor(sampler_list) * 900 + torch.randint(0, 900, (len(sampler_list),))\n        elif args.target_dataset == \"NCALTECH101\":\n            tmp_sampler_list = []\n            idx_list = []\n            for idx, label_sampler in enumerate(sampler_list):\n                if label_sampler == 19:\n                    tmp_sampler_list.append(0)\n                    idx_list.append(idx)\n                else:\n                    tmp_sampler_list.append(torch.randint(CALTECH101_list[label_sampler],\n                                                          CALTECH101_list[label_sampler + 1], (1,)).item())\n        elif args.target_dataset == \"esimnet\":\n            tmp_sampler_list = []\n            for idx, label_sampler in enumerate(sampler_list):  # 这里的label_sampler是一个列表\n                tmp_sampler_list.append(torch.randint(ImageNet_list[label_sampler],\n                                                      ImageNet_list[label_sampler + 1], (1,)).item())\n        elif args.target_dataset == \"nomni\":\n            sampler_list = torch.tensor(sampler_list) * 20 + torch.randint(0, 20, (len(sampler_list),))\n\n        source_input, source_label = [], []\n        if args.target_dataset == \"dvsc10\":\n            source_input, source_label = source_input_list[sampler_list], source_label_list[sampler_list]\n        if args.target_dataset == \"NCALTECH101\":\n            source_input, source_label = source_input_list[tmp_sampler_list], source_label_list[tmp_sampler_list]\n        if args.target_dataset == \"esimnet\":\n            train_dataset = source_loader  # 给传回来的重新命个名儿\n            source_loader_used = torch.utils.data.DataLoader(\n                train_dataset,\n                batch_size=args.batch_size, shuffle=False,\n                num_workers=8, pin_memory=True, sampler=TransferSampler(tmp_sampler_list))\n            source_input, source_label = next(iter(source_loader_used))\n        if args.target_dataset == \"nomni\":\n            source_input, source_label = source_input_list[sampler_list], source_label_list[sampler_list]\n        # for i in range(128):\n        #     # vis origin picture\n        #     plt.figure()\n        #     plt.imshow(source_input[i].permute(1, 2, 0))\n        #     plt.title(\"origin image\")\n        #     plt.show()\n\n        # # vis HSV picture\n        # plt.figure()\n        # plt.imshow(convertor.rgb_to_hsv(inputs)[7, :, :, :].permute(1, 2, 0).numpy())\n        # plt.title(\"HSV image\")\n        # plt.show()\n\n        # source_input = convertor.rgb_to_hsv(source_input)[:, -1, :, :].unsqueeze(1).repeat(1, args.step * 2, 1, 1)\n        if args.source_dataset == \"dvsc10\" or args.source_dataset == \"NCALTECH101\":\n            pass\n        else:\n            source_input = source_input[:, -1, :, :].unsqueeze(1).repeat(1, args.step * 2, 1, 1)\n            source_input = rearrange(source_input, 'b (t c) h w -> b t c h w', t=args.step)\n\n        for b in range(source_input.shape[0]):\n            if rd.uniform(0, 1) <= P_Replacement:\n                source_input[b] = inputs[b, :, :, :, :]\n\n        # for i in range(10):\n        #     # vis HSV picture for v channel\n        #     plt.figure()\n        #     plt.imshow(source_input[i][0].permute(1, 2, 0)[:, :, -1].unsqueeze(2))\n        #     plt.title(\"HSV image for v channel\")\n        #     plt.show()\n\n        if args.target_dataset == \"NCALTECH101\" and len(idx_list) > 0:\n            for i in range(len(idx_list)):\n                source_input[idx_list[i]] = inputs[idx_list[i]]\n        last_batch = batch_idx == last_idx\n        data_time_m.update(time.time() - end)\n        if not args.prefetcher or args.target_dataset != 'imnet':\n            inputs, label = inputs.type(torch.FloatTensor).cuda(), label.cuda()\n            source_input, source_label = source_input.type(torch.FloatTensor).cuda(), label.cuda()\n            if mixup_fn is not None:\n                inputs, label = mixup_fn(inputs, label)\n                source_input, source_label = mixup_fn(source_input, source_label)\n        if args.channels_last:\n            inputs = inputs.contiguous(memory_format=torch.channels_last)\n            source_input = source_input.contiguous(memory_format=torch.channels_last)\n        with amp_autocast():\n            domain_rbg_list, domain_dvs_list, output_rgb, output_dvs = model(source_input, inputs)\n\n            # compute semantic loss\n            label_idx = [[] for i in range(args.num_classes)]\n            semantic_label_list = []\n            for idx, i in enumerate(label):\n                label_idx[i.item()].append(idx)\n            for i in label:\n                while True:\n                    label_tmp = torch.randint(0, args.num_classes, (1,)).item()\n                    if i.item() != label_tmp and len(label_idx[label_tmp]) > 0:  # NCALTECH101有空列表, 需要判断\n                        break\n                semantic_label_list.append(int(np.random.choice(label_idx[label_tmp], 1)))\n            semantic_rbg_list = []\n            semantic_loss = 0.\n            for i in range(len(domain_rbg_list)):\n                semantic_rbg_list.append(domain_rbg_list[i][semantic_label_list])\n            for i in range(len(domain_rbg_list)):\n                semantic_loss += torch.abs(CKA.linear_CKA(domain_dvs_list[i].view(args.batch_size, -1), semantic_rbg_list[i].view(args.batch_size, -1)))\n            semantic_loss /= len(domain_rbg_list)\n            if args.target_dataset == \"dvsc10\":\n                m = 0.1\n            elif args.target_dataset == \"NCALTECH101\":\n                m = 0.3\n            else:\n                m = 0.2\n            if args.m >= 0.0:\n                m = args.m\n            if semantic_loss.item() - m <= 0:\n                semantic_loss = torch.tensor(0., device=semantic_loss.device)\n\n            # if args.domain_loss_after:\n            #     # compute domain loss\n            #     for b in range(source_input.shape[0]):\n            #         if rd.uniform(0, 1) <= P_Replacement:\n            #             for i in range(len(domain_rbg_list)):\n            #                 domain_rbg_list[i][b] = domain_dvs_list[i][b, :, :, :]\n\n            domain_loss = 0.\n            for i in range(len(domain_rbg_list)):\n                domain_loss += 1 - torch.abs(CKA.linear_CKA(domain_rbg_list[i].view(args.batch_size, -1), domain_dvs_list[i].view(args.batch_size, -1)))\n            domain_loss /= len(domain_rbg_list)\n\n            # compute cls loss\n            lamb = 1e-3\n            if args.TET_loss_first or args.TET_loss_second:  # 第一项必须有，也就是测两个，第一个何第一个加第二个\n                loss_rgb = 0\n                tet_loss_first = 0\n                tet_loss_second = 0\n                assert len(output_rgb) == len(output_dvs)\n                for i in range(len(output_rgb)):\n                    loss_rgb += loss_fn(output_rgb[i], label)\n                    tet_loss_first += loss_fn(output_dvs[i], label)\n                loss_rgb /= len(output_rgb)\n                tet_loss_first /= len(output_dvs)\n\n                if args.TET_loss_second:\n                    y = torch.zeros_like(output_dvs[-1]).fill_(args.threshold)\n                    secondLoss = torch.nn.MSELoss()\n                    tet_loss_second = secondLoss(output_dvs[-1], y)\n                else:\n                    lamb = 0.0\n                loss_dvs = (1 - lamb) * tet_loss_first + lamb * tet_loss_second\n                output_rgb = sum(output_rgb) / len(output_rgb)\n                output_dvs = sum(output_dvs) / len(output_dvs)\n            else:\n                output_rgb = sum(output_rgb) / len(output_rgb)\n                output_dvs = sum(output_dvs) / len(output_dvs)\n                loss_rgb = loss_fn(output_rgb, label)\n                loss_dvs = loss_fn(output_dvs, label)\n\n            loss = 0 * loss_rgb + loss_dvs\n            if args.domain_loss:\n                loss += args.domain_loss_coefficient * domain_loss\n            if args.semantic_loss and epoch <= set_MaxReplacement_epoch:\n                if args.target_dataset == \"NCALTECH101\" and epoch <= set_MaxReplacement_epoch * 0.66:\n                    # loss += args.semantic_loss_coefficient * semantic_loss * math.pow(10, -1.0 * float(set_MaxReplacement_epoch / (epoch+1)))\n                    pass\n                else:\n                    loss += args.semantic_loss_coefficient * semantic_loss\n\n        if not (args.cut_mix | args.mix_up | args.event_mix) and args.target_dataset != 'imnet':\n            acc1, acc5 = accuracy(output_dvs, label, topk=(1, 5))\n        else:\n            acc1, acc5 = torch.tensor([0.]), torch.tensor([0.])\n\n        closs = torch.tensor([0.], device=loss.device)\n\n        loss = loss + .1 * closs\n\n        spike_rate_avg_layer_str = ''\n        threshold_str = ''\n        if not args.distributed:\n            losses_m.update(loss.item(), inputs.size(0))\n            domain_losses_m.update(domain_loss.item(), inputs.size(0))\n            semantic_losses_m.update(semantic_loss.item(), inputs.size(0))\n            rgb_losses_m.update(loss_rgb.item(), inputs.size(0))\n            dvs_losses_m.update(loss_dvs.item(), inputs.size(0))\n            top1_m.update(acc1.item(), inputs.size(0))\n            top5_m.update(acc5.item(), inputs.size(0))\n            closses_m.update(closs.item(), inputs.size(0))\n\n            spike_rate_avg_layer = model.get_fire_rate().tolist()\n            spike_rate_avg_layer_str = ['{:.3f}'.format(i) for i in spike_rate_avg_layer]\n            threshold = model.get_threshold()\n            threshold_str = ['{:.3f}'.format(i.item()) for i in threshold]\n\n        optimizer.zero_grad()\n        if loss_scaler is not None:\n            loss_scaler(\n                loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order)\n        else:\n            loss.backward(create_graph=second_order)\n            if args.noisy_grad != 0.:\n                random_gradient(model, args.noisy_grad)\n            if args.clip_grad is not None:\n                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)\n            if args.opt == 'lamb':\n                optimizer.step(epoch=epoch)\n            else:\n                optimizer.step()\n\n        torch.cuda.synchronize()\n        if model_ema is not None:\n            model_ema.update(model)\n        num_updates += 1\n\n        batch_time_m.update(time.time() - end)\n        if last_batch or batch_idx % args.log_interval == 0:\n            lrl = [param_group['lr'] for param_group in optimizer.param_groups]\n            lr = sum(lrl) / len(lrl)\n\n            mu_str = ''\n            sigma_str = ''\n            if not args.distributed:\n                if 'Noise' in args.node_type:\n                    mu, sigma = model.get_noise_param()\n                    mu_str = ['{:.3f}'.format(i.detach()) for i in mu]\n                    sigma_str = ['{:.3f}'.format(i.detach()) for i in sigma]\n\n            if args.distributed:\n                reduced_loss = reduce_tensor(loss.data, args.world_size)\n                losses_m.update(reduced_loss.item(), inputs.size(0))\n                closses_m.update(reduced_loss.item(), inputs.size(0))\n\n            if args.local_rank == 0:\n                if args.distributed:\n                    _logger.info(\n                        'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '\n                        'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f})  '\n                        'cLoss: {closs.val:>9.6f} ({closs.avg:>6.4f})  '\n                        'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '\n                        'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})  '\n                        'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '\n                        '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '\n                        'LR: {lr:.3e}  '\n                        'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(\n                            epoch,\n                            batch_idx, len(target_loader),\n                            100. * batch_idx / last_idx,\n                            loss=losses_m,\n                            closs=closses_m,\n                            top1=top1_m,\n                            top5=top5_m,\n                            batch_time=batch_time_m,\n                            rate=inputs.size(0) * args.world_size / batch_time_m.val,\n                            rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,\n                            lr=lr,\n                            data_time=data_time_m\n                        ))\n                else:\n                    _logger.info(\n                        'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '\n                        'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f})  '\n                        'cLoss: {closs.val:>9.6f} ({closs.avg:>6.4f})  '\n                        'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '\n                        'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})  '\n                        'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '\n                        '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '\n                        'LR: {lr:.3e}  '\n                        'Data: {data_time.val:.3f} ({data_time.avg:.3f})\\n'\n                        'Fire_rate: {spike_rate}\\n'\n                        'Thres: {threshold}\\n'\n                        'Mu: {mu_str}\\n'\n                        'Sigma: {sigma_str}\\n'\n                        'P_Replacement: {P_Replacement}\\n'.format(\n                            epoch,\n                            batch_idx, len(target_loader),\n                            100. * batch_idx / last_idx,\n                            loss=losses_m,\n                            closs=closses_m,\n                            top1=top1_m,\n                            top5=top5_m,\n                            batch_time=batch_time_m,\n                            rate=inputs.size(0) * args.world_size / batch_time_m.val,\n                            rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,\n                            lr=lr,\n                            data_time=data_time_m,\n                            spike_rate=spike_rate_avg_layer_str,\n                            threshold=threshold_str,\n                            mu_str=mu_str,\n                            sigma_str=sigma_str,\n                            P_Replacement=P_Replacement,\n                        ))\n\n                if args.save_images and output_dir:\n                    torchvision.utils.save_image(\n                        inputs,\n                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),\n                        padding=0,\n                        normalize=True)\n\n        if saver is not None and args.recovery_interval and (\n                last_batch or (batch_idx + 1) % args.recovery_interval == 0):\n            saver.save_recovery(epoch, batch_idx=batch_idx)\n\n        if lr_scheduler is not None:\n            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)\n\n        end = time.time()\n    # end for\n\n    if hasattr(optimizer, 'sync_lookahead'):\n        optimizer.sync_lookahead()\n\n    return OrderedDict([('loss', losses_m.avg), ('domainLoss', domain_losses_m.avg), ('semanticLoss', semantic_losses_m.avg),\n                        ('rgbLoss', rgb_losses_m.avg), ('dvsLoss', dvs_losses_m.avg)])\n\ndef validate(epoch, model, loader, loss_fn, args, amp_autocast=suppress,\n             log_suffix='', visualize=False, spike_rate=False, tsne=False, conf_mat=False):\n    batch_time_m = AverageMeter()\n    losses_m = AverageMeter()\n    closses_m = AverageMeter()\n    top1_m = AverageMeter()\n    top5_m = AverageMeter()\n\n    model.eval()\n\n    feature_vec = []\n    feature_cls = []\n    logits_vec = []\n    labels_vec = []\n\n    end = time.time()\n    last_idx = len(loader) - 1\n    with torch.no_grad():\n        for batch_idx, (inputs, target) in enumerate(loader):\n            # inputs = inputs.type(torch.float64)\n            last_batch = batch_idx == last_idx\n            if not args.prefetcher or args.target_dataset != 'imnet':\n                inputs = inputs.type(torch.FloatTensor).cuda()\n                target = target.cuda()\n            if args.channels_last:\n                inputs = inputs.contiguous(memory_format=torch.channels_last)\n\n            if not args.distributed:\n                if (visualize or spike_rate or tsne or conf_mat) and not args.critical_loss:\n                    model.set_requires_fp(True)\n                    # if not args.critical_loss:\n                    #     model.set_requires_fp(False)\n\n            with amp_autocast():\n                _, _, output_rbg, output_dvs = model(inputs, inputs)\n                output = sum(output_dvs) / len(output_dvs)\n            if isinstance(output, (tuple, list)):\n                output = output[0]\n\n            if not args.distributed:\n                if visualize:\n                    x = model.get_fp()\n                    feature_path = os.path.join(args.output_dir, 'feature_map')\n                    if os.path.exists(feature_path) is False:\n                        os.mkdir(feature_path)\n                    save_feature_map(x, feature_path)\n                    # if not args.critical_loss:\n                    #     model_config.set_requires_fp(False)\n\n                if tsne:\n                    x = model.get_fp(temporal_info=False)[-1]\n                    x = torch.nn.AdaptiveAvgPool2d((1, 1))(x)\n                    x = x.reshape(x.shape[0], -1)\n                    feature_vec.append(x)\n                    feature_cls.append(target)\n\n                if conf_mat:\n                    logits_vec.append(output)\n                    labels_vec.append(target)\n\n                if spike_rate:\n                    avg, var, spike, avg_per_step = model.get_spike_info()\n                    save_spike_info(\n                        os.path.join(args.output_dir, 'spike_info.csv'),\n                        epoch, batch_idx,\n                        args.step, avg, var,\n                        spike, avg_per_step)\n\n            # augmentation reduction\n            reduce_factor = args.tta\n            if reduce_factor > 1:\n                output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)\n                target = target[0:target.size(0):reduce_factor]\n\n            loss = loss_fn(output, target)\n            acc1, acc5 = accuracy(output, target, topk=(1, 5))\n            # acc1, = accuracy(output, target)\n\n            closs = torch.tensor([0.], device=loss.device)\n\n            if not args.distributed:\n                spike_rate_avg_layer = model.get_fire_rate().tolist()\n                threshold = model.get_threshold()\n                threshold_str = ['{:.3f}'.format(i) for i in threshold]\n                spike_rate_avg_layer_str = ['{:.3f}'.format(i) for i in spike_rate_avg_layer]\n                tot_spike = model.get_tot_spike()\n\n            if args.distributed:\n                reduced_loss = reduce_tensor(loss.data, args.world_size)\n                acc1 = reduce_tensor(acc1, args.world_size)\n                acc5 = reduce_tensor(acc5, args.world_size)\n            else:\n                reduced_loss = loss.data\n\n            torch.cuda.synchronize()\n\n            losses_m.update(reduced_loss.item(), inputs.size(0))\n            top1_m.update(acc1.item(), output.size(0))\n            top5_m.update(acc5.item(), output.size(0))\n            closses_m.update(closs.item(), inputs.size(0))\n\n            batch_time_m.update(time.time() - end)\n            end = time.time()\n            if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):\n                log_name = 'Test' + log_suffix\n\n                mu_str = ''\n                sigma_str = ''\n                if not args.distributed:\n                    if 'Noise' in args.node_type:\n                        mu, sigma = model.get_noise_param()\n                        mu_str = ['{:.3f}'.format(i.detach()) for i in mu]\n                        sigma_str = ['{:.3f}'.format(i.detach()) for i in sigma]\n\n                if args.distributed:\n                    _logger.info(\n                        '{0}: [{1:>4d}/{2}]  '\n                        'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '\n                        'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '\n                        'cLoss: {closs.val:>7.4f} ({closs.avg:>6.4f})  '\n                        'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})'\n                        'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(\n                            log_name,\n                            batch_idx,\n                            last_idx,\n                            batch_time=batch_time_m,\n                            loss=losses_m,\n                            closs=closses_m,\n                            top1=top1_m,\n                            top5=top5_m,\n                            ))\n                else:\n                    _logger.info(\n                        '{0}: [{1:>4d}/{2}]  '\n                        'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '\n                        'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '\n                        'cLoss: {closs.val:>7.4f} ({closs.avg:>6.4f})  '\n                        'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})'\n                        'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})\\n'\n                        'Fire_rate: {spike_rate}\\n'\n                        'Tot_spike: {tot_spike}\\n'\n                        'Thres: {threshold}\\n'\n                        'Mu: {mu_str}\\n'\n                        'Sigma: {sigma_str}\\n'.format(\n                            log_name,\n                            batch_idx,\n                            last_idx,\n                            batch_time=batch_time_m,\n                            loss=losses_m,\n                            closs=closses_m,\n                            top1=top1_m,\n                            top5=top5_m,\n                            spike_rate=spike_rate_avg_layer_str,\n                            tot_spike=tot_spike,\n                            threshold=threshold_str,\n                            mu_str=mu_str,\n                            sigma_str=sigma_str\n                        ))\n\n    # metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])\n    metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg)])\n\n    if not args.distributed:\n        if tsne:\n            feature_vec = torch.cat(feature_vec)\n            feature_cls = torch.cat(feature_cls)\n            plot_tsne(feature_vec, feature_cls, os.path.join(args.output_dir, 't-sne-2d.eps'))\n            plot_tsne_3d(feature_vec, feature_cls, os.path.join(args.output_dir, 't-sne-3d.eps'))\n        if conf_mat:\n            logits_vec = torch.cat(logits_vec)\n            labels_vec = torch.cat(labels_vec)\n            plot_confusion_matrix(logits_vec, labels_vec, os.path.join(args.output_dir, 'confusion_matrix.eps'))\n\n    return metrics\n\n\nif __name__ == '__main__':\n    main()"
  },
  {
    "path": "examples/Perception_and_Learning/img_cls/transfer_for_dvs/main_visual_losslandscape.py",
    "content": "# -*- coding: utf-8 -*-            \n# Time : 2023/2/14 11:52\n# Author : Regulus\n# FileName: main_visual_losslandscape.py\n# Explain: \n# Software: PyCharm\n\nfrom loss_landscape.plot_surface import *\n\n\nimport argparse\nimport math\nimport time\nimport CKA\nimport numpy\nimport timm.models\nimport random as rd\nimport yaml\nimport os\nimport logging\nfrom collections import OrderedDict\nfrom contextlib import suppress\nfrom datetime import datetime\nfrom braincog.base.node.node import *\nfrom braincog.utils import *\nfrom braincog.base.utils.criterions import *\nfrom braincog.datasets.datasets import *\nfrom braincog.model_zoo.resnet import *\nfrom braincog.model_zoo.convnet import *\nfrom braincog.model_zoo.vgg_snn import VGG_SNN\nfrom braincog.model_zoo.resnet19_snn import resnet19\nfrom braincog.utils import save_feature_map, setup_seed\nfrom braincog.base.utils.visualization import plot_tsne_3d, plot_tsne, plot_confusion_matrix\nimport torch\nimport torch.nn as nn\nimport torchvision.utils\nfrom torch.nn.parallel import DistributedDataParallel as NativeDDP\nfrom rgb_hsv import RGB_HSV\nimport matplotlib.pyplot as plt\nfrom timm.data import ImageDataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset\nfrom timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model\nfrom timm.utils import *\nfrom timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy\nfrom timm.optim import create_optimizer\nfrom timm.scheduler import create_scheduler\nfrom timm.utils import ApexScaler, NativeScaler\n\ntorch.backends.cudnn.benchmark = True\n_logger = logging.getLogger('train')\n\n# The first arg parser parses out only the --config argument, this argument is used to\n# load a yaml file containing key-values that override the defaults for the main parser below\nconfig_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)\nparser.add_argument('-c', '--config', default='', type=str, metavar='FILE',\n                    help='YAML config file specifying default arguments')\n\n\nparser = argparse.ArgumentParser(description='SNN Training and Evaluating')\n\n# Model parameters\nparser.add_argument('--source-dataset', default='cifar10', type=str)\nparser.add_argument('--target-dataset', default='dvsc10', type=str)\nparser.add_argument('--model', default='cifar_convnet', type=str, metavar='MODEL',\n                    help='Name of model to train (default: \"countception\"')\nparser.add_argument('--pretrained', action='store_true', default=False,\n                    help='Start with pretrained version of specified network (if avail)')\nparser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',\n                    help='Initialize model from this checkpoint (default: none)')\nparser.add_argument('--resume', default='', type=str, metavar='PATH',\n                    help='Resume full model and optimizer state from checkpoint (default: none)')\nparser.add_argument('--eval_checkpoint', default='', type=str, metavar='PATH',\n                    help='path to eval checkpoint (default: none)')\nparser.add_argument('--no-resume-opt', action='store_true', default=False,\n                    help='prevent resume of optimizer state when resuming model')\nparser.add_argument('--num-classes', type=int, default=10, metavar='N',\n                    help='number of label classes (default: 1000)')\nparser.add_argument('--gp', default=None, type=str, metavar='POOL',\n                    help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')\n\n# Dataset parameters for static datasets\nparser.add_argument('--img-size', type=int, default=224, metavar='N',\n                    help='Image patch size (default: None => model default)')\nparser.add_argument('--crop-pct', default=None, type=float,\n                    metavar='N', help='inputs image center crop percent (for validation only)')\nparser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',\n                    help='Override mean pixel value of dataset')\nparser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',\n                    help='Override std deviation of of dataset')\nparser.add_argument('--interpolation', default='', type=str, metavar='NAME',\n                    help='Image resize interpolation type (overrides model)')\n\n# Dataloader parameters\nparser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',\n                    help='inputs batch size for training (default: 128)')\nparser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',\n                    help='ratio of validation batch size to training batch size (default: 1)')\n\n# Optimizer parameters\nparser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',\n                    help='Optimizer (default: \"adamw\"')\nparser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',\n                    help='Optimizer Epsilon (default: None, use opt default)')\nparser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',\n                    help='Optimizer Betas (default: None, use opt default)')\nparser.add_argument('--momentum', type=float, default=0.9, metavar='M',\n                    help='Optimizer momentum (default: 0.9)')\nparser.add_argument('--weight-decay', type=float, default=0.01,\n                    help='weight decay (default: 0.01 for adamw)')\nparser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',\n                    help='Clip gradient norm (default: None, no clipping)')\nparser.add_argument('--adam-epoch', type=int, default=1000, help='lamb switch to adamw')\n\n# Learning rate schedule parameters\nparser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',\n                    help='LR scheduler (default: \"cosine\"')\nparser.add_argument('--lr', type=float, default=5e-3, metavar='LR',\n                    help='learning rate (default: 0.01)')\nparser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',\n                    help='learning rate noise on/off epoch percentages')\nparser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',\n                    help='learning rate noise limit percent (default: 0.67)')\nparser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',\n                    help='learning rate noise std-dev (default: 1.0)')\nparser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',\n                    help='learning rate cycle len multiplier (default: 1.0)')\nparser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',\n                    help='learning rate cycle limit')\nparser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',\n                    help='warmup learning rate (default: 0.0001)')\nparser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',\n                    help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')\nparser.add_argument('--epochs', type=int, default=600, metavar='N',\n                    help='number of epochs to train (default: 2)')\nparser.add_argument('--start-epoch', default=None, type=int, metavar='N',\n                    help='manual epoch number (useful on restarts)')\nparser.add_argument('--decay-epochs', type=float, default=30, metavar='N',\n                    help='epoch interval to decay LR')\nparser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',\n                    help='epochs to warmup LR, if scheduler supports')\nparser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',\n                    help='epochs to cooldown LR at min_lr, after cyclic schedule ends')\nparser.add_argument('--patience-epochs', type=int, default=10, metavar='N',\n                    help='patience epochs for Plateau LR scheduler (default: 10')\nparser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',\n                    help='LR decay rate (default: 0.1)')\nparser.add_argument('--power', type=int, default=1, help='power')\n\n# Augmentation & regularization parameters ONLY FOR IMAGE NET\nparser.add_argument('--no-aug', action='store_true', default=False,\n                    help='Disable all training augmentation, override other train aug args')\nparser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',\n                    help='Random resize scale (default: 0.08 1.0)')\nparser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                    help='Random resize aspect ratio (default: 0.75 1.33)')\nparser.add_argument('--hflip', type=float, default=0.5,\n                    help='Horizontal flip training aug probability')\nparser.add_argument('--vflip', type=float, default=0.,\n                    help='Vertical flip training aug probability')\nparser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',\n                    help='Color jitter factor (default: 0.4)')\nparser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',\n                    help='Use AutoAugment policy. \"v0\" or \"original\". (default: None)'),\nparser.add_argument('--aug-splits', type=int, default=0,\n                    help='Number of augmentation splits (default: 0, valid: 0 or >=2)')\nparser.add_argument('--jsd', action='store_true', default=False,\n                    help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')\nparser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',\n                    help='Random erase prob (default: 0.25)')\nparser.add_argument('--remode', type=str, default='pixel',\n                    help='Random erase mode (default: \"const\")')\nparser.add_argument('--recount', type=int, default=1,\n                    help='Random erase count (default: 1)')\nparser.add_argument('--resplit', action='store_true', default=False,\n                    help='Do not random erase first (clean) augmentation split')\nparser.add_argument('--mixup', type=float, default=0.8,\n                    help='mixup alpha, mixup enabled if > 0. (default: 0.)')\nparser.add_argument('--cutmix', type=float, default=1.0,\n                    help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')\nparser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,\n                    help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')\nparser.add_argument('--mixup-prob', type=float, default=1.0,\n                    help='Probability of performing mixup or cutmix when either/both is enabled')\nparser.add_argument('--mixup-switch-prob', type=float, default=0.5,\n                    help='Probability of switching to cutmix when both mixup and cutmix enabled')\nparser.add_argument('--mixup-mode', type=str, default='batch',\n                    help='How to apply mixup/cutmix params. Per \"batch\", \"pair\", or \"elem\"')\nparser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',\n                    help='Turn off mixup after this epoch, disabled if 0 (default: 0)')\nparser.add_argument('--smoothing', type=float, default=0.1,\n                    help='Label smoothing (default: 0.1)')\nparser.add_argument('--train-interpolation', type=str, default='random',\n                    help='Training interpolation (random, bilinear, bicubic default: \"random\")')\nparser.add_argument('--drop', type=float, default=0.0, metavar='PCT',\n                    help='Dropout rate (default: 0.0)')\nparser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',\n                    help='Drop connect rate, DEPRECATED, use drop-path (default: None)')\nparser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',\n                    help='Drop path rate (default: None)')\nparser.add_argument('--drop-block', type=float, default=None, metavar='PCT',\n                    help='Drop block rate (default: None)')\nparser.add_argument('--newton-maxiter', default=20, type=int,\n                    help='max iterration in newton method')\nparser.add_argument('--reset-drop', action='store_true', default=False,\n                    help='whether to reset drop')\nparser.add_argument('--kernel-method', type=str, default='cuda', choices=['torch', 'cuda'],\n                    help='The implementation way of gaussian kernel method, choose from \"cuda\" and \"torch\"')\n\n# Batch norm parameters (only works with gen_efficientnet based models currently)\nparser.add_argument('--bn-tf', action='store_true', default=False,\n                    help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')\nparser.add_argument('--bn-momentum', type=float, default=None,\n                    help='BatchNorm momentum override (if not None)')\nparser.add_argument('--bn-eps', type=float, default=None,\n                    help='BatchNorm epsilon override (if not None)')\nparser.add_argument('--sync-bn', action='store_true',\n                    help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')\nparser.add_argument('--dist-bn', type=str, default='',\n                    help='Distribute BatchNorm stats between node after each epoch (\"broadcast\", \"reduce\", or \"\")')\nparser.add_argument('--split-bn', action='store_true',\n                    help='Enable separate BN layers per augmentation split.')\n\n# Model Exponential Moving Average\nparser.add_argument('--model-ema', action='store_true', default=False,\n                    help='Enable tracking moving average of model weights')\nparser.add_argument('--model-ema-force-cpu', action='store_true', default=False,\n                    help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')\nparser.add_argument('--model-ema-decay', type=float, default=0.99996,\n                    help='decay factor for model weights moving average (default: 0.9998)')\n\n# Misc\nparser.add_argument('--seed', type=int, default=42, metavar='S',\n                    help='random seed (default: 42)')\nparser.add_argument('--log-interval', type=int, default=50, metavar='N',\n                    help='how many batches to wait before logging training status')\nparser.add_argument('--recovery-interval', type=int, default=0, metavar='N',\n                    help='how many batches to wait before writing recovery checkpoint')\nparser.add_argument('-j', '--workers', type=int, default=8, metavar='N',\n                    help='how many training processes to use (default: 1)')\nparser.add_argument('--num-gpu', type=int, default=1,\n                    help='Number of GPUS to use')\nparser.add_argument('--save-images', action='store_true', default=False,\n                    help='save images of inputs bathes every log interval for debugging')\nparser.add_argument('--amp', action='store_true', default=False,\n                    help='use NVIDIA Apex AMP or Native AMP for mixed precision training')\nparser.add_argument('--apex-amp', action='store_true', default=False,\n                    help='Use NVIDIA Apex AMP mixed precision')\nparser.add_argument('--native-amp', action='store_true', default=False,\n                    help='Use Native Torch AMP mixed precision')\nparser.add_argument('--channels-last', action='store_true', default=False,\n                    help='Use channels_last memory layout')\nparser.add_argument('--pin-mem', action='store_true', default=False,\n                    help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')\nparser.add_argument('--no-prefetcher', action='store_true', default=False,\n                    help='disable fast prefetcher')\nparser.add_argument('--output', default='/home/hexiang/TransferLearning_For_DVS/Results_new_refined/', type=str, metavar='PATH',\n                    help='path to output folder (default: none, current dir)')\nparser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',\n                    help='Best metric (default: \"top1\"')\nparser.add_argument('--tta', type=int, default=0, metavar='N',\n                    help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')\nparser.add_argument('--local_rank', default=0, type=int)\nparser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,\n                    help='use the multi-epochs-loader to save time at the beginning of every epoch')\nparser.add_argument('--eval', action='store_true', help='Perform evaluation only')\nparser.add_argument('--device', type=int, default=0)\n\n# Spike parameters\nparser.add_argument('--step', type=int, default=10, help='Simulation time step (default: 10)')\nparser.add_argument('--encode', type=str, default='direct', help='Input encode method (default: direct)')\nparser.add_argument('--temporal-flatten', action='store_true',\n                    help='Temporal flatten to channels. ONLY FOR EVENT DATA TRAINING BY ANN')\nparser.add_argument('--adaptive-node', action='store_true')\nparser.add_argument('--critical-loss', action='store_true')\n\n# neuron type\nparser.add_argument('--node-type', type=str, default='LIFNode', help='Node type in network (default: PLIF)')\nparser.add_argument('--act-fun', type=str, default='GateGrad',\n                    help='Surogate Function in node. Only for Surrogate nodes (default: AtanGrad)')\nparser.add_argument('--threshold', type=float, default=.5, help='Firing threshold (default: 0.5)')\nparser.add_argument('--tau', type=float, default=2., help='Attenuation coefficient (default: 2.)')\nparser.add_argument('--requires-thres-grad', action='store_true')\nparser.add_argument('--sigmoid-thres', action='store_true')\n\nparser.add_argument('--loss-fn', type=str, default='ce', help='loss function (default: ce)')\nparser.add_argument('--noisy-grad', type=float, default=0.,\n                    help='Add noise to backward, sometime will make higher accuracy (default: 0.)')\nparser.add_argument('--spike-output', action='store_true', default=False,\n                    help='Using mem output or spike output (default: False)')\nparser.add_argument('--n_groups', type=int, default=1)\n\n# EventData Augmentation\nparser.add_argument('--mix-up', action='store_true', help='Mix-up for event data (default: False)')\nparser.add_argument('--cut-mix', action='store_true', help='CutMix for event data (default: False)')\nparser.add_argument('--event-mix', action='store_true', help='EventMix for event data (default: False)')\nparser.add_argument('--cutmix_beta', type=float, default=1.0, help='cutmix_beta (default: 1.)')\nparser.add_argument('--cutmix_prob', type=float, default=0.5, help='cutmix_prib for event data (default: .5)')\nparser.add_argument('--cutmix_num', type=int, default=1, help='cutmix_num for event data (default: 1)')\nparser.add_argument('--cutmix_noise', type=float, default=0.,\n                    help='Add Pepper noise after mix, sometimes work (default: 0.)')\nparser.add_argument('--gaussian-n', type=int, default=3)\nparser.add_argument('--rand-aug', action='store_true',\n                    help='Rand Augment for Event data (default: False)')\nparser.add_argument('--randaug_n', type=int, default=3,\n                    help='Rand Augment times n (default: 3)')\nparser.add_argument('--randaug_m', type=int, default=15,\n                    help='Rand Augment times n (default: 15) (0-30)')\nparser.add_argument('--train-portion', type=float, default=0.9,\n                    help='Dataset portion, only for datasets which do not have validation set (default: 0.9)')\nparser.add_argument('--event-size', default=48, type=int,\n                    help='Event size. Resize event data before process (default: 48)')\nparser.add_argument('--layer-by-layer', action='store_true',\n                    help='forward step-by-step or layer-by-layer. '\n                         'Larger Model with layer-by-layer will be faster (default: False)')\nparser.add_argument('--node-resume', type=str, default='',\n                    help='resume weights in node for adaptive node. (default: False)')\nparser.add_argument('--node-trainable', action='store_true')\n\n# visualize\nparser.add_argument('--visualize', action='store_true',\n                    help='Visualize spiking map for each layer, only for validate (default: False)')\nparser.add_argument('--spike-rate', action='store_true',\n                    help='Print spiking rate for each layer, only for validate(default: False)')\nparser.add_argument('--tsne', action='store_true')\nparser.add_argument('--conf-mat', action='store_true')\n\nparser.add_argument('--suffix', type=str, default='',\n                    help='Add an additional suffix to the save path (default: \\'\\')')\n\nparser.add_argument('--DVS-DA', action='store_true',\n                    help='use DA on DVS')\n\n# train data used ratio\nparser.add_argument('--traindata-ratio', default=1.0, type=float,\n                    help='training data ratio')\n\n# snr value\nparser.add_argument('--snr', default=0, type=int,\n                    help='random noise amplitude controled by snr, 0 means no noise')\n\n# --------------------------------------------------------------------------\n# Start the loss-landscape\n# --------------------------------------------------------------------------\n\nparser.add_argument('--mpi', '-m', action='store_true', help='use mpi')\nparser.add_argument('--threads', default=2, type=int, help='number of threads')\nparser.add_argument('--ngpu', type=int, default=1,\n                    help='number of GPUs to use for each rank, useful for data parallel evaluation')\n\n\n# model parameters\nparser.add_argument('--model_folder', default='',\n                    help='the common folder that contains model_file and model_file2')\nparser.add_argument('--model_file', default='', help='path to the trained model file')\nparser.add_argument('--model_file2', default='', help='use (model_file2 - model_file) as the xdirection')\nparser.add_argument('--model_file3', default='', help='use (model_file3 - model_file) as the ydirection')\nparser.add_argument('--loss_name', '-l', default='crossentropy', help='loss functions: crossentropy | mse')\n\n# direction parameters\nparser.add_argument('--dir_file', default='',\n                    help='specify the name of direction file, or the path to an eisting direction file')\nparser.add_argument('--dir_type', default='weights',\n                    help='direction type: weights | states (including BN\\'s running_mean/var)')\nparser.add_argument('--x', default='-1:1:51', help='A string with format xmin:x_max:xnum')\nparser.add_argument('--y', default=None, help='A string with format ymin:ymax:ynum')\nparser.add_argument('--xnorm', default='', help='direction normalization: filter | layer | weight')\nparser.add_argument('--ynorm', default='', help='direction normalization: filter | layer | weight')\nparser.add_argument('--xignore', default='', help='ignore bias and BN parameters: biasbn')\nparser.add_argument('--yignore', default='', help='ignore bias and BN parameters: biasbn')\nparser.add_argument('--same_dir', action='store_true', default=False,\n                    help='use the same random direction for both x-axis and y-axis')\nparser.add_argument('--idx', default=0, type=int, help='the index for the repeatness experiment')\nparser.add_argument('--surf_file', default='',\n                    help='customize the name of surface file, could be an existing file.')\n\n# plot parameters\nparser.add_argument('--proj_file', default='', help='the .h5 file contains projected optimization trajectory.')\nparser.add_argument('--loss_max', default=5, type=float, help='Maximum value to show in 1D plot')\nparser.add_argument('--vmax', default=10, type=float, help='Maximum value to map')\nparser.add_argument('--vmin', default=0.1, type=float, help='Miminum value to map')\nparser.add_argument('--vlevel', default=1.0, type=float, help='plot contours every vlevel')\nparser.add_argument('--show', action='store_true', default=False, help='show plotted figures')\nparser.add_argument('--log', action='store_true', default=False, help='use log scale for loss values')\nparser.add_argument('--plot', action='store_true', default=False, help='plot figures after computation')\n\ntry:\n    from apex import amp\n    from apex.parallel import DistributedDataParallel as ApexDDP\n    from apex.parallel import convert_syncbn_model\n\n    has_apex = True\nexcept ImportError:\n    has_apex = False\n\nhas_native_amp = False\ntry:\n    if getattr(torch.cuda.amp, 'autocast') is not None:\n        has_native_amp = True\nexcept AttributeError:\n    pass\n\n\ndef _parse_args():\n    # Do we have a config file to parse?\n    args_config, remaining = config_parser.parse_known_args()\n    if args_config.config:\n        with open(args_config.config, 'r') as f:\n            cfg = yaml.safe_load(f)\n            parser.set_defaults(**cfg)\n\n    # The main arg parser parses the rest of the args, the usual\n    # defaults will have been overridden if config file specified.\n    args = parser.parse_args(remaining)\n\n    # Cache the args as a text string to save them in the output dir later\n    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)\n    return args, args_text\n\n\ndef main():\n    torch.set_num_threads(20)\n    os.environ[\"OMP_NUM_THREADS\"] = \"20\"  # 设置OpenMP计算库的线程数\n    os.environ[\"MKL_NUM_THREADS\"] = \"20\"  # 设置MKL-DNN CPU加速库的线程数。\n    args, args_text = _parse_args()\n    args.no_spike_output = True\n\n    args.prefetcher = not args.no_prefetcher\n    args.distributed = False\n    if 'WORLD_SIZE' in os.environ:\n        args.distributed = int(os.environ['WORLD_SIZE']) > 1\n        if args.distributed and args.num_gpu > 1:\n            _logger.warning(\n                'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')\n            args.num_gpu = 1\n\n    # args.device = 'cuda:0'\n    args.world_size = 1\n    args.rank = 0  # global rank\n\n    assert args.rank >= 0\n\n    if args.distributed:\n        _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'\n                     % (args.rank, args.world_size))\n    else:\n        _logger.info('Training with a single process on %d GPUs.' % args.num_gpu)\n\n    # torch.manual_seed(args.seed + args.rank)\n    setup_seed(args.seed + args.rank)\n\n    model = create_model(\n        args.model,\n        pretrained=args.pretrained,\n        num_classes=args.num_classes,\n        adaptive_node=args.adaptive_node,\n        dataset=args.target_dataset,\n        step=args.step,\n        encode_type=args.encode,\n        node_type=eval(args.node_type),\n        threshold=args.threshold,\n        tau=args.tau,\n        sigmoid_thres=args.sigmoid_thres,\n        requires_thres_grad=args.requires_thres_grad,\n        spike_output=not args.no_spike_output,\n        act_fun=args.act_fun,\n        temporal_flatten=args.temporal_flatten,\n        layer_by_layer=args.layer_by_layer,\n        n_groups=args.n_groups,\n        TET_loss=False\n    )  # 注意这里的TET_loss，选择losslandscape在谁上看.\n\n    if 'dvs' in args.target_dataset:\n        args.channels = 2\n    elif 'mnist' in args.target_dataset:\n        args.channels = 1\n    else:\n        args.channels = 3\n    # flops, params = profile(model, inputs=(torch.randn(1, args.channels, args.event_size, args.event_size),), verbose=False)\n    # _logger.info('flops = %fM', flops / 1e6)\n    # _logger.info('param size = %fM', params / 1e6)\n\n    linear_scaled_lr = args.lr * args.batch_size * args.world_size / 1024.0\n    args.lr = linear_scaled_lr\n    _logger.info(\"learning rate is %f\" % linear_scaled_lr)\n\n    if args.local_rank == 0:\n        _logger.info('Model %s created, param count: %d' %\n                     (args.model, sum([m.numel() for m in model.parameters()])))\n\n    # now config only for imnet\n    data_config = resolve_data_config(vars(args), model=model, verbose=False)\n    source_loader_train, _, _, _ = eval('get_transfer_%s_data' % args.source_dataset)(\n        batch_size=args.batch_size,\n        step=args.step,\n        args=args,\n        _logge=_logger,\n        data_config=data_config,\n        size=args.event_size,\n        mix_up=args.mix_up,\n        cut_mix=args.cut_mix,\n        event_mix=args.event_mix,\n        beta=args.cutmix_beta,\n        prob=args.cutmix_prob,\n        gaussian_n=args.gaussian_n,\n        num=args.cutmix_num,\n        noise=args.cutmix_noise,\n        num_classes=args.num_classes,\n        rand_aug=args.rand_aug,\n        randaug_n=args.randaug_n,\n        randaug_m=args.randaug_m,\n        portion=args.train_portion,\n        _logger=_logger,\n    )\n\n\n    target_loader_train, target_loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % args.target_dataset)(\n        batch_size=args.batch_size,\n        dvs_da=args.DVS_DA,\n        step=args.step,\n        args=args,\n        _logge=_logger,\n        data_config=data_config,\n        size=args.event_size,\n        mix_up=args.mix_up,\n        cut_mix=args.cut_mix,\n        event_mix=args.event_mix,\n        beta=args.cutmix_beta,\n        prob=args.cutmix_prob,\n        gaussian_n=args.gaussian_n,\n        num=args.cutmix_num,\n        noise=args.cutmix_noise,\n        num_classes=args.num_classes,\n        rand_aug=args.rand_aug,\n        randaug_n=args.randaug_n,\n        randaug_m=args.randaug_m,\n        portion=args.train_portion,\n        _logger=_logger,\n        train_data_ratio=args.traindata_ratio,\n        snr=args.snr,\n        data_mode=\"full\",\n        frames_num=12,\n        data_type=\"frequency\"\n    )\n\n    if args.eval:  # evaluate the model\n        if args.distributed:\n            state_dict = torch.load(args.eval_checkpoint)['state_dict_ema']\n            new_state_dict = OrderedDict()\n            # add module prefix for DDP\n            for k, v in state_dict.items():\n                k = 'module.' + k\n                new_state_dict[k] = v\n\n            model.load_state_dict(new_state_dict)\n        else:\n            model.load_state_dict(torch.load(args.eval_checkpoint, map_location=torch.device('cpu'))['state_dict'])\n        # --------------------------------------------------------------------------\n        # Show Acc\n        # --------------------------------------------------------------------------\n        print(\"load model finished!\")\n        # train_loss_fn = nn.CrossEntropyLoss()\n        # for i in range(1):\n        #     _, val_acc = validate(model, target_loader_train, train_loss_fn, args)\n        #     print(f\"Top-1 accuracy of the model is: {val_acc:.2f}%\")\n\n        # --------------------------------------------------------------------------\n        # Environment setup\n        # --------------------------------------------------------------------------\n        if args.mpi:\n            comm = mpi.setup_MPI()\n            rank, nproc = comm.Get_rank(), comm.Get_size()\n        else:\n            comm, rank, nproc = None, 0, 1\n\n        if True:\n            if not torch.cuda.is_available():\n                raise Exception('User selected cuda option, but cuda is not available on this machine')\n            gpu_count = torch.cuda.device_count()\n            # torch.cuda.set_device(rank % gpu_count)\n            torch.cuda.set_device(\"cuda:{}\".format(args.device))\n            print('Rank %d use GPU %d of %d GPUs on %s' %\n                  (rank, torch.cuda.current_device(), gpu_count, socket.gethostname()))\n\n        # --------------------------------------------------------------------------\n        # Check plotting resolution\n        # --------------------------------------------------------------------------\n        try:\n            args.xmin, args.xmax, args.xnum = [float(a) for a in args.x.split(':')]\n            args.ymin, args.ymax, args.ynum = (None, None, None)\n            if args.y:\n                args.ymin, args.ymax, args.ynum = [float(a) for a in args.y.split(':')]\n                assert args.ymin and args.ymax and args.ynum, \\\n                    'You specified some arguments for the y axis, but not all'\n        except:\n            raise Exception('Improper format for x- or y-coordinates. Try something like -1:1:51')\n\n        # --------------------------------------------------------------------------\n        # Load models and extract parameters\n        # --------------------------------------------------------------------------\n\n        net = model\n        w = net_plotter.get_weights(net)  # initial parameters\n        s = copy.deepcopy(net.state_dict())  # deepcopy since state_dict are references\n        if args.ngpu > 1:\n            # data parallel with multiple GPUs on a single node\n            net = nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))\n\n        # --------------------------------------------------------------------------\n        # Setup the direction file and the surface file\n        # --------------------------------------------------------------------------\n        dir_file = net_plotter.name_direction_file(args)  # name the direction file\n        dir_file = os.path.join(os.path.split(args.eval_checkpoint)[0], dir_file)\n        if rank == 0:\n            net_plotter.setup_direction(args, dir_file, net)\n\n        surf_file = name_surface_file(args, dir_file)\n        if rank == 0:\n            setup_surface_file(args, surf_file, dir_file)\n\n        # wait until master has setup the direction file and surface file\n        mpi.barrier(comm)\n\n        # load directions\n        d = net_plotter.load_directions(dir_file)\n        # calculate the consine similarity of the two directions\n        if len(d) == 2 and rank == 0:\n            similarity = proj.cal_angle(proj.nplist_to_tensor(d[0]), proj.nplist_to_tensor(d[1]))\n            print('cosine similarity between x-axis and y-axis: %f' % similarity)\n\n        mpi.barrier(comm)\n\n        # --------------------------------------------------------------------------\n        # Start the computation\n        # --------------------------------------------------------------------------\n        trainloader = target_loader_train\n        crunch(surf_file, net, w, s, d, trainloader, 'train_loss', 'train_acc', comm, rank, args)\n\n        # --------------------------------------------------------------------------\n        # Plot figures\n        # --------------------------------------------------------------------------\n        if args.plot and rank == 0:\n            if args.y and args.proj_file:\n                plot_2D.plot_contour_trajectory(surf_file, dir_file, args.proj_file, 'train_loss', args.show)\n            elif args.y:\n                plot_2D.plot_2d_contour(surf_file, 'train_loss', args.vmin, args.vmax, args.vlevel, args.show)\n            else:\n                plot_1D.plot_1d_loss_err(surf_file, args.xmin, args.xmax, args.loss_max, args.log, args.show)\n        return\n\n\nif __name__ == '__main__':\n    main()"
  },
  {
    "path": "examples/Snn_safety/DPSNN/Readme.txt",
    "content": "The code for the differential private spiking neural network(DPSNN).\n"
  },
  {
    "path": "examples/Snn_safety/DPSNN/load_data.py",
    "content": "import numpy as np\r\nfrom torchvision import datasets, transforms\r\nimport torch\r\nfrom torch.utils.data import Dataset\r\nimport tonic\r\nfrom tonic import DiskCachedDataset\r\nimport torch.nn.functional as F\r\nimport os\r\n\r\n\r\nMNIST_MEAN = 0.1307\r\nMNIST_STD = 0.3081\r\nCIFAR10_MEAN = (0.4914, 0.4822, 0.4465)\r\nCIFAR10_STD_DEV = (0.2023, 0.1994, 0.2010)\r\ncifar100_mean = [0.5071, 0.4865, 0.4409]\r\ncifar100_std = [0.2673, 0.2563, 0.2761]\r\nDVSCIFAR10_MEAN_16 = [0.3290, 0.4507]\r\nDVSCIFAR10_STD_16 = [1.8398, 1.6549]\r\n\r\nDATA_DIR = '/data/datasets'\r\n\r\n\r\nclass CustomDataset(Dataset):\r\n    \"\"\"An abstract Dataset class wrapped around Pytorch Dataset class.\r\n    \"\"\"\r\n\r\n    def __init__(self, dataset, indices):\r\n        self.dataset = dataset\r\n        self.indices = [int(i) for i in indices]\r\n\r\n    def __len__(self):\r\n        return len(self.indices)\r\n\r\n    def __getitem__(self, item):\r\n        x, y = self.dataset[self.indices[item]]\r\n        return x, y\r\n\r\n\r\ndef load_static_data(data_root, batch_size, dataset):\r\n    if dataset == 'cifar10':\r\n        transform_train = transforms.Compose([\r\n            transforms.ToTensor(),\r\n            transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD_DEV)])\r\n\r\n        transform_test = transforms.Compose([\r\n            transforms.ToTensor(),\r\n            transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD_DEV)])\r\n\r\n        train_data = datasets.CIFAR10(data_root, train=True, transform=transform_train, download=True)\r\n        test_data = datasets.CIFAR10(data_root, train=False, transform=transform_test, download=True)\r\n\r\n        train_loader = torch.utils.data.DataLoader(\r\n            train_data,\r\n            batch_size=batch_size,\r\n            shuffle=True\r\n        )\r\n        test_loader = torch.utils.data.DataLoader(\r\n            test_data,\r\n            batch_size=batch_size,\r\n        )\r\n    elif dataset == 'MNIST':\r\n        transform_train = transforms.Compose([\r\n            transforms.ToTensor(),\r\n            transforms.Normalize(MNIST_MEAN, MNIST_STD)])\r\n\r\n        transform_test = transforms.Compose([\r\n            transforms.ToTensor(),\r\n            transforms.Normalize(MNIST_MEAN, MNIST_STD)])\r\n\r\n        train_data = datasets.MNIST(data_root, train=True, transform=transform_train, download=True)\r\n        test_data = datasets.MNIST(data_root, train=False, transform=transform_test, download=True)\r\n\r\n        train_loader = torch.utils.data.DataLoader(\r\n            train_data,\r\n            batch_size=batch_size,\r\n            shuffle=True\r\n        )\r\n        test_loader = torch.utils.data.DataLoader(\r\n            test_data,\r\n            batch_size=batch_size,\r\n        )\r\n    elif dataset == 'FashionMNIST':\r\n        transform_train = transforms.Compose([\r\n            transforms.ToTensor(),\r\n            transforms.Normalize(MNIST_MEAN, MNIST_STD)])\r\n\r\n        transform_test = transforms.Compose([\r\n            transforms.ToTensor(),\r\n            transforms.Normalize(MNIST_MEAN, MNIST_STD)])\r\n\r\n        train_data = datasets.FashionMNIST(data_root, train=True, transform=transform_train, download=True)\r\n        test_data = datasets.FashionMNIST(data_root, train=False, transform=transform_test, download=True)\r\n\r\n        train_loader = torch.utils.data.DataLoader(\r\n            train_data,\r\n            batch_size=batch_size,\r\n            shuffle=True\r\n        )\r\n        test_loader = torch.utils.data.DataLoader(\r\n            test_data,\r\n            batch_size=batch_size,\r\n        )\r\n\r\n    return train_data, test_data, train_loader, test_loader\r\n\r\n\r\ndef load_dvs10_data(batch_size, step, **kwargs):\r\n    size = kwargs['size'] if 'size' in kwargs else 48\r\n    sensor_size = tonic.datasets.CIFAR10DVS.sensor_size\r\n    train_transform = transforms.Compose([\r\n        # tonic.transforms.Denoise(filter_time=10000),\r\n        # tonic.transforms.DropEvent(p=0.1),\r\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\r\n    test_transform = transforms.Compose([\r\n        # tonic.transforms.Denoise(filter_time=10000),\r\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\r\n    train_dataset = tonic.datasets.CIFAR10DVS(os.path.join(DATA_DIR, 'DVS/DVS_Cifar10'), transform=train_transform)\r\n    test_dataset = tonic.datasets.CIFAR10DVS(os.path.join(DATA_DIR, 'DVS/DVS_Cifar10'), transform=test_transform)\r\n\r\n    train_transform = transforms.Compose([\r\n        lambda x: torch.tensor(x, dtype=torch.float),\r\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\r\n    ])\r\n    test_transform = transforms.Compose([\r\n        lambda x: torch.tensor(x, dtype=torch.float),\r\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\r\n    ])\r\n\r\n    train_dataset = DiskCachedDataset(train_dataset,\r\n                                      cache_path=f'./dataset/dvs_cifar10/train_cache_{step}',\r\n                                      transform=train_transform)\r\n    test_dataset = DiskCachedDataset(train_dataset,\r\n                                     cache_path=f'./dataset/dvs_cifar10/test_cache_{step}',\r\n                                     transform=test_transform)\r\n\r\n    num_train = len(train_dataset)\r\n    num_per_cls = num_train // 10\r\n    indices_train, indices_test = [], []\r\n    portion = kwargs['portion'] if 'portion' in kwargs else .9\r\n    for i in range(10):\r\n        indices_train.extend(\r\n            list(range(i * num_per_cls, round(i * num_per_cls + num_per_cls * portion))))\r\n        indices_test.extend(\r\n            list(range(round(i * num_per_cls + num_per_cls * portion), (i + 1) * num_per_cls)))\r\n    train_dataset = CustomDataset(train_dataset, np.array(indices_train))\r\n    test_dataset = CustomDataset(test_dataset, np.array(indices_test))\r\n\r\n    train_loader = torch.utils.data.DataLoader(\r\n        train_dataset, batch_size=batch_size, shuffle=True,\r\n        pin_memory=True, drop_last=False, num_workers=1\r\n    )\r\n\r\n    test_loader = torch.utils.data.DataLoader(\r\n        test_dataset, batch_size=batch_size,\r\n        pin_memory=True, drop_last=False, num_workers=1\r\n    )\r\n\r\n    return train_loader, test_loader, train_dataset, test_dataset\r\n\r\n\r\ndef load_nmnist_data(batch_size, step, **kwargs):\r\n    size = kwargs['size'] if 'size' in kwargs else 28\r\n    sensor_size = tonic.datasets.NMNIST.sensor_size\r\n    train_transform = transforms.Compose([\r\n        # tonic.transforms.Denoise(filter_time=10000),\r\n        # tonic.transforms.DropEvent(p=0.1),\r\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\r\n    test_transform = transforms.Compose([\r\n        # tonic.transforms.Denoise(filter_time=10000),\r\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\r\n    train_dataset = tonic.datasets.NMNIST(os.path.join(DATA_DIR, 'DVS/NMNIST'), transform=train_transform, train=True)\r\n    test_dataset = tonic.datasets.NMNIST(os.path.join(DATA_DIR, 'DVS/NMNIST'), transform=test_transform, train=False)\r\n\r\n    train_transform = transforms.Compose([\r\n        lambda x: torch.tensor(x, dtype=torch.float),\r\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\r\n\r\n    ])\r\n    test_transform = transforms.Compose([\r\n        lambda x: torch.tensor(x, dtype=torch.float),\r\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\r\n    ])\r\n\r\n    train_dataset = DiskCachedDataset(train_dataset,\r\n                                      cache_path=f'./dataset/NMNIST/train_cache_{step}',\r\n                                      transform=train_transform)\r\n    test_dataset = DiskCachedDataset(test_dataset,\r\n                                     cache_path=f'./dataset/NMNIST/test_cache_{step}',\r\n                                     transform=test_transform)\r\n\r\n    train_loader = torch.utils.data.DataLoader(\r\n        train_dataset, batch_size=batch_size, shuffle=True,\r\n        pin_memory=True, drop_last=False, num_workers=1\r\n    )\r\n\r\n    test_loader = torch.utils.data.DataLoader(\r\n        test_dataset, batch_size=batch_size,\r\n        pin_memory=True, drop_last=False, num_workers=1\r\n    )\r\n\r\n    return train_loader, test_loader, train_dataset, test_dataset\r\n"
  },
  {
    "path": "examples/Snn_safety/DPSNN/main_dpsnn.py",
    "content": "import torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nimport torch.optim as optim\r\nfrom torchvision import datasets, transforms\r\nfrom opacus import PrivacyEngine\r\n\r\nfrom model import *\r\nfrom braincog.base.node.node import *\r\nimport warnings\r\nfrom load_data import *\r\nfrom opacus.utils.batch_memory_manager import BatchMemoryManager\r\n\r\nwarnings.simplefilter(\"ignore\")\r\n\r\n# Precomputed characteristics of the dataset dataset\r\ntorch.cuda.manual_seed(3154)\r\n\r\nbatch_size = 512\r\nMAX_PHYSICAL_BATCH_SIZE = 32\r\ntarget_ep = 8\r\nc = 6\r\nepochs = 40\r\n\r\nstep = 10\r\ndelta = 1e-5\r\ndevices = 4\r\nr = 5\r\ndevice = torch.device(f'cuda:{devices}' if torch.cuda.is_available() else 'cpu')\r\n# device = 'cpu'\r\ndisable_noise = False\r\ndata_root = \"./dataset\"\r\nkwargs = {\"num_workers\": 1, \"pin_memory\": True}\r\ndataset = 'dvs_cifar10'\r\n# NMNIST, cifar10, dvs_cifar10, MNIST, FashionMNIST\r\n\r\n\r\ndef train(model, device, train_loader, optimizer, epoch, privacy_engine):\r\n    criterion = nn.CrossEntropyLoss().to(device)\r\n    losses = []\r\n    model.train()\r\n    correct = 0\r\n    for _batch_idx, (data, target) in enumerate(train_loader):\r\n        # print(target)\r\n        data, target = data.to(device), target.to(device)\r\n        optimizer.zero_grad()\r\n        output = model(data)\r\n        loss = criterion(output, target)\r\n        loss.backward()\r\n        optimizer.step()\r\n        losses.append(loss.item())\r\n        pred = output.argmax(\r\n            dim=1, keepdim=True\r\n        )  # get the index of the max log-probability\r\n        correct += pred.eq(target.view_as(pred)).sum().item()\r\n\r\n    if not disable_noise:\r\n        epsilon = privacy_engine.get_epsilon(delta=delta)\r\n        print(\r\n            f\"Train Epoch: {epoch} \\t\"\r\n            f\"Loss: {np.mean(losses):.6f} \"\r\n        )\r\n        print(\"Accuracy: {}/{} ({:.2f}%)\\n\".format(\r\n            correct,\r\n            len(train_loader.dataset),\r\n            100.0 * correct / len(train_loader.dataset), ))\r\n        print(\r\n              f\"(ε = {epsilon:.2f}, δ = {delta})\"\r\n              )\r\n    else:\r\n        print(f\"Train Epoch: {epoch} \\t Loss: {np.mean(losses):.6f}\")\r\n\r\n    return 100.0 * correct / len(train_loader.dataset)\r\n\r\n\r\ndef test(model, device, test_loader):\r\n    model.eval()\r\n    criterion = nn.CrossEntropyLoss().to(device)\r\n    test_loss = 0\r\n    correct = 0\r\n    with torch.no_grad():\r\n        for data, target in test_loader:\r\n            data, target = data.to(device), target.to(device)\r\n            output = model(data)\r\n            test_loss += criterion(output, target).item()  # sum up batch loss\r\n            pred = output.argmax(\r\n                dim=1, keepdim=True\r\n            )  # get the index of the max log-probability\r\n            correct += pred.eq(target.view_as(pred)).sum().item()\r\n\r\n    test_loss /= len(test_loader)\r\n\r\n    print(\r\n        \"\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\\n\".format(\r\n            test_loss,\r\n            correct,\r\n            len(test_loader.dataset),\r\n            100.0 * correct / len(test_loader.dataset),\r\n        )\r\n    )\r\n    return correct / len(test_loader.dataset)\r\n\r\n\r\ndef run():\r\n    if dataset == 'dvs_cifar10':\r\n        train_loader, test_loader, train_data, test_data = load_dvs10_data(batch_size=batch_size, step=step)\r\n        # train_loader, test_loader, _, _ = get_dvsc10_data(batch_size=batch_size, step=step)\r\n    elif dataset == 'NMNIST':\r\n        train_loader, test_loader, train_data, test_data = load_nmnist_data(batch_size=batch_size, step=step)\r\n    else:\r\n        train_data, test_data, train_loader, test_loader = load_static_data(data_root, batch_size, dataset)\r\n\r\n    result = []\r\n    result_train = []\r\n    for _ in range(r):\r\n        if dataset == 'cifar10':\r\n            model = cifar_convnet(\r\n                step=step,\r\n                encode_type='direct',\r\n                node_type=LIFNode,\r\n                num_classes=10,\r\n                spike_output=False,\r\n                layer_by_layer=True,\r\n                act_fun=QGateGrad\r\n            )\r\n            model.to(device)\r\n            optimizer = optim.AdamW(model.parameters(), lr=0.001)\r\n            scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[40], gamma=0.1, last_epoch=-1)\r\n\r\n        elif dataset == 'dvs_cifar10':\r\n            model = dvs_convnet(\r\n                step=step,\r\n                encode_type='direct',\r\n                node_type=LIFNode,\r\n                num_classes=10,\r\n                spike_output=False,\r\n                layer_by_layer=True,\r\n                act_fun=QGateGrad\r\n            )\r\n            model.to(device)\r\n            optimizer = optim.AdamW(model.parameters(), lr=0.001)\r\n            scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[15], gamma=1, last_epoch=-1)\r\n        elif dataset == 'NMNIST':\r\n            model = SimpleSNN(\r\n                channel=2,\r\n                step=step,\r\n                node_type=LIFNode,\r\n                act_fun=QGateGrad,\r\n                layer_by_layer=True,\r\n            )\r\n            model.to(device)\r\n\r\n            optimizer = optim.AdamW(model.parameters(), lr=0.005)\r\n            scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10], gamma=0.1, last_epoch=-1)\r\n\r\n        elif dataset == 'MNIST' or dataset == 'FashionMNIST':\r\n            model = SimpleSNN(\r\n                channel=1,\r\n                step=step,\r\n                node_type=LIFNode,\r\n                act_fun=QGateGrad,\r\n                layer_by_layer=True,\r\n            )\r\n            model.to(device)\r\n            optimizer = optim.AdamW(model.parameters(), lr=0.005)\r\n            scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10], gamma=0.1, last_epoch=-1)\r\n        if not disable_noise:\r\n            privacy_engine = PrivacyEngine()\r\n            model, optimizer, data_loader = privacy_engine.make_private_with_epsilon(\r\n                module=model,\r\n                optimizer=optimizer,\r\n                data_loader=train_loader,\r\n                max_grad_norm=c,\r\n                epochs=epochs,\r\n                target_delta=delta,\r\n                target_epsilon=target_ep\r\n            )\r\n            with BatchMemoryManager(\r\n                    data_loader=data_loader,\r\n                    max_physical_batch_size=MAX_PHYSICAL_BATCH_SIZE,\r\n                    optimizer=optimizer\r\n            ) as memory_safe_data_loader:\r\n                # if 1:\r\n                for epoch in range(1, epochs + 1):\r\n                    result_train.append(train(model, device, memory_safe_data_loader, optimizer, epoch, privacy_engine))\r\n                    result.append(test(model, device, test_loader))\r\n                    scheduler.step()\r\n        else:\r\n            privacy_engine = PrivacyEngine()\r\n            model, optimizer, data_loader = privacy_engine.make_private(\r\n                module=model,\r\n                optimizer=optimizer,\r\n                data_loader=train_loader,\r\n                max_grad_norm=c,\r\n                noise_multiplier=0.0,\r\n            )\r\n            with BatchMemoryManager(\r\n                    data_loader=data_loader,\r\n                    max_physical_batch_size=MAX_PHYSICAL_BATCH_SIZE,\r\n                    optimizer=optimizer\r\n            ) as memory_safe_data_loader:\r\n                for epoch in range(1, epochs + 1):\r\n                    train(model, device, memory_safe_data_loader, optimizer, epoch, privacy_engine)\r\n                    result.append(test(model, device, test_loader))\r\n                    scheduler.step()\r\n    result = np.array(result).reshape((r, -1))\r\n    result_train = np.array(result_train).reshape((r, -1))\r\n    best_acc = np.mean(np.max(result, axis=1))\r\n    print(best_acc)\r\n    np.save(file=f'./{dataset}/MP_test.npy', arr=result)\r\n    np.save(file=f'./{dataset}/MP_train.npy', arr=result_train)\r\n\r\nif __name__ == \"__main__\":\r\n    run()\r\n"
  },
  {
    "path": "examples/Snn_safety/DPSNN/model.py",
    "content": "import abc\r\nfrom functools import partial\r\nimport torch\r\nfrom torch.nn import functional as F\r\nimport torchvision\r\nfrom timm.models import register_model\r\nfrom braincog.base.node.node import *\r\nfrom braincog.base.connection.layer import *\r\nfrom braincog.base.encoder.encoder import *\r\nfrom braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule\r\n\r\n\r\nclass TEP(nn.Module):\r\n    def __init__(self, step, channel, device=None, dtype=None):\r\n        factory_kwargs = {'device': device, 'dtype': dtype}\r\n        super(TEP, self).__init__()\r\n        self.step = step\r\n        self.gn = nn.GroupNorm(channel, channel)\r\n\r\n\r\n    def forward(self, x):\r\n\r\n        x = rearrange(x, '(t b) c w h -> t b c w h', t=self.step)\r\n        fire_rate = torch.mean(x, dim=0)\r\n        fire_rate = self.gn(fire_rate) + 1\r\n\r\n        x = x * fire_rate\r\n        x = rearrange(x, 't b c w h -> (t b) c w h')\r\n\r\n        return x\r\n\r\n\r\nclass BaseConvNet(BaseModule, abc.ABC):\r\n    def __init__(self,\r\n                 step,\r\n                 input_channels,\r\n                 num_classes,\r\n                 encode_type,\r\n                 spike_output: bool,\r\n                 out_channels: list,\r\n                 block_depth: list,\r\n                 node_list: list,\r\n                 *args,\r\n                 **kwargs):\r\n        super().__init__(step, encode_type, *args, **kwargs)\r\n        self.num_cls = num_classes\r\n        self.spike_output = spike_output\r\n        self.groups = kwargs['n_groups'] if 'n_groups' in kwargs else 1\r\n        if not spike_output:\r\n            node_list.append(nn.Identity)\r\n            out_channels.append(self.num_cls)\r\n            self.vote = nn.Identity()\r\n            # self.vote = nn.Sequential(\r\n            #     nn.Linear(self.step, 32),\r\n            #     nn.ReLU(),\r\n            #     nn.Linear(32, 1)\r\n            # )\r\n        else:\r\n            out_channels.append(10 * self.num_cls)\r\n            self.vote = VotingLayer(10)\r\n\r\n        # check list length\r\n        if len(node_list) != len(out_channels):\r\n            raise ValueError\r\n        self.input_channels = input_channels\r\n        self.out_channels = out_channels\r\n        self.block_depth = block_depth\r\n        self.node_list = node_list\r\n        self.feature = self._create_feature()\r\n        self.fc = self._create_fc()\r\n        if self.layer_by_layer:\r\n            self.flatten = nn.Flatten(start_dim=1)\r\n        else:\r\n            self.flatten = nn.Flatten()\r\n\r\n    @staticmethod\r\n    def _create_feature(self):\r\n        raise NotImplementedError\r\n\r\n    @staticmethod\r\n    def _create_fc(self):\r\n        raise NotImplementedError\r\n\r\n    def forward(self, inputs):\r\n        inputs = self.encoder(inputs)\r\n        self.reset()\r\n        if not self.training:\r\n            self.fire_rate.clear()\r\n\r\n        if not self.layer_by_layer:\r\n            outputs = []\r\n            if self.warm_up:\r\n                step = 1\r\n            else:\r\n                step = self.step\r\n\r\n            for t in range(step):\r\n                x = inputs[t]\r\n                x = self.feature(x)\r\n                x = self.flatten(x)\r\n                x = self.fc(x)\r\n                x = self.vote(x)\r\n                outputs.append(x)\r\n\r\n            return sum(outputs) / len(outputs)\r\n            # outputs = torch.stack(outputs)\r\n            # outputs = rearrange(outputs, 't b c -> b c t')\r\n            # outputs = self.vote(outputs).squeeze()\r\n            # return outputs\r\n\r\n        else:\r\n            x = self.feature(inputs)\r\n            x = self.flatten(x)\r\n            x = self.fc(x)\r\n            if self.groups == 1:\r\n                x = rearrange(x, '(t b) c -> t b c', t=self.step).mean(0)\r\n            else:\r\n                x = rearrange(x, 'b (c t) -> t b c', t=self.step).mean(0)\r\n            x = self.vote(x)\r\n            return x\r\n\r\n\r\nclass LayerWiseConvModule(nn.Module):\r\n    \"\"\"\r\n    SNN卷积模块\r\n    :param in_channels: 输入通道数\r\n    :param out_channels: 输出通道数\r\n    :param kernel_size: kernel size\r\n    :param stride: stride\r\n    :param padding: padding\r\n    :param bias: Bias\r\n    :param node: 神经元类型\r\n    :param kwargs:\r\n    \"\"\"\r\n\r\n    def __init__(self,\r\n                 in_channels: int,\r\n                 out_channels: int,\r\n                 kernel_size=(3, 3),\r\n                 stride=(1, 1),\r\n                 padding=(1, 1),\r\n                 bias=False,\r\n                 node=LIFNode,\r\n                 step=6,\r\n                 **kwargs):\r\n\r\n        super().__init__()\r\n\r\n        if node is None:\r\n            raise TypeError\r\n\r\n        self.groups = kwargs['groups'] if 'groups' in kwargs else 1\r\n        self.conv = nn.Conv2d(in_channels=in_channels * self.groups,\r\n                              out_channels=out_channels * self.groups,\r\n                              kernel_size=kernel_size,\r\n                              padding=padding,\r\n                              stride=stride,\r\n                              bias=bias)\r\n        self.gn = nn.GroupNorm(16, out_channels * self.groups)\r\n        self.node = partial(node, **kwargs)()\r\n        self.step = step\r\n        self.activation = nn.Identity()\r\n\r\n    def forward(self, x):\r\n        x = rearrange(x, '(t b) c w h -> t b c w h', t=self.step)\r\n        outputs = []\r\n\r\n        for t in range(self.step):\r\n            outputs.append(self.gn(self.conv(x[t])))\r\n        outputs = torch.stack(outputs)  # t b c w h\r\n        outputs = rearrange(outputs, 't b c w h -> (t b) c w h')\r\n        outputs = self.node(outputs)\r\n\r\n        return outputs\r\n\r\n\r\nclass LayerWiseLinearModule(nn.Module):\r\n    \"\"\"\r\n    线性模块\r\n    :param in_features: 输入尺寸\r\n    :param out_features: 输出尺寸\r\n    :param bias: 是否有Bias, 默认 ``False``\r\n    :param node: 神经元类型, 默认 ``LIFNode``\r\n    :param args:\r\n    :param kwargs:\r\n    \"\"\"\r\n\r\n    def __init__(self,\r\n                 in_features: int,\r\n                 out_features: int,\r\n                 bias=True,\r\n                 node=LIFNode,\r\n                 step=6,\r\n                 spike=False,\r\n                 *args,\r\n                 **kwargs):\r\n        super().__init__()\r\n        if node is None:\r\n            raise TypeError\r\n\r\n        self.groups = kwargs['groups'] if 'groups' in kwargs else 1\r\n        if self.groups == 1:\r\n            self.fc = nn.Linear(in_features=in_features,\r\n                                out_features=out_features, bias=bias)\r\n        else:\r\n            self.fc = nn.ModuleList()\r\n            for i in range(self.groups):\r\n                self.fc.append(nn.Linear(\r\n                    in_features=in_features,\r\n                    out_features=out_features,\r\n                    bias=bias\r\n                ))\r\n        self.node = partial(node, **kwargs)()\r\n        self.step = step\r\n        self.spike = spike\r\n\r\n    def forward(self, x):\r\n        if self.groups == 1:  # (t b) c\r\n            x = rearrange(x, '(t b) c -> t b c', t=self.step)\r\n            outputs = []\r\n            for t in range(self.step):\r\n                outputs.append(self.fc(x[t]))\r\n            outputs = torch.stack(outputs)  # t b c\r\n            outputs = rearrange(outputs, 't b c -> (t b) c')\r\n\r\n        else:  # b (c t)\r\n            x = rearrange(x, 'b (c t) -> t b c', t=self.groups)\r\n            outputs = []\r\n            for i in range(self.groups):\r\n                outputs.append(self.fc[i](x[i]))\r\n            outputs = torch.stack(outputs)  # t b c\r\n            outputs = rearrange(outputs, 't b c -> b (c t)')\r\n        if self.spike:\r\n            return self.node(outputs)\r\n        else:\r\n            return outputs\r\n\r\n\r\nclass LayWiseConvNet(BaseConvNet):\r\n    def __init__(self,\r\n                 step,\r\n                 input_channels,\r\n                 num_classes,\r\n                 encode_type,\r\n                 spike_output: bool,\r\n                 out_channels: list,\r\n                 node_list: list,\r\n                 block_depth: list,\r\n                 *args,\r\n                 **kwargs):\r\n        super().__init__(step,\r\n                         input_channels,\r\n                         num_classes,\r\n                         encode_type,\r\n                         spike_output,\r\n                         out_channels,\r\n                         block_depth,\r\n                         node_list,\r\n                         *args,\r\n                         **kwargs)\r\n\r\n    def _create_feature(self):\r\n        feature_depth = len(self.node_list) - 1\r\n\r\n        feature = [LayerWiseConvModule(\r\n            self.input_channels * self.init_channel_mul, self.out_channels[0], node=self.node_list[0],\r\n            groups=self.groups, step=self.step)]\r\n        if self.block_depth[0] != 1:\r\n            feature.extend(\r\n                [LayerWiseConvModule(self.out_channels[0], self.out_channels[0], node=self.node_list[0],\r\n                                     groups=self.groups, step=self.step)] * (\r\n                        self.block_depth[0] - 1),\r\n            )\r\n        feature.append(TEP(channel=self.out_channels[0], step=self.step))\r\n        feature.append(nn.AvgPool2d(kernel_size=4, stride=2))\r\n        for i in range(1, feature_depth - 1):\r\n            feature.append(LayerWiseConvModule(\r\n                self.out_channels[i - 1], self.out_channels[i], node=self.node_list[i], groups=self.groups,\r\n                step=self.step))\r\n            if self.block_depth[i] != 1:\r\n                feature.extend(\r\n                    [LayerWiseConvModule(self.out_channels[i], self.out_channels[i], node=self.node_list[i],\r\n                                         groups=self.groups,\r\n                                         step=self.step)] * (\r\n                            self.block_depth[i] - 1),\r\n                )\r\n            feature.append(TEP(channel=self.out_channels[i], step=self.step))\r\n            feature.append(nn.AvgPool2d(kernel_size=4, stride=2))\r\n        feature.append(LayerWiseConvModule(\r\n            self.out_channels[-3], self.out_channels[-2], node=self.node_list[-2], groups=self.groups,\r\n            step=self.step))\r\n        if self.block_depth[feature_depth - 1] != 1:\r\n            feature.extend(\r\n                [LayerWiseConvModule(self.out_channels[-2], self.out_channels[-2], node=self.node_list[-2],\r\n                                     groups=self.groups,\r\n                                     step=self.step)] * (\r\n                        self.block_depth[feature_depth - 1] - 1),\r\n            )\r\n        feature.append(nn.AdaptiveAvgPool2d(1))\r\n\r\n        return nn.Sequential(*feature)\r\n\r\n    def _create_fc(self):\r\n        fc = nn.Sequential(\r\n            # NDropout(.5),\r\n            LayerWiseLinearModule(\r\n                self.out_channels[-2], self.out_channels[-1], node=self.node_list[-1], groups=self.groups,\r\n                step=self.step, spike=False)\r\n        )\r\n        return fc\r\n\r\n\r\n@register_model\r\ndef cifar_convnet(step,\r\n                encode_type,\r\n                spike_output: bool,\r\n                node_type,\r\n                *args,\r\n                **kwargs):\r\n    # out_channels = [256, 256, 512, 1024]\r\n    out_channels = [64, 128, 128, 256]\r\n    block_depth = [2, 2, 2, 2]\r\n    # print(kwargs)\r\n    node_cls = partial(node_type, step=step, **kwargs)\r\n    # print(node_cls)\r\n    if spike_output:\r\n        node_list = [node_cls] * (len(out_channels) + 1)\r\n    else:\r\n        node_list = [node_cls] * (len(out_channels))\r\n\r\n    return LayWiseConvNet(step=step,\r\n                          input_channels=3,\r\n                          encode_type=encode_type,\r\n                          node_list=node_list,\r\n                          block_depth=block_depth,\r\n                          out_channels=out_channels,\r\n                          spike_output=spike_output,\r\n                          **kwargs)\r\n\r\n\r\n@register_model\r\ndef dvs_convnet(step,\r\n                encode_type,\r\n                spike_output: bool,\r\n                node_type,\r\n                num_classes,\r\n                *args,\r\n                **kwargs):\r\n    out_channels = [64, 128, 256, 512, 1024]\r\n    block_depth = [2, 1, 2, 1, 2]\r\n\r\n    node_cls = partial(node_type, step=step, **kwargs)\r\n    if spike_output:\r\n        node_list = [node_cls] * (len(out_channels) + 1)\r\n        # node_list[-2] = partial(DoubleSidePLIFNode, step=step, **kwargs)\r\n    else:\r\n        node_list = [node_cls] * (len(out_channels))\r\n        # node_list[-1] = partial(DoubleSidePLIFNode, step=step, **kwargs)\r\n\r\n    return LayWiseConvNet(step=step,\r\n                          input_channels=2,\r\n                          num_classes=num_classes,\r\n                          encode_type=encode_type,\r\n                          node_list=node_list,\r\n                          block_depth=block_depth,\r\n                          out_channels=out_channels,\r\n                          spike_output=spike_output,\r\n                          **kwargs)\r\n\r\n\r\n@register_model\r\nclass SimpleSNN(BaseModule, abc.ABC):\r\n    def __init__(self,\r\n                 channel=1,\r\n                 num_classes=10,\r\n                 step=8,\r\n                 node_type=LIFNode,\r\n                 encode_type='direct',\r\n                 *args,\r\n                 **kwargs):\r\n        super().__init__(step, encode_type, *args, **kwargs)\r\n        self.num_classes = num_classes\r\n\r\n        self.node = node_type\r\n        init_channel = channel\r\n\r\n        self.feature = nn.Sequential(\r\n            LayerWiseConvModule(init_channel, 32, kernel_size=7, padding=0, node=self.node, step=step),\r\n            TEP(step=step, channel=32),\r\n            nn.AvgPool2d(kernel_size=2, stride=2),\r\n            LayerWiseConvModule(32, 64, kernel_size=4, padding=0, node=self.node, step=step),\r\n            TEP(step=step, channel=64),\r\n            nn.AvgPool2d(kernel_size=2, stride=2),\r\n        )\r\n        self.fc = nn.Sequential(\r\n            nn.Flatten(),\r\n            LayerWiseLinearModule(64 * 4 * 4, self.num_classes, node=self.node, spike=False, step=step),\r\n        )\r\n\r\n    def forward(self, inputs):\r\n        inputs = self.encoder(inputs)\r\n        self.reset()\r\n\r\n        if self.layer_by_layer:\r\n            x = self.feature(inputs)\r\n            x = self.fc(x)\r\n            x = rearrange(x, '(t b) c -> t b c', t=self.step).mean(0)\r\n            return x\r\n\r\n        else:\r\n            outputs = []\r\n            for t in range(self.step):\r\n                x = inputs[t]\r\n                x = self.feature(x)\r\n                x = self.fc(x)\r\n                outputs.append(x)\r\n\r\n            return sum(outputs) / len(outputs)\r\n"
  },
  {
    "path": "examples/Snn_safety/RandHet-SNN/README.md",
    "content": "* To train a SNN with AT on CIFAR-10:\n```\npython train.py --adv_training --attack_iters 1 --epsilon 4 --alpha 4 --network ResNet18 --batch_size 64 --worker 4 --node_type LIF --save_dir AT --device cuda:1 --time_step 8 --dataset cifar10\n```\n\n* To train a RHSNN with AT on CIFAR-10:\n```\npython train.py --adv_training --attack_iters 1 --epsilon 4 --alpha 4 --network ResNet18 --batch_size 64 --worker 4 --node_type RHLIF --save_dir AT_RH_1 --device cuda:1 --time_step 8 --dataset cifar10\n```\n\n* To train a RHSNN with RAT on CIFAR-10:\n```\npython train.py --adv_training --attack_iters 1 --epsilon 4 --alpha 4 --network ResNet18 --batch_size 64 --worker 4 --node_type RHLIF --save_dir RAT_RH_1 --device cuda:1 --parseval --beta 0.004 --time_step 8 --dataset cifar10\n```\n\n* To train a RHSNN with SR on CIFAR-10:\n```\npython train.py --adv_training --attack_iters 1 --epsilon 4 --alpha 4 --network ResNet18 --batch_size 64 --worker 4 --node_type RHLIF --save_dir SR_RH_1 --device cuda:1 --SR --time_step 8 --dataset cifar10\n```\n\n* To evaluate the performance of RHSNN on CIFAR-10:\n\n```\npython evaluate.py --network ResNet18 --attack_type all --batch_size 32 --worker 4 --node_type RHLIF --pretrain RAT_RH_1/weight_r.pth --save_dir RAT --device cuda:1 --time_step 8 --dataset cifar10\n```\n\n"
  },
  {
    "path": "examples/Snn_safety/RandHet-SNN/evaluate.py",
    "content": "import argparse\nimport copy\nimport logging\nimport os\nimport sys\nimport time\nfrom my_node import RHLIFNode, RHLIFNode2\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom sew_resnet import SEWResNet19, BasicBlock\nfrom braincog.base.node.node import *\n\nfrom utils import evaluate_standard\n\nfrom utils import get_loaders\n\nimport torchattacks\n\nfrom tqdm import tqdm\n\nlogger = logging.getLogger(__name__)\n\n\ndef get_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--batch_size', default=32, type=int)\n    parser.add_argument('--data_dir', default='/mnt/data/datasets', type=str)\n    parser.add_argument('--dataset', default='cifar10', choices=['cifar10', 'cifar100'])\n    parser.add_argument('--network', default='ResNet18', type=str)\n    parser.add_argument('--worker', default=4, type=int)\n    parser.add_argument('--epsilon', default=8, type=int)\n    parser.add_argument('--device', default='cuda:1', type=str)\n    parser.add_argument('--pretrain', default=None, type=str, help='path to load the pretrained model')\n    parser.add_argument('--save_dir', default=None, type=str, help='path to save log')\n    parser.add_argument('--attack_type', default='pgd')\n    parser.add_argument('--time_step', default=8, type=int)\n    parser.add_argument('--node_type', default='LIF', type=str)\n    return parser.parse_args()\n\ndef evaluate_attack(model, test_loader, args, atk, atk_name, logger):\n    test_loss = 0\n    test_acc = 0\n    n = 0\n    model.eval()\n    device = args.device\n\n    test_loader = iter(test_loader)\n\n    bar_format = '{desc}[{elapsed}<{remaining},{rate_fmt}]'\n    pbar = tqdm(range(len(test_loader)), file=sys.stdout, bar_format=bar_format, ncols=80)\n    for i in pbar:\n        X, y = next(test_loader)\n        X, y = X.to(device), y.to(device)\n        X_adv = atk(X, y)  # advtorch\n        with torch.no_grad():\n            output = model(X_adv)\n        loss = F.cross_entropy(output, y)\n        test_loss += loss.item() * y.size(0)\n        test_acc += (output.max(1)[1] == y).sum().item()\n        n += y.size(0)\n\n    pgd_acc = test_acc / n\n    pgd_loss = test_loss / n\n\n    logger.info(atk_name)\n    logger.info('adv: %.4f \\t', pgd_acc)\n\n    return pgd_loss, pgd_acc\n\ndef main():\n    args = get_args()\n\n    args.save_dir = os.path.join('logs', args.save_dir)\n\n    if not os.path.exists(args.save_dir):\n        os.makedirs(args.save_dir)\n    logfile = os.path.join(args.save_dir, 'output.log')\n    if os.path.exists(logfile):\n        os.remove(logfile)\n\n\n    log_path = os.path.join(args.save_dir, 'output_test.log')\n\n    handlers = [logging.FileHandler(log_path, mode='a+'),\n                logging.StreamHandler()]\n\n    logging.basicConfig(\n        format='[%(asctime)s] - %(message)s',\n        datefmt='%Y/%m/%d %H:%M:%S',\n        level=logging.INFO,\n        handlers=handlers)\n\n    logger.info(args)\n\n    # assert type(args.pretrain) == str and os.path.exists(args.pretrain)\n\n    if args.dataset == 'cifar10':\n        args.num_classes = 10\n    elif args.dataset == 'cifar100':\n        args.num_classes = 100\n    else:\n        print('Wrong dataset:', args.dataset)\n        exit()\n\n    logger.info('Dataset: %s', args.dataset)\n\n    train_loader, test_loader, dataset_normalization = get_loaders(args.data_dir, args.batch_size, dataset=args.dataset,\n                                                                   worker=args.worker, norm=False)\n    node = LIFNode\n    if args.node_type == 'RHLIF':\n        node = RHLIFNode\n    elif args.node_type == 'RHLIF2':\n        node = RHLIFNode2\n\n    # setup network\n    model = SEWResNet19(BasicBlock, [3, 3, 2], cnf='ADD', node_type=node, step=args.time_step, num_classes=args.num_classes,\n                        layer_by_layer=True, act_fun=AtanGrad, data_norm=dataset_normalization)\n    # print(model)\n\n    # load pretrained model\n    path = os.path.join('./ckpt', args.dataset, args.network)\n    args.pretrain = os.path.join(path, args.pretrain)\n    pretrained_model = torch.load(args.pretrain, map_location=args.device, weights_only=False)\n    model.load_state_dict(pretrained_model, strict=False)\n    model.to(args.device)\n    model.eval()\n    # for name, param in model.named_parameters():\n    #     if 'sigma' in name:  # 查找包含 'sigma' 的参数\n    #         param.data = torch.as_tensor(0.0, device=args.device)\n    #     if 'alpha' in name:  # 查找包含 'sigma' 的参数\n    #         param.data = torch.as_tensor(2.0, device=args.device)\n\n    logger.info('Evaluating with standard images...')\n    _, nature_acc = evaluate_standard(test_loader, model, args)\n    logger.info('Nature Acc: %.4f \\t', nature_acc)\n\n    if args.attack_type == 'eotpgd':\n        atk = torchattacks.EOTPGD(model, eps=8 / 255, alpha=(16/50) / 255, steps=50, random_start=True, eot_iter=10)\n        evaluate_attack(model, test_loader, args, atk, 'eotpgd', logger)\n    elif args.attack_type[0:3] == 'pgd':\n        steps = int(''.join(filter(str.isdigit, args.attack_type)))\n        atk = torchattacks.PGD(model, eps=8 / 255, alpha=(16/steps) / 255, steps=steps, random_start=True)\n        evaluate_attack(model, test_loader, args, atk, args.attack_type, logger)\n    elif args.attack_type == 'apgd':\n        atk = torchattacks.APGD(model, eps=8 / 255, steps=50, eot_iter=10)\n        evaluate_attack(model, test_loader, args, atk, 'apgd', logger)\n    elif args.attack_type == 'fgsm':\n        atk = torchattacks.FGSM(model, eps=8/255)\n        evaluate_attack(model, test_loader, args, atk, 'fgsm', logger)\n    elif args.attack_type == 'mifgsm':\n        atk = torchattacks.MIFGSM(model, eps=8 / 255, alpha=2 / 255, steps=5, decay=1.0)\n        evaluate_attack(model, test_loader, args, atk, 'mifgsm', logger)\n    elif args.attack_type == 'autoattack':\n        atk = torchattacks.AutoAttack(model, norm='Linf', eps=8/255, version='standard', n_classes=args.num_classes)\n        evaluate_attack(model, test_loader, args, atk, 'autoattack', logger)\n    elif args.attack_type == 'all':\n        atk = torchattacks.FGSM(model, eps=8 / 255)\n        evaluate_attack(model, test_loader, args, atk, 'fgsm', logger)\n        atk = torchattacks.APGD(model, eps=8 / 255, steps=10)\n        evaluate_attack(model, test_loader, args, atk, 'apgd', logger)\n        atk = torchattacks.PGD(model, eps=8 / 255, alpha=1.6 / 255, steps=10, random_start=True)\n        evaluate_attack(model, test_loader, args, atk, 'pgd', logger)\n        atk = torchattacks.MIFGSM(model, eps=8 / 255, alpha=2 / 255, steps=5, decay=1.0)\n        evaluate_attack(model, test_loader, args, atk, 'mifgsm', logger)\n        atk = torchattacks.AutoAttack(model, norm='Linf', eps=8 / 255, version='standard',\n                                      n_classes=args.num_classes)\n        evaluate_attack(model, test_loader, args, atk, 'autoattack', logger)\n    elif args.attack_type == 'step_test':\n        for steps in [10,30,50,70,90,110]:\n            atk = torchattacks.PGD(model, eps=8 / 255, alpha=(16 / steps) / 255, steps=steps, random_start=True)\n            pgd_loss, pgd_acc = evaluate_attack(model, test_loader, args, atk, f'pgd{steps}', logger)\n            atk = torchattacks.APGD(model, eps=8 / 255, steps=steps)\n            apgd_loss, apgd_acc = evaluate_attack(model, test_loader, args, atk, f'apgd{steps}', logger)\n\n    elif args.attack_type == 'eot_test':\n\n        for steps in [1,10,20,30]:\n            atk = torchattacks.EOTPGD(model, eps=8 / 255, alpha=(16 / 10) / 255, steps=10, random_start=True, eot_iter=steps)\n            pgd_loss, pgd_acc = evaluate_attack(model, test_loader, args, atk, f'eot{steps}_pgd', logger)\n            atk = torchattacks.APGD(model, eps=8 / 255, steps=10, eot_iter=steps)\n            apgd_loss, apgd_acc = evaluate_attack(model, test_loader, args, atk, f'eot{steps}_apgd', logger)\n\n    elif args.attack_type == 'intensity_test':\n\n        for intensity in [2, 4, 6, 8, 10, 12, 14, 16]:\n            atk = torchattacks.APGD(model, eps=intensity / 255, steps=10)\n            pgd_loss, pgd_acc = evaluate_attack(model, test_loader, args, atk, f'{intensity}_apgd', logger)\n\n\n    logger.info('Testing done.')\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/Snn_safety/RandHet-SNN/my_node.py",
    "content": "import torch\nfrom braincog.base.node.node import *\n\nclass RHLIFNode(BaseNode):\n    \"\"\"\n    Parametric LIF， 其中的 ```tau``` 会被backward过程影响\n    Reference：https://arxiv.org/abs/2007.05785\n    :param threshold: 神经元发放脉冲需要达到的阈值\n    :param v_reset: 静息电位\n    :param dt: 时间步长\n    :param step: 仿真步\n    :param tau: 膜电位时间常数, 用于控制膜电位衰减\n    :param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``\n    :param requires_thres_grad: 是否需要计算对于threshold的梯度, 默认为 ``False``\n    :param sigmoid_thres: 是否使用sigmoid约束threshold的范围搭到 [0, 1], 默认为 ``False``\n    :param requires_fp: 是否需要在推理过程中保存feature map, 需要消耗额外的内存和时间, 默认为 ``False``\n    :param layer_by_layer: 是否以一次性计算所有step的输出, 在网络模型较大的情况下, 一般会缩短单次推理的时间, 默认为 ``False``\n    :param n_groups: 在不同的时间步, 是否使用不同的权重, 默认为 ``1``, 即不分组\n    :param args: 其他的参数\n    :param kwargs: 其他的参数\n    \"\"\"\n\n    def __init__(self, threshold=0.5, tau=0., sigma=1.0, act_fun=AtanGrad, *args, **kwargs):\n        super().__init__(threshold, *args, **kwargs)\n        init_w = tau\n        if isinstance(act_fun, str):\n            act_fun = eval(act_fun)\n\n        self.act_fun = act_fun(alpha=2., requires_grad=False)\n        self.sigma = nn.Parameter(torch.as_tensor(sigma), requires_grad=False)\n        self.w = nn.Parameter(torch.as_tensor(init_w), requires_grad=False)\n        self.flag = 0\n        self.rd = 0\n\n\n    def integral(self, inputs):\n        self.rd = self.sigma * torch.normal(0., 1., size=(inputs.shape[0], inputs.shape[1], inputs.shape[2], inputs.shape[3]), device=inputs.device)\n        self.mem = self.rd.sigmoid() * self.mem + (1 - self.rd.sigmoid()) * inputs\n\n    def calc_spike(self):\n        self.spike = self.act_fun(self.mem - self.threshold)\n        # self.mem = self.mem - self.spike.detach() * self.threshold\n        self.mem = self.mem * (1 - self.spike.detach())\n\n    def n_reset(self):\n        self.mem = self.v_reset\n        self.spike = 0.\n        self.feature_map = []\n        self.mem_collect = []\n        self.flag = 0\n\nclass RHLIFNode2(BaseNode):\n    \"\"\"\n    Parametric LIF， 其中的 ```tau``` 会被backward过程影响\n    Reference：https://arxiv.org/abs/2007.05785\n    :param threshold: 神经元发放脉冲需要达到的阈值\n    :param v_reset: 静息电位\n    :param dt: 时间步长\n    :param step: 仿真步\n    :param tau: 膜电位时间常数, 用于控制膜电位衰减\n    :param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``\n    :param requires_thres_grad: 是否需要计算对于threshold的梯度, 默认为 ``False``\n    :param sigmoid_thres: 是否使用sigmoid约束threshold的范围搭到 [0, 1], 默认为 ``False``\n    :param requires_fp: 是否需要在推理过程中保存feature map, 需要消耗额外的内存和时间, 默认为 ``False``\n    :param layer_by_layer: 是否以一次性计算所有step的输出, 在网络模型较大的情况下, 一般会缩短单次推理的时间, 默认为 ``False``\n    :param n_groups: 在不同的时间步, 是否使用不同的权重, 默认为 ``1``, 即不分组\n    :param args: 其他的参数\n    :param kwargs: 其他的参数\n    \"\"\"\n\n    def __init__(self, threshold=0.5, tau=0., sigma=1.0, act_fun=AtanGrad, *args, **kwargs):\n        super().__init__(threshold, *args, **kwargs)\n        init_w = tau\n        if isinstance(act_fun, str):\n            act_fun = eval(act_fun)\n\n        self.act_fun = act_fun(alpha=2., requires_grad=False)\n        self.sigma = nn.Parameter(torch.as_tensor(sigma), requires_grad=False)\n        self.w = nn.Parameter(torch.as_tensor(init_w), requires_grad=False)\n        self.flag = 0\n        self.rd = 0\n        self.resample = 1\n\n\n    def integral(self, inputs):\n        if self.flag == 0:\n            self.rd = self.sigma * torch.normal(0., 1., size=(inputs.shape[0], inputs.shape[1], inputs.shape[2], inputs.shape[3]), device=inputs.device)\n            self.flag = 1\n        self.mem = self.rd.sigmoid() * self.mem + (1 - self.rd.sigmoid()) * inputs\n\n    def calc_spike(self):\n        self.spike = self.act_fun(self.mem - self.threshold)\n        # self.mem = self.mem - self.spike.detach() * self.threshold\n        self.mem = self.mem * (1 - self.spike.detach())\n\n    def n_reset(self):\n        self.mem = self.v_reset\n        self.spike = 0.\n        self.feature_map = []\n        self.mem_collect = []\n        if self.resample == 1:\n            self.flag = 0\n        else:\n            self.flag = 1"
  },
  {
    "path": "examples/Snn_safety/RandHet-SNN/sew_resnet.py",
    "content": "import torch\nimport torch.nn as nn\nfrom copy import deepcopy\nimport random\n\ntry:\n    from torchvision.models.utils import load_state_dict_from_url\nexcept ImportError:\n    from torchvision._internally_replaced_utils import load_state_dict_from_url\nfrom braincog.base.node import *\nfrom braincog.model_zoo.base_module import *\nfrom braincog.datasets import is_dvs_data\nfrom timm.models import register_model\n\n\ndef sew_function(x: torch.Tensor, y: torch.Tensor, cnf: str):\n    if cnf == 'ADD':\n        return x + y\n    elif cnf == 'AND':\n        return x * y\n    elif cnf == 'IAND':\n        return x * (1. - y)\n    else:\n        raise NotImplementedError\n\n\ndef conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=dilation, groups=groups, bias=False)\n\n\ndef conv1x1(in_planes, out_planes, stride=1):\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,\n                 base_width=64, dilation=1, norm_layer=None, cnf: str = None, node: callable = None, **kwargs):\n        super(BasicBlock, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        if groups != 1 or base_width != 64:\n            raise ValueError('BasicBlock only supports groups=1 and base_width=64')\n        if dilation > 1:\n            raise NotImplementedError(\"Dilation > 1 not supported in BasicBlock\")\n        # Both self.conv1 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = norm_layer(planes)\n        self.node1 = node()\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = norm_layer(planes)\n        self.node2 = node()\n        self.downsample = downsample\n        if downsample is not None:\n            self.downsample_sn = node()\n        self.stride = stride\n        self.cnf = cnf\n\n    def forward(self, x):\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.node1(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.node2(out)\n\n        if self.downsample is not None:\n            identity = self.downsample_sn(self.downsample(x))\n\n        out = sew_function(identity, out, self.cnf)\n\n        return out\n\n    def extra_repr(self) -> str:\n        return super().extra_repr() + f'cnf={self.cnf}'\n\n\nclass Bottleneck(nn.Module):\n    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)\n    # while original implementation places the stride at the first 1x1 convolution(self.conv1)\n    # according to \"Deep residual learning for image recognition\"https://arxiv.org/abs/1512.03385.\n    # This variant is also known as ResNet V1.5 and improves accuracy according to\n    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.\n\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,\n                 base_width=64, dilation=1, norm_layer=None, cnf: str = None, node: callable = None, **kwargs):\n        super(Bottleneck, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        width = int(planes * (base_width / 64.)) * groups\n        # Both self.conv2 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv1x1(inplanes, width)\n        self.bn1 = norm_layer(width)\n        self.node1 = node()\n        self.conv2 = conv3x3(width, width, stride, groups, dilation)\n        self.bn2 = norm_layer(width)\n        self.node2 = node()\n        self.conv3 = conv1x1(width, planes * self.expansion)\n        self.bn3 = norm_layer(planes * self.expansion)\n        self.node3 = node()\n        self.downsample = downsample\n        if downsample is not None:\n            self.downsample_sn = node()\n        self.stride = stride\n        self.cnf = cnf\n\n    def forward(self, x):\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.node1(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.node2(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n        out = self.node3(out)\n\n        if self.downsample is not None:\n            identity = self.downsample_sn(self.downsample(x))\n\n        out = sew_function(out, identity, self.cnf)\n\n        return out\n\n    def extra_repr(self) -> str:\n        return super().extra_repr() + f'cnf={self.cnf}'\n\n\nclass SEWResNet19(BaseModule):\n    def __init__(self, block, layers, num_classes=1000, step=8, encode_type=\"direct\", zero_init_residual=False,\n                 groups=1, width_per_group=64, replace_stride_with_dilation=None, data_norm=None,\n                 norm_layer=None, cnf: str = None, *args, **kwargs):\n        super().__init__(\n            step,\n            encode_type,\n            *args,\n            **kwargs\n        )\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        self._norm_layer = norm_layer\n        self.num_classes = num_classes\n        self.normalize = data_norm\n        self.node = kwargs['node_type']\n        if issubclass(self.node, BaseNode):\n            # self.node = partial(self.node, **kwargs, step=step)\n            self.node1 = partial(self.node, **kwargs, step=step)()\n            self.node2 = partial(self.node, **kwargs, step=step)\n            self.node3 = partial(self.node, **kwargs, step=step)\n            self.node4 = partial(self.node, **kwargs, step=step)\n        self.once = kwargs[\"once\"] if \"once\" in kwargs else False\n        self.sum_output = kwargs[\"sum_output\"] if \"sum_output\" in kwargs else True\n\n        init_channel = 3\n\n        self.inplanes = 128\n        self.dilation = 1\n        if replace_stride_with_dilation is None:\n            # each element in the tuple indicates if we should replace\n            # the 2x2 stride with a dilated convolution instead\n            replace_stride_with_dilation = [False, False, False]\n        if len(replace_stride_with_dilation) != 3:\n            raise ValueError(\"replace_stride_with_dilation should be None \"\n                             \"or a 3-element tuple, got {}\".format(replace_stride_with_dilation))\n        self.groups = groups\n        self.base_width = width_per_group\n\n        self.conv1 = nn.Conv2d(init_channel, self.inplanes, kernel_size=3, stride=1, padding=1,\n                               bias=False)\n        self.bn1 = norm_layer(self.inplanes)\n        # self.node1 = self.node()\n        self.layer1 = self._make_layer(block, 128, layers[0], cnf=cnf, node=self.node2, **kwargs)\n        self.layer2 = self._make_layer(block, 256, layers[1], stride=2,\n                                       dilate=replace_stride_with_dilation[0], cnf=cnf, node=self.node3, **kwargs)\n        self.layer3 = self._make_layer(block, 512, layers[2], stride=2,\n                                       dilate=replace_stride_with_dilation[1], cnf=cnf, node=self.node4, **kwargs)\n        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n        self.fc1 = nn.Linear(512 * block.expansion, num_classes)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n                m.weight.data.normal_(0, math.sqrt(2. / n))\n            elif isinstance(m, nn.BatchNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n            elif isinstance(m, nn.Linear):\n                n = m.weight.size(1)\n                m.weight.data.normal_(0, 1.0 / float(n))\n                m.bias.data.zero_()\n\n\n    def _make_layer(self, block, planes, blocks, stride=1, dilate=False, cnf: str = None, node: callable = None,\n                    **kwargs):\n        norm_layer = self._norm_layer\n        downsample = None\n        previous_dilation = self.dilation\n        if dilate:\n            self.dilation *= stride\n            stride = 1\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                conv1x1(self.inplanes, planes * block.expansion, stride),\n                norm_layer(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,\n                            self.base_width, previous_dilation, norm_layer, cnf, node, **kwargs))\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(block(self.inplanes, planes, groups=self.groups,\n                                base_width=self.base_width, dilation=self.dilation,\n                                norm_layer=norm_layer, cnf=cnf, node=node, **kwargs))\n\n        return nn.Sequential(*layers)\n\n    def _forward_impl(self, inputs):\n        # See note [TorchScript super()]\n        if self.normalize is not None:\n            self.normalize.mean = self.normalize.mean.to(inputs.device)\n            self.normalize.std = self.normalize.std.to(inputs.device)\n            inputs = self.normalize(inputs)\n        self.reset()\n\n        if self.layer_by_layer:\n            inputs = repeat(inputs, 'b c w h -> t b c w h', t=self.step)\n            inputs = rearrange(inputs, 't b c w h -> (t b) c w h')\n            x = self.conv1(inputs)\n            x = self.bn1(x)\n            x = self.node1(x)\n            x = self.layer1(x)\n            x = self.layer2(x)\n            x = self.layer3(x)\n\n            x = self.avgpool(x)\n\n            x = torch.flatten(x, 1)\n\n            x = self.fc1(x)\n            # x = self.node2(x)\n            # x = self.fc2(x)\n\n            x = rearrange(x, '(t b) c -> t b c', t=self.step)\n            # print(x)\n            if self.sum_output: x = x.mean(0)\n\n            return x\n\n    def _forward_once(self, x):\n        # inputs = self.encoder(inputs)\n        # x = inputs[t]\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.node1(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = self.avgpool(x)\n        x = torch.flatten(x, 1)\n\n        x = self.fc(x)\n        return x\n\n    def forward(self, x):\n        if self.once: return self._forward_once(x)\n        return self._forward_impl(x)\n"
  },
  {
    "path": "examples/Snn_safety/RandHet-SNN/train.py",
    "content": "import argparse\nimport copy\nimport logging\nimport os\nimport sys\nimport time\nfrom evaluate import evaluate_attack\nimport torchattacks\nfrom my_node import RHLIFNode, RHLIFNode2\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom sew_resnet import SEWResNet19, BasicBlock, Bottleneck\nfrom braincog.base.node.node import *\n\nfrom utils import (evaluate_standard, cifar10_std, cifar10_mean,\n                   orthogonal_retraction)\n\nfrom utils import (clamp, get_norm_stat,\n                   get_loaders)\n\nlogger = logging.getLogger(__name__)\n\n\ndef get_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--batch_size', default=128, type=int)\n    parser.add_argument('--data_dir', default='/mnt/data/datasets', type=str)\n    parser.add_argument('--dataset', default='cifar10', type=str)\n    parser.add_argument('--epochs', default=100, type=int)\n    parser.add_argument('--network', default='ResNet18', type=str)\n    parser.add_argument('--device', default='cuda:3', type=str)\n    parser.add_argument('--worker', default=4, type=int)\n    parser.add_argument('--lr_schedule', default='cosine', choices=['cyclic', 'multistep', 'cosine'])\n    parser.add_argument('--lr_min', default=0., type=float)\n    parser.add_argument('--lr_max', default=0.1, type=float)\n    parser.add_argument('--weight_decay', default=1e-4, type=float)\n    parser.add_argument('--momentum', default=0.9, type=float)\n    parser.add_argument('--epsilon', default=4, type=int)\n    parser.add_argument('--alpha', default=4, type=float, help='Step size')\n    parser.add_argument('--save_dir', default='ckpt', type=str, help='Output directory')\n    parser.add_argument('--seed', default=0, type=int, help='Random seed')\n\n    parser.add_argument('--attack_iters', default=1, type=int, help='Attack iterations')\n\n    parser.add_argument('--pretrain', default=None, type=str, help='path to load the pretrained model')\n\n    parser.add_argument('--beta', default=0.004, type=float)\n    parser.add_argument('--adv_training', action='store_true',\n                        help='if adv training')\n\n    parser.add_argument('--time_step', default=8, type=int)\n    parser.add_argument('--SR', action='store_true')\n    parser.add_argument('--node_type', default='LIF', type=str)\n    parser.add_argument('--parseval', action='store_true', help='if use different norm for different layers')\n\n    return parser.parse_args()\n\n\ndef main():\n    args = get_args()\n    device = args.device\n    torch.cuda.set_device(device)\n    if args.dataset == 'cifar10' or args.dataset == 'svhn':\n        args.num_classes = 10\n    elif args.dataset == 'cifar100':\n        args.num_classes = 100\n    mu, std, upper_limit, lower_limit = get_norm_stat(cifar10_mean, cifar10_std)\n\n    path = os.path.join('./ckpt', args.dataset, args.network)\n    args.save_dir = os.path.join(path, args.save_dir)\n\n    if not os.path.exists(args.save_dir):\n        os.makedirs(args.save_dir)\n    logfile = os.path.join(args.save_dir, 'output.log')\n    if os.path.exists(logfile):\n        os.remove(logfile)\n\n    handlers = [logging.FileHandler(logfile, mode='a+'),\n                logging.StreamHandler()]\n\n    logging.basicConfig(\n        format='[%(asctime)s] - %(message)s',\n        datefmt='%Y/%m/%d %H:%M:%S',\n        level=logging.INFO,\n        handlers=handlers)\n    logger.info(args)\n\n    np.random.seed(args.seed)\n    torch.manual_seed(args.seed)\n    torch.cuda.manual_seed(args.seed)\n\n    # get data loader\n    train_loader, test_loader, dataset_normalization = get_loaders(args.data_dir, args.batch_size, dataset=args.dataset,\n                                                                   worker=args.worker)\n    train_loader_e, test_loader_e, dataset_normalization_e = get_loaders(args.data_dir, args.batch_size, dataset=args.dataset,\n                                                                   worker=args.worker, norm=False)\n\n    # adv training attack setting\n    epsilon = ((args.epsilon / 255.) / std).to(device)\n    alpha = ((args.alpha / 255.) / std).to(device)\n\n    node = LIFNode\n    if args.node_type == 'RHLIF':\n        node = RHLIFNode\n    elif args.node_type == 'RHLIF2':\n        node = RHLIFNode2\n    # setup network\n\n    model = SEWResNet19(BasicBlock, [3, 3, 2], cnf='ADD', node_type=node, step=args.time_step, num_classes=args.num_classes,\n                            layer_by_layer=True, act_fun=AtanGrad)\n    model.to(device)\n\n    # model = torch.nn.DataParallel(model)\n    # logger.info(model)\n\n    # setup optimizer, loss function, LR scheduler\n    # opt = torch.optim.AdamW(model.parameters(), lr=args.lr_max, weight_decay=args.weight_decay)\n    if args.parseval:\n        opt = torch.optim.SGD(model.parameters(), lr=args.lr_max, momentum=0.9, weight_decay=0)\n    else:\n        opt = torch.optim.SGD(model.parameters(), lr=args.lr_max, momentum=0.9, weight_decay=args.weight_decay)\n\n\n    criterion = nn.CrossEntropyLoss()\n\n    if args.lr_schedule == 'cyclic':\n        lr_steps = args.epochs\n        scheduler = torch.optim.lr_scheduler.CyclicLR(opt, base_lr=args.lr_min, max_lr=args.lr_max,\n                                                      step_size_up=lr_steps / 2, step_size_down=lr_steps / 2)\n    elif args.lr_schedule == 'multistep':\n        lr_steps = args.epochs\n        scheduler = torch.optim.lr_scheduler.MultiStepLR(opt, milestones=[lr_steps / 2, lr_steps * 3 / 4], gamma=0.1)\n    elif args.lr_schedule == 'cosine':\n        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.epochs)\n\n    best_pgd_acc = 0\n    best_clean_acc = 0\n    test_acc_best_pgd = 0\n\n    start_epoch = 0\n\n    # Start training\n    start_train_time = time.time()\n\n    for epoch in range(start_epoch, args.epochs):\n\n        logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])\n        model.train()\n        train_loss = 0\n        train_acc = 0\n        train_n = 0\n        for i, (X, y) in enumerate(train_loader):\n\n            _iters = epoch * len(train_loader) + i\n\n            X, y = X.to(device), y.to(device)\n            if args.adv_training:\n                # init delta\n                delta = torch.zeros_like(X).to(device)\n                for j in range(len(epsilon)):\n                    delta[:, j, :, :].uniform_((-epsilon[j][0][0] / 10).item(), (epsilon[j][0][0]/10).item())\n                delta.data = clamp(delta, lower_limit.to(device) - X, upper_limit.to(device) - X)\n                delta.requires_grad = True\n\n                # pgd attack\n                for _ in range(args.attack_iters):\n                    output = model(X + delta)\n                    # model.random_reset_step = 0\n                    loss = criterion(output, y)\n\n                    loss.backward()\n\n                    grad = delta.grad.detach()\n\n                    delta.data = clamp(delta + alpha * torch.sign(grad), -epsilon, epsilon)\n                    delta.data = clamp(delta, lower_limit.to(device) - X, upper_limit.to(device) - X)\n                    delta.grad.zero_()\n\n                delta = delta.detach()\n                X_adv = X + delta[:X.size(0)]\n            else:\n                X_adv = X\n\n            if args.SR:\n                X_adv.requires_grad_(True)\n\n                outputs = model(X_adv)\n                out = outputs.gather(1, y.unsqueeze(1)).squeeze()  # choose\n                batch = []\n                inds = []\n                for j in range(len(outputs)):\n                    mm, ind = torch.cat([outputs[j, :y[j]], outputs[j, y[j] + 1:]], dim=0).max(0)\n                    f = torch.exp(out[j]) / (torch.exp(out[j]) + torch.exp(mm))\n                    batch.append(f)\n                    inds.append(ind.item())\n                f1 = torch.stack(batch, dim=0)\n\n                loss1 = criterion(outputs, y)\n\n                dx = torch.autograd.grad(f1, X_adv, grad_outputs=torch.ones_like(f1, device=device), retain_graph=True)[0]\n                X_adv.requires_grad_(False)\n\n                v = dx.detach().sign()\n\n                x2 = X_adv + 0.01 * v\n\n                outputs2 = model(x2)\n\n                out = outputs2.gather(1, y.unsqueeze(1)).squeeze()  # choose\n                batch = []\n                for j in range(len(outputs2)):\n                    mm = torch.cat([outputs2[j, :y[j]], outputs2[j, y[j] + 1:]], dim=0)[inds[j]]\n                    f = torch.exp(out[j]) / (torch.exp(out[j]) + torch.exp(mm))\n                    batch.append(f)\n                f2 = torch.stack(batch, dim=0)\n\n                dl = (f2 - f1) / 0.01\n                loss2 = dl.pow(2).mean()\n                loss = loss1 + 0.001 * loss2\n                loss = loss.mean()\n            else:\n                output = model(X_adv)\n                loss = criterion(output, y)\n\n            opt.zero_grad()\n            loss.backward()\n\n            opt.step()\n            if args.parseval:\n                orthogonal_retraction(model, args.beta)\n            # for name, param in model.named_parameters():\n            #     if 'sigma' in name:  # 查找包含 'sigma' 的参数\n            #         param.data = torch.clamp(param.data, 0.0, 1.5)  # 约束参数范围\n            #         if i==0:\n            #             print(param)\n            train_loss += loss.item() * y.size(0)\n            train_acc += (output.max(1)[1] == y).sum().item()\n            train_n += y.size(0)\n\n            if i % 50 == 0:\n                logger.info(\"Iter: [{:d}][{:d}/{:d}]\\t\"\n                            \"Loss {:.3f} ({:.3f})\\t\"\n                            \"Prec@1 {:.3f} ({:.3f})\\t\".format(\n                    epoch,\n                    i,\n                    len(train_loader),\n                    loss.item(),\n                    train_loss / train_n,\n                    (output.max(1)[1] == y).sum().item() / y.size(0),\n                    train_acc / train_n)\n                )\n\n        scheduler.step()\n\n        logger.info('Evaluating with standard images...')\n        test_loss, test_acc = evaluate_standard(test_loader, model, args)\n        logger.info(\n            'Test Loss: %.4f  \\t Test Acc: %.4f',\n            test_loss, test_acc)\n        if test_acc > best_clean_acc:\n            best_clean_acc = (\n                test_acc)\n\n            torch.save(model.state_dict(), os.path.join(args.save_dir, 'weight_c.pth'))\n\n        # pgd_loss, pgd_acc = evaluate_pgd(test_loader, model, 5, 1, args)\n        if epoch > args.epochs - 10:\n            logger.info('Evaluating with APGD Attack...')\n            model.normalize = dataset_normalization_e\n            atk = torchattacks.APGD(model, norm='Linf', eps=8 / 255, steps=10)\n            pgd_loss, pgd_acc = evaluate_attack(model, test_loader_e, args, atk, 'APGD', logger)\n            model.normalize = dataset_normalization\n\n            if pgd_acc > best_pgd_acc:\n                best_pgd_acc = pgd_acc\n                test_acc_best_pgd = test_acc\n\n                torch.save(model.state_dict(), os.path.join(args.save_dir, 'weight_r.pth'))\n            logger.info(\n                    'PGD Loss: %.4f \\t PGD Acc: %.4f \\n Best PGD Acc: %.4f \\t Test Acc of best PGD ckpt: %.4f',\n                    pgd_loss, pgd_acc, best_pgd_acc, test_acc_best_pgd)\n\n\n    train_time = time.time()\n    logger.info('Total train time: %.4f minutes', (train_time - start_train_time) / 60)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/Snn_safety/RandHet-SNN/utils.py",
    "content": "# import apex.amp as amp\nimport os.path\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torchvision import datasets, transforms\nimport numpy as np\n\n\ncifar10_mean = (0.4914, 0.4822, 0.4465)\ncifar10_std = (0.2471, 0.2435, 0.2616)\n\ndef get_norm_stat(mean, std):\n    mu = torch.tensor(mean).view(3, 1, 1)\n    std = torch.tensor(std).view(3, 1, 1)\n\n    upper_limit = ((1 - mu) / std)\n    lower_limit = ((0 - mu) / std)\n\n    return mu, std, upper_limit, lower_limit\n\ndef clamp(X, lower_limit, upper_limit):\n    return torch.max(torch.min(X, upper_limit), lower_limit)\n\n\ndef normalize_fn(tensor, mean, std):\n    \"\"\"Differentiable version of torchvision.functional.normalize\"\"\"\n    # here we assume the color channel is in at dim=1\n    mean = mean[None, :, None, None]\n    std = std[None, :, None, None]\n    return tensor.sub(mean).div(std)\n\n\nclass NormalizeByChannelMeanStd(nn.Module):\n    def __init__(self, mean, std):\n        super(NormalizeByChannelMeanStd, self).__init__()\n        if not isinstance(mean, torch.Tensor):\n            mean = torch.tensor(mean)\n        if not isinstance(std, torch.Tensor):\n            std = torch.tensor(std)\n        self.register_buffer(\"mean\", mean)\n        self.register_buffer(\"std\", std)\n\n    def forward(self, tensor):\n        return normalize_fn(tensor, self.mean, self.std)\n\n    def extra_repr(self):\n        return 'mean={}, std={}'.format(self.mean, self.std)\n\n\ndef get_loaders(dir_, batch_size, dataset='cifar10', worker=4, norm=True):\n\n    if norm:\n        train_transform = transforms.Compose([\n            transforms.RandomCrop(32, padding=4),\n            transforms.RandomHorizontalFlip(),\n            transforms.ToTensor(),\n            transforms.Normalize(cifar10_mean, cifar10_std),\n        ])\n        test_transform = transforms.Compose([\n            transforms.ToTensor(),\n            transforms.Normalize(cifar10_mean, cifar10_std),\n        ])\n        dataset_normalization = None\n\n    else:\n        train_transform = transforms.Compose([\n            transforms.RandomCrop(32, padding=4),\n            transforms.RandomHorizontalFlip(),\n            transforms.ToTensor(),\n        ])\n        test_transform = transforms.Compose([\n            transforms.ToTensor(),\n        ])\n        dataset_normalization = NormalizeByChannelMeanStd(\n            mean=cifar10_mean, std=cifar10_std)\n\n    if dataset == 'cifar10':\n        train_dataset = datasets.CIFAR10(\n            dir_, train=True, transform=train_transform, download=True)\n        test_dataset = datasets.CIFAR10(\n            dir_, train=False, transform=test_transform, download=True)\n    elif dataset == 'cifar100':\n        train_dataset = datasets.CIFAR100(\n            dir_, train=True, transform=train_transform, download=True)\n        test_dataset = datasets.CIFAR100(\n            dir_, train=False, transform=test_transform, download=True)\n\n    train_loader = torch.utils.data.DataLoader(\n        dataset=train_dataset,\n        batch_size=batch_size,\n        shuffle=True,\n        pin_memory=True,\n        num_workers=worker,\n    )\n    test_loader = torch.utils.data.DataLoader(\n        dataset=test_dataset,\n        batch_size=batch_size,\n        shuffle=False,\n        pin_memory=True,\n        num_workers=worker,\n    )\n    return train_loader, test_loader, dataset_normalization\n\n# evaluate on clean images with single norm\ndef evaluate_standard(test_loader, model, args):\n    test_loss = 0\n    test_acc = 0\n    n = 0\n    model.eval()\n    device = args.device\n\n    with torch.no_grad():\n        for i, (X, y) in enumerate(test_loader):\n            X, y = X.to(device), y.to(device)\n            output = model(X)\n            loss = F.cross_entropy(output, y)\n            test_loss += loss.item() * y.size(0)\n            test_acc += (output.max(1)[1] == y).sum().item()\n            n += y.size(0)\n    return test_loss/n, test_acc/n\n\ndef orthogonal_retraction(model, beta=0.002):\n    with torch.no_grad():\n        for module in model.modules():\n            if isinstance(module, (nn.Conv2d, nn.Linear)):\n                if isinstance(module, nn.Conv2d):\n                    weight_ = module.weight.data\n                    sz = weight_.shape\n                    weight_ = weight_.reshape(sz[0],-1)\n                    rows = list(range(module.weight.data.shape[0]))\n                elif isinstance(module, nn.Linear):\n                    if module.weight.data.shape[0] < 200: # set a sample threshold for row number\n                        weight_ = module.weight.data\n                        sz = weight_.shape\n                        weight_ = weight_.reshape(sz[0], -1)\n                        rows = list(range(module.weight.data.shape[0]))\n                    else:\n                        rand_rows = np.random.permutation(module.weight.data.shape[0])\n                        rows = rand_rows[: int(module.weight.data.shape[0] * 0.3)]\n                        weight_ = module.weight.data[rows,:]\n                        sz = weight_.shape\n                module.weight.data[rows,:] = ((1 + beta) * weight_ - beta * weight_.matmul(weight_.t()).matmul(weight_)).reshape(sz)\n"
  },
  {
    "path": "examples/Social_Cognition/FOToM/algorithms/ToM_class.py",
    "content": "import torch\nimport torch.distributions as td\nfrom utils.networks import MLPNetwork, SNNNetwork, LSTMClassifier\nfrom utils.misc import soft_update, average_gradients, onehot_from_logits, gumbel_softmax\n\n\nclass ToM1(object):\n    \"\"\"\n    tom factory (Simplification of ToM's model)\n    init ToM0 and ToM1 net\n    train ToM0 and ToM1 net\n    \"\"\"\n\n    def __init__(self, tom_base, alg_types, agent_types, num_lm, device, hidden_dim=64):\n        self.device = device\n        self.alg_types = alg_types\n        self.agent_types = agent_types\n        self.num_good_agents = len(self._get_index1(self.agent_types, 'agent'))\n        self.nagents = len(alg_types)\n        self.num_lm = num_lm\n        '''\n        Assume that ToM0 and ToM1 are equivalent\n        '''\n        self.tom1 = tom_base\n        self.other_tom1 = [0] * self.nagents\n        self._agent_tom1_init()\n        '''\n        ToM0_policy\n        '''\n        # self.tom_PHI = []   #TODO\n\n        self.hidden = None\n\n    def _agent_tom1_init(self):\n        other_alg_types_ = self.alg_types.copy()\n        other_agent_types_ = self.agent_types.copy()\n        for agent_i in range(self.nagents):\n            other_alg_types = other_alg_types_.copy()\n            other_agent_types = other_agent_types_.copy()\n            other_alg_types.pop(agent_i)\n            other_agent_types.pop(agent_i)\n\n            adv_indx = self._get_index1(other_agent_types, 'adversary')\n            good_indx = self._get_index1(other_agent_types, 'agent')\n            self.other_tom1[agent_i] = [self.tom1['adversary'][self.agent_types[agent_i]]] * len(adv_indx)     #TODO\n            self.other_tom1[agent_i] += [self.tom1['agent'][self.agent_types[agent_i]]] * len(good_indx)\n\n    def _get_index1(self, lst=None, item=''):\n        return [index for (index, value) in enumerate(lst) if value == item]\n\n    def c_function(self, tom0_actions_q, tom1_actions_q):\n        c1 = 0.7\n\n        # tom0_actions = torch.stack([gumbel_softmax(action_i, hard=True)\n        #                    for action_i in tom0_actions_prob], 0)\n        # tom1_actions = torch.stack([gumbel_softmax(action_i, hard=True)\n        #                  for action_i in tom1_actions_prob], 0)\n        '''\n        batch, num_agent, ep, 1\n        '''\n        tom0_actions = (tom0_actions_q == tom0_actions_q.max(dim=-1, keepdim=True)[0]).to(dtype=torch.int32)\n        tom1_actions = (tom1_actions_q.unsqueeze(1) == tom1_actions_q.unsqueeze(0).max(dim=-1, keepdim=True)[0]).to(\n            dtype=torch.int32)\n        alig = tom0_actions.long().detach() & tom1_actions.long().detach()\n        I_belief = tom0_actions_q * (1 - c1) + alig * c1\n        # I_belief = [prob_i * (1 -c1) + alig[i] * c1 for i, prob_i in enumerate(tom0_actions_prob)]\n\n        return I_belief\n\n    def tom1_output(self, agent_i, adv_indx, good_indx, obs_, acs_pre_):\n        \"\"\"\n        ToM1 <--> ToM1\n        obs_self : obs of self, need to convert\n        tom0_out : predict other-action (episode_num * self.args.episode_limit * 2, -1), need to convert\n        tom0_out_q : predict other-action q_value (episode_num * self.args.episode_limit * 2, -1)\n        device : interact with env (cpu)  train (cuda)\n\n        ToM0_policy\n        \"\"\"\n        actions = []\n        actions += [\n            # gumbel_softmax(\n                self.other_tom1[agent_i][j].to(self.device)(\n                    torch.cat((obs_[:, -(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2):],\n                               acs_pre_[:, :5]), 1))#.detach()   #, hard=True\n             for j in adv_indx\n        ]\n        actions += [\n            # gumbel_softmax(\n                self.other_tom1[agent_i][j].to(self.device)(\n                    torch.cat((obs_[:, -(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2):],\n                               acs_pre_[:, :5]), 1))#.detach()   #, hard=True\n             for j in good_indx\n        ]\n        # E_action = torch.cat(actions, 1)\n        E_action = actions\n        return E_action\n\n\n\n\n\n\n\n\n\n\n\n"
  },
  {
    "path": "examples/Social_Cognition/FOToM/algorithms/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Social_Cognition/FOToM/algorithms/maddpg.py",
    "content": "import torch\nfrom torch.optim import Adam\nimport torch.nn.functional as F\nfrom gym.spaces import Box, Discrete, MultiDiscrete\nfrom multiagent.multi_discrete import MultiDiscrete\nfrom utils.networks import MLPNetwork, SNNNetwork, LSTMClassifier\nfrom utils.misc import soft_update, average_gradients, onehot_from_logits, gumbel_softmax\nfrom utils.agents import DDPGAgent, DDPGAgent_RNN, DDPGAgent_SNN, DDPGAgent_ToM\n# from commom.distributions import make_pdtype\nfrom thop import profile\nfrom thop import clever_format\n\nimport  time\nMSELoss = torch.nn.MSELoss()\n\n# reference:https://github.com/starry-sky6688/MADDPG.git\nclass MADDPG(object):\n    def __init__(self, agent_init_params, alg_types, device,\n                 gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64,\n                 discrete_action=False):\n        \"\"\"\n        Inputs:\n            agent_init_params (list of dict): List of dicts with parameters to\n                                              initialize each agent\n                num_in_pol (int): Input dimensions to policy\n                num_out_pol (int): Output dimensions to policy\n                num_in_critic (int): Input dimensions to critic\n            alg_types (list of str): Learning algorithm for each agent (DDPG\n                                       or MADDPG)\n            gamma (float): Discount factor\n            tau (float): Target update rate\n            lr (float): Learning rate for policy and critic\n            hidden_dim (int): Number of hidden dimensions for networks\n            discrete_action (bool): Whether or not to use discrete action space\n        \"\"\"\n        self.device = device\n        self.nagents = len(alg_types)\n        self.alg_types = alg_types\n        self.agents = [DDPGAgent(lr=lr, discrete_action=discrete_action,\n                                 hidden_dim=hidden_dim,\n                                 **params)\n                       for params in agent_init_params]\n        self.agent_init_params = agent_init_params\n        self.gamma = gamma\n        self.tau = tau\n        self.lr = lr\n        self.discrete_action = discrete_action\n        self.pol_dev = 'cpu'  # device for policies\n        self.critic_dev = 'cpu'  # device for critics\n        self.trgt_pol_dev = 'cpu'  # device for target policies\n        self.trgt_critic_dev = 'cpu'  # device for target critics\n        self.niter = 0\n\n    @property\n    def policies(self):\n        return [a.policy for a in self.agents]\n\n    @property\n    def target_policies(self):\n        return [a.target_policy for a in self.agents]\n\n    def scale_noise(self, scale):\n        \"\"\"\n        Scale noise for each agent\n        Inputs:\n            scale (float): scale of noise\n        \"\"\"\n        for a in self.agents:\n            a.scale_noise(scale)\n\n    def reset_noise(self):\n        for a in self.agents:\n            a.reset_noise()\n\n    def step(self, observations, explore=False):\n        \"\"\"\n        Take a step forward in environment with all agents\n        Inputs:\n            observations: List of observations for each agent\n            explore (boolean): Whether or not to add exploration noise\n        Outputs:\n            actions: List of actions for each agent\n        \"\"\"\n        return [a.step(obs, explore=explore) for a, obs in zip(self.agents,\n                                                                 observations)]\n\n    def update(self, sample, agent_i, parallel=False, logger=None):\n        \"\"\"\n        Update parameters of agent model based on sample from replay buffer\n        Inputs:\n            sample: tuple of (observations, actions, rewards, next\n                    observations, and episode end masks) sampled randomly from\n                    the replay buffer. Each is a list with entries\n                    corresponding to each agent\n            agent_i (int): index of agent to update\n            parallel (bool): If true, will average gradients across threads\n            logger (SummaryWriter from Tensorboard-Pytorch):\n                If passed in, important quantities will be logged\n        \"\"\"\n        obs, acs, rews, next_obs, dones = sample\n        curr_agent = self.agents[agent_i]\n\n        curr_agent.critic_optimizer.zero_grad()\n        if self.alg_types[agent_i] == 'MADDPG':\n            if self.discrete_action: # one-hot encode action\n                all_trgt_acs = [onehot_from_logits(pi(nobs)) for pi, nobs in\n                                zip(self.target_policies, next_obs)]\n            else:\n                all_trgt_acs = [pi(nobs) for pi, nobs in zip(self.target_policies,\n                                                             next_obs)]\n            trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1)\n        else:  # DDPG\n            if self.discrete_action:\n                trgt_vf_in = torch.cat((next_obs[agent_i],\n                                        onehot_from_logits(\n                                            curr_agent.target_policy(\n                                                next_obs[agent_i]))),\n                                       dim=1)\n            else:\n                trgt_vf_in = torch.cat((next_obs[agent_i],\n                                        curr_agent.target_policy(next_obs[agent_i])),\n                                       dim=1)\n        target_value = (rews[agent_i].view(-1, 1) + self.gamma *\n                        curr_agent.target_critic(trgt_vf_in) *\n                        (1 - dones[agent_i].view(-1, 1)))\n\n        if self.alg_types[agent_i] == 'MADDPG':\n            vf_in = torch.cat((*obs, *acs), dim=1)\n        else:  # DDPG\n            vf_in = torch.cat((obs[agent_i], acs[agent_i]), dim=1)\n        actual_value = curr_agent.critic(vf_in)\n        vf_loss = MSELoss(actual_value, target_value.detach())\n        vf_loss.backward()\n        if parallel:\n            average_gradients(curr_agent.critic)\n        torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), 0.5)\n        curr_agent.critic_optimizer.step()\n\n        curr_agent.policy_optimizer.zero_grad()\n\n        if self.discrete_action:\n            # Forward pass as if onehot (hard=True) but backprop through a differentiable\n            # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop\n            # through discrete categorical samples, but I'm not sure if that is\n            # correct since it removes the assumption of a deterministic policy for\n            # DDPG. Regardless, discrete policies don't seem to learn properly without it.\n            curr_pol_out = curr_agent.policy(obs[agent_i])\n            curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)\n        else:\n            curr_pol_out = curr_agent.policy(obs[agent_i])\n            curr_pol_vf_in = curr_pol_out\n        if self.alg_types[agent_i] == 'MADDPG':\n            all_pol_acs = []\n            for i, pi, ob in zip(range(self.nagents), self.policies, obs):\n                if i == agent_i:\n                    all_pol_acs.append(curr_pol_vf_in)\n                elif self.discrete_action:\n                    all_pol_acs.append(onehot_from_logits(pi(ob)))\n                else:\n                    all_pol_acs.append(pi(ob))\n            vf_in = torch.cat((*obs, *all_pol_acs), dim=1)\n        else:  # DDPG\n            vf_in = torch.cat((obs[agent_i], curr_pol_vf_in),\n                              dim=1)\n        pol_loss = -curr_agent.critic(vf_in).mean()\n        pol_loss += (curr_pol_out**2).mean() * 1e-3\n        pol_loss.backward()\n        if parallel:\n            average_gradients(curr_agent.policy)\n        torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), 0.5)\n        curr_agent.policy_optimizer.step()\n        if logger is not None:\n            logger.add_scalars('agent%i/losses' % agent_i,\n                               {'vf_loss': vf_loss,\n                                'pol_loss': pol_loss},\n                               self.niter)\n\n\n    def update_all_targets(self):\n        \"\"\"\n        Update all target networks (called after normal updates have been\n        performed for each agent)\n        \"\"\"\n        for a in self.agents:\n            soft_update(a.target_critic, a.critic, self.tau)\n            soft_update(a.target_policy, a.policy, self.tau)\n        self.niter += 1\n\n    def prep_training(self, device='gpu'):\n        for a in self.agents:\n            a.policy.train()\n            a.critic.train()\n            a.target_policy.train()\n            a.target_critic.train()\n        if device == 'gpu':\n            fn = lambda x: x.to(torch.device(self.device))\n        else:\n            fn = lambda x: x.cpu()\n        if not self.pol_dev == device:\n            for a in self.agents:\n                a.policy = fn(a.policy)\n            self.pol_dev = device\n        if not self.critic_dev == device:\n            for a in self.agents:\n                a.critic = fn(a.critic)\n            self.critic_dev = device\n        if not self.trgt_pol_dev == device:\n            for a in self.agents:\n                a.target_policy = fn(a.target_policy)\n            self.trgt_pol_dev = device\n        if not self.trgt_critic_dev == device:\n            for a in self.agents:\n                a.target_critic = fn(a.target_critic)\n            self.trgt_critic_dev = device\n\n    def prep_rollouts(self, device='cpu'):\n        for a in self.agents:\n            a.policy.eval()\n        if device == 'gpu':\n            fn = lambda x: x.to(torch.device(self.device))\n        else:\n            fn = lambda x: x.cpu()\n        # only need main policy for rollouts\n        if not self.pol_dev == device:\n            for a in self.agents:\n                a.policy = fn(a.policy)\n            self.pol_dev = device\n\n    def save(self, filename):\n        \"\"\"\n        Save trained parameters of all agents into one file\n        \"\"\"\n        self.prep_training(device='cpu')  # move parameters to CPU before saving\n        save_dict = {'init_dict': self.init_dict,\n                     'agent_params': [a.get_params() for a in self.agents]}\n        torch.save(save_dict, filename)\n\n    @classmethod\n    def init_from_env(cls, env, device, agent_alg=\"MADDPG\", adversary_alg=\"MADDPG\",\n                      gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64):\n        \"\"\"\n        Instantiate instance of this class from multi-agent environment\n        \"\"\"\n        agent_init_params = []\n        alg_types = [adversary_alg if atype == 'adversary' else agent_alg for\n                     atype in env.agent_types]\n        for acsp, obsp, algtype in zip(env.action_space, env.observation_space,\n                                       alg_types):\n            num_in_pol = obsp.shape[0]\n            if isinstance(acsp, Box):\n                discrete_action = False\n                get_shape = lambda x: x.shape[0]\n            elif isinstance(acsp, Discrete):  # Discrete\n                discrete_action = True\n                get_shape = lambda x: x.n\n            elif isinstance(acsp, MultiDiscrete):\n                discrete_action = True\n                get_shape = lambda x: sum(x.high - x.low + 1)\n            num_out_pol = get_shape(acsp)\n            if algtype == \"MADDPG\":\n                num_in_critic = 0\n                for oobsp in env.observation_space:\n                    num_in_critic += oobsp.shape[0]\n                for oacsp in env.action_space:\n                    if isinstance(oacsp, Box):\n                        discrete_action = False\n                        get_shape = lambda x: x.shape[0]\n                    elif isinstance(oacsp, Discrete):  # Discrete\n                        discrete_action = True\n                        get_shape = lambda x: x.n\n                    elif isinstance(oacsp, MultiDiscrete):\n                        discrete_action = True\n                        get_shape = lambda x: sum(x.high - x.low + 1)\n                    num_in_critic += get_shape(oacsp)\n            else:\n                num_in_critic = obsp.shape[0] + get_shape(acsp)\n            agent_init_params.append({'num_in_pol': num_in_pol,\n                                      'num_out_pol': num_out_pol,\n                                      'num_in_critic': num_in_critic})\n        init_dict = {'gamma': gamma, 'tau': tau, 'lr': lr,\n                     'hidden_dim': hidden_dim,\n                     'alg_types': alg_types,\n                     'agent_init_params': agent_init_params,\n                     'discrete_action': discrete_action,\n                     'device': device}\n        instance = cls(**init_dict)\n        instance.init_dict = init_dict\n        return instance\n\n    @classmethod\n    def init_from_save(cls, filename):\n        \"\"\"\n        Instantiate instance of this class from file created by 'save' method\n        \"\"\"\n        save_dict = torch.load(filename)\n        instance = cls(**save_dict['init_dict'])\n        instance.init_dict = save_dict['init_dict']\n        for a, params in zip(instance.agents, save_dict['agent_params']):\n            a.load_params(params)\n        return instance\n\nclass MADDPG_RNN(object):\n    \"\"\"\n    Wrapper class for DDPG-esque (i.e. also MADDPG) agents in multi-agent task\n    \"\"\"\n    def __init__(self, agent_init_params, alg_types,\n                 gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64,\n                 discrete_action=False):\n        \"\"\"\n        Inputs:\n            agent_init_params (list of dict): List of dicts with parameters to\n                                              initialize each agent\n                num_in_pol (int): Input dimensions to policy\n                num_out_pol (int): Output dimensions to policy\n                num_in_critic (int): Input dimensions to critic\n            alg_types (list of str): Learning algorithm for each agent (DDPG\n                                       or MADDPG)\n            gamma (float): Discount factor\n            tau (float): Target update rate\n            lr (float): Learning rate for policy and critic\n            hidden_dim (int): Number of hidden dimensions for networks\n            discrete_action (bool): Whether or not to use discrete action space\n        \"\"\"\n        self.nagents = len(alg_types)\n        self.alg_types = alg_types\n        self.agents = [DDPGAgent_RNN(lr=lr, discrete_action=discrete_action,\n                                 hidden_dim=hidden_dim,\n                                 **params)\n                       for params in agent_init_params]\n        self.agent_init_params = agent_init_params\n        self.gamma = gamma\n        self.tau = tau\n        self.lr = lr\n        self.discrete_action = discrete_action\n        self.pol_dev = 'cpu'  # device for policies\n        self.critic_dev = 'cpu'  # device for critics\n        self.trgt_pol_dev = 'cpu'  # device for target policies\n        self.trgt_critic_dev = 'cpu'  # device for target critics\n        self.niter = 0\n\n    def _init_agent(self, n_rollout_threads):\n        for agent_i in self.agents:\n            agent_i.init_hidden(n_rollout_threads, policy_hidden=True, policy_target_hidden=True, \\\n                    critic_hidden=True, critic_target_hidden=True)\n\n    # @property\n    def policies(self, len_ep):\n        for a in self.agents: a.init_hidden(len_ep, policy_hidden=True, policy_target_hidden=False, \\\n                    critic_hidden=False, critic_target_hidden=False)\n        return [a.policy for a in self.agents], [a.policy_hidden for a in self.agents]\n\n    # @property\n    def target_policies(self, len_ep):\n        for a in self.agents: a.init_hidden(len_ep, policy_hidden=False, policy_target_hidden=True, \\\n                    critic_hidden=False, critic_target_hidden=False)\n        return [a.target_policy for a in self.agents], [a.policy_target_hidden for a in self.agents]\n\n    def scale_noise(self, scale):\n        \"\"\"\n        Scale noise for each agent\n        Inputs:\n            scale (float): scale of noise\n        \"\"\"\n        for a in self.agents:\n            a.scale_noise(scale)\n\n    def reset_noise(self):\n        for a in self.agents:\n            a.reset_noise()\n\n    def step(self, observations, explore=False):\n        \"\"\"\n        Take a step forward in environment with all agents\n        Inputs:\n            observations: List of observations for each agent\n            explore (boolean): Whether or not to add exploration noise\n        Outputs:\n            actions: List of actions for each agent\n        \"\"\"\n        return [a.step(obs, explore=explore) for a, obs in zip(self.agents,\n                                                                 observations)]\n\n    def _compute_rnn(self, fn, hidden, inputs, logit):\n        num_ep = inputs.shape[1]\n        outputs = []\n        hidden = hidden.to(torch.device('cuda:4'))\n        for step_id in range(num_ep):\n            output, hidden = fn(inputs[:, step_id, :], hidden)\n            if logit == onehot_from_logits:\n                outputs.append(logit(output))\n            elif logit == gumbel_softmax:\n                outputs.append(logit(output, True))\n            else:\n                outputs.append(output)\n        outputs = torch.stack(outputs,1)\n        return outputs\n\n    def update(self, sample, agent_i, parallel=False, logger=None):\n        \"\"\"\n        Update parameters of agent model based on sample from replay buffer\n        Inputs:\n            sample: tuple of (observations, actions, rewards, next\n                    observations, and episode end masks) sampled randomly from\n                    the replay buffer. Each is a list with entries\n                    corresponding to each agent\n            agent_i (int): index of agent to update\n            parallel (bool): If true, will average gradients across threads\n            logger (SummaryWriter from Tensorboard-Pytorch):\n                If passed in, important quantities will be logged\n        \"\"\"\n        obs, acs, rews, next_obs, dones = sample\n        curr_agent = self.agents[agent_i]\n        len_ep = obs[0].shape[0]\n        curr_agent.init_hidden(len_ep, True, True, True, True)\n\n        curr_agent.critic_optimizer.zero_grad()\n        if self.alg_types[agent_i] == 'MADDPG_RNN':\n            if self.discrete_action: # one-hot encode action\n                all_trgt_acs = [self._compute_rnn(pi, hidden, nobs, onehot_from_logits) for pi, hidden, nobs in \\\n                                zip(self.target_policies(len_ep)[0], self.target_policies(len_ep)[1], next_obs)]\n            else:\n                all_trgt_acs = [pi(nobs) for pi, nobs in zip(self.target_policies,\n                                                             next_obs)]\n            trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=-1)\n\n        target_critic = self._compute_rnn(curr_agent.target_critic, curr_agent.critic_target_hidden, trgt_vf_in, None)\n\n        target_value = (rews[agent_i].view(-1, 1) + self.gamma *\n                        target_critic.view(-1, 1) *\n                        (1 - dones[agent_i].view(-1, 1)))\n\n        if self.alg_types[agent_i] == 'MADDPG_RNN':\n            vf_in = torch.cat((*obs, *acs), dim=-1)\n\n        actual_value = self._compute_rnn(curr_agent.critic, curr_agent.critic_hidden, vf_in, None)\n        vf_loss = MSELoss(actual_value.view(-1,1), target_value.detach())\n        vf_loss.backward()\n        if parallel:\n            average_gradients(curr_agent.critic)\n        torch.nn.utils.clip_grad_norm(curr_agent.critic.parameters(), 0.5)\n        curr_agent.critic_optimizer.step()\n\n        curr_agent.policy_optimizer.zero_grad()\n\n        if self.discrete_action:\n            # Forward pass as if onehot (hard=True) but backprop through a differentiable\n            # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop\n            # through discrete categorical samples, but I'm not sure if that is\n            # correct since it removes the assumption of a deterministic policy for\n            # DDPG. Regardless, discrete policies don't seem to learn properly without it.\n            curr_pol_out = self._compute_rnn(curr_agent.policy, curr_agent.policy_hidden, obs[agent_i], None)\n            curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)\n\n        if self.alg_types[agent_i] == 'MADDPG_RNN':\n            all_pol_acs = []\n            for i, pi, policy_hidden, ob in zip(range(self.nagents), self.policies(len_ep)[0], self.policies(len_ep)[1], obs):\n                if i == agent_i:\n                    all_pol_acs.append(curr_pol_vf_in)\n                elif self.discrete_action:\n                    all_pol_acs.append(self._compute_rnn(pi, policy_hidden, ob, onehot_from_logits))\n                    # all_pol_acs.append(onehot_from_logits(pi(ob)))\n                else:\n                    all_pol_acs.append(pi(ob))\n            vf_in = torch.cat((*obs, *all_pol_acs), dim=-1)\n\n        curr_agent.init_hidden(len_ep, False, False, True, False)\n        pol_loss = -self._compute_rnn(curr_agent.critic, policy_hidden, vf_in, None).mean()\n        # pol_loss = -curr_agent.critic(vf_in).mean()\n        pol_loss += (curr_pol_out**2).mean() * 1e-3\n        pol_loss.backward()\n        if parallel:\n            average_gradients(curr_agent.policy)\n        torch.nn.utils.clip_grad_norm(curr_agent.policy.parameters(), 0.5)\n        curr_agent.policy_optimizer.step()\n        if logger is not None:\n            logger.add_scalars('agent%i/losses' % agent_i,\n                               {'vf_loss': vf_loss,\n                                'pol_loss': pol_loss},\n                               self.niter)\n\n    def update_all_targets(self):\n        \"\"\"\n        Update all target networks (called after normal updates have been\n        performed for each agent)\n        \"\"\"\n        for a in self.agents:\n            soft_update(a.target_critic, a.critic, self.tau)\n            soft_update(a.target_policy, a.policy, self.tau)\n        self.niter += 1\n\n    def prep_training(self, device='gpu'):\n        for a in self.agents:\n            a.policy.train()\n            a.critic.train()\n            a.target_policy.train()\n            a.target_critic.train()\n        if device == 'gpu':\n            fn = lambda x: x.to(torch.device('cuda:4'))\n        else:\n            fn = lambda x: x.cpu()\n        if not self.pol_dev == device:\n            for a in self.agents:\n                a.policy = fn(a.policy)\n            self.pol_dev = device\n        if not self.critic_dev == device:\n            for a in self.agents:\n                a.critic = fn(a.critic)\n            self.critic_dev = device\n        if not self.trgt_pol_dev == device:\n            for a in self.agents:\n                a.target_policy = fn(a.target_policy)\n            self.trgt_pol_dev = device\n        if not self.trgt_critic_dev == device:\n            for a in self.agents:\n                a.target_critic = fn(a.target_critic)\n            self.trgt_critic_dev = device\n\n    def prep_rollouts(self, device='cpu'):\n        for a in self.agents:\n            a.policy.eval()\n        if device == 'gpu':\n            fn = lambda x: x.to(torch.device('cuda:4'))\n        else:\n            fn = lambda x: x.cpu()\n        # only need main policy for rollouts\n        if not self.pol_dev == device:\n            for a in self.agents:\n                a.policy = fn(a.policy)\n            self.pol_dev = device\n\n    def save(self, filename):\n        \"\"\"\n        Save trained parameters of all agents into one file\n        \"\"\"\n        self.prep_training(device='cpu')  # move parameters to CPU before saving\n        save_dict = {'init_dict': self.init_dict,\n                     'agent_params': [a.get_params() for a in self.agents]}\n        torch.save(save_dict, filename)\n\n    @classmethod\n    def init_from_env(cls, env, agent_alg=\"MADDPG\", adversary_alg=\"MADDPG_RNN\",\n                      gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64):\n        \"\"\"\n        Instantiate instance of this class from multi-agent environment\n        \"\"\"\n        agent_init_params = []\n        alg_types = [adversary_alg if atype == 'adversary' else agent_alg for\n                     atype in env.agent_types]\n        for acsp, obsp, algtype in zip(env.action_space, env.observation_space,\n                                       alg_types):\n            num_in_pol = obsp.shape[0]\n            if isinstance(acsp, Box):\n                discrete_action = False\n                get_shape = lambda x: x.shape[0]\n            else:  # Discrete\n                discrete_action = True\n                get_shape = lambda x: x.n\n            num_out_pol = get_shape(acsp)\n            if algtype == \"MADDPG_RNN\":\n                num_in_critic = 0\n                for oobsp in env.observation_space:\n                    num_in_critic += oobsp.shape[0]\n                for oacsp in env.action_space:\n                    num_in_critic += get_shape(oacsp)\n            else:\n                num_in_critic = obsp.shape[0] + get_shape(acsp)\n            agent_init_params.append({'num_in_pol': num_in_pol,\n                                      'num_out_pol': num_out_pol,\n                                      'num_in_critic': num_in_critic})\n        init_dict = {'gamma': gamma, 'tau': tau, 'lr': lr,\n                     'hidden_dim': hidden_dim,\n                     'alg_types': alg_types,\n                     'agent_init_params': agent_init_params,\n                     'discrete_action': discrete_action}\n        instance = cls(**init_dict)\n        instance.init_dict = init_dict\n        return instance\n\n    @classmethod\n    def init_from_save(cls, filename):\n        \"\"\"\n        Instantiate instance of this class from file created by 'save' method\n        \"\"\"\n        save_dict = torch.load(filename)\n        instance = cls(**save_dict['init_dict'])\n        instance.init_dict = save_dict['init_dict']\n        for a, params in zip(instance.agents, save_dict['agent_params']):\n            a.load_params(params)\n        return instance\n\n\n\n\n\n\n\n\n\n\n\n\n"
  },
  {
    "path": "examples/Social_Cognition/FOToM/algorithms/tom11.py",
    "content": "import torch\nfrom torch.optim import Adam\nimport torch.nn.functional as F\nfrom gym.spaces import Box, Discrete, MultiDiscrete\nfrom multiagent.multi_discrete import MultiDiscrete\nfrom utils.networks import MLPNetwork, SNNNetwork, LSTMClassifier\nfrom utils.misc import soft_update, average_gradients, onehot_from_logits, gumbel_softmax\nfrom utils.agents import DDPGAgent\nfrom algorithms.ToM_class import ToM1\n# from commom.distributions import make_pdtype\nfrom thop import profile\nfrom thop import clever_format\n\nimport  time\nMSELoss = torch.nn.MSELoss()\nKL_criterion = torch.nn.KLDivLoss(reduction='sum')\nCE_criterion = torch.nn.CrossEntropyLoss(reduction=\"sum\")\n\n\nclass ToM_decision11(object):\n\n    def __init__(self, agent_init_params, alg_types, agent_types, num_lm,\n                 output_style, device, config, gamma=0.95, tau=0.01, lr=0.01,\n                 hidden_dim=64, discrete_action=False):\n        self.config = config\n        self.device = device\n        self.num_lm = num_lm\n        self.nagents = len(alg_types)\n        self.alg_types = alg_types\n        self.agent_types = agent_types\n        self.num_good_agents = len(self._get_index1(self.agent_types, 'agent'))\n        self.agents = [DDPGAgent_ToM(lr=lr, discrete_action=discrete_action,\n                                     hidden_dim=hidden_dim,\n                                     **params, output_style=output_style,\n                                     num_agents=self.nagents,\n                                     device=self.device)\n                       for params in agent_init_params]\n        self.agent_init_params = agent_init_params\n        # tom0\n        self.mle_base = [MLPNetwork(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2 + 5,\n                                  self.agent_init_params[-1]['num_out_pol'],\n                                  hidden_dim=5, norm_in=False),    # infer good agent\n                         MLPNetwork(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2 + 5,\n                                    self.agent_init_params[-1]['num_out_pol'],\n                                    hidden_dim=5, norm_in=False),  # infer adversary\n                         ]\n\n        self.tom_base = {\n            'agent': {'agent': self.mle_base[0],\n                      'adversary': self.mle_base[1]\n                      },\n            'adversary': {'agent': self.mle_base[0],\n                          'adversary': self.mle_base[1]\n                          }\n        }\n\n        self.tom_PHI = [LSTMClassifier(self.num_good_agents * 2 + self.num_lm * 2 +\n                                    (self.nagents - 1) * 2 + 5 * (self.nagents - 1),\n                                  self.agent_init_params[-1]['num_out_pol'],\n                                  hidden_size=64),    # infer good agent\n                         LSTMClassifier(self.num_good_agents * 2 + self.num_lm * 2 +\n                                    (self.nagents - 1) * 2 + 5 * (self.nagents - 1),\n                                    self.agent_init_params[-1]['num_out_pol'],\n                                    hidden_size=64),  # infer adversary\n                         ]   #TODO\n        self._agent_tom_init()  #TODO\n\n        self.tom1 = ToM1(self.tom_base, alg_types, agent_types, num_lm, device)\n        self.actions_tom0 = []\n        self.next_actions_tom0 = []\n        self.actions_tom1 = []\n        self.next_actions_tom1 = []\n        self.mle_opts = [Adam(i.parameters(), lr=1e-4) for i in self.mle_base]\n        self.PHI_opts = [Adam(i.parameters(), lr=1e-4) for i in self.tom_PHI]\n        self.gamma = gamma\n        self.tau = tau\n        self.lr = lr\n        self.discrete_action = discrete_action\n        self.pol_dev = 'cpu'  # device for policies\n        self.critic_dev = 'cpu'  # device for critics\n        self.trgt_pol_dev = 'cpu'  # device for target policies\n        self.trgt_critic_dev = 'cpu'  # device for target critics\n        self.mle_dev = 'cpu'\n        self.niter = 0\n\n    @property\n    def policies(self):\n        return [a.policy for a in self.agents]\n\n    @property\n    def target_policies(self):\n        return [a.target_policy for a in self.agents]\n\n    def scale_noise(self, scale):\n        \"\"\"\n        Scale noise for each agent\n        Inputs:\n            scale (float): scale of noise\n        \"\"\"\n        for a in self.agents:\n            a.scale_noise(scale)\n\n    def reset_noise(self):\n        for a in self.agents:\n            a.reset_noise()\n\n    def _get_index1(self, lst=None, item=''):\n        return [index for (index, value) in enumerate(lst) if value == item]\n\n    def _agent_tom_init(self):\n        # other_alg_types_ = self.alg_types.copy()\n        other_agent_types_ = self.agent_types.copy()\n        for agent_i in range(self.nagents):\n            # other_alg_types = other_alg_types_.copy()\n            other_agent_types = other_agent_types_.copy()\n            # other_alg_types.pop(agent_i)\n            other_agent_types.pop(agent_i)\n\n            adv_indx = self._get_index1(other_agent_types, 'adversary')\n            good_indx = self._get_index1(other_agent_types, 'agent')\n            self.agents[agent_i].mle += [self.tom_base[self.agent_types[agent_i]]['adversary']] * len(adv_indx)     #TODO\n            self.agents[agent_i].mle += [self.tom_base[self.agent_types[agent_i]]['agent']] * len(good_indx)\n\n    def step(self, observations, actions_pre, explore=False):    #simple_tag\n        \"\"\"\n        Take a step forward in environment with all agents\n        Inputs:\n            observations: List of observations for each agent\n            explore (boolean): Whether or not to add exploration noise\n        Outputs:\n            actions: List of actions for each agent\n        \"\"\"\n        # t1 = time.time()\n        observations_ = observations.copy()\n        actions_pre_ = actions_pre.copy()\n        # other_alg_types_ = self.alg_types.copy()\n        other_agent_types_ = self.agent_types.copy()\n        adv_agent_indx = self._get_index1(self.agent_types, 'adversary')\n        good_agent_indx = self._get_index1(self.agent_types, 'agent')\n        '''\n        tom0\n        '''\n        actions_tom0 = []\n        actions_tom0 += [\n            gumbel_softmax(\n                self.mle_base[1].to(self.device)(\n                    torch.cat((observations[j][:, -(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2):],\n                               actions_pre[j][:, :5]), 1).to(self.device)).detach(), hard=True\n            ) for j in adv_agent_indx\n        ]\n        actions_tom0 += [\n            gumbel_softmax(\n                self.mle_base[0].to(self.device)(\n                    torch.cat((observations[j][:, -(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2):],\n                               actions_pre[j][:, :5]), 1).to(self.device)).detach(), hard=True\n            ) for j in good_agent_indx\n        ]\n\n        '''\n        tom1\n        '''\n        actions_tom1 = []\n        for agent_i, obs in enumerate(observations):\n            obs_ = observations_.copy()\n            acs_other = actions_tom0.copy()\n            other_agent_types = other_agent_types_.copy()\n            obs_.pop(agent_i)\n            acs_other.pop(agent_i)\n            other_agent_types.pop(agent_i)\n\n            if agent_i in adv_agent_indx:\n                actions_tom1.append(\n                    gumbel_softmax(\n                        self.tom_PHI[1].to(self.device)(\n                    torch.cat((obs[:, -(self.num_good_agents * 2 + self.num_lm * 2 +\n                    (self.nagents - 1) * 2):].to(self.device), torch.cat(acs_other, 1)), 1)), hard=True\n                    ).cpu()\n                )\n\n            elif agent_i in good_agent_indx:\n                actions_tom1.append(\n                    gumbel_softmax(\n                        self.tom_PHI[0].to(self.device)(\n                    torch.cat((obs[:, -(self.num_good_agents * 2 + self.num_lm * 2 +\n                    (self.nagents - 1) * 2):].to(self.device), torch.cat(acs_other, 1)), 1)), hard=True\n                    ).cpu()\n                )\n\n\n        observations = self._get_obs(observations, actions_tom1)\n\n        return [a.step(obs, explore=explore) for a, obs in zip(self.agents,\n                                                               observations)]\n\n    def _get_obs(self, observations, action_tom):\n        observations_ = []\n        other_actions_tom_ = action_tom.copy()\n        for agent_i, obs in enumerate(observations):\n            other_action_tom = other_actions_tom_.copy()\n            other_action_tom.pop(agent_i)\n            actions = other_action_tom\n            observations_.append(torch.cat((obs, torch.cat(actions, 1)), 1))\n        return observations_\n\n    def train_tom0(self, sample, agent_i):\n        acs_pre, obs, acs, rews, next_obs, dones = sample\n\n        adv_agent_indx = self._get_index1(self.agent_types, 'adversary')\n        good_agent_indx = self._get_index1(self.agent_types, 'agent')\n        # self.agent_types[tom_agent_indx]\n        '''\n        data\n        for with_tom\n        for without_tom\n        '''\n        if adv_agent_indx != []:\n            adv_input = torch.cat([torch.cat((obs[i], acs_pre[i][:, :5]), 1) for i in adv_agent_indx])\n            label_adv_output = torch.cat([acs[i][:, :5] for i in adv_agent_indx])\n            self.mle_base[1].zero_grad()\n            adv_output = self.mle_base[1](adv_input[:,\n                      -(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2 + 5):])\n            loss_adv = F.mse_loss(adv_output.float(), label_adv_output.float())\n            loss_adv.backward(retain_graph=True)\n            torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)\n            self.mle_opts[1].step()\n\n        if good_agent_indx != []:\n            good_input = torch.cat([torch.cat((obs[i], acs_pre[i][:, :5]), 1) for i in good_agent_indx])\n            label_good_output = torch.cat([acs[i][:, :5] for i in good_agent_indx])\n            self.mle_base[0].zero_grad()\n            good_output = self.mle_base[0](good_input[:,\n                      -(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2 + 5):])\n            loss_good = F.mse_loss(good_output.float(), label_good_output.float())\n            loss_good.backward(retain_graph=True)\n            torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)\n            self.mle_opts[0].step()\n        '''\n        Only train agents with ToM (adversarys)\n        '''\n        '''\n        adv-adv\n        adv-good\n        '''\n\n    def tom1_infer_other(self, sample):\n        acs_pre, obs, acs, rews, next_obs, dones = sample\n        self.actions_tom1 = []\n        self.next_actions_tom1 = []\n        actions_tom1 = []\n        next_actions_tom1 = []\n        other_actions_tom0_ = self.actions_tom0.copy()\n        other_next_actions_tom0_ = self.next_actions_tom0.copy()\n        adv_agent_indx = self._get_index1(self.agent_types, 'adversary')\n        good_agent_indx = self._get_index1(self.agent_types, 'agent')\n        good_in = []\n        adv_in =[]\n        good_in_next = []\n        adv_in_next =[]\n        for agent_i, (obs_i, next_obs_i) in enumerate(zip(obs, next_obs)):\n            other_action_tom0 = other_actions_tom0_.copy()\n            other_next_actions_tom0 = other_next_actions_tom0_.copy()\n            other_action_tom0.pop(agent_i)\n            other_next_actions_tom0.pop(agent_i)\n            if agent_i in adv_agent_indx:\n                adv_in.append(torch.cat(\n                    (obs_i[:, -(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2):],\n                     torch.cat(other_action_tom0, 1)), 1))\n                adv_in_next.append(torch.cat(\n                    (next_obs_i[:, -(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2):],\n                     torch.cat(other_next_actions_tom0, 1)), 1))\n\n            elif agent_i in good_agent_indx:\n                good_in.append(torch.cat(\n                    (obs_i[:, -(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2):],\n                     torch.cat(other_action_tom0, 1)), 1))\n                good_in_next.append(torch.cat(\n                    (next_obs_i[:, -(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2):],\n                     torch.cat(other_next_actions_tom0, 1)), 1))\n\n        if adv_agent_indx != []:\n            adv_in = torch.cat(adv_in, 0)\n            adv_in_next = torch.cat(adv_in_next, 0)\n            actions_tom1.append(gumbel_softmax(\n                self.tom_PHI[1].to(self.device)(\n                    adv_in), hard=True\n            ))\n\n            next_actions_tom1.append(gumbel_softmax(\n                self.tom_PHI[1].to(self.device)(\n                    adv_in_next), hard=True\n            ))  # adv\n            label_adv_output = torch.cat([self.actions_tom0[i] for i in adv_agent_indx]).detach()\n            # label_adv_output = torch.cat([acs[i][:, :5] for i in adv_agent_indx])\n            adv_output = actions_tom1[0]\n            loss_adv = F.mse_loss(adv_output.float(), label_adv_output.float())\n            loss_adv.backward(retain_graph=True)\n            torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)\n            self.PHI_opts[1].step()\n        if good_agent_indx != []:\n            good_in = torch.cat(good_in, 0)\n            good_in_next = torch.cat(good_in_next, 0)\n            actions_tom1.append(gumbel_softmax(\n                self.tom_PHI[0].to(self.device)(\n                    good_in), hard=True\n            ))\n            next_actions_tom1.append(gumbel_softmax(\n                self.tom_PHI[0].to(self.device)(\n                    good_in_next), hard=True\n            ))  # agent\n            label_good_output = torch.cat([self.actions_tom0[i] for i in good_agent_indx]).detach()\n            # label_good_output = torch.cat([acs[i][:, :5] for i in good_agent_indx])\n            if self.config.env_id == 'simple_spread' or self.config.env_id == 'hetero_spread':\n                good_output = actions_tom1[0]\n            else:\n                good_output = actions_tom1[1]\n            loss_good = F.mse_loss(good_output.float(), label_good_output.float())\n            loss_good.backward(retain_graph=True)\n            torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)\n            self.PHI_opts[0].step()\n\n        actions_tom1 = torch.cat(actions_tom1)#.detach()\n        next_actions_tom1 = torch.cat(next_actions_tom1).detach()\n        for i in range(self.nagents):\n            self.actions_tom1.append(actions_tom1[i*self.config.batch_size:(i+1)*self.config.batch_size, :])\n            self.next_actions_tom1.append(next_actions_tom1[i*self.config.batch_size:(i+1)*self.config.batch_size, :])\n        # print(self.actions_tom1)\n\n    def tom0_output(self, sample):\n        acs_pre, obs, acs, rews, next_obs, dones = sample\n        self.actions_tom0 = []\n        self.next_actions_tom0 = []\n        adv_indx = self._get_index1(self.agent_types, 'adversary')\n        good_indx = self._get_index1(self.agent_types, 'agent')\n        self.actions_tom0 += [\n            gumbel_softmax(\n                self.mle_base[1].to(self.device)(\ntorch.cat((obs[j][:, -(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2):],\n                               acs_pre[j][:, :5]), 1)).detach(), hard=True\n            ) for j in adv_indx\n        ]\n        self.actions_tom0 += [\n            gumbel_softmax(\n                self.mle_base[0].to(self.device)(\ntorch.cat((obs[j][:, -(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2):],\n                               acs_pre[j][:, :5]), 1)).detach(), hard=True\n            ) for j in good_indx\n        ]\n        self.next_actions_tom0 += [\n            gumbel_softmax(\n                self.mle_base[1].to(self.device)(\ntorch.cat((next_obs[j][:, -(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2):],\n                               acs[j][:, :5]), 1)).detach(), hard=True\n            ) for j in adv_indx\n        ]\n        self.next_actions_tom0 += [\n            gumbel_softmax(\n                self.mle_base[0].to(self.device)(\ntorch.cat((next_obs[j][:, -(self.num_good_agents * 2 + self.num_lm * 2 + (self.nagents - 1) * 2):],\n                               acs[j][:, :5]), 1)).detach(), hard=True\n            ) for j in good_indx\n        ]\n        # actions_tom0 = torch.cat(actions_tom0, 1)\n        # return actions_tom0, next_actions_tom0\n\n    def update(self, sample, agent_i, parallel=False, logger=None, sample_r=None):\n\n        # other_alg_types = self.alg_types.copy()\n        other_agent_types = self.agent_types.copy()\n        # other_alg_types.pop(agent_i)\n        other_agent_types.pop(agent_i)\n        adv_indx = self._get_index1(other_agent_types, 'adversary')\n        good_indx = self._get_index1(other_agent_types, 'agent')\n        agent_i_alg = self.alg_types[agent_i]\n        acs_pre, obs, acs, rews, next_obs, dones = sample\n\n        next_obs_ = self._get_obs(next_obs, self.next_actions_tom1)\n        obs_ = self._get_obs(obs, self.actions_tom1)\n        curr_agent = self.agents[agent_i]\n\n        '''\n        distance between self and other\n        '''\n        Euclidean_D = []\n        for i in range(len(other_agent_types)):\n            Euclidean_D.append(obs[agent_i][:,\n               -len(other_agent_types)*2:][:, i: i+2].pow(2).sum(1).sqrt())\n        Euclidean_D_ = torch.stack(Euclidean_D, 1)\n        '''\n        distance between self and landmark\n        '''\n        Euclidean_L = []\n        for i in range(self.num_lm):\n            Euclidean_L.append(obs[agent_i][:, -len(other_agent_types)*2-self.num_lm*2:\n                   -len(other_agent_types)*2][:, i: i+2].pow(2).sum(1).sqrt())\n        Euclidean_L = torch.stack(Euclidean_L, 1)\n\n        close_agent_index = (Euclidean_D_ == Euclidean_D_.min(dim=1, keepdim=True)[0])\\\n            .to(dtype=torch.int32)    #run11/run12 self-orgnization\n        # close_agent_index = torch.ones((self.config.batch_size, len(other_agent_types))) \\\n        #         .to(dtype=torch.int32).to(self.device)  #run13\n\n        if agent_i == 0:\n            self.train_tom0(sample, agent_i)\n\n        E_action = self.tom1.tom1_output(agent_i, adv_indx,\n          good_indx, obs[agent_i], acs_pre[agent_i])\n\n        if agent_i_alg == 'with_tom':\n            acs_other = acs.copy()\n            acs_other.pop(agent_i)\n\n            # KL loss\n            # adv_loss = sum([KL_criterion(E_action[j], acs_other[j][:, :5].float()) for j in adv_indx])  #TODO\n            # good_loss = sum([KL_criterion(E_action[j], acs_other[j][:, :5].float()) for j in good_indx])\n\n            # L2 loss\n            action_loss = torch.norm(acs[agent_i][:, :5] - E_action[0], p=2, dim=1)\n            loss_other = 0.1 * torch.stack([action_loss]*len(other_agent_types), 1)\n\n            if agent_i in adv_indx:\n                '''\n                adv_loss : decrease\n                good_loss : increase\n                '''\n                close_agent_index[:, good_indx] *= -1\n                close_agent_index[:, adv_indx] *= 0.1\n                intri_rew = close_agent_index.mul(loss_other).mul(Euclidean_D_).sum(1)\n                # intri_rew = close_agent_index.mul(loss_other).sum(1)\n            else:\n                '''\n                adv_loss : increase\n                good_loss : decrease\n                '''\n                close_agent_index[:, adv_indx] *= 1\n                close_agent_index[:, good_indx] *= 0.1\n                # if self.config.env_id == 'simple_adversary':\n                #     intri_rew = close_agent_index.mul(loss_other).mul(Euclidean_D_).sum(1) - \\\n                #                 obs[agent_i][:, :2].pow(2).sum(1)\n                # elif self.config.env_id == 'simple_spread_pre':\n                #     intri_rew = close_agent_index.mul(loss_other).mul(Euclidean_D_).sum(1) - \\\n                #                 Euclidean_L.min(dim=1, keepdim=True)[0][:,0]\n                # else:\n                intri_rew = close_agent_index.mul(loss_other).mul(Euclidean_D_).sum(1)\n                    # intri_rew = close_agent_index.mul(loss_other).sum(1)\n            rews[agent_i] = rews[agent_i] + intri_rew.detach()\n\n\n        # center critic\n        curr_agent.critic_optimizer.zero_grad()\n        all_trgt_acs = []\n        if self.discrete_action:  # one-hot encode action\n            all_trgt_acs = [onehot_from_logits(pi(nobs)) for pi, nobs in\n                            zip(self.target_policies, next_obs_)]\n        trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1)\n\n        target_value = (rews[agent_i].view(-1, 1) + self.gamma *\n                        curr_agent.target_critic(trgt_vf_in) *\n                        (1 - dones[agent_i].view(-1, 1)))\n\n        vf_in = torch.cat((*obs, *acs), dim=1)\n\n        actual_value = curr_agent.critic(vf_in)\n        vf_loss = MSELoss(actual_value, target_value.detach())\n        vf_loss.backward()\n        if parallel:\n            average_gradients(curr_agent.critic)\n        torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), 0.5)\n        curr_agent.critic_optimizer.step()\n\n        curr_agent.policy_optimizer.zero_grad()\n        if self.discrete_action:\n            # Forward pass as if onehot (hard=True) but backprop through a differentiable\n            # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop\n            # through discrete categorical samples, but I'm not sure if that is\n            # correct since it removes the assumption of a deterministic policy for\n            # DDPG. Regardless, discrete policies don't seem to learn properly without it.\n\n            curr_pol_out = curr_agent.policy(obs_[agent_i].detach())\n            curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)\n        else:\n            curr_pol_out = curr_agent.policy(obs[agent_i])\n            curr_pol_vf_in = curr_pol_out\n        all_pol_acs = []\n        for i, pi, ob in zip(range(self.nagents), self.policies, obs_):\n            ob = ob.detach()\n            if i == agent_i:\n                all_pol_acs.append(curr_pol_vf_in)\n            elif self.discrete_action:\n                all_pol_acs.append(onehot_from_logits(pi(ob)))\n            else:\n                all_pol_acs.append(pi(ob))\n            vf_in = torch.cat((*obs, *all_pol_acs), dim=1)\n\n        pol_loss = -curr_agent.critic(vf_in).mean()\n        pol_loss += (curr_pol_out ** 2).mean() * 1e-3\n        pol_loss.backward()\n        if parallel:\n            average_gradients(curr_agent.policy)\n        torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), 0.5)\n        # actor\n        curr_agent.policy_optimizer.step()\n        # print('c_loss:',vf_loss, 'p_loss:', pol_loss)\n        if logger is not None:\n            logger.add_scalars('agent%i/losses' % agent_i,\n                               {'vf_loss': vf_loss,\n                                'pol_loss': pol_loss},\n                               self.niter)\n\n    def update_all_targets(self):\n        \"\"\"\n        Update all target networks (called after normal updates have been\n        performed for each agent)\n        \"\"\"\n        for a in self.agents:\n            soft_update(a.target_critic, a.critic, self.tau)\n            soft_update(a.target_policy, a.policy, self.tau)\n        self.niter += 1\n\n    def prep_training(self, device='gpu'):\n        # for i in self.tom_base.values():\n        #     for mle in i.values():\n        #         mle.train()\n        for mle in self.mle_base:\n            mle.train()\n        for a in self.agents:\n            a.policy.train()\n            a.critic.train()\n            a.target_policy.train()\n            a.target_critic.train()\n            for mle_i in a.mle:\n                mle_i.train()\n        if device == 'gpu':\n            fn = lambda x: x.to(torch.device(self.device))\n        else:\n            fn = lambda x: x.cpu()\n        if not self.pol_dev == device:\n            for a in self.agents:\n                a.policy = fn(a.policy)\n            self.pol_dev = device\n        if not self.critic_dev == device:\n            for a in self.agents:\n                a.critic = fn(a.critic)\n            self.critic_dev = device\n        if not self.trgt_pol_dev == device:\n            for a in self.agents:\n                a.target_policy = fn(a.target_policy)\n            self.trgt_pol_dev = device\n        if not self.trgt_critic_dev == device:\n            for a in self.agents:\n                a.target_critic = fn(a.target_critic)\n            self.trgt_critic_dev = device\n        if not self.mle_dev == device:\n            # for i in self.tom_base.keys():\n            #     for j in self.tom_base[i].keys():\n            #         self.tom_base[i][j] = fn(mle)\n            for i, mle in enumerate(self.mle_base):\n                self.mle_base[i] = fn(mle)\n            for a in self.agents:\n                for i, mle_i in enumerate(a.mle):\n                    a.mle[i] = fn(mle_i)\n            self.mle_dev = device\n\n    def prep_rollouts(self, device='cpu'):\n        for a in self.agents:\n            a.policy.eval()\n        if device == 'gpu':\n            fn = lambda x: x.to(torch.device(self.device))\n        else:\n            fn = lambda x: x.cpu()\n        # only need main policy for rollouts\n        if not self.pol_dev == device:\n            for a in self.agents:\n                a.policy = fn(a.policy)\n            self.pol_dev = device\n\n    def save(self, filename):\n        \"\"\"\n        Save trained parameters of all agents into one file\n        \"\"\"\n        self.prep_training(device='cpu')  # move parameters to CPU before saving\n        save_dict = {'init_dict': self.init_dict,\n                     'agent_params': [a.get_params() for a in self.agents],\n                     'tom_params': [self.get_params()],}\n        torch.save(save_dict, filename)\n\n    @classmethod\n    def init_from_env(cls, env, config, device, agent_alg, adversary_alg,\n                      gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64, output_style='sum'):\n        \"\"\"\n        Instantiate instance of this class from multi-agent environment\n        \"\"\"\n        agent_init_params = []\n        alg_types = [adversary_alg if atype == 'adversary' else agent_alg for\n                     atype in env.agent_types]\n        num_lm = env.num_lm\n        for acsp, obsp, algtype in zip(env.action_space, env.observation_space,\n                                       alg_types):\n            num_in_pol = obsp.shape[0]\n            num_in_mle = obsp.shape[0]\n            if isinstance(acsp, Box):\n                discrete_action = False\n                get_shape = lambda x: x.shape[0]\n            elif isinstance(acsp, Discrete):  # Discrete\n                discrete_action = True\n                get_shape = lambda x: x.n\n            elif isinstance(acsp, MultiDiscrete):\n                discrete_action = True\n                get_shape = lambda x: sum(x.high - x.low + 1)\n            num_out_pol = get_shape(acsp)\n            # if algtype == \"with_tom\":\n            num_in_critic = 0\n            num_in_pol += (len(env.agent_types)-1) * 5\n            for oobsp in env.observation_space:\n                num_in_critic += oobsp.shape[0]\n            for oacsp in env.action_space:\n                if isinstance(oacsp, Box):\n                    discrete_action = False\n                    get_shape = lambda x: x.shape[0]\n                elif isinstance(oacsp, Discrete):  # Discrete\n                    discrete_action = True\n                    get_shape = lambda x: x.n\n                elif isinstance(oacsp, MultiDiscrete):\n                    discrete_action = True\n                    get_shape = lambda x: sum(x.high - x.low + 1)\n                num_in_critic += get_shape(oacsp)\n            # else:\n            #     num_in_critic = obsp.shape[0] + get_shape(acsp)\n            agent_init_params.append({'num_in_pol': num_in_pol,\n                                      'num_out_pol': num_out_pol,\n                                      'num_in_critic': num_in_critic,\n                                      'num_in_mle': num_in_mle,})\n        init_dict = {'gamma': gamma, 'tau': tau, 'lr': lr,\n                     'device': device,\n                     'config' : config,\n                     'hidden_dim': hidden_dim,\n                     'alg_types': alg_types,\n                     'agent_types' : env.agent_types,\n                     'num_lm' : num_lm,\n                     'agent_init_params': agent_init_params,\n                     'discrete_action': discrete_action,\n                     'output_style': output_style}\n        instance = cls(**init_dict)\n        instance.init_dict = init_dict\n        return instance\n\n    @classmethod\n    def init_from_save(cls, filename):\n        \"\"\"\n        Instantiate instance of this class from file created by 'save' method\n        \"\"\"\n        save_dict = torch.load(filename)\n        instance = cls(**save_dict['init_dict'])\n        instance.init_dict = save_dict['init_dict']\n        for a, params in zip(instance.agents, save_dict['agent_params']):\n            a.load_params(params)\n        for a, params in zip([instance], save_dict['tom_params']):\n            a.load_params(params)\n        return instance\n\n\n    def get_params(self):\n        params = {\n                }\n        for i in range(len(self.mle_base)):\n            params['mle%d'%i] = self.mle_base[i].state_dict()\n            params['mle_optimizer%d'%i] = self.mle_opts[i].state_dict()\n            params['tom_phi%d'%i] = self.tom_PHI[i].state_dict()\n            params['phi_opt%d' % i] = self.PHI_opts[i].state_dict()\n        return params\n\n    def load_params(self, params):\n        for i in range(len(self.mle_base)):\n            self.mle_base[i].load_state_dict(params['mle%d'%i])\n            self.mle_opts[i].load_state_dict(params['mle_optimizer%d'%i])\n            self.tom_PHI[i].load_state_dict(params['tom_phi%d'%i])\n            self.PHI_opts[i].load_state_dict(params['phi_opt%d' % i])\n"
  },
  {
    "path": "examples/Social_Cognition/FOToM/common/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Social_Cognition/FOToM/common/distributions.py",
    "content": "# import tensorflow as tf\nimport tensorflow.compat.v1 as tf\ntf.compat.v1.disable_eager_execution()\nimport numpy as np\nimport maddpg.common.tf_util as U\nfrom tensorflow.python.ops import math_ops\nfrom multiagent.multi_discrete import MultiDiscrete\nfrom tensorflow.python.ops import nn\n\nclass Pd(object):\n    \"\"\"\n    A particular probability distribution\n    \"\"\"\n    def flatparam(self):\n        raise NotImplementedError\n    def mode(self):\n        raise NotImplementedError\n    def logp(self, x):\n        raise NotImplementedError\n    def kl(self, other):\n        raise NotImplementedError\n    def entropy(self):\n        raise NotImplementedError\n    def sample(self):\n        raise NotImplementedError\n\nclass PdType(object):\n    \"\"\"\n    Parametrized family of probability distributions\n    \"\"\"\n    def pdclass(self):\n        raise NotImplementedError\n    def pdfromflat(self, flat):\n        return self.pdclass()(flat)\n    def param_shape(self):\n        raise NotImplementedError\n    def sample_shape(self):\n        raise NotImplementedError\n    def sample_dtype(self):\n        raise NotImplementedError\n\n    def param_placeholder(self, prepend_shape, name=None):\n        return tf.placeholder(dtype=tf.float32, shape=prepend_shape+self.param_shape(), name=name)\n    def sample_placeholder(self, prepend_shape, name=None):\n        return tf.placeholder(dtype=self.sample_dtype(), shape=prepend_shape+self.sample_shape(), name=name)\n\nclass CategoricalPdType(PdType):\n    def __init__(self, ncat):\n        self.ncat = ncat\n    def pdclass(self):\n        return CategoricalPd\n    def param_shape(self):\n        return [self.ncat]\n    def sample_shape(self):\n        return []\n    def sample_dtype(self):\n        return tf.int32\n\nclass SoftCategoricalPdType(PdType):\n    def __init__(self, ncat):\n        self.ncat = ncat\n    def pdclass(self):\n        return SoftCategoricalPd\n    def param_shape(self):\n        return [self.ncat]\n    def sample_shape(self):\n        return [self.ncat]\n    def sample_dtype(self):\n        return tf.float32\n\nclass MultiCategoricalPdType(PdType):\n    def __init__(self, low, high):\n        self.low = low\n        self.high = high\n        self.ncats = high - low + 1\n    def pdclass(self):\n        return MultiCategoricalPd\n    def pdfromflat(self, flat):\n        return MultiCategoricalPd(self.low, self.high, flat)\n    def param_shape(self):\n        return [sum(self.ncats)]\n    def sample_shape(self):\n        return [len(self.ncats)]\n    def sample_dtype(self):\n        return tf.int32\n\nclass SoftMultiCategoricalPdType(PdType):\n    def __init__(self, low, high):\n        self.low = low\n        self.high = high\n        self.ncats = high - low + 1\n    def pdclass(self):\n        return SoftMultiCategoricalPd\n    def pdfromflat(self, flat):\n        return SoftMultiCategoricalPd(self.low, self.high, flat)\n    def param_shape(self):\n        return [sum(self.ncats)]\n    def sample_shape(self):\n        return [sum(self.ncats)]\n    def sample_dtype(self):\n        return tf.float32\n\nclass DiagGaussianPdType(PdType):\n    def __init__(self, size):\n        self.size = size\n    def pdclass(self):\n        return DiagGaussianPd\n    def param_shape(self):\n        return [2*self.size]\n    def sample_shape(self):\n        return [self.size]\n    def sample_dtype(self):\n        return tf.float32\n\nclass BernoulliPdType(PdType):\n    def __init__(self, size):\n        self.size = size\n    def pdclass(self):\n        return BernoulliPd\n    def param_shape(self):\n        return [self.size]\n    def sample_shape(self):\n        return [self.size]\n    def sample_dtype(self):\n        return tf.int32\n\n# WRONG SECOND DERIVATIVES\n# class CategoricalPd(Pd):\n#     def __init__(self, logits):\n#         self.logits = logits\n#         self.ps = tf.nn.softmax(logits)\n#     @classmethod\n#     def fromflat(cls, flat):\n#         return cls(flat)\n#     def flatparam(self):\n#         return self.logits\n#     def mode(self):\n#         return U.argmax(self.logits, axis=1)\n#     def logp(self, x):\n#         return -tf.nn.sparse_softmax_cross_entropy_with_logits(self.logits, x)\n#     def kl(self, other):\n#         return tf.nn.softmax_cross_entropy_with_logits(other.logits, self.ps) \\\n#                 - tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps)\n#     def entropy(self):\n#         return tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps)\n#     def sample(self):\n#         u = tf.random_uniform(tf.shape(self.logits))\n#         return U.argmax(self.logits - tf.log(-tf.log(u)), axis=1)\n\nclass CategoricalPd(Pd):\n    def __init__(self, logits):\n        self.logits = logits\n    def flatparam(self):\n        return self.logits\n    def mode(self):\n        return U.argmax(self.logits, axis=1)\n    def logp(self, x):\n        return -tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x)\n    def kl(self, other):\n        a0 = self.logits - U.max(self.logits, axis=1, keepdims=True)\n        a1 = other.logits - U.max(other.logits, axis=1, keepdims=True)\n        ea0 = tf.exp(a0)\n        ea1 = tf.exp(a1)\n        z0 = U.sum(ea0, axis=1, keepdims=True)\n        z1 = U.sum(ea1, axis=1, keepdims=True)\n        p0 = ea0 / z0\n        return U.sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), axis=1)\n    def entropy(self):\n        a0 = self.logits - U.max(self.logits, axis=1, keepdims=True)\n        ea0 = tf.exp(a0)\n        z0 = U.sum(ea0, axis=1, keepdims=True)\n        p0 = ea0 / z0\n        return U.sum(p0 * (tf.log(z0) - a0), axis=1)\n    def sample(self):\n        u = tf.random_uniform(tf.shape(self.logits))\n        return U.argmax(self.logits - tf.log(-tf.log(u)), axis=1)\n    @classmethod\n    def fromflat(cls, flat):\n        return cls(flat)\n\nclass SoftCategoricalPd(Pd):\n    def __init__(self, logits):\n        self.logits = logits\n    def flatparam(self):\n        return self.logits\n    def mode(self):\n        return U.softmax(self.logits, axis=-1)\n    def logp(self, x):\n        return -tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=x)\n    def kl(self, other):\n        a0 = self.logits - U.max(self.logits, axis=1, keepdims=True)\n        a1 = other.logits - U.max(other.logits, axis=1, keepdims=True)\n        ea0 = tf.exp(a0)\n        ea1 = tf.exp(a1)\n        z0 = U.sum(ea0, axis=1, keepdims=True)\n        z1 = U.sum(ea1, axis=1, keepdims=True)\n        p0 = ea0 / z0\n        return U.sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), axis=1)\n    def entropy(self):\n        a0 = self.logits - U.max(self.logits, axis=1, keepdims=True)\n        ea0 = tf.exp(a0)\n        z0 = U.sum(ea0, axis=1, keepdims=True)\n        p0 = ea0 / z0\n        return U.sum(p0 * (tf.log(z0) - a0), axis=1)\n    def sample(self):\n        u = tf.random_uniform(tf.shape(self.logits))\n        return U.softmax(self.logits - tf.log(-tf.log(u)), axis=-1)\n    @classmethod\n    def fromflat(cls, flat):\n        return cls(flat)\n\nclass MultiCategoricalPd(Pd):\n    def __init__(self, low, high, flat):\n        self.flat = flat\n        self.low = tf.constant(low, dtype=tf.int32)\n        self.categoricals = list(map(CategoricalPd, tf.split(flat, high - low + 1, axis=len(flat.get_shape()) - 1)))\n    def flatparam(self):\n        return self.flat\n    def mode(self):\n        return self.low + tf.cast(tf.stack([p.mode() for p in self.categoricals], axis=-1), tf.int32)\n    def logp(self, x):\n        return tf.add_n([p.logp(px) for p, px in zip(self.categoricals, tf.unstack(x - self.low, axis=len(x.get_shape()) - 1))])\n    def kl(self, other):\n        return tf.add_n([\n                p.kl(q) for p, q in zip(self.categoricals, other.categoricals)\n            ])\n    def entropy(self):\n        return tf.add_n([p.entropy() for p in self.categoricals])\n    def sample(self):\n        return self.low + tf.cast(tf.stack([p.sample() for p in self.categoricals], axis=-1), tf.int32)\n    @classmethod\n    def fromflat(cls, flat):\n        return cls(flat)\n\nclass SoftMultiCategoricalPd(Pd):  # doesn't work yet\n    def __init__(self, low, high, flat):\n        self.flat = flat\n        self.low = tf.constant(low, dtype=tf.float32)\n        self.categoricals = list(map(SoftCategoricalPd, tf.split(flat, high - low + 1, axis=len(flat.get_shape()) - 1)))\n    def flatparam(self):\n        return self.flat\n    def mode(self):\n        x = []\n        for i in range(len(self.categoricals)):\n            x.append(self.low[i] + self.categoricals[i].mode())\n        return tf.concat(x, axis=-1)\n    def logp(self, x):\n        return tf.add_n([p.logp(px) for p, px in zip(self.categoricals, tf.unstack(x - self.low, axis=len(x.get_shape()) - 1))])\n    def kl(self, other):\n        return tf.add_n([\n                p.kl(q) for p, q in zip(self.categoricals, other.categoricals)\n            ])\n    def entropy(self):\n        return tf.add_n([p.entropy() for p in self.categoricals])\n    def sample(self):\n        x = []\n        for i in range(len(self.categoricals)):\n            x.append(self.low[i] + self.categoricals[i].sample())\n        return tf.concat(x, axis=-1)\n    @classmethod\n    def fromflat(cls, flat):\n        return cls(flat)\n\nclass DiagGaussianPd(Pd):\n    def __init__(self, flat):\n        self.flat = flat\n        mean, logstd = tf.split(axis=1, num_or_size_splits=2, value=flat)\n        self.mean = mean\n        self.logstd = logstd\n        self.std = tf.exp(logstd)\n    def flatparam(self):\n        return self.flat\n    def mode(self):\n        return self.mean\n    def logp(self, x):\n        return - 0.5 * U.sum(tf.square((x - self.mean) / self.std), axis=1) \\\n               - 0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[1]) \\\n               - U.sum(self.logstd, axis=1)\n    def kl(self, other):\n        assert isinstance(other, DiagGaussianPd)\n        return U.sum(other.logstd - self.logstd + (tf.square(self.std) + tf.square(self.mean - other.mean)) / (2.0 * tf.square(other.std)) - 0.5, axis=1)\n    def entropy(self):\n        return U.sum(self.logstd + .5 * np.log(2.0 * np.pi * np.e), 1)\n    def sample(self):\n        return self.mean + self.std * tf.random_normal(tf.shape(self.mean))\n    @classmethod\n    def fromflat(cls, flat):\n        return cls(flat)\n\nclass BernoulliPd(Pd):\n    def __init__(self, logits):\n        self.logits = logits\n        self.ps = tf.sigmoid(logits)\n    def flatparam(self):\n        return self.logits\n    def mode(self):\n        return tf.round(self.ps)\n    def logp(self, x):\n        return - U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=tf.to_float(x)), axis=1)\n    def kl(self, other):\n        return U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=other.logits, labels=self.ps), axis=1) - U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=1)\n    def entropy(self):\n        return U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=1)\n    def sample(self):\n        p = tf.sigmoid(self.logits)\n        u = tf.random_uniform(tf.shape(p))\n        return tf.to_float(math_ops.less(u, p))\n    @classmethod\n    def fromflat(cls, flat):\n        return cls(flat)\n\ndef make_pdtype(ac_space):\n    from gym import spaces\n    if isinstance(ac_space, spaces.Box):\n        assert len(ac_space.shape) == 1\n        return DiagGaussianPdType(ac_space.shape[0])\n    elif isinstance(ac_space, spaces.Discrete):\n        # return CategoricalPdType(ac_space.n)\n        return SoftCategoricalPdType(ac_space.n)\n    elif isinstance(ac_space, MultiDiscrete):\n        #return MultiCategoricalPdType(ac_space.low, ac_space.high)\n        return SoftMultiCategoricalPdType(ac_space.low, ac_space.high)\n    elif isinstance(ac_space, spaces.MultiBinary):\n        return BernoulliPdType(ac_space.n)\n    else:\n        raise NotImplementedError\n\ndef shape_el(v, i):\n    maybe = v.get_shape()[i]\n    if maybe is not None:\n        return maybe\n    else:\n        return tf.shape(v)[i]\n"
  },
  {
    "path": "examples/Social_Cognition/FOToM/common/tile_images.py",
    "content": "import numpy as np\n\ndef tile_images(img_nhwc):\n    \"\"\"\n    Tile N images into one big PxQ image\n    (P,Q) are chosen to be as close as possible, and if N\n    is square, then P=Q.\n\n    input: img_nhwc, list or array of images, ndim=4 once turned into array\n        n = batch index, h = height, w = width, c = channel\n    returns:\n        bigim_HWc, ndarray with ndim=3\n    \"\"\"\n    img_nhwc = np.asarray(img_nhwc)\n    N, h, w, c = img_nhwc.shape\n    H = int(np.ceil(np.sqrt(N)))\n    W = int(np.ceil(float(N)/H))\n    img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)])\n    img_HWhwc = img_nhwc.reshape(H, W, h, w, c)\n    img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4)\n    img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c)\n    return img_Hh_Ww_c\n\n"
  },
  {
    "path": "examples/Social_Cognition/FOToM/common/vec_env/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Social_Cognition/FOToM/common/vec_env/vec_env.py",
    "content": "import contextlib\nimport os\nfrom abc import ABC, abstractmethod\n\nfrom common.tile_images import tile_images\n\nclass AlreadySteppingError(Exception):\n    \"\"\"\n    Raised when an asynchronous step is running while\n    step_async() is called again.\n    \"\"\"\n\n    def __init__(self):\n        msg = 'already running an async step'\n        Exception.__init__(self, msg)\n\n\nclass NotSteppingError(Exception):\n    \"\"\"\n    Raised when an asynchronous step is not running but\n    step_wait() is called.\n    \"\"\"\n\n    def __init__(self):\n        msg = 'not running an async step'\n        Exception.__init__(self, msg)\n\n\nclass VecEnv(ABC):\n    \"\"\"\n    An abstract asynchronous, vectorized environment.\n    Used to batch data from multiple copies of an environment, so that\n    each observation becomes an batch of observations, and expected action is a batch of actions to\n    be applied per-environment.\n    \"\"\"\n    closed = False\n    viewer = None\n\n    metadata = {\n        'render.modes': ['human', 'rgb_array']\n    }\n\n    def __init__(self, num_envs, observation_space, action_space):\n        self.num_envs = num_envs\n        self.observation_space = observation_space\n        self.action_space = action_space\n\n    @abstractmethod\n    def reset(self):\n        \"\"\"\n        Reset all the environments and return an array of\n        observations, or a dict of observation arrays.\n\n        If step_async is still doing work, that work will\n        be cancelled and step_wait() should not be called\n        until step_async() is invoked again.\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def step_async(self, actions):\n        \"\"\"\n        Tell all the environments to start taking a step\n        with the given actions.\n        Call step_wait() to get the results of the step.\n\n        You should not call this if a step_async run is\n        already pending.\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def step_wait(self):\n        \"\"\"\n        Wait for the step taken with step_async().\n\n        Returns (obs, rews, dones, infos):\n         - obs: an array of observations, or a dict of\n                arrays of observations.\n         - rews: an array of rewards\n         - dones: an array of \"episode done\" booleans\n         - infos: a sequence of info objects\n        \"\"\"\n        pass\n\n    def close_extras(self):\n        \"\"\"\n        Clean up the  extra resources, beyond what's in this base class.\n        Only runs when not self.closed.\n        \"\"\"\n        pass\n\n    def close(self):\n        if self.closed:\n            return\n        if self.viewer is not None:\n            self.viewer.close()\n        self.close_extras()\n        self.closed = True\n\n    def step(self, actions):\n        \"\"\"\n        Step the environments synchronously.\n\n        This is available for backwards compatibility.\n        \"\"\"\n        self.step_async(actions)\n        return self.step_wait()\n\n    def render(self, mode='human'):\n        imgs = self.get_images()\n        bigimg = tile_images(imgs)\n        if mode == 'human':\n            self.get_viewer().imshow(bigimg)\n            return self.get_viewer().isopen\n        elif mode == 'rgb_array':\n            return bigimg\n        else:\n            raise NotImplementedError\n\n    def get_images(self):\n        \"\"\"\n        Return RGB images from each environment\n        \"\"\"\n        raise NotImplementedError\n\n    @property\n    def unwrapped(self):\n        if isinstance(self, VecEnvWrapper):\n            return self.venv.unwrapped\n        else:\n            return self\n\n    def get_viewer(self):\n        if self.viewer is None:\n            from gym.envs.classic_control import rendering\n            self.viewer = rendering.SimpleImageViewer()\n        return self.viewer\n\nclass VecEnvWrapper(VecEnv):\n    \"\"\"\n    An environment wrapper that applies to an entire batch\n    of environments at once.\n    \"\"\"\n\n    def __init__(self, venv, observation_space=None, action_space=None):\n        self.venv = venv\n        super().__init__(num_envs=venv.num_envs,\n                        observation_space=observation_space or venv.observation_space,\n                        action_space=action_space or venv.action_space)\n\n    def step_async(self, actions):\n        self.venv.step_async(actions)\n\n    @abstractmethod\n    def reset(self):\n        pass\n\n    @abstractmethod\n    def step_wait(self):\n        pass\n\n    def close(self):\n        return self.venv.close()\n\n    def render(self, mode='human'):\n        return self.venv.render(mode=mode)\n\n    def get_images(self):\n        return self.venv.get_images()\n\n    def __getattr__(self, name):\n        if name.startswith('_'):\n            raise AttributeError(\"attempted to get missing private attribute '{}'\".format(name))\n        return getattr(self.venv, name)\n\nclass VecEnvObservationWrapper(VecEnvWrapper):\n    @abstractmethod\n    def process(self, obs):\n        pass\n\n    def reset(self):\n        obs = self.venv.reset()\n        return self.process(obs)\n\n    def step_wait(self):\n        obs, rews, dones, infos = self.venv.step_wait()\n        return self.process(obs), rews, dones, infos\n\nclass CloudpickleWrapper(object):\n    \"\"\"\n    Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)\n    \"\"\"\n\n    def __init__(self, x):\n        self.x = x\n\n    def __getstate__(self):\n        import cloudpickle\n        return cloudpickle.dumps(self.x)\n\n    def __setstate__(self, ob):\n        import pickle\n        self.x = pickle.loads(ob)\n\n\n@contextlib.contextmanager\ndef clear_mpi_env_vars():\n    \"\"\"\n    from mpi4py import MPI will call MPI_Init by default.  If the child process has MPI environment variables, MPI will think that the child process is an MPI process just like the parent and do bad things such as hang.\n    This context manager is a hacky way to clear those environment variables temporarily such as when we are starting multiprocessing\n    Processes.\n    \"\"\"\n    removed_environment = {}\n    for k, v in list(os.environ.items()):\n        for prefix in ['OMPI_', 'PMI_']:\n            if k.startswith(prefix):\n                removed_environment[k] = v\n                del os.environ[k]\n    try:\n        yield\n    finally:\n        os.environ.update(removed_environment)\n"
  },
  {
    "path": "examples/Social_Cognition/FOToM/evaluate.py",
    "content": "import argparse\nimport torch\nimport time\nimport imageio\nimport numpy as np\nfrom pathlib import Path\nfrom torch.autograd import Variable\nfrom utils.make_env import make_env\nfrom algorithms.tom11 import ToM_decision11\nfrom algorithms.maddpg import MADDPG\nimport matplotlib.pyplot as plt\nfrom tqdm import tqdm\nfrom utils.env_wrappers import SubprocVecEnv, DummyVecEnv\n\ndef display_frames_as_gif(frames):\n    patch = plt.imshow(frames[1])\n    plt.axis('off')\n\n    plt.savefig('./images/comm2', bbox_inches='tight')\n\ndef make_parallel_env(env_id, n_rollout_threads, discrete_action, num_good_agents, num_adversaries):\n    def get_env_fn(rank):\n        def init_env():\n            env = make_env(env_id, num_good_agents=num_good_agents, num_adversaries=num_adversaries, discrete_action=discrete_action)\n            # env.seed(seed + rank * 1000)\n            # np.random.seed(seed + rank * 1000)\n            return env\n        return init_env\n    if n_rollout_threads == 1:\n        return DummyVecEnv([get_env_fn(0)])\n    else:\n        return SubprocVecEnv([get_env_fn(i) for i in range(n_rollout_threads)])\n\ndef run(config):\n    rew_ep = []\n    for i in range(config.num):\n        for run_num in config.run_num:\n            pbar = tqdm(config.n_episodes)\n            model_path = (Path('./models') / config.env_id / config.model_name /\n                          ('run%i' % (run_num)))\n            if config.incremental is not None:\n                model_path = model_path / 'incremental' / ('model_ep%i.pt' %\n                                                           config.incremental)\n            else:\n                model_path = model_path / 'model.pt'\n\n            # if config.save_gifs:\n            #     gif_path = model_path.parent / 'gifs'\n            #     gif_path.mkdir(exist_ok=True)\n            if config.alg == 'ToM1':\n                maddpg = ToM_decision11.init_from_save(model_path)\n            elif config.alg == 'ToM_SB01' or config.alg== 'ToM_SA01':\n                maddpg = ToM_decision01.init_from_save(model_path)#.eval()\n            elif config.alg == 'ToM_SBN1' or config.alg== 'ToM_SAN1':\n                maddpg = ToM_decisionN1.init_from_save(model_path)#.eval()\n            elif config.alg == 'MADDPG':\n                maddpg = MADDPG.init_from_save(model_path)\n\n            # env = make_env(config.env_id, num_good_agents=config.num_good_agents,\n            #                num_adversaries=config.num_adversaries, discrete_action=maddpg.discrete_action)\n            env = make_parallel_env(config.env_id, config.n_rollout_threads,\n                                    config.discrete_action, config.num_good_agents, config.num_adversaries)\n            maddpg.prep_rollouts(device='cpu')\n            ifi = 1 / config.fps  # inter-frame interval\n\n            for ep_i in range(0, config.n_episodes, config.n_rollout_threads):\n                rew = np.zeros((config.n_rollout_threads, config.num_good_agents + config.num_adversaries))\n                torch_agent_actions = [torch.zeros((config.n_rollout_threads, 5))\n                                       for i in range(maddpg.nagents)]\n                # print(\"Episode %i of %i\" % (ep_i + 1, config.n_episodes))\n                obs = env.reset()\n\n                for t_i in range(config.episode_length):\n                    # calc_start = time.time()\n                    # rearrange observations to be per agent, and convert to torch Variable\n                    torch_obs = [Variable(torch.Tensor(np.vstack(obs[:, i])),\n                                          requires_grad=False)\n                                 for i in range(maddpg.nagents)]\n                    # get actions as torch Variables\n                    # t1 = time.time()\n                    if config.alg == 'MADDPG':\n                        torch_actions = maddpg.step(torch_obs, explore=False)\n\n                    else:\n                        torch_actions = maddpg.step(torch_obs, torch_agent_actions, explore=False)\n\n                    actions = [ac.data.cpu().numpy() for ac in torch_actions]\n                    actions = [[ac[i] for ac in actions] for i in range(config.n_rollout_threads)]\n                    obs, rewards, dones, infos = env.step(actions)\n                    rew += rewards\n                rew_ep.append(rew)\n                pbar.update(config.n_rollout_threads)\n\n            pbar.close()\n    rew_ep = np.concatenate(rew_ep, 0)\n    rew_ep_agent = rew_ep.mean(0)\n    std_ep_agent = rew_ep.std(0)\n    print('mean:', rew_ep_agent, 'std:', std_ep_agent)\n    rew_ep_good = rew_ep[:, -config.num_good_agents:].sum(1).mean()\n    rew_ep_adv = rew_ep[:, :config.num_adversaries].sum(1).mean()\n    std_ep_good = rew_ep[:, -config.num_good_agents:].sum(1).std()\n    std_ep_adv = rew_ep[:, :config.num_adversaries].sum(1).std()\n    print('good:', rew_ep_good, 'std:', std_ep_good)\n    print('adv:', rew_ep_adv, 'std:', std_ep_adv)\n    rew_ep_all = rew_ep.sum(1).mean()\n    std_ep_all = rew_ep.sum(1).std()\n    print('all:', rew_ep_all, 'std:', std_ep_all)\n    env.close()\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--alg\",\n                        default=\"ToM_SAN1\", type=str,\n                        choices=['MADDPG', 'ToM_SB01', 'ToM_SA01', 'ToM_SBN1', 'ToM_SAN1',\n                            'ToM1'])\n    parser.add_argument(\"--env_id\", default='simple_adversary', type=str, help=\"Name of environment\",\n                        choices=['simple_tag', 'simple_world_comm', 'hetero_spread',\n                                 'simple_adversary', 'simple_spread'\n                                 ])\n    parser.add_argument(\"--num_good_agents\", default=None, type=int,\n                        help=\"Num of Agent\")\n    parser.add_argument(\"--num_adversaries\", default=None, type=int,\n                        help=\"Num of Adversary\")\n    parser.add_argument(\"--model_name\", default='4VS2_tomaAN1', type=str,\n                        help=\"Name of model\")   #ma2c, maddpg, maddpg_rnn\n    parser.add_argument(\"--run_num\", default=2, type=int, nargs='+')\n    parser.add_argument(\"--save_gifs\", default=True, action=\"store_true\",\n                        help=\"Saves gif of each episode into model directory\")\n    parser.add_argument(\"--incremental\", default= None, type=int,\n                        help=\"Load incremental policy from given episode \" +\n                             \"rather than final policy\")\n    parser.add_argument(\"--n_episodes\", default=1000, type=int)\n    parser.add_argument(\"--episode_length\", default=25, type=int)\n    parser.add_argument(\"--fps\", default=30, type=int)\n    parser.add_argument(\"--eval\",\n                        default=True, type=bool,\n                        )\n    parser.add_argument(\"--num\", default=3, type=int )\n    parser.add_argument(\"--n_rollout_threads\", default=20, type=int)\n    parser.add_argument(\"--discrete_action\",\n                        # default=False, type=bool,\n                        action='store_true')\n\n\n    config = parser.parse_args()\n\n\n    if 'ToM_SB' in config.alg:\n        config.agent_alg = 'without_tom'\n        config.adversary_alg = 'with_tom'\n    elif 'ToM_SA' in config.alg:\n        config.agent_alg = 'with_tom'\n        config.adversary_alg = 'without_tom'\n    elif config.alg == 'ToM1':\n        config.agent_alg = 'with_tom'\n        config.adversary_alg = 'with_tom'\n    else:\n        config.agent_alg = config.alg\n        config.adversary_alg = config.alg\n\n    if config.num_good_agents == None and config.num_adversaries == None:\n        if config.env_id == 'simple_adversary':\n            config.num_good_agents = 2\n            config.num_adversaries = 1\n        elif config.env_id == 'simple_tag':\n            config.num_good_agents = 2\n            config.num_adversaries = 2\n        elif config.env_id == 'simple_world_comm':\n            config.num_good_agents = 2\n            config.num_adversaries = 4\n        elif config.env_id == 'simple_spread':\n            config.num_good_agents = 3\n            config.num_adversaries = 0\n    run(config)\n"
  },
  {
    "path": "examples/Social_Cognition/FOToM/main.py",
    "content": "import argparse\nimport torch\nimport time\nimport os\nimport numpy as np\nfrom gym.spaces import Box, Discrete, MultiDiscrete\nfrom pathlib import Path\nfrom torch.autograd import Variable\nfrom tensorboardX import SummaryWriter\nfrom utils.make_env import make_env\nfrom utils.buffer import ReplayBuffer, ReplayBuffer_pre\nfrom utils.env_wrappers import SubprocVecEnv, DummyVecEnv\nfrom algorithms.tomN1 import ToM_decisionN1\nfrom algorithms.tom01 import ToM_decision01\nfrom algorithms.tom11 import ToM_decision11\nfrom algorithms.maddpg import MADDPG\nfrom tqdm import tqdm\n\nfrom thop import profile\nfrom thop import clever_format\n\ndef get_common_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--env_id\", default='simple_tag', type=str,\n                        choices=['simple_tag', 'simple_adversary', 'hetero_spread', 'simple_world_comm',\n                                 'simple_spread'],\n                        help=\"Name of environment\")\n    parser.add_argument(\"--num_good_agents\", default=None, type=int,\n                        help=\"Num of Agent\")\n    parser.add_argument(\"--num_adversaries\", default=None, type=int,\n                        help=\"Num of Adversary\")\n    parser.add_argument(\"--model_name\", default='ann', type=str,\n                        help=\"Name of directory to store \" +\n                             \"model/training contents\") #ToM_SA\n    parser.add_argument(\"--seed\",\n                        default=1, type=int,\n                        help=\"Random seed\")\n    parser.add_argument(\"--cuda_num\",\n                        default=5, type=int,\n                        help=\"device\")\n    parser.add_argument(\"--output_style\",\n                        default='sum', type=str,\n                        choices=['sum', 'voltage'])\n    parser.add_argument(\"--n_rollout_threads\", default=20, type=int)\n    parser.add_argument(\"--n_training_threads\", default=6, type=int)\n    parser.add_argument(\"--buffer_length\", default=int(1e6), type=int)  #1e6\n    parser.add_argument(\"--n_episodes\", default=25000, type=int)#25000\n    parser.add_argument(\"--episode_length\", default=25, type=int)\n    parser.add_argument(\"--steps_per_update\", default=100, type=int)\n    parser.add_argument(\"--load_para\", type=bool, default=False)\n    parser.add_argument(\"--batch_size\",\n                        default=1024, type=int,#4\n                        help=\"Batch size for model training\")\n    parser.add_argument(\"--n_exploration_eps\", default=25000, type=int)\n    parser.add_argument(\"--init_noise_scale\", default=0.3, type=float)\n    parser.add_argument(\"--final_noise_scale\", default=0.0, type=float)\n    parser.add_argument(\"--save_interval\", default=1000, type=int)#1000\n    parser.add_argument(\"--hidden_dim\", default=64, type=int)\n    parser.add_argument(\"--lr\", default=1e-3, type=float)   #0.01 #1e-3\n    parser.add_argument(\"--tau\", default=0.01, type=float)\n    parser.add_argument(\"--alg\",\n                        default=\"ToM1\", type=str,\n                        choices=['MADDPG',\n                            'ToM1', 'ToM_SB01', 'ToM_SA01', 'ToM_SBN1', 'ToM_SAN1'])\n\n    parser.add_argument(\"--discrete_action\",\n                        # default=False, type=bool,\n                        action='store_true')\n    args = parser.parse_args()\n    parser.add_argument('--device', type=str, default='cuda:{}'.format(args.cuda_num), help='whether to use the GPU')  #'cuda:1'\n    parser = parser.parse_args()\n    return parser\n\nUSE_CUDA = torch.cuda.is_available()\n\ndef make_parallel_env(env_id, n_rollout_threads, seed, discrete_action, num_good_agents, num_adversaries):\n    def get_env_fn(rank):\n        def init_env():\n            env = make_env(env_id, num_good_agents=num_good_agents, num_adversaries=num_adversaries, discrete_action=discrete_action)\n            env.seed(seed + rank * 1000)\n            np.random.seed(seed + rank * 1000)\n            return env\n        return init_env\n    if n_rollout_threads == 1:\n        return DummyVecEnv([get_env_fn(0)])\n    else:\n        return SubprocVecEnv([get_env_fn(i) for i in range(n_rollout_threads)])\n\ndef run(config):\n    pbar = tqdm(config.n_episodes)\n    model_dir = Path('./models') / config.env_id / config.model_name\n    if not model_dir.exists():\n        curr_run = 'run1'\n    else:\n        exst_run_nums = [int(str(folder.name).split('run')[1]) for folder in\n                         model_dir.iterdir() if\n                         str(folder.name).startswith('run')]\n        if len(exst_run_nums) == 0:\n            curr_run = 'run1'\n        else:\n            curr_run = 'run%i' % (max(exst_run_nums) + 1)\n    run_dir = model_dir / curr_run\n    log_dir = run_dir / 'logs'\n    os.makedirs(log_dir)\n    logger = SummaryWriter(str(log_dir))\n\n    save_path = log_dir\n    print(os.path.exists(save_path))\n    argsDict = config.__dict__\n    with open(save_path / 'args_{}'.format(max(exst_run_nums) + 1), 'w') as f:\n        f.writelines('------------------ start ------------------' + '\\n')\n        for eachArg, value in argsDict.items():\n            f.writelines(eachArg + ' : ' + str(value) + '\\n')\n        f.writelines('------------------- end -------------------')\n\n    torch.manual_seed(config.seed)\n    np.random.seed(config.seed)\n    if not USE_CUDA:\n        torch.set_num_threads(config.n_training_threads)\n    env = make_parallel_env(config.env_id, config.n_rollout_threads, config.seed,\n                            config.discrete_action, config.num_good_agents, config.num_adversaries)\n    if config.alg == 'ToM1':\n        print('_______Alg: ' + config.alg + '_______')\n        if  config.load_para == False:\n            maddpg = ToM_decision11.init_from_env(env, config, agent_alg=config.agent_alg,\n                                          adversary_alg=config.adversary_alg,\n                                          tau=config.tau,\n                                          lr=config.lr,\n                                          hidden_dim=config.hidden_dim,\n                                          output_style=config.output_style,\n                                          device=config.device)\n        else:\n            # model_path = (Path('./models') / config.env_id / 'self1' /  #'self1' tag, maddpg_self_11 adv\n            #               ('run%i' % 1)) / 'model.pt'\n            model_path = (Path('./models') / config.env_id / 'rs1' /  #'self1' tag, maddpg_self_11 adv\n                          ('run%i' % 1)) / 'model.pt'\n            maddpg = ToM_decision11.init_from_save(model_path)\n    elif config.alg == 'ToM_SB01' or config.alg== 'ToM_SA01':\n        print('_______Alg: ' + config.alg + '_______')\n        if config.load_para == False:\n            maddpg = ToM_decision01.init_from_env(env, config,  agent_alg=config.agent_alg,\n                                      adversary_alg=config.adversary_alg,\n                                      tau=config.tau,\n                                      lr=config.lr,\n                                      hidden_dim=config.hidden_dim,\n                                      output_style=config.output_style,\n                                      device=config.device)\n        else:\n            model_path = (Path('./models') / config.env_id / 'am' /  #'self1' tag, maddpg_self_11 adv\n                          ('run%i' % 3)) / 'model.pt'\n            maddpg = ToM_decision01.init_from_save(model_path)\n    elif config.alg == 'ToM_SBN1' or config.alg== 'ToM_SAN1':\n        print('_______Alg: ' + config.alg + '_______')\n        if config.load_para == False:\n            maddpg = ToM_decisionN1.init_from_env(env, config,  agent_alg=config.agent_alg,\n                                      adversary_alg=config.adversary_alg,\n                                      tau=config.tau,\n                                      lr=config.lr,\n                                      hidden_dim=config.hidden_dim,\n                                      output_style=config.output_style,\n                                      device=config.device)\n        else:\n            model_path = (Path('./models') / config.env_id / 'am' /  #'self1' tag, maddpg_self_11 adv\n                          ('run%i' % 3)) / 'model.pt'\n            maddpg = ToM_decisionN1.init_from_save(model_path)\n    elif config.alg == 'MADDPG':\n        print('_______Alg: ' + config.alg + '_______')\n        if  config.load_para == False:\n            maddpg = MADDPG.init_from_env(env, agent_alg=config.agent_alg,\n                                          adversary_alg=config.adversary_alg,\n                                          tau=config.tau,\n                                          lr=config.lr,\n                                          hidden_dim=config.hidden_dim,\n                                          device=config.device)\n\n\n    if config.alg == 'ToM1' or config.alg == 'ToM_SB01' or config.alg == 'ToM_SA01' \\\n            or config.alg == 'ToM_SBN1' or config.alg == 'ToM_SAN1':\n        replay_buffer = ReplayBuffer_pre(config.buffer_length, maddpg.nagents,\n                                     [obsp.shape[0] for obsp in env.observation_space],\n                                     [acsp.n if isinstance(acsp, Discrete) else sum(acsp.high - acsp.low + 1)\n                                      for acsp in env.action_space],\n                                 device=config.device)\n    else:\n        replay_buffer = ReplayBuffer(config.buffer_length, maddpg.nagents,\n                                     [obsp.shape[0] for obsp in env.observation_space],\n                                     [acsp.n if isinstance(acsp, Discrete) else sum(acsp.high - acsp.low + 1)\n                                      for acsp in env.action_space],\n                                 device=config.device)\n\n    t = 0\n    total_reward = []\n    for agent_i in range(maddpg.nagents):\n        total_reward.append([])\n    for ep_i in range(0, config.n_episodes, config.n_rollout_threads):\n        # print(\"Episodes %i-%i of %i\" % (ep_i + 1,\n        #                                 ep_i + 1 + config.n_rollout_threads,\n        #                                 config.n_episodes))\n        obs = env.reset()\n        # obs.shape = (n_rollout_threads, nagent)(nobs), nobs differs per agent so not tensor\n        maddpg.prep_rollouts(device='cpu')\n        torch_agent_actions = [torch.zeros((config.n_rollout_threads, 5)) for i in range(maddpg.nagents)]\n        explr_pct_remaining = max(0, config.n_exploration_eps - ep_i) / config.n_exploration_eps\n        maddpg.scale_noise(config.final_noise_scale + (config.init_noise_scale - config.final_noise_scale) * explr_pct_remaining)\n        if config.alg == 'rnnD' or config.alg == 'rnn':\n            maddpg._init_agent(config.n_rollout_threads)\n        maddpg.reset_noise()\n        obs_ep = []\n        agent_actions_ep = []\n        rewards_ep = []\n        next_obs_ep = []\n        dones_ep = []\n\n        for et_i in range(config.episode_length):\n            torch_agent_actions_pre = torch_agent_actions\n            torch_agent_actions_pre = [ac.data.numpy() for ac in torch_agent_actions_pre]\n            # rearrange observations to be per agent, and convert to torch Variable\n            torch_obs = [Variable(torch.Tensor(np.vstack(obs[:, i])),\n                                  requires_grad=False)\n                         for i in range(maddpg.nagents)]    #\n            # get actions as torch Variables\n            # t1 = time.time()\n            if config.alg == 'ToM1' or config.alg == 'ToM_SB01' or config.alg == 'ToM_SA01' \\\n                    or config.alg == 'ToM_SBN1' or config.alg == 'ToM_SAN1':\n                torch_agent_actions = maddpg.step(torch_obs, torch_agent_actions, explore=True)\n            else:\n                torch_agent_actions = maddpg.step(torch_obs, explore=True)\n            # t2 = time.time()\n            # print('maddpg.step time:', t2-t1)\n\n            agent_actions = [ac.data.numpy() for ac in torch_agent_actions] #\n            # rearrange actions to be per environment\n            actions = [[ac[i] for ac in agent_actions] for i in range(config.n_rollout_threads)]\n            # t3 = time.time()\n            next_obs, rewards, dones, infos = env.step(actions)\n            # t4 = time.time()\n            # print('env.step')\n            obs_ep.append(obs)                  #episode_id,process, n_agents, dim\n            agent_actions_ep.append(actions)    #episode_id, n_agents, process, dim\n            rewards_ep.append(rewards)          #episode_id,process, n_agents,\n            next_obs_ep.append(next_obs)            #episode_id,process, n_agents, dim\n            dones_ep.append(dones)              #episode_id,process, n_agents,\n            if config.alg == 'ToM1' or config.alg == 'ToM_SB01' or config.alg == 'ToM_SA01' \\\n                    or config.alg == 'ToM_SBN1' or config.alg == 'ToM_SAN1':\n                replay_buffer.push(torch_agent_actions_pre, obs, agent_actions, rewards, next_obs, dones)\n            else:\n                replay_buffer.push(obs, agent_actions, rewards, next_obs, dones)\n            obs = next_obs\n            t += config.n_rollout_threads\n\n            if (len(replay_buffer) >= config.batch_size and\n                (t % config.steps_per_update) < config.n_rollout_threads):\n                if USE_CUDA:\n                    maddpg.prep_training(device='gpu')\n                else:\n                    maddpg.prep_training(device='cpu')\n                if config.n_episodes >300:\n                    rollout = 2\n                else:\n                    rollout = config.n_rollout_threads\n                for u_i in range(rollout):\n                    sample = replay_buffer.sample(config.batch_size,\n                                                  to_gpu=USE_CUDA)\n                    if '1' in config.alg:\n                        maddpg.tom0_output(sample)\n                        maddpg.tom1_infer_other(sample)\n                    for a_i in range(maddpg.nagents):\n                        # t1 = time.time()\n                        maddpg.update(sample, a_i, logger=logger)\n                        # t2 = time.time()\n                        # print('trian_time:', t2-t1, u_i, a_i)\n                    maddpg.update_all_targets()\n                maddpg.prep_rollouts(device='cpu')\n            if config.alg == 'rnnD' or config.alg == 'rnn':\n                maddpg._init_agent(config.n_rollout_threads)\n        ep_rews = replay_buffer.get_average_rewards(\n            config.episode_length * config.n_rollout_threads)\n        for a_i, a_ep_rew in enumerate(ep_rews):\n            logger.add_scalar('agent%i/mean_episode_rewards' % a_i,\n                              a_ep_rew,\n                              ep_i)\n        logger.add_scalar('agent_mean/mean_episode_rewards',\n                          np.mean(ep_rews),\n                          ep_i)\n\n        if ep_i % config.save_interval < config.n_rollout_threads:\n            os.makedirs(run_dir / 'incremental', exist_ok=True)\n            maddpg.save(run_dir / 'incremental' / ('model_ep%i.pt' % (ep_i + 1)))\n            maddpg.save(run_dir / 'model.pt')\n        pbar.update(config.n_rollout_threads)\n\n    pbar.close()\n    maddpg.save(run_dir / 'model.pt')\n    env.close()\n    logger.export_scalars_to_json(str(log_dir / 'summary.json'))\n    logger.close()\n    for a_i, reward in enumerate(total_reward):\n        reward_dir = str(log_dir) + '/agent{}/mean_episode_rewards'.format(a_i) + '/episode_rewards_{}'.format(config.cuda_num)\n        os.makedirs(reward_dir)\n        np.save(reward_dir, reward)\n\n\nif __name__ == '__main__':\n    config = get_common_args()\n    # config.env_id = 'simple_world_comm'#'simple_adversary'#'simple_tag'\n    # # config.model_name = 'ma2c'\n\n\n\n    if 'ToM_SB' in config.alg:\n        config.agent_alg = 'without_tom'\n        config.adversary_alg = 'with_tom'\n    elif 'ToM_SA' in config.alg:\n        config.agent_alg = 'with_tom'\n        config.adversary_alg = 'without_tom'\n    elif config.alg == 'ToM1':\n        config.agent_alg = 'with_tom'\n        config.adversary_alg = 'with_tom'\n    else:\n        config.agent_alg = config.alg\n        config.adversary_alg = config.alg\n\n\n    # debug\n    if config.num_good_agents == None and config.num_adversaries == None:\n        if config.env_id == 'simple_adversary':\n            config.num_good_agents = 2\n            config.num_adversaries = 1\n        elif config.env_id == 'simple_tag':\n            config.num_good_agents = 1\n            config.num_adversaries = 3\n        elif config.env_id == 'simple_world_comm':\n            config.num_good_agents = 2\n            config.num_adversaries = 4\n        elif config.env_id == 'hetero_spread':\n            config.num_good_agents = 4\n            config.num_adversaries = 0\n        elif config.env_id == 'simple_spread':  #coop\n            config.num_good_agents = 3\n            config.num_adversaries = 0\n\n    run(config)\n"
  },
  {
    "path": "examples/Social_Cognition/FOToM/multiagent/__init__.py",
    "content": "from gym.envs.registration import register\n\n# Multiagent envs\n# ----------------------------------------\n\nregister(\n    id='MultiagentSimple-v0',\n    entry_point='multiagent.envs:SimpleEnv',\n    # FIXME(cathywu) currently has to be exactly max_path_length parameters in\n    # rllab run script\n    max_episode_steps=100,\n)\n\nregister(\n    id='MultiagentSimpleSpeakerListener-v0',\n    entry_point='multiagent.envs:SimpleSpeakerListenerEnv',\n    max_episode_steps=100,\n)\n"
  },
  {
    "path": "examples/Social_Cognition/FOToM/multiagent/core.py",
    "content": "import numpy as np\n\n# physical/external base state of all entites\nclass EntityState(object):\n    def __init__(self):\n        # physical position\n        self.p_pos = None\n        # physical velocity\n        self.p_vel = None\n\n# state of agents (including communication and internal/mental state)\nclass AgentState(EntityState):\n    def __init__(self):\n        super(AgentState, self).__init__()\n        # communication utterance\n        self.c = None\n\n# action of the agent\nclass Action(object):\n    def __init__(self):\n        # physical action\n        self.u = None\n        # communication action\n        self.c = None\n\n# properties and state of physical world entity\nclass Entity(object):\n    def __init__(self):\n        # name \n        self.name = ''\n        # properties:\n        self.size = 0.050\n        # entity can move / be pushed\n        self.movable = False\n        # entity collides with others\n        self.collide = True\n        # material density (affects mass)\n        self.density = 25.0\n        # color\n        self.color = None\n        # max speed and accel\n        self.max_speed = None\n        self.accel = None\n        # state\n        self.state = EntityState()\n        # mass\n        self.initial_mass = 1.0\n\n    @property\n    def mass(self):\n        return self.initial_mass\n\n# properties of landmark entities\nclass Landmark(Entity):\n     def __init__(self):\n        super(Landmark, self).__init__()\n\n# properties of agent entities\nclass Agent(Entity):\n    def __init__(self):\n        super(Agent, self).__init__()\n        # agents are movable by default\n        self.movable = True\n        # cannot send communication signals\n        self.silent = False\n        # cannot observe the world\n        self.blind = False\n        # physical motor noise amount\n        self.u_noise = None\n        # communication noise amount\n        self.c_noise = None\n        # control range\n        self.u_range = 1.0\n        # state\n        self.state = AgentState()\n        # action\n        self.action = Action()\n        # script behavior to execute\n        self.action_callback = None\n\n# multi-agent world\nclass World(object):\n    def __init__(self):\n        # list of agents and entities (can change at execution-time!)\n        self.agents = []\n        self.landmarks = []\n        # communication channel dimensionality\n        self.dim_c = 0\n        # position dimensionality\n        self.dim_p = 2\n        # color dimensionality\n        self.dim_color = 3\n        # simulation timestep\n        self.dt = 0.1\n        # physical damping\n        self.damping = 0.25\n        # contact response parameters\n        self.contact_force = 1e+2\n        self.contact_margin = 1e-3\n\n    # return all entities in the world\n    @property\n    def entities(self):\n        return self.agents + self.landmarks\n\n    # return all agents controllable by external policies\n    @property\n    def policy_agents(self):\n        return [agent for agent in self.agents if agent.action_callback is None]\n\n    # return all agents controlled by world scripts\n    @property\n    def scripted_agents(self):\n        return [agent for agent in self.agents if agent.action_callback is not None]\n\n    # update state of the world\n    def step(self):\n        # set actions for scripted agents \n        for agent in self.scripted_agents:\n            agent.action = agent.action_callback(agent, self)\n        # gather forces applied to entities\n        p_force = [None] * len(self.entities)\n        # apply agent physical controls\n        p_force = self.apply_action_force(p_force)\n        # apply environment forces\n        p_force = self.apply_environment_force(p_force)\n        # integrate physical state\n        self.integrate_state(p_force)\n        # update agent state\n        for agent in self.agents:\n            self.update_agent_state(agent)\n\n    # gather agent action forces\n    def apply_action_force(self, p_force):\n        # set applied forces\n        for i,agent in enumerate(self.agents):\n            if agent.movable:\n                noise = np.random.randn(*agent.action.u.shape) * agent.u_noise if agent.u_noise else 0.0\n                p_force[i] = agent.action.u + noise                \n        return p_force\n\n    # gather physical forces acting on entities\n    def apply_environment_force(self, p_force):\n        # simple (but inefficient) collision response\n        for a,entity_a in enumerate(self.entities):\n            for b,entity_b in enumerate(self.entities):\n                if(b <= a): continue\n                [f_a, f_b] = self.get_collision_force(entity_a, entity_b)\n                if(f_a is not None):\n                    if(p_force[a] is None): p_force[a] = 0.0\n                    p_force[a] = f_a + p_force[a] \n                if(f_b is not None):\n                    if(p_force[b] is None): p_force[b] = 0.0\n                    p_force[b] = f_b + p_force[b]        \n        return p_force\n\n    # integrate physical state\n    def integrate_state(self, p_force):\n        for i,entity in enumerate(self.entities):\n            if not entity.movable: continue\n            entity.state.p_vel = entity.state.p_vel * (1 - self.damping)\n            if (p_force[i] is not None):\n                entity.state.p_vel += (p_force[i] / entity.mass) * self.dt\n            if entity.max_speed is not None:\n                speed = np.sqrt(np.square(entity.state.p_vel[0]) + np.square(entity.state.p_vel[1]))\n                if speed > entity.max_speed:\n                    entity.state.p_vel = entity.state.p_vel / np.sqrt(np.square(entity.state.p_vel[0]) +\n                                                                  np.square(entity.state.p_vel[1])) * entity.max_speed\n            entity.state.p_pos += entity.state.p_vel * self.dt\n\n    def update_agent_state(self, agent):\n        # set communication state (directly for now)\n        if agent.silent:\n            agent.state.c = np.zeros(self.dim_c)\n        else:\n            noise = np.random.randn(*agent.action.c.shape) * agent.c_noise if agent.c_noise else 0.0\n            agent.state.c = agent.action.c + noise      \n\n    # get collision forces for any contact between two entities\n    def get_collision_force(self, entity_a, entity_b):\n        if (not entity_a.collide) or (not entity_b.collide):\n            return [None, None] # not a collider\n        if (entity_a is entity_b):\n            return [None, None] # don't collide against itself\n        # compute actual distance between entities\n        delta_pos = entity_a.state.p_pos - entity_b.state.p_pos\n        dist = np.sqrt(np.sum(np.square(delta_pos)))\n        # minimum allowable distance\n        dist_min = entity_a.size + entity_b.size\n        # softmax penetration\n        k = self.contact_margin\n        penetration = np.logaddexp(0, -(dist - dist_min)/k)*k\n        force = self.contact_force * delta_pos / dist * penetration\n        force_a = +force if entity_a.movable else None\n        force_b = -force if entity_b.movable else None\n        return [force_a, force_b]"
  },
  {
    "path": "examples/Social_Cognition/FOToM/multiagent/environment.py",
    "content": "import gym\nfrom gym import spaces\nfrom gym.envs.registration import EnvSpec\nimport numpy as np\nfrom multiagent.multi_discrete import MultiDiscrete\n\n# environment for all agents in the multiagent world\n# currently code assumes that no agents will be created/destroyed at runtime!\nclass MultiAgentEnv(gym.Env):\n    metadata = {\n        'render.modes' : ['human', 'rgb_array']\n    }\n\n    def __init__(self, world, reset_callback=None, reward_callback=None,\n                 observation_callback=None, info_callback=None,\n                 done_callback=None, shared_viewer=True):\n\n        self.world = world\n        self.agents = self.world.policy_agents\n        # set required vectorized gym env property\n        self.n = len(world.policy_agents)\n        # scenario callbacks\n        self.reset_callback = reset_callback\n        self.reward_callback = reward_callback\n        self.observation_callback = observation_callback\n        self.info_callback = info_callback\n        self.done_callback = done_callback\n        # environment parameters\n        self.discrete_action_space = True\n        # if true, action is a number 0...N, otherwise action is a one-hot N-dimensional vector\n        self.discrete_action_input = False\n        # if true, even the action is continuous, action will be performed discretely\n        self.force_discrete_action = world.discrete_action if hasattr(world, 'discrete_action') else False\n        # if true, every agent has the same reward\n        self.shared_reward = world.collaborative if hasattr(world, 'collaborative') else False\n        self.time = 0\n\n        # configure spaces\n        self.action_space = []\n        self.observation_space = []\n        for agent in self.agents:\n            total_action_space = []\n            # physical action space\n            if self.discrete_action_space:\n                u_action_space = spaces.Discrete(world.dim_p * 2 + 1)\n            else:\n                u_action_space = spaces.Box(low=-agent.u_range, high=+agent.u_range, shape=(world.dim_p,), dtype=np.float32)\n            if agent.movable:\n                total_action_space.append(u_action_space)\n            # communication action space\n            if self.discrete_action_space:\n                c_action_space = spaces.Discrete(world.dim_c)\n            else:\n                c_action_space = spaces.Box(low=0.0, high=1.0, shape=(world.dim_c,), dtype=np.float32)\n            if not agent.silent:\n                total_action_space.append(c_action_space)\n            # total action space\n            if len(total_action_space) > 1:\n                # all action spaces are discrete, so simplify to MultiDiscrete action space\n                if all([isinstance(act_space, spaces.Discrete) for act_space in total_action_space]):\n                    act_space = MultiDiscrete([[0, act_space.n - 1] for act_space in total_action_space])\n                else:\n                    act_space = spaces.Tuple(total_action_space)\n                self.action_space.append(act_space)\n            else:\n                self.action_space.append(total_action_space[0])\n            # observation space\n            obs_dim = len(observation_callback(agent, self.world))\n            self.observation_space.append(spaces.Box(low=-np.inf, high=+np.inf, shape=(obs_dim,), dtype=np.float32))\n            agent.action.c = np.zeros(self.world.dim_c)\n\n        # rendering\n        self.shared_viewer = shared_viewer\n        if self.shared_viewer:\n            self.viewers = [None]\n        else:\n            self.viewers = [None] * self.n\n        self._reset_render()\n\n    def step(self, action_n):\n        obs_n = []\n        reward_n = []\n        done_n = []\n        info_n = {'n': []}\n        self.agents = self.world.policy_agents\n        # set action for each agent\n        for i, agent in enumerate(self.agents):\n            self._set_action(action_n[i], agent, self.action_space[i])\n        # advance world state\n        self.world.step()\n        # record observation for each agent\n        for agent in self.agents:\n            obs_n.append(self._get_obs(agent))\n            reward_n.append(self._get_reward(agent))\n            done_n.append(self._get_done(agent))\n\n            info_n['n'].append(self._get_info(agent))\n\n        # all agents get total reward in cooperative case\n        reward = np.sum(reward_n)\n        if self.shared_reward:\n            reward_n = [reward] * self.n\n\n        return obs_n, reward_n, done_n, info_n\n\n    def reset(self):\n        # reset world\n        self.reset_callback(self.world)\n        # reset renderer\n        self._reset_render()\n        # record observations for each agent\n        obs_n = []\n        self.agents = self.world.policy_agents\n        for agent in self.agents:\n            obs_n.append(self._get_obs(agent))\n        return obs_n\n\n    # get info used for benchmarking\n    def _get_info(self, agent):\n        if self.info_callback is None:\n            return {}\n        return self.info_callback(agent, self.world)\n\n    # get observation for a particular agent\n    def _get_obs(self, agent):\n        if self.observation_callback is None:\n            return np.zeros(0)\n        return self.observation_callback(agent, self.world)\n\n    # get dones for a particular agent\n    # unused right now -- agents are allowed to go beyond the viewing screen\n    def _get_done(self, agent):\n        if self.done_callback is None:\n            return False\n        return self.done_callback(agent, self.world)\n\n    # get reward for a particular agent\n    def _get_reward(self, agent):\n        if self.reward_callback is None:\n            return 0.0\n        return self.reward_callback(agent, self.world)\n\n    # set env action for a particular agent\n    def _set_action(self, action, agent, action_space, time=None):\n        agent.action.u = np.zeros(self.world.dim_p)\n        agent.action.c = np.zeros(self.world.dim_c)\n        # process action\n        if isinstance(action_space, MultiDiscrete):\n            act = []\n            size = action_space.high - action_space.low + 1\n            index = 0\n            for s in size:\n                act.append(action[index:(index+s)])\n                index += s\n            action = act\n        else:\n            action = [action]\n\n        if agent.movable:\n            # physical action\n            if self.discrete_action_input:\n                agent.action.u = np.zeros(self.world.dim_p)\n                # process discrete action\n                if action[0] == 1: agent.action.u[0] = -1.0\n                if action[0] == 2: agent.action.u[0] = +1.0\n                if action[0] == 3: agent.action.u[1] = -1.0\n                if action[0] == 4: agent.action.u[1] = +1.0\n            else:\n                if self.force_discrete_action:\n                    d = np.argmax(action[0])\n                    action[0][:] = 0.0\n                    action[0][d] = 1.0\n                if self.discrete_action_space:\n                    agent.action.u[0] += action[0][1] - action[0][2]\n                    agent.action.u[1] += action[0][3] - action[0][4]\n                else:\n                    agent.action.u = action[0]\n            sensitivity = 5.0\n            if agent.accel is not None:\n                sensitivity = agent.accel\n            agent.action.u *= sensitivity\n            action = action[1:]\n        if not agent.silent:\n            # communication action\n            if self.discrete_action_input:\n                agent.action.c = np.zeros(self.world.dim_c)\n                agent.action.c[action[0]] = 1.0\n            else:\n                agent.action.c = action[0]\n            action = action[1:]\n        # make sure we used all elements of action\n        assert len(action) == 0\n\n    # reset rendering assets\n    def _reset_render(self):\n        self.render_geoms = None\n        self.render_geoms_xform = None\n\n    # render environment\n    def render(self, mode='human'):\n        if mode == 'human':\n            alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'\n            message = ''\n            for agent in self.world.agents:\n                comm = []\n                for other in self.world.agents:\n                    if other is agent: continue\n                    if np.all(other.state.c == 0):\n                        word = '_'\n                    else:\n                        word = alphabet[np.argmax(other.state.c)]\n                    message += (other.name + ' to ' + agent.name + ': ' + word + '   ')\n            print(message)\n\n        for i in range(len(self.viewers)):\n            # create viewers (if necessary)\n            if self.viewers[i] is None:\n                # import rendering only if we need it (and don't import for headless machines)\n                #from gym.envs.classic_control import rendering\n                from multiagent import rendering\n                self.viewers[i] = rendering.Viewer(700,700)\n\n        # create rendering geometry\n        if self.render_geoms is None:\n            # import rendering only if we need it (and don't import for headless machines)\n            #from gym.envs.classic_control import rendering\n            from multiagent import rendering\n            self.render_geoms = []\n            self.render_geoms_xform = []\n            for entity in self.world.entities:\n                geom = rendering.make_circle(entity.size)\n                xform = rendering.Transform()\n                if 'agent' in entity.name:\n                    geom.set_color(*entity.color, alpha=0.5)\n                else:\n                    geom.set_color(*entity.color)\n                geom.add_attr(xform)\n                self.render_geoms.append(geom)\n                self.render_geoms_xform.append(xform)\n\n            # add geoms to viewer\n            for viewer in self.viewers:\n                viewer.geoms = []\n                for geom in self.render_geoms:\n                    viewer.add_geom(geom)\n\n        results = []\n        for i in range(len(self.viewers)):\n            from multiagent import rendering\n            # update bounds to center around agent\n            cam_range = 1\n            if self.shared_viewer:\n                pos = np.zeros(self.world.dim_p)\n            else:\n                pos = self.agents[i].state.p_pos\n            self.viewers[i].set_bounds(pos[0]-cam_range,pos[0]+cam_range,pos[1]-cam_range,pos[1]+cam_range)\n            # update geometry positions\n            for e, entity in enumerate(self.world.entities):\n                self.render_geoms_xform[e].set_translation(*entity.state.p_pos)\n            # render to display or array\n            results.append(self.viewers[i].render(return_rgb_array = mode=='rgb_array'))\n\n        return results\n\n    # create receptor field locations in local coordinate frame\n    def _make_receptor_locations(self, agent):\n        receptor_type = 'polar'\n        range_min = 0.05 * 2.0\n        range_max = 1.00\n        dx = []\n        # circular receptive field\n        if receptor_type == 'polar':\n            for angle in np.linspace(-np.pi, +np.pi, 8, endpoint=False):\n                for distance in np.linspace(range_min, range_max, 3):\n                    dx.append(distance * np.array([np.cos(angle), np.sin(angle)]))\n            # add origin\n            dx.append(np.array([0.0, 0.0]))\n        # grid receptive field\n        if receptor_type == 'grid':\n            for x in np.linspace(-range_max, +range_max, 5):\n                for y in np.linspace(-range_max, +range_max, 5):\n                    dx.append(np.array([x,y]))\n        return dx\n\n\n# vectorized wrapper for a batch of multi-agent environments\n# assumes all environments have the same observation and action space\nclass BatchMultiAgentEnv(gym.Env):\n    metadata = {\n        'runtime.vectorized': True,\n        'render.modes' : ['human', 'rgb_array']\n    }\n\n    def __init__(self, env_batch):\n        self.env_batch = env_batch\n\n    @property\n    def n(self):\n        return np.sum([env.n for env in self.env_batch])\n\n    @property\n    def action_space(self):\n        return self.env_batch[0].action_space\n\n    @property\n    def observation_space(self):\n        return self.env_batch[0].observation_space\n\n    def step(self, action_n, time):\n        obs_n = []\n        reward_n = []\n        done_n = []\n        info_n = {'n': []}\n        i = 0\n        for env in self.env_batch:\n            obs, reward, done, _ = env.step(action_n[i:(i+env.n)], time)\n            i += env.n\n            obs_n += obs\n            # reward = [r / len(self.env_batch) for r in reward]\n            reward_n += reward\n            done_n += done\n        return obs_n, reward_n, done_n, info_n\n\n    def reset(self):\n        obs_n = []\n        for env in self.env_batch:\n            obs_n += env.reset()\n        return obs_n\n\n    # render environment\n    def render(self, mode='human', close=True):\n        results_n = []\n        for env in self.env_batch:\n            results_n += env.render(mode, close)\n        return results_n\n"
  },
  {
    "path": "examples/Social_Cognition/FOToM/multiagent/multi_discrete.py",
    "content": "# An old version of OpenAI Gym's multi_discrete.py. (Was getting affected by Gym updates)\n# (https://github.com/openai/gym/blob/1fb81d4e3fb780ccf77fec731287ba07da35eb84/gym/spaces/multi_discrete.py)\n\nimport numpy as np\n\nimport gym\n# from gym.spaces import prng\n\nclass MultiDiscrete(gym.Space):\n    \"\"\"\n    - The multi-discrete action space consists of a series of discrete action spaces with different parameters\n    - It can be adapted to both a Discrete action space or a continuous (Box) action space\n    - It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space\n    - It is parametrized by passing an array of arrays containing [min, max] for each discrete action space\n       where the discrete action space can take any integers from `min` to `max` (both inclusive)\n    Note: A value of 0 always need to represent the NOOP action.\n    e.g. Nintendo Game Controller\n    - Can be conceptualized as 3 discrete action spaces:\n        1) Arrow Keys: Discrete 5  - NOOP[0], UP[1], RIGHT[2], DOWN[3], LEFT[4]  - params: min: 0, max: 4\n        2) Button A:   Discrete 2  - NOOP[0], Pressed[1] - params: min: 0, max: 1\n        3) Button B:   Discrete 2  - NOOP[0], Pressed[1] - params: min: 0, max: 1\n    - Can be initialized as\n        MultiDiscrete([ [0,4], [0,1], [0,1] ])\n    \"\"\"\n    def __init__(self, array_of_param_array):\n        self.low = np.array([x[0] for x in array_of_param_array])\n        self.high = np.array([x[1] for x in array_of_param_array])\n        self.num_discrete_space = self.low.shape[0]\n\n    def sample(self):\n        \"\"\" Returns a array with one sample from each discrete action space \"\"\"\n        # For each row: round(random .* (max - min) + min, 0)\n        # random_array = prng.np_random.rand(self.num_discrete_space)\n        random_array = np.random.RandomState().rand(self.num_discrete_space)\n        return [int(x) for x in np.floor(np.multiply((self.high - self.low + 1.), random_array) + self.low)]\n    def contains(self, x):\n        return len(x) == self.num_discrete_space and (np.array(x) >= self.low).all() and (np.array(x) <= self.high).all()\n\n    @property\n    def shape(self):\n        return self.num_discrete_space\n    def __repr__(self):\n        return \"MultiDiscrete\" + str(self.num_discrete_space)\n    def __eq__(self, other):\n        return np.array_equal(self.low, other.low) and np.array_equal(self.high, other.high)"
  },
  {
    "path": "examples/Social_Cognition/FOToM/multiagent/policy.py",
    "content": "import numpy as np\nfrom pyglet.window import key\n\n# individual agent policy\nclass Policy(object):\n    def __init__(self):\n        pass\n    def action(self, obs):\n        raise NotImplementedError()\n\n# interactive policy based on keyboard input\n# hard-coded to deal only with movement, not communication\nclass InteractivePolicy(Policy):\n    def __init__(self, env, agent_index):\n        super(InteractivePolicy, self).__init__()\n        self.env = env\n        # hard-coded keyboard events\n        self.move = [False for i in range(4)]\n        self.comm = [False for i in range(env.world.dim_c)]\n        # register keyboard events with this environment's window\n        # env.viewers[agent_index].window.on_key_press = self.key_press\n        # env.viewers[agent_index].window.on_key_release = self.key_release\n\n    def action(self, obs):\n        # ignore observation and just act based on keyboard events\n        if self.env.discrete_action_input:\n            u = 0\n            if self.move[0]: u = 1\n            if self.move[1]: u = 2\n            if self.move[2]: u = 4\n            if self.move[3]: u = 3\n        else:\n            u = np.zeros(5) # 5-d because of no-move action\n            if self.move[0]: u[1] += 1.0\n            if self.move[1]: u[2] += 1.0\n            if self.move[3]: u[3] += 1.0\n            if self.move[2]: u[4] += 1.0\n            if True not in self.move:\n                u[0] += 1.0\n        return np.concatenate([u, np.zeros(self.env.world.dim_c)])\n\n    # keyboard event callbacks\n    def key_press(self, k, mod):\n        if k==key.LEFT:  self.move[0] = True\n        if k==key.RIGHT: self.move[1] = True\n        if k==key.UP:    self.move[2] = True\n        if k==key.DOWN:  self.move[3] = True\n    def key_release(self, k, mod):\n        if k==key.LEFT:  self.move[0] = False\n        if k==key.RIGHT: self.move[1] = False\n        if k==key.UP:    self.move[2] = False\n        if k==key.DOWN:  self.move[3] = False\n"
  },
  {
    "path": "examples/Social_Cognition/FOToM/multiagent/rendering.py",
    "content": "\"\"\"\n2D rendering framework\n\"\"\"\nfrom __future__ import division\nimport os\nimport six\nimport sys\n\nif \"Apple\" in sys.version:\n    if 'DYLD_FALLBACK_LIBRARY_PATH' in os.environ:\n        os.environ['DYLD_FALLBACK_LIBRARY_PATH'] += ':/usr/lib'\n        # (JDS 2016/04/15): avoid bug on Anaconda 2.3.0 / Yosemite\n\n# from gym.utils import reraise\nfrom gym import error\nimport pyglet\nfrom pyglet.gl import *\n# try:\n    # import pyglet\n# except ImportError as e:\n#     reraise(suffix=\"HINT: you can install pyglet directly via 'pip install pyglet'. But if you really just want to install all Gym dependencies and not have to think about it, 'pip install -e .[all]' or 'pip install gym[all]' will do it.\")\n\n# try:\n    # from pyglet.gl import *\n# except ImportError as e:\n#     reraise(prefix=\"Error occured while running `from pyglet.gl import *`\",suffix=\"HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get install python-opengl'. If you're running on a server, you may need a virtual frame buffer; something like this should work: 'xvfb-run -s \\\"-screen 0 1400x900x24\\\" python <your_script.py>'\")\n\nimport math\nimport numpy as np\n\nRAD2DEG = 57.29577951308232\n\ndef get_display(spec):\n    \"\"\"Convert a display specification (such as :0) into an actual Display\n    object.\n\n    Pyglet only supports multiple Displays on Linux.\n    \"\"\"\n    if spec is None:\n        return None\n    elif isinstance(spec, six.string_types):\n        return pyglet.canvas.Display(spec)\n    else:\n        raise error.Error('Invalid display specification: {}. (Must be a string like :0 or None.)'.format(spec))\n\nclass Viewer(object):\n    def __init__(self, width, height, display=None):\n        display = get_display(display)\n\n        self.width = width\n        self.height = height\n\n        self.window = pyglet.window.Window(width=width, height=height, display=display)\n        self.window.on_close = self.window_closed_by_user\n        self.geoms = []\n        self.onetime_geoms = []\n        self.transform = Transform()\n\n        glEnable(GL_BLEND)\n        # glEnable(GL_MULTISAMPLE)\n        glEnable(GL_LINE_SMOOTH)\n        # glHint(GL_LINE_SMOOTH_HINT, GL_DONT_CARE)\n        glHint(GL_LINE_SMOOTH_HINT, GL_NICEST)\n        glLineWidth(2.0)\n        glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA)\n\n    def close(self):\n        self.window.close()\n\n    def window_closed_by_user(self):\n        self.close()\n\n    def set_bounds(self, left, right, bottom, top):\n        assert right > left and top > bottom\n        scalex = self.width/(right-left)\n        scaley = self.height/(top-bottom)\n        self.transform = Transform(\n            translation=(-left*scalex, -bottom*scaley),\n            scale=(scalex, scaley))\n\n    def add_geom(self, geom):\n        self.geoms.append(geom)\n\n    def add_onetime(self, geom):\n        self.onetime_geoms.append(geom)\n\n    def render(self, return_rgb_array=False):\n        glClearColor(1,1,1,1)\n        self.window.clear()\n        self.window.switch_to()\n        self.window.dispatch_events()\n        self.transform.enable()\n        for geom in self.geoms:\n            geom.render()\n        for geom in self.onetime_geoms:\n            geom.render()\n        self.transform.disable()\n        arr = None\n        if return_rgb_array:\n            buffer = pyglet.image.get_buffer_manager().get_color_buffer()\n            image_data = buffer.get_image_data()\n            arr = np.fromstring(image_data._current_data, dtype=np.uint8, sep='')\n            # In https://github.com/openai/gym-http-api/issues/2, we\n            # discovered that someone using Xmonad on Arch was having\n            # a window of size 598 x 398, though a 600 x 400 window\n            # was requested. (Guess Xmonad was preserving a pixel for\n            # the boundary.) So we use the buffer height/width rather\n            # than the requested one.\n            arr = arr.reshape(buffer.height, buffer.width, 4)\n            arr = arr[::-1,:,0:3]\n        self.window.flip()\n        self.onetime_geoms = []\n        return arr\n\n    # Convenience\n    def draw_circle(self, radius=10, res=30, filled=True, **attrs):\n        geom = make_circle(radius=radius, res=res, filled=filled)\n        _add_attrs(geom, attrs)\n        self.add_onetime(geom)\n        return geom\n\n    def draw_polygon(self, v, filled=True, **attrs):\n        geom = make_polygon(v=v, filled=filled)\n        _add_attrs(geom, attrs)\n        self.add_onetime(geom)\n        return geom\n\n    def draw_polyline(self, v, **attrs):\n        geom = make_polyline(v=v)\n        _add_attrs(geom, attrs)\n        self.add_onetime(geom)\n        return geom\n\n    def draw_line(self, start, end, **attrs):\n        geom = Line(start, end)\n        _add_attrs(geom, attrs)\n        self.add_onetime(geom)\n        return geom\n\n    def get_array(self):\n        self.window.flip()\n        image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()\n        self.window.flip()\n        arr = np.fromstring(image_data.data, dtype=np.uint8, sep='')\n        arr = arr.reshape(self.height, self.width, 4)\n        return arr[::-1,:,0:3]\n\ndef _add_attrs(geom, attrs):\n    if \"color\" in attrs:\n        geom.set_color(*attrs[\"color\"])\n    if \"linewidth\" in attrs:\n        geom.set_linewidth(attrs[\"linewidth\"])\n\nclass Geom(object):\n    def __init__(self):\n        self._color=Color((0, 0, 0, 1.0))\n        self.attrs = [self._color]\n    def render(self):\n        for attr in reversed(self.attrs):\n            attr.enable()\n        self.render1()\n        for attr in self.attrs:\n            attr.disable()\n    def render1(self):\n        raise NotImplementedError\n    def add_attr(self, attr):\n        self.attrs.append(attr)\n    def set_color(self, r, g, b, alpha=1):\n        self._color.vec4 = (r, g, b, alpha)\n\nclass Attr(object):\n    def enable(self):\n        raise NotImplementedError\n    def disable(self):\n        pass\n\nclass Transform(Attr):\n    def __init__(self, translation=(0.0, 0.0), rotation=0.0, scale=(1,1)):\n        self.set_translation(*translation)\n        self.set_rotation(rotation)\n        self.set_scale(*scale)\n    def enable(self):\n        glPushMatrix()\n        glTranslatef(self.translation[0], self.translation[1], 0) # translate to GL loc ppint\n        glRotatef(RAD2DEG * self.rotation, 0, 0, 1.0)\n        glScalef(self.scale[0], self.scale[1], 1)\n    def disable(self):\n        glPopMatrix()\n    def set_translation(self, newx, newy):\n        self.translation = (float(newx), float(newy))\n    def set_rotation(self, new):\n        self.rotation = float(new)\n    def set_scale(self, newx, newy):\n        self.scale = (float(newx), float(newy))\n\nclass Color(Attr):\n    def __init__(self, vec4):\n        self.vec4 = vec4\n    def enable(self):\n        glColor4f(*self.vec4)\n\nclass LineStyle(Attr):\n    def __init__(self, style):\n        self.style = style\n    def enable(self):\n        glEnable(GL_LINE_STIPPLE)\n        glLineStipple(1, self.style)\n    def disable(self):\n        glDisable(GL_LINE_STIPPLE)\n\nclass LineWidth(Attr):\n    def __init__(self, stroke):\n        self.stroke = stroke\n    def enable(self):\n        glLineWidth(self.stroke)\n\nclass Point(Geom):\n    def __init__(self):\n        Geom.__init__(self)\n    def render1(self):\n        glBegin(GL_POINTS) # draw point\n        glVertex3f(0.0, 0.0, 0.0)\n        glEnd()\n\nclass FilledPolygon(Geom):\n    def __init__(self, v):\n        Geom.__init__(self)\n        self.v = v\n    def render1(self):\n        if   len(self.v) == 4 : glBegin(GL_QUADS)\n        elif len(self.v)  > 4 : glBegin(GL_POLYGON)\n        else: glBegin(GL_TRIANGLES)\n        for p in self.v:\n            glVertex3f(p[0], p[1],0)  # draw each vertex\n        glEnd()\n\n        color = (self._color.vec4[0] * 0.5, self._color.vec4[1] * 0.5, self._color.vec4[2] * 0.5, self._color.vec4[3] * 0.5)\n        glColor4f(*color)\n        glBegin(GL_LINE_LOOP)\n        for p in self.v:\n            glVertex3f(p[0], p[1],0)  # draw each vertex\n        glEnd()\n\ndef make_circle(radius=10, res=30, filled=True):\n    points = []\n    for i in range(res):\n        ang = 2*math.pi*i / res\n        points.append((math.cos(ang)*radius, math.sin(ang)*radius))\n    if filled:\n        return FilledPolygon(points)\n    else:\n        return PolyLine(points, True)\n\ndef make_polygon(v, filled=True):\n    if filled: return FilledPolygon(v)\n    else: return PolyLine(v, True)\n\ndef make_polyline(v):\n    return PolyLine(v, False)\n\ndef make_capsule(length, width):\n    l, r, t, b = 0, length, width/2, -width/2\n    box = make_polygon([(l,b), (l,t), (r,t), (r,b)])\n    circ0 = make_circle(width/2)\n    circ1 = make_circle(width/2)\n    circ1.add_attr(Transform(translation=(length, 0)))\n    geom = Compound([box, circ0, circ1])\n    return geom\n\nclass Compound(Geom):\n    def __init__(self, gs):\n        Geom.__init__(self)\n        self.gs = gs\n        for g in self.gs:\n            g.attrs = [a for a in g.attrs if not isinstance(a, Color)]\n    def render1(self):\n        for g in self.gs:\n            g.render()\n\nclass PolyLine(Geom):\n    def __init__(self, v, close):\n        Geom.__init__(self)\n        self.v = v\n        self.close = close\n        self.linewidth = LineWidth(1)\n        self.add_attr(self.linewidth)\n    def render1(self):\n        glBegin(GL_LINE_LOOP if self.close else GL_LINE_STRIP)\n        for p in self.v:\n            glVertex3f(p[0], p[1],0)  # draw each vertex\n        glEnd()\n    def set_linewidth(self, x):\n        self.linewidth.stroke = x\n\nclass Line(Geom):\n    def __init__(self, start=(0.0, 0.0), end=(0.0, 0.0)):\n        Geom.__init__(self)\n        self.start = start\n        self.end = end\n        self.linewidth = LineWidth(1)\n        self.add_attr(self.linewidth)\n\n    def render1(self):\n        glBegin(GL_LINES)\n        glVertex2f(*self.start)\n        glVertex2f(*self.end)\n        glEnd()\n\nclass Image(Geom):\n    def __init__(self, fname, width, height):\n        Geom.__init__(self)\n        self.width = width\n        self.height = height\n        img = pyglet.image.load(fname)\n        self.img = img\n        self.flip = False\n    def render1(self):\n        self.img.blit(-self.width/2, -self.height/2, width=self.width, height=self.height)\n\n# ================================================================\n\nclass SimpleImageViewer(object):\n    def __init__(self, display=None):\n        self.window = None\n        self.isopen = False\n        self.display = display\n    def imshow(self, arr):\n        if self.window is None:\n            height, width, channels = arr.shape\n            self.window = pyglet.window.Window(width=width, height=height, display=self.display)\n            self.width = width\n            self.height = height\n            self.isopen = True\n        assert arr.shape == (self.height, self.width, 3), \"You passed in an image with the wrong number shape\"\n        image = pyglet.image.ImageData(self.width, self.height, 'RGB', arr.tobytes(), pitch=self.width * -3)\n        self.window.clear()\n        self.window.switch_to()\n        self.window.dispatch_events()\n        image.blit(0,0)\n        self.window.flip()\n    def close(self):\n        if self.isopen:\n            self.window.close()\n            self.isopen = False\n    def __del__(self):\n        self.close()"
  },
  {
    "path": "examples/Social_Cognition/FOToM/multiagent/scenario.py",
    "content": "import numpy as np\n\n# defines scenario upon which the world is built\nclass BaseScenario(object):\n    # create elements of the world\n    def make_world(self):\n        raise NotImplementedError()\n    # create initial conditions of the world\n    def reset_world(self, world):\n        raise NotImplementedError()\n"
  },
  {
    "path": "examples/Social_Cognition/FOToM/multiagent/scenarios/__init__.py",
    "content": "import imp\nimport os.path as osp\n\n\ndef load(name):\n    pathname = osp.join(osp.dirname(__file__), name)\n    return imp.load_source('', pathname)\n"
  },
  {
    "path": "examples/Social_Cognition/FOToM/multiagent/scenarios/hetero_spread.py",
    "content": "import numpy as np\nfrom multiagent.core import World, Agent, Landmark\nfrom multiagent.scenario import BaseScenario\n\n\nclass Scenario(BaseScenario):\n    def make_world(self, num_good_agents=2, num_adversaries=0):\n        world = World()\n        # set any world properties first\n        world.dim_c = 2\n        world.max_steps = 25\n        num_agents = num_good_agents\n        self.n_agent_a = num_agents // 2 # 2\n        self.n_agent_b = num_agents // 2 # 2\n        num_landmarks = num_good_agents\n        world.collaborative = True\n        self.agent_size = 0.10\n        self.n_others = 3\n        self.n_group = 2\n        # add agents\n        world.agents = [Agent() for i in range(num_agents)]\n        for i, agent in enumerate(world.agents):\n            agent.name = 'agent %d' % i\n            agent.collide = True\n            agent.silent = True\n            agent.id = i\n            if i < self.n_agent_a:\n                agent.size = self.agent_size\n                agent.accel = 3.0\n                agent.max_speed = 1.0\n            else:\n                agent.size = self.agent_size / 2\n                agent.accel = 4.0\n                agent.max_speed = 1.3\n\n        # add landmarks\n        world.landmarks = [Landmark() for i in range(num_landmarks)]\n        for i, landmark in enumerate(world.landmarks):\n            landmark.name = 'landmark %d' % i\n            landmark.collide = False\n            landmark.movable = False\n        # make initial conditions\n        self.reset_world(world)\n        return world\n\n    def reset_world(self, world):\n        world.num_steps = 0\n        self.end_steps = world.max_steps\n        # random properties for agents\n        for i, agent in enumerate(world.agents):\n            if i < self.n_agent_a:\n                agent.color = np.array([0.35, 0.35, 0.85])\n            else:\n                agent.color = np.array([0.35, 0.85, 0.35])\n        # random properties for landmarks\n        for i, landmark in enumerate(world.landmarks):\n            landmark.color = np.array([0.25, 0.25, 0.25])\n        # set random initial states\n        for agent in world.agents:\n            agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            agent.state.p_vel = np.zeros(world.dim_p)\n            agent.state.c = np.zeros(world.dim_c)\n        for i, landmark in enumerate(world.landmarks):\n            landmark.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n\n    def benchmark_data(self, agent, world):\n        rew = 0\n        collisions = 0\n        occupied_landmarks = 0\n        min_dists = 0\n        for l in world.landmarks:\n            dists = [np.sqrt(np.sum(np.square(a.state.p_pos - l.state.p_pos))) for a in world.agents]\n            min_dists += min(dists)\n            rew -= min(dists)\n            if min(dists) < 0.1:\n                occupied_landmarks += 1\n        if agent.collide:\n            for a in world.agents:\n                if self.is_collision(a, agent):\n                    rew -= 1\n                    collisions += 1\n        return (rew, collisions, min_dists, occupied_landmarks)\n\n\n    def is_collision(self, agent1, agent2):\n        delta_pos = agent1.state.p_pos - agent2.state.p_pos\n        dist = np.sqrt(np.sum(np.square(delta_pos)))\n        dist_min = agent1.size + agent2.size\n        return True if dist < dist_min else False\n\n    def reward(self, agent, world):\n        # Agents are rewarded based on minimum agent distance to each landmark, penalized for collisions\n        rew = 0\n        shaped_reward = False\n        if shaped_reward:  # distance-based reward\n            for l in world.landmarks:\n                dists = [np.sqrt(np.sum(np.square(a.state.p_pos - l.state.p_pos))) for a in world.agents]\n                rew -= min(dists)\n            if agent.collide:\n                for a in world.agents:\n                    if self.is_collision(a, agent):\n                        rew -= 1\n            return rew\n        else:\n            win_agents = []\n            for land in world.landmarks:\n                for a in world.agents:\n                    if self.is_collision(a, land):\n                        win_agents.append(a)\n                        break\n            rew += 2 * len(set(win_agents))\n\n            def bound(x):\n                if x > 1.0:\n                    return min(np.exp(2 * x - 2), 10)\n                else:\n                    return 0.0\n            bound_rew = 0.0\n            for p in range(world.dim_p):\n                x = abs(agent.state.p_pos[p])\n                bound_rew -= bound(x)\n            rew += bound_rew\n            return rew\n\n    def observation(self, agent, world):\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:  # world.entities:\n            entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n        # entity colors\n        entity_color = []\n        other_vel = []\n        for entity in world.landmarks:  # world.entities:\n            entity_color.append(entity.color)\n        # communication of all other agents\n        comm = []\n        other_pos = []\n        for other in world.agents:\n            if other is agent:\n                other_vel.append([0, 0])\n                continue\n            comm.append(other.state.c)\n            other_pos.append(other.state.p_pos - agent.state.p_pos)\n        return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + comm +\n                              other_vel + entity_pos + other_pos)\n"
  },
  {
    "path": "examples/Social_Cognition/FOToM/multiagent/scenarios/simple.py",
    "content": "import numpy as np\nfrom multiagent.core import World, Agent, Landmark\nfrom multiagent.scenario import BaseScenario\n\nclass Scenario(BaseScenario):\n    def make_world(self):\n        world = World()\n        # add agents\n        world.agents = [Agent() for i in range(1)]\n        for i, agent in enumerate(world.agents):\n            agent.name = 'agent %d' % i\n            agent.collide = False\n            agent.silent = True\n        # add landmarks\n        world.landmarks = [Landmark() for i in range(1)]\n        for i, landmark in enumerate(world.landmarks):\n            landmark.name = 'landmark %d' % i\n            landmark.collide = False\n            landmark.movable = False\n        # make initial conditions\n        self.reset_world(world)\n        return world\n\n    def reset_world(self, world):\n        # random properties for agents\n        for i, agent in enumerate(world.agents):\n            agent.color = np.array([0.25,0.25,0.25])\n        # random properties for landmarks\n        for i, landmark in enumerate(world.landmarks):\n            landmark.color = np.array([0.75,0.75,0.75])\n        world.landmarks[0].color = np.array([0.75,0.25,0.25])\n        # set random initial states\n        for agent in world.agents:\n            agent.state.p_pos = np.random.uniform(-1,+1, world.dim_p)\n            agent.state.p_vel = np.zeros(world.dim_p)\n            agent.state.c = np.zeros(world.dim_c)\n        for i, landmark in enumerate(world.landmarks):\n            landmark.state.p_pos = np.random.uniform(-1,+1, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n\n    def reward(self, agent, world):\n        dist2 = np.sum(np.square(agent.state.p_pos - world.landmarks[0].state.p_pos))\n        return -dist2\n\n    def observation(self, agent, world):\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:\n            entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n        return np.concatenate([agent.state.p_vel] + entity_pos)\n"
  },
  {
    "path": "examples/Social_Cognition/FOToM/multiagent/scenarios/simple_adversary.py",
    "content": "import numpy as np\nfrom multiagent.core import World, Agent, Landmark\nfrom multiagent.scenario import BaseScenario\n\n\nclass Scenario(BaseScenario):\n\n    def make_world(self, num_good_agents=2, num_adversaries=1):\n        world = World()\n        # set any world properties first\n        world.dim_c = 2\n        num_agents = num_adversaries + num_good_agents\n        num_good_agents = num_good_agents\n        num_adversaries = 1\n        world.num_agents = num_agents\n\n        num_landmarks = num_agents - 1\n        # add agents\n        world.agents = [Agent() for i in range(num_agents)]\n        for i, agent in enumerate(world.agents):\n            agent.name = 'agent %d' % i\n            agent.collide = False\n            agent.silent = True\n            agent.adversary = True if i < num_adversaries else False\n            agent.size = 0.15\n        # add landmarks\n        world.landmarks = [Landmark() for i in range(num_landmarks)]\n        for i, landmark in enumerate(world.landmarks):\n            landmark.name = 'landmark %d' % i\n            landmark.collide = False\n            landmark.movable = False\n            landmark.size = 0.08\n        # make initial conditions\n        self.reset_world(world)\n        return world\n\n    def reset_world(self, world):\n        # random properties for agents\n        world.agents[0].color = np.array([0.85, 0.35, 0.35])\n        for i in range(1, world.num_agents):\n            world.agents[i].color = np.array([0.35, 0.35, 0.85])\n        # random properties for landmarks\n        for i, landmark in enumerate(world.landmarks):\n            landmark.color = np.array([0.15, 0.15, 0.15])\n        # set goal landmark\n        goal = np.random.choice(world.landmarks)\n        goal.color = np.array([0.15, 0.65, 0.15])\n        for agent in world.agents:\n            agent.goal_a = goal\n        # set random initial states\n        for agent in world.agents:\n            agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            agent.state.p_vel = np.zeros(world.dim_p)\n            agent.state.c = np.zeros(world.dim_c)\n        for i, landmark in enumerate(world.landmarks):\n            landmark.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n\n    def benchmark_data(self, agent, world):\n        # returns data for benchmarking purposes\n        if agent.adversary:\n            return np.sum(np.square(agent.state.p_pos - agent.goal_a.state.p_pos))\n        else:\n            dists = []\n            for l in world.landmarks:\n                dists.append(np.sum(np.square(agent.state.p_pos - l.state.p_pos)))\n            dists.append(np.sum(np.square(agent.state.p_pos - agent.goal_a.state.p_pos)))\n            return tuple(dists)\n\n    # return all agents that are not adversaries\n    def good_agents(self, world):\n        return [agent for agent in world.agents if not agent.adversary]\n\n    # return all adversarial agents\n    def adversaries(self, world):\n        return [agent for agent in world.agents if agent.adversary]\n\n    def reward(self, agent, world):\n        # Agents are rewarded based on minimum agent distance to each landmark\n        return self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)\n\n    def agent_reward(self, agent, world):\n        # Rewarded based on how close any good agent is to the goal landmark, and how far the adversary is from it\n        shaped_reward = False\n        shaped_adv_reward = False\n\n        # Calculate negative reward for adversary\n        adversary_agents = self.adversaries(world)\n        if shaped_adv_reward:  # distance-based adversary reward\n            adv_rew = sum([np.sqrt(np.sum(np.square(a.state.p_pos - a.goal_a.state.p_pos))) for a in adversary_agents])\n        else:  # proximity-based adversary reward (binary)\n            adv_rew = 0\n            for a in adversary_agents:\n                if np.sqrt(np.sum(np.square(a.state.p_pos - a.goal_a.state.p_pos))) < 2 * a.goal_a.size:\n                    adv_rew -= 5\n\n        # Calculate positive reward for agents\n        good_agents = self.good_agents(world)\n        if shaped_reward:  # distance-based agent reward\n            pos_rew = -min(\n                [np.sqrt(np.sum(np.square(a.state.p_pos - a.goal_a.state.p_pos))) for a in good_agents])\n        else:  # proximity-based agent reward (binary)\n            pos_rew = 0\n            if min([np.sqrt(np.sum(np.square(a.state.p_pos - a.goal_a.state.p_pos))) for a in good_agents]) \\\n                    < 2 * agent.goal_a.size:\n                pos_rew += 5\n            pos_rew -= min(\n                [np.sqrt(np.sum(np.square(a.state.p_pos - a.goal_a.state.p_pos))) for a in good_agents])\n        return pos_rew + adv_rew\n\n    def adversary_reward(self, agent, world):\n        # Rewarded based on proximity to the goal landmark\n        shaped_reward = False\n        if shaped_reward:  # distance-based reward\n            return -np.sum(np.square(agent.state.p_pos - agent.goal_a.state.p_pos))\n        else:  # proximity-based reward (binary)\n            adv_rew = 0\n            if np.sqrt(np.sum(np.square(agent.state.p_pos - agent.goal_a.state.p_pos))) < 2 * agent.goal_a.size:\n                adv_rew += 5\n            return adv_rew\n\n\n    def observation(self, agent, world):\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:\n            entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n        # entity colors\n        entity_color = []\n        other_vel = []\n        for entity in world.landmarks:\n            entity_color.append(entity.color)\n        # communication of all other agents\n        other_pos = []\n        for other in world.agents:\n            if other is agent and not other.adversary:\n                other_vel.append([0, 0])\n                continue\n            other_pos.append(other.state.p_pos - agent.state.p_pos)\n            if not other.adversary:\n                other_vel.append([0, 0])\n\n        if not agent.adversary:\n            return np.concatenate(\n                [agent.goal_a.state.p_pos - agent.state.p_pos] +\n                other_vel + entity_pos + other_pos)\n        else:\n            return np.concatenate([np.zeros(2)] +\n              other_vel + entity_pos + other_pos)\n"
  },
  {
    "path": "examples/Social_Cognition/FOToM/multiagent/scenarios/simple_crypto.py",
    "content": "\"\"\"\nScenario:\n1 speaker, 2 listeners (one of which is an adversary). Good agents rewarded for proximity to goal, and distance from\nadversary to goal. Adversary is rewarded for its distance to the goal.\n\"\"\"\n\n\nimport numpy as np\nfrom multiagent.core import World, Agent, Landmark\nfrom multiagent.scenario import BaseScenario\nimport random\n\n\nclass CryptoAgent(Agent):\n    def __init__(self):\n        super(CryptoAgent, self).__init__()\n        self.key = None\n\nclass Scenario(BaseScenario):\n\n    def make_world(self):\n        world = World()\n        # set any world properties first\n        num_agents = 3\n        num_adversaries = 1\n        num_landmarks = 2\n        world.dim_c = 4\n        # add agents\n        world.agents = [CryptoAgent() for i in range(num_agents)]\n        for i, agent in enumerate(world.agents):\n            agent.name = 'agent %d' % i\n            agent.collide = False\n            agent.adversary = True if i < num_adversaries else False\n            agent.speaker = True if i == 2 else False\n            agent.movable = False\n        # add landmarks\n        world.landmarks = [Landmark() for i in range(num_landmarks)]\n        for i, landmark in enumerate(world.landmarks):\n            landmark.name = 'landmark %d' % i\n            landmark.collide = False\n            landmark.movable = False\n        # make initial conditions\n        self.reset_world(world)\n        return world\n\n\n    def reset_world(self, world):\n        # random properties for agents\n        for i, agent in enumerate(world.agents):\n            agent.color = np.array([0.25, 0.25, 0.25])\n            if agent.adversary:\n                agent.color = np.array([0.75, 0.25, 0.25])\n            agent.key = None\n        # random properties for landmarks\n        color_list = [np.zeros(world.dim_c) for i in world.landmarks]\n        for i, color in enumerate(color_list):\n            color[i] += 1\n        for color, landmark in zip(color_list, world.landmarks):\n            landmark.color = color\n        # set goal landmark\n        goal = np.random.choice(world.landmarks)\n        world.agents[1].color = goal.color\n        world.agents[2].key = np.random.choice(world.landmarks).color\n\n        for agent in world.agents:\n            agent.goal_a = goal\n\n        # set random initial states\n        for agent in world.agents:\n            agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            agent.state.p_vel = np.zeros(world.dim_p)\n            agent.state.c = np.zeros(world.dim_c)\n        for i, landmark in enumerate(world.landmarks):\n            landmark.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n\n\n    def benchmark_data(self, agent, world):\n        # returns data for benchmarking purposes\n        return (agent.state.c, agent.goal_a.color)\n\n    # return all agents that are not adversaries\n    def good_listeners(self, world):\n        return [agent for agent in world.agents if not agent.adversary and not agent.speaker]\n\n    # return all agents that are not adversaries\n    def good_agents(self, world):\n        return [agent for agent in world.agents if not agent.adversary]\n\n    # return all adversarial agents\n    def adversaries(self, world):\n        return [agent for agent in world.agents if agent.adversary]\n\n    def reward(self, agent, world):\n        return self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)\n\n    def agent_reward(self, agent, world):\n        # Agents rewarded if Bob can reconstruct message, but adversary (Eve) cannot\n        good_listeners = self.good_listeners(world)\n        adversaries = self.adversaries(world)\n        good_rew = 0\n        adv_rew = 0\n        for a in good_listeners:\n            if (a.state.c == np.zeros(world.dim_c)).all():\n                continue\n            else:\n                good_rew -= np.sum(np.square(a.state.c - agent.goal_a.color))\n        for a in adversaries:\n            if (a.state.c == np.zeros(world.dim_c)).all():\n                continue\n            else:\n                adv_l1 = np.sum(np.square(a.state.c - agent.goal_a.color))\n                adv_rew += adv_l1\n        return adv_rew + good_rew\n\n    def adversary_reward(self, agent, world):\n        # Adversary (Eve) is rewarded if it can reconstruct original goal\n        rew = 0\n        if not (agent.state.c == np.zeros(world.dim_c)).all():\n            rew -= np.sum(np.square(agent.state.c - agent.goal_a.color))\n        return rew\n\n\n    def observation(self, agent, world):\n        # goal color\n        goal_color = np.zeros(world.dim_color)\n        if agent.goal_a is not None:\n            goal_color = agent.goal_a.color\n\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:\n            entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n        # communication of all other agents\n        comm = []\n        for other in world.agents:\n            if other is agent or (other.state.c is None) or not other.speaker: continue\n            comm.append(other.state.c)\n\n        confer = np.array([0])\n\n        if world.agents[2].key is None:\n            confer = np.array([1])\n            key = np.zeros(world.dim_c)\n            goal_color = np.zeros(world.dim_c)\n        else:\n            key = world.agents[2].key\n\n        prnt = False\n        # speaker\n        if agent.speaker:\n            if prnt:\n                print('speaker')\n                print(agent.state.c)\n                print(np.concatenate([goal_color] + [key] + [confer] + [np.random.randn(1)]))\n            return np.concatenate([goal_color] + [key])\n        # listener\n        if not agent.speaker and not agent.adversary:\n            if prnt:\n                print('listener')\n                print(agent.state.c)\n                print(np.concatenate([key] + comm + [confer]))\n            return np.concatenate([key] + comm)\n        if not agent.speaker and agent.adversary:\n            if prnt:\n                print('adversary')\n                print(agent.state.c)\n                print(np.concatenate(comm + [confer]))\n            return np.concatenate(comm)\n"
  },
  {
    "path": "examples/Social_Cognition/FOToM/multiagent/scenarios/simple_push.py",
    "content": "import numpy as np\nfrom multiagent.core import World, Agent, Landmark\nfrom multiagent.scenario import BaseScenario\n\nclass Scenario(BaseScenario):\n    def make_world(self, num_good_agents=2, num_adversaries=2):\n        world = World()\n        # set any world properties first\n        world.dim_c = 2\n        num_agents = num_good_agents + num_adversaries\n        num_adversaries = num_adversaries\n        num_landmarks = 2\n        # add agents\n        world.agents = [Agent() for i in range(num_agents)]\n        for i, agent in enumerate(world.agents):\n            agent.name = 'agent %d' % i\n            agent.collide = True\n            agent.silent = True\n            if i < num_adversaries:\n                agent.adversary = True\n            else:\n                agent.adversary = False\n        # add landmarks\n        world.landmarks = [Landmark() for i in range(num_landmarks)]\n        for i, landmark in enumerate(world.landmarks):\n            landmark.name = 'landmark %d' % i\n            landmark.collide = False\n            landmark.movable = False\n        # make initial conditions\n        self.reset_world(world)\n        return world\n\n    def reset_world(self, world):\n        # random properties for landmarks\n        for i, landmark in enumerate(world.landmarks):\n            landmark.color = np.array([0.1, 0.1, 0.1])\n            landmark.color[i + 1] += 0.8\n            landmark.index = i\n        # set goal landmark\n        goal = np.random.choice(world.landmarks)\n        for i, agent in enumerate(world.agents):\n            agent.goal_a = goal\n            agent.color = np.array([0.25, 0.25, 0.25])\n            if agent.adversary:\n                agent.color = np.array([0.75, 0.25, 0.25])\n            else:\n                j = goal.index\n                agent.color[j + 1] += 0.5\n        # set random initial states\n        for agent in world.agents:\n            agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            agent.state.p_vel = np.zeros(world.dim_p)\n            agent.state.c = np.zeros(world.dim_c)\n        for i, landmark in enumerate(world.landmarks):\n            landmark.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n\n    def reward(self, agent, world):\n        # Agents are rewarded based on minimum agent distance to each landmark\n        return self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)\n\n    def good_agents(self, world):\n        return [agent for agent in world.agents if not agent.adversary]\n\n    # return all adversarial agents\n    def adversaries(self, world):\n        return [agent for agent in world.agents if agent.adversary]\n\n    def is_collision(self, agent1, agent2):\n        delta_pos = agent1.state.p_pos - agent2.state.p_pos\n        dist = np.sqrt(np.sum(np.square(delta_pos)))\n        dist_min = agent1.size + agent2.size\n        return dist < dist_min\n\n    def agent_reward(self, agent, world):\n        '''\n        Rewrite\n        '''\n        shaped_reward = False\n        if shaped_reward:  # distance-based reward\n        # the distance to the goal\n            return -np.sqrt(np.sum(np.square(\n                agent.state.p_pos - agent.goal_a.state.p_pos)))\n        else:\n            pos_rew, adv_rew = 0.0, 0.0\n            for a in self.adversaries(world):\n                if self.is_collision(a, agent):\n                    adv_rew -= 5.0\n            if self.is_collision(agent, agent.goal_a):\n                pos_rew += 5.0\n            rew = pos_rew + adv_rew\n            def bound(x):\n                if x > 1.0:\n                    return min(np.exp(2 * x - 2), 10)\n                else:\n                    return 0.0\n            bound_rew = 0.0\n            for p in range(world.dim_p):\n                x = abs(agent.state.p_pos[p])\n                bound_rew -= bound(x)\n            rew += bound_rew\n            return rew\n\n    def adversary_reward(self, agent, world):\n        '''\n        Rewrite\n        '''\n        shaped_reward = False\n        if shaped_reward:  # distance-based reward\n            # keep the nearest good agents away from the goal\n            agent_dist = [np.sqrt(np.sum(np.square(a.state.p_pos -\n                       a.goal_a.state.p_pos))) for a in world.agents if not a.adversary]\n            pos_rew = min(agent_dist)\n            #nearest_agent = world.good_agents[np.argmin(agent_dist)]\n            #neg_rew = np.sqrt(np.sum(np.square(nearest_agent.state.p_pos - agent.state.p_pos)))\n            neg_rew = np.sqrt(np.sum(np.square(agent.goal_a.state.p_pos - agent.state.p_pos)))\n            #neg_rew = sum([np.sqrt(np.sum(np.square(a.state.p_pos - agent.state.p_pos))) for a in world.good_agents])\n            return pos_rew - neg_rew\n        else:\n            rew = 0.0\n            for a in self.good_agents(world):\n                if self.is_collision(a, a.goal_a):\n                    rew -= 5.0\n                # if self.is_collision(a, agent):\n                #     rew += 5.0\n            if self.is_collision(agent, agent.goal_a):\n                rew += 5.0\n            return rew\n               \n    def observation(self, agent, world):\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:  # world.entities:\n            entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n        # entity colors\n        entity_color = []\n        other_vel = []\n        for entity in world.landmarks:  # world.entities:\n            entity_color.append(entity.color)\n        # communication of all other agents\n        comm = []\n        other_pos = []\n        for other in world.agents:\n            if other is agent and not other.adversary:\n                other_vel.append([0, 0])\n                continue\n            comm.append(other.state.c)\n            other_pos.append(other.state.p_pos - agent.state.p_pos)\n            if not other.adversary:\n                other_vel.append([0, 0])\n\n        if not agent.adversary:\n            return np.concatenate([agent.state.p_vel] +\n            [agent.goal_a.state.p_pos - agent.state.p_pos] +\n          [agent.color] + entity_color +\n                  other_vel + entity_pos + other_pos)\n        else:\n            #other_pos = list(reversed(other_pos)) if random.uniform(0,1) > 0.5 else other_pos  # randomize position of other agents in adversary network\n            return np.concatenate([agent.state.p_vel] +\n                  other_vel + entity_pos + other_pos)\n"
  },
  {
    "path": "examples/Social_Cognition/FOToM/multiagent/scenarios/simple_reference.py",
    "content": "import numpy as np\nfrom multiagent.core import World, Agent, Landmark\nfrom multiagent.scenario import BaseScenario\n\nclass Scenario(BaseScenario):\n    def make_world(self, num_good_agents=2, num_adversaries=0):\n        world = World()\n        # set any world properties first\n        world.dim_c = 10\n        world.collaborative = True  # whether agents share rewards\n        # add agents\n        world.agents = [Agent() for i in range(num_good_agents)]\n        for i, agent in enumerate(world.agents):\n            agent.name = 'agent %d' % i\n            agent.collide = False\n        # add landmarks\n        world.landmarks = [Landmark() for i in range(num_good_agents+1)]\n        for i, landmark in enumerate(world.landmarks):\n            landmark.name = 'landmark %d' % i\n            landmark.collide = False\n            landmark.movable = False\n        # make initial conditions\n        self.reset_world(world)\n        return world\n\n    def reset_world(self, world):\n        # assign goals to agents\n        for agent in world.agents:\n            agent.goal_a = None\n            agent.goal_b = None\n        # want other agent to go to the goal landmark\n        world.agents[0].goal_a = world.agents[1]\n        world.agents[0].goal_b = np.random.choice(world.landmarks)\n        world.agents[1].goal_a = world.agents[0]\n        world.agents[1].goal_b = np.random.choice(world.landmarks)\n        # random properties for agents\n        for i, agent in enumerate(world.agents):\n            agent.color = np.array([0.25,0.25,0.25])               \n        # random properties for landmarks\n        world.landmarks[0].color = np.array([0.75,0.25,0.25]) \n        world.landmarks[1].color = np.array([0.25,0.75,0.25]) \n        world.landmarks[2].color = np.array([0.25,0.25,0.75]) \n        # special colors for goals\n        world.agents[0].goal_a.color = world.agents[0].goal_b.color                \n        world.agents[1].goal_a.color = world.agents[1].goal_b.color                               \n        # set random initial states\n        for agent in world.agents:\n            agent.state.p_pos = np.random.uniform(-1,+1, world.dim_p)\n            agent.state.p_vel = np.zeros(world.dim_p)\n            agent.state.c = np.zeros(world.dim_c)\n        for i, landmark in enumerate(world.landmarks):\n            landmark.state.p_pos = np.random.uniform(-1,+1, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n\n    def reward(self, agent, world):\n        if agent.goal_a is None or agent.goal_b is None:\n            return 0.0\n        dist2 = np.sum(np.square(agent.goal_a.state.p_pos - agent.goal_b.state.p_pos))\n        return -dist2\n\n    def observation(self, agent, world):\n        # goal color\n        goal_color = [np.zeros(world.dim_color), np.zeros(world.dim_color)]\n        if agent.goal_b is not None:\n            goal_color[1] = agent.goal_b.color \n\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:\n            entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n        # entity colors\n        entity_color = []\n        for entity in world.landmarks:\n            entity_color.append(entity.color)\n        # communication of all other agents\n        comm = []\n        for other in world.agents:\n            if other is agent: continue\n            comm.append(other.state.c)\n        return np.concatenate([agent.state.p_vel] + entity_pos + [goal_color[1]] + comm)\n            "
  },
  {
    "path": "examples/Social_Cognition/FOToM/multiagent/scenarios/simple_speaker_listener.py",
    "content": "import numpy as np\nfrom multiagent.core import World, Agent, Landmark\nfrom multiagent.scenario import BaseScenario\n\nclass Scenario(BaseScenario):\n    def make_world(self):\n        world = World()\n        # set any world properties first\n        world.dim_c = 3\n        num_landmarks = 3\n        world.collaborative = True\n        # add agents\n        world.agents = [Agent() for i in range(2)]\n        for i, agent in enumerate(world.agents):\n            agent.name = 'agent %d' % i\n            agent.collide = False\n            agent.size = 0.075\n        # speaker\n        world.agents[0].movable = False\n        # listener\n        world.agents[1].silent = True\n        # add landmarks\n        world.landmarks = [Landmark() for i in range(num_landmarks)]\n        for i, landmark in enumerate(world.landmarks):\n            landmark.name = 'landmark %d' % i\n            landmark.collide = False\n            landmark.movable = False\n            landmark.size = 0.04\n        # make initial conditions\n        self.reset_world(world)\n        return world\n\n    def reset_world(self, world):\n        # assign goals to agents\n        for agent in world.agents:\n            agent.goal_a = None\n            agent.goal_b = None\n        # want listener to go to the goal landmark\n        world.agents[0].goal_a = world.agents[1]\n        world.agents[0].goal_b = np.random.choice(world.landmarks)\n        # random properties for agents\n        for i, agent in enumerate(world.agents):\n            agent.color = np.array([0.25,0.25,0.25])               \n        # random properties for landmarks\n        world.landmarks[0].color = np.array([0.65,0.15,0.15])\n        world.landmarks[1].color = np.array([0.15,0.65,0.15])\n        world.landmarks[2].color = np.array([0.15,0.15,0.65])\n        # special colors for goals\n        world.agents[0].goal_a.color = world.agents[0].goal_b.color + np.array([0.45, 0.45, 0.45])\n        # set random initial states\n        for agent in world.agents:\n            agent.state.p_pos = np.random.uniform(-1,+1, world.dim_p)\n            agent.state.p_vel = np.zeros(world.dim_p)\n            agent.state.c = np.zeros(world.dim_c)\n        for i, landmark in enumerate(world.landmarks):\n            landmark.state.p_pos = np.random.uniform(-1,+1, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n\n    def benchmark_data(self, agent, world):\n        # returns data for benchmarking purposes\n        return self.reward(agent, reward)\n\n    def reward(self, agent, world):\n        # squared distance from listener to landmark\n        a = world.agents[0]\n        dist2 = np.sum(np.square(a.goal_a.state.p_pos - a.goal_b.state.p_pos))\n        return -dist2\n\n    def observation(self, agent, world):\n        # goal color\n        goal_color = np.zeros(world.dim_color)\n        if agent.goal_b is not None:\n            goal_color = agent.goal_b.color\n\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:\n            entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n\n        # communication of all other agents\n        comm = []\n        for other in world.agents:\n            if other is agent or (other.state.c is None): continue\n            comm.append(other.state.c)\n        \n        # speaker\n        if not agent.movable:\n            return np.concatenate([goal_color])\n        # listener\n        if agent.silent:\n            return np.concatenate([agent.state.p_vel] + entity_pos + comm)\n            \n"
  },
  {
    "path": "examples/Social_Cognition/FOToM/multiagent/scenarios/simple_spread.py",
    "content": "import numpy as np\nfrom multiagent.core import World, Agent, Landmark\nfrom multiagent.scenario import BaseScenario\n\n\nclass Scenario(BaseScenario):\n    def make_world(self, num_good_agents=2, num_adversaries=2):\n        world = World()\n        # set any world properties first\n        world.dim_c = 2\n        num_agents = num_good_agents\n        num_landmarks = num_good_agents\n        world.collaborative = True\n        # add agents\n        world.agents = [Agent() for i in range(num_agents)]\n        for i, agent in enumerate(world.agents):\n            agent.name = 'agent %d' % i\n            agent.collide = True\n            agent.silent = True\n            agent.size = 0.15\n        # add landmarks\n        world.landmarks = [Landmark() for i in range(num_landmarks)]\n        for i, landmark in enumerate(world.landmarks):\n            landmark.name = 'landmark %d' % i\n            landmark.collide = False\n            landmark.movable = False\n        # make initial conditions\n        self.reset_world(world)\n        return world\n\n    def reset_world(self, world):\n        # random properties for agents\n        for i, agent in enumerate(world.agents):\n            agent.color = np.array([0.35, 0.35, 0.85])\n        # random properties for landmarks\n        for i, landmark in enumerate(world.landmarks):\n            landmark.color = np.array([0.25, 0.25, 0.25])\n        # set random initial states\n        for agent in world.agents:\n            agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            agent.state.p_vel = np.zeros(world.dim_p)\n            agent.state.c = np.zeros(world.dim_c)\n        for i, landmark in enumerate(world.landmarks):\n            landmark.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n\n    def benchmark_data(self, agent, world):\n        rew = 0\n        collisions = 0\n        occupied_landmarks = 0\n        min_dists = 0\n        for l in world.landmarks:\n            dists = [np.sqrt(np.sum(np.square(a.state.p_pos - l.state.p_pos))) for a in world.agents]\n            min_dists += min(dists)\n            rew -= min(dists)\n            if min(dists) < 0.1:\n                occupied_landmarks += 1\n        if agent.collide:\n            for a in world.agents:\n                if self.is_collision(a, agent):\n                    rew -= 1\n                    collisions += 1\n        return (rew, collisions, min_dists, occupied_landmarks)\n\n\n    def is_collision(self, agent1, agent2):\n        delta_pos = agent1.state.p_pos - agent2.state.p_pos\n        dist = np.sqrt(np.sum(np.square(delta_pos)))\n        dist_min = agent1.size + agent2.size\n        return True if dist < dist_min else False\n\n    def reward(self, agent, world):\n        # Agents are rewarded based on minimum agent distance to each landmark, penalized for collisions\n        rew = 0\n        shaped_reward = False\n        if shaped_reward:  # distance-based reward\n            for l in world.landmarks:\n                dists = [np.sqrt(np.sum(np.square(a.state.p_pos - l.state.p_pos))) for a in world.agents]\n                rew -= min(dists)\n            if agent.collide:\n                for a in world.agents:\n                    if self.is_collision(a, agent):\n                        rew -= 1\n            return rew\n        else:\n            win_agents = []\n            for land in world.landmarks:\n                for a in world.agents:\n                    if self.is_collision(a, land):\n                        win_agents.append(a)\n                        break\n            rew += 2 * len(set(win_agents))\n\n            def bound(x):\n                if x > 1.0:\n                    return min(np.exp(2 * x - 2), 10)\n                else:\n                    return 0.0\n            bound_rew = 0.0\n            for p in range(world.dim_p):\n                x = abs(agent.state.p_pos[p])\n                bound_rew -= bound(x)\n            rew += bound_rew\n            return rew\n\n    def observation(self, agent, world):\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:  # world.entities:\n            entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n        # entity colors\n        entity_color = []\n        other_vel = []\n        for entity in world.landmarks:  # world.entities:\n            entity_color.append(entity.color)\n        # communication of all other agents\n        comm = []\n        other_pos = []\n        for other in world.agents:\n            if other is agent:\n                other_vel.append([0, 0])\n                continue\n            comm.append(other.state.c)\n            other_pos.append(other.state.p_pos - agent.state.p_pos)\n        return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + comm +\n                              other_vel + entity_pos + other_pos)\n"
  },
  {
    "path": "examples/Social_Cognition/FOToM/multiagent/scenarios/simple_tag.py",
    "content": "import numpy as np\nfrom multiagent.core import World, Agent, Landmark\nfrom multiagent.scenario import BaseScenario\n\n\nclass Scenario(BaseScenario):\n    def make_world(self, num_good_agents=1, num_adversaries=3):\n        world = World()\n        # set any world properties first\n        world.dim_c = 2\n        num_good_agents = num_good_agents\n        num_adversaries = num_adversaries\n        num_agents = num_adversaries + num_good_agents\n        num_landmarks = 2\n        # add agents\n        world.agents = [Agent() for i in range(num_agents)]\n        for i, agent in enumerate(world.agents):\n            agent.name = 'agent %d' % i\n            agent.collide = True\n            agent.silent = True\n            agent.adversary = True if i < num_adversaries else False\n            agent.size = 0.075 if agent.adversary else 0.05\n            agent.accel = 3.0 if agent.adversary else 4.0\n            #agent.accel = 20.0 if agent.adversary else 25.0\n            agent.max_speed = 1.0 if agent.adversary else 1.3\n        # add landmarks\n        world.landmarks = [Landmark() for i in range(num_landmarks)]\n        for i, landmark in enumerate(world.landmarks):\n            landmark.name = 'landmark %d' % i\n            landmark.collide = True\n            landmark.movable = False\n            landmark.size = 0.2\n            landmark.boundary = False\n        # make initial conditions\n        self.reset_world(world)\n        return world\n\n\n    def reset_world(self, world):\n        # random properties for agents\n        for i, agent in enumerate(world.agents):\n            agent.color = np.array([0.35, 0.85, 0.35]) if not agent.adversary else np.array([0.85, 0.35, 0.35])\n            # random properties for landmarks\n        for i, landmark in enumerate(world.landmarks):\n            landmark.color = np.array([0.25, 0.25, 0.25])\n        # set random initial states\n        for agent in world.agents:\n            agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            agent.state.p_vel = np.zeros(world.dim_p)\n            agent.state.c = np.zeros(world.dim_c)\n        for i, landmark in enumerate(world.landmarks):\n            if not landmark.boundary:\n                landmark.state.p_pos = np.random.uniform(-0.9, +0.9, world.dim_p)\n                landmark.state.p_vel = np.zeros(world.dim_p)\n\n\n    def benchmark_data(self, agent, world):\n        # returns data for benchmarking purposes\n        if agent.adversary:\n            collisions = 0\n            for a in self.good_agents(world):\n                if self.is_collision(a, agent):\n                    collisions += 1\n            return collisions\n        else:\n            return 0\n\n\n    def is_collision(self, agent1, agent2):\n        delta_pos = agent1.state.p_pos - agent2.state.p_pos\n        dist = np.sqrt(np.sum(np.square(delta_pos)))\n        dist_min = agent1.size + agent2.size\n        return True if dist < dist_min else False\n\n    # return all agents that are not adversaries\n    def good_agents(self, world):\n        return [agent for agent in world.agents if not agent.adversary]\n\n    # return all adversarial agents\n    def adversaries(self, world):\n        return [agent for agent in world.agents if agent.adversary]\n\n\n    def reward(self, agent, world):\n        # Agents are rewarded based on minimum agent distance to each landmark\n        main_reward = self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)\n        return main_reward\n\n    def agent_reward(self, agent, world):\n        # Agents are negatively rewarded if caught by adversaries\n        rew = 0\n        shape = False\n        adversaries = self.adversaries(world)\n        if shape:  # reward can optionally be shaped (increased reward for increased distance from adversary)\n            for adv in adversaries:\n                rew += 0.1 * np.sqrt(np.sum(np.square(agent.state.p_pos - adv.state.p_pos)))\n        if agent.collide:\n            for a in adversaries:\n                if self.is_collision(a, agent):\n                    rew -= 10\n\n        # agents are penalized for exiting the screen, so that they can be caught by the adversaries\n        def bound(x):\n            if x < 0.9:\n                return 0\n            if x < 1.0:\n                return (x - 0.9) * 10\n            return min(np.exp(2 * x - 2), 10)\n        for p in range(world.dim_p):\n            x = abs(agent.state.p_pos[p])\n            rew -= bound(x)\n\n        return rew\n\n    def adversary_reward(self, agent, world):\n        # Adversaries are rewarded for collisions with agents\n        rew = 0\n        shape = False\n        agents = self.good_agents(world)\n        adversaries = self.adversaries(world)\n        if shape:  # reward can optionally be shaped (decreased reward for increased distance from agents)\n            for adv in adversaries:\n                rew -= 0.1 * min([np.sqrt(np.sum(np.square(a.state.p_pos - adv.state.p_pos))) for a in agents])\n        if agent.collide:\n            for ag in agents:\n                for adv in adversaries:\n                    if self.is_collision(ag, adv):\n                        rew += 10\n        return rew\n\n    def observation(self, agent, world):\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:\n            if not entity.boundary:\n                entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n        # communication of all other agents\n        comm = []\n        other_pos = []\n        other_vel = []\n        for other in world.agents:\n            if other is agent and not other.adversary:\n                other_vel.append(other.state.p_vel)\n                continue\n            comm.append(other.state.c)\n            other_pos.append(other.state.p_pos - agent.state.p_pos)\n            if not other.adversary:\n                other_vel.append(other.state.p_vel)\n        return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + other_vel +\n                              entity_pos + other_pos)\n"
  },
  {
    "path": "examples/Social_Cognition/FOToM/multiagent/scenarios/simple_world_comm.py",
    "content": "import numpy as np\nfrom multiagent.core import World, Agent, Landmark\nfrom multiagent.scenario import BaseScenario\n\n\nclass Scenario(BaseScenario):\n    def make_world(self, num_good_agents=2, num_adversaries=4):\n        world = World()\n        # set any world properties first\n        world.dim_c = 4\n        #world.damping = 1\n        num_good_agents = num_good_agents\n        num_adversaries = num_adversaries\n        num_agents = num_adversaries + num_good_agents\n        num_landmarks = 1\n        num_food = 2\n        num_forests = 2\n        # add agents\n        world.agents = [Agent() for i in range(num_agents)]\n        for i, agent in enumerate(world.agents):\n            agent.name = 'agent %d' % i\n            agent.collide = True\n            agent.leader = True if i == 0 else False\n            agent.silent = True if i > 0 else False\n            agent.adversary = True if i < num_adversaries else False\n            agent.size = 0.075 if agent.adversary else 0.045\n            agent.accel = 3.0 if agent.adversary else 4.0\n            #agent.accel = 20.0 if agent.adversary else 25.0\n            agent.max_speed = 1.0 if agent.adversary else 1.3\n        # add landmarks\n        world.landmarks = [Landmark() for i in range(num_landmarks)]\n        for i, landmark in enumerate(world.landmarks):\n            landmark.name = 'landmark %d' % i\n            landmark.collide = True\n            landmark.movable = False\n            landmark.size = 0.2\n            landmark.boundary = False\n        world.food = [Landmark() for i in range(num_food)]\n        for i, landmark in enumerate(world.food):\n            landmark.name = 'food %d' % i\n            landmark.collide = False\n            landmark.movable = False\n            landmark.size = 0.03\n            landmark.boundary = False\n        world.forests = [Landmark() for i in range(num_forests)]\n        for i, landmark in enumerate(world.forests):\n            landmark.name = 'forest %d' % i\n            landmark.collide = False\n            landmark.movable = False\n            landmark.size = 0.3\n            landmark.boundary = False\n        world.landmarks += world.food\n        world.landmarks += world.forests\n        #world.landmarks += self.set_boundaries(world)  # world boundaries now penalized with negative reward\n        # make initial conditions\n        self.reset_world(world)\n        return world\n\n    def set_boundaries(self, world):\n        boundary_list = []\n        landmark_size = 1\n        edge = 1 + landmark_size\n        num_landmarks = int(edge * 2 / landmark_size)\n        for x_pos in [-edge, edge]:\n            for i in range(num_landmarks):\n                l = Landmark()\n                l.state.p_pos = np.array([x_pos, -1 + i * landmark_size])\n                boundary_list.append(l)\n\n        for y_pos in [-edge, edge]:\n            for i in range(num_landmarks):\n                l = Landmark()\n                l.state.p_pos = np.array([-1 + i * landmark_size, y_pos])\n                boundary_list.append(l)\n\n        for i, l in enumerate(boundary_list):\n            l.name = 'boundary %d' % i\n            l.collide = True\n            l.movable = False\n            l.boundary = True\n            l.color = np.array([0.75, 0.75, 0.75])\n            l.size = landmark_size\n            l.state.p_vel = np.zeros(world.dim_p)\n\n        return boundary_list\n\n\n    def reset_world(self, world):\n        # random properties for agents\n        for i, agent in enumerate(world.agents):\n            agent.color = np.array([0.45, 0.95, 0.45]) if not agent.adversary else np.array([0.95, 0.45, 0.45])\n            agent.color -= np.array([0.3, 0.3, 0.3]) if agent.leader else np.array([0, 0, 0])\n            # random properties for landmarks\n        for i, landmark in enumerate(world.landmarks):\n            landmark.color = np.array([0.25, 0.25, 0.25])\n        for i, landmark in enumerate(world.food):\n            landmark.color = np.array([0.15, 0.15, 0.65])\n        for i, landmark in enumerate(world.forests):\n            landmark.color = np.array([0.6, 0.9, 0.6])\n        # set random initial states\n        for agent in world.agents:\n            agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            agent.state.p_vel = np.zeros(world.dim_p)\n            agent.state.c = np.zeros(world.dim_c)\n        for i, landmark in enumerate(world.landmarks):\n            landmark.state.p_pos = np.random.uniform(-0.9, +0.9, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n        for i, landmark in enumerate(world.food):\n            landmark.state.p_pos = np.random.uniform(-0.9, +0.9, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n        for i, landmark in enumerate(world.forests):\n            landmark.state.p_pos = np.random.uniform(-0.9, +0.9, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n\n    def benchmark_data(self, agent, world):\n        if agent.adversary:\n            collisions = 0\n            for a in self.good_agents(world):\n                if self.is_collision(a, agent):\n                    collisions += 1\n            return collisions\n        else:\n            return 0\n\n\n    def is_collision(self, agent1, agent2):\n        delta_pos = agent1.state.p_pos - agent2.state.p_pos\n        dist = np.sqrt(np.sum(np.square(delta_pos)))\n        dist_min = agent1.size + agent2.size\n        return True if dist < dist_min else False\n\n\n    # return all agents that are not adversaries\n    def good_agents(self, world):\n        return [agent for agent in world.agents if not agent.adversary]\n\n    # return all adversarial agents\n    def adversaries(self, world):\n        return [agent for agent in world.agents if agent.adversary]\n\n\n    def reward(self, agent, world):\n        # Agents are rewarded based on minimum agent distance to each landmark\n        #boundary_reward = -10 if self.outside_boundary(agent) else 0\n        main_reward = self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)\n        return main_reward\n\n    def outside_boundary(self, agent):\n        if agent.state.p_pos[0] > 1 or agent.state.p_pos[0] < -1 or agent.state.p_pos[1] > 1 or agent.state.p_pos[1] < -1:\n            return True\n        else:\n            return False\n\n\n    def agent_reward(self, agent, world):\n        # Agents are rewarded based on minimum agent distance to each landmark\n        rew = 0\n        shape = False\n        adversaries = self.adversaries(world)\n        if shape:\n            for adv in adversaries:\n                rew += 0.1 * np.sqrt(np.sum(np.square(agent.state.p_pos - adv.state.p_pos)))\n        if agent.collide:\n            for a in adversaries:\n                if self.is_collision(a, agent):\n                    rew -= 5\n        def bound(x):\n            if x < 0.9:\n                return 0\n            if x < 1.0:\n                return (x - 0.9) * 10\n            return min(np.exp(2 * x - 2), 10)  # 1 + (x - 1) * (x - 1)\n\n        for p in range(world.dim_p):\n            x = abs(agent.state.p_pos[p])\n            rew -= 2 * bound(x)\n\n        for food in world.food:\n            if self.is_collision(agent, food):\n                rew += 2\n        rew += 0.05 * min([np.sqrt(np.sum(np.square(food.state.p_pos - agent.state.p_pos))) for food in world.food])\n\n        return rew\n\n    def adversary_reward(self, agent, world):\n        # Agents are rewarded based on minimum agent distance to each landmark\n        rew = 0\n        shape = True\n        agents = self.good_agents(world)\n        adversaries = self.adversaries(world)\n        if shape:\n            rew -= 0.1 * min([np.sqrt(np.sum(np.square(a.state.p_pos - agent.state.p_pos))) for a in agents])\n        if agent.collide:\n            for ag in agents:\n                for adv in adversaries:\n                    if self.is_collision(ag, adv):\n                        rew += 5\n        return rew\n\n\n    def observation2(self, agent, world):\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:\n            if not entity.boundary:\n                entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n\n        food_pos = []\n        for entity in world.food:\n            if not entity.boundary:\n                food_pos.append(entity.state.p_pos - agent.state.p_pos)\n        # communication of all other agents\n        comm = []\n        other_pos = []\n        other_vel = []\n        for other in world.agents:\n            if other is agent: continue\n            comm.append(other.state.c)\n            other_pos.append(other.state.p_pos - agent.state.p_pos)\n            if not other.adversary:\n                other_vel.append(other.state.p_vel)\n        return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel)\n\n    def observation(self, agent, world):\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:\n            if not entity.boundary:\n                entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n\n        in_forest = [np.array([-1]), np.array([-1])]\n        inf1 = False\n        inf2 = False\n        if self.is_collision(agent, world.forests[0]):\n            in_forest[0] = np.array([1])\n            inf1= True\n        if self.is_collision(agent, world.forests[1]):\n            in_forest[1] = np.array([1])\n            inf2 = True\n\n        food_pos = []\n        for entity in world.food:\n            if not entity.boundary:\n                food_pos.append(entity.state.p_pos - agent.state.p_pos)\n        # communication of all other agents\n        comm = []\n        other_pos = []\n        other_vel = []\n        for other in world.agents:\n            if other is agent and not other.adversary:\n                other_vel.append(other.state.p_vel) #\n                continue\n            comm.append(other.state.c)\n            oth_f1 = self.is_collision(other, world.forests[0])\n            oth_f2 = self.is_collision(other, world.forests[1])\n            if (inf1 and oth_f1) or (inf2 and oth_f2) or \\\n                    (not inf1 and not oth_f1 and not inf2 and not oth_f2) or \\\n                    agent.leader:  #without forest vis\n                other_pos.append(other.state.p_pos - agent.state.p_pos)\n                if not other.adversary:\n                    other_vel.append(other.state.p_vel)\n            else:\n                other_pos.append([0, 0])\n                if not other.adversary:\n                    other_vel.append([0, 0])\n\n        # to tell the pred when the prey are in the forest\n        prey_forest = []\n        ga = self.good_agents(world)\n        for a in ga:\n            if any([self.is_collision(a, f) for f in world.forests]):\n                prey_forest.append(np.array([1]))\n            else:\n                prey_forest.append(np.array([-1]))\n        # to tell leader when pred are in forest\n        prey_forest_lead = []\n        for f in world.forests:\n            if any([self.is_collision(a, f) for a in ga]):\n                prey_forest_lead.append(np.array([1]))\n            else:\n                prey_forest_lead.append(np.array([-1]))\n\n        comm = [world.agents[0].state.c]\n\n        if agent.adversary and not agent.leader:\n            return np.concatenate(in_forest + comm +\n                      [agent.state.p_vel] + [agent.state.p_pos] + other_vel +\n                      entity_pos + other_pos)\n        if agent.leader:\n            return np.concatenate(in_forest + comm +\n                [agent.state.p_vel] + [agent.state.p_pos] + other_vel +\n                                  entity_pos + other_pos)\n        else:\n            return np.concatenate(in_forest + [np.zeros_like(world.agents[0].state.c)] +\n               [agent.state.p_vel] + [agent.state.p_pos] + other_vel +\n              entity_pos + other_pos)\n\n\n"
  },
  {
    "path": "examples/Social_Cognition/FOToM/readme.md",
    "content": "# Readme\n\nThis project is for FOToM.\n"
  },
  {
    "path": "examples/Social_Cognition/FOToM/utils/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Social_Cognition/FOToM/utils/agents.py",
    "content": "import torch\nfrom torch import Tensor\nfrom torch.autograd import Variable\nfrom torch.optim import Adam\nfrom .networks import MLPNetwork, RNN, SNNNetwork, LSTMClassifier\nfrom .misc import hard_update, gumbel_softmax, onehot_from_logits\nfrom .noise import OUNoise\nimport time\n\nclass DDPGAgent(object):\n    \"\"\"\n    General class for DDPG agents (policy, critic, target policy, target\n    critic, exploration noise)\n    \"\"\"\n    def __init__(self, num_in_pol, num_out_pol, num_in_critic, hidden_dim=64,\n                 lr=0.01, discrete_action=True):\n        \"\"\"\n        Inputs:\n            num_in_pol (int): number of dimensions for policy input\n            num_out_pol (int): number of dimensions for policy output\n            num_in_critic (int): number of dimensions for critic input\n        \"\"\"\n        self.policy = LSTMClassifier(num_in_pol, num_out_pol,#MLPNetwork\n                                 hidden_dim,)\n                                 # constrain_out=True,\n                                 # discrete_action=discrete_action)\n        self.critic = LSTMClassifier(num_in_critic, 1,\n                                 hidden_dim,)\n                                 # constrain_out=False)\n        self.target_policy = LSTMClassifier(num_in_pol, num_out_pol,\n                                        hidden_dim,)\n                                        # constrain_out=True,\n                                        # discrete_action=discrete_action)\n        self.target_critic = LSTMClassifier(num_in_critic, 1,\n                                        hidden_dim,)\n                                        # constrain_out=False)\n        hard_update(self.target_policy, self.policy)\n        hard_update(self.target_critic, self.critic)\n        self.policy_optimizer = Adam(self.policy.parameters(), lr=lr)\n        self.critic_optimizer = Adam(self.critic.parameters(), lr=lr)\n        if not discrete_action:\n            self.exploration = OUNoise(num_out_pol)\n        else:\n            self.exploration = 0.3  # epsilon for eps-greedy\n        self.discrete_action = discrete_action\n\n    def reset_noise(self):\n        if not self.discrete_action:\n            self.exploration.reset()\n\n    def scale_noise(self, scale):\n        if self.discrete_action:\n            self.exploration = scale\n        else:\n            self.exploration.scale = scale\n\n    def step(self, obs, explore=False):\n        \"\"\"\n        Take a step forward in environment for a minibatch of observations\n        Inputs:\n            obs (PyTorch Variable): Observations for this agent\n            explore (boolean): Whether or not to add exploration noise\n        Outputs:\n            action (PyTorch Variable): Actions for this agent\n        \"\"\"\n        action = self.policy(obs)\n\n\n        if self.discrete_action:\n            if explore:\n                if action.shape[1] == 9:\n                    action = torch.cat(\n                        (gumbel_softmax(action[:, :5], hard=True), gumbel_softmax(action[:, 5:], hard=True)), 1)\n                else:\n                    action = gumbel_softmax(action, hard=True)\n            else:\n                if action.shape[1] == 9:\n                    action = torch.cat(\n                        (onehot_from_logits(action[:, :5]), onehot_from_logits(action[:, 5:])), 1)\n                else:\n                    action = onehot_from_logits(action)\n        else:  # continuous action\n            if explore:\n                action += Variable(Tensor(self.exploration.noise()),\n                                   requires_grad=False)\n            action = action.clamp(-1, 1)\n        return action\n\n    def get_params(self):\n        return {'policy': self.policy.state_dict(),\n                'critic': self.critic.state_dict(),\n                'target_policy': self.target_policy.state_dict(),\n                'target_critic': self.target_critic.state_dict(),\n                'policy_optimizer': self.policy_optimizer.state_dict(),\n                'critic_optimizer': self.critic_optimizer.state_dict()}\n\n    def load_params(self, params):\n        self.policy.load_state_dict(params['policy'])\n        self.critic.load_state_dict(params['critic'])\n        self.target_policy.load_state_dict(params['target_policy'])\n        self.target_critic.load_state_dict(params['target_critic'])\n        self.policy_optimizer.load_state_dict(params['policy_optimizer'])\n        self.critic_optimizer.load_state_dict(params['critic_optimizer'])\n\nclass DDPGAgent_RNN(object):\n    \"\"\"\n    General class for DDPG agents (policy, critic, target policy, target\n    critic, exploration noise)\n    \"\"\"\n    def __init__(self, num_in_pol, num_out_pol, num_in_critic, hidden_dim=64,\n                 lr=0.01, discrete_action=True):\n        \"\"\"\n        Inputs:\n            num_in_pol (int): number of dimensions for policy input\n            num_out_pol (int): number of dimensions for policy output\n            num_in_critic (int): number of dimensions for critic input\n        \"\"\"\n        self.policy = RNN(num_in_pol, num_out_pol,\n                                 hidden_dim=hidden_dim,\n                                 constrain_out=True,\n                                 discrete_action=discrete_action)\n        self.critic = RNN(num_in_critic, 1,\n                                 hidden_dim=hidden_dim,\n                                 constrain_out=False)\n        self.target_policy = RNN(num_in_pol, num_out_pol,\n                                        hidden_dim=hidden_dim,\n                                        constrain_out=True,\n                                        discrete_action=discrete_action)\n        self.target_critic = RNN(num_in_critic, 1,\n                                        hidden_dim=hidden_dim,\n                                        constrain_out=False)\n\n        self.policy_hidden = None\n        self.policy_target_hidden = None\n        self.critic_hidden = None\n        self.critic_target_hidden = None\n        self.num_in_pol = num_in_pol\n        self.num_out_pol = num_out_pol\n        self.hidden_dim = hidden_dim\n        hard_update(self.target_policy, self.policy)\n        hard_update(self.target_critic, self.critic)\n        self.policy_optimizer = Adam(self.policy.parameters(), lr=lr)\n        self.critic_optimizer = Adam(self.critic.parameters(), lr=lr)\n        if not discrete_action:\n            self.exploration = OUNoise(num_out_pol)\n        else:\n            self.exploration = 0.3  # epsilon for eps-greedy\n        self.discrete_action = discrete_action\n\n    def reset_noise(self):\n        if not self.discrete_action:\n            self.exploration.reset()\n\n    def scale_noise(self, scale):\n        if self.discrete_action:\n            self.exploration = scale\n        else:\n            self.exploration.scale = scale\n\n    def step(self, obs, explore=False):\n        \"\"\"\n        Take a step forward in environment for a minibatch of observations\n        Inputs:\n            obs (PyTorch Variable): Observations for this agent\n            explore (boolean): Whether or not to add exploration noise\n        Outputs:\n            action (PyTorch Variable): Actions for this agent\n        \"\"\"\n\n        action, self.policy_hidden = self.policy(obs, self.policy_hidden)\n        if self.discrete_action:\n            if explore:\n                action = gumbel_softmax(action, hard=True)\n            else:\n                action = onehot_from_logits(action)\n        else:  # continuous action\n            if explore:\n                action += Variable(Tensor(self.exploration.noise()),\n                                   requires_grad=False)\n            action = action.clamp(-1, 1)\n        return action\n\n    def get_params(self):\n        return {'policy': self.policy.state_dict(),\n                'critic': self.critic.state_dict(),\n                'target_policy': self.target_policy.state_dict(),\n                'target_critic': self.target_critic.state_dict(),\n                'policy_optimizer': self.policy_optimizer.state_dict(),\n                'critic_optimizer': self.critic_optimizer.state_dict()}\n\n    def load_params(self, params):\n        self.policy.load_state_dict(params['policy'])\n        self.critic.load_state_dict(params['critic'])\n        self.target_policy.load_state_dict(params['target_policy'])\n        self.target_critic.load_state_dict(params['target_critic'])\n        self.policy_optimizer.load_state_dict(params['policy_optimizer'])\n        self.critic_optimizer.load_state_dict(params['critic_optimizer'])\n\n    def init_hidden(self, len_ep, policy_hidden=False, policy_target_hidden=False, \\\n                    critic_hidden=False, critic_target_hidden=False):\n        # 为每个episode中的每个agent都初始化一个eval_hidden、target_hidden\n        if policy_hidden == True:\n            self.policy_hidden = torch.zeros((len_ep, self.hidden_dim))\n        if policy_target_hidden == True:\n            self.policy_target_hidden = torch.zeros((len_ep, self.hidden_dim))\n        if critic_hidden == True:\n            self.critic_hidden = torch.zeros((len_ep, self.hidden_dim))\n        if critic_target_hidden == True:\n            self.critic_target_hidden = torch.zeros((len_ep, self.hidden_dim))\n\nclass DDPGAgent_SNN(object):\n    \"\"\"\n    General class for DDPG agents (policy, critic, target policy, target\n    critic, exploration noise)\n    \"\"\"\n    def __init__(self, num_in_pol, num_out_pol, num_in_critic, output_style, hidden_dim=64,\n                 lr=0.01, discrete_action=True):\n        \"\"\"\n        Inputs:\n            num_in_pol (int): number of dimensions for policy input\n            num_out_pol (int): number of dimensions for policy output\n            num_in_critic (int): number of dimensions for critic input\n        \"\"\"\n        self.policy = SNNNetwork(num_in_pol, num_out_pol,\n                                 hidden_dim=hidden_dim,\n                                 output_style=output_style)\n        self.critic = SNNNetwork(num_in_critic, 1,\n                                 hidden_dim=hidden_dim,\n                                 output_style=output_style)\n        self.target_policy = SNNNetwork(num_in_pol, num_out_pol,\n                                        hidden_dim=hidden_dim,\n                                        output_style=output_style)\n        self.target_critic = SNNNetwork(num_in_critic, 1,\n                                        hidden_dim=hidden_dim,\n                                        output_style=output_style)\n        hard_update(self.target_policy, self.policy)\n        hard_update(self.target_critic, self.critic)\n        self.policy_optimizer = Adam(self.policy.parameters(), lr=lr)\n        self.critic_optimizer = Adam(self.critic.parameters(), lr=lr)\n        if not discrete_action:\n            self.exploration = OUNoise(num_out_pol)\n        else:\n            self.exploration = 0.3  # epsilon for eps-greedy\n        self.discrete_action = discrete_action\n\n    def reset_noise(self):\n        if not self.discrete_action:\n            self.exploration.reset()\n\n    def scale_noise(self, scale):\n        if self.discrete_action:\n            self.exploration = scale\n        else:\n            self.exploration.scale = scale\n\n    def step(self, obs, explore=False):\n        \"\"\"\n        Take a step forward in environment for a minibatch of observations\n        Inputs:\n            obs (PyTorch Variable): Observations for this agent\n            explore (boolean): Whether or not to add exploration noise\n        Outputs:\n            action (PyTorch Variable): Actions for this agent\n        \"\"\"\n        # t1 = time.time()\n        action = self.policy(obs)\n        # t2 = time.time()\n        # print('time_interaction:', t2 - t1)\n        if self.discrete_action:\n            if explore:\n                if action.shape[1] == 9:\n                    action = torch.cat(\n                        (gumbel_softmax(action[:, :5], hard=True), gumbel_softmax(action[:, 5:], hard=True)), 1)\n                else:\n                    action = gumbel_softmax(action, hard=True)\n            else:\n                if action.shape[1] == 9:\n                    action = torch.cat(\n                        (onehot_from_logits(action[:, :5]), onehot_from_logits(action[:, 5:])), 1)\n                else:\n                    action = onehot_from_logits(action)\n            # if explore:\n            #\n            #     action = gumbel_softmax(action, hard=True)\n            #\n            # else:\n            #     action = onehot_from_logits(action)\n        else:  # continuous action\n            if explore:\n                action += Variable(Tensor(self.exploration.noise()),\n                                   requires_grad=False)\n\n            action = action.clamp(-1, 1)\n\n        return action\n\n    def get_params(self):\n        return {'policy': self.policy.state_dict(),\n                'critic': self.critic.state_dict(),\n                'target_policy': self.target_policy.state_dict(),\n                'target_critic': self.target_critic.state_dict(),\n                'policy_optimizer': self.policy_optimizer.state_dict(),\n                'critic_optimizer': self.critic_optimizer.state_dict()}\n\n    def load_params(self, params):\n        self.policy.load_state_dict(params['policy'])\n        self.critic.load_state_dict(params['critic'])\n        self.target_policy.load_state_dict(params['target_policy'])\n        self.target_critic.load_state_dict(params['target_critic'])\n        self.policy_optimizer.load_state_dict(params['policy_optimizer'])\n        self.critic_optimizer.load_state_dict(params['critic_optimizer'])\n\nclass DDPGAgent_ToM(object):\n    \"\"\"\n    General class for DDPG agents (policy, critic, target policy, target\n    critic, exploration noise)\n    \"\"\"\n    def __init__(self, num_in_pol, num_out_pol, num_in_critic, num_in_mle, output_style,\n                 num_agents, device, hidden_dim=64, lr=0.01, discrete_action=True):\n        \"\"\"\n        Inputs:\n            num_in_pol (int): number of dimensions for policy input\n            num_out_pol (int): number of dimensions for policy output\n            num_in_critic (int): number of dimensions for critic input\n        \"\"\"\n        self.device = device\n        self.policy = LSTMClassifier(num_in_pol, num_out_pol,hidden_dim) #SNNNetwork\n                                 # hidden_dim=hidden_dim,\n                                 # output_style=output_style)\n        self.critic = LSTMClassifier(num_in_critic, 1,hidden_dim)\n                                 # hidden_dim=hidden_dim,\n                                 # output_style=output_style)\n        self.target_policy = LSTMClassifier(num_in_pol, num_out_pol,hidden_dim)\n                                        # hidden_dim=hidden_dim,\n                                        # output_style=output_style)\n        self.target_critic = LSTMClassifier(num_in_critic, 1,hidden_dim)\n                                        # hidden_dim=hidden_dim,\n                                        # output_style=output_style)\n        # self.policy = SNNNetwork(num_in_pol, num_out_pol,\n        #                          hidden_dim=hidden_dim,\n        #                          output_style=output_style)\n        # self.critic = SNNNetwork(num_in_critic, 1,\n        #                          hidden_dim=hidden_dim,\n        #                          output_style=output_style)\n        # self.target_policy = SNNNetwork(num_in_pol, num_out_pol,\n        #                                 hidden_dim=hidden_dim,\n        #                                 output_style=output_style)\n        # self.target_critic = SNNNetwork(num_in_critic, 1,\n        #                                 hidden_dim=hidden_dim,\n        #                                 output_style=output_style)\n        # self.mle = [SNNNetwork(num_in_mle, num_out_pol,\n        #                       hidden_dim=hidden_dim,\n        #                       output_style=output_style)] * (num_agents - 1)\n        self.mle = []\n        hard_update(self.target_policy, self.policy)\n        hard_update(self.target_critic, self.critic)\n        self.policy_optimizer = Adam(self.policy.parameters(), lr=lr)\n        self.critic_optimizer = Adam(self.critic.parameters(), lr=lr)\n        self.mle_optimizer = []\n        if not discrete_action:\n            self.exploration = OUNoise(num_out_pol)\n        else:\n            self.exploration = 0.3  # epsilon for eps-greedy\n        self.discrete_action = discrete_action\n\n    def reset_noise(self):\n        if not self.discrete_action:\n            self.exploration.reset()\n\n    def scale_noise(self, scale):\n        if self.discrete_action:\n            self.exploration = scale\n        else:\n            self.exploration.scale = scale\n\n    def step(self, obs, explore=False):\n        \"\"\"\n        Take a step forward in environment for a minibatch of observations\n        Inputs:\n            obs (PyTorch Variable): Observations for this agent\n            explore (boolean): Whether or not to add exploration noise\n        Outputs:\n            action (PyTorch Variable): Actions for this agent\n        \"\"\"\n        action = self.policy.to(self.device)(obs.to(self.device))\n        if self.discrete_action:\n            if explore:\n                if action.shape[1] == 9:\n                    action = torch.cat(\n                        (gumbel_softmax(action[:, :5], hard=True), gumbel_softmax(action[:, 5:], hard=True)), 1).cpu()\n                else:\n                    action = gumbel_softmax(action, hard=True).cpu()\n            else:\n                if action.shape[1] == 9:\n                    action = torch.cat(\n                        (onehot_from_logits(action[:, :5], hard=True), onehot_from_logits(action[:, 5:], hard=True)), 1)\n                else:\n                    action = onehot_from_logits(action).cpu()\n            # if explore:\n            #     action = gumbel_softmax(action, hard=True).cpu()\n            # else:\n            #     action = onehot_from_logits(action).cpu()\n        else:  # continuous action\n            if explore:\n                action += Variable(Tensor(self.exploration.noise()),\n                                   requires_grad=False)\n            action = action.clamp(-1, 1)\n\n        return action\n\n    def get_params(self):\n        params = {'policy': self.policy.state_dict(),\n                'critic': self.critic.state_dict(),\n                'target_policy': self.target_policy.state_dict(),\n                'target_critic': self.target_critic.state_dict(),\n                'policy_optimizer': self.policy_optimizer.state_dict(),\n                'critic_optimizer': self.critic_optimizer.state_dict(),\n                }\n        # for i in range(len(self.mle)):\n        #     params['mle%d'%i] = self.mle[i].state_dict()\n        #     params['mle_optimizer%d'%i] = self.mle_optimizer[i].state_dict()\n        return params\n\n    def load_params(self, params):\n        self.policy.load_state_dict(params['policy'])\n        self.critic.load_state_dict(params['critic'])\n        self.target_policy.load_state_dict(params['target_policy'])\n        self.target_critic.load_state_dict(params['target_critic'])\n        self.policy_optimizer.load_state_dict(params['policy_optimizer'])\n        self.critic_optimizer.load_state_dict(params['critic_optimizer'])\n        # for i in range(len(self.mle)):\n        #     self.mle[i].load_state_dict(params['mle%d'%i])\n        #     self.mle_optimizer[i].load_state_dict(params['mle_optimizer%d'%i])\n\nclass rDDPGAgent_ToM(object):\n    \"\"\"\n    General class for DDPG agents (policy, critic, target policy, target\n    critic, exploration noise)\n    \"\"\"\n    def __init__(self, num_in_pol, num_out_pol, num_in_critic, num_in_mle, output_style,\n                 num_agents, device, hidden_dim=64, lr=0.01, discrete_action=True):\n        \"\"\"\n        Inputs:\n            num_in_pol (int): number of dimensions for policy input\n            num_out_pol (int): number of dimensions for policy output\n            num_in_critic (int): number of dimensions for critic input\n        \"\"\"\n        self.device = device\n        self.policy = RNN(num_in_pol, num_out_pol,\n                                 hidden_dim=hidden_dim,\n                          constrain_out=True,\n                          discrete_action=discrete_action)\n        self.critic = RNN(num_in_critic, 1,\n                                 hidden_dim=hidden_dim,\n                                 constrain_out=True,\n                                 discrete_action=discrete_action)\n        self.target_policy = RNN(num_in_pol, num_out_pol,\n                                 hidden_dim=hidden_dim,\n                                 constrain_out=True,\n                                 discrete_action=discrete_action)\n        self.target_critic = RNN(num_in_critic, 1,\n                                 hidden_dim=hidden_dim,\n                                 constrain_out=True,\n                                 discrete_action=discrete_action)\n        # self.mle = [SNNNetwork(num_in_mle, num_out_pol,\n        #                       hidden_dim=hidden_dim,\n        #                       output_style=output_style)] * (num_agents - 1)\n        self.mle = []\n        self.policy_hidden = None\n        self.policy_target_hidden = None\n        self.critic_hidden = None\n        self.critic_target_hidden = None\n        self.hidden_dim = hidden_dim\n        hard_update(self.target_policy, self.policy)\n        hard_update(self.target_critic, self.critic)\n        self.policy_optimizer = Adam(self.policy.parameters(), lr=lr)\n        self.critic_optimizer = Adam(self.critic.parameters(), lr=lr)\n        self.mle_optimizer = []\n        if not discrete_action:\n            self.exploration = OUNoise(num_out_pol)\n        else:\n            self.exploration = 0.3  # epsilon for eps-greedy\n        self.discrete_action = discrete_action\n\n    def reset_noise(self):\n        if not self.discrete_action:\n            self.exploration.reset()\n\n    def scale_noise(self, scale):\n        if self.discrete_action:\n            self.exploration = scale\n        else:\n            self.exploration.scale = scale\n\n    def step(self, obs, explore=False):\n        \"\"\"\n        Take a step forward in environment for a minibatch of observations\n        Inputs:\n            obs (PyTorch Variable): Observations for this agent\n            explore (boolean): Whether or not to add exploration noise\n        Outputs:\n            action (PyTorch Variable): Actions for this agent\n        \"\"\"\n        action, self.policy_hidden = self.policy(obs, self.policy_hidden)\n        if self.discrete_action:\n            if explore:\n                if action.shape[1] == 9:\n                    action = torch.cat(\n                        (gumbel_softmax(action[:, :5], hard=True), gumbel_softmax(action[:, 5:], hard=True)), 1).cpu()\n                else:\n                    action = gumbel_softmax(action, hard=True).cpu()\n            else:\n                if action.shape[1] == 9:\n                    action = torch.cat(\n                        (onehot_from_logits(action[:, :5], hard=True), onehot_from_logits(action[:, 5:], hard=True)), 1)\n                else:\n                    action = onehot_from_logits(action).cpu()\n            # if explore:\n            #     action = gumbel_softmax(action, hard=True).cpu()\n            # else:\n            #     action = onehot_from_logits(action).cpu()\n        else:  # continuous action\n            if explore:\n                action += Variable(Tensor(self.exploration.noise()),\n                                   requires_grad=False)\n            action = action.clamp(-1, 1)\n\n        return action\n\n    def get_params(self):\n        params = {'policy': self.policy.state_dict(),\n                'critic': self.critic.state_dict(),\n                'target_policy': self.target_policy.state_dict(),\n                'target_critic': self.target_critic.state_dict(),\n                'policy_optimizer': self.policy_optimizer.state_dict(),\n                'critic_optimizer': self.critic_optimizer.state_dict(),\n                }\n        # for i in range(len(self.mle)):\n        #     params['mle%d'%i] = self.mle[i].state_dict()\n        #     params['mle_optimizer%d'%i] = self.mle_optimizer[i].state_dict()\n        return params\n\n    def load_params(self, params):\n        self.policy.load_state_dict(params['policy'])\n        self.critic.load_state_dict(params['critic'])\n        self.target_policy.load_state_dict(params['target_policy'])\n        self.target_critic.load_state_dict(params['target_critic'])\n        self.policy_optimizer.load_state_dict(params['policy_optimizer'])\n        self.critic_optimizer.load_state_dict(params['critic_optimizer'])\n        # for i in range(len(self.mle)):\n        #     self.mle[i].load_state_dict(params['mle%d'%i])\n        #     self.mle_optimizer[i].load_state_dict(params['mle_optimizer%d'%i])\n\n    def init_hidden(self, len_ep, policy_hidden=False, policy_target_hidden=False, \\\n                    critic_hidden=False, critic_target_hidden=False):\n        # 为每个episode中的每个agent都初始化一个eval_hidden、target_hidden\n        if policy_hidden == True:\n            self.policy_hidden = torch.zeros((len_ep, self.hidden_dim))\n        if policy_target_hidden == True:\n            self.policy_target_hidden = torch.zeros((len_ep, self.hidden_dim))\n        if critic_hidden == True:\n            self.critic_hidden = torch.zeros((len_ep, self.hidden_dim))\n        if critic_target_hidden == True:\n            self.critic_target_hidden = torch.zeros((len_ep, self.hidden_dim))\n\nclass lDDPGAgent(object):\n    \"\"\"\n    General class for DDPG agents (policy, critic, target policy, target\n    critic, exploration noise)\n    \"\"\"\n    def __init__(self, num_in_pol, num_out_pol, num_in_critic, hidden_dim=64,\n                 lr=0.01, discrete_action=True):\n        \"\"\"\n        Inputs:\n            num_in_pol (int): number of dimensions for policy input\n            num_out_pol (int): number of dimensions for policy output\n            num_in_critic (int): number of dimensions for critic input\n        \"\"\"\n        self.policy = LSTMClassifier(num_in_pol, num_out_pol,\n                                 hidden_dim=hidden_dim,\n                                 constrain_out=True,\n                                 discrete_action=discrete_action)\n        self.critic = LSTMClassifier(num_in_critic, 1,\n                                 hidden_dim=hidden_dim,\n                                 constrain_out=False)\n        self.target_policy = LSTMClassifier(num_in_pol, num_out_pol,\n                                        hidden_dim=hidden_dim,\n                                        constrain_out=True,\n                                        discrete_action=discrete_action)\n        self.target_critic = LSTMClassifier(num_in_critic, 1,\n                                        hidden_dim=hidden_dim,\n                                        constrain_out=False)\n        hard_update(self.target_policy, self.policy)\n        hard_update(self.target_critic, self.critic)\n        self.policy_optimizer = Adam(self.policy.parameters(), lr=lr)\n        self.critic_optimizer = Adam(self.critic.parameters(), lr=lr)\n        if not discrete_action:\n            self.exploration = OUNoise(num_out_pol)\n        else:\n            self.exploration = 0.3  # epsilon for eps-greedy\n        self.discrete_action = discrete_action\n\n    def reset_noise(self):\n        if not self.discrete_action:\n            self.exploration.reset()\n\n    def scale_noise(self, scale):\n        if self.discrete_action:\n            self.exploration = scale\n        else:\n            self.exploration.scale = scale\n\n    def step(self, obs, explore=False):\n        \"\"\"\n        Take a step forward in environment for a minibatch of observations\n        Inputs:\n            obs (PyTorch Variable): Observations for this agent\n            explore (boolean): Whether or not to add exploration noise\n        Outputs:\n            action (PyTorch Variable): Actions for this agent\n        \"\"\"\n        action = self.policy(obs)\n\n\n        if self.discrete_action:\n            if explore:\n                if action.shape[1] == 9:\n                    action = torch.cat(\n                        (gumbel_softmax(action[:, :5], hard=True), gumbel_softmax(action[:, 5:], hard=True)), 1)\n                else:\n                    action = gumbel_softmax(action, hard=True)\n            else:\n                if action.shape[1] == 9:\n                    action = torch.cat(\n                        (onehot_from_logits(action[:, :5]), onehot_from_logits(action[:, 5:])), 1)\n                else:\n                    action = onehot_from_logits(action)\n        else:  # continuous action\n            if explore:\n                action += Variable(Tensor(self.exploration.noise()),\n                                   requires_grad=False)\n            action = action.clamp(-1, 1)\n        return action\n\n    def get_params(self):\n        return {'policy': self.policy.state_dict(),\n                'critic': self.critic.state_dict(),\n                'target_policy': self.target_policy.state_dict(),\n                'target_critic': self.target_critic.state_dict(),\n                'policy_optimizer': self.policy_optimizer.state_dict(),\n                'critic_optimizer': self.critic_optimizer.state_dict()}\n\n    def load_params(self, params):\n        self.policy.load_state_dict(params['policy'])\n        self.critic.load_state_dict(params['critic'])\n        self.target_policy.load_state_dict(params['target_policy'])\n        self.target_critic.load_state_dict(params['target_critic'])\n        self.policy_optimizer.load_state_dict(params['policy_optimizer'])\n        self.critic_optimizer.load_state_dict(params['critic_optimizer'])"
  },
  {
    "path": "examples/Social_Cognition/FOToM/utils/buffer.py",
    "content": "import numpy as np\nimport torch\nfrom torch import Tensor\nfrom torch.autograd import Variable\n\nclass ReplayBuffer(object):\n    \"\"\"\n    Replay Buffer for multi-agent RL with parallel rollouts\n    \"\"\"\n    def __init__(self, max_steps, num_agents, obs_dims, ac_dims, device):\n        \"\"\"\n        Inputs:\n            max_steps (int): Maximum number of timepoints to store in buffer\n            num_agents (int): Number of agents in environment\n            obs_dims (list of ints): number of obervation dimensions for each\n                                     agent\n            ac_dims (list of ints): number of action dimensions for each agent\n        \"\"\"\n        self.device = device\n        self.max_steps = max_steps\n        self.num_agents = num_agents\n        self.obs_buffs = []\n        self.ac_buffs = []\n        self.rew_buffs = []\n        self.next_obs_buffs = []\n        self.done_buffs = []\n        for odim, adim in zip(obs_dims, ac_dims):\n            self.obs_buffs.append(np.zeros((max_steps, odim)))\n            self.ac_buffs.append(np.zeros((max_steps, adim)))\n            self.rew_buffs.append(np.zeros(max_steps))\n            self.next_obs_buffs.append(np.zeros((max_steps, odim)))\n            self.done_buffs.append(np.zeros(max_steps))\n\n        self.filled_i = 0  # index of first empty location in buffer (last index when full)\n        self.curr_i = 0  # current index to write to (ovewrite oldest data)\n\n    def __len__(self):\n        return self.filled_i\n\n    def push(self, observations, actions, rewards, next_observations, dones):\n        nentries = observations.shape[0]  # handle multiple parallel environments\n        if self.curr_i + nentries > self.max_steps:\n            rollover = self.max_steps - self.curr_i # num of indices to roll over\n            for agent_i in range(self.num_agents):\n                self.obs_buffs[agent_i] = np.roll(self.obs_buffs[agent_i],\n                                                  rollover, axis=0)\n                self.ac_buffs[agent_i] = np.roll(self.ac_buffs[agent_i],\n                                                 rollover, axis=0)\n                self.rew_buffs[agent_i] = np.roll(self.rew_buffs[agent_i],\n                                                  rollover)\n                self.next_obs_buffs[agent_i] = np.roll(\n                    self.next_obs_buffs[agent_i], rollover, axis=0)\n                self.done_buffs[agent_i] = np.roll(self.done_buffs[agent_i],\n                                                   rollover)\n            self.curr_i = 0\n            self.filled_i = self.max_steps\n        for agent_i in range(self.num_agents):\n            self.obs_buffs[agent_i][self.curr_i:self.curr_i + nentries] = np.vstack(\n                observations[:, agent_i])\n            # actions are already batched by agent, so they are indexed differently\n            self.ac_buffs[agent_i][self.curr_i:self.curr_i + nentries] = actions[agent_i]\n            self.rew_buffs[agent_i][self.curr_i:self.curr_i + nentries] = rewards[:, agent_i]\n            self.next_obs_buffs[agent_i][self.curr_i:self.curr_i + nentries] = np.vstack(\n                next_observations[:, agent_i])\n            self.done_buffs[agent_i][self.curr_i:self.curr_i + nentries] = dones[:, agent_i]\n        self.curr_i += nentries\n        if self.filled_i < self.max_steps:\n            self.filled_i += nentries\n        if self.curr_i == self.max_steps:\n            self.curr_i = 0\n\n    def sample(self, N, to_gpu=False, norm_rews=True):\n        inds = np.random.choice(np.arange(self.filled_i), size=N,\n                                replace=False)\n        if to_gpu:\n            cast = lambda x: Variable(Tensor(x), requires_grad=False).to(torch.device(self.device))\n        else:\n            cast = lambda x: Variable(Tensor(x), requires_grad=False)\n        if norm_rews:\n            ret_rews = [cast((self.rew_buffs[i][inds] -\n                              self.rew_buffs[i][:self.filled_i].mean()) /\n                             self.rew_buffs[i][:self.filled_i].std())\n                        for i in range(self.num_agents)]\n        else:\n            ret_rews = [cast(self.rew_buffs[i][inds]) for i in range(self.num_agents)]\n        return ([cast(self.obs_buffs[i][inds]) for i in range(self.num_agents)],\n                [cast(self.ac_buffs[i][inds]) for i in range(self.num_agents)],\n                ret_rews,\n                [cast(self.next_obs_buffs[i][inds]) for i in range(self.num_agents)],\n                [cast(self.done_buffs[i][inds]) for i in range(self.num_agents)])\n\n    def get_average_rewards(self, N):\n        if self.filled_i == self.max_steps:\n            inds = np.arange(self.curr_i - N, self.curr_i)  # allow for negative indexing\n        else:\n            inds = np.arange(max(0, self.curr_i - N), self.curr_i)\n        return [self.rew_buffs[i][inds].mean() for i in range(self.num_agents)]\n\nclass ReplayBuffer_pre(object):\n    \"\"\"\n    Replay Buffer for multi-agent RL with parallel rollouts\n    \"\"\"\n    def __init__(self, max_steps, num_agents, obs_dims, ac_dims, device):\n        \"\"\"\n        Inputs:\n            max_steps (int): Maximum number of timepoints to store in buffer\n            num_agents (int): Number of agents in environment\n            obs_dims (list of ints): number of obervation dimensions for each\n                                     agent\n            ac_dims (list of ints): number of action dimensions for each agent\n        \"\"\"\n        self.device = device\n        self.max_steps = max_steps\n        self.num_agents = num_agents\n        self.ac_pre_buffs = []\n        self.obs_buffs = []\n        self.ac_buffs = []\n        self.rew_buffs = []\n        self.next_obs_buffs = []\n        self.done_buffs = []\n        for odim, adim in zip(obs_dims, ac_dims):\n            self.ac_pre_buffs.append(np.zeros((max_steps, 5)))\n            self.obs_buffs.append(np.zeros((max_steps, odim)))\n            self.ac_buffs.append(np.zeros((max_steps, adim)))\n            self.rew_buffs.append(np.zeros(max_steps))\n            self.next_obs_buffs.append(np.zeros((max_steps, odim)))\n            self.done_buffs.append(np.zeros(max_steps))\n\n        self.filled_i = 0  # index of first empty location in buffer (last index when full)\n        self.curr_i = 0  # current index to write to (ovewrite oldest data)\n\n    def __len__(self):\n        return self.filled_i\n\n    def push(self, actions_pre, observations, actions, rewards, next_observations, dones):\n        nentries = observations.shape[0]  # handle multiple parallel environments\n        if self.curr_i + nentries > self.max_steps:\n            rollover = self.max_steps - self.curr_i # num of indices to roll over\n            for agent_i in range(self.num_agents):\n                self.ac_pre_buffs[agent_i] = np.roll(self.ac_pre_buffs[agent_i][:,:5],\n                                                 rollover, axis=0)\n                self.obs_buffs[agent_i] = np.roll(self.obs_buffs[agent_i],\n                                                  rollover, axis=0)\n                self.ac_buffs[agent_i] = np.roll(self.ac_buffs[agent_i],\n                                                 rollover, axis=0)\n                self.rew_buffs[agent_i] = np.roll(self.rew_buffs[agent_i],\n                                                  rollover)\n                self.next_obs_buffs[agent_i] = np.roll(\n                    self.next_obs_buffs[agent_i], rollover, axis=0)\n                self.done_buffs[agent_i] = np.roll(self.done_buffs[agent_i],\n                                                   rollover)\n            self.curr_i = 0\n            self.filled_i = self.max_steps\n        for agent_i in range(self.num_agents):\n            self.ac_pre_buffs[agent_i][self.curr_i:self.curr_i + nentries] = actions_pre[agent_i][:,:5]\n            self.obs_buffs[agent_i][self.curr_i:self.curr_i + nentries] = np.vstack(\n                observations[:, agent_i])\n            # actions are already batched by agent, so they are indexed differently\n            self.ac_buffs[agent_i][self.curr_i:self.curr_i + nentries] = actions[agent_i]\n            self.rew_buffs[agent_i][self.curr_i:self.curr_i + nentries] = rewards[:, agent_i]\n            self.next_obs_buffs[agent_i][self.curr_i:self.curr_i + nentries] = np.vstack(\n                next_observations[:, agent_i])\n            self.done_buffs[agent_i][self.curr_i:self.curr_i + nentries] = dones[:, agent_i]\n        self.curr_i += nentries\n        if self.filled_i < self.max_steps:\n            self.filled_i += nentries\n        if self.curr_i == self.max_steps:\n            self.curr_i = 0\n\n    def sample(self, N, to_gpu=False, norm_rews=True):\n        inds = np.random.choice(np.arange(self.filled_i), size=N,\n                                replace=False)\n        # inds = np.arange(self.filled_i)[0:-1:self.filled_i//N]\n        if to_gpu:\n            cast = lambda x: Variable(Tensor(x), requires_grad=False).to(torch.device(self.device))\n        else:\n            cast = lambda x: Variable(Tensor(x), requires_grad=False)\n        if self.rew_buffs[0].sum() == False:\n            norm_rews = False\n        if norm_rews:\n            ret_rews = [cast((self.rew_buffs[i][inds] -\n                              self.rew_buffs[i][:self.filled_i].mean()) /\n                             self.rew_buffs[i][:self.filled_i].std())\n                        for i in range(self.num_agents)]\n        else:\n            ret_rews = [cast(self.rew_buffs[i][inds]) for i in range(self.num_agents)]\n        return ([cast(self.ac_pre_buffs[i][inds]) for i in range(self.num_agents)],\n                [cast(self.obs_buffs[i][inds]) for i in range(self.num_agents)],\n                [cast(self.ac_buffs[i][inds]) for i in range(self.num_agents)],\n                ret_rews,\n                [cast(self.next_obs_buffs[i][inds]) for i in range(self.num_agents)],\n                [cast(self.done_buffs[i][inds]) for i in range(self.num_agents)])\n\n    def get_average_rewards(self, N):\n        if self.filled_i == self.max_steps:\n            inds = np.arange(self.curr_i - N, self.curr_i)  # allow for negative indexing\n        else:\n            inds = np.arange(max(0, self.curr_i - N), self.curr_i)\n        return [self.rew_buffs[i][inds].mean() for i in range(self.num_agents)]\n\n\nclass ReplayBuffer_RNN(object):\n    \"\"\"\n    Replay Buffer for multi-agent RL with parallel rollouts\n    \"\"\"\n    def __init__(self, max_steps, num_agents, obs_dims, ac_dims, ep_dims, device):\n        \"\"\"\n        Inputs:\n            max_steps (int): Maximum number of timepoints to store in buffer\n            num_agents (int): Number of agents in environment\n            obs_dims (list of ints): number of obervation dimensions for each\n                                     agent\n            ac_dims (list of ints): number of action dimensions for each agent\n            ep_dims (int): Number of steps in each episode\n        \"\"\"\n        self.device = device\n        self.max_steps = max_steps\n        self.num_agents = num_agents\n        self.obs_buffs = []\n        self.ac_buffs = []\n        self.rew_buffs = []\n        self.next_obs_buffs = []\n        self.done_buffs = []\n        for odim, adim in zip(obs_dims, ac_dims):\n            self.obs_buffs.append(np.zeros((max_steps, ep_dims, odim)))\n            self.ac_buffs.append(np.zeros((max_steps, ep_dims, adim)))\n            self.rew_buffs.append(np.zeros((max_steps, ep_dims)))\n            self.next_obs_buffs.append(np.zeros((max_steps, ep_dims, odim)))\n            self.done_buffs.append(np.zeros((max_steps, ep_dims)))\n\n        self.filled_i = 0  # index of first empty location in buffer (last index when full)\n        self.curr_i = 0  # current index to write to (ovewrite oldest data)\n\n    def __len__(self):\n        return self.filled_i\n\n    def push(self, observations_ep, actions_ep, rewards_ep, next_observations_ep, dones_ep):\n        nentries = observations_ep[0].shape[0]  # handle multiple parallel environments\n        observations_ep, actions_ep, rewards_ep, next_observations_ep, dones_ep = \\\n            np.array(observations_ep), np.array(actions_ep), np.array(rewards_ep),\\\n            np.array(next_observations_ep), np.array(dones_ep)\n        if self.curr_i + nentries > self.max_steps:\n            rollover = self.max_steps - self.curr_i # num of indices to roll over\n            for agent_i in range(self.num_agents):\n                self.obs_buffs[agent_i] = np.roll(self.obs_buffs[agent_i],\n                                                  rollover, axis=0)\n                self.ac_buffs[agent_i] = np.roll(self.ac_buffs[agent_i],\n                                                 rollover, axis=0)\n                self.rew_buffs[agent_i] = np.roll(self.rew_buffs[agent_i],\n                                                  rollover)\n                self.next_obs_buffs[agent_i] = np.roll(\n                    self.next_obs_buffs[agent_i], rollover, axis=0)\n                self.done_buffs[agent_i] = np.roll(self.done_buffs[agent_i],\n                                                   rollover)\n            self.curr_i = 0\n            self.filled_i = self.max_steps\n        for agent_i in range(self.num_agents):\n            for i in range(observations_ep[:,:,agent_i].shape[0]):\n                if i == 0:\n                    ob_ep = np.expand_dims(np.vstack(observations_ep[:,:,agent_i][i]), 0)\n                    ob_next_ep = np.expand_dims(np.vstack(next_observations_ep[:,:,agent_i][i]), 0)\n                else:\n                    ob_ep = np.vstack((ob_ep, np.expand_dims(np.vstack(observations_ep[:,:,agent_i][i]), 0)))\n                    ob_next_ep = np.vstack((ob_next_ep, np.expand_dims(np.vstack(next_observations_ep[:,:,agent_i][i]), 0)))\n\n            self.obs_buffs[agent_i][self.curr_i:self.curr_i + nentries, :] = ob_ep.transpose(1, 0, 2)\n            # actions are already batched by agent, so they are indexed differently\n            self.ac_buffs[agent_i][self.curr_i:self.curr_i + nentries, :] = actions_ep[:,:,0,:].transpose(1, 0, 2)\n            self.rew_buffs[agent_i][self.curr_i:self.curr_i + nentries, :] = rewards_ep[:, :, agent_i].transpose(1, 0)\n            self.next_obs_buffs[agent_i][self.curr_i:self.curr_i + nentries, :] = ob_next_ep.transpose(1, 0, 2)\n            self.done_buffs[agent_i][self.curr_i:self.curr_i + nentries, :] = dones_ep[:, :, agent_i].transpose(1, 0)\n        self.curr_i += nentries\n        if self.filled_i < self.max_steps:\n            self.filled_i += nentries\n        if self.curr_i == self.max_steps:\n            self.curr_i = 0\n\n    def sample(self, N, to_gpu=False, norm_rews=True):\n        inds = np.random.choice(np.arange(self.filled_i), size=N,\n                                replace=False)\n        if to_gpu:\n            cast = lambda x: Variable(Tensor(x), requires_grad=False).to(torch.device(self.device))\n        else:\n            cast = lambda x: Variable(Tensor(x), requires_grad=False)\n        if norm_rews:\n            ret_rews = [cast((self.rew_buffs[i][inds] -\n                              self.rew_buffs[i][:self.filled_i].mean()) /\n                             self.rew_buffs[i][:self.filled_i].std())\n                        for i in range(self.num_agents)]\n        else:\n            ret_rews = [cast(self.rew_buffs[i][inds]) for i in range(self.num_agents)]\n        return ([cast(self.obs_buffs[i][inds]) for i in range(self.num_agents)],\n                [cast(self.ac_buffs[i][inds]) for i in range(self.num_agents)],\n                ret_rews,\n                [cast(self.next_obs_buffs[i][inds]) for i in range(self.num_agents)],\n                [cast(self.done_buffs[i][inds]) for i in range(self.num_agents)])\n\n    def get_average_rewards(self, N):\n        if self.filled_i == self.max_steps:\n            inds = np.arange(self.curr_i - N, self.curr_i)  # allow for negative indexing\n        else:\n            inds = np.arange(max(0, self.curr_i - N), self.curr_i)\n        return [self.rew_buffs[i][inds].mean() for i in range(self.num_agents)]\n"
  },
  {
    "path": "examples/Social_Cognition/FOToM/utils/env_wrappers.py",
    "content": "\"\"\"\nModified from OpenAI Baselines code to work with multi-agent envs\n\"\"\"\nimport numpy as np\nfrom multiprocessing import Process, Pipe\nfrom common.vec_env.vec_env import VecEnv, CloudpickleWrapper\n\n\ndef worker(remote, parent_remote, env_fn_wrapper):\n    parent_remote.close()\n    env = env_fn_wrapper.x()\n    while True:\n        cmd, data = remote.recv()\n        if cmd == 'step':\n            ob, reward, done, info = env.step(data)\n            if all(done):\n                ob = env.reset()\n            remote.send((ob, reward, done, info))\n        elif cmd == 'reset':\n            ob = env.reset()\n            remote.send(ob)\n        elif cmd == 'reset_task':\n            ob = env.reset_task()\n            remote.send(ob)\n        elif cmd == 'close':\n            remote.close()\n            break\n        elif cmd == 'get_spaces':\n            remote.send((env.observation_space, env.action_space))\n        elif cmd == 'get_agent_types':\n            if all([hasattr(a, 'adversary') for a in env.agents]):\n                remote.send(['adversary' if a.adversary else 'agent' for a in\n                             env.agents])\n            else:\n                remote.send(['agent' for _ in env.agents])\n        elif cmd == 'get_num_landmarks':\n            remote.send(len(env.world.landmarks))\n        else:\n            raise NotImplementedError\n\n\nclass SubprocVecEnv(VecEnv):\n    def __init__(self, env_fns, spaces=None):\n        \"\"\"\n        envs: list of gym environments to run in subprocesses\n        \"\"\"\n        self.waiting = False\n        self.closed = False\n        nenvs = len(env_fns)\n        self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])\n        self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))\n            for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]\n        for p in self.ps:\n            p.daemon = True # if the main process crashes, we should not cause things to hang\n            p.start()\n        for remote in self.work_remotes:\n            remote.close()\n\n        self.remotes[0].send(('get_spaces', None))\n        observation_space, action_space = self.remotes[0].recv()\n        self.remotes[0].send(('get_agent_types', None))\n        self.agent_types = self.remotes[0].recv()\n        self.remotes[0].send(('get_num_landmarks', None))\n        self.num_lm = self.remotes[0].recv()\n        VecEnv.__init__(self, len(env_fns), observation_space, action_space)\n\n    def step_async(self, actions):\n        for remote, action in zip(self.remotes, actions):\n            remote.send(('step', action))\n        self.waiting = True\n\n    def step_wait(self):\n        results = [remote.recv() for remote in self.remotes]\n        self.waiting = False\n        obs, rews, dones, infos = zip(*results)\n        return np.stack(obs), np.stack(rews), np.stack(dones), infos\n\n    def reset(self):\n        for remote in self.remotes:\n            remote.send(('reset', None))\n        return np.stack([remote.recv() for remote in self.remotes])\n\n    def reset_task(self):\n        for remote in self.remotes:\n            remote.send(('reset_task', None))\n        return np.stack([remote.recv() for remote in self.remotes])\n\n    def close(self):\n        if self.closed:\n            return\n        if self.waiting:\n            for remote in self.remotes:            \n                remote.recv()\n        for remote in self.remotes:\n            remote.send(('close', None))\n        for p in self.ps:\n            p.join()\n        self.closed = True\n\n\nclass DummyVecEnv(VecEnv):\n    def __init__(self, env_fns):\n        self.envs = [fn() for fn in env_fns]\n        env = self.envs[0]        \n        VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)\n        if all([hasattr(a, 'adversary') for a in env.agents]):\n            self.agent_types = ['adversary' if a.adversary else 'agent' for a in\n                                env.agents]\n        else:\n            self.agent_types = ['agent' for _ in env.agents]\n        self.ts = np.zeros(len(self.envs), dtype='int')        \n        self.actions = None\n\n    def step_async(self, actions):\n        self.actions = actions\n\n    def step_wait(self):\n        results = [env.step(a) for (a,env) in zip(self.actions, self.envs)]\n        obs, rews, dones, infos = map(np.array, zip(*results))\n        self.ts += 1\n        for (i, done) in enumerate(dones):\n            if all(done): \n                obs[i] = self.envs[i].reset()\n                self.ts[i] = 0\n        self.actions = None\n        return np.array(obs), np.array(rews), np.array(dones), infos\n\n    def reset(self):        \n        results = [env.reset() for env in self.envs]\n        return np.array(results)\n\n    def close(self):\n        return"
  },
  {
    "path": "examples/Social_Cognition/FOToM/utils/make_env.py",
    "content": "\"\"\"\nCode for creating a multiagent environment with one of the scenarios listed\nin ./scenarios/.\nCan be called by using, for example:\n    env = make_env('simple_speaker_listener')\nAfter producing the env object, can be used similarly to an OpenAI gym\nenvironment.\n\nA policy using this environment must output actions in the form of a list\nfor all agents. Each element of the list should be a numpy array,\nof size (env.world.dim_p + env.world.dim_c, 1). Physical actions precede\ncommunication actions in this array. See environment.py for more details.\n\"\"\"\n\ndef make_env(scenario_name, num_good_agents, num_adversaries, benchmark=False, discrete_action=False):\n    '''\n    Creates a MultiAgentEnv object as env. This can be used similar to a gym\n    environment by calling env.reset() and env.step().\n    Use env.render() to view the environment on the screen.\n\n    Input:\n        scenario_name   :   name of the scenario from ./scenarios/ to be Returns\n                            (without the .py extension)\n        benchmark       :   whether you want to produce benchmarking data\n                            (usually only done during evaluation)\n\n    Some useful env properties (see environment.py):\n        .observation_space  :   Returns the observation space for each agent\n        .action_space       :   Returns the action space for each agent\n        .n                  :   Returns the number of Agents\n    '''\n    from multiagent.environment import MultiAgentEnv\n    import multiagent.scenarios as scenarios\n\n    # load scenario from script\n    scenario = scenarios.load(scenario_name + \".py\").Scenario()\n    # create world\n    world = scenario.make_world(num_good_agents, num_adversaries)\n    # create multiagent environment\n    if benchmark:        \n        env = MultiAgentEnv(world, scenario.reset_world, scenario.reward,\n                            scenario.observation, scenario.benchmark_data)\n    else:\n        env = MultiAgentEnv(world, scenario.reset_world, scenario.reward,\n                            scenario.observation)\n    return env\n"
  },
  {
    "path": "examples/Social_Cognition/FOToM/utils/misc.py",
    "content": "import os\nimport torch\nimport torch.nn.functional as F\nimport torch.distributed as dist\nfrom torch.autograd import Variable\nimport numpy as np\n\n# https://github.com/ikostrikov/pytorch-ddpg-naf/blob/master/ddpg.py#L11\ndef soft_update(target, source, tau):\n    \"\"\"\n    Perform DDPG soft update (move target params toward source based on weight\n    factor tau)\n    Inputs:\n        target (torch.nn.Module): Net to copy parameters to\n        source (torch.nn.Module): Net whose parameters to copy\n        tau (float, 0 < x < 1): Weight factor for update\n    \"\"\"\n    for target_param, param in zip(target.parameters(), source.parameters()):\n        target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)\n\n# https://github.com/ikostrikov/pytorch-ddpg-naf/blob/master/ddpg.py#L15\ndef hard_update(target, source):\n    \"\"\"\n    Copy network parameters from source to target\n    Inputs:\n        target (torch.nn.Module): Net to copy parameters to\n        source (torch.nn.Module): Net whose parameters to copy\n    \"\"\"\n    for target_param, param in zip(target.parameters(), source.parameters()):\n        target_param.data.copy_(param.data)\n\n# https://github.com/seba-1511/dist_tuto.pth/blob/gh-pages/train_dist.py\ndef average_gradients(model):\n    \"\"\" Gradient averaging. \"\"\"\n    size = float(dist.get_world_size())\n    for param in model.parameters():\n        dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM, group=0)\n        param.grad.data /= size\n\n# https://github.com/seba-1511/dist_tuto.pth/blob/gh-pages/train_dist.py\ndef init_processes(rank, size, fn, backend='gloo'):\n    \"\"\" Initialize the distributed environment. \"\"\"\n    os.environ['MASTER_ADDR'] = '127.0.0.1'\n    os.environ['MASTER_PORT'] = '29500'\n    dist.init_process_group(backend, rank=rank, world_size=size)\n    fn(rank, size)\n\ndef onehot_from_logits(logits, eps=0.0):\n    \"\"\"\n    Given batch of logits, return one-hot sample using epsilon greedy strategy\n    (based on given epsilon)\n    \"\"\"\n    # get best (according to current policy) actions in one-hot form\n    argmax_acs = (logits == logits.max(1, keepdim=True)[0]).float()\n    if eps == 0.0:\n        return argmax_acs\n    # get random actions in one-hot form\n    rand_acs = Variable(torch.eye(logits.shape[1])[[np.random.choice(\n        range(logits.shape[1]), size=logits.shape[0])]], requires_grad=False)\n    # chooses between best and random actions using epsilon greedy\n    return torch.stack([argmax_acs[i] if r > eps else rand_acs[i] for i, r in\n                        enumerate(torch.rand(logits.shape[0]))])\n\n# modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb\ndef sample_gumbel(shape, eps=1e-20, tens_type=torch.FloatTensor):\n    \"\"\"Sample from Gumbel(0, 1)\"\"\"\n    U = Variable(tens_type(*shape).uniform_(), requires_grad=False)\n    return -torch.log(-torch.log(U + eps) + eps)\n\n# modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb\ndef gumbel_softmax_sample(logits, temperature):\n    \"\"\" Draw a sample from the Gumbel-Softmax distribution\"\"\"\n    y = logits + sample_gumbel(logits.shape, tens_type=type(logits.data)).to(logits.device)\n    return F.softmax(y / temperature, dim=1)\n\n# modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb\ndef gumbel_softmax(logits, temperature=1.0, hard=False):\n    \"\"\"Sample from the Gumbel-Softmax distribution and optionally discretize.\n    Args:\n      logits: [batch_size, n_class] unnormalized log-probs\n      temperature: non-negative scalar\n      hard: if True, take argmax, but differentiate w.r.t. soft sample y\n    Returns:\n      [batch_size, n_class] sample from the Gumbel-Softmax distribution.\n      If hard=True, then the returned sample will be one-hot, otherwise it will\n      be a probabilitiy distribution that sums to 1 across classes\n    \"\"\"\n    y = gumbel_softmax_sample(logits, temperature)\n    if hard:\n        y_hard = onehot_from_logits(y)\n        y = (y_hard - y).detach() + y\n    return y\n"
  },
  {
    "path": "examples/Social_Cognition/FOToM/utils/multiprocessing.py",
    "content": "# This code is from openai baseline\n# https://github.com/openai/baselines/tree/master/baselines/common/vec_env\nimport time\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom multiprocessing import Process, Pipe\n\n\ndef _flatten_list(l):\n    assert isinstance(l, (list, tuple))\n    assert len(l) > 0\n    assert all([len(l_) > 0 for l_ in l])\n\n    return [l__ for l_ in l for l__ in l_]\n\n\ndef worker(remote, parent_remote, env_fn_wrapper):\n    parent_remote.close()\n    env = env_fn_wrapper.x()\n    while True:\n        cmd, data = remote.recv()\n        if cmd == 'step':\n            ob, reward, done, info = env.step(data)\n            if done:\n                ob = env.reset()\n            remote.send((ob, reward, done, info))\n        elif cmd == 'reset':\n            ob = env.reset()\n            remote.send(ob)\n        elif cmd == 'reset_task':\n            ob = env.reset_task()\n            remote.send(ob)\n        elif cmd == 'world':\n            remote.send(env.world)\n        elif cmd == 'render':\n            ob = env.render(mode='rgb_array')\n            # print(len(ob), 'len(frames)')\n            # print(len(ob[0]), 'len(frames[0])')\n            # print(len(ob[0][0]), 'len(frames[0][0])')\n            remote.send(ob)  # rgb_array\n        elif cmd == 'observe':\n            ob = env.observe(data)\n            remote.send(ob)\n        elif cmd == 'agents':\n            remote.send(env.agents)\n        elif cmd == 'spec':\n            remote.send(env.spec)\n        elif cmd == 'get_spaces':\n            remote.send((env.observation_space, env.action_space))\n        elif cmd == 'close':\n            remote.close()\n            break\n        else:\n            raise NotImplementedError\n\n\nclass VecEnv(object):\n    \"\"\"\n    An abstract asynchronous, vectorized environment.\n    \"\"\"\n    closed = False\n    viewer = None\n\n    metadata = {\n        'render.modes': ['human', 'rgb_array']\n    }\n    def __init__(self, num_envs, observation_space, action_space):\n        self.num_envs = num_envs\n        self.observation_space = observation_space\n        self.action_space = action_space\n\n    def observe(self, agent):\n        pass\n\n    def reset(self):\n        \"\"\"\n        Reset all the environments and return an array of\n        observations, or a tuple of observation arrays.\n        If step_async is still doing work, that work will\n        be cancelled and step_wait() should not be called\n        until step_async() is invoked again.\n        \"\"\"\n        pass\n\n    def step_async(self, actions):\n        \"\"\"\n        Tell all the environments to start taking a step\n        with the given actions.\n        Call step_wait() to get the results of the step.\n        You should not call this if a step_async run is\n        already pending.\n        \"\"\"\n        pass\n\n    def step_wait(self):\n        \"\"\"\n        Wait for the step taken with step_async().\n        Returns (obs, rews, dones, infos):\n         - obs: an array of observations, or a tuple of\n                arrays of observations.\n         - rews: an array of rewards\n         - dones: an array of \"episode done\" booleans\n         - infos: a sequence of info objects\n        \"\"\"\n        pass\n\n    def close(self):\n        \"\"\"\n        Clean up the environments' resources.\n        \"\"\"\n        pass\n\n    def step(self, actions):\n        self.step_async(actions)\n        return self.step_wait()\n\n    def render(self, mode='human'):\n        imgs = self.get_images()\n        bigimg = self.tile_images(imgs)\n        if mode == 'human':\n            self.get_viewer().imshow(bigimg)    #\n            return self.get_viewer().isopen\n\n        elif mode == 'rgb_array':\n            return bigimg\n        else:\n            raise NotImplementedError\n\n    def get_images(self):\n        \"\"\"\n        Return RGB images from each environment\n        \"\"\"\n        raise NotImplementedError\n\n    def get_viewer(self):\n        if self.viewer is None:\n            from common import rendering\n            self.viewer = rendering.SimpleImageViewer()\n        return self.viewer\n\n    def tile_images(self, img_nhwc):\n        \"\"\"\n        Tile N images into one big PxQ image\n        (P,Q) are chosen to be as close as possible, and if N\n        is square, then P=Q.\n        input: img_nhwc, list or array of images, ndim=4 once turned into array\n            n = batch index, h = height, w = width, c = channel\n        returns:\n            bigim_HWc, ndarray with ndim=3\n        \"\"\"\n        img_nhwc = np.asarray(img_nhwc)\n        N, h, w, c = img_nhwc.shape\n        H = int(np.ceil(np.sqrt(N)))\n        W = int(np.ceil(float(N) / H))\n        img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0] * 0 for _ in range(N, H * W)])\n        img_HWhwc = img_nhwc.reshape(H, W, h, w, c)\n        img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4)\n        img_Hh_Ww_c = img_HhWwc.reshape(H * h, W * w, c)\n        return img_Hh_Ww_c\n\nclass CloudpickleWrapper(object):\n    \"\"\"\n    Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)\n    \"\"\"\n\n    def __init__(self, x):\n        self.x = x\n\n    def __getstate__(self):\n        import cloudpickle\n        return cloudpickle.dumps(self.x)\n\n    def __setstate__(self, ob):\n        import pickle\n        self.x = pickle.loads(ob)\n\n\nclass SubprocVecEnv(VecEnv):\n    def __init__(self, env_fns, spaces=None):\n        \"\"\"\n        envs_sc: list of gym environments to run in subprocesses\n        \"\"\"\n        # self.venv = venv\n        self.waiting = False\n        self.closed = False\n        nenvs = len(env_fns)\n        self.nenvs = nenvs\n        self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])\n        self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))\n                   for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]\n        for p in self.ps:\n            p.daemon = True  # if the main process crashes, we should not cause things to hang\n            p.start()\n        for remote in self.work_remotes:\n            remote.close()\n\n        self.remotes[0].send(('get_spaces', None))\n        observation_space, action_space = self.remotes[0].recv()\n\n        VecEnv.__init__(self, len(env_fns), observation_space, action_space)\n\n    def step_async(self, actions):\n        for remote, action in zip(self.remotes, actions):       # the input of step() : action\n            remote.send(('step', action))\n        self.waiting = True\n\n    def step_wait(self):\n        results = [remote.recv() for remote in self.remotes]    # the output of step() : zip(*results)\n        self.waiting = False\n        obs, rews, dones, infos = zip(*results)\n        return np.stack(obs), np.stack(rews), np.stack(dones), infos\n\n    def step_wait_2(self):\n        results = [remote.recv() for remote in self.remotes]\n        self.waiting = False\n        reward, done, _cumulative_rewards = zip(*results)\n        return reward, done, _cumulative_rewards\n\n    def step_wait_3(self):\n        results = [remote.recv() for remote in self.remotes]    # the output of step() : zip(*results)\n        self.waiting = False\n        obs, rews, dones, infos = zip(*results)\n        return np.stack(obs), np.stack(rews), np.stack(dones), infos\n\n    def reset(self):\n        for remote in self.remotes:\n            remote.send(('reset', None))\n        return np.stack([remote.recv() for remote in self.remotes])\n\n    def agents(self):\n        for remote in self.remotes:\n            remote.send(('agents', None))\n        return np.stack([remote.recv() for remote in self.remotes])\n\n    def world(self):\n        for remote in self.remotes:\n            remote.send(('world', None))\n        return np.stack([remote.recv() for remote in self.remotes])\n\n    def reset_task(self):\n        for remote in self.remotes:\n            remote.send(('reset_task', None))\n        return np.stack([remote.recv() for remote in self.remotes])\n\n    def spec(self):\n        for remote in self.remotes:\n            remote.send(('spec', None))\n        return np.stack([remote.recv() for remote in self.remotes])\n\n    def get_images(self):\n        # self._assert_not_closed()\n        for pipe in self.remotes:\n            pipe.send(('render', None))\n        imgs = [pipe.recv() for pipe in self.remotes]\n        # imgs = _flatten_list(imgs)\n        return imgs\n\n    def observe(self, agent):\n        for remote, agent in zip(self.remotes, agent):\n            remote.send(('observe', agent))\n        return np.stack([remote.recv() for remote in self.remotes])\n\n    # def render(self, mode='human'):\n    #     return self.venv.render(mode=mode)\n\n    def close(self):\n        if self.closed:\n            return\n        if self.waiting:\n            for remote in self.remotes:\n                remote.recv()\n        for remote in self.remotes:\n            remote.send(('close', None))\n        for p in self.ps:\n            p.join()\n            self.closed = True\n\n    def __len__(self):\n        return self.nenvs\n\ndef _flatten_list(l):\n    assert isinstance(l, (list, tuple))\n    assert len(l) > 0\n    assert all([len(l_) > 0 for l_ in l])\n\n    return [l__ for l_ in l for l__ in l_]\n\nclass DummyVecEnv(VecEnv):\n    def __init__(self, env_fns):\n        self.envs = [fn() for fn in env_fns]\n        env = self.envs[0]\n        VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)\n        if all([hasattr(a, 'adversary') for a in env.agents]):\n            self.agent_types = ['adversary' if a.adversary else 'agent' for a in\n                                env.agents]\n        else:\n            self.agent_types = ['agent' for _ in env.agents]\n        self.ts = np.zeros(len(self.envs), dtype='int')\n        self.actions = None\n\n    def step_async(self, actions):\n        self.actions = actions\n\n    def step_wait(self):\n        results = [env.step(a) for (a,env) in zip(self.actions, self.envs)]\n        obs, rews, dones, infos = map(np.array, zip(*results))\n        self.ts += 1\n        for (i, done) in enumerate(dones):\n            if all(done):\n                obs[i] = self.envs[i].reset()\n                self.ts[i] = 0\n        self.actions = None\n        return np.array(obs), np.array(rews), np.array(dones), infos\n\n    def reset(self):\n        results = [env.reset() for env in self.envs]\n        return np.array(results)\n\n    def close(self):\n        return"
  },
  {
    "path": "examples/Social_Cognition/FOToM/utils/networks.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport time\n\nclass MLPNetwork(nn.Module):\n    \"\"\"\n    MLP network (can be used as value or policy)\n    \"\"\"\n    def __init__(self, input_dim, out_dim, hidden_dim=64, nonlin=F.relu,\n                 constrain_out=False, norm_in=True, discrete_action=True):\n        \"\"\"\n        Inputs:\n            input_dim (int): Number of dimensions in input\n            out_dim (int): Number of dimensions in output\n            hidden_dim (int): Number of hidden dimensions\n            nonlin (PyTorch function): Nonlinearity to apply to hidden layers\n        \"\"\"\n        super(MLPNetwork, self).__init__()\n\n        if norm_in:  # normalize inputs\n            self.in_fn = nn.BatchNorm1d(input_dim)    #train\n            # self.in_fn = input_dim  #test\n            self.in_fn.weight.data.fill_(1)\n            self.in_fn.bias.data.fill_(0)\n        else:\n            self.in_fn = lambda x: x\n        self.fc1 = nn.Linear(input_dim, hidden_dim)\n        self.fc2 = nn.Linear(hidden_dim, hidden_dim)\n        self.fc3 = nn.Linear(hidden_dim, out_dim)\n        self.nonlin = nonlin\n        if constrain_out and not discrete_action:\n            # initialize small to prevent saturation\n            self.fc3.weight.data.uniform_(-3e-3, 3e-3)\n            self.out_fn = F.tanh\n        else:  # logits for discrete action (will softmax later)\n            self.out_fn = lambda x: x\n\n    def forward(self, X):\n        \"\"\"\n        Inputs:\n            X (PyTorch Matrix): Batch of observations\n        Outputs:\n            out (PyTorch Matrix): Output of network (actions, values, etc)\n        \"\"\"\n        h1 = self.nonlin(self.fc1(self.in_fn(X)))\n        h2 = self.nonlin(self.fc2(h1))\n        out = self.out_fn(self.fc3(h2))\n        return out\n\nclass RNN(nn.Module):\n    # Because all the agents_sc share the same network_sc, input_shape=obs_shape+n_actions+n_agents\n    def __init__(self, input_dim, out_dim, hidden_dim=64, nonlin=F.relu,\n                 constrain_out=False, norm_in=True, discrete_action=True):\n        super(RNN, self).__init__()\n        self.rnn_hidden_dim = hidden_dim\n        if norm_in:  # normalize inputs\n            self.in_fn = nn.BatchNorm1d(input_dim)\n            self.in_fn.weight.data.fill_(1)\n            self.in_fn.bias.data.fill_(0)\n        else:\n            self.in_fn = lambda x: x\n        self.fc1 = nn.Linear(input_dim, hidden_dim)\n        self.rnn = nn.GRUCell(hidden_dim, hidden_dim)\n        self.fc2 = nn.Linear(hidden_dim, out_dim)\n        self.nonlin = nonlin\n        if constrain_out and not discrete_action:\n            # initialize small to prevent saturation\n            self.fc3.weight.data.uniform_(-3e-3, 3e-3)\n            self.out_fn = F.tanh\n        else:  # logits for discrete action (will softmax later)\n            self.out_fn = lambda x: x\n    def forward(self, obs, hidden_state):\n        x = self.nonlin(self.fc1(obs))\n        # x = x.reshape(-1, self.rnn_hidden_dim)\n        # h_in = hidden_state.reshape(-1, self.rnn_hidden_dim)\n        h = self.rnn(x, hidden_state)\n        q = self.fc2(h)\n        return q, h\n\nclass BCNoSpikingLIFNode(LIFNode):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n    def forward(self, dv: torch.Tensor):\n        self.integral(dv)\n        return self.mem\n\nclass SNNNetwork(nn.Module):\n    \"\"\"\n    SNN network (can be used as value or policy or MLE)\n    \"\"\"\n    def __init__(self, input_dim, out_dim, hidden_dim=64, node=LIFNode, time_window=16,\n                 norm_in=True, output_style='sum'):\n        \"\"\"\n        Inputs:\n            input_dim (int): Number of dimensions in input\n            out_dim (int): Number of dimensions in output\n            hidden_dim (int): Number of hidden dimensions\n            nonlin (PyTorch function): Nonlinearity to apply to hidden layers\n        \"\"\"\n        super(SNNNetwork, self).__init__()\n\n        self._threshold = 0.5\n        self.v_reset = 0.0\n        self._time_window = time_window\n        self.output_style = output_style\n        self._node1 = node(threshold=self._threshold, v_reset=self.v_reset)\n        self._node2 = node(threshold=self._threshold, v_reset=self.v_reset)\n\n        if norm_in:  # normalize inputs\n            self.in_fn = nn.BatchNorm1d(input_dim)    #train\n            self.in_fn.weight.data.fill_(1)\n            self.in_fn.bias.data.fill_(0)\n        else:\n            self.in_fn = lambda x: x\n\n        self.fc1 = nn.Linear(input_dim, hidden_dim)\n        self.fc2 = nn.Linear(hidden_dim, hidden_dim)\n        self.fc3 = nn.Linear(hidden_dim, out_dim)\n\n        if self.output_style == 'sum':\n            self._out_node = lambda x: x\n        elif self.output_style == 'voltage':\n            self._out_node = BCNoSpikingLIFNode()\n\n    def reset(self):\n        for mod in self.modules():\n            if hasattr(mod, 'n_reset'):\n                mod.n_reset()\n\n    def forward(self, X):\n        qs = []\n        self.reset()\n        for t in range(self._time_window):\n            x = self.fc1((self.in_fn(X)+0.5)) #train\n            # x = self.fc1((X + 0.5)) #test\n            x = self._node1(x)\n            x = self.fc2(x)\n            x = self._node2(x)\n            x = self.fc3(x)\n            x = self._out_node(x)\n            qs.append(x)\n\n        if self.output_style == 'sum':\n            outputs = sum(qs) / self._time_window\n            return outputs\n        elif self.output_style == 'voltage':\n            outputs = x\n            return outputs\n\n\nclass LSTMClassifier(nn.Module):\n    def __init__(self, input_size, output_size, hidden_size=256):\n        super(LSTMClassifier, self).__init__()\n        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=1, batch_first=True)\n        self.fc = nn.Linear(hidden_size, output_size)\n\n    def forward(self, x):\n        x = x.unsqueeze(1)\n        output, (h, c) = self.lstm(x)\n        output = output.squeeze(1)\n        out = self.fc(output)\n\n        return out\n"
  },
  {
    "path": "examples/Social_Cognition/FOToM/utils/noise.py",
    "content": "import numpy as np\n\n\n# from https://github.com/songrotek/DDPG/blob/master/ou_noise.py\nclass OUNoise:\n    def __init__(self, action_dimension, scale=0.1, mu=0, theta=0.15, sigma=0.2):\n        self.action_dimension = action_dimension\n        self.scale = scale\n        self.mu = mu\n        self.theta = theta\n        self.sigma = sigma\n        self.state = np.ones(self.action_dimension) * self.mu\n        self.reset()\n\n    def reset(self):\n        self.state = np.ones(self.action_dimension) * self.mu\n\n    def noise(self):\n        x = self.state\n        dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(len(x))\n        self.state = x + dx\n        return self.state * self.scale\n"
  },
  {
    "path": "examples/Social_Cognition/Intention_Prediction/Intention_Prediction.py",
    "content": "import numpy as np\nimport torch,os,sys\nfrom torch import nn\nfrom torch.nn import Parameter \n\nimport abc\nimport math\nfrom abc import ABC\n\nimport numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import Parameter\nimport torch.nn.functional as F\nimport matplotlib.pyplot as plt\nfrom BrainCog.base.strategy.surrogate import *\n\nimport os\nos.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\"\nimport random\n\nfrom BrainCog.base.node.node import *\nfrom BrainCog.base.learningrule.STDP import MutliInputSTDP\n\nclass CustomLinear(nn.Module):\n    def __init__(self, weight,mask=None):\n        super().__init__()\n\n        self.weight = nn.Parameter(weight, requires_grad=True)\n        self.mask=mask\n    def forward(self, x: torch.Tensor):\n        #\n        # ret.shape = [C]\n        return x.mul(self.weight) # Changed\n\n    def update(self, dw):\n        with torch.no_grad():\n            if self.mask is not None:\n                dw *= self.mask\n            self.weight.data+= dw\n\n\nclass DLPFCNet(nn.Module):\n    def __init__(self,connection):\n        super().__init__()\n        # DLPFC, BG     \n        self.node = []\n        self.node.append(IzhNodeMU(threshold=30., a=0.02, b=0.60, c=-65., d=8., mem=-70.)) \n        self.node.append(IzhNodeMU(threshold=30., a=0.02, b=0.60, c=-65., d=8., mem=-70.)) \n\n        self.learning_rule = []\n        self.connection = connection\n\n        self.out_DLPFC=torch.zeros((self.connection[0].weight.shape[1]), dtype=torch.float) # Input-DLPFC\n        self.out_BG=torch.zeros((self.connection[1].weight.shape[1]), dtype=torch.float) # DLPFC-BG\n\n    def forward(self, input): \n        self.out_DLPFC=self.node[0](self.connection[0](input))   \n        self.out_BG=self.node[1](self.connection[1](self.out_DLPFC))\n\n        BG_Spike = self.node[1].spike\n\n        if sum(sum(BG_Spike)).item() > 1:\n            num_neuron = len(BG_Spike)\n            BG_Spike_index = torch.argmax(BG_Spike)\n            BG_Spike_index_x = torch.floor(BG_Spike_index/num_neuron) \n            BG_Spike_index_y = BG_Spike_index - BG_Spike_index_x*num_neuron \n\n            BG_Spike = torch.zeros([num_neuron, num_neuron], dtype=torch.float) \n            BG_Spike[BG_Spike_index_x.long()][BG_Spike_index_y.long()] = 1 \n        return BG_Spike\n\n    def reset(self):\n        for i in range(len(self.node)):\n            self.node[i].n_reset()\n        for i in range(len(self.learning_rule)):\n            self.learning_rule[i].reset()\n\n    def UpdateWeight(self, i, W):\n        self.connection[i].weight.data = W\n\n\nclass OFCNet(nn.Module):\n    def __init__(self,connection):\n        super().__init__()\n        # OFC, MOFC, LOFC    \n        self.node = []\n        self.node.append(IzhNodeMU(threshold=30., a=0.02, b=0.60, c=-65., d=8., mem=-70.)) # OFC_1 \n        self.node.append(IzhNodeMU(threshold=30., a=0.02, b=0.60, c=-65., d=8., mem=-70.)) # OFC_2 \n        self.node.append(IzhNodeMU(threshold=30., a=0.02, b=0.60, c=-65., d=8., mem=-70.)) # MOFC\n        self.node.append(IzhNodeMU(threshold=30., a=0.02, b=0.60, c=-65., d=8., mem=-70.)) # LOFC\n        \n        self.connection = connection\n        self.learning_rule = []\n\n        self.learning_rule.append(MutliInputSTDP(self.node[3], [self.connection[3],self.connection[4]])) # OFC_2-LOFC, MOFC-LOFC\n        self.learning_rule.append(MutliInputSTDP(self.node[3], [self.connection[3],self.connection[5]])) # OFC_2-LOFC, OFC_1-LOFC\n\n        self.out_OFC_1=torch.zeros((self.connection[0].weight.shape[1]), dtype=torch.float) \n        self.out_OFC_2=torch.zeros((self.connection[1].weight.shape[1]), dtype=torch.float)\n        self.out_MOFC=torch.zeros((self.connection[2].weight.shape[1]), dtype=torch.float)\n        self.out_LOFC=torch.zeros((self.connection[5].weight.shape[1]), dtype=torch.float)\n        \n\n    def forward(self, Input_Tha, Input_SNc, Reward): \n        self.out_OFC_1 = self.node[0](self.connection[0](Input_Tha))\n        self.out_OFC_2 = self.node[1](self.connection[1](Input_SNc))\n        if Reward == 1:\n            self.out_MOFC = self.node[2](self.connection[2](self.out_OFC_1))\n            self.out_LOFC, dw_lofc = self.learning_rule[0](self.out_OFC_2, self.out_MOFC)\n        else:\n            self.out_MOFC = self.node[2](self.connection[2](self.out_OFC_1*0)) \n            self.out_LOFC, dw_lofc = self.learning_rule[1](self.out_OFC_2, self.out_OFC_1)\n        \n        MOFC_Spike = self.node[2].spike\n        LOFC_Spike = self.node[3].spike\n\n        return MOFC_Spike, LOFC_Spike\n    \n    def reset(self):\n        for i in range(len(self.node)):\n            self.node[i].n_reset()\n        for i in range(len(self.learning_rule)):\n            self.learning_rule[i].reset()\n\n\nclass BGNet(nn.Module):\n    def __init__(self,connection):\n        super().__init__()\n        # DLPFC, StrD1, StrD2       \n        self.node = []\n        self.node.append(IzhNodeMU(threshold=30., a=0.02, b=0.60, c=-65., d=8., mem=-70.)) # DLPFC\n        self.node.append(IzhNodeMU(threshold=30., a=0.01, b=0.01, c=-65., d=8., mem=-70.)) # StrD1 \n        self.node.append(IzhNodeMU(threshold=30., a=0.1, b=0.5, c=-65., d=8., mem=-70.)) # StrD2\n        \n        self.connection = connection\n        self.learning_rule = []\n        \n        self.out_DLPFC=torch.zeros((self.connection[0].weight.shape[1]), dtype=torch.float)\n        self.out_StrD1=torch.zeros((self.connection[1].weight.shape[1]), dtype=torch.float)\n        self.out_StrD2=torch.zeros((self.connection[2].weight.shape[1]), dtype=torch.float)\n\n\n    def forward(self, input1, input2, input3): \n        self.out_DLPFC=self.node[0](self.connection[0](input1))       \n        self.out_StrD1=self.node[1](self.connection[1](input2))\n        self.out_StrD2=self.node[2](self.connection[2](input3))\n\n        DLPFC_out = self.node[0].spike\n        BG_out = self.node[1].spike + self.node[2].spike\n\n        return DLPFC_out, BG_out\n\n    def reset(self):\n        for i in range(len(self.node)):\n            self.node[i].n_reset()\n        for i in range(len(self.learning_rule)):\n            self.learning_rule[i].reset()\n    def UpdateWeight(self, i, W):\n        self.connection[i].weight.data = W\n\n\ndef STDP(Pre_mat, Post_mat, W):  \n    T_Pre = 0\n    T_Post = 0\n    for i in range(len(Pre_mat)):       \n        C_Pre = Pre_mat[i]\n        C_Post = Post_mat[i]\n        if sum(sum(C_Pre)) > 0:\n            T_Pre = i\n            Spike_Pre = Pre_mat[T_Pre]\n        if sum(sum(C_Post)) > 0:\n            T_Post = i\n            Spike_Post = Post_mat[T_Post]\n        if T_Pre*T_Post > 0:\n            dT = T_Pre - T_Post   \n    \n            A_up = 0.777\n            A_down = -0.237\n            tau_up = 16.8\n            tau_down = -33.7\n            if dT < 0:\n                dW = A_up * math.exp(dT/tau_up)\n            else:\n                dW = A_down * math.exp(dT/tau_down)           \n            T_Post = 0         \n            dW_mat = torch.mul(Spike_Pre, Spike_Post)*dW\n            W = W + torch.mul(dW_mat, W)\n    return W\n              \n    \nif __name__==\"__main__\":\n    # number of neurons\n    num_neuron = 6\n    num_DLPFC = num_neuron \n    num_BG  = num_neuron\n    num_StrD1 = num_neuron\n    num_StrD2 = num_neuron\n    num_Thalamus = num_neuron\n    num_OFC = num_neuron\n    num_SNc = num_neuron\n    num_PMC = num_neuron\n\n\n    ##############################\n    # DLPFC\n    ##############################\n    WeightAdd = 20\n    # DLPFC-BG\n    DLPFC_BG_connection = []\n    # Input-DLPFC\n    con_matrix0 = torch.ones([num_DLPFC, num_DLPFC], dtype=torch.float)*WeightAdd\n    DLPFC_BG_connection.append(CustomLinear(con_matrix0))\n    # DLPFC-BG\n    W = torch.ones([num_DLPFC, num_BG], dtype=torch.float)*WeightAdd\n    DLPFC_BG_connection.append(CustomLinear(W))\n\n    DLPFC = DLPFCNet(DLPFC_BG_connection)\n\n\n    ##############################\n    # OFC\n    ##############################\n    WeightAdd = 20\n    OFC_connection = []\n    # Tha-OFC_1 (Input1)  \n    con_matrix0 = torch.ones([num_Thalamus, num_OFC], dtype=torch.float)*WeightAdd\n    OFC_connection.append(CustomLinear(con_matrix0))   \n    # SNc/VTA-OFC_2 (Input2)\n    con_matrix1 = torch.ones([num_SNc, num_OFC], dtype=torch.float)*WeightAdd\n    OFC_connection.append(CustomLinear(con_matrix1))\n    # OFC_1-MOFC\n    con_matrix2 = torch.ones([num_OFC, num_OFC], dtype=torch.float)*WeightAdd*5\n    OFC_connection.append(CustomLinear(con_matrix2)) \n    # OFC_2-LOFC\n    con_matrix3 = torch.ones([num_OFC, num_OFC], dtype=torch.float)*WeightAdd*5\n    OFC_connection.append(CustomLinear(con_matrix3)) \n    # MOFC-LOFC\n    con_matrix4 = torch.ones([num_OFC, num_OFC], dtype=torch.float)*WeightAdd*-10\n    OFC_connection.append(CustomLinear(con_matrix4)) \n    # OFC_1-LOFC\n    con_matrix5 = torch.ones([num_OFC, num_OFC], dtype=torch.float)*WeightAdd*5\n    OFC_connection.append(CustomLinear(con_matrix5)) \n\n    OFC = OFCNet(OFC_connection)\n\n\n    ##############################\n    # BGNet\n    ##############################\n    BG_connection = []\n    WeightAdd = 20\n    # Input1-DLPFC\n    con_matrix0 = torch.ones([num_DLPFC, num_DLPFC], dtype=torch.float)*WeightAdd\n    BG_connection.append(CustomLinear(con_matrix0))\n    # Input2-StrD1\n    con_matrix1 = torch.ones([num_StrD1,num_StrD1], dtype=torch.float)*WeightAdd\n    BG_connection.append(CustomLinear(con_matrix1))\n    # Input3-StrD2\n    con_matrix2 = torch.ones([num_StrD2,num_StrD2], dtype=torch.float)*WeightAdd\n    BG_connection.append(CustomLinear(con_matrix2))\n    BG_connection.append(CustomLinear(W))\n    # StrD1-BG\n    con_matrix4 = torch.ones([num_StrD1,num_BG], dtype=torch.float)*WeightAdd\n    BG_connection.append(CustomLinear(con_matrix4))\n    # StrD2-BG\n    con_matrix5 = torch.ones([num_StrD2,num_BG], dtype=torch.float)*WeightAdd\n    BG_connection.append(CustomLinear(con_matrix5))\n\n    BG = BGNet(BG_connection)\n\n\n    ##############################\n    # Train\n    ##############################\n    # Intention-action corresponding rules\n    Intention_mat = range(num_neuron)\n    Action_mat = range(num_neuron)\n\n    TrainNum = 0\n    for k in range(len(Intention_mat)):\n        Intention = Intention_mat[k] \n        Intention_Action = Action_mat[k] \n        for j in range (len(Intention_mat)+1): \n\n            TrainNum = TrainNum + 1\n\n            # Intention prediction\n            for i in range(10):\n                DLPFC_Input = torch.zeros([num_DLPFC, num_DLPFC], dtype=torch.float)\n                DLPFC_Input[Intention,:] = 10\n                BG_Spike = DLPFC(DLPFC_Input)\n                if sum(sum(BG_Spike)).item() > 0:\n                    Action = torch.nonzero(BG_Spike).numpy()[0][1]\n                    break\n            DLPFC.reset()\n            \n            PMC = torch.zeros([1, num_PMC], dtype=torch.float)\n            PMC[0][Action] = 1\n\n            Thalamus = torch.zeros([num_Thalamus, num_Thalamus], dtype=torch.float) \n            Thalamus[Intention][Action] = 10 \n                  \n            if Intention_Action == Action:\n                # Positive reward\n                # Tha-OFC_1-MOFC, SNc-OFC_2&MOFC-LOFC\n                Reward = 1\n                Input_Tha = Thalamus \n                Input_SNc_Reward = torch.ones([num_OFC, num_OFC], dtype=torch.float) \n                Input_SNc = torch.mul(Input_SNc_Reward, PMC)*10 \n                for t in range(10):\n                    MOFC_Spike, LOFC_Spike = OFC(Input_Tha, Input_SNc, Reward)\n                \n            else:\n                # Negative reward\n                # Tha-OFC_1-LOFC, SNc-OFC_2-LOFC (SNC is zeros)\n                Reward = -1\n                Input_Tha = Thalamus \n                Input_SNc = torch.zeros([num_OFC, num_OFC], dtype=torch.float) \n                for t in range(10):\n                    MOFC_Spike, LOFC_Spike = OFC(Input_Tha, Input_SNc, Reward)\n            OFC.reset()\n\n            for i in range(1):\n                DLPFC_out_mat = []\n                BG_out_mat = []\n                State = 0\n                for t in range(10):   \n                    DLPFC_Input = MOFC_Spike + LOFC_Spike\n                    StrD1_Input = MOFC_Spike\n                    StrD2_Input = LOFC_Spike\n\n                    DLPFC_out, BG_out = BG(DLPFC_Input, StrD1_Input, StrD2_Input)\n                    DLPFC_out_mat.append(DLPFC_out)\n                    BG_out_mat.append(BG_out)\n                \n                W = STDP(DLPFC_out_mat, BG_out_mat, W)\n                BG.reset()\n\n                DLPFC.UpdateWeight(1, W)\n                BG.UpdateWeight(3, W)\n\n            if Reward == 1:\n                break\n    \n    print(\"Train End\")\n    print(\"W is: \\n\", W)\n    print(\"TrainNum is: \\n\", TrainNum)\n    print(\"*****************************\")\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/LICENSE",
    "content": "                    GNU GENERAL PUBLIC LICENSE\n                       Version 3, 29 June 2007\n\n Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>\n Everyone is permitted to copy and distribute verbatim copies\n of this license document, but changing it is not allowed.\n\n                            Preamble\n\n  The GNU General Public License is a free, copyleft license for\nsoftware and other kinds of works.\n\n  The licenses for most software and other practical works are designed\nto take away your freedom to share and change the works.  By contrast,\nthe GNU General Public License is intended to guarantee your freedom to\nshare and change all versions of a program--to make sure it remains free\nsoftware for all its users.  We, the Free Software Foundation, use the\nGNU General Public License for most of our software; it applies also to\nany other work released this way by its authors.  You can apply it to\nyour programs, too.\n\n  When we speak of free software, we are referring to freedom, not\nprice.  Our General Public Licenses are designed to make sure that you\nhave the freedom to distribute copies of free software (and charge for\nthem if you wish), that you receive source code or can get it if you\nwant it, that you can change the software or use pieces of it in new\nfree programs, and that you know you can do these things.\n\n  To protect your rights, we need to prevent others from denying you\nthese rights or asking you to surrender the rights.  Therefore, you have\ncertain responsibilities if you distribute copies of the software, or if\nyou modify it: responsibilities to respect the freedom of others.\n\n  For example, if you distribute copies of such a program, whether\ngratis or for a fee, you must pass on to the recipients the same\nfreedoms that you received.  You must make sure that they, too, receive\nor can get the source code.  And you must show them these terms so they\nknow their rights.\n\n  Developers that use the GNU GPL protect your rights with two steps:\n(1) assert copyright on the software, and (2) offer you this License\ngiving you legal permission to copy, distribute and/or modify it.\n\n  For the developers' and authors' protection, the GPL clearly explains\nthat there is no warranty for this free software.  For both users' and\nauthors' sake, the GPL requires that modified versions be marked as\nchanged, so that their problems will not be attributed erroneously to\nauthors of previous versions.\n\n  Some devices are designed to deny users access to install or run\nmodified versions of the software inside them, although the manufacturer\ncan do so.  This is fundamentally incompatible with the aim of\nprotecting users' freedom to change the software.  The systematic\npattern of such abuse occurs in the area of products for individuals to\nuse, which is precisely where it is most unacceptable.  Therefore, we\nhave designed this version of the GPL to prohibit the practice for those\nproducts.  If such problems arise substantially in other domains, we\nstand ready to extend this provision to those domains in future versions\nof the GPL, as needed to protect the freedom of users.\n\n  Finally, every program is threatened constantly by software patents.\nStates should not allow patents to restrict development and use of\nsoftware on general-purpose computers, but in those that do, we wish to\navoid the special danger that patents applied to a free program could\nmake it effectively proprietary.  To prevent this, the GPL assures that\npatents cannot be used to render the program non-free.\n\n  The precise terms and conditions for copying, distribution and\nmodification follow.\n\n                       TERMS AND CONDITIONS\n\n  0. Definitions.\n\n  \"This License\" refers to version 3 of the GNU General Public License.\n\n  \"Copyright\" also means copyright-like laws that apply to other kinds of\nworks, such as semiconductor masks.\n\n  \"The Program\" refers to any copyrightable work licensed under this\nLicense.  Each licensee is addressed as \"you\".  \"Licensees\" and\n\"recipients\" may be individuals or organizations.\n\n  To \"modify\" a work means to copy from or adapt all or part of the work\nin a fashion requiring copyright permission, other than the making of an\nexact copy.  The resulting work is called a \"modified version\" of the\nearlier work or a work \"based on\" the earlier work.\n\n  A \"covered work\" means either the unmodified Program or a work based\non the Program.\n\n  To \"propagate\" a work means to do anything with it that, without\npermission, would make you directly or secondarily liable for\ninfringement under applicable copyright law, except executing it on a\ncomputer or modifying a private copy.  Propagation includes copying,\ndistribution (with or without modification), making available to the\npublic, and in some countries other activities as well.\n\n  To \"convey\" a work means any kind of propagation that enables other\nparties to make or receive copies.  Mere interaction with a user through\na computer network, with no transfer of a copy, is not conveying.\n\n  An interactive user interface displays \"Appropriate Legal Notices\"\nto the extent that it includes a convenient and prominently visible\nfeature that (1) displays an appropriate copyright notice, and (2)\ntells the user that there is no warranty for the work (except to the\nextent that warranties are provided), that licensees may convey the\nwork under this License, and how to view a copy of this License.  If\nthe interface presents a list of user commands or options, such as a\nmenu, a prominent item in the list meets this criterion.\n\n  1. Source Code.\n\n  The \"source code\" for a work means the preferred form of the work\nfor making modifications to it.  \"Object code\" means any non-source\nform of a work.\n\n  A \"Standard Interface\" means an interface that either is an official\nstandard defined by a recognized standards body, or, in the case of\ninterfaces specified for a particular programming language, one that\nis widely used among developers working in that language.\n\n  The \"System Libraries\" of an executable work include anything, other\nthan the work as a whole, that (a) is included in the normal form of\npackaging a Major Component, but which is not part of that Major\nComponent, and (b) serves only to enable use of the work with that\nMajor Component, or to implement a Standard Interface for which an\nimplementation is available to the public in source code form.  A\n\"Major Component\", in this context, means a major essential component\n(kernel, window system, and so on) of the specific operating system\n(if any) on which the executable work runs, or a compiler used to\nproduce the work, or an object code interpreter used to run it.\n\n  The \"Corresponding Source\" for a work in object code form means all\nthe source code needed to generate, install, and (for an executable\nwork) run the object code and to modify the work, including scripts to\ncontrol those activities.  However, it does not include the work's\nSystem Libraries, or general-purpose tools or generally available free\nprograms which are used unmodified in performing those activities but\nwhich are not part of the work.  For example, Corresponding Source\nincludes interface definition files associated with source files for\nthe work, and the source code for shared libraries and dynamically\nlinked subprograms that the work is specifically designed to require,\nsuch as by intimate data communication or control flow between those\nsubprograms and other parts of the work.\n\n  The Corresponding Source need not include anything that users\ncan regenerate automatically from other parts of the Corresponding\nSource.\n\n  The Corresponding Source for a work in source code form is that\nsame work.\n\n  2. Basic Permissions.\n\n  All rights granted under this License are granted for the term of\ncopyright on the Program, and are irrevocable provided the stated\nconditions are met.  This License explicitly affirms your unlimited\npermission to run the unmodified Program.  The output from running a\ncovered work is covered by this License only if the output, given its\ncontent, constitutes a covered work.  This License acknowledges your\nrights of fair use or other equivalent, as provided by copyright law.\n\n  You may make, run and propagate covered works that you do not\nconvey, without conditions so long as your license otherwise remains\nin force.  You may convey covered works to others for the sole purpose\nof having them make modifications exclusively for you, or provide you\nwith facilities for running those works, provided that you comply with\nthe terms of this License in conveying all material for which you do\nnot control copyright.  Those thus making or running the covered works\nfor you must do so exclusively on your behalf, under your direction\nand control, on terms that prohibit them from making any copies of\nyour copyrighted material outside their relationship with you.\n\n  Conveying under any other circumstances is permitted solely under\nthe conditions stated below.  Sublicensing is not allowed; section 10\nmakes it unnecessary.\n\n  3. Protecting Users' Legal Rights From Anti-Circumvention Law.\n\n  No covered work shall be deemed part of an effective technological\nmeasure under any applicable law fulfilling obligations under article\n11 of the WIPO copyright treaty adopted on 20 December 1996, or\nsimilar laws prohibiting or restricting circumvention of such\nmeasures.\n\n  When you convey a covered work, you waive any legal power to forbid\ncircumvention of technological measures to the extent such circumvention\nis effected by exercising rights under this License with respect to\nthe covered work, and you disclaim any intention to limit operation or\nmodification of the work as a means of enforcing, against the work's\nusers, your or third parties' legal rights to forbid circumvention of\ntechnological measures.\n\n  4. Conveying Verbatim Copies.\n\n  You may convey verbatim copies of the Program's source code as you\nreceive it, in any medium, provided that you conspicuously and\nappropriately publish on each copy an appropriate copyright notice;\nkeep intact all notices stating that this License and any\nnon-permissive terms added in accord with section 7 apply to the code;\nkeep intact all notices of the absence of any warranty; and give all\nrecipients a copy of this License along with the Program.\n\n  You may charge any price or no price for each copy that you convey,\nand you may offer support or warranty protection for a fee.\n\n  5. Conveying Modified Source Versions.\n\n  You may convey a work based on the Program, or the modifications to\nproduce it from the Program, in the form of source code under the\nterms of section 4, provided that you also meet all of these conditions:\n\n    a) The work must carry prominent notices stating that you modified\n    it, and giving a relevant date.\n\n    b) The work must carry prominent notices stating that it is\n    released under this License and any conditions added under section\n    7.  This requirement modifies the requirement in section 4 to\n    \"keep intact all notices\".\n\n    c) You must license the entire work, as a whole, under this\n    License to anyone who comes into possession of a copy.  This\n    License will therefore apply, along with any applicable section 7\n    additional terms, to the whole of the work, and all its parts,\n    regardless of how they are packaged.  This License gives no\n    permission to license the work in any other way, but it does not\n    invalidate such permission if you have separately received it.\n\n    d) If the work has interactive user interfaces, each must display\n    Appropriate Legal Notices; however, if the Program has interactive\n    interfaces that do not display Appropriate Legal Notices, your\n    work need not make them do so.\n\n  A compilation of a covered work with other separate and independent\nworks, which are not by their nature extensions of the covered work,\nand which are not combined with it such as to form a larger program,\nin or on a volume of a storage or distribution medium, is called an\n\"aggregate\" if the compilation and its resulting copyright are not\nused to limit the access or legal rights of the compilation's users\nbeyond what the individual works permit.  Inclusion of a covered work\nin an aggregate does not cause this License to apply to the other\nparts of the aggregate.\n\n  6. Conveying Non-Source Forms.\n\n  You may convey a covered work in object code form under the terms\nof sections 4 and 5, provided that you also convey the\nmachine-readable Corresponding Source under the terms of this License,\nin one of these ways:\n\n    a) Convey the object code in, or embodied in, a physical product\n    (including a physical distribution medium), accompanied by the\n    Corresponding Source fixed on a durable physical medium\n    customarily used for software interchange.\n\n    b) Convey the object code in, or embodied in, a physical product\n    (including a physical distribution medium), accompanied by a\n    written offer, valid for at least three years and valid for as\n    long as you offer spare parts or customer support for that product\n    model, to give anyone who possesses the object code either (1) a\n    copy of the Corresponding Source for all the software in the\n    product that is covered by this License, on a durable physical\n    medium customarily used for software interchange, for a price no\n    more than your reasonable cost of physically performing this\n    conveying of source, or (2) access to copy the\n    Corresponding Source from a network server at no charge.\n\n    c) Convey individual copies of the object code with a copy of the\n    written offer to provide the Corresponding Source.  This\n    alternative is allowed only occasionally and noncommercially, and\n    only if you received the object code with such an offer, in accord\n    with subsection 6b.\n\n    d) Convey the object code by offering access from a designated\n    place (gratis or for a charge), and offer equivalent access to the\n    Corresponding Source in the same way through the same place at no\n    further charge.  You need not require recipients to copy the\n    Corresponding Source along with the object code.  If the place to\n    copy the object code is a network server, the Corresponding Source\n    may be on a different server (operated by you or a third party)\n    that supports equivalent copying facilities, provided you maintain\n    clear directions next to the object code saying where to find the\n    Corresponding Source.  Regardless of what server hosts the\n    Corresponding Source, you remain obligated to ensure that it is\n    available for as long as needed to satisfy these requirements.\n\n    e) Convey the object code using peer-to-peer transmission, provided\n    you inform other peers where the object code and Corresponding\n    Source of the work are being offered to the general public at no\n    charge under subsection 6d.\n\n  A separable portion of the object code, whose source code is excluded\nfrom the Corresponding Source as a System Library, need not be\nincluded in conveying the object code work.\n\n  A \"User Product\" is either (1) a \"consumer product\", which means any\ntangible personal property which is normally used for personal, family,\nor household purposes, or (2) anything designed or sold for incorporation\ninto a dwelling.  In determining whether a product is a consumer product,\ndoubtful cases shall be resolved in favor of coverage.  For a particular\nproduct received by a particular user, \"normally used\" refers to a\ntypical or common use of that class of product, regardless of the status\nof the particular user or of the way in which the particular user\nactually uses, or expects or is expected to use, the product.  A product\nis a consumer product regardless of whether the product has substantial\ncommercial, industrial or non-consumer uses, unless such uses represent\nthe only significant mode of use of the product.\n\n  \"Installation Information\" for a User Product means any methods,\nprocedures, authorization keys, or other information required to install\nand execute modified versions of a covered work in that User Product from\na modified version of its Corresponding Source.  The information must\nsuffice to ensure that the continued functioning of the modified object\ncode is in no case prevented or interfered with solely because\nmodification has been made.\n\n  If you convey an object code work under this section in, or with, or\nspecifically for use in, a User Product, and the conveying occurs as\npart of a transaction in which the right of possession and use of the\nUser Product is transferred to the recipient in perpetuity or for a\nfixed term (regardless of how the transaction is characterized), the\nCorresponding Source conveyed under this section must be accompanied\nby the Installation Information.  But this requirement does not apply\nif neither you nor any third party retains the ability to install\nmodified object code on the User Product (for example, the work has\nbeen installed in ROM).\n\n  The requirement to provide Installation Information does not include a\nrequirement to continue to provide support service, warranty, or updates\nfor a work that has been modified or installed by the recipient, or for\nthe User Product in which it has been modified or installed.  Access to a\nnetwork may be denied when the modification itself materially and\nadversely affects the operation of the network or violates the rules and\nprotocols for communication across the network.\n\n  Corresponding Source conveyed, and Installation Information provided,\nin accord with this section must be in a format that is publicly\ndocumented (and with an implementation available to the public in\nsource code form), and must require no special password or key for\nunpacking, reading or copying.\n\n  7. Additional Terms.\n\n  \"Additional permissions\" are terms that supplement the terms of this\nLicense by making exceptions from one or more of its conditions.\nAdditional permissions that are applicable to the entire Program shall\nbe treated as though they were included in this License, to the extent\nthat they are valid under applicable law.  If additional permissions\napply only to part of the Program, that part may be used separately\nunder those permissions, but the entire Program remains governed by\nthis License without regard to the additional permissions.\n\n  When you convey a copy of a covered work, you may at your option\nremove any additional permissions from that copy, or from any part of\nit.  (Additional permissions may be written to require their own\nremoval in certain cases when you modify the work.)  You may place\nadditional permissions on material, added by you to a covered work,\nfor which you have or can give appropriate copyright permission.\n\n  Notwithstanding any other provision of this License, for material you\nadd to a covered work, you may (if authorized by the copyright holders of\nthat material) supplement the terms of this License with terms:\n\n    a) Disclaiming warranty or limiting liability differently from the\n    terms of sections 15 and 16 of this License; or\n\n    b) Requiring preservation of specified reasonable legal notices or\n    author attributions in that material or in the Appropriate Legal\n    Notices displayed by works containing it; or\n\n    c) Prohibiting misrepresentation of the origin of that material, or\n    requiring that modified versions of such material be marked in\n    reasonable ways as different from the original version; or\n\n    d) Limiting the use for publicity purposes of names of licensors or\n    authors of the material; or\n\n    e) Declining to grant rights under trademark law for use of some\n    trade names, trademarks, or service marks; or\n\n    f) Requiring indemnification of licensors and authors of that\n    material by anyone who conveys the material (or modified versions of\n    it) with contractual assumptions of liability to the recipient, for\n    any liability that these contractual assumptions directly impose on\n    those licensors and authors.\n\n  All other non-permissive additional terms are considered \"further\nrestrictions\" within the meaning of section 10.  If the Program as you\nreceived it, or any part of it, contains a notice stating that it is\ngoverned by this License along with a term that is a further\nrestriction, you may remove that term.  If a license document contains\na further restriction but permits relicensing or conveying under this\nLicense, you may add to a covered work material governed by the terms\nof that license document, provided that the further restriction does\nnot survive such relicensing or conveying.\n\n  If you add terms to a covered work in accord with this section, you\nmust place, in the relevant source files, a statement of the\nadditional terms that apply to those files, or a notice indicating\nwhere to find the applicable terms.\n\n  Additional terms, permissive or non-permissive, may be stated in the\nform of a separately written license, or stated as exceptions;\nthe above requirements apply either way.\n\n  8. Termination.\n\n  You may not propagate or modify a covered work except as expressly\nprovided under this License.  Any attempt otherwise to propagate or\nmodify it is void, and will automatically terminate your rights under\nthis License (including any patent licenses granted under the third\nparagraph of section 11).\n\n  However, if you cease all violation of this License, then your\nlicense from a particular copyright holder is reinstated (a)\nprovisionally, unless and until the copyright holder explicitly and\nfinally terminates your license, and (b) permanently, if the copyright\nholder fails to notify you of the violation by some reasonable means\nprior to 60 days after the cessation.\n\n  Moreover, your license from a particular copyright holder is\nreinstated permanently if the copyright holder notifies you of the\nviolation by some reasonable means, this is the first time you have\nreceived notice of violation of this License (for any work) from that\ncopyright holder, and you cure the violation prior to 30 days after\nyour receipt of the notice.\n\n  Termination of your rights under this section does not terminate the\nlicenses of parties who have received copies or rights from you under\nthis License.  If your rights have been terminated and not permanently\nreinstated, you do not qualify to receive new licenses for the same\nmaterial under section 10.\n\n  9. Acceptance Not Required for Having Copies.\n\n  You are not required to accept this License in order to receive or\nrun a copy of the Program.  Ancillary propagation of a covered work\noccurring solely as a consequence of using peer-to-peer transmission\nto receive a copy likewise does not require acceptance.  However,\nnothing other than this License grants you permission to propagate or\nmodify any covered work.  These actions infringe copyright if you do\nnot accept this License.  Therefore, by modifying or propagating a\ncovered work, you indicate your acceptance of this License to do so.\n\n  10. Automatic Licensing of Downstream Recipients.\n\n  Each time you convey a covered work, the recipient automatically\nreceives a license from the original licensors, to run, modify and\npropagate that work, subject to this License.  You are not responsible\nfor enforcing compliance by third parties with this License.\n\n  An \"entity transaction\" is a transaction transferring control of an\norganization, or substantially all assets of one, or subdividing an\norganization, or merging organizations.  If propagation of a covered\nwork results from an entity transaction, each party to that\ntransaction who receives a copy of the work also receives whatever\nlicenses to the work the party's predecessor in interest had or could\ngive under the previous paragraph, plus a right to possession of the\nCorresponding Source of the work from the predecessor in interest, if\nthe predecessor has it or can get it with reasonable efforts.\n\n  You may not impose any further restrictions on the exercise of the\nrights granted or affirmed under this License.  For example, you may\nnot impose a license fee, royalty, or other charge for exercise of\nrights granted under this License, and you may not initiate litigation\n(including a cross-claim or counterclaim in a lawsuit) alleging that\nany patent claim is infringed by making, using, selling, offering for\nsale, or importing the Program or any portion of it.\n\n  11. Patents.\n\n  A \"contributor\" is a copyright holder who authorizes use under this\nLicense of the Program or a work on which the Program is based.  The\nwork thus licensed is called the contributor's \"contributor version\".\n\n  A contributor's \"essential patent claims\" are all patent claims\nowned or controlled by the contributor, whether already acquired or\nhereafter acquired, that would be infringed by some manner, permitted\nby this License, of making, using, or selling its contributor version,\nbut do not include claims that would be infringed only as a\nconsequence of further modification of the contributor version.  For\npurposes of this definition, \"control\" includes the right to grant\npatent sublicenses in a manner consistent with the requirements of\nthis License.\n\n  Each contributor grants you a non-exclusive, worldwide, royalty-free\npatent license under the contributor's essential patent claims, to\nmake, use, sell, offer for sale, import and otherwise run, modify and\npropagate the contents of its contributor version.\n\n  In the following three paragraphs, a \"patent license\" is any express\nagreement or commitment, however denominated, not to enforce a patent\n(such as an express permission to practice a patent or covenant not to\nsue for patent infringement).  To \"grant\" such a patent license to a\nparty means to make such an agreement or commitment not to enforce a\npatent against the party.\n\n  If you convey a covered work, knowingly relying on a patent license,\nand the Corresponding Source of the work is not available for anyone\nto copy, free of charge and under the terms of this License, through a\npublicly available network server or other readily accessible means,\nthen you must either (1) cause the Corresponding Source to be so\navailable, or (2) arrange to deprive yourself of the benefit of the\npatent license for this particular work, or (3) arrange, in a manner\nconsistent with the requirements of this License, to extend the patent\nlicense to downstream recipients.  \"Knowingly relying\" means you have\nactual knowledge that, but for the patent license, your conveying the\ncovered work in a country, or your recipient's use of the covered work\nin a country, would infringe one or more identifiable patents in that\ncountry that you have reason to believe are valid.\n\n  If, pursuant to or in connection with a single transaction or\narrangement, you convey, or propagate by procuring conveyance of, a\ncovered work, and grant a patent license to some of the parties\nreceiving the covered work authorizing them to use, propagate, modify\nor convey a specific copy of the covered work, then the patent license\nyou grant is automatically extended to all recipients of the covered\nwork and works based on it.\n\n  A patent license is \"discriminatory\" if it does not include within\nthe scope of its coverage, prohibits the exercise of, or is\nconditioned on the non-exercise of one or more of the rights that are\nspecifically granted under this License.  You may not convey a covered\nwork if you are a party to an arrangement with a third party that is\nin the business of distributing software, under which you make payment\nto the third party based on the extent of your activity of conveying\nthe work, and under which the third party grants, to any of the\nparties who would receive the covered work from you, a discriminatory\npatent license (a) in connection with copies of the covered work\nconveyed by you (or copies made from those copies), or (b) primarily\nfor and in connection with specific products or compilations that\ncontain the covered work, unless you entered into that arrangement,\nor that patent license was granted, prior to 28 March 2007.\n\n  Nothing in this License shall be construed as excluding or limiting\nany implied license or other defenses to infringement that may\notherwise be available to you under applicable patent law.\n\n  12. No Surrender of Others' Freedom.\n\n  If conditions are imposed on you (whether by court order, agreement or\notherwise) that contradict the conditions of this License, they do not\nexcuse you from the conditions of this License.  If you cannot convey a\ncovered work so as to satisfy simultaneously your obligations under this\nLicense and any other pertinent obligations, then as a consequence you may\nnot convey it at all.  For example, if you agree to terms that obligate you\nto collect a royalty for further conveying from those to whom you convey\nthe Program, the only way you could satisfy both those terms and this\nLicense would be to refrain entirely from conveying the Program.\n\n  13. Use with the GNU Affero General Public License.\n\n  Notwithstanding any other provision of this License, you have\npermission to link or combine any covered work with a work licensed\nunder version 3 of the GNU Affero General Public License into a single\ncombined work, and to convey the resulting work.  The terms of this\nLicense will continue to apply to the part which is the covered work,\nbut the special requirements of the GNU Affero General Public License,\nsection 13, concerning interaction through a network will apply to the\ncombination as such.\n\n  14. Revised Versions of this License.\n\n  The Free Software Foundation may publish revised and/or new versions of\nthe GNU General Public License from time to time.  Such new versions will\nbe similar in spirit to the present version, but may differ in detail to\naddress new problems or concerns.\n\n  Each version is given a distinguishing version number.  If the\nProgram specifies that a certain numbered version of the GNU General\nPublic License \"or any later version\" applies to it, you have the\noption of following the terms and conditions either of that numbered\nversion or of any later version published by the Free Software\nFoundation.  If the Program does not specify a version number of the\nGNU General Public License, you may choose any version ever published\nby the Free Software Foundation.\n\n  If the Program specifies that a proxy can decide which future\nversions of the GNU General Public License can be used, that proxy's\npublic statement of acceptance of a version permanently authorizes you\nto choose that version for the Program.\n\n  Later license versions may give you additional or different\npermissions.  However, no additional obligations are imposed on any\nauthor or copyright holder as a result of your choosing to follow a\nlater version.\n\n  15. Disclaimer of Warranty.\n\n  THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY\nAPPLICABLE LAW.  EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT\nHOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM \"AS IS\" WITHOUT WARRANTY\nOF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,\nTHE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR\nPURPOSE.  THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM\nIS WITH YOU.  SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF\nALL NECESSARY SERVICING, REPAIR OR CORRECTION.\n\n  16. Limitation of Liability.\n\n  IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING\nWILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS\nTHE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY\nGENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE\nUSE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF\nDATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD\nPARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),\nEVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF\nSUCH DAMAGES.\n\n  17. Interpretation of Sections 15 and 16.\n\n  If the disclaimer of warranty and limitation of liability provided\nabove cannot be given local legal effect according to their terms,\nreviewing courts shall apply local law that most closely approximates\nan absolute waiver of all civil liability in connection with the\nProgram, unless a warranty or assumption of liability accompanies a\ncopy of the Program in return for a fee.\n\n                     END OF TERMS AND CONDITIONS\n\n            How to Apply These Terms to Your New Programs\n\n  If you develop a new program, and you want it to be of the greatest\npossible use to the public, the best way to achieve this is to make it\nfree software which everyone can redistribute and change under these terms.\n\n  To do so, attach the following notices to the program.  It is safest\nto attach them to the start of each source file to most effectively\nstate the exclusion of warranty; and each file should have at least\nthe \"copyright\" line and a pointer to where the full notice is found.\n\n    <one line to give the program's name and a brief idea of what it does.>\n    Copyright (C) <year>  <name of author>\n\n    This program is free software: you can redistribute it and/or modify\n    it under the terms of the GNU General Public License as published by\n    the Free Software Foundation, either version 3 of the License, or\n    (at your option) any later version.\n\n    This program is distributed in the hope that it will be useful,\n    but WITHOUT ANY WARRANTY; without even the implied warranty of\n    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n    GNU General Public License for more details.\n\n    You should have received a copy of the GNU General Public License\n    along with this program.  If not, see <https://www.gnu.org/licenses/>.\n\nAlso add information on how to contact you by electronic and paper mail.\n\n  If the program does terminal interaction, make it output a short\nnotice like this when it starts in an interactive mode:\n\n    <program>  Copyright (C) <year>  <name of author>\n    This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.\n    This is free software, and you are welcome to redistribute it\n    under certain conditions; type `show c' for details.\n\nThe hypothetical commands `show w' and `show c' should show the appropriate\nparts of the General Public License.  Of course, your program's commands\nmight be different; for a GUI interface, you would use an \"about box\".\n\n  You should also get your employer (if you work as a programmer) or school,\nif any, to sign a \"copyright disclaimer\" for the program, if necessary.\nFor more information on this, and how to apply and follow the GNU GPL, see\n<https://www.gnu.org/licenses/>.\n\n  The GNU General Public License does not permit incorporating your program\ninto proprietary programs.  If your program is a subroutine library, you\nmay consider it more useful to permit linking proprietary applications with\nthe library.  If this is what you want to do, use the GNU Lesser General\nPublic License instead of this License.  But first, please read\n<https://www.gnu.org/licenses/why-not-lgpl.html>.\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/MPE/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/MPE/agents/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/MPE/agents/agents.py",
    "content": "import torch\nfrom torch import Tensor\nfrom torch.autograd import Variable\nfrom torch.optim import Adam\nfrom MPE.utils.networks import MLPNetwork, RNN, SNNNetwork\nfrom MPE.utils.misc import hard_update, gumbel_softmax, onehot_from_logits\nfrom MPE.utils.noise import OUNoise\nimport time\n\nclass DDPGAgent(object):\n    \"\"\"\n    General class for DDPG agents (policy, critic, target policy, target\n    critic, exploration noise)\n    \"\"\"\n    def __init__(self, num_in_pol, num_out_pol, num_in_critic, hidden_dim=64,\n                 lr=0.01, discrete_action=True):\n        \"\"\"\n        Inputs:\n            num_in_pol (int): number of dimensions for policy input\n            num_out_pol (int): number of dimensions for policy output\n            num_in_critic (int): number of dimensions for critic input\n        \"\"\"\n        self.policy = MLPNetwork(num_in_pol, num_out_pol,\n                                 hidden_dim=hidden_dim,\n                                 constrain_out=True,\n                                 discrete_action=discrete_action)\n        self.critic = MLPNetwork(num_in_critic, 1,\n                                 hidden_dim=hidden_dim,\n                                 constrain_out=False)\n        self.target_policy = MLPNetwork(num_in_pol, num_out_pol,\n                                        hidden_dim=hidden_dim,\n                                        constrain_out=True,\n                                        discrete_action=discrete_action)\n        self.target_critic = MLPNetwork(num_in_critic, 1,\n                                        hidden_dim=hidden_dim,\n                                        constrain_out=False)\n        hard_update(self.target_policy, self.policy)\n        hard_update(self.target_critic, self.critic)\n        self.policy_optimizer = Adam(self.policy.parameters(), lr=lr)\n        self.critic_optimizer = Adam(self.critic.parameters(), lr=lr)\n        if not discrete_action:\n            self.exploration = OUNoise(num_out_pol)\n        else:\n            self.exploration = 0.3  # epsilon for eps-greedy\n        self.discrete_action = discrete_action\n\n    def reset_noise(self):\n        if not self.discrete_action:\n            self.exploration.reset()\n\n    def scale_noise(self, scale):\n        if self.discrete_action:\n            self.exploration = scale\n        else:\n            self.exploration.scale = scale\n\n    def step(self, obs, explore=False):\n        \"\"\"\n        Take a step forward in environment for a minibatch of observations\n        Inputs:\n            obs (PyTorch Variable): Observations for this agent\n            explore (boolean): Whether or not to add exploration noise\n        Outputs:\n            action (PyTorch Variable): Actions for this agent\n        \"\"\"\n        action = self.policy(obs)\n\n\n        if self.discrete_action:\n            if explore:\n                if action.shape[1] == 9:\n                    action = torch.cat(\n                        (gumbel_softmax(action[:, :5], hard=True), gumbel_softmax(action[:, 5:], hard=True)), 1)\n                else:\n                    action = gumbel_softmax(action, hard=True)\n            else:\n                if action.shape[1] == 9:\n                    action = torch.cat(\n                        (onehot_from_logits(action[:, :5]), onehot_from_logits(action[:, 5:])), 1)\n                else:\n                    action = onehot_from_logits(action)\n        else:  # continuous action\n            if explore:\n                action += Variable(Tensor(self.exploration.noise()),\n                                   requires_grad=False)\n            action = action.clamp(-1, 1)\n        return action\n\n    def get_params(self):\n        return {'policy': self.policy.state_dict(),\n                'critic': self.critic.state_dict(),\n                'target_policy': self.target_policy.state_dict(),\n                'target_critic': self.target_critic.state_dict(),\n                'policy_optimizer': self.policy_optimizer.state_dict(),\n                'critic_optimizer': self.critic_optimizer.state_dict()}\n\n    def load_params(self, params):\n        self.policy.load_state_dict(params['policy'])\n        self.critic.load_state_dict(params['critic'])\n        self.target_policy.load_state_dict(params['target_policy'])\n        self.target_critic.load_state_dict(params['target_critic'])\n        self.policy_optimizer.load_state_dict(params['policy_optimizer'])\n        self.critic_optimizer.load_state_dict(params['critic_optimizer'])\n\nclass DDPGAgent_RNN(object):\n    \"\"\"\n    General class for DDPG agents (policy, critic, target policy, target\n    critic, exploration noise)\n    \"\"\"\n    def __init__(self, num_in_pol, num_out_pol, num_in_critic, hidden_dim=64,\n                 lr=0.01, discrete_action=True):\n        \"\"\"\n        Inputs:\n            num_in_pol (int): number of dimensions for policy input\n            num_out_pol (int): number of dimensions for policy output\n            num_in_critic (int): number of dimensions for critic input\n        \"\"\"\n        self.policy = RNN(num_in_pol, num_out_pol,\n                                 hidden_dim=hidden_dim,\n                                 constrain_out=True,\n                                 discrete_action=discrete_action)\n        self.critic = RNN(num_in_critic, 1,\n                                 hidden_dim=hidden_dim,\n                                 constrain_out=False)\n        self.target_policy = RNN(num_in_pol, num_out_pol,\n                                        hidden_dim=hidden_dim,\n                                        constrain_out=True,\n                                        discrete_action=discrete_action)\n        self.target_critic = RNN(num_in_critic, 1,\n                                        hidden_dim=hidden_dim,\n                                        constrain_out=False)\n\n        self.policy_hidden = None\n        self.policy_target_hidden = None\n        self.critic_hidden = None\n        self.critic_target_hidden = None\n        self.num_in_pol = num_in_pol\n        self.num_out_pol = num_out_pol\n        self.hidden_dim = hidden_dim\n        hard_update(self.target_policy, self.policy)\n        hard_update(self.target_critic, self.critic)\n        self.policy_optimizer = Adam(self.policy.parameters(), lr=lr)\n        self.critic_optimizer = Adam(self.critic.parameters(), lr=lr)\n        if not discrete_action:\n            self.exploration = OUNoise(num_out_pol)\n        else:\n            self.exploration = 0.3  # epsilon for eps-greedy\n        self.discrete_action = discrete_action\n\n    def reset_noise(self):\n        if not self.discrete_action:\n            self.exploration.reset()\n\n    def scale_noise(self, scale):\n        if self.discrete_action:\n            self.exploration = scale\n        else:\n            self.exploration.scale = scale\n\n    def step(self, obs, explore=False):\n        \"\"\"\n        Take a step forward in environment for a minibatch of observations\n        Inputs:\n            obs (PyTorch Variable): Observations for this agent\n            explore (boolean): Whether or not to add exploration noise\n        Outputs:\n            action (PyTorch Variable): Actions for this agent\n        \"\"\"\n\n        action, self.policy_hidden = self.policy(obs, self.policy_hidden)\n        if self.discrete_action:\n            if explore:\n                action = gumbel_softmax(action, hard=True)\n            else:\n                action = onehot_from_logits(action)\n        else:  # continuous action\n            if explore:\n                action += Variable(Tensor(self.exploration.noise()),\n                                   requires_grad=False)\n            action = action.clamp(-1, 1)\n        return action\n\n    def get_params(self):\n        return {'policy': self.policy.state_dict(),\n                'critic': self.critic.state_dict(),\n                'target_policy': self.target_policy.state_dict(),\n                'target_critic': self.target_critic.state_dict(),\n                'policy_optimizer': self.policy_optimizer.state_dict(),\n                'critic_optimizer': self.critic_optimizer.state_dict()}\n\n    def load_params(self, params):\n        self.policy.load_state_dict(params['policy'])\n        self.critic.load_state_dict(params['critic'])\n        self.target_policy.load_state_dict(params['target_policy'])\n        self.target_critic.load_state_dict(params['target_critic'])\n        self.policy_optimizer.load_state_dict(params['policy_optimizer'])\n        self.critic_optimizer.load_state_dict(params['critic_optimizer'])\n\n    def init_hidden(self, len_ep, policy_hidden=False, policy_target_hidden=False, \\\n                    critic_hidden=False, critic_target_hidden=False):\n        # 为每个episode中的每个agent都初始化一个eval_hidden、target_hidden\n        if policy_hidden == True:\n            self.policy_hidden = torch.zeros((len_ep, self.hidden_dim))\n        if policy_target_hidden == True:\n            self.policy_target_hidden = torch.zeros((len_ep, self.hidden_dim))\n        if critic_hidden == True:\n            self.critic_hidden = torch.zeros((len_ep, self.hidden_dim))\n        if critic_target_hidden == True:\n            self.critic_target_hidden = torch.zeros((len_ep, self.hidden_dim))\n\nclass DDPGAgent_SNN(object):\n    \"\"\"\n    General class for DDPG agents (policy, critic, target policy, target\n    critic, exploration noise)\n    \"\"\"\n    def __init__(self, num_in_pol, num_out_pol, num_in_critic, output_style, hidden_dim=64,\n                 lr=0.01, discrete_action=True):\n        \"\"\"\n        Inputs:\n            num_in_pol (int): number of dimensions for policy input\n            num_out_pol (int): number of dimensions for policy output\n            num_in_critic (int): number of dimensions for critic input\n        \"\"\"\n        self.policy = SNNNetwork(num_in_pol, num_out_pol,\n                                 hidden_dim=hidden_dim,\n                                 output_style=output_style)\n        self.critic = SNNNetwork(num_in_critic, 1,\n                                 hidden_dim=hidden_dim,\n                                 output_style=output_style)\n        self.target_policy = SNNNetwork(num_in_pol, num_out_pol,\n                                        hidden_dim=hidden_dim,\n                                        output_style=output_style)\n        self.target_critic = SNNNetwork(num_in_critic, 1,\n                                        hidden_dim=hidden_dim,\n                                        output_style=output_style)\n        hard_update(self.target_policy, self.policy)\n        hard_update(self.target_critic, self.critic)\n        self.policy_optimizer = Adam(self.policy.parameters(), lr=lr)\n        self.critic_optimizer = Adam(self.critic.parameters(), lr=lr)\n        if not discrete_action:\n            self.exploration = OUNoise(num_out_pol)\n        else:\n            self.exploration = 0.3  # epsilon for eps-greedy\n        self.discrete_action = discrete_action\n\n    def reset_noise(self):\n        if not self.discrete_action:\n            self.exploration.reset()\n\n    def scale_noise(self, scale):\n        if self.discrete_action:\n            self.exploration = scale\n        else:\n            self.exploration.scale = scale\n\n    def step(self, obs, explore=False):\n        \"\"\"\n        Take a step forward in environment for a minibatch of observations\n        Inputs:\n            obs (PyTorch Variable): Observations for this agent\n            explore (boolean): Whether or not to add exploration noise\n        Outputs:\n            action (PyTorch Variable): Actions for this agent\n        \"\"\"\n        # t1 = time.time()\n        action = self.policy(obs)\n        # t2 = time.time()\n        # print('time_interaction:', t2 - t1)\n        if self.discrete_action:\n            if explore:\n                if action.shape[1] == 9:\n                    action = torch.cat(\n                        (gumbel_softmax(action[:, :5], hard=True), gumbel_softmax(action[:, 5:], hard=True)), 1)\n                else:\n                    action = gumbel_softmax(action, hard=True)\n            else:\n                if action.shape[1] == 9:\n                    action = torch.cat(\n                        (onehot_from_logits(action[:, :5]), onehot_from_logits(action[:, 5:])), 1)\n                else:\n                    action = onehot_from_logits(action)\n            # if explore:\n            #\n            #     action = gumbel_softmax(action, hard=True)\n            #\n            # else:\n            #     action = onehot_from_logits(action)\n        else:  # continuous action\n            if explore:\n                action += Variable(Tensor(self.exploration.noise()),\n                                   requires_grad=False)\n\n            action = action.clamp(-1, 1)\n\n        return action\n\n    def get_params(self):\n        return {'policy': self.policy.state_dict(),\n                'critic': self.critic.state_dict(),\n                'target_policy': self.target_policy.state_dict(),\n                'target_critic': self.target_critic.state_dict(),\n                'policy_optimizer': self.policy_optimizer.state_dict(),\n                'critic_optimizer': self.critic_optimizer.state_dict()}\n\n    def load_params(self, params):\n        self.policy.load_state_dict(params['policy'])\n        self.critic.load_state_dict(params['critic'])\n        self.target_policy.load_state_dict(params['target_policy'])\n        self.target_critic.load_state_dict(params['target_critic'])\n        self.policy_optimizer.load_state_dict(params['policy_optimizer'])\n        self.critic_optimizer.load_state_dict(params['critic_optimizer'])\n\nclass DDPGAgent_ToM(object):\n    \"\"\"\n    General class for DDPG agents (policy, critic, target policy, target\n    critic, exploration noise)\n    \"\"\"\n    def __init__(self, num_in_pol, num_out_pol, num_in_critic, num_in_mle, output_style,\n                 num_agents, device, hidden_dim=64, lr=0.01, discrete_action=True):\n        \"\"\"\n        Inputs:\n            num_in_pol (int): number of dimensions for policy input\n            num_out_pol (int): number of dimensions for policy output\n            num_in_critic (int): number of dimensions for critic input\n        \"\"\"\n        self.device = device\n        self.policy = SNNNetwork(num_in_pol, num_out_pol,\n                                 hidden_dim=hidden_dim,\n                                 output_style=output_style)\n        self.critic = SNNNetwork(num_in_critic, 1,\n                                 hidden_dim=hidden_dim,\n                                 output_style=output_style)\n        self.target_policy = SNNNetwork(num_in_pol, num_out_pol,\n                                        hidden_dim=hidden_dim,\n                                        output_style=output_style)\n        self.target_critic = SNNNetwork(num_in_critic, 1,\n                                        hidden_dim=hidden_dim,\n                                        output_style=output_style)\n        # self.mle = [SNNNetwork(num_in_mle, num_out_pol,\n        #                       hidden_dim=hidden_dim,\n        #                       output_style=output_style)] * (num_agents - 1)\n        self.mle = []\n        hard_update(self.target_policy, self.policy)\n        hard_update(self.target_critic, self.critic)\n        self.policy_optimizer = Adam(self.policy.parameters(), lr=lr)\n        self.critic_optimizer = Adam(self.critic.parameters(), lr=lr)\n        self.mle_optimizer = []\n        if not discrete_action:\n            self.exploration = OUNoise(num_out_pol)\n        else:\n            self.exploration = 0.3  # epsilon for eps-greedy\n        self.discrete_action = discrete_action\n\n    def reset_noise(self):\n        if not self.discrete_action:\n            self.exploration.reset()\n\n    def scale_noise(self, scale):\n        if self.discrete_action:\n            self.exploration = scale\n        else:\n            self.exploration.scale = scale\n\n    def step(self, obs, explore=False):\n        \"\"\"\n        Take a step forward in environment for a minibatch of observations\n        Inputs:\n            obs (PyTorch Variable): Observations for this agent\n            explore (boolean): Whether or not to add exploration noise\n        Outputs:\n            action (PyTorch Variable): Actions for this agent\n        \"\"\"\n        action = self.policy.to(self.device)(obs.to(self.device))\n        if self.discrete_action:\n            if explore:\n                if action.shape[1] == 9:\n                    action = torch.cat(\n                        (gumbel_softmax(action[:, :5], hard=True), gumbel_softmax(action[:, 5:], hard=True)), 1).cpu()\n                else:\n                    action = gumbel_softmax(action, hard=True).cpu()\n            else:\n                if action.shape[1] == 9:\n                    action = torch.cat(\n                        (onehot_from_logits(action[:, :5], hard=True), onehot_from_logits(action[:, 5:], hard=True)), 1)\n                else:\n                    action = onehot_from_logits(action).cpu()\n            # if explore:\n            #     action = gumbel_softmax(action, hard=True).cpu()\n            # else:\n            #     action = onehot_from_logits(action).cpu()\n        else:  # continuous action\n            if explore:\n                action += Variable(Tensor(self.exploration.noise()),\n                                   requires_grad=False)\n            action = action.clamp(-1, 1)\n\n        return action\n\n    def get_params(self):\n        params = {'policy': self.policy.state_dict(),\n                'critic': self.critic.state_dict(),\n                'target_policy': self.target_policy.state_dict(),\n                'target_critic': self.target_critic.state_dict(),\n                'policy_optimizer': self.policy_optimizer.state_dict(),\n                'critic_optimizer': self.critic_optimizer.state_dict(),\n                }\n        # for i in range(len(self.mle)):\n        #     params['mle%d'%i] = self.mle[i].state_dict()\n        #     params['mle_optimizer%d'%i] = self.mle_optimizer[i].state_dict()\n        return params\n\n    def load_params(self, params):\n        self.policy.load_state_dict(params['policy'])\n        self.critic.load_state_dict(params['critic'])\n        self.target_policy.load_state_dict(params['target_policy'])\n        self.target_critic.load_state_dict(params['target_critic'])\n        self.policy_optimizer.load_state_dict(params['policy_optimizer'])\n        self.critic_optimizer.load_state_dict(params['critic_optimizer'])\n        # for i in range(len(self.mle)):\n        #     self.mle[i].load_state_dict(params['mle%d'%i])\n        #     self.mle_optimizer[i].load_state_dict(params['mle_optimizer%d'%i])"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/MPE/common/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/MPE/common/distributions.py",
    "content": "# import tensorflow as tf\nimport tensorflow.compat.v1 as tf\ntf.compat.v1.disable_eager_execution()\nimport numpy as np\nimport maddpg.common.tf_util as U\nfrom tensorflow.python.ops import math_ops\nfrom multiagent.multi_discrete import MultiDiscrete\nfrom tensorflow.python.ops import nn\n\nclass Pd(object):\n    \"\"\"\n    A particular probability distribution\n    \"\"\"\n    def flatparam(self):\n        raise NotImplementedError\n    def mode(self):\n        raise NotImplementedError\n    def logp(self, x):\n        raise NotImplementedError\n    def kl(self, other):\n        raise NotImplementedError\n    def entropy(self):\n        raise NotImplementedError\n    def sample(self):\n        raise NotImplementedError\n\nclass PdType(object):\n    \"\"\"\n    Parametrized family of probability distributions\n    \"\"\"\n    def pdclass(self):\n        raise NotImplementedError\n    def pdfromflat(self, flat):\n        return self.pdclass()(flat)\n    def param_shape(self):\n        raise NotImplementedError\n    def sample_shape(self):\n        raise NotImplementedError\n    def sample_dtype(self):\n        raise NotImplementedError\n\n    def param_placeholder(self, prepend_shape, name=None):\n        return tf.placeholder(dtype=tf.float32, shape=prepend_shape+self.param_shape(), name=name)\n    def sample_placeholder(self, prepend_shape, name=None):\n        return tf.placeholder(dtype=self.sample_dtype(), shape=prepend_shape+self.sample_shape(), name=name)\n\nclass CategoricalPdType(PdType):\n    def __init__(self, ncat):\n        self.ncat = ncat\n    def pdclass(self):\n        return CategoricalPd\n    def param_shape(self):\n        return [self.ncat]\n    def sample_shape(self):\n        return []\n    def sample_dtype(self):\n        return tf.int32\n\nclass SoftCategoricalPdType(PdType):\n    def __init__(self, ncat):\n        self.ncat = ncat\n    def pdclass(self):\n        return SoftCategoricalPd\n    def param_shape(self):\n        return [self.ncat]\n    def sample_shape(self):\n        return [self.ncat]\n    def sample_dtype(self):\n        return tf.float32\n\nclass MultiCategoricalPdType(PdType):\n    def __init__(self, low, high):\n        self.low = low\n        self.high = high\n        self.ncats = high - low + 1\n    def pdclass(self):\n        return MultiCategoricalPd\n    def pdfromflat(self, flat):\n        return MultiCategoricalPd(self.low, self.high, flat)\n    def param_shape(self):\n        return [sum(self.ncats)]\n    def sample_shape(self):\n        return [len(self.ncats)]\n    def sample_dtype(self):\n        return tf.int32\n\nclass SoftMultiCategoricalPdType(PdType):\n    def __init__(self, low, high):\n        self.low = low\n        self.high = high\n        self.ncats = high - low + 1\n    def pdclass(self):\n        return SoftMultiCategoricalPd\n    def pdfromflat(self, flat):\n        return SoftMultiCategoricalPd(self.low, self.high, flat)\n    def param_shape(self):\n        return [sum(self.ncats)]\n    def sample_shape(self):\n        return [sum(self.ncats)]\n    def sample_dtype(self):\n        return tf.float32\n\nclass DiagGaussianPdType(PdType):\n    def __init__(self, size):\n        self.size = size\n    def pdclass(self):\n        return DiagGaussianPd\n    def param_shape(self):\n        return [2*self.size]\n    def sample_shape(self):\n        return [self.size]\n    def sample_dtype(self):\n        return tf.float32\n\nclass BernoulliPdType(PdType):\n    def __init__(self, size):\n        self.size = size\n    def pdclass(self):\n        return BernoulliPd\n    def param_shape(self):\n        return [self.size]\n    def sample_shape(self):\n        return [self.size]\n    def sample_dtype(self):\n        return tf.int32\n\n# WRONG SECOND DERIVATIVES\n# class CategoricalPd(Pd):\n#     def __init__(self, logits):\n#         self.logits = logits\n#         self.ps = tf.nn.softmax(logits)\n#     @classmethod\n#     def fromflat(cls, flat):\n#         return cls(flat)\n#     def flatparam(self):\n#         return self.logits\n#     def mode(self):\n#         return U.argmax(self.logits, axis=1)\n#     def logp(self, x):\n#         return -tf.nn.sparse_softmax_cross_entropy_with_logits(self.logits, x)\n#     def kl(self, other):\n#         return tf.nn.softmax_cross_entropy_with_logits(other.logits, self.ps) \\\n#                 - tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps)\n#     def entropy(self):\n#         return tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps)\n#     def sample(self):\n#         u = tf.random_uniform(tf.shape(self.logits))\n#         return U.argmax(self.logits - tf.log(-tf.log(u)), axis=1)\n\nclass CategoricalPd(Pd):\n    def __init__(self, logits):\n        self.logits = logits\n    def flatparam(self):\n        return self.logits\n    def mode(self):\n        return U.argmax(self.logits, axis=1)\n    def logp(self, x):\n        return -tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x)\n    def kl(self, other):\n        a0 = self.logits - U.max(self.logits, axis=1, keepdims=True)\n        a1 = other.logits - U.max(other.logits, axis=1, keepdims=True)\n        ea0 = tf.exp(a0)\n        ea1 = tf.exp(a1)\n        z0 = U.sum(ea0, axis=1, keepdims=True)\n        z1 = U.sum(ea1, axis=1, keepdims=True)\n        p0 = ea0 / z0\n        return U.sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), axis=1)\n    def entropy(self):\n        a0 = self.logits - U.max(self.logits, axis=1, keepdims=True)\n        ea0 = tf.exp(a0)\n        z0 = U.sum(ea0, axis=1, keepdims=True)\n        p0 = ea0 / z0\n        return U.sum(p0 * (tf.log(z0) - a0), axis=1)\n    def sample(self):\n        u = tf.random_uniform(tf.shape(self.logits))\n        return U.argmax(self.logits - tf.log(-tf.log(u)), axis=1)\n    @classmethod\n    def fromflat(cls, flat):\n        return cls(flat)\n\nclass SoftCategoricalPd(Pd):\n    def __init__(self, logits):\n        self.logits = logits\n    def flatparam(self):\n        return self.logits\n    def mode(self):\n        return U.softmax(self.logits, axis=-1)\n    def logp(self, x):\n        return -tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=x)\n    def kl(self, other):\n        a0 = self.logits - U.max(self.logits, axis=1, keepdims=True)\n        a1 = other.logits - U.max(other.logits, axis=1, keepdims=True)\n        ea0 = tf.exp(a0)\n        ea1 = tf.exp(a1)\n        z0 = U.sum(ea0, axis=1, keepdims=True)\n        z1 = U.sum(ea1, axis=1, keepdims=True)\n        p0 = ea0 / z0\n        return U.sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), axis=1)\n    def entropy(self):\n        a0 = self.logits - U.max(self.logits, axis=1, keepdims=True)\n        ea0 = tf.exp(a0)\n        z0 = U.sum(ea0, axis=1, keepdims=True)\n        p0 = ea0 / z0\n        return U.sum(p0 * (tf.log(z0) - a0), axis=1)\n    def sample(self):\n        u = tf.random_uniform(tf.shape(self.logits))\n        return U.softmax(self.logits - tf.log(-tf.log(u)), axis=-1)\n    @classmethod\n    def fromflat(cls, flat):\n        return cls(flat)\n\nclass MultiCategoricalPd(Pd):\n    def __init__(self, low, high, flat):\n        self.flat = flat\n        self.low = tf.constant(low, dtype=tf.int32)\n        self.categoricals = list(map(CategoricalPd, tf.split(flat, high - low + 1, axis=len(flat.get_shape()) - 1)))\n    def flatparam(self):\n        return self.flat\n    def mode(self):\n        return self.low + tf.cast(tf.stack([p.mode() for p in self.categoricals], axis=-1), tf.int32)\n    def logp(self, x):\n        return tf.add_n([p.logp(px) for p, px in zip(self.categoricals, tf.unstack(x - self.low, axis=len(x.get_shape()) - 1))])\n    def kl(self, other):\n        return tf.add_n([\n                p.kl(q) for p, q in zip(self.categoricals, other.categoricals)\n            ])\n    def entropy(self):\n        return tf.add_n([p.entropy() for p in self.categoricals])\n    def sample(self):\n        return self.low + tf.cast(tf.stack([p.sample() for p in self.categoricals], axis=-1), tf.int32)\n    @classmethod\n    def fromflat(cls, flat):\n        return cls(flat)\n\nclass SoftMultiCategoricalPd(Pd):  # doesn't work yet\n    def __init__(self, low, high, flat):\n        self.flat = flat\n        self.low = tf.constant(low, dtype=tf.float32)\n        self.categoricals = list(map(SoftCategoricalPd, tf.split(flat, high - low + 1, axis=len(flat.get_shape()) - 1)))\n    def flatparam(self):\n        return self.flat\n    def mode(self):\n        x = []\n        for i in range(len(self.categoricals)):\n            x.append(self.low[i] + self.categoricals[i].mode())\n        return tf.concat(x, axis=-1)\n    def logp(self, x):\n        return tf.add_n([p.logp(px) for p, px in zip(self.categoricals, tf.unstack(x - self.low, axis=len(x.get_shape()) - 1))])\n    def kl(self, other):\n        return tf.add_n([\n                p.kl(q) for p, q in zip(self.categoricals, other.categoricals)\n            ])\n    def entropy(self):\n        return tf.add_n([p.entropy() for p in self.categoricals])\n    def sample(self):\n        x = []\n        for i in range(len(self.categoricals)):\n            x.append(self.low[i] + self.categoricals[i].sample())\n        return tf.concat(x, axis=-1)\n    @classmethod\n    def fromflat(cls, flat):\n        return cls(flat)\n\nclass DiagGaussianPd(Pd):\n    def __init__(self, flat):\n        self.flat = flat\n        mean, logstd = tf.split(axis=1, num_or_size_splits=2, value=flat)\n        self.mean = mean\n        self.logstd = logstd\n        self.std = tf.exp(logstd)\n    def flatparam(self):\n        return self.flat\n    def mode(self):\n        return self.mean\n    def logp(self, x):\n        return - 0.5 * U.sum(tf.square((x - self.mean) / self.std), axis=1) \\\n               - 0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[1]) \\\n               - U.sum(self.logstd, axis=1)\n    def kl(self, other):\n        assert isinstance(other, DiagGaussianPd)\n        return U.sum(other.logstd - self.logstd + (tf.square(self.std) + tf.square(self.mean - other.mean)) / (2.0 * tf.square(other.std)) - 0.5, axis=1)\n    def entropy(self):\n        return U.sum(self.logstd + .5 * np.log(2.0 * np.pi * np.e), 1)\n    def sample(self):\n        return self.mean + self.std * tf.random_normal(tf.shape(self.mean))\n    @classmethod\n    def fromflat(cls, flat):\n        return cls(flat)\n\nclass BernoulliPd(Pd):\n    def __init__(self, logits):\n        self.logits = logits\n        self.ps = tf.sigmoid(logits)\n    def flatparam(self):\n        return self.logits\n    def mode(self):\n        return tf.round(self.ps)\n    def logp(self, x):\n        return - U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=tf.to_float(x)), axis=1)\n    def kl(self, other):\n        return U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=other.logits, labels=self.ps), axis=1) - U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=1)\n    def entropy(self):\n        return U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=1)\n    def sample(self):\n        p = tf.sigmoid(self.logits)\n        u = tf.random_uniform(tf.shape(p))\n        return tf.to_float(math_ops.less(u, p))\n    @classmethod\n    def fromflat(cls, flat):\n        return cls(flat)\n\ndef make_pdtype(ac_space):\n    from gym import spaces\n    if isinstance(ac_space, spaces.Box):\n        assert len(ac_space.shape) == 1\n        return DiagGaussianPdType(ac_space.shape[0])\n    elif isinstance(ac_space, spaces.Discrete):\n        # return CategoricalPdType(ac_space.n)\n        return SoftCategoricalPdType(ac_space.n)\n    elif isinstance(ac_space, MultiDiscrete):\n        #return MultiCategoricalPdType(ac_space.low, ac_space.high)\n        return SoftMultiCategoricalPdType(ac_space.low, ac_space.high)\n    elif isinstance(ac_space, spaces.MultiBinary):\n        return BernoulliPdType(ac_space.n)\n    else:\n        raise NotImplementedError\n\ndef shape_el(v, i):\n    maybe = v.get_shape()[i]\n    if maybe is not None:\n        return maybe\n    else:\n        return tf.shape(v)[i]\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/MPE/common/tile_images.py",
    "content": "import numpy as np\n\ndef tile_images(img_nhwc):\n    \"\"\"\n    Tile N images into one big PxQ image\n    (P,Q) are chosen to be as close as possible, and if N\n    is square, then P=Q.\n\n    input: img_nhwc, list or array of images, ndim=4 once turned into array\n        n = batch index, h = height, w = width, c = channel\n    returns:\n        bigim_HWc, ndarray with ndim=3\n    \"\"\"\n    img_nhwc = np.asarray(img_nhwc)\n    N, h, w, c = img_nhwc.shape\n    H = int(np.ceil(np.sqrt(N)))\n    W = int(np.ceil(float(N)/H))\n    img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)])\n    img_HWhwc = img_nhwc.reshape(H, W, h, w, c)\n    img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4)\n    img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c)\n    return img_Hh_Ww_c\n\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/MPE/common/vec_env/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/MPE/common/vec_env/vec_env.py",
    "content": "import contextlib\nimport os\nfrom abc import ABC, abstractmethod\n\nfrom common.tile_images import tile_images\n\nclass AlreadySteppingError(Exception):\n    \"\"\"\n    Raised when an asynchronous step is running while\n    step_async() is called again.\n    \"\"\"\n\n    def __init__(self):\n        msg = 'already running an async step'\n        Exception.__init__(self, msg)\n\n\nclass NotSteppingError(Exception):\n    \"\"\"\n    Raised when an asynchronous step is not running but\n    step_wait() is called.\n    \"\"\"\n\n    def __init__(self):\n        msg = 'not running an async step'\n        Exception.__init__(self, msg)\n\n\nclass VecEnv(ABC):\n    \"\"\"\n    An abstract asynchronous, vectorized environment.\n    Used to batch data from multiple copies of an environment, so that\n    each observation becomes an batch of observations, and expected action is a batch of actions to\n    be applied per-environment.\n    \"\"\"\n    closed = False\n    viewer = None\n\n    metadata = {\n        'render.modes': ['human', 'rgb_array']\n    }\n\n    def __init__(self, num_envs, observation_space, action_space):\n        self.num_envs = num_envs\n        self.observation_space = observation_space\n        self.action_space = action_space\n\n    @abstractmethod\n    def reset(self):\n        \"\"\"\n        Reset all the environments and return an array of\n        observations, or a dict of observation arrays.\n\n        If step_async is still doing work, that work will\n        be cancelled and step_wait() should not be called\n        until step_async() is invoked again.\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def step_async(self, actions):\n        \"\"\"\n        Tell all the environments to start taking a step\n        with the given actions.\n        Call step_wait() to get the results of the step.\n\n        You should not call this if a step_async run is\n        already pending.\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def step_wait(self):\n        \"\"\"\n        Wait for the step taken with step_async().\n\n        Returns (obs, rews, dones, infos):\n         - obs: an array of observations, or a dict of\n                arrays of observations.\n         - rews: an array of rewards\n         - dones: an array of \"episode done\" booleans\n         - infos: a sequence of info objects\n        \"\"\"\n        pass\n\n    def close_extras(self):\n        \"\"\"\n        Clean up the  extra resources, beyond what's in this base class.\n        Only runs when not self.closed.\n        \"\"\"\n        pass\n\n    def close(self):\n        if self.closed:\n            return\n        if self.viewer is not None:\n            self.viewer.close()\n        self.close_extras()\n        self.closed = True\n\n    def step(self, actions):\n        \"\"\"\n        Step the environments synchronously.\n\n        This is available for backwards compatibility.\n        \"\"\"\n        self.step_async(actions)\n        return self.step_wait()\n\n    def render(self, mode='human'):\n        imgs = self.get_images()\n        bigimg = tile_images(imgs)\n        if mode == 'human':\n            self.get_viewer().imshow(bigimg)\n            return self.get_viewer().isopen\n        elif mode == 'rgb_array':\n            return bigimg\n        else:\n            raise NotImplementedError\n\n    def get_images(self):\n        \"\"\"\n        Return RGB images from each environment\n        \"\"\"\n        raise NotImplementedError\n\n    @property\n    def unwrapped(self):\n        if isinstance(self, VecEnvWrapper):\n            return self.venv.unwrapped\n        else:\n            return self\n\n    def get_viewer(self):\n        if self.viewer is None:\n            from gym.envs.classic_control import rendering\n            self.viewer = rendering.SimpleImageViewer()\n        return self.viewer\n\nclass VecEnvWrapper(VecEnv):\n    \"\"\"\n    An environment wrapper that applies to an entire batch\n    of environments at once.\n    \"\"\"\n\n    def __init__(self, venv, observation_space=None, action_space=None):\n        self.venv = venv\n        super().__init__(num_envs=venv.num_envs,\n                        observation_space=observation_space or venv.observation_space,\n                        action_space=action_space or venv.action_space)\n\n    def step_async(self, actions):\n        self.venv.step_async(actions)\n\n    @abstractmethod\n    def reset(self):\n        pass\n\n    @abstractmethod\n    def step_wait(self):\n        pass\n\n    def close(self):\n        return self.venv.close()\n\n    def render(self, mode='human'):\n        return self.venv.render(mode=mode)\n\n    def get_images(self):\n        return self.venv.get_images()\n\n    def __getattr__(self, name):\n        if name.startswith('_'):\n            raise AttributeError(\"attempted to get missing private attribute '{}'\".format(name))\n        return getattr(self.venv, name)\n\nclass VecEnvObservationWrapper(VecEnvWrapper):\n    @abstractmethod\n    def process(self, obs):\n        pass\n\n    def reset(self):\n        obs = self.venv.reset()\n        return self.process(obs)\n\n    def step_wait(self):\n        obs, rews, dones, infos = self.venv.step_wait()\n        return self.process(obs), rews, dones, infos\n\nclass CloudpickleWrapper(object):\n    \"\"\"\n    Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)\n    \"\"\"\n\n    def __init__(self, x):\n        self.x = x\n\n    def __getstate__(self):\n        import cloudpickle\n        return cloudpickle.dumps(self.x)\n\n    def __setstate__(self, ob):\n        import pickle\n        self.x = pickle.loads(ob)\n\n\n@contextlib.contextmanager\ndef clear_mpi_env_vars():\n    \"\"\"\n    from mpi4py import MPI will call MPI_Init by default.  If the child process has MPI environment variables, MPI will think that the child process is an MPI process just like the parent and do bad things such as hang.\n    This context manager is a hacky way to clear those environment variables temporarily such as when we are starting multiprocessing\n    Processes.\n    \"\"\"\n    removed_environment = {}\n    for k, v in list(os.environ.items()):\n        for prefix in ['OMPI_', 'PMI_']:\n            if k.startswith(prefix):\n                removed_environment[k] = v\n                del os.environ[k]\n    try:\n        yield\n    finally:\n        os.environ.update(removed_environment)\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/MPE/main.py",
    "content": "import argparse\nimport torch\nimport time\nimport os\nimport numpy as np\nfrom gym.spaces import Box, Discrete, MultiDiscrete\nfrom pathlib import Path\nfrom torch.autograd import Variable\nfrom tensorboardX import SummaryWriter\nfrom utils.make_env import make_env\nfrom utils.buffer import ReplayBuffer, ReplayBuffer_pre\nfrom utils.env_wrappers import SubprocVecEnv, DummyVecEnv\nfrom policy.maddpg import MADDPG, MADDPG_SNN, MADDPG_ToM, ToM_SA, ToM_S, ToM_self\nfrom tqdm import tqdm\n\ndef get_common_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--env_id\", default='simple_world_comm', type=str,\n                        choices=['simple_tag', 'simple_adversary', 'simple_push', 'simple_world_comm'],\n                        help=\"Name of environment\")\n    parser.add_argument(\"--model_name\", default='ann', type=str,\n                        help=\"Name of directory to store \" +\n                             \"model/training contents\") #ToM_SA\n    parser.add_argument(\"--seed\",\n                        default=1, type=int,\n                        help=\"Random seed\")\n    parser.add_argument(\"--cuda_num\",\n                        default=7, type=int,\n                        help=\"device\")\n    parser.add_argument(\"--output_style\",\n                        default='sum', type=str,\n                        choices=['sum', 'voltage'])\n    parser.add_argument(\"--n_rollout_threads\", default=20, type=int)\n    parser.add_argument(\"--n_training_threads\", default=6, type=int)\n    parser.add_argument(\"--buffer_length\", default=int(1e6), type=int)\n    parser.add_argument(\"--n_episodes\", default=15000, type=int)#\n    parser.add_argument(\"--episode_length\", default=25, type=int)\n    parser.add_argument(\"--steps_per_update\", default=100, type=int)\n    parser.add_argument(\"--batch_size\",\n                        default=1024, type=int,#4\n                        help=\"Batch size for model training\")\n    parser.add_argument(\"--n_exploration_eps\", default=25000, type=int)\n    parser.add_argument(\"--init_noise_scale\", default=0.3, type=float)\n    parser.add_argument(\"--final_noise_scale\", default=0.0, type=float)\n    parser.add_argument(\"--save_interval\", default=1000, type=int)\n    parser.add_argument(\"--hidden_dim\", default=64, type=int)\n    parser.add_argument(\"--lr\", default=0.01, type=float)\n    parser.add_argument(\"--tau\", default=0.01, type=float)\n    parser.add_argument(\"--agent_alg\",\n                        default=\"MADDPG_ToM\", type=str,\n                        choices=['MADDPG', 'DDPG', 'MADDPG_SNN', 'MADDPG_ToM', 'ToM_SA', 'ToM_S', 'ToM_self'])\n    parser.add_argument(\"--adversary_alg\",\n                        default=\"MADDPG_ToM\", type=str,\n                        choices=['MADDPG', 'DDPG', 'MADDPG_SNN', 'MADDPG_ToM', 'ToM_SA', 'ToM_S', 'ToM_self'])\n    parser.add_argument(\"--discrete_action\",\n                        # default=False, type=bool,\n                        action='store_true')\n    args = parser.parse_args()\n    parser.add_argument('--device', type=str, default='cuda:{}'.format(args.cuda_num), help='whether to use the GPU')  #'cuda:1'\n    parser = parser.parse_args()\n    return parser\n\nUSE_CUDA = torch.cuda.is_available()\n\ndef make_parallel_env(env_id, n_rollout_threads, seed, discrete_action):\n    def get_env_fn(rank):\n        def init_env():\n            env = make_env(env_id, discrete_action=discrete_action)\n            env.seed(seed + rank * 1000)\n            np.random.seed(seed + rank * 1000)\n            return env\n        return init_env\n    if n_rollout_threads == 1:\n        return DummyVecEnv([get_env_fn(0)])\n    else:\n        return SubprocVecEnv([get_env_fn(i) for i in range(n_rollout_threads)])\n\ndef run(config):\n    pbar = tqdm(config.n_episodes)\n    model_dir = Path('./models') / config.env_id / config.model_name\n    if not model_dir.exists():\n        curr_run = 'run1'\n    else:\n        exst_run_nums = [int(str(folder.name).split('run')[1]) for folder in\n                         model_dir.iterdir() if\n                         str(folder.name).startswith('run')]\n        if len(exst_run_nums) == 0:\n            curr_run = 'run1'\n        else:\n            curr_run = 'run%i' % (max(exst_run_nums) + 1)\n    run_dir = model_dir / curr_run\n    log_dir = run_dir / 'logs'\n    os.makedirs(log_dir)\n    logger = SummaryWriter(str(log_dir))\n\n    torch.manual_seed(config.seed)\n    np.random.seed(config.seed)\n    if not USE_CUDA:\n        torch.set_num_threads(config.n_training_threads)\n    env = make_parallel_env(config.env_id, config.n_rollout_threads, config.seed,\n                            config.discrete_action)\n    if config.agent_alg == 'MADDPG' or config.agent_alg == 'DDPG':\n        print('_____MADDPG_____')\n        maddpg = MADDPG.init_from_env(env, agent_alg=config.agent_alg,\n                                      adversary_alg=config.adversary_alg,\n                                      tau=config.tau,\n                                      lr=config.lr,\n                                      hidden_dim=config.hidden_dim,\n                                      device=config.device)\n    elif config.agent_alg == 'MADDPG_SNN':\n        print('_____MADDPG_SNN_____')\n        maddpg = MADDPG_SNN.init_from_env(env, agent_alg=config.agent_alg,\n                                      adversary_alg=config.adversary_alg,\n                                      tau=config.tau,\n                                      lr=config.lr,\n                                      hidden_dim=config.hidden_dim,\n                                      output_style=config.output_style,\n                                          device=config.device)\n    elif config.agent_alg == 'MADDPG_ToM':\n        print('_____MADDPG_ToM_____')\n        maddpg = MADDPG_ToM.init_from_env(env, agent_alg=config.agent_alg,\n                                      adversary_alg=config.adversary_alg,\n                                      tau=config.tau,\n                                      lr=config.lr,\n                                      hidden_dim=config.hidden_dim,\n                                      output_style=config.output_style,\n                                      device=config.device)\n\n    elif config.agent_alg == 'ToM_SA':\n        print('_______ToM_SA_______')\n        maddpg = ToM_SA.init_from_env(env, agent_alg=config.agent_alg,\n                                      adversary_alg=config.adversary_alg,\n                                      tau=config.tau,\n                                      lr=config.lr,\n                                      hidden_dim=config.hidden_dim,\n                                      output_style=config.output_style,\n                                      device=config.device)\n    elif config.agent_alg == 'ToM_S':\n        print('_______ToM_S_______')\n        maddpg = ToM_S.init_from_env(env, agent_alg=config.agent_alg,\n                                      adversary_alg=config.adversary_alg,\n                                      tau=config.tau,\n                                      lr=config.lr,\n                                      hidden_dim=config.hidden_dim,\n                                      output_style=config.output_style,\n                                      device=config.device)\n\n\n        print('_______ToM_self_______')\n        maddpg = ToM_self.init_from_env(env, agent_alg=config.agent_alg,\n                                      adversary_alg=config.adversary_alg,\n                                      tau=config.tau,\n                                      lr=config.lr,\n                                      hidden_dim=config.hidden_dim,\n                                      output_style=config.output_style,\n                                      device=config.device)\n\n\n    if config.agent_alg == 'ToM_SA'or config.agent_alg == 'ToM_S' or config.agent_alg == 'ToM_self' or config.agent_alg == 'ToM_SB':\n        replay_buffer = ReplayBuffer_pre(config.buffer_length, maddpg.nagents,\n                                     [obsp.shape[0] for obsp in env.observation_space],\n                                     [acsp.n if isinstance(acsp, Discrete) else sum(acsp.high - acsp.low + 1)\n                                      for acsp in env.action_space],\n                                 device=config.device)\n    else:\n\n        replay_buffer = ReplayBuffer(config.buffer_length, maddpg.nagents,\n                                     [obsp.shape[0] for obsp in env.observation_space],\n                                     [acsp.n if isinstance(acsp, Discrete) else sum(acsp.high - acsp.low + 1)\n                                      for acsp in env.action_space],\n                                 device=config.device)\n\n    t = 0\n    total_reward = []\n    for agent_i in range(maddpg.nagents):\n        total_reward.append([])\n    for ep_i in range(0, config.n_episodes, config.n_rollout_threads):\n        # print(\"Episodes %i-%i of %i\" % (ep_i + 1,\n        #                                 ep_i + 1 + config.n_rollout_threads,\n        #                                 config.n_episodes))\n        obs = env.reset()\n        # obs.shape = (n_rollout_threads, nagent)(nobs), nobs differs per agent so not tensor\n        maddpg.prep_rollouts(device='cpu')\n        torch_agent_actions = [torch.zeros((config.n_rollout_threads, 5)) for i in range(maddpg.nagents)]\n        explr_pct_remaining = max(0, config.n_exploration_eps - ep_i) / config.n_exploration_eps\n        maddpg.scale_noise(config.final_noise_scale + (config.init_noise_scale - config.final_noise_scale) * explr_pct_remaining)\n        maddpg.reset_noise()\n        obs_ep = []\n        agent_actions_ep = []\n        rewards_ep = []\n        next_obs_ep = []\n        dones_ep = []\n        for et_i in range(config.episode_length):\n            torch_agent_actions_pre = torch_agent_actions\n            torch_agent_actions_pre = [ac.data.numpy() for ac in torch_agent_actions_pre]\n            # rearrange observations to be per agent, and convert to torch Variable\n            torch_obs = [Variable(torch.Tensor(np.vstack(obs[:, i])),\n                                  requires_grad=False)\n                         for i in range(maddpg.nagents)]    #\n            # get actions as torch Variables\n            # t1 = time.time()\n            if config.agent_alg == 'ToM_SA' or config.agent_alg == 'ToM_S' or config.agent_alg == 'ToM_self' or config.agent_alg == 'ToM_SB':\n                torch_agent_actions = maddpg.step(torch_obs, torch_agent_actions, explore=True)\n            else:\n                torch_agent_actions = maddpg.step(torch_obs, explore=True)\n            # t2 = time.time()\n            # print('time_step:', t2-t1)\n            # convert actions to numpy arrays\n            agent_actions = [ac.data.numpy() for ac in torch_agent_actions] #\n            # rearrange actions to be per environment\n            actions = [[ac[i] for ac in agent_actions] for i in range(config.n_rollout_threads)]\n            next_obs, rewards, dones, infos = env.step(actions)\n            obs_ep.append(obs)                  #episode_id,process, n_agents, dim\n            agent_actions_ep.append(actions)    #episode_id, n_agents, process, dim\n            rewards_ep.append(rewards)          #episode_id,process, n_agents,\n            next_obs_ep.append(next_obs)            #episode_id,process, n_agents, dim\n            dones_ep.append(dones)              #episode_id,process, n_agents,\n            if config.agent_alg == 'ToM_SA' or config.agent_alg == 'ToM_S'  or config.agent_alg == 'ToM_self'or config.agent_alg == 'ToM_SB':\n                replay_buffer.push(torch_agent_actions_pre, obs, agent_actions, rewards, next_obs, dones)\n            else:\n                replay_buffer.push(obs, agent_actions, rewards, next_obs, dones)\n            obs = next_obs\n            t += config.n_rollout_threads\n\n            if (len(replay_buffer) >= config.batch_size and\n                (t % config.steps_per_update) < config.n_rollout_threads):\n                if USE_CUDA:\n                    maddpg.prep_training(device='gpu')\n                else:\n                    maddpg.prep_training(device='cpu')\n                if config.n_episodes >300:\n                    rollout = 2\n                else:\n                    rollout = config.n_rollout_threads\n                for u_i in range(rollout):\n                    for a_i in range(maddpg.nagents):\n                        sample = replay_buffer.sample(config.batch_size,\n                                                      to_gpu=USE_CUDA)\n                        # t1 = time.time()\n                        maddpg.update(sample, a_i, logger=logger)\n                        # t2 = time.time()\n                        # print('trian_time:', t2-t1, u_i, a_i)\n                    maddpg.update_all_targets()\n                maddpg.prep_rollouts(device='cpu')\n        ep_rews = replay_buffer.get_average_rewards(\n            config.episode_length * config.n_rollout_threads)\n        for a_i, a_ep_rew in enumerate(ep_rews):\n            logger.add_scalar('agent%i/mean_episode_rewards' % a_i,\n                              a_ep_rew,\n                              ep_i)\n        logger.add_scalar('agent_mean/mean_episode_rewards',\n                          np.mean(ep_rews),\n                          ep_i)\n\n        if ep_i % config.save_interval < config.n_rollout_threads:\n            os.makedirs(run_dir / 'incremental', exist_ok=True)\n            maddpg.save(run_dir / 'incremental' / ('model_ep%i.pt' % (ep_i + 1)))\n            maddpg.save(run_dir / 'model.pt')\n        pbar.update(config.n_rollout_threads)\n\n    pbar.close()\n    maddpg.save(run_dir / 'model.pt')\n    env.close()\n    logger.export_scalars_to_json(str(log_dir / 'summary.json'))\n    logger.close()\n    for a_i, reward in enumerate(total_reward):\n        reward_dir = str(log_dir) + '/agent{}/mean_episode_rewards'.format(a_i) + '/episode_rewards_{}'.format(config.cuda_num)\n        os.makedirs(reward_dir)\n        np.save(reward_dir, reward)\n\n\nif __name__ == '__main__':\n    config = get_common_args()\n    # config.env_id = 'simple_tag'\n    # # config.model_name = 'ma2c'\n    config.agent_alg = 'ToM_SB'#\n    config.adversary_alg = 'ToM_SB'\n    #\n    run(config)\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/MPE/multiagent/__init__.py",
    "content": "from gym.envs.registration import register\n\n# Multiagent envs\n# ----------------------------------------\n\nregister(\n    id='MultiagentSimple-v0',\n    entry_point='multiagent.envs:SimpleEnv',\n    # FIXME(cathywu) currently has to be exactly max_path_length parameters in\n    # rllab run script\n    max_episode_steps=100,\n)\n\nregister(\n    id='MultiagentSimpleSpeakerListener-v0',\n    entry_point='multiagent.envs:SimpleSpeakerListenerEnv',\n    max_episode_steps=100,\n)\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/MPE/multiagent/scenarios/__init__.py",
    "content": "import imp\nimport os.path as osp\n\n\ndef load(name):\n    pathname = osp.join(osp.dirname(__file__), name)\n    return imp.load_source('', pathname)\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/MPE/multiagent/scenarios/simple.py",
    "content": "import numpy as np\nfrom multiagent.core import World, Agent, Landmark\nfrom multiagent.scenario import BaseScenario\n\nclass Scenario(BaseScenario):\n    def make_world(self):\n        world = World()\n        # add agents\n        world.agents = [Agent() for i in range(1)]\n        for i, agent in enumerate(world.agents):\n            agent.name = 'agent %d' % i\n            agent.collide = False\n            agent.silent = True\n        # add landmarks\n        world.landmarks = [Landmark() for i in range(1)]\n        for i, landmark in enumerate(world.landmarks):\n            landmark.name = 'landmark %d' % i\n            landmark.collide = False\n            landmark.movable = False\n        # make initial conditions\n        self.reset_world(world)\n        return world\n\n    def reset_world(self, world):\n        # random properties for agents\n        for i, agent in enumerate(world.agents):\n            agent.color = np.array([0.25,0.25,0.25])\n        # random properties for landmarks\n        for i, landmark in enumerate(world.landmarks):\n            landmark.color = np.array([0.75,0.75,0.75])\n        world.landmarks[0].color = np.array([0.75,0.25,0.25])\n        # set random initial states\n        for agent in world.agents:\n            agent.state.p_pos = np.random.uniform(-1,+1, world.dim_p)\n            agent.state.p_vel = np.zeros(world.dim_p)\n            agent.state.c = np.zeros(world.dim_c)\n        for i, landmark in enumerate(world.landmarks):\n            landmark.state.p_pos = np.random.uniform(-1,+1, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n\n    def reward(self, agent, world):\n        dist2 = np.sum(np.square(agent.state.p_pos - world.landmarks[0].state.p_pos))\n        return -dist2\n\n    def observation(self, agent, world):\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:\n            entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n        return np.concatenate([agent.state.p_vel] + entity_pos)\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/MPE/multiagent/scenarios/simple_crypto.py",
    "content": "\"\"\"\nScenario:\n1 speaker, 2 listeners (one of which is an adversary). Good agents rewarded for proximity to goal, and distance from\nadversary to goal. Adversary is rewarded for its distance to the goal.\n\"\"\"\n\n\nimport numpy as np\nfrom multiagent.core import World, Agent, Landmark\nfrom multiagent.scenario import BaseScenario\nimport random\n\n\nclass CryptoAgent(Agent):\n    def __init__(self):\n        super(CryptoAgent, self).__init__()\n        self.key = None\n\nclass Scenario(BaseScenario):\n\n    def make_world(self):\n        world = World()\n        # set any world properties first\n        num_agents = 3\n        num_adversaries = 1\n        num_landmarks = 2\n        world.dim_c = 4\n        # add agents\n        world.agents = [CryptoAgent() for i in range(num_agents)]\n        for i, agent in enumerate(world.agents):\n            agent.name = 'agent %d' % i\n            agent.collide = False\n            agent.adversary = True if i < num_adversaries else False\n            agent.speaker = True if i == 2 else False\n            agent.movable = False\n        # add landmarks\n        world.landmarks = [Landmark() for i in range(num_landmarks)]\n        for i, landmark in enumerate(world.landmarks):\n            landmark.name = 'landmark %d' % i\n            landmark.collide = False\n            landmark.movable = False\n        # make initial conditions\n        self.reset_world(world)\n        return world\n\n\n    def reset_world(self, world):\n        # random properties for agents\n        for i, agent in enumerate(world.agents):\n            agent.color = np.array([0.25, 0.25, 0.25])\n            if agent.adversary:\n                agent.color = np.array([0.75, 0.25, 0.25])\n            agent.key = None\n        # random properties for landmarks\n        color_list = [np.zeros(world.dim_c) for i in world.landmarks]\n        for i, color in enumerate(color_list):\n            color[i] += 1\n        for color, landmark in zip(color_list, world.landmarks):\n            landmark.color = color\n        # set goal landmark\n        goal = np.random.choice(world.landmarks)\n        world.agents[1].color = goal.color\n        world.agents[2].key = np.random.choice(world.landmarks).color\n\n        for agent in world.agents:\n            agent.goal_a = goal\n\n        # set random initial states\n        for agent in world.agents:\n            agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            agent.state.p_vel = np.zeros(world.dim_p)\n            agent.state.c = np.zeros(world.dim_c)\n        for i, landmark in enumerate(world.landmarks):\n            landmark.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n\n\n    def benchmark_data(self, agent, world):\n        # returns data for benchmarking purposes\n        return (agent.state.c, agent.goal_a.color)\n\n    # return all agents that are not adversaries\n    def good_listeners(self, world):\n        return [agent for agent in world.agents if not agent.adversary and not agent.speaker]\n\n    # return all agents that are not adversaries\n    def good_agents(self, world):\n        return [agent for agent in world.agents if not agent.adversary]\n\n    # return all adversarial agents\n    def adversaries(self, world):\n        return [agent for agent in world.agents if agent.adversary]\n\n    def reward(self, agent, world):\n        return self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)\n\n    def agent_reward(self, agent, world):\n        # Agents rewarded if Bob can reconstruct message, but adversary (Eve) cannot\n        good_listeners = self.good_listeners(world)\n        adversaries = self.adversaries(world)\n        good_rew = 0\n        adv_rew = 0\n        for a in good_listeners:\n            if (a.state.c == np.zeros(world.dim_c)).all():\n                continue\n            else:\n                good_rew -= np.sum(np.square(a.state.c - agent.goal_a.color))\n        for a in adversaries:\n            if (a.state.c == np.zeros(world.dim_c)).all():\n                continue\n            else:\n                adv_l1 = np.sum(np.square(a.state.c - agent.goal_a.color))\n                adv_rew += adv_l1\n        return adv_rew + good_rew\n\n    def adversary_reward(self, agent, world):\n        # Adversary (Eve) is rewarded if it can reconstruct original goal\n        rew = 0\n        if not (agent.state.c == np.zeros(world.dim_c)).all():\n            rew -= np.sum(np.square(agent.state.c - agent.goal_a.color))\n        return rew\n\n\n    def observation(self, agent, world):\n        # goal color\n        goal_color = np.zeros(world.dim_color)\n        if agent.goal_a is not None:\n            goal_color = agent.goal_a.color\n\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:\n            entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n        # communication of all other agents\n        comm = []\n        for other in world.agents:\n            if other is agent or (other.state.c is None) or not other.speaker: continue\n            comm.append(other.state.c)\n\n        confer = np.array([0])\n\n        if world.agents[2].key is None:\n            confer = np.array([1])\n            key = np.zeros(world.dim_c)\n            goal_color = np.zeros(world.dim_c)\n        else:\n            key = world.agents[2].key\n\n        prnt = False\n        # speaker\n        if agent.speaker:\n            if prnt:\n                print('speaker')\n                print(agent.state.c)\n                print(np.concatenate([goal_color] + [key] + [confer] + [np.random.randn(1)]))\n            return np.concatenate([goal_color] + [key])\n        # listener\n        if not agent.speaker and not agent.adversary:\n            if prnt:\n                print('listener')\n                print(agent.state.c)\n                print(np.concatenate([key] + comm + [confer]))\n            return np.concatenate([key] + comm)\n        if not agent.speaker and agent.adversary:\n            if prnt:\n                print('adversary')\n                print(agent.state.c)\n                print(np.concatenate(comm + [confer]))\n            return np.concatenate(comm)\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/MPE/multiagent/scenarios/simple_push.py",
    "content": "import numpy as np\nfrom multiagent.core import World, Agent, Landmark\nfrom multiagent.scenario import BaseScenario\n\nclass Scenario(BaseScenario):\n    def make_world(self):\n        world = World()\n        # set any world properties first\n        world.dim_c = 2\n        num_agents = 2\n        num_adversaries = 1\n        num_landmarks = 2\n        # add agents\n        world.agents = [Agent() for i in range(num_agents)]\n        for i, agent in enumerate(world.agents):\n            agent.name = 'agent %d' % i\n            agent.collide = True\n            agent.silent = True\n            if i < num_adversaries:\n                agent.adversary = True\n            else:\n                agent.adversary = False\n        # add landmarks\n        world.landmarks = [Landmark() for i in range(num_landmarks)]\n        for i, landmark in enumerate(world.landmarks):\n            landmark.name = 'landmark %d' % i\n            landmark.collide = False\n            landmark.movable = False\n        # make initial conditions\n        self.reset_world(world)\n        return world\n\n    def reset_world(self, world):\n        # random properties for landmarks\n        for i, landmark in enumerate(world.landmarks):\n            landmark.color = np.array([0.1, 0.1, 0.1])\n            landmark.color[i + 1] += 0.8\n            landmark.index = i\n        # set goal landmark\n        goal = np.random.choice(world.landmarks)\n        for i, agent in enumerate(world.agents):\n            agent.goal_a = goal\n            agent.color = np.array([0.25, 0.25, 0.25])\n            if agent.adversary:\n                agent.color = np.array([0.75, 0.25, 0.25])\n            else:\n                j = goal.index\n                agent.color[j + 1] += 0.5\n        # set random initial states\n        for agent in world.agents:\n            agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            agent.state.p_vel = np.zeros(world.dim_p)\n            agent.state.c = np.zeros(world.dim_c)\n        for i, landmark in enumerate(world.landmarks):\n            landmark.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n\n    def reward(self, agent, world):\n        # Agents are rewarded based on minimum agent distance to each landmark\n        return self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)\n\n    def agent_reward(self, agent, world):\n        # the distance to the goal\n        return -np.sqrt(np.sum(np.square(agent.state.p_pos - agent.goal_a.state.p_pos)))\n\n    def adversary_reward(self, agent, world):\n        # keep the nearest good agents away from the goal\n        agent_dist = [np.sqrt(np.sum(np.square(a.state.p_pos - a.goal_a.state.p_pos))) for a in world.agents if not a.adversary]\n        pos_rew = min(agent_dist)\n        #nearest_agent = world.good_agents[np.argmin(agent_dist)]\n        #neg_rew = np.sqrt(np.sum(np.square(nearest_agent.state.p_pos - agent.state.p_pos)))\n        neg_rew = np.sqrt(np.sum(np.square(agent.goal_a.state.p_pos - agent.state.p_pos)))\n        #neg_rew = sum([np.sqrt(np.sum(np.square(a.state.p_pos - agent.state.p_pos))) for a in world.good_agents])\n        return pos_rew - neg_rew\n               \n    def observation(self, agent, world):\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:  # world.entities:\n            entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n        # entity colors\n        entity_color = []\n        for entity in world.landmarks:  # world.entities:\n            entity_color.append(entity.color)\n        # communication of all other agents\n        comm = []\n        other_pos = []\n        for other in world.agents:\n            if other is agent: continue\n            comm.append(other.state.c)\n            other_pos.append(other.state.p_pos - agent.state.p_pos)\n        if not agent.adversary:\n            return np.concatenate([agent.state.p_vel] + [agent.goal_a.state.p_pos - agent.state.p_pos] + [agent.color] + entity_pos + entity_color + other_pos)\n        else:\n            #other_pos = list(reversed(other_pos)) if random.uniform(0,1) > 0.5 else other_pos  # randomize position of other agents in adversary network\n            return np.concatenate([agent.state.p_vel] + entity_pos + other_pos)\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/MPE/multiagent/scenarios/simple_reference.py",
    "content": "import numpy as np\nfrom multiagent.core import World, Agent, Landmark\nfrom multiagent.scenario import BaseScenario\n\nclass Scenario(BaseScenario):\n    def make_world(self):\n        world = World()\n        # set any world properties first\n        world.dim_c = 10\n        world.collaborative = True  # whether agents share rewards\n        # add agents\n        world.agents = [Agent() for i in range(2)]\n        for i, agent in enumerate(world.agents):\n            agent.name = 'agent %d' % i\n            agent.collide = False\n        # add landmarks\n        world.landmarks = [Landmark() for i in range(3)]\n        for i, landmark in enumerate(world.landmarks):\n            landmark.name = 'landmark %d' % i\n            landmark.collide = False\n            landmark.movable = False\n        # make initial conditions\n        self.reset_world(world)\n        return world\n\n    def reset_world(self, world):\n        # assign goals to agents\n        for agent in world.agents:\n            agent.goal_a = None\n            agent.goal_b = None\n        # want other agent to go to the goal landmark\n        world.agents[0].goal_a = world.agents[1]\n        world.agents[0].goal_b = np.random.choice(world.landmarks)\n        world.agents[1].goal_a = world.agents[0]\n        world.agents[1].goal_b = np.random.choice(world.landmarks)\n        # random properties for agents\n        for i, agent in enumerate(world.agents):\n            agent.color = np.array([0.25,0.25,0.25])               \n        # random properties for landmarks\n        world.landmarks[0].color = np.array([0.75,0.25,0.25]) \n        world.landmarks[1].color = np.array([0.25,0.75,0.25]) \n        world.landmarks[2].color = np.array([0.25,0.25,0.75]) \n        # special colors for goals\n        world.agents[0].goal_a.color = world.agents[0].goal_b.color                \n        world.agents[1].goal_a.color = world.agents[1].goal_b.color                               \n        # set random initial states\n        for agent in world.agents:\n            agent.state.p_pos = np.random.uniform(-1,+1, world.dim_p)\n            agent.state.p_vel = np.zeros(world.dim_p)\n            agent.state.c = np.zeros(world.dim_c)\n        for i, landmark in enumerate(world.landmarks):\n            landmark.state.p_pos = np.random.uniform(-1,+1, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n\n    def reward(self, agent, world):\n        if agent.goal_a is None or agent.goal_b is None:\n            return 0.0\n        dist2 = np.sum(np.square(agent.goal_a.state.p_pos - agent.goal_b.state.p_pos))\n        return -dist2\n\n    def observation(self, agent, world):\n        # goal color\n        goal_color = [np.zeros(world.dim_color), np.zeros(world.dim_color)]\n        if agent.goal_b is not None:\n            goal_color[1] = agent.goal_b.color \n\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:\n            entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n        # entity colors\n        entity_color = []\n        for entity in world.landmarks:\n            entity_color.append(entity.color)\n        # communication of all other agents\n        comm = []\n        for other in world.agents:\n            if other is agent: continue\n            comm.append(other.state.c)\n        return np.concatenate([agent.state.p_vel] + entity_pos + [goal_color[1]] + comm)\n            "
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/MPE/multiagent/scenarios/simple_speaker_listener.py",
    "content": "import numpy as np\nfrom multiagent.core import World, Agent, Landmark\nfrom multiagent.scenario import BaseScenario\n\nclass Scenario(BaseScenario):\n    def make_world(self):\n        world = World()\n        # set any world properties first\n        world.dim_c = 3\n        num_landmarks = 3\n        world.collaborative = True\n        # add agents\n        world.agents = [Agent() for i in range(2)]\n        for i, agent in enumerate(world.agents):\n            agent.name = 'agent %d' % i\n            agent.collide = False\n            agent.size = 0.075\n        # speaker\n        world.agents[0].movable = False\n        # listener\n        world.agents[1].silent = True\n        # add landmarks\n        world.landmarks = [Landmark() for i in range(num_landmarks)]\n        for i, landmark in enumerate(world.landmarks):\n            landmark.name = 'landmark %d' % i\n            landmark.collide = False\n            landmark.movable = False\n            landmark.size = 0.04\n        # make initial conditions\n        self.reset_world(world)\n        return world\n\n    def reset_world(self, world):\n        # assign goals to agents\n        for agent in world.agents:\n            agent.goal_a = None\n            agent.goal_b = None\n        # want listener to go to the goal landmark\n        world.agents[0].goal_a = world.agents[1]\n        world.agents[0].goal_b = np.random.choice(world.landmarks)\n        # random properties for agents\n        for i, agent in enumerate(world.agents):\n            agent.color = np.array([0.25,0.25,0.25])               \n        # random properties for landmarks\n        world.landmarks[0].color = np.array([0.65,0.15,0.15])\n        world.landmarks[1].color = np.array([0.15,0.65,0.15])\n        world.landmarks[2].color = np.array([0.15,0.15,0.65])\n        # special colors for goals\n        world.agents[0].goal_a.color = world.agents[0].goal_b.color + np.array([0.45, 0.45, 0.45])\n        # set random initial states\n        for agent in world.agents:\n            agent.state.p_pos = np.random.uniform(-1,+1, world.dim_p)\n            agent.state.p_vel = np.zeros(world.dim_p)\n            agent.state.c = np.zeros(world.dim_c)\n        for i, landmark in enumerate(world.landmarks):\n            landmark.state.p_pos = np.random.uniform(-1,+1, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n\n    def benchmark_data(self, agent, world):\n        # returns data for benchmarking purposes\n        return self.reward(agent, reward)\n\n    def reward(self, agent, world):\n        # squared distance from listener to landmark\n        a = world.agents[0]\n        dist2 = np.sum(np.square(a.goal_a.state.p_pos - a.goal_b.state.p_pos))\n        return -dist2\n\n    def observation(self, agent, world):\n        # goal color\n        goal_color = np.zeros(world.dim_color)\n        if agent.goal_b is not None:\n            goal_color = agent.goal_b.color\n\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:\n            entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n\n        # communication of all other agents\n        comm = []\n        for other in world.agents:\n            if other is agent or (other.state.c is None): continue\n            comm.append(other.state.c)\n        \n        # speaker\n        if not agent.movable:\n            return np.concatenate([goal_color])\n        # listener\n        if agent.silent:\n            return np.concatenate([agent.state.p_vel] + entity_pos + comm)\n            \n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/MPE/multiagent/scenarios/simple_spread.py",
    "content": "import numpy as np\nfrom multiagent.core import World, Agent, Landmark\nfrom multiagent.scenario import BaseScenario\n\n\nclass Scenario(BaseScenario):\n    def make_world(self):\n        world = World()\n        # set any world properties first\n        world.dim_c = 2\n        num_agents = 3\n        num_landmarks = 3\n        world.collaborative = True\n        # add agents\n        world.agents = [Agent() for i in range(num_agents)]\n        for i, agent in enumerate(world.agents):\n            agent.name = 'agent %d' % i\n            agent.collide = True\n            agent.silent = True\n            agent.size = 0.15\n        # add landmarks\n        world.landmarks = [Landmark() for i in range(num_landmarks)]\n        for i, landmark in enumerate(world.landmarks):\n            landmark.name = 'landmark %d' % i\n            landmark.collide = False\n            landmark.movable = False\n        # make initial conditions\n        self.reset_world(world)\n        return world\n\n    def reset_world(self, world):\n        # random properties for agents\n        for i, agent in enumerate(world.agents):\n            agent.color = np.array([0.35, 0.35, 0.85])\n        # random properties for landmarks\n        for i, landmark in enumerate(world.landmarks):\n            landmark.color = np.array([0.25, 0.25, 0.25])\n        # set random initial states\n        for agent in world.agents:\n            agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            agent.state.p_vel = np.zeros(world.dim_p)\n            agent.state.c = np.zeros(world.dim_c)\n        for i, landmark in enumerate(world.landmarks):\n            landmark.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n\n    def benchmark_data(self, agent, world):\n        rew = 0\n        collisions = 0\n        occupied_landmarks = 0\n        min_dists = 0\n        for l in world.landmarks:\n            dists = [np.sqrt(np.sum(np.square(a.state.p_pos - l.state.p_pos))) for a in world.agents]\n            min_dists += min(dists)\n            rew -= min(dists)\n            if min(dists) < 0.1:\n                occupied_landmarks += 1\n        if agent.collide:\n            for a in world.agents:\n                if self.is_collision(a, agent):\n                    rew -= 1\n                    collisions += 1\n        return (rew, collisions, min_dists, occupied_landmarks)\n\n\n    def is_collision(self, agent1, agent2):\n        delta_pos = agent1.state.p_pos - agent2.state.p_pos\n        dist = np.sqrt(np.sum(np.square(delta_pos)))\n        dist_min = agent1.size + agent2.size\n        return True if dist < dist_min else False\n\n    def reward(self, agent, world):\n        # Agents are rewarded based on minimum agent distance to each landmark, penalized for collisions\n        rew = 0\n        for l in world.landmarks:\n            dists = [np.sqrt(np.sum(np.square(a.state.p_pos - l.state.p_pos))) for a in world.agents]\n            rew -= min(dists)\n        if agent.collide:\n            for a in world.agents:\n                if self.is_collision(a, agent):\n                    rew -= 1\n        return rew\n\n    def observation(self, agent, world):\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:  # world.entities:\n            entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n        # entity colors\n        entity_color = []\n        for entity in world.landmarks:  # world.entities:\n            entity_color.append(entity.color)\n        # communication of all other agents\n        comm = []\n        other_pos = []\n        for other in world.agents:\n            if other is agent: continue\n            comm.append(other.state.c)\n            other_pos.append(other.state.p_pos - agent.state.p_pos)\n        return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + comm)\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/MPE/multiagent/scenarios/simple_world_comm.py",
    "content": "import numpy as np\nfrom multiagent.core import World, Agent, Landmark\nfrom multiagent.scenario import BaseScenario\n\n\nclass Scenario(BaseScenario):\n    def make_world(self):\n        world = World()\n        # set any world properties first\n        world.dim_c = 4\n        #world.damping = 1\n        num_good_agents = 2\n        num_adversaries = 4\n        num_agents = num_adversaries + num_good_agents\n        num_landmarks = 1\n        num_food = 2\n        num_forests = 2\n        # add agents\n        world.agents = [Agent() for i in range(num_agents)]\n        for i, agent in enumerate(world.agents):\n            agent.name = 'agent %d' % i\n            agent.collide = True\n            agent.leader = True if i == 0 else False\n            agent.silent = True if i > 0 else False\n            agent.adversary = True if i < num_adversaries else False\n            agent.size = 0.075 if agent.adversary else 0.045\n            agent.accel = 3.0 if agent.adversary else 4.0\n            #agent.accel = 20.0 if agent.adversary else 25.0\n            agent.max_speed = 1.0 if agent.adversary else 1.3\n        # add landmarks\n        world.landmarks = [Landmark() for i in range(num_landmarks)]\n        for i, landmark in enumerate(world.landmarks):\n            landmark.name = 'landmark %d' % i\n            landmark.collide = True\n            landmark.movable = False\n            landmark.size = 0.2\n            landmark.boundary = False\n        world.food = [Landmark() for i in range(num_food)]\n        for i, landmark in enumerate(world.food):\n            landmark.name = 'food %d' % i\n            landmark.collide = False\n            landmark.movable = False\n            landmark.size = 0.03\n            landmark.boundary = False\n        world.forests = [Landmark() for i in range(num_forests)]\n        for i, landmark in enumerate(world.forests):\n            landmark.name = 'forest %d' % i\n            landmark.collide = False\n            landmark.movable = False\n            landmark.size = 0.3\n            landmark.boundary = False\n        world.landmarks += world.food\n        world.landmarks += world.forests\n        #world.landmarks += self.set_boundaries(world)  # world boundaries now penalized with negative reward\n        # make initial conditions\n        self.reset_world(world)\n        return world\n\n    def set_boundaries(self, world):\n        boundary_list = []\n        landmark_size = 1\n        edge = 1 + landmark_size\n        num_landmarks = int(edge * 2 / landmark_size)\n        for x_pos in [-edge, edge]:\n            for i in range(num_landmarks):\n                l = Landmark()\n                l.state.p_pos = np.array([x_pos, -1 + i * landmark_size])\n                boundary_list.append(l)\n\n        for y_pos in [-edge, edge]:\n            for i in range(num_landmarks):\n                l = Landmark()\n                l.state.p_pos = np.array([-1 + i * landmark_size, y_pos])\n                boundary_list.append(l)\n\n        for i, l in enumerate(boundary_list):\n            l.name = 'boundary %d' % i\n            l.collide = True\n            l.movable = False\n            l.boundary = True\n            l.color = np.array([0.75, 0.75, 0.75])\n            l.size = landmark_size\n            l.state.p_vel = np.zeros(world.dim_p)\n\n        return boundary_list\n\n\n    def reset_world(self, world):\n        # random properties for agents\n        for i, agent in enumerate(world.agents):\n            agent.color = np.array([0.45, 0.95, 0.45]) if not agent.adversary else np.array([0.95, 0.45, 0.45])\n            agent.color -= np.array([0.3, 0.3, 0.3]) if agent.leader else np.array([0, 0, 0])\n            # random properties for landmarks\n        for i, landmark in enumerate(world.landmarks):\n            landmark.color = np.array([0.25, 0.25, 0.25])\n        for i, landmark in enumerate(world.food):\n            landmark.color = np.array([0.15, 0.15, 0.65])\n        for i, landmark in enumerate(world.forests):\n            landmark.color = np.array([0.6, 0.9, 0.6])\n        # set random initial states\n        for agent in world.agents:\n            agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            agent.state.p_vel = np.zeros(world.dim_p)\n            agent.state.c = np.zeros(world.dim_c)\n        for i, landmark in enumerate(world.landmarks):\n            landmark.state.p_pos = np.random.uniform(-0.9, +0.9, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n        for i, landmark in enumerate(world.food):\n            landmark.state.p_pos = np.random.uniform(-0.9, +0.9, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n        for i, landmark in enumerate(world.forests):\n            landmark.state.p_pos = np.random.uniform(-0.9, +0.9, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n\n    def benchmark_data(self, agent, world):\n        if agent.adversary:\n            collisions = 0\n            for a in self.good_agents(world):\n                if self.is_collision(a, agent):\n                    collisions += 1\n            return collisions\n        else:\n            return 0\n\n\n    def is_collision(self, agent1, agent2):\n        delta_pos = agent1.state.p_pos - agent2.state.p_pos\n        dist = np.sqrt(np.sum(np.square(delta_pos)))\n        dist_min = agent1.size + agent2.size\n        return True if dist < dist_min else False\n\n\n    # return all agents that are not adversaries\n    def good_agents(self, world):\n        return [agent for agent in world.agents if not agent.adversary]\n\n    # return all adversarial agents\n    def adversaries(self, world):\n        return [agent for agent in world.agents if agent.adversary]\n\n\n    def reward(self, agent, world):\n        # Agents are rewarded based on minimum agent distance to each landmark\n        #boundary_reward = -10 if self.outside_boundary(agent) else 0\n        main_reward = self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)\n        return main_reward\n\n    def outside_boundary(self, agent):\n        if agent.state.p_pos[0] > 1 or agent.state.p_pos[0] < -1 or agent.state.p_pos[1] > 1 or agent.state.p_pos[1] < -1:\n            return True\n        else:\n            return False\n\n\n    def agent_reward(self, agent, world):\n        # Agents are rewarded based on minimum agent distance to each landmark\n        rew = 0\n        shape = False\n        adversaries = self.adversaries(world)\n        if shape:\n            for adv in adversaries:\n                rew += 0.1 * np.sqrt(np.sum(np.square(agent.state.p_pos - adv.state.p_pos)))\n        if agent.collide:\n            for a in adversaries:\n                if self.is_collision(a, agent):\n                    rew -= 5\n        def bound(x):\n            if x < 0.9:\n                return 0\n            if x < 1.0:\n                return (x - 0.9) * 10\n            return min(np.exp(2 * x - 2), 10)  # 1 + (x - 1) * (x - 1)\n\n        for p in range(world.dim_p):\n            x = abs(agent.state.p_pos[p])\n            rew -= 2 * bound(x)\n\n        for food in world.food:\n            if self.is_collision(agent, food):\n                rew += 2\n        rew += 0.05 * min([np.sqrt(np.sum(np.square(food.state.p_pos - agent.state.p_pos))) for food in world.food])\n\n        return rew\n\n    def adversary_reward(self, agent, world):\n        # Agents are rewarded based on minimum agent distance to each landmark\n        rew = 0\n        shape = True\n        agents = self.good_agents(world)\n        adversaries = self.adversaries(world)\n        if shape:\n            rew -= 0.1 * min([np.sqrt(np.sum(np.square(a.state.p_pos - agent.state.p_pos))) for a in agents])\n        if agent.collide:\n            for ag in agents:\n                for adv in adversaries:\n                    if self.is_collision(ag, adv):\n                        rew += 5\n        return rew\n\n\n    def observation2(self, agent, world):\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:\n            if not entity.boundary:\n                entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n\n        food_pos = []\n        for entity in world.food:\n            if not entity.boundary:\n                food_pos.append(entity.state.p_pos - agent.state.p_pos)\n        # communication of all other agents\n        comm = []\n        other_pos = []\n        other_vel = []\n        for other in world.agents:\n            if other is agent: continue\n            comm.append(other.state.c)\n            other_pos.append(other.state.p_pos - agent.state.p_pos)\n            if not other.adversary:\n                other_vel.append(other.state.p_vel)\n        return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel)\n\n    def observation(self, agent, world):\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:\n            if not entity.boundary:\n                entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n\n        in_forest = [np.array([-1]), np.array([-1])]\n        inf1 = False\n        inf2 = False\n        if self.is_collision(agent, world.forests[0]):\n            in_forest[0] = np.array([1])\n            inf1= True\n        if self.is_collision(agent, world.forests[1]):\n            in_forest[1] = np.array([1])\n            inf2 = True\n\n        food_pos = []\n        for entity in world.food:\n            if not entity.boundary:\n                food_pos.append(entity.state.p_pos - agent.state.p_pos)\n        # communication of all other agents\n        comm = []\n        other_pos = []\n        other_vel = []\n        for other in world.agents:\n            if other is agent: continue\n            comm.append(other.state.c)\n            oth_f1 = self.is_collision(other, world.forests[0])\n            oth_f2 = self.is_collision(other, world.forests[1])\n            if (inf1 and oth_f1) or (inf2 and oth_f2) or (not inf1 and not oth_f1 and not inf2 and not oth_f2) or agent.leader:  #without forest vis\n                other_pos.append(other.state.p_pos - agent.state.p_pos)\n                if not other.adversary:\n                    other_vel.append(other.state.p_vel)\n            else:\n                other_pos.append([0, 0])\n                if not other.adversary:\n                    other_vel.append([0, 0])\n\n        # to tell the pred when the prey are in the forest\n        prey_forest = []\n        ga = self.good_agents(world)\n        for a in ga:\n            if any([self.is_collision(a, f) for f in world.forests]):\n                prey_forest.append(np.array([1]))\n            else:\n                prey_forest.append(np.array([-1]))\n        # to tell leader when pred are in forest\n        prey_forest_lead = []\n        for f in world.forests:\n            if any([self.is_collision(a, f) for a in ga]):\n                prey_forest_lead.append(np.array([1]))\n            else:\n                prey_forest_lead.append(np.array([-1]))\n\n        comm = [world.agents[0].state.c]\n\n        if agent.adversary and not agent.leader:\n            return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel + in_forest + comm)\n        if agent.leader:\n            return np.concatenate(\n                [agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel + in_forest + comm)\n        else:\n            return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + in_forest + other_vel)\n\n\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/MPE/policy/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/MPE/policy/maddpg.py",
    "content": "import torch\nfrom torch.optim import Adam\nimport torch.nn.functional as F\nfrom gym.spaces import Box, Discrete, MultiDiscrete\nfrom multiagent.multi_discrete import MultiDiscrete\nfrom utils.networks import MLPNetwork, SNNNetwork\nfrom utils.misc import soft_update, average_gradients, onehot_from_logits, gumbel_softmax\nfrom agents.agents import DDPGAgent, DDPGAgent_RNN, DDPGAgent_SNN, DDPGAgent_ToM\n# from commom.distributions import make_pdtype\n\n\nimport  time\nMSELoss = torch.nn.MSELoss()\n\nclass MADDPG(object):\n    \"\"\"\n    Wrapper class for DDPG-esque (i.e. also MADDPG) agents in multi-agent task\n    \"\"\"\n    def __init__(self, agent_init_params, alg_types, device,\n                 gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64,\n                 discrete_action=False):\n        \"\"\"\n        Inputs:\n            agent_init_params (list of dict): List of dicts with parameters to\n                                              initialize each agent\n                num_in_pol (int): Input dimensions to policy\n                num_out_pol (int): Output dimensions to policy\n                num_in_critic (int): Input dimensions to critic\n            alg_types (list of str): Learning algorithm for each agent (DDPG\n                                       or MADDPG)\n            gamma (float): Discount factor\n            tau (float): Target update rate\n            lr (float): Learning rate for policy and critic\n            hidden_dim (int): Number of hidden dimensions for networks\n            discrete_action (bool): Whether or not to use discrete action space\n        \"\"\"\n        self.device = device\n        self.nagents = len(alg_types)\n        self.alg_types = alg_types\n        self.agents = [DDPGAgent(lr=lr, discrete_action=discrete_action,\n                                 hidden_dim=hidden_dim,\n                                 **params)\n                       for params in agent_init_params]\n        self.agent_init_params = agent_init_params\n        self.gamma = gamma\n        self.tau = tau\n        self.lr = lr\n        self.discrete_action = discrete_action\n        self.pol_dev = 'cpu'  # device for policies\n        self.critic_dev = 'cpu'  # device for critics\n        self.trgt_pol_dev = 'cpu'  # device for target policies\n        self.trgt_critic_dev = 'cpu'  # device for target critics\n        self.niter = 0\n\n    @property\n    def policies(self):\n        return [a.policy for a in self.agents]\n\n    @property\n    def target_policies(self):\n        return [a.target_policy for a in self.agents]\n\n    def scale_noise(self, scale):\n        \"\"\"\n        Scale noise for each agent\n        Inputs:\n            scale (float): scale of noise\n        \"\"\"\n        for a in self.agents:\n            a.scale_noise(scale)\n\n    def reset_noise(self):\n        for a in self.agents:\n            a.reset_noise()\n\n    def step(self, observations, explore=False):\n        \"\"\"\n        Take a step forward in environment with all agents\n        Inputs:\n            observations: List of observations for each agent\n            explore (boolean): Whether or not to add exploration noise\n        Outputs:\n            actions: List of actions for each agent\n        \"\"\"\n        return [a.step(obs, explore=explore) for a, obs in zip(self.agents,\n                                                                 observations)]\n\n    def update(self, sample, agent_i, parallel=False, logger=None):\n        \"\"\"\n        Update parameters of agent model based on sample from replay buffer\n        Inputs:\n            sample: tuple of (observations, actions, rewards, next\n                    observations, and episode end masks) sampled randomly from\n                    the replay buffer. Each is a list with entries\n                    corresponding to each agent\n            agent_i (int): index of agent to update\n            parallel (bool): If true, will average gradients across threads\n            logger (SummaryWriter from Tensorboard-Pytorch):\n                If passed in, important quantities will be logged\n        \"\"\"\n        obs, acs, rews, next_obs, dones = sample\n        curr_agent = self.agents[agent_i]\n\n        curr_agent.critic_optimizer.zero_grad()\n        if self.alg_types[agent_i] == 'MADDPG':\n            if self.discrete_action: # one-hot encode action\n                all_trgt_acs = [onehot_from_logits(pi(nobs)) for pi, nobs in\n                                zip(self.target_policies, next_obs)]\n            else:\n                all_trgt_acs = [pi(nobs) for pi, nobs in zip(self.target_policies,\n                                                             next_obs)]\n            trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1)\n        else:  # DDPG\n            if self.discrete_action:\n                trgt_vf_in = torch.cat((next_obs[agent_i],\n                                        onehot_from_logits(\n                                            curr_agent.target_policy(\n                                                next_obs[agent_i]))),\n                                       dim=1)\n            else:\n                trgt_vf_in = torch.cat((next_obs[agent_i],\n                                        curr_agent.target_policy(next_obs[agent_i])),\n                                       dim=1)\n        target_value = (rews[agent_i].view(-1, 1) + self.gamma *\n                        curr_agent.target_critic(trgt_vf_in) *\n                        (1 - dones[agent_i].view(-1, 1)))\n\n        if self.alg_types[agent_i] == 'MADDPG':\n            vf_in = torch.cat((*obs, *acs), dim=1)\n        else:  # DDPG\n            vf_in = torch.cat((obs[agent_i], acs[agent_i]), dim=1)\n        actual_value = curr_agent.critic(vf_in)\n        vf_loss = MSELoss(actual_value, target_value.detach())\n        vf_loss.backward()\n        if parallel:\n            average_gradients(curr_agent.critic)\n        torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), 0.5)\n        curr_agent.critic_optimizer.step()\n\n        curr_agent.policy_optimizer.zero_grad()\n\n        if self.discrete_action:\n            # Forward pass as if onehot (hard=True) but backprop through a differentiable\n            # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop\n            # through discrete categorical samples, but I'm not sure if that is\n            # correct since it removes the assumption of a deterministic policy for\n            # DDPG. Regardless, discrete policies don't seem to learn properly without it.\n            curr_pol_out = curr_agent.policy(obs[agent_i])\n            curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)\n        else:\n            curr_pol_out = curr_agent.policy(obs[agent_i])\n            curr_pol_vf_in = curr_pol_out\n        if self.alg_types[agent_i] == 'MADDPG':\n            all_pol_acs = []\n            for i, pi, ob in zip(range(self.nagents), self.policies, obs):\n                if i == agent_i:\n                    all_pol_acs.append(curr_pol_vf_in)\n                elif self.discrete_action:\n                    all_pol_acs.append(onehot_from_logits(pi(ob)))\n                else:\n                    all_pol_acs.append(pi(ob))\n            vf_in = torch.cat((*obs, *all_pol_acs), dim=1)\n        else:  # DDPG\n            vf_in = torch.cat((obs[agent_i], curr_pol_vf_in),\n                              dim=1)\n        pol_loss = -curr_agent.critic(vf_in).mean()\n        pol_loss += (curr_pol_out**2).mean() * 1e-3\n        pol_loss.backward()\n        if parallel:\n            average_gradients(curr_agent.policy)\n        torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), 0.5)\n        curr_agent.policy_optimizer.step()\n        if logger is not None:\n            logger.add_scalars('agent%i/losses' % agent_i,\n                               {'vf_loss': vf_loss,\n                                'pol_loss': pol_loss},\n                               self.niter)\n\n\n    def update_all_targets(self):\n        \"\"\"\n        Update all target networks (called after normal updates have been\n        performed for each agent)\n        \"\"\"\n        for a in self.agents:\n            soft_update(a.target_critic, a.critic, self.tau)\n            soft_update(a.target_policy, a.policy, self.tau)\n        self.niter += 1\n\n    def prep_training(self, device='gpu'):\n        for a in self.agents:\n            a.policy.train()\n            a.critic.train()\n            a.target_policy.train()\n            a.target_critic.train()\n        if device == 'gpu':\n            fn = lambda x: x.to(torch.device(self.device))\n        else:\n            fn = lambda x: x.cpu()\n        if not self.pol_dev == device:\n            for a in self.agents:\n                a.policy = fn(a.policy)\n            self.pol_dev = device\n        if not self.critic_dev == device:\n            for a in self.agents:\n                a.critic = fn(a.critic)\n            self.critic_dev = device\n        if not self.trgt_pol_dev == device:\n            for a in self.agents:\n                a.target_policy = fn(a.target_policy)\n            self.trgt_pol_dev = device\n        if not self.trgt_critic_dev == device:\n            for a in self.agents:\n                a.target_critic = fn(a.target_critic)\n            self.trgt_critic_dev = device\n\n    def prep_rollouts(self, device='cpu'):\n        for a in self.agents:\n            a.policy.eval()\n        if device == 'gpu':\n            fn = lambda x: x.to(torch.device(self.device))\n        else:\n            fn = lambda x: x.cpu()\n        # only need main policy for rollouts\n        if not self.pol_dev == device:\n            for a in self.agents:\n                a.policy = fn(a.policy)\n            self.pol_dev = device\n\n    def save(self, filename):\n        \"\"\"\n        Save trained parameters of all agents into one file\n        \"\"\"\n        self.prep_training(device='cpu')  # move parameters to CPU before saving\n        save_dict = {'init_dict': self.init_dict,\n                     'agent_params': [a.get_params() for a in self.agents]}\n        torch.save(save_dict, filename)\n\n    @classmethod\n    def init_from_env(cls, env, device, agent_alg=\"MADDPG\", adversary_alg=\"MADDPG\",\n                      gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64):\n        \"\"\"\n        Instantiate instance of this class from multi-agent environment\n        \"\"\"\n        agent_init_params = []\n        alg_types = [adversary_alg if atype == 'adversary' else agent_alg for\n                     atype in env.agent_types]\n        for acsp, obsp, algtype in zip(env.action_space, env.observation_space,\n                                       alg_types):\n            num_in_pol = obsp.shape[0]\n            if isinstance(acsp, Box):\n                discrete_action = False\n                get_shape = lambda x: x.shape[0]\n            elif isinstance(acsp, Discrete):  # Discrete\n                discrete_action = True\n                get_shape = lambda x: x.n\n            elif isinstance(acsp, MultiDiscrete):\n                discrete_action = True\n                get_shape = lambda x: sum(x.high - x.low + 1)\n            num_out_pol = get_shape(acsp)\n            if algtype == \"MADDPG\":\n                num_in_critic = 0\n                for oobsp in env.observation_space:\n                    num_in_critic += oobsp.shape[0]\n                for oacsp in env.action_space:\n                    if isinstance(oacsp, Box):\n                        discrete_action = False\n                        get_shape = lambda x: x.shape[0]\n                    elif isinstance(oacsp, Discrete):  # Discrete\n                        discrete_action = True\n                        get_shape = lambda x: x.n\n                    elif isinstance(oacsp, MultiDiscrete):\n                        discrete_action = True\n                        get_shape = lambda x: sum(x.high - x.low + 1)\n                    num_in_critic += get_shape(oacsp)\n            else:\n                num_in_critic = obsp.shape[0] + get_shape(acsp)\n            agent_init_params.append({'num_in_pol': num_in_pol,\n                                      'num_out_pol': num_out_pol,\n                                      'num_in_critic': num_in_critic})\n        init_dict = {'gamma': gamma, 'tau': tau, 'lr': lr,\n                     'hidden_dim': hidden_dim,\n                     'alg_types': alg_types,\n                     'agent_init_params': agent_init_params,\n                     'discrete_action': discrete_action,\n                     'device': device}\n        instance = cls(**init_dict)\n        instance.init_dict = init_dict\n        return instance\n\n    @classmethod\n    def init_from_save(cls, filename):\n        \"\"\"\n        Instantiate instance of this class from file created by 'save' method\n        \"\"\"\n        save_dict = torch.load(filename)\n        instance = cls(**save_dict['init_dict'])\n        instance.init_dict = save_dict['init_dict']\n        for a, params in zip(instance.agents, save_dict['agent_params']):\n            a.load_params(params)\n        return instance\n\nclass MADDPG_SNN(object):\n    \"\"\"\n    Wrapper class for DDPG-esque (i.e. also MADDPG) agents in multi-agent task\n    \"\"\"\n    def __init__(self, agent_init_params, alg_types,output_style, device,\n                 gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64,\n                 discrete_action=False):\n        \"\"\"\n        Inputs:\n            agent_init_params (list of dict): List of dicts with parameters to\n                                              initialize each agent\n                num_in_pol (int): Input dimensions to policy\n                num_out_pol (int): Output dimensions to policy\n                num_in_critic (int): Input dimensions to critic\n            alg_types (list of str): Learning algorithm for each agent (DDPG\n                                       or MADDPG)\n            gamma (float): Discount factor\n            tau (float): Target update rate\n            lr (float): Learning rate for policy and critic\n            hidden_dim (int): Number of hidden dimensions for networks\n            discrete_action (bool): Whether or not to use discrete action space\n        \"\"\"\n        self.device = device\n        self.nagents = len(alg_types)\n        self.alg_types = alg_types\n        self.agents = [DDPGAgent_SNN(lr=lr, discrete_action=discrete_action,\n                                 hidden_dim=hidden_dim,\n                                 **params, output_style=output_style)\n                       for params in agent_init_params]\n        self.agent_init_params = agent_init_params\n        self.gamma = gamma\n        self.tau = tau\n        self.lr = lr\n        self.discrete_action = discrete_action\n        self.pol_dev = 'cpu'  # device for policies\n        self.critic_dev = 'cpu'  # device for critics\n        self.trgt_pol_dev = 'cpu'  # device for target policies\n        self.trgt_critic_dev = 'cpu'  # device for target critics\n        self.niter = 0\n\n    @property\n    def policies(self):\n        return [a.policy for a in self.agents]\n\n    @property\n    def target_policies(self):\n        return [a.target_policy for a in self.agents]\n\n    def scale_noise(self, scale):\n        \"\"\"\n        Scale noise for each agent\n        Inputs:\n            scale (float): scale of noise\n        \"\"\"\n        for a in self.agents:\n            a.scale_noise(scale)\n\n    def reset_noise(self):\n        for a in self.agents:\n            a.reset_noise()\n\n    def step(self, observations, explore=False):\n        \"\"\"\n        Take a step forward in environment with all agents\n        Inputs:\n            observations: List of observations for each agent\n            explore (boolean): Whether or not to add exploration noise\n        Outputs:\n            actions: List of actions for each agent\n        \"\"\"\n        return [a.step(obs, explore=explore) for a, obs in zip(self.agents,\n                                                                 observations)]\n\n    def update(self, sample, agent_i, parallel=False, logger=None):\n        \"\"\"\n        Update parameters of agent model based on sample from replay buffer\n        Inputs:\n            sample: tuple of (observations, actions, rewards, next\n                    observations, and episode end masks) sampled randomly from\n                    the replay buffer. Each is a list with entries\n                    corresponding to each agent\n            agent_i (int): index of agent to update\n            parallel (bool): If true, will average gradients across threads\n            logger (SummaryWriter from Tensorboard-Pytorch):\n                If passed in, important quantities will be logged\n        \"\"\"\n        obs, acs, rews, next_obs, dones = sample\n        curr_agent = self.agents[agent_i]\n\n        curr_agent.critic_optimizer.zero_grad()\n        if self.alg_types[agent_i] == 'MADDPG_SNN':\n            all_trgt_acs = []\n            if self.discrete_action: # one-hot encode action\n                all_trgt_acs = [onehot_from_logits(pi(nobs)) for pi, nobs in\n                                zip(self.target_policies, next_obs)]\n                # for nobs in next_obs:\n                #     if nobs.shape[1] == next_obs[agent_i].shape[1]:\n                #         all_trgt_acs.append(onehot_from_logits(self.target_policies[agent_i](nobs)))\n                #     else:\n                #         if next_obs[agent_i].shape[1] - nobs[:][3].shape[0] > 0 :\n                #             a = torch.zeros((nobs.shape[0], next_obs[agent_i].shape[1] - nobs[:][3].shape[0]))\n                #             a = a.to(torch.device(self.device))\n                #             obs_good = torch.cat((nobs, a), 1)\n                #             all_trgt_acs.append(onehot_from_logits(self.target_policies[agent_i](obs_good)))\n                #         else:\n                #             all_trgt_acs.append(onehot_from_logits(self.target_policies[agent_i](nobs[:, :next_obs[agent_i].shape[1]])))\n                # all_trgt_acs = [onehot_from_logits(self.target_policies[agent_i](nobs)) for nobs in\n                #                 next_obs]\n            else:\n                # all_trgt_acs = [pi(nobs) for pi, nobs in zip(self.target_policies,\n                #                                              next_obs)]\n                all_trgt_acs = [self.target_policies[agent_i](nobs) for nobs in next_obs]   #self-experience\n            trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1)\n        else:  # DDPG\n            if self.discrete_action:\n                trgt_vf_in = torch.cat((next_obs[agent_i],\n                                        onehot_from_logits(\n                                            curr_agent.target_policy(\n                                                next_obs[agent_i]))),\n                                       dim=1)\n            else:\n                trgt_vf_in = torch.cat((next_obs[agent_i],\n                                        curr_agent.target_policy(next_obs[agent_i])),\n                                       dim=1)\n        target_value = (rews[agent_i].view(-1, 1) + self.gamma *\n                        curr_agent.target_critic(trgt_vf_in) *\n                        (1 - dones[agent_i].view(-1, 1)))\n\n        if self.alg_types[agent_i] == 'MADDPG_SNN':\n            vf_in = torch.cat((*obs, *acs), dim=1)\n        else:  # DDPG\n            vf_in = torch.cat((obs[agent_i], acs[agent_i]), dim=1)\n        actual_value = curr_agent.critic(vf_in)\n        vf_loss = MSELoss(actual_value, target_value.detach())\n        vf_loss.backward()\n        if parallel:\n            average_gradients(curr_agent.critic)\n        torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), 0.5)\n        curr_agent.critic_optimizer.step()\n\n        curr_agent.policy_optimizer.zero_grad()\n\n        if self.discrete_action:\n            # Forward pass as if onehot (hard=True) but backprop through a differentiable\n            # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop\n            # through discrete categorical samples, but I'm not sure if that is\n            # correct since it removes the assumption of a deterministic policy for\n            # DDPG. Regardless, discrete policies don't seem to learn properly without it.\n            curr_pol_out = curr_agent.policy(obs[agent_i])\n            curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)\n        else:\n            curr_pol_out = curr_agent.policy(obs[agent_i])\n            curr_pol_vf_in = curr_pol_out\n        if self.alg_types[agent_i] == 'MADDPG_SNN':\n            all_pol_acs = []\n            for i, pi, ob in zip(range(self.nagents), self.policies, obs):\n                if i == agent_i:\n                    all_pol_acs.append(curr_pol_vf_in)\n                elif self.discrete_action:\n                    all_pol_acs.append(onehot_from_logits(pi(ob)))\n                else:\n                    all_pol_acs.append(pi(ob))\n            vf_in = torch.cat((*obs, *all_pol_acs), dim=1)\n        else:  # DDPG\n            vf_in = torch.cat((obs[agent_i], curr_pol_vf_in),\n                              dim=1)\n        pol_loss = -curr_agent.critic(vf_in).mean()\n        pol_loss += (curr_pol_out**2).mean() * 1e-3\n        pol_loss.backward()\n        if parallel:\n            average_gradients(curr_agent.policy)\n        torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), 0.5)\n        curr_agent.policy_optimizer.step()\n        if logger is not None:\n            logger.add_scalars('agent%i/losses' % agent_i,\n                               {'vf_loss': vf_loss,\n                                'pol_loss': pol_loss},\n                               self.niter)\n\n\n    def update_all_targets(self):\n        \"\"\"\n        Update all target networks (called after normal updates have been\n        performed for each agent)\n        \"\"\"\n        for a in self.agents:\n            soft_update(a.target_critic, a.critic, self.tau)\n            soft_update(a.target_policy, a.policy, self.tau)\n        self.niter += 1\n\n    def prep_training(self, device='gpu'):\n        for a in self.agents:\n            a.policy.train()\n            a.critic.train()\n            a.target_policy.train()\n            a.target_critic.train()\n        if device == 'gpu':\n            fn = lambda x: x.to(torch.device(self.device))\n        else:\n            fn = lambda x: x.cpu()\n        if not self.pol_dev == device:\n            for a in self.agents:\n                a.policy = fn(a.policy)\n            self.pol_dev = device\n        if not self.critic_dev == device:\n            for a in self.agents:\n                a.critic = fn(a.critic)\n            self.critic_dev = device\n        if not self.trgt_pol_dev == device:\n            for a in self.agents:\n                a.target_policy = fn(a.target_policy)\n            self.trgt_pol_dev = device\n        if not self.trgt_critic_dev == device:\n            for a in self.agents:\n                a.target_critic = fn(a.target_critic)\n            self.trgt_critic_dev = device\n\n    def prep_rollouts(self, device='cpu'):\n        for a in self.agents:\n            a.policy.eval()\n        if device == 'gpu':\n            fn = lambda x: x.to(torch.device(self.device))\n        else:\n            fn = lambda x: x.cpu()\n        # only need main policy for rollouts\n        if not self.pol_dev == device:\n            for a in self.agents:\n                a.policy = fn(a.policy)\n            self.pol_dev = device\n\n    def save(self, filename):\n        \"\"\"\n        Save trained parameters of all agents into one file\n        \"\"\"\n        self.prep_training(device='cpu')  # move parameters to CPU before saving\n        save_dict = {'init_dict': self.init_dict,\n                     'agent_params': [a.get_params() for a in self.agents]}\n        torch.save(save_dict, filename)\n\n    @classmethod\n    def init_from_env(cls, env, device, agent_alg=\"MADDPG_SNN\", adversary_alg=\"MADDPG_SNN\",\n                      gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64, output_style='sum'):\n    # def init_from_env(cls, env, agent_alg=\"MADDPG_SNN\", adversary_alg=\"MADDPG_SNN\",\n    #                   gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64):    #eval\n        \"\"\"\n        Instantiate instance of this class from multi-agent environment\n        \"\"\"\n        agent_init_params = []\n        alg_types = [adversary_alg if atype == 'adversary' else agent_alg for\n                     atype in env.agent_types]\n        for acsp, obsp, algtype in zip(env.action_space, env.observation_space,\n                                       alg_types):\n            num_in_pol = obsp.shape[0]\n            if isinstance(acsp, Box):\n                discrete_action = False\n                get_shape = lambda x: x.shape[0]\n            elif isinstance(acsp, Discrete):  # Discrete\n                discrete_action = True\n                get_shape = lambda x: x.n\n            elif isinstance(acsp, MultiDiscrete):\n                discrete_action = True\n                get_shape = lambda x: sum(x.high - x.low + 1)\n            num_out_pol = get_shape(acsp)\n            if algtype == \"MADDPG_SNN\":\n                num_in_critic = 0\n                for oobsp in env.observation_space:\n                    num_in_critic += oobsp.shape[0]\n                for oacsp in env.action_space:\n                    if isinstance(oacsp, Box):\n                        discrete_action = False\n                        get_shape = lambda x: x.shape[0]\n                    elif isinstance(oacsp, Discrete):  # Discrete\n                        discrete_action = True\n                        get_shape = lambda x: x.n\n                    elif isinstance(oacsp, MultiDiscrete):\n                        discrete_action = True\n                        get_shape = lambda x: sum(x.high - x.low + 1)\n                    num_in_critic += get_shape(oacsp)\n            else:\n                num_in_critic = obsp.shape[0] + get_shape(acsp)\n            agent_init_params.append({'num_in_pol': num_in_pol,\n                                      'num_out_pol': num_out_pol,\n                                      'num_in_critic': num_in_critic})\n        init_dict = {'gamma': gamma, 'tau': tau, 'lr': lr,\n                     'hidden_dim': hidden_dim,\n                     'alg_types': alg_types,\n                     'agent_init_params': agent_init_params,\n                     'discrete_action': discrete_action,\n                     'output_style': output_style,\n                     'device': device}\n        instance = cls(**init_dict)\n        instance.init_dict = init_dict\n        return instance\n\n    @classmethod\n    def init_from_save(cls, filename):\n        \"\"\"\n        Instantiate instance of this class from file created by 'save' method\n        \"\"\"\n        save_dict = torch.load(filename)\n        instance = cls(**save_dict['init_dict'])\n        instance.init_dict = save_dict['init_dict']\n        for a, params in zip(instance.agents, save_dict['agent_params']):\n            a.load_params(params)\n        return instance\n\nclass MADDPG_ToM(object):\n    \"\"\"\n    Wrapper class for DDPG-esque (i.e. also MADDPG) agents in multi-agent task\n    \"\"\"\n\n    def __init__(self, agent_init_params, alg_types, output_style, device,\n                 gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64,\n                 discrete_action=False):\n        \"\"\"\n        Inputs:\n            agent_init_params (list of dict): List of dicts with parameters to\n                                              initialize each agent\n                num_in_pol (int): Input dimensions to policy\n                num_out_pol (int): Output dimensions to policy\n                num_in_critic (int): Input dimensions to critic\n            alg_types (list of str): Learning algorithm for each agent (DDPG\n                                       or MADDPG)\n            gamma (float): Discount factor\n            tau (float): Target update rate\n            lr (float): Learning rate for policy and critic\n            hidden_dim (int): Number of hidden dimensions for networks\n            discrete_action (bool): Whether or not to use discrete action space\n        \"\"\"\n        self.device = device\n        self.nagents = len(alg_types)\n        self.alg_types = alg_types\n        self.agents = [DDPGAgent_ToM(lr=lr, discrete_action=discrete_action,\n                                     hidden_dim=hidden_dim,\n                                     **params, output_style=output_style,\n                                     num_agents=self.nagents,\n                                     device=self.device)\n                       for params in agent_init_params]\n        self.agent_init_params = agent_init_params\n        if self.nagents == 6:\n            self.mle_base = [SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14,      #simple_com\n                                        self.agent_init_params[3]['num_out_pol'], #adv self-self\n                                  hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14,\n                                        self.agent_init_params[3]['num_out_pol'], #adv self-other\n                                        hidden_dim=hidden_dim, output_style=output_style),\n                             # SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14,\n                             #            self.agent_init_params[3]['num_out_pol'],  # adv self-other\n                             #            hidden_dim=hidden_dim, output_style=output_style),\n                             # SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14,\n                             #            self.agent_init_params[3]['num_out_pol'],\n                             #            hidden_dim=hidden_dim, output_style=output_style),    ##agent self-other\n                             ]\n        if self.nagents == 4:\n            self.mle_base = [SNNNetwork(self.agent_init_params[0]['num_in_mle'] - 2,      #simple_tag\n                                        self.agent_init_params[0]['num_out_pol'], #adv self-self\n                                  hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 2,\n                                        self.agent_init_params[3]['num_out_pol'], #adv self-other\n                                        hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 2,\n                                        self.agent_init_params[3]['num_out_pol'],\n                                        hidden_dim=hidden_dim, output_style=output_style),    ##agent self-other\n                             ]\n        elif self.nagents == 3:\n            self.mle_base = [SNNNetwork(self.agent_init_params[1]['num_in_mle'],      #simple_adv\n                                        self.agent_init_params[1]['num_out_pol'], #adv self-other\n                                        hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[1]['num_in_mle'],\n                                        self.agent_init_params[1]['num_out_pol'], #agent self-self\n                                        hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[1]['num_in_mle'],\n                                        self.agent_init_params[1]['num_out_pol'],\n                                        hidden_dim=hidden_dim, output_style=output_style),    ##agent self-other\n            ]\n        elif self.nagents == 2:\n            self.mle_base = [SNNNetwork(self.agent_init_params[0]['num_in_mle']-2,      #simple_push\n                                        self.agent_init_params[0]['num_out_pol'], #adv self-other\n                                        hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[1]['num_in_mle']-2,\n                                        self.agent_init_params[1]['num_out_pol'],\n                                        hidden_dim=hidden_dim, output_style=output_style),    ##agent self-other\n                ]\n        self.mle_opts = [Adam(self.mle_base[i].parameters(), lr=lr) for i in range(len(self.mle_base))]\n        self.gamma = gamma\n        self.tau = tau\n        self.lr = lr\n        self.discrete_action = discrete_action\n        self.pol_dev = 'cpu'  # device for policies\n        self.critic_dev = 'cpu'  # device for critics\n        self.trgt_pol_dev = 'cpu'  # device for target policies\n        self.trgt_critic_dev = 'cpu'  # device for target critics\n        self.mle_dev = 'cpu'\n        self.niter = 0\n\n    @property\n    def policies(self):\n        return [a.policy for a in self.agents]\n\n    @property\n    def target_policies(self):\n        return [a.target_policy for a in self.agents]\n\n    def scale_noise(self, scale):\n        \"\"\"\n        Scale noise for each agent\n        Inputs:\n            scale (float): scale of noise\n        \"\"\"\n        for a in self.agents:\n            a.scale_noise(scale)\n\n    def reset_noise(self):\n        for a in self.agents:\n            a.reset_noise()\n\n    def step(self, observations, explore=False):    #simple_tag\n        \"\"\"\n        Take a step forward in environment with all agents\n        Inputs:\n            observations: List of observations for each agent\n            explore (boolean): Whether or not to add exploration noise\n        Outputs:\n            actions: List of actions for each agent\n        \"\"\"\n        # t1 = time.time()\n        observations_ = observations.copy()\n        for agent_i, obs in enumerate(observations):\n            obs_ = observations_.copy()\n            obs_.pop(agent_i)\n            # actions = [self.agents[agent_i].mle[j].cpu()(observations[agent_i]) for j, obs_j in enumerate(obs_)]\n            # observations[agent_i] = torch.cat((observations[agent_i], torch.cat(actions, 1)), 1)\n            if self.nagents == 6:\n\n                if agent_i < 4:\n                    self.agents[agent_i].mle = [self.mle_base[0],self.mle_base[0],self.mle_base[0], self.mle_base[1], self.mle_base[1]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(obs_j[:,2:]), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(obs_j[:, 4:24].to(self.device)),\n                                              hard=True).cpu()\n                               for j, obs_j in enumerate(obs_)]\n\n                else:\n                    self.agents[agent_i].mle = [self.mle_base[0],self.mle_base[0], self.mle_base[0], self.mle_base[0], self.mle_base[1]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(obs_j[:,2:-2]), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(obs_j[:, 4:24].to(self.device)),\n                                              hard=True).cpu()\n                               for j, obs_j in enumerate(obs_)]\n            if self.nagents == 4:\n                if agent_i < 3:\n                    self.agents[agent_i].mle = [self.mle_base[0],self.mle_base[0], self.mle_base[1]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(obs_j[:,2:]), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(obs_j[:, 2:].to(self.device)),\n                                              hard=True).cpu()\n                               for j, obs_j in enumerate(obs_)]\n\n                elif agent_i == 3:\n                    self.agents[agent_i].mle = [self.mle_base[2],self.mle_base[2], self.mle_base[2]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(obs_j[:,2:-2]), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(obs_j[:,2:-2].to(self.device)),\n                                              hard=True).cpu()\n                               for j, obs_j in enumerate(obs_)]\n            elif self.nagents == 3: #simple_adv\n                if agent_i < 1:\n                    self.agents[agent_i].mle = [self.mle_base[0], self.mle_base[0]]\n                    # actions = [gumbel_softmax(\n                    #     self.agents[agent_i].mle[j].cpu()(torch.cat((obs_j[:, :2], observations_[agent_i]), 1)),\n                    #     hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(torch.cat((obs_j[:, :2],\n                             observations_[agent_i]), 1).to(self.device)), hard=True).cpu()\n                               for j, obs_j in enumerate(obs_)]\n                elif agent_i >= 1:\n                    self.agents[agent_i].mle = [self.mle_base[2], self.mle_base[1]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(observations_[agent_i]), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(observations_[agent_i].to(self.device)),\n                                              hard=True).cpu() for j, obs_j in enumerate(obs_)]\n            elif self.nagents == 2:\n                if agent_i < 1:\n                    self.agents[agent_i].mle = [self.mle_base[0]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(observations_[agent_i][:,2:]), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(observations_[agent_i][:,2:].to(self.device)),\n                                              hard=True).cpu() for j, obs_j in enumerate(obs_)]\n\n                elif agent_i == 1:\n                    self.agents[agent_i].mle = [self.mle_base[1]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(observations_[agent_i][:, 2:]), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(observations_[agent_i][:,2:].to(self.device)),\n                                              hard=True).cpu() for j, obs_j in enumerate(obs_)]\n            observations[agent_i] = torch.cat((observations[agent_i], torch.cat(actions, 1)), 1)\n        # t2 = time.time()\n        # print('step+time:', t2 - t1)\n        return [a.step(obs, explore=explore) for a, obs in zip(self.agents,\n                                                               observations)]\n\n\n    def _get_obs(self, observations):\n        observations_ = []\n        for agent_i, obs in enumerate(observations):\n            obs_ = observations.copy()\n            obs_.pop(agent_i)\n            if self.nagents == 6:\n                if agent_i < 4:   #simple_tag\n                    self.agents[agent_i].mle = [self.mle_base[0], self.mle_base[0], self.mle_base[0], self.mle_base[1], self.mle_base[1]]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(obs_j[:, 4:24]).detach(), hard=True)\n                               for j, obs_j in enumerate(obs_)]\n                elif agent_i > 4:\n                    self.agents[agent_i].mle = [self.mle_base[0], self.mle_base[0], self.mle_base[0], self.mle_base[0], self.mle_base[1]]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(obs_j[:, 4:24]).detach(), hard=True)\n                               for j, obs_j in enumerate(obs_)]\n            if self.nagents == 4:\n                if agent_i < 3:   #simple_tag\n                    self.agents[agent_i].mle = [self.mle_base[0], self.mle_base[0], self.mle_base[1]]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(obs_j[:, 2:]).detach(), hard=True)\n                               for j, obs_j in enumerate(obs_)]\n                elif agent_i == 3:\n                    self.agents[agent_i].mle = [self.mle_base[2], self.mle_base[2], self.mle_base[2]]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(obs_j[:, 2:-2]).detach(), hard=True)\n                               for j, obs_j in enumerate(obs_)]\n            elif self.nagents == 3:\n                if agent_i < 1:     #simple_adv\n                    self.agents[agent_i].mle = [self.mle_base[0],self.mle_base[0]]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((obs_j[:,:2],observations[agent_i]),1)).detach(), hard=True)\n                               for j, obs_j in enumerate(obs_)]\n\n                elif agent_i >= 1:\n                    self.agents[agent_i].mle = [self.mle_base[2],self.mle_base[1]]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(observations[agent_i]).detach(), hard=True)\n                               for j, obs_j in enumerate(obs_)]\n            elif self.nagents == 2:\n                if agent_i < 1:     #simple_push\n                    self.agents[agent_i].mle = [self.mle_base[0]]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(observations[agent_i][:,2:]).detach(), hard=True)\n                               for j, obs_j in enumerate(obs_)]\n\n                elif agent_i == 1:\n                    self.agents[agent_i].mle = [self.mle_base[1]]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(observations[agent_i][:,2:]).detach(), hard=True)\n                               for j, obs_j in enumerate(obs_)]\n\n            observations_.append(torch.cat((observations[agent_i], torch.cat(actions, 1)), 1))\n\n        return observations_\n\n    def trian_tag(self, agent_i, KL_criterion, obs, parallel, acs):\n        if agent_i == 0:\n            self.mle_opts[0].zero_grad()\n            action_i = self.mle_base[0](obs[0][:, 2:])\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[0].float())\n            loss.backward(retain_graph=True)\n            if parallel:\n                average_gradients(self.mle_base[0])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)\n            self.mle_opts[0].step()\n\n            self.mle_opts[1].zero_grad()\n            action_i = self.mle_base[1](obs[3][:, 2:])\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[3].float())\n            loss.backward()\n            if parallel:\n                average_gradients(self.mle_base[1])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)\n            self.mle_opts[1].step()\n\n            self.mle_opts[2].zero_grad()\n            action_i = self.mle_base[2](obs[0][:, 2:-2])\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[0].float())\n            loss.backward()\n            if parallel:\n                average_gradients(self.mle_base[2])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[2].parameters(), 20)\n            self.mle_opts[2].step()\n\n    def trian_adv(self, agent_i, KL_criterion, obs, parallel, acs):\n        if agent_i == 0:\n            self.mle_opts[0].zero_grad()\n            action_i = self.mle_base[0](torch.cat((obs[1][:,:2],obs[agent_i]), 1))\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[1].float())\n            loss.backward(retain_graph=True)\n            if parallel:\n                average_gradients(self.mle_base[0])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)\n            self.mle_opts[0].step()\n\n            self.mle_opts[1].zero_grad()\n            action_i = self.mle_base[1](obs[1])\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[1].float())\n            loss.backward()\n            if parallel:\n                average_gradients(self.mle_base[1])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)\n            self.mle_opts[1].step()\n\n            self.mle_opts[2].zero_grad()\n            action_i = self.mle_base[2](obs[1])\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[0].float())\n            loss.backward()\n            if parallel:\n                average_gradients(self.mle_base[2])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[2].parameters(), 20)\n            self.mle_opts[2].step()\n\n    def trian_push(self, agent_i, KL_criterion, obs, parallel, acs):\n        if agent_i == 0:\n            self.mle_opts[0].zero_grad()\n            action_i = self.mle_base[0](obs[0][:, 2:])\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[1].float())\n            loss.backward(retain_graph=True)\n            if parallel:\n                average_gradients(self.mle_base[0])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)\n            self.mle_opts[0].step()\n\n            self.mle_opts[1].zero_grad()\n            action_i = self.mle_base[1](obs[1][:, 2:])\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[0].float())\n            loss.backward()\n            if parallel:\n                average_gradients(self.mle_base[1])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)\n            self.mle_opts[1].step()\n\n    def trian_com(self, agent_i, KL_criterion, obs, parallel, acs):\n        if agent_i == 0:\n            self.mle_opts[0].zero_grad()\n            action_i = self.mle_base[0](obs[1][:, 4:24])\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[1].float())\n            loss.backward(retain_graph=True)\n            if parallel:\n                average_gradients(self.mle_base[0])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)\n            self.mle_opts[0].step()\n\n            self.mle_opts[1].zero_grad()\n            action_i = self.mle_base[1](obs[4][:, 4:24])\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[4].float())\n            loss.backward()\n            if parallel:\n                average_gradients(self.mle_base[1])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)\n            self.mle_opts[1].step()\n\n\n\n\n    def update(self, sample, agent_i, parallel=False, logger=None, sample_r=None):\n        \"\"\"\n        Update parameters of agent model based on sample from replay buffer\n        Inputs:\n            sample: tuple of (observations, actions, rewards, next\n                    observations, and episode end masks) sampled randomly from\n                    the replay buffer. Each is a list with entries\n                    corresponding to each agent\n            agent_i (int): index of agent to update\n            parallel (bool): If true, will average gradients across threads\n            logger (SummaryWriter from Tensorboard-Pytorch):\n                If passed in, important quantities will be logged\n        \"\"\"\n        # print('___update___')\n        obs, acs, rews, next_obs, dones = sample\n\n        next_obs_ = self._get_obs(next_obs)\n        obs_ = self._get_obs(obs)\n        curr_agent = self.agents[agent_i]\n        # mle\n        KL_criterion = torch.nn.KLDivLoss(reduction='sum')\n        # for i in range(len(curr_agent.mle)):\n        #     curr_agent.mle_optimizer[i].zero_grad()\n        #     action_i = curr_agent.mle[i](obs[agent_i]obs[agent_i])\n        #     action_pre = gumbel_softmax(action_i, hard=True)\n        #     loss = KL_criterion(action_pre.float(), acs[i].float())\n        #     loss.backward()\n        #     if parallel:\n        #         average_gradients(curr_agent.mle[i])\n        #     torch.nn.utils.clip_grad_norm_(curr_agent.mle[i].parameters(), 20)\n        #     curr_agent.policy_optimizer.step()\n        if self.nagents == 6:\n            self.trian_com(agent_i, KL_criterion, obs, parallel, acs)\n        if self.nagents == 4:\n            self.trian_tag(agent_i, KL_criterion, obs, parallel, acs)\n        elif self.nagents == 3:\n            self.trian_adv(agent_i, KL_criterion, obs, parallel, acs)\n        elif self.nagents == 2:\n            self.trian_push(agent_i, KL_criterion, obs, parallel, acs)\n\n        # center critic\n        curr_agent.critic_optimizer.zero_grad()\n        if self.alg_types[agent_i] == 'MADDPG_ToM':\n            all_trgt_acs = []\n            if self.discrete_action:  # one-hot encode action\n\n                all_trgt_acs = [onehot_from_logits(pi(nobs)) for pi, nobs in\n                                zip(self.target_policies, next_obs_)]\n            trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1)\n\n        target_value = (rews[agent_i].view(-1, 1) + self.gamma *\n                        curr_agent.target_critic(trgt_vf_in) *\n                        (1 - dones[agent_i].view(-1, 1)))\n\n        if self.alg_types[agent_i] == 'MADDPG_ToM':\n            vf_in = torch.cat((*obs, *acs), dim=1)\n\n        actual_value = curr_agent.critic(vf_in)\n        vf_loss = MSELoss(actual_value, target_value.detach())\n        vf_loss.backward()\n        if parallel:\n            average_gradients(curr_agent.critic)\n        torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), 0.5)\n        curr_agent.critic_optimizer.step()\n\n        curr_agent.policy_optimizer.zero_grad()\n        if self.discrete_action:\n            # Forward pass as if onehot (hard=True) but backprop through a differentiable\n            # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop\n            # through discrete categorical samples, but I'm not sure if that is\n            # correct since it removes the assumption of a deterministic policy for\n            # DDPG. Regardless, discrete policies don't seem to learn properly without it.\n\n            curr_pol_out = curr_agent.policy(obs_[agent_i])\n            curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)\n        else:\n            curr_pol_out = curr_agent.policy(obs[agent_i])\n            curr_pol_vf_in = curr_pol_out\n        if self.alg_types[agent_i] == 'MADDPG_ToM':\n            all_pol_acs = []\n            for i, pi, ob in zip(range(self.nagents), self.policies, obs_):\n                if i == agent_i:\n                    all_pol_acs.append(curr_pol_vf_in)\n                elif self.discrete_action:\n                    all_pol_acs.append(onehot_from_logits(pi(ob)))\n                else:\n                    all_pol_acs.append(pi(ob))\n            vf_in = torch.cat((*obs, *all_pol_acs), dim=1)\n\n        pol_loss = -curr_agent.critic(vf_in).mean()\n        pol_loss += (curr_pol_out ** 2).mean() * 1e-3\n        pol_loss.backward()\n        if parallel:\n            average_gradients(curr_agent.policy)\n        torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), 0.5)\n        # actor\n        curr_agent.policy_optimizer.step()\n        if logger is not None:\n            logger.add_scalars('agent%i/losses' % agent_i,\n                               {'vf_loss': vf_loss,\n                                'pol_loss': pol_loss},\n                               self.niter)\n\n    def update_all_targets(self):\n        \"\"\"\n        Update all target networks (called after normal updates have been\n        performed for each agent)\n        \"\"\"\n        for a in self.agents:\n            soft_update(a.target_critic, a.critic, self.tau)\n            soft_update(a.target_policy, a.policy, self.tau)\n        self.niter += 1\n\n    def prep_training(self, device='gpu'):\n        for mle in self.mle_base:\n            mle.train()\n        for a in self.agents:\n            a.policy.train()\n            a.critic.train()\n            a.target_policy.train()\n            a.target_critic.train()\n            for mle_i in a.mle:\n                mle_i.train()\n        if device == 'gpu':\n            fn = lambda x: x.to(torch.device(self.device))\n        else:\n            fn = lambda x: x.cpu()\n        if not self.pol_dev == device:\n            for a in self.agents:\n                a.policy = fn(a.policy)\n            self.pol_dev = device\n        if not self.critic_dev == device:\n            for a in self.agents:\n                a.critic = fn(a.critic)\n            self.critic_dev = device\n        if not self.trgt_pol_dev == device:\n            for a in self.agents:\n                a.target_policy = fn(a.target_policy)\n            self.trgt_pol_dev = device\n        if not self.trgt_critic_dev == device:\n            for a in self.agents:\n                a.target_critic = fn(a.target_critic)\n            self.trgt_critic_dev = device\n        if not self.mle_dev == device:\n            for i, mle in enumerate(self.mle_base):\n                self.mle_base[i] = fn(mle)\n            for a in self.agents:\n                for i, mle_i in enumerate(a.mle):\n                    a.mle[i] = fn(mle_i)\n            self.mle_dev = device\n\n    def prep_rollouts(self, device='cpu'):\n        for a in self.agents:\n            a.policy.eval()\n        if device == 'gpu':\n            fn = lambda x: x.to(torch.device(self.device))\n        else:\n            fn = lambda x: x.cpu()\n        # only need main policy for rollouts\n        if not self.pol_dev == device:\n            for a in self.agents:\n                a.policy = fn(a.policy)\n            self.pol_dev = device\n\n    def save(self, filename):\n        \"\"\"\n        Save trained parameters of all agents into one file\n        \"\"\"\n        self.prep_training(device='cpu')  # move parameters to CPU before saving\n        save_dict = {'init_dict': self.init_dict,\n                     'agent_params': [a.get_params() for a in self.agents],\n                     'mle_params': [self.get_params()],}\n        torch.save(save_dict, filename)\n\n    @classmethod\n    def init_from_env(cls, env, device, agent_alg=\"MADDPG_ToM\", adversary_alg=\"MADDPG_ToM\",\n                      gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64, output_style='sum'):\n        \"\"\"\n        Instantiate instance of this class from multi-agent environment\n        \"\"\"\n        agent_init_params = []\n        alg_types = [adversary_alg if atype == 'adversary' else agent_alg for\n                     atype in env.agent_types]\n        for acsp, obsp, algtype in zip(env.action_space, env.observation_space,\n                                       alg_types):\n            num_in_pol = obsp.shape[0]\n            num_in_mle = obsp.shape[0]\n            if isinstance(acsp, Box):\n                discrete_action = False\n                get_shape = lambda x: x.shape[0]\n            elif isinstance(acsp, Discrete):  # Discrete\n                discrete_action = True\n                get_shape = lambda x: x.n\n            elif isinstance(acsp, MultiDiscrete):\n                discrete_action = True\n                get_shape = lambda x: sum(x.high - x.low + 1)\n            num_out_pol = get_shape(acsp)\n            if algtype == \"MADDPG_ToM\":\n                num_in_critic = 0\n                num_in_pol += (len(env.agent_types)-1) * 5\n                for oobsp in env.observation_space:\n                    num_in_critic += oobsp.shape[0]\n                for oacsp in env.action_space:\n                    if isinstance(oacsp, Box):\n                        discrete_action = False\n                        get_shape = lambda x: x.shape[0]\n                    elif isinstance(oacsp, Discrete):  # Discrete\n                        discrete_action = True\n                        get_shape = lambda x: x.n\n                    elif isinstance(oacsp, MultiDiscrete):\n                        discrete_action = True\n                        get_shape = lambda x: sum(x.high - x.low + 1)\n                    num_in_critic += get_shape(oacsp)\n            else:\n                num_in_critic = obsp.shape[0] + get_shape(acsp)\n            agent_init_params.append({'num_in_pol': num_in_pol,\n                                      'num_out_pol': num_out_pol,\n                                      'num_in_critic': num_in_critic,\n                                      'num_in_mle': num_in_mle,})\n        init_dict = {'gamma': gamma, 'tau': tau, 'lr': lr,\n                     'device': device,\n                     'hidden_dim': hidden_dim,\n                     'alg_types': alg_types,\n                     'agent_init_params': agent_init_params,\n                     'discrete_action': discrete_action,\n                     'output_style': output_style}\n        instance = cls(**init_dict)\n        instance.init_dict = init_dict\n        return instance\n\n    @classmethod\n    def init_from_save(cls, filename):\n        \"\"\"\n        Instantiate instance of this class from file created by 'save' method\n        \"\"\"\n        save_dict = torch.load(filename)\n        instance = cls(**save_dict['init_dict'])\n        instance.init_dict = save_dict['init_dict']\n        for a, params in zip(instance.agents, save_dict['agent_params']):\n            a.load_params(params)\n        for a, params in zip([instance], save_dict['mle_params']):\n            a.load_params(params)\n        return instance\n\n    def get_params(self):\n        params = {\n                }\n        for i in range(len(self.mle_base)):\n            params['mle%d'%i] = self.mle_base[i].state_dict()\n            params['mle_optimizer%d'%i] = self.mle_opts[i].state_dict()\n        return params\n\n    def load_params(self, params):\n        for i in range(len(self.mle_base)):\n            self.mle_base[i].load_state_dict(params['mle%d'%i])\n            self.mle_opts[i].load_state_dict(params['mle_optimizer%d'%i])\n\nclass ToM_SA(object):\n    \"\"\"\n    Wrapper class for DDPG-esque (i.e. also MADDPG) agents in multi-agent task\n    \"\"\"\n\n    def __init__(self, agent_init_params, alg_types, output_style, device,\n                 gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64,\n                 discrete_action=False):\n        \"\"\"\n        Inputs:\n            agent_init_params (list of dict): List of dicts with parameters to\n                                              initialize each agent\n                num_in_pol (int): Input dimensions to policy\n                num_out_pol (int): Output dimensions to policy\n                num_in_critic (int): Input dimensions to critic\n            alg_types (list of str): Learning algorithm for each agent (DDPG\n                                       or MADDPG)\n            gamma (float): Discount factor\n            tau (float): Target update rate\n            lr (float): Learning rate for policy and critic\n            hidden_dim (int): Number of hidden dimensions for networks\n            discrete_action (bool): Whether or not to use discrete action space\n        \"\"\"\n        self.device = device\n        self.nagents = len(alg_types)\n        self.alg_types = alg_types\n        self.agents = [DDPGAgent_ToM(lr=lr, discrete_action=discrete_action,\n                                     hidden_dim=hidden_dim,\n                                     **params, output_style=output_style,\n                                     num_agents=self.nagents,\n                                     device=self.device)\n                       for params in agent_init_params]\n        self.agent_init_params = agent_init_params\n        if self.nagents == 6:\n            self.mle_base = [SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14 + 5,      #simple_com\n                                        self.agent_init_params[3]['num_out_pol'], #adv self-self\n                                  hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14 + 5,\n                                        self.agent_init_params[3]['num_out_pol'], #adv self-other\n                                        hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14 + 5,\n                                        self.agent_init_params[3]['num_out_pol'],  # adv self-other\n                                        hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14 + 5,\n                                        self.agent_init_params[3]['num_out_pol'],\n                                        hidden_dim=hidden_dim, output_style=output_style),    ##agent self-other\n                             ]\n        if self.nagents == 4:\n            self.mle_base = [SNNNetwork(self.agent_init_params[0]['num_in_mle'] - 2 + 5,      #simple_tag\n                                        self.agent_init_params[0]['num_out_pol'], #adv self-self\n                                  hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 2 + 5,\n                                        self.agent_init_params[3]['num_out_pol'], #adv self-other\n                                        hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 2 + 5,\n                                        self.agent_init_params[3]['num_out_pol'],\n                                        hidden_dim=hidden_dim, output_style=output_style),    ##agent self-other\n                             ]\n        elif self.nagents == 3:\n            self.mle_base = [SNNNetwork(self.agent_init_params[1]['num_in_mle'] + 5,      #simple_adv\n                                        self.agent_init_params[1]['num_out_pol'], #adv self-other\n                                        hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[1]['num_in_mle'] + 5,\n                                        self.agent_init_params[1]['num_out_pol'], #agent self-self\n                                        hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[1]['num_in_mle'] + 5,\n                                        self.agent_init_params[1]['num_out_pol'],\n                                        hidden_dim=hidden_dim, output_style=output_style),    ##agent self-other\n            ]\n        elif self.nagents == 2:\n            self.mle_base = [SNNNetwork(self.agent_init_params[0]['num_in_mle']-2  + 5,      #simple_push\n                                        self.agent_init_params[0]['num_out_pol'], #adv self-other\n                                        hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[1]['num_in_mle']-2  + 5,\n                                        self.agent_init_params[1]['num_out_pol'],\n                                        hidden_dim=hidden_dim, output_style=output_style),    ##agent self-other\n                ]\n        self.mle_opts = [Adam(self.mle_base[i].parameters(), lr=lr) for i in range(len(self.mle_base))]\n        self.gamma = gamma\n        self.tau = tau\n        self.lr = lr\n        self.discrete_action = discrete_action\n        self.pol_dev = 'cpu'  # device for policies\n        self.critic_dev = 'cpu'  # device for critics\n        self.trgt_pol_dev = 'cpu'  # device for target policies\n        self.trgt_critic_dev = 'cpu'  # device for target critics\n        self.mle_dev = 'cpu'\n        self.niter = 0\n\n    @property\n    def policies(self):\n        return [a.policy for a in self.agents]\n\n    @property\n    def target_policies(self):\n        return [a.target_policy for a in self.agents]\n\n    def scale_noise(self, scale):\n        \"\"\"\n        Scale noise for each agent\n        Inputs:\n            scale (float): scale of noise\n        \"\"\"\n        for a in self.agents:\n            a.scale_noise(scale)\n\n    def reset_noise(self):\n        for a in self.agents:\n            a.reset_noise()\n\n    def step(self, observations, actions_pre, explore=False):    #simple_tag\n        \"\"\"\n        Take a step forward in environment with all agents\n        Inputs:\n            observations: List of observations for each agent\n            explore (boolean): Whether or not to add exploration noise\n        Outputs:\n            actions: List of actions for each agent\n        \"\"\"\n        # t1 = time.time()\n        observations_ = observations.copy()\n        actions_pre_ = actions_pre.copy()\n        for agent_i, obs in enumerate(observations):\n            obs_ = observations_.copy()\n            acs_pre_ = actions_pre_.copy()\n            obs_.pop(agent_i)\n            acs_pre_.pop(agent_i)\n            # actions = [self.agents[agent_i].mle[j].cpu()(observations[agent_i]) for j, obs_j in enumerate(obs_)]\n            # observations[agent_i] = torch.cat((observations[agent_i], torch.cat(actions, 1)), 1)\n            if self.nagents == 6:\n                if agent_i < 4:\n                    self.agents[agent_i].mle = [self.mle_base[0],self.mle_base[0],self.mle_base[0], self.mle_base[1], self.mle_base[1]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(obs_j[:,2:]), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    # t1 = time.time()\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(torch.cat((obs_j[:, 4:24], acs_pre_[j][:,:5]),1).to(self.device)),\n                    #                           hard=True).cpu()\n                    #            for j, obs_j in enumerate(obs_)]\n                    # print(t1 - time.time())\n                    # t1 = time.time()\n                    actions = [torch.cat((obs_j[:, 4:24], acs_pre_[j][:,:5]),1) for j, obs_j in enumerate(obs_)]\n                    b1 = gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat(actions[:3]).to(self.device)), hard=True).cpu()\n                    b2 = gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat(actions[3:]).to(self.device)), hard=True).cpu()\n                    actions = torch.cat((b1[:20], b1[20:40], b1[40:60], b2[:20], b2[20:40]), 1)\n                    # print(t1 - time.time())\n                    # print()\n                else:\n                    self.agents[agent_i].mle = [self.mle_base[3],self.mle_base[3], self.mle_base[3], self.mle_base[3], self.mle_base[2]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(obs_j[:,2:-2]), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(torch.cat((obs_j[:, 4:24], acs_pre_[j]),1).to(self.device)),\n                    #                           hard=True).cpu()\n                    # for j, obs_j in enumerate(obs_)]\n                    actions = [torch.zeros((obs_j.shape[0],self.agent_init_params[1]['num_out_pol']))\n                               for j, obs_j in enumerate(obs_)]\n                    actions = torch.cat(actions,1)\n                    # print()\n\n            if self.nagents == 4:\n                if agent_i < 3:\n                    self.agents[agent_i].mle = [self.mle_base[0],self.mle_base[0], self.mle_base[1]]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(torch.cat((obs_j[:, 2:], acs_pre_[j]),1).to(self.device)),\n                                              hard=True).cpu()\n                               for j, obs_j in enumerate(obs_)]\n\n                elif agent_i == 3:\n                    self.agents[agent_i].mle = [self.mle_base[2],self.mle_base[2], self.mle_base[2]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(obs_j[:,2:-2]), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    actions = [torch.zeros((obs_j.shape[0],self.agent_init_params[0]['num_out_pol']))\n                               for j, obs_j in enumerate(obs_)]\n            elif self.nagents == 3: #simple_adv\n                actions = []\n                if agent_i < 1:\n                    self.agents[agent_i].mle = [self.mle_base[0], self.mle_base[0]]\n                    actions = [torch.zeros((obs_j.shape[0],self.agent_init_params[0]['num_out_pol']))\n                               for j, obs_j in enumerate(obs_)]\n                elif agent_i == 1:\n                    self.agents[agent_i].mle = [self.mle_base[2], self.mle_base[1]]\n                    actions.append(gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat((observations_[agent_i],\n                                                                     actions_pre[(0)]), 1).to(self.device)),\n                                              hard=True).cpu() )\n                    actions.append(gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat((observations_[agent_i],\n                                                                     actions_pre[(2)]), 1).to(self.device)),\n                                              hard=True).cpu() )\n                elif agent_i == 2:\n                    self.agents[agent_i].mle = [self.mle_base[2], self.mle_base[1]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(observations_[agent_i]), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    actions.append(gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat((observations_[agent_i],\n                                                                     actions_pre[(0)]), 1).to(self.device)),\n                                              hard=True).cpu() )\n                    actions.append(gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat((observations_[agent_i],\n                                                                     actions_pre[(1)]), 1).to(self.device)),\n                                              hard=True).cpu() )\n\n            elif self.nagents == 2:\n                if agent_i < 1:\n                    self.agents[agent_i].mle = [self.mle_base[0]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(observations_[agent_i][:,2:]), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    actions = [torch.zeros((obs_j.shape[0], self.agent_init_params[0]['num_out_pol'])) for j, obs_j in enumerate(obs_)]\n\n                elif agent_i == 1:\n                    self.agents[agent_i].mle = [self.mle_base[1]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(observations_[agent_i][:, 2:]), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(torch.cat((observations_[agent_i][:,2:],\n                                                                     actions_pre[(self.nagents -1 - agent_i)]), 1).to(self.device)),\n                                              hard=True).cpu() for j, obs_j in enumerate(obs_)]\n            if self.nagents == 6:\n                observations[agent_i] = torch.cat((observations[agent_i], actions), 1)\n            else:\n                observations[agent_i] = torch.cat((observations[agent_i], torch.cat(actions, 1)), 1)\n        # t2 = time.time()\n        # print('step+time:', t2 - t1)\n        return [a.step(obs, explore=explore) for a, obs in zip(self.agents,\n                                                               observations)]\n\n\n\n    def _get_obs(self, observations, actions_pre):\n        observations_ = []\n        actions_pre_ = []\n        for agent_i, obs in enumerate(observations):\n            obs_ = observations.copy()\n            obs_.pop(agent_i)\n            actions_pre_ = actions_pre.copy()\n            actions_pre_.pop(agent_i)\n            if self.nagents == 6:\n                if agent_i < 4:   #simple_comm\n                    self.agents[agent_i].mle = [self.mle_base[0], self.mle_base[0], self.mle_base[0], self.mle_base[1], self.mle_base[1]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((obs_j[:, 4:24], actions_pre_[j][:,:5]),1)).detach(), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    actions = [torch.cat((obs_j[:, 4:24], actions_pre_[j][:,:5]),1) for j, obs_j in enumerate(obs_)]\n                    b1 = gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat(actions[:3]).to(self.device)).detach(), hard=True)\n                    b2 = gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat(actions[3:]).to(self.device)).detach(), hard=True)\n                    actions = torch.cat((b1[:1024], b1[1024:2048], b1[2048:3072], b2[:1024], b2[1024:2048]), 1)\n\n                    # print()\n                elif agent_i > 4:\n                    self.agents[agent_i].mle = [self.mle_base[3], self.mle_base[3], self.mle_base[3], self.mle_base[3], self.mle_base[2]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(obs_j[:, 4:24]).detach(), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    actions = [torch.zeros((obs_j.shape[0], self.agent_init_params[1]['num_out_pol'])).to(\n                        torch.device(self.device)).detach()  for j, obs_j in enumerate(obs_)]\n                    actions = torch.cat(actions, 1)\n                    # print()\n            if self.nagents == 4:\n                if agent_i < 3:   #simple_tag\n                    self.agents[agent_i].mle = [self.mle_base[0], self.mle_base[0], self.mle_base[1]]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((obs_j[:, 2:], actions_pre_[j]),1)).detach(), hard=True)\n                               for j, obs_j in enumerate(obs_)]\n                elif agent_i == 3:\n                    self.agents[agent_i].mle = [self.mle_base[2], self.mle_base[2], self.mle_base[2]]\n                    actions = [torch.zeros((obs_j.shape[0],self.agent_init_params[0]['num_out_pol'])).to(torch.device(self.device)).detach()\n                               for j, obs_j in enumerate(obs_)]\n            elif self.nagents == 3:\n                actions = []\n                if agent_i < 1:     #simple_adv\n                    # self.agents[agent_i].mle = [self.mle_base[0],self.mle_base[0]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((obs_j[:,:2],observations[agent_i]),1)).detach(), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    actions = [torch.zeros((obs_j.shape[0],self.agent_init_params[0]['num_out_pol'])).to(torch.device(self.device)).detach()\n                    for j, obs_j in enumerate(obs_)]\n\n                elif agent_i == 1:\n                    self.agents[agent_i].mle = [self.mle_base[2],self.mle_base[1]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(observations[agent_i]).detach(), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    actions.append(\n                        gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat((observations[agent_i],\n                          actions_pre[(0)]), 1).to(self.device)).detach(), hard=True))\n                    actions.append(gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat((observations[agent_i],\n                          actions_pre[(2)]), 1).to(self.device)).detach(), hard=True))\n                elif agent_i == 2:\n                    self.agents[agent_i].mle = [self.mle_base[2], self.mle_base[1]]\n                    actions.append(\n                        gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat((observations[agent_i],\n                          actions_pre[(0)]), 1).to(self.device)).detach(), hard=True))\n                    actions.append(gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat((observations[agent_i],\n                          actions_pre[(1)]), 1).to(self.device)).detach(), hard=True))\n\n            elif self.nagents == 2:\n                if agent_i < 1:     #simple_push\n                    self.agents[agent_i].mle = [self.mle_base[0]]\n                    actions = [torch.zeros((obs_j.shape[0], self.agent_init_params[0]['num_out_pol'])).to(torch.device(self.device)).detach() for j, obs_j in\n                     enumerate(obs_)]\n\n                elif agent_i == 1:\n                    self.agents[agent_i].mle = [self.mle_base[1]]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((observations[agent_i][:,2:],\n                                                                           actions_pre[(self.nagents -1 - agent_i)]), 1)).detach(), hard=True)\n                               for j, obs_j in enumerate(obs_)]\n\n            if self.nagents == 6:\n                observations_.append(torch.cat((observations[agent_i], actions), 1))\n            else:\n                observations_.append(torch.cat((observations[agent_i], torch.cat(actions, 1)), 1))\n\n        return observations_\n\n    def trian_tag(self, agent_i, KL_criterion, obs, acs_pre, parallel, acs):\n        if agent_i == 0:\n            self.mle_opts[0].zero_grad()\n            action_i = self.mle_base[0](torch.cat((obs[0][:, 2:], acs_pre[0]),1))#\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[0].float())\n            loss.backward(retain_graph=True)\n            if parallel:\n                average_gradients(self.mle_base[0])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)\n            self.mle_opts[0].step()\n\n            self.mle_opts[1].zero_grad()\n            action_i = self.mle_base[1](torch.cat((obs[3][:, 2:], acs_pre[3]),1))\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[3].float())\n            loss.backward()\n            if parallel:\n                average_gradients(self.mle_base[1])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)\n            self.mle_opts[1].step()\n\n            # self.mle_opts[2].zero_grad()\n            # action_i = self.mle_base[2](obs[0][:, 2:-2])\n            # action_pre = gumbel_softmax(action_i, hard=True)\n            # loss = KL_criterion(action_pre.float(), acs[0].float())\n            # loss.backward()\n            # if parallel:\n            #     average_gradients(self.mle_base[2])\n            # torch.nn.utils.clip_grad_norm_(self.mle_base[2].parameters(), 20)\n            # self.mle_opts[2].step()\n\n    def trian_adv(self, agent_i, KL_criterion, obs, acs_pre, parallel, acs):\n        if agent_i == 0:\n            # self.mle_opts[0].zero_grad()\n            # action_i = self.mle_base[0](torch.cat((obs[1][:,:2],obs[agent_i]), 1))\n            # action_pre = gumbel_softmax(action_i, hard=True)\n            # loss = KL_criterion(action_pre.float(), acs[1].float())\n            # loss.backward(retain_graph=True)\n            # if parallel:\n            #     average_gradients(self.mle_base[0])\n            # torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)\n            # self.mle_opts[0].step()\n\n            self.mle_opts[1].zero_grad()\n            action_i = self.mle_base[1](torch.cat((obs[1], acs_pre[2]), 1)) #torch.cat((obs[1], acs_pre[2]), 1))\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[1].float())\n            loss.backward()\n            if parallel:\n                average_gradients(self.mle_base[1])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)\n            self.mle_opts[1].step()\n\n            self.mle_opts[2].zero_grad()\n            action_i = self.mle_base[2](torch.cat((obs[1], acs_pre[0]), 1))\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[0].float())\n            loss.backward()\n            if parallel:\n                average_gradients(self.mle_base[2])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[2].parameters(), 20)\n            self.mle_opts[2].step()\n\n    def trian_push(self, agent_i, KL_criterion, obs, acs_pre, parallel, acs):\n        if agent_i == 0:\n            # self.mle_opts[0].zero_grad()\n            # action_i = self.mle_base[0](obs[0][:, 2:])  #torch.cat((obs[agent_i][:,2:], actions[(self.nagents -1 - agent_i)]), 1)\n            # action_pre = gumbel_softmax(action_i, hard=True)\n            # loss = KL_criterion(action_pre.float(), acs[1].float())\n            # loss.backward(retain_graph=True)\n            # if parallel:\n            #     average_gradients(self.mle_base[0])\n            # torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)\n            # self.mle_opts[0].step()\n\n            self.mle_opts[1].zero_grad()\n            action_i = self.mle_base[1](torch.cat((obs[1][:,2:], acs_pre[(0)]), 1))  #obs[1][:, 2:]\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[0].float())\n            loss.backward()\n            if parallel:\n                average_gradients(self.mle_base[1])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)\n            self.mle_opts[1].step()\n\n    def trian_com(self, agent_i, KL_criterion, obs, acs_pre, parallel, acs):\n        if agent_i == 0:\n            self.mle_opts[0].zero_grad()\n            action_i = self.mle_base[0](torch.cat((obs[1][:, 4:24], acs_pre[(1)]), 1))\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[1].float())\n            loss.backward(retain_graph=True)\n            if parallel:\n                average_gradients(self.mle_base[0])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)\n            self.mle_opts[0].step()\n\n            self.mle_opts[1].zero_grad()\n            action_i = self.mle_base[1](torch.cat((obs[4][:, 4:24], acs_pre[(4)]), 1))\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[4].float())\n            loss.backward()\n            if parallel:\n                average_gradients(self.mle_base[1])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)\n            self.mle_opts[1].step()\n\n    def update(self, sample, agent_i, parallel=False, logger=None, sample_r=None):\n        \"\"\"\n        Update parameters of agent model based on sample from replay buffer\n        Inputs:\n            sample: tuple of (observations, actions, rewards, next\n                    observations, and episode end masks) sampled randomly from\n                    the replay buffer. Each is a list with entries\n                    corresponding to each agent\n            agent_i (int): index of agent to update\n            parallel (bool): If true, will average gradients across threads\n            logger (SummaryWriter from Tensorboard-Pytorch):\n                If passed in, important quantities will be logged\n        \"\"\"\n        # print('___update___')\n        acs_pre, obs, acs, rews, next_obs, dones = sample\n\n        next_obs_ = self._get_obs(next_obs, acs)\n        obs_ = self._get_obs(obs, acs_pre)\n        curr_agent = self.agents[agent_i]\n        # mle\n        KL_criterion = torch.nn.KLDivLoss(reduction='sum')\n        # for i in range(len(curr_agent.mle)):\n        #     curr_agent.mle_optimizer[i].zero_grad()\n        #     action_i = curr_agent.mle[i](obs[agent_i]obs[agent_i])\n        #     action_pre = gumbel_softmax(action_i, hard=True)\n        #     loss = KL_criterion(action_pre.float(), acs[i].float())\n        #     loss.backward()\n        #     if parallel:\n        #         average_gradients(curr_agent.mle[i])\n        #     torch.nn.utils.clip_grad_norm_(curr_agent.mle[i].parameters(), 20)\n        #     curr_agent.policy_optimizer.step()\n        if self.nagents == 6:\n            self.trian_com(agent_i, KL_criterion, obs, acs_pre, parallel, acs)\n        elif self.nagents == 4:\n            self.trian_tag(agent_i, KL_criterion, obs, acs_pre, parallel, acs)\n        elif self.nagents == 3:\n            self.trian_adv(agent_i, KL_criterion, obs, acs_pre, parallel, acs)\n        elif self.nagents == 2:\n            self.trian_push(agent_i, KL_criterion, obs, acs_pre, parallel, acs)\n\n        # center critic\n        curr_agent.critic_optimizer.zero_grad()\n        all_trgt_acs = []\n        if self.discrete_action:  # one-hot encode action\n            all_trgt_acs = [onehot_from_logits(pi(nobs)) for pi, nobs in\n                            zip(self.target_policies, next_obs_)]\n        trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1)\n\n        target_value = (rews[agent_i].view(-1, 1) + self.gamma *\n                        curr_agent.target_critic(trgt_vf_in) *\n                        (1 - dones[agent_i].view(-1, 1)))\n\n        vf_in = torch.cat((*obs, *acs), dim=1)\n\n        actual_value = curr_agent.critic(vf_in)\n        vf_loss = MSELoss(actual_value, target_value.detach())\n        vf_loss.backward()\n        if parallel:\n            average_gradients(curr_agent.critic)\n        torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), 0.5)\n        curr_agent.critic_optimizer.step()\n\n        curr_agent.policy_optimizer.zero_grad()\n        if self.discrete_action:\n            # Forward pass as if onehot (hard=True) but backprop through a differentiable\n            # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop\n            # through discrete categorical samples, but I'm not sure if that is\n            # correct since it removes the assumption of a deterministic policy for\n            # DDPG. Regardless, discrete policies don't seem to learn properly without it.\n\n            curr_pol_out = curr_agent.policy(obs_[agent_i])\n            curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)\n        else:\n            curr_pol_out = curr_agent.policy(obs[agent_i])\n            curr_pol_vf_in = curr_pol_out\n        all_pol_acs = []\n        for i, pi, ob in zip(range(self.nagents), self.policies, obs_):\n            if i == agent_i:\n                all_pol_acs.append(curr_pol_vf_in)\n            elif self.discrete_action:\n                all_pol_acs.append(onehot_from_logits(pi(ob)))\n            else:\n                all_pol_acs.append(pi(ob))\n            vf_in = torch.cat((*obs, *all_pol_acs), dim=1)\n\n        pol_loss = -curr_agent.critic(vf_in).mean()\n        pol_loss += (curr_pol_out ** 2).mean() * 1e-3\n        pol_loss.backward()\n        if parallel:\n            average_gradients(curr_agent.policy)\n        torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), 0.5)\n        # actor\n        curr_agent.policy_optimizer.step()\n        if logger is not None:\n            logger.add_scalars('agent%i/losses' % agent_i,\n                               {'vf_loss': vf_loss,\n                                'pol_loss': pol_loss},\n                               self.niter)\n\n    def update_all_targets(self):\n        \"\"\"\n        Update all target networks (called after normal updates have been\n        performed for each agent)\n        \"\"\"\n        for a in self.agents:\n            soft_update(a.target_critic, a.critic, self.tau)\n            soft_update(a.target_policy, a.policy, self.tau)\n        self.niter += 1\n\n    def prep_training(self, device='gpu'):\n        for mle in self.mle_base:\n            mle.train()\n        for a in self.agents:\n            a.policy.train()\n            a.critic.train()\n            a.target_policy.train()\n            a.target_critic.train()\n            for mle_i in a.mle:\n                mle_i.train()\n        if device == 'gpu':\n            fn = lambda x: x.to(torch.device(self.device))\n        else:\n            fn = lambda x: x.cpu()\n        if not self.pol_dev == device:\n            for a in self.agents:\n                a.policy = fn(a.policy)\n            self.pol_dev = device\n        if not self.critic_dev == device:\n            for a in self.agents:\n                a.critic = fn(a.critic)\n            self.critic_dev = device\n        if not self.trgt_pol_dev == device:\n            for a in self.agents:\n                a.target_policy = fn(a.target_policy)\n            self.trgt_pol_dev = device\n        if not self.trgt_critic_dev == device:\n            for a in self.agents:\n                a.target_critic = fn(a.target_critic)\n            self.trgt_critic_dev = device\n        if not self.mle_dev == device:\n            for i, mle in enumerate(self.mle_base):\n                self.mle_base[i] = fn(mle)\n            for a in self.agents:\n                for i, mle_i in enumerate(a.mle):\n                    a.mle[i] = fn(mle_i)\n            self.mle_dev = device\n\n    def prep_rollouts(self, device='cpu'):\n        for a in self.agents:\n            a.policy.eval()\n        if device == 'gpu':\n            fn = lambda x: x.to(torch.device(self.device))\n        else:\n            fn = lambda x: x.cpu()\n        # only need main policy for rollouts\n        if not self.pol_dev == device:\n            for a in self.agents:\n                a.policy = fn(a.policy)\n            self.pol_dev = device\n\n    def save(self, filename):\n        \"\"\"\n        Save trained parameters of all agents into one file\n        \"\"\"\n        self.prep_training(device='cpu')  # move parameters to CPU before saving\n        save_dict = {'init_dict': self.init_dict,\n                     'agent_params': [a.get_params() for a in self.agents],\n                     'mle_params': [self.get_params()],}\n        torch.save(save_dict, filename)\n\n    @classmethod\n    def init_from_env(cls, env, device, agent_alg=\"ToM_SA\", adversary_alg=\"ToM_SA\",\n                      gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64, output_style='sum'):\n        \"\"\"\n        Instantiate instance of this class from multi-agent environment\n        \"\"\"\n        agent_init_params = []\n        alg_types = [adversary_alg if atype == 'adversary' else agent_alg for\n                     atype in env.agent_types]\n        for acsp, obsp, algtype in zip(env.action_space, env.observation_space,\n                                       alg_types):\n            num_in_pol = obsp.shape[0]\n            num_in_mle = obsp.shape[0]\n            if isinstance(acsp, Box):\n                discrete_action = False\n                get_shape = lambda x: x.shape[0]\n            elif isinstance(acsp, Discrete):  # Discrete\n                discrete_action = True\n                get_shape = lambda x: x.n\n            elif isinstance(acsp, MultiDiscrete):\n                discrete_action = True\n                get_shape = lambda x: sum(x.high - x.low + 1)\n            num_out_pol = get_shape(acsp)\n            if algtype == \"ToM_SA\":\n                num_in_critic = 0\n                num_in_pol += (len(env.agent_types)-1) * 5\n                for oobsp in env.observation_space:\n                    num_in_critic += oobsp.shape[0]\n                for oacsp in env.action_space:\n                    if isinstance(oacsp, Box):\n                        discrete_action = False\n                        get_shape = lambda x: x.shape[0]\n                    elif isinstance(oacsp, Discrete):  # Discrete\n                        discrete_action = True\n                        get_shape = lambda x: x.n\n                    elif isinstance(oacsp, MultiDiscrete):\n                        discrete_action = True\n                        get_shape = lambda x: sum(x.high - x.low + 1)\n                    num_in_critic += get_shape(oacsp)\n            else:\n                num_in_critic = obsp.shape[0] + get_shape(acsp)\n            agent_init_params.append({'num_in_pol': num_in_pol,\n                                      'num_out_pol': num_out_pol,\n                                      'num_in_critic': num_in_critic,\n                                      'num_in_mle': num_in_mle,})\n        init_dict = {'gamma': gamma, 'tau': tau, 'lr': lr,\n                     'device': device,\n                     'hidden_dim': hidden_dim,\n                     'alg_types': alg_types,\n                     'agent_init_params': agent_init_params,\n                     'discrete_action': discrete_action,\n                     'output_style': output_style}\n        instance = cls(**init_dict)\n        instance.init_dict = init_dict\n        return instance\n\n    @classmethod\n    def init_from_save(cls, filename):\n        \"\"\"\n        Instantiate instance of this class from file created by 'save' method\n        \"\"\"\n        save_dict = torch.load(filename)\n        instance = cls(**save_dict['init_dict'])\n        instance.init_dict = save_dict['init_dict']\n        for a, params in zip(instance.agents, save_dict['agent_params']):\n            a.load_params(params)\n        for a, params in zip([instance], save_dict['mle_params']):\n            a.load_params(params)\n        return instance\n\n    def get_params(self):\n        params = {\n                }\n        for i in range(len(self.mle_base)):\n            params['mle%d'%i] = self.mle_base[i].state_dict()\n            params['mle_optimizer%d'%i] = self.mle_opts[i].state_dict()\n        return params\n\n    def load_params(self, params):\n        for i in range(len(self.mle_base)):\n            self.mle_base[i].load_state_dict(params['mle%d'%i])\n            self.mle_opts[i].load_state_dict(params['mle_optimizer%d'%i])\n\nclass ToM_S(object):\n    \"\"\"\n    Wrapper class for DDPG-esque (i.e. also MADDPG) agents in multi-agent task\n    \"\"\"\n\n    def __init__(self, agent_init_params, alg_types, output_style, device,\n                 gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64,\n                 discrete_action=False):\n        \"\"\"\n        Inputs:\n            agent_init_params (list of dict): List of dicts with parameters to\n                                              initialize each agent\n                num_in_pol (int): Input dimensions to policy\n                num_out_pol (int): Output dimensions to policy\n                num_in_critic (int): Input dimensions to critic\n            alg_types (list of str): Learning algorithm for each agent (DDPG\n                                       or MADDPG)\n            gamma (float): Discount factor\n            tau (float): Target update rate\n            lr (float): Learning rate for policy and critic\n            hidden_dim (int): Number of hidden dimensions for networks\n            discrete_action (bool): Whether or not to use discrete action space\n        \"\"\"\n        self.device = device\n        self.nagents = len(alg_types)\n        self.alg_types = alg_types\n        self.agents = [DDPGAgent_ToM(lr=lr, discrete_action=discrete_action,\n                                     hidden_dim=hidden_dim,\n                                     **params, output_style=output_style,\n                                     num_agents=self.nagents,\n                                     device=self.device)\n                       for params in agent_init_params]\n        self.agent_init_params = agent_init_params\n        if self.nagents == 6:\n            self.mle_base = [SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14 + 5,      #simple_com\n                                        self.agent_init_params[3]['num_out_pol'], #adv self-self\n                                  hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14 + 5,\n                                        self.agent_init_params[3]['num_out_pol'], #adv self-other\n                                        hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14 + 5,\n                                        self.agent_init_params[3]['num_out_pol'],  # adv self-other\n                                        hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14 + 5,\n                                        self.agent_init_params[3]['num_out_pol'],\n                                        hidden_dim=hidden_dim, output_style=output_style),    ##agent self-other\n                             ]\n        if self.nagents == 4:\n            self.mle_base = [SNNNetwork(self.agent_init_params[0]['num_in_mle'] - 2 + 5,      #simple_tag\n                                        self.agent_init_params[0]['num_out_pol'], #adv self-self\n                                  hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 2 + 5,\n                                        self.agent_init_params[3]['num_out_pol'], #adv self-other\n                                        hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 2 + 5,\n                                        self.agent_init_params[3]['num_out_pol'],\n                                        hidden_dim=hidden_dim, output_style=output_style),    ##agent self-other\n                             ]\n        elif self.nagents == 3:\n            self.mle_base = [SNNNetwork(self.agent_init_params[1]['num_in_mle'] + 5,      #simple_adv\n                                        self.agent_init_params[1]['num_out_pol'], #adv self-other\n                                        hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[1]['num_in_mle'] + 5,\n                                        self.agent_init_params[1]['num_out_pol'], #agent self-self\n                                        hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[1]['num_in_mle'] + 5,\n                                        self.agent_init_params[1]['num_out_pol'],\n                                        hidden_dim=hidden_dim, output_style=output_style),    ##agent self-other\n            ]\n        elif self.nagents == 2:\n            self.mle_base = [SNNNetwork(self.agent_init_params[0]['num_in_mle']-2  + 5,      #simple_push\n                                        self.agent_init_params[0]['num_out_pol'], #adv self-other\n                                        hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[1]['num_in_mle']-2  + 5,\n                                        self.agent_init_params[1]['num_out_pol'],\n                                        hidden_dim=hidden_dim, output_style=output_style),    ##agent self-other\n                ]\n        self.mle_opts = [Adam(self.mle_base[i].parameters(), lr=lr) for i in range(len(self.mle_base))]\n        self.gamma = gamma\n        self.tau = tau\n        self.lr = lr\n        self.discrete_action = discrete_action\n        self.pol_dev = 'cpu'  # device for policies\n        self.critic_dev = 'cpu'  # device for critics\n        self.trgt_pol_dev = 'cpu'  # device for target policies\n        self.trgt_critic_dev = 'cpu'  # device for target critics\n        self.mle_dev = 'cpu'\n        self.niter = 0\n\n    @property\n    def policies(self):\n        return [a.policy for a in self.agents]\n\n    @property\n    def target_policies(self):\n        return [a.target_policy for a in self.agents]\n\n    def scale_noise(self, scale):\n        \"\"\"\n        Scale noise for each agent\n        Inputs:\n            scale (float): scale of noise\n        \"\"\"\n        for a in self.agents:\n            a.scale_noise(scale)\n\n    def reset_noise(self):\n        for a in self.agents:\n            a.reset_noise()\n\n    def step(self, observations, actions_pre, explore=False):    #simple_tag\n        \"\"\"\n        Take a step forward in environment with all agents\n        Inputs:\n            observations: List of observations for each agent\n            explore (boolean): Whether or not to add exploration noise\n        Outputs:\n            actions: List of actions for each agent\n        \"\"\"\n        # t1 = time.time()\n        observations_ = observations.copy()\n        actions_pre_ = actions_pre.copy()\n        for agent_i, obs in enumerate(observations):\n            obs_ = observations_.copy()\n            acs_pre_ = actions_pre_.copy()\n            obs_.pop(agent_i)\n            acs_pre_.pop(agent_i)\n            # actions = [self.agents[agent_i].mle[j].cpu()(observations[agent_i]) for j, obs_j in enumerate(obs_)]\n            # observations[agent_i] = torch.cat((observations[agent_i], torch.cat(actions, 1)), 1)\n            if self.nagents == 6:\n                if agent_i < 4:\n                    self.agents[agent_i].mle = [self.mle_base[0],self.mle_base[0],self.mle_base[0], self.mle_base[1], self.mle_base[1]]\n\n                    actions = [torch.cat((obs_j[:, 4:24], acs_pre_[j][:,:5]),1) for j, obs_j in enumerate(obs_)]\n                    b1 = gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat(actions[:3]).to(self.device)), hard=True).cpu()\n                    b2 = gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat(actions[3:]).to(self.device)), hard=True).cpu()\n                    actions = torch.cat((b1[:20], b1[20:40], b1[40:60], b2[:20], b2[20:40]), 1)\n                    # print(t1 - time.time())\n                    # print()\n                else:\n                    self.agents[agent_i].mle = [self.mle_base[0],self.mle_base[0], self.mle_base[0], self.mle_base[0], self.mle_base[1]]\n                    actions = [torch.cat((obs_j[:, 4:24], acs_pre_[j][:,:5]),1) for j, obs_j in enumerate(obs_)]\n                    b1 = gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat(actions[:3]).to(self.device)), hard=True).cpu()\n                    b2 = gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat(actions[3:]).to(self.device)), hard=True).cpu()\n                    actions = torch.cat((b1[:20], b1[20:40], b1[40:60], b2[:20], b2[20:40]), 1)\n                    # actions = [torch.zeros((obs_j.shape[0],self.agent_init_params[1]['num_out_pol']))\n                    #            for j, obs_j in enumerate(obs_)]\n                    # actions = torch.cat(actions,1)\n                    # print()\n\n            if self.nagents == 4:\n                if agent_i < 3:\n                    self.agents[agent_i].mle = [self.mle_base[0],self.mle_base[0], self.mle_base[1]]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(torch.cat((obs_j[:, 2:], acs_pre_[j]),1).to(self.device)),\n                                              hard=True).cpu()\n                               for j, obs_j in enumerate(obs_)]\n\n                elif agent_i == 3:\n                    self.agents[agent_i].mle = [self.mle_base[2],self.mle_base[2], self.mle_base[2]]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(torch.cat((obs_j[:, 2:-2], acs_pre_[j]),1).to(self.device)),\n                                              hard=True).cpu()\n                               for j, obs_j in enumerate(obs_)]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(obs_j[:,2:-2]), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    # actions = [torch.zeros((obs_j.shape[0],self.agent_init_params[0]['num_out_pol']))\n                    #            for j, obs_j in enumerate(obs_)]\n            elif self.nagents == 3: #simple_adv\n                actions = []\n                if agent_i < 1:\n                    self.agents[agent_i].mle = [self.mle_base[0], self.mle_base[0]]\n                    actions = [torch.zeros((obs_j.shape[0],self.agent_init_params[0]['num_out_pol']))\n                               for j, obs_j in enumerate(obs_)]\n                elif agent_i == 1:\n                    self.agents[agent_i].mle = [self.mle_base[2], self.mle_base[1]]\n                    actions.append(gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat((observations_[agent_i],\n                                                                     actions_pre[(0)]), 1).to(self.device)),\n                                              hard=True).cpu() )\n                    actions.append(gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat((observations_[agent_i],\n                                                                     actions_pre[(2)]), 1).to(self.device)),\n                                              hard=True).cpu() )\n                elif agent_i == 2:\n                    self.agents[agent_i].mle = [self.mle_base[2], self.mle_base[1]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(observations_[agent_i]), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    actions.append(gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat((observations_[agent_i],\n                                                                     actions_pre[(0)]), 1).to(self.device)),\n                                              hard=True).cpu() )\n                    actions.append(gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat((observations_[agent_i],\n                                                                     actions_pre[(1)]), 1).to(self.device)),\n                                              hard=True).cpu() )\n\n            elif self.nagents == 2:\n                if agent_i < 1:\n                    self.agents[agent_i].mle = [self.mle_base[0]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(observations_[agent_i][:,2:]), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    actions = [torch.zeros((obs_j.shape[0], self.agent_init_params[0]['num_out_pol'])) for j, obs_j in enumerate(obs_)]\n\n                elif agent_i == 1:\n                    self.agents[agent_i].mle = [self.mle_base[1]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(observations_[agent_i][:, 2:]), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(torch.cat((observations_[agent_i][:,2:],\n                                                                     actions_pre[(self.nagents -1 - agent_i)]), 1).to(self.device)),\n                                              hard=True).cpu() for j, obs_j in enumerate(obs_)]\n            if self.nagents == 6:\n                observations[agent_i] = torch.cat((observations[agent_i], actions), 1)\n            else:\n                observations[agent_i] = torch.cat((observations[agent_i], torch.cat(actions, 1)), 1)\n        # t2 = time.time()\n        # print('step+time:', t2 - t1)\n        return [a.step(obs, explore=explore) for a, obs in zip(self.agents,\n                                                               observations)]\n\n\n\n    def _get_obs(self, observations, actions_pre):\n        observations_ = []\n        actions_pre_ = []\n        for agent_i, obs in enumerate(observations):\n            obs_ = observations.copy()\n            obs_.pop(agent_i)\n            actions_pre_ = actions_pre.copy()\n            actions_pre_.pop(agent_i)\n            if self.nagents == 6:\n                if agent_i < 4:   #simple_comm\n                    self.agents[agent_i].mle = [self.mle_base[0], self.mle_base[0], self.mle_base[0], self.mle_base[1], self.mle_base[1]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((obs_j[:, 4:24], actions_pre_[j][:,:5]),1)).detach(), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    actions = [torch.cat((obs_j[:, 4:24], actions_pre_[j][:,:5]),1) for j, obs_j in enumerate(obs_)]\n                    b1 = gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat(actions[:3]).to(self.device)).detach(), hard=True)\n                    b2 = gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat(actions[3:]).to(self.device)).detach(), hard=True)\n                    actions = torch.cat((b1[:1024], b1[1024:2048], b1[2048:3072], b2[:1024], b2[1024:2048]), 1)\n\n                    # print()\n                elif agent_i > 4:\n                    self.agents[agent_i].mle = [self.mle_base[0], self.mle_base[0], self.mle_base[0], self.mle_base[0], self.mle_base[1]]\n                    actions = [torch.cat((obs_j[:, 4:24], actions_pre_[j][:,:5]),1) for j, obs_j in enumerate(obs_)]\n                    b1 = gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat(actions[:3]).to(self.device)).detach(), hard=True)\n                    b2 = gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat(actions[3:]).to(self.device)).detach(), hard=True)\n                    actions = torch.cat((b1[:1024], b1[1024:2048], b1[2048:3072], b2[:1024], b2[1024:2048]), 1)\n                    # print()\n            if self.nagents == 4:\n                if agent_i < 3:   #simple_tag\n                    self.agents[agent_i].mle = [self.mle_base[0], self.mle_base[0], self.mle_base[1]]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((obs_j[:, 2:], actions_pre_[j]),1)).detach(), hard=True)\n                               for j, obs_j in enumerate(obs_)]\n                elif agent_i == 3:\n                    self.agents[agent_i].mle = [self.mle_base[2], self.mle_base[2], self.mle_base[2]]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((obs_j[:, 2:-2], actions_pre_[j]),1)).detach(), hard=True)\n                               for j, obs_j in enumerate(obs_)]\n            elif self.nagents == 3:\n                actions = []\n                if agent_i < 1:     #simple_adv\n                    # self.agents[agent_i].mle = [self.mle_base[0],self.mle_base[0]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((obs_j[:,:2],observations[agent_i]),1)).detach(), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    actions = [torch.zeros((obs_j.shape[0],self.agent_init_params[0]['num_out_pol'])).to(torch.device(self.device)).detach()\n                    for j, obs_j in enumerate(obs_)]\n\n                elif agent_i == 1:\n                    self.agents[agent_i].mle = [self.mle_base[2],self.mle_base[1]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(observations[agent_i]).detach(), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    actions.append(\n                        gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat((observations[agent_i],\n                          actions_pre[(0)]), 1).to(self.device)).detach(), hard=True))\n                    actions.append(gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat((observations[agent_i],\n                          actions_pre[(2)]), 1).to(self.device)).detach(), hard=True))\n                elif agent_i == 2:\n                    self.agents[agent_i].mle = [self.mle_base[2], self.mle_base[1]]\n                    actions.append(\n                        gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat((observations[agent_i],\n                          actions_pre[(0)]), 1).to(self.device)).detach(), hard=True))\n                    actions.append(gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat((observations[agent_i],\n                          actions_pre[(1)]), 1).to(self.device)).detach(), hard=True))\n\n            elif self.nagents == 2:\n                if agent_i < 1:     #simple_push\n                    self.agents[agent_i].mle = [self.mle_base[0]]\n                    actions = [torch.zeros((obs_j.shape[0], self.agent_init_params[0]['num_out_pol'])).to(torch.device(self.device)).detach() for j, obs_j in\n                     enumerate(obs_)]\n\n                elif agent_i == 1:\n                    self.agents[agent_i].mle = [self.mle_base[1]]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((observations[agent_i][:,2:],\n                                                                           actions_pre[(self.nagents -1 - agent_i)]), 1)).detach(), hard=True)\n                               for j, obs_j in enumerate(obs_)]\n\n            if self.nagents == 6:\n                observations_.append(torch.cat((observations[agent_i], actions), 1))\n            else:\n                observations_.append(torch.cat((observations[agent_i], torch.cat(actions, 1)), 1))\n\n        return observations_\n\n    def trian_tag(self, agent_i, KL_criterion, obs, acs_pre, parallel, acs):\n        if agent_i == 0:\n            self.mle_opts[0].zero_grad()\n            action_i = self.mle_base[0](torch.cat((obs[0][:, 2:], acs_pre[0]),1))#\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[0].float())\n            loss.backward(retain_graph=True)\n            if parallel:\n                average_gradients(self.mle_base[0])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)\n            self.mle_opts[0].step()\n\n            self.mle_opts[1].zero_grad()\n            action_i = self.mle_base[1](torch.cat((obs[3][:, 2:], acs_pre[3]),1))\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[3].float())\n            loss.backward()\n            if parallel:\n                average_gradients(self.mle_base[1])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)\n            self.mle_opts[1].step()\n\n            self.mle_opts[2].zero_grad()\n            action_i = self.mle_base[2](torch.cat((obs[0][:, 2:-2], acs_pre[0]),1))\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[0].float())\n            loss.backward()\n            if parallel:\n                average_gradients(self.mle_base[2])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[2].parameters(), 20)\n            self.mle_opts[2].step()\n\n    def trian_adv(self, agent_i, KL_criterion, obs, acs_pre, parallel, acs):\n        if agent_i == 0:\n            # self.mle_opts[0].zero_grad()\n            # action_i = self.mle_base[0](torch.cat((obs[1][:,:2],obs[agent_i]), 1))\n            # action_pre = gumbel_softmax(action_i, hard=True)\n            # loss = KL_criterion(action_pre.float(), acs[1].float())\n            # loss.backward(retain_graph=True)\n            # if parallel:\n            #     average_gradients(self.mle_base[0])\n            # torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)\n            # self.mle_opts[0].step()\n\n            self.mle_opts[1].zero_grad()\n            action_i = self.mle_base[1](torch.cat((obs[1], acs_pre[2]), 1)) #torch.cat((obs[1], acs_pre[2]), 1))\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[1].float())\n            loss.backward()\n            if parallel:\n                average_gradients(self.mle_base[1])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)\n            self.mle_opts[1].step()\n\n            self.mle_opts[2].zero_grad()\n            action_i = self.mle_base[2](torch.cat((obs[1], acs_pre[0]), 1))\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[0].float())\n            loss.backward()\n            if parallel:\n                average_gradients(self.mle_base[2])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[2].parameters(), 20)\n            self.mle_opts[2].step()\n\n    def trian_push(self, agent_i, KL_criterion, obs, acs_pre, parallel, acs):\n        if agent_i == 0:\n            # self.mle_opts[0].zero_grad()\n            # action_i = self.mle_base[0](obs[0][:, 2:])  #torch.cat((obs[agent_i][:,2:], actions[(self.nagents -1 - agent_i)]), 1)\n            # action_pre = gumbel_softmax(action_i, hard=True)\n            # loss = KL_criterion(action_pre.float(), acs[1].float())\n            # loss.backward(retain_graph=True)\n            # if parallel:\n            #     average_gradients(self.mle_base[0])\n            # torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)\n            # self.mle_opts[0].step()\n\n            self.mle_opts[1].zero_grad()\n            action_i = self.mle_base[1](torch.cat((obs[1][:,2:], acs_pre[(0)]), 1))  #obs[1][:, 2:]\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[0].float())\n            loss.backward()\n            if parallel:\n                average_gradients(self.mle_base[1])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)\n            self.mle_opts[1].step()\n\n    def trian_com(self, agent_i, KL_criterion, obs, acs_pre, parallel, acs):\n        if agent_i == 0:\n            self.mle_opts[0].zero_grad()\n            action_i = self.mle_base[0](torch.cat((obs[1][:, 4:24], acs_pre[(1)]), 1))\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[1].float())\n            loss.backward(retain_graph=True)\n            if parallel:\n                average_gradients(self.mle_base[0])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)\n            self.mle_opts[0].step()\n\n            self.mle_opts[1].zero_grad()\n            action_i = self.mle_base[1](torch.cat((obs[4][:, 4:24], acs_pre[(4)]), 1))\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[4].float())\n            loss.backward()\n            if parallel:\n                average_gradients(self.mle_base[1])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)\n            self.mle_opts[1].step()\n\n\n    def update(self, sample, agent_i, parallel=False, logger=None, sample_r=None):\n        \"\"\"\n        Update parameters of agent model based on sample from replay buffer\n        Inputs:\n            sample: tuple of (observations, actions, rewards, next\n                    observations, and episode end masks) sampled randomly from\n                    the replay buffer. Each is a list with entries\n                    corresponding to each agent\n            agent_i (int): index of agent to update\n            parallel (bool): If true, will average gradients across threads\n            logger (SummaryWriter from Tensorboard-Pytorch):\n                If passed in, important quantities will be logged\n        \"\"\"\n        # print('___update___')\n        acs_pre, obs, acs, rews, next_obs, dones = sample\n\n        next_obs_ = self._get_obs(next_obs, acs)\n        obs_ = self._get_obs(obs, acs_pre)\n        curr_agent = self.agents[agent_i]\n        # mle\n        KL_criterion = torch.nn.KLDivLoss(reduction='sum')\n        # for i in range(len(curr_agent.mle)):\n        #     curr_agent.mle_optimizer[i].zero_grad()\n        #     action_i = curr_agent.mle[i](obs[agent_i]obs[agent_i])\n        #     action_pre = gumbel_softmax(action_i, hard=True)\n        #     loss = KL_criterion(action_pre.float(), acs[i].float())\n        #     loss.backward()\n        #     if parallel:\n        #         average_gradients(curr_agent.mle[i])\n        #     torch.nn.utils.clip_grad_norm_(curr_agent.mle[i].parameters(), 20)\n        #     curr_agent.policy_optimizer.step()\n        if self.nagents == 6:\n            self.trian_com(agent_i, KL_criterion, obs, acs_pre, parallel, acs)\n        elif self.nagents == 4:\n            self.trian_tag(agent_i, KL_criterion, obs, acs_pre, parallel, acs)\n        elif self.nagents == 3:\n            self.trian_adv(agent_i, KL_criterion, obs, acs_pre, parallel, acs)\n        elif self.nagents == 2:\n            self.trian_push(agent_i, KL_criterion, obs, acs_pre, parallel, acs)\n\n        # center critic\n        curr_agent.critic_optimizer.zero_grad()\n        all_trgt_acs = []\n        if self.discrete_action:  # one-hot encode action\n            all_trgt_acs = [onehot_from_logits(pi(nobs)) for pi, nobs in\n                            zip(self.target_policies, next_obs_)]\n        trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1)\n\n        target_value = (rews[agent_i].view(-1, 1) + self.gamma *\n                        curr_agent.target_critic(trgt_vf_in) *\n                        (1 - dones[agent_i].view(-1, 1)))\n\n        vf_in = torch.cat((*obs, *acs), dim=1)\n\n        actual_value = curr_agent.critic(vf_in)\n        vf_loss = MSELoss(actual_value, target_value.detach())\n        vf_loss.backward()\n        if parallel:\n            average_gradients(curr_agent.critic)\n        torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), 0.5)\n        curr_agent.critic_optimizer.step()\n\n        curr_agent.policy_optimizer.zero_grad()\n        if self.discrete_action:\n            # Forward pass as if onehot (hard=True) but backprop through a differentiable\n            # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop\n            # through discrete categorical samples, but I'm not sure if that is\n            # correct since it removes the assumption of a deterministic policy for\n            # DDPG. Regardless, discrete policies don't seem to learn properly without it.\n\n            curr_pol_out = curr_agent.policy(obs_[agent_i])\n            curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)\n        else:\n            curr_pol_out = curr_agent.policy(obs[agent_i])\n            curr_pol_vf_in = curr_pol_out\n        all_pol_acs = []\n        for i, pi, ob in zip(range(self.nagents), self.policies, obs_):\n            if i == agent_i:\n                all_pol_acs.append(curr_pol_vf_in)\n            elif self.discrete_action:\n                all_pol_acs.append(onehot_from_logits(pi(ob)))\n            else:\n                all_pol_acs.append(pi(ob))\n            vf_in = torch.cat((*obs, *all_pol_acs), dim=1)\n\n        pol_loss = -curr_agent.critic(vf_in).mean()\n        pol_loss += (curr_pol_out ** 2).mean() * 1e-3\n        pol_loss.backward()\n        if parallel:\n            average_gradients(curr_agent.policy)\n        torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), 0.5)\n        # actor\n        curr_agent.policy_optimizer.step()\n        if logger is not None:\n            logger.add_scalars('agent%i/losses' % agent_i,\n                               {'vf_loss': vf_loss,\n                                'pol_loss': pol_loss},\n                               self.niter)\n\n    def update_all_targets(self):\n        \"\"\"\n        Update all target networks (called after normal updates have been\n        performed for each agent)\n        \"\"\"\n        for a in self.agents:\n            soft_update(a.target_critic, a.critic, self.tau)\n            soft_update(a.target_policy, a.policy, self.tau)\n        self.niter += 1\n\n    def prep_training(self, device='gpu'):\n        for mle in self.mle_base:\n            mle.train()\n        for a in self.agents:\n            a.policy.train()\n            a.critic.train()\n            a.target_policy.train()\n            a.target_critic.train()\n            for mle_i in a.mle:\n                mle_i.train()\n        if device == 'gpu':\n            fn = lambda x: x.to(torch.device(self.device))\n        else:\n            fn = lambda x: x.cpu()\n        if not self.pol_dev == device:\n            for a in self.agents:\n                a.policy = fn(a.policy)\n            self.pol_dev = device\n        if not self.critic_dev == device:\n            for a in self.agents:\n                a.critic = fn(a.critic)\n            self.critic_dev = device\n        if not self.trgt_pol_dev == device:\n            for a in self.agents:\n                a.target_policy = fn(a.target_policy)\n            self.trgt_pol_dev = device\n        if not self.trgt_critic_dev == device:\n            for a in self.agents:\n                a.target_critic = fn(a.target_critic)\n            self.trgt_critic_dev = device\n        if not self.mle_dev == device:\n            for i, mle in enumerate(self.mle_base):\n                self.mle_base[i] = fn(mle)\n            for a in self.agents:\n                for i, mle_i in enumerate(a.mle):\n                    a.mle[i] = fn(mle_i)\n            self.mle_dev = device\n\n    def prep_rollouts(self, device='cpu'):\n        for a in self.agents:\n            a.policy.eval()\n        if device == 'gpu':\n            fn = lambda x: x.to(torch.device(self.device))\n        else:\n            fn = lambda x: x.cpu()\n        # only need main policy for rollouts\n        if not self.pol_dev == device:\n            for a in self.agents:\n                a.policy = fn(a.policy)\n            self.pol_dev = device\n\n    def save(self, filename):\n        \"\"\"\n        Save trained parameters of all agents into one file\n        \"\"\"\n        self.prep_training(device='cpu')  # move parameters to CPU before saving\n        save_dict = {'init_dict': self.init_dict,\n                     'agent_params': [a.get_params() for a in self.agents],\n                     'mle_params': [self.get_params()],}\n        torch.save(save_dict, filename)\n\n    @classmethod\n    def init_from_env(cls, env, device, agent_alg=\"ToM_S\", adversary_alg=\"ToM_S\",\n                      gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64, output_style='sum'):\n        \"\"\"\n        Instantiate instance of this class from multi-agent environment\n        \"\"\"\n        agent_init_params = []\n        alg_types = [adversary_alg if atype == 'adversary' else agent_alg for\n                     atype in env.agent_types]\n        for acsp, obsp, algtype in zip(env.action_space, env.observation_space,\n                                       alg_types):\n            num_in_pol = obsp.shape[0]\n            num_in_mle = obsp.shape[0]\n            if isinstance(acsp, Box):\n                discrete_action = False\n                get_shape = lambda x: x.shape[0]\n            elif isinstance(acsp, Discrete):  # Discrete\n                discrete_action = True\n                get_shape = lambda x: x.n\n            elif isinstance(acsp, MultiDiscrete):\n                discrete_action = True\n                get_shape = lambda x: sum(x.high - x.low + 1)\n            num_out_pol = get_shape(acsp)\n            if algtype == \"ToM_S\":\n                num_in_critic = 0\n                num_in_pol += (len(env.agent_types)-1) * 5\n                for oobsp in env.observation_space:\n                    num_in_critic += oobsp.shape[0]\n                for oacsp in env.action_space:\n                    if isinstance(oacsp, Box):\n                        discrete_action = False\n                        get_shape = lambda x: x.shape[0]\n                    elif isinstance(oacsp, Discrete):  # Discrete\n                        discrete_action = True\n                        get_shape = lambda x: x.n\n                    elif isinstance(oacsp, MultiDiscrete):\n                        discrete_action = True\n                        get_shape = lambda x: sum(x.high - x.low + 1)\n                    num_in_critic += get_shape(oacsp)\n            else:\n                num_in_critic = obsp.shape[0] + get_shape(acsp)\n            agent_init_params.append({'num_in_pol': num_in_pol,\n                                      'num_out_pol': num_out_pol,\n                                      'num_in_critic': num_in_critic,\n                                      'num_in_mle': num_in_mle,})\n        init_dict = {'gamma': gamma, 'tau': tau, 'lr': lr,\n                     'device': device,\n                     'hidden_dim': hidden_dim,\n                     'alg_types': alg_types,\n                     'agent_init_params': agent_init_params,\n                     'discrete_action': discrete_action,\n                     'output_style': output_style}\n        instance = cls(**init_dict)\n        instance.init_dict = init_dict\n        return instance\n\n    @classmethod\n    def init_from_save(cls, filename):\n        \"\"\"\n        Instantiate instance of this class from file created by 'save' method\n        \"\"\"\n        save_dict = torch.load(filename)\n        instance = cls(**save_dict['init_dict'])\n        instance.init_dict = save_dict['init_dict']\n        for a, params in zip(instance.agents, save_dict['agent_params']):\n            a.load_params(params)\n        for a, params in zip([instance], save_dict['mle_params']):\n            a.load_params(params)\n        return instance\n\n    def get_params(self):\n        params = {\n                }\n        for i in range(len(self.mle_base)):\n            params['mle%d'%i] = self.mle_base[i].state_dict()\n            params['mle_optimizer%d'%i] = self.mle_opts[i].state_dict()\n        return params\n\n    def load_params(self, params):\n        for i in range(len(self.mle_base)):\n            self.mle_base[i].load_state_dict(params['mle%d'%i])\n            self.mle_opts[i].load_state_dict(params['mle_optimizer%d'%i])\n\nclass ToM_self(object):\n    \"\"\"\n    Wrapper class for DDPG-esque (i.e. also MADDPG) agents in multi-agent task\n    \"\"\"\n\n    def __init__(self, agent_init_params, alg_types, output_style, device,\n                 gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64,\n                 discrete_action=False):\n        \"\"\"\n        Inputs:\n            agent_init_params (list of dict): List of dicts with parameters to\n                                              initialize each agent\n                num_in_pol (int): Input dimensions to policy\n                num_out_pol (int): Output dimensions to policy\n                num_in_critic (int): Input dimensions to critic\n            alg_types (list of str): Learning algorithm for each agent (DDPG\n                                       or MADDPG)\n            gamma (float): Discount factor\n            tau (float): Target update rate\n            lr (float): Learning rate for policy and critic\n            hidden_dim (int): Number of hidden dimensions for networks\n            discrete_action (bool): Whether or not to use discrete action space\n        \"\"\"\n        self.device = device\n        self.nagents = len(alg_types)\n        self.alg_types = alg_types\n        self.agents = [DDPGAgent_ToM(lr=lr, discrete_action=discrete_action,\n                                     hidden_dim=hidden_dim,\n                                     **params, output_style=output_style,\n                                     num_agents=self.nagents,\n                                     device=self.device)\n                       for params in agent_init_params]\n        self.agent_init_params = agent_init_params\n        if self.nagents == 6:\n            self.mle_base = [SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14 + 5,      #simple_com\n                                        self.agent_init_params[3]['num_out_pol'], #adv self-self\n                                  hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14 + 5,\n                                        self.agent_init_params[3]['num_out_pol'], #adv self-other\n                                        hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14 + 5,\n                                        self.agent_init_params[3]['num_out_pol'],  # adv self-other\n                                        hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 14 + 5,\n                                        self.agent_init_params[3]['num_out_pol'],\n                                        hidden_dim=hidden_dim, output_style=output_style),    ##agent self-other\n                             ]\n        if self.nagents == 4:\n            self.mle_base = [SNNNetwork(self.agent_init_params[0]['num_in_mle'] - 2 + 5,      #simple_tag\n                                        self.agent_init_params[0]['num_out_pol'], #adv self-self\n                                  hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 2 + 5,\n                                        self.agent_init_params[3]['num_out_pol'], #adv self-other\n                                        hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[3]['num_in_mle'] - 2 + 5,\n                                        self.agent_init_params[3]['num_out_pol'],\n                                        hidden_dim=hidden_dim, output_style=output_style),    ##agent self-other\n                             ]\n        elif self.nagents == 3:\n            self.mle_base = [SNNNetwork(self.agent_init_params[1]['num_in_mle'] + 5,      #simple_adv\n                                        self.agent_init_params[1]['num_out_pol'], #adv self-other\n                                        hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[1]['num_in_mle'] + 5,\n                                        self.agent_init_params[1]['num_out_pol'], #agent self-self\n                                        hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[1]['num_in_mle'] + 5,\n                                        self.agent_init_params[1]['num_out_pol'],\n                                        hidden_dim=hidden_dim, output_style=output_style),    ##agent self-other\n            ]\n        elif self.nagents == 2:\n            self.mle_base = [SNNNetwork(self.agent_init_params[0]['num_in_mle']-2  + 5,      #simple_push\n                                        self.agent_init_params[0]['num_out_pol'], #adv self-other\n                                        hidden_dim=hidden_dim, output_style=output_style),\n                             SNNNetwork(self.agent_init_params[1]['num_in_mle']-2  + 5,\n                                        self.agent_init_params[1]['num_out_pol'],\n                                        hidden_dim=hidden_dim, output_style=output_style),    ##agent self-other\n                ]\n        self.mle_opts = [Adam(self.mle_base[i].parameters(), lr=lr) for i in range(len(self.mle_base))]\n        self.gamma = gamma\n        self.tau = tau\n        self.lr = lr\n        self.discrete_action = discrete_action\n        self.pol_dev = 'cpu'  # device for policies\n        self.critic_dev = 'cpu'  # device for critics\n        self.trgt_pol_dev = 'cpu'  # device for target policies\n        self.trgt_critic_dev = 'cpu'  # device for target critics\n        self.mle_dev = 'cpu'\n        self.niter = 0\n\n    @property\n    def policies(self):\n        return [a.policy for a in self.agents]\n\n    @property\n    def target_policies(self):\n        return [a.target_policy for a in self.agents]\n\n    def scale_noise(self, scale):\n        \"\"\"\n        Scale noise for each agent\n        Inputs:\n            scale (float): scale of noise\n        \"\"\"\n        for a in self.agents:\n            a.scale_noise(scale)\n\n    def reset_noise(self):\n        for a in self.agents:\n            a.reset_noise()\n\n    def step(self, observations, actions_pre, explore=False):    #simple_tag\n        \"\"\"\n        Take a step forward in environment with all agents\n        Inputs:\n            observations: List of observations for each agent\n            explore (boolean): Whether or not to add exploration noise\n        Outputs:\n            actions: List of actions for each agent\n        \"\"\"\n        # t1 = time.time()\n        observations_ = observations.copy()\n        actions_pre_ = actions_pre.copy()\n        for agent_i, obs in enumerate(observations):\n            obs_ = observations_.copy()\n            acs_pre_ = actions_pre_.copy()\n            obs_.pop(agent_i)\n            acs_pre_.pop(agent_i)\n            # actions = [self.agents[agent_i].mle[j].cpu()(observations[agent_i]) for j, obs_j in enumerate(obs_)]\n            # observations[agent_i] = torch.cat((observations[agent_i], torch.cat(actions, 1)), 1)\n            if self.nagents == 6:\n                if agent_i < 4:\n                    self.agents[agent_i].mle = [self.mle_base[0],self.mle_base[0],self.mle_base[0], self.mle_base[0], self.mle_base[0]]\n\n                    actions = [torch.cat((obs_j[:, 4:24], acs_pre_[j][:,:5]),1) for j, obs_j in enumerate(obs_)]\n                    b1 = gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat(actions[:3]).to(self.device)), hard=True).cpu()\n                    b2 = gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat(actions[3:]).to(self.device)), hard=True).cpu()\n                    actions = torch.cat((b1[:20], b1[20:40], b1[40:60], b2[:20], b2[20:40]), 1)\n                    # print(t1 - time.time())\n                    # print()\n                else:\n                    self.agents[agent_i].mle = [self.mle_base[1],self.mle_base[1], self.mle_base[1], self.mle_base[1], self.mle_base[1]]\n                    actions = [torch.cat((obs_j[:, 4:24], acs_pre_[j][:,:5]),1) for j, obs_j in enumerate(obs_)]\n                    b1 = gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat(actions[:3]).to(self.device)), hard=True).cpu()\n                    b2 = gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat(actions[3:]).to(self.device)), hard=True).cpu()\n                    actions = torch.cat((b1[:20], b1[20:40], b1[40:60], b2[:20], b2[20:40]), 1)\n                    # actions = [torch.zeros((obs_j.shape[0],self.agent_init_params[1]['num_out_pol']))\n                    #            for j, obs_j in enumerate(obs_)]\n                    # actions = torch.cat(actions,1)\n                    # print()\n\n            if self.nagents == 4:\n                if agent_i < 3:\n                    self.agents[agent_i].mle = [self.mle_base[1],self.mle_base[1], self.mle_base[1]]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(torch.cat((obs_j[:, 2:14], acs_pre_[j]),1).to(self.device)),\n                                              hard=True).cpu()\n                               for j, obs_j in enumerate(obs_)]\n\n                elif agent_i == 3:\n                    self.agents[agent_i].mle = [self.mle_base[2],self.mle_base[2], self.mle_base[2]]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(torch.cat((obs_j[:,2:14], acs_pre_[j]),1).to(self.device)),\n                                              hard=True).cpu()  for j, obs_j in enumerate(obs_)]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(obs_j[:,2:-2]), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    # actions = [torch.zeros((obs_j.shape[0],self.agent_init_params[0]['num_out_pol']))\n                    #            for j, obs_j in enumerate(obs_)]\n            elif self.nagents == 3: #simple_adv\n                actions = []\n                if agent_i < 1:\n                    self.agents[agent_i].mle = [self.mle_base[0], self.mle_base[0]]\n                    actions = [torch.zeros((obs_j.shape[0],self.agent_init_params[0]['num_out_pol']))\n                               for j, obs_j in enumerate(obs_)]\n                elif agent_i == 1:\n                    self.agents[agent_i].mle = [self.mle_base[2], self.mle_base[1]]\n                    actions.append(gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat((observations_[agent_i],\n                                                                     actions_pre[(0)]), 1).to(self.device)),\n                                              hard=True).cpu() )\n                    actions.append(gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat((observations_[agent_i],\n                                                                     actions_pre[(2)]), 1).to(self.device)),\n                                              hard=True).cpu() )\n                elif agent_i == 2:\n                    self.agents[agent_i].mle = [self.mle_base[2], self.mle_base[1]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(observations_[agent_i]), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    actions.append(gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat((observations_[agent_i],\n                                                                     actions_pre[(0)]), 1).to(self.device)),\n                                              hard=True).cpu() )\n                    actions.append(gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat((observations_[agent_i],\n                                                                     actions_pre[(1)]), 1).to(self.device)),\n                                              hard=True).cpu() )\n            elif self.nagents == 2:\n                if agent_i < 1:\n                    self.agents[agent_i].mle = [self.mle_base[0]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(observations_[agent_i][:,2:]), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    actions = [torch.zeros((obs_j.shape[0], self.agent_init_params[0]['num_out_pol'])) for j, obs_j in enumerate(obs_)]\n\n                elif agent_i == 1:\n                    self.agents[agent_i].mle = [self.mle_base[1]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].cpu()(observations_[agent_i][:, 2:]), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(torch.cat((observations_[agent_i][:,2:],\n                                                                     actions_pre[(self.nagents -1 - agent_i)]), 1).to(self.device)),\n                                              hard=True).cpu() for j, obs_j in enumerate(obs_)]\n            if self.nagents == 6:\n                observations[agent_i] = torch.cat((observations[agent_i], actions), 1)\n            else:\n                observations[agent_i] = torch.cat((observations[agent_i], torch.cat(actions, 1)), 1)\n        # t2 = time.time()\n        # print('step+time:', t2 - t1)\n        return [a.step(obs, explore=explore) for a, obs in zip(self.agents,\n                                                               observations)]\n\n\n\n    def _get_obs(self, observations, actions_pre):\n        observations_ = []\n        actions_pre_ = []\n        for agent_i, obs in enumerate(observations):\n            obs_ = observations.copy()\n            obs_.pop(agent_i)\n            actions_pre_ = actions_pre.copy()\n            actions_pre_.pop(agent_i)\n            if self.nagents == 6:\n                if agent_i < 4:   #simple_comm\n                    self.agents[agent_i].mle = [self.mle_base[0], self.mle_base[0], self.mle_base[0], self.mle_base[0], self.mle_base[0]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((obs_j[:, 4:24], actions_pre_[j][:,:5]),1)).detach(), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    actions = [torch.cat((obs_j[:, 4:24], actions_pre_[j][:,:5]),1) for j, obs_j in enumerate(obs_)]\n                    b1 = gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat(actions[:3]).to(self.device)).detach(), hard=True)\n                    b2 = gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat(actions[3:]).to(self.device)).detach(), hard=True)\n                    actions = torch.cat((b1[:1024], b1[1024:2048], b1[2048:3072], b2[:1024], b2[1024:2048]), 1)\n\n                    # print()\n                elif agent_i > 4:\n                    self.agents[agent_i].mle = [self.mle_base[1], self.mle_base[1], self.mle_base[1], self.mle_base[1], self.mle_base[1]]\n                    actions = [torch.cat((obs_j[:, 4:24], actions_pre_[j][:,:5]),1) for j, obs_j in enumerate(obs_)]\n                    b1 = gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat(actions[:3]).to(self.device)).detach(), hard=True)\n                    b2 = gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat(actions[3:]).to(self.device)).detach(), hard=True)\n                    actions = torch.cat((b1[:1024], b1[1024:2048], b1[2048:3072], b2[:1024], b2[1024:2048]), 1)\n                    # print()\n            if self.nagents == 4:\n                if agent_i < 3:   #simple_tag\n                    self.agents[agent_i].mle = [self.mle_base[1], self.mle_base[1], self.mle_base[1]]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((obs_j[:, 2:14], actions_pre_[j]),1)).detach(), hard=True)\n                               for j, obs_j in enumerate(obs_)]\n                elif agent_i == 3:\n                    self.agents[agent_i].mle = [self.mle_base[2], self.mle_base[2], self.mle_base[2]]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((obs_j[:, 2:14], actions_pre_[j]),1)).detach(), hard=True)\n                               for j, obs_j in enumerate(obs_)]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(self.device)(torch.cat((obs_j[:,2:-2], acs_pre_[j]),1).to(self.device)),\n                    #                           hard=True).cpu()  for j, obs_j in enumerate(obs_)]\n                    # actions = [torch.zeros((obs_j.shape[0],self.agent_init_params[0]['num_out_pol'])).to(torch.device(self.device)).detach()\n                    #            for j, obs_j in enumerate(obs_)]\n            elif self.nagents == 3:\n                actions = []\n                if agent_i < 1:     #simple_adv\n                    # self.agents[agent_i].mle = [self.mle_base[0],self.mle_base[0]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((obs_j[:,:2],observations[agent_i]),1)).detach(), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    actions = [torch.zeros((obs_j.shape[0],self.agent_init_params[0]['num_out_pol'])).to(torch.device(self.device)).detach()\n                    for j, obs_j in enumerate(obs_)]\n\n                elif agent_i == 1:\n                    self.agents[agent_i].mle = [self.mle_base[2],self.mle_base[1]]\n                    # actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(observations[agent_i]).detach(), hard=True)\n                    #            for j, obs_j in enumerate(obs_)]\n                    actions.append(\n                        gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat((observations[agent_i],\n                          actions_pre[(0)]), 1).to(self.device)).detach(), hard=True))\n                    actions.append(gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat((observations[agent_i],\n                          actions_pre[(2)]), 1).to(self.device)).detach(), hard=True))\n                elif agent_i == 2:\n                    self.agents[agent_i].mle = [self.mle_base[2], self.mle_base[1]]\n                    actions.append(\n                        gumbel_softmax(self.agents[agent_i].mle[0].to(self.device)(torch.cat((observations[agent_i],\n                          actions_pre[(0)]), 1).to(self.device)).detach(), hard=True))\n                    actions.append(gumbel_softmax(self.agents[agent_i].mle[1].to(self.device)(torch.cat((observations[agent_i],\n                          actions_pre[(1)]), 1).to(self.device)).detach(), hard=True))\n\n            elif self.nagents == 2:\n                if agent_i < 1:     #simple_push\n                    self.agents[agent_i].mle = [self.mle_base[0]]\n                    actions = [torch.zeros((obs_j.shape[0], self.agent_init_params[0]['num_out_pol'])).to(torch.device(self.device)).detach() for j, obs_j in\n                     enumerate(obs_)]\n\n                elif agent_i == 1:\n                    self.agents[agent_i].mle = [self.mle_base[1]]\n                    actions = [gumbel_softmax(self.agents[agent_i].mle[j].to(torch.device(self.device))(torch.cat((observations[agent_i][:,2:],\n                                                                           actions_pre[(self.nagents -1 - agent_i)]), 1)).detach(), hard=True)\n                               for j, obs_j in enumerate(obs_)]\n\n            if self.nagents == 6:\n                observations_.append(torch.cat((observations[agent_i], actions), 1))\n            else:\n                observations_.append(torch.cat((observations[agent_i], torch.cat(actions, 1)), 1))\n\n        return observations_\n\n    def trian_tag(self, agent_i, KL_criterion, obs, acs_pre, parallel, acs):\n        if agent_i == 0:\n            self.mle_opts[1].zero_grad()\n            action_i = self.mle_base[1](torch.cat((obs[0][:, 2:14], acs_pre[0]),1))#\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[0].float())\n            loss.backward(retain_graph=True)\n            if parallel:\n                average_gradients(self.mle_base[1])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)\n            self.mle_opts[1].step()\n\n            self.mle_opts[2].zero_grad()\n            action_i = self.mle_base[2](torch.cat((obs[3][:, 2:14], acs_pre[3]),1))\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[3].float())\n            loss.backward()\n            if parallel:\n                average_gradients(self.mle_base[2])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[2].parameters(), 20)\n            self.mle_opts[2].step()\n\n    def trian_adv(self, agent_i, KL_criterion, obs, acs_pre, parallel, acs):\n        if agent_i == 0:\n            self.mle_opts[1].zero_grad()\n            action_i = self.mle_base[1](torch.cat((obs[1], acs_pre[2]), 1)) #torch.cat((obs[1], acs_pre[2]), 1))\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[1].float())\n            loss.backward()\n            if parallel:\n                average_gradients(self.mle_base[1])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)\n            self.mle_opts[1].step()\n\n            self.mle_opts[2].zero_grad()\n            action_i = self.mle_base[2](torch.cat((obs[1], acs_pre[0]), 1))\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[0].float())\n            loss.backward()\n            if parallel:\n                average_gradients(self.mle_base[2])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[2].parameters(), 20)\n            self.mle_opts[2].step()\n\n    def trian_push(self, agent_i, KL_criterion, obs, acs_pre, parallel, acs):\n        if agent_i == 0:\n            self.mle_opts[1].zero_grad()\n            action_i = self.mle_base[1](torch.cat((obs[1][:,2:], acs_pre[(0)]), 1))  #obs[1][:, 2:]\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[0].float())\n            loss.backward()\n            if parallel:\n                average_gradients(self.mle_base[1])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)\n            self.mle_opts[1].step()\n\n    def trian_com(self, agent_i, KL_criterion, obs, acs_pre, parallel, acs):\n        if agent_i == 0:\n            self.mle_opts[0].zero_grad()\n            action_i = self.mle_base[0](torch.cat((obs[1][:, 4:24], acs_pre[(1)]), 1))\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[1].float())\n            loss.backward(retain_graph=True)\n            if parallel:\n                average_gradients(self.mle_base[0])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[0].parameters(), 20)\n            self.mle_opts[0].step()\n\n            self.mle_opts[1].zero_grad()\n            action_i = self.mle_base[1](torch.cat((obs[4][:, 4:24], acs_pre[(4)]), 1))\n            action_pre = gumbel_softmax(action_i, hard=True)\n            loss = KL_criterion(action_pre.float(), acs[4].float())\n            loss.backward()\n            if parallel:\n                average_gradients(self.mle_base[1])\n            torch.nn.utils.clip_grad_norm_(self.mle_base[1].parameters(), 20)\n            self.mle_opts[1].step()\n\n    def update(self, sample, agent_i, parallel=False, logger=None, sample_r=None):\n        \"\"\"\n        Update parameters of agent model based on sample from replay buffer\n        Inputs:\n            sample: tuple of (observations, actions, rewards, next\n                    observations, and episode end masks) sampled randomly from\n                    the replay buffer. Each is a list with entries\n                    corresponding to each agent\n            agent_i (int): index of agent to update\n            parallel (bool): If true, will average gradients across threads\n            logger (SummaryWriter from Tensorboard-Pytorch):\n                If passed in, important quantities will be logged\n        \"\"\"\n        # print('___update___')\n        acs_pre, obs, acs, rews, next_obs, dones = sample\n\n        next_obs_ = self._get_obs(next_obs, acs)\n        obs_ = self._get_obs(obs, acs_pre)\n        curr_agent = self.agents[agent_i]\n        # mle\n        KL_criterion = torch.nn.KLDivLoss(reduction='sum')\n        # for i in range(len(curr_agent.mle)):\n        #     curr_agent.mle_optimizer[i].zero_grad()\n        #     action_i = curr_agent.mle[i](obs[agent_i]obs[agent_i])\n        #     action_pre = gumbel_softmax(action_i, hard=True)\n        #     loss = KL_criterion(action_pre.float(), acs[i].float())\n        #     loss.backward()\n        #     if parallel:\n        #         average_gradients(curr_agent.mle[i])\n        #     torch.nn.utils.clip_grad_norm_(curr_agent.mle[i].parameters(), 20)\n        #     curr_agent.policy_optimizer.step()\n        if self.nagents == 6:\n            self.trian_com(agent_i, KL_criterion, obs, acs_pre, parallel, acs)\n        elif self.nagents == 4:\n            self.trian_tag(agent_i, KL_criterion, obs, acs_pre, parallel, acs)\n        elif self.nagents == 3:\n            self.trian_adv(agent_i, KL_criterion, obs, acs_pre, parallel, acs)\n        elif self.nagents == 2:\n            self.trian_push(agent_i, KL_criterion, obs, acs_pre, parallel, acs)\n\n        # center critic\n        curr_agent.critic_optimizer.zero_grad()\n        all_trgt_acs = []\n        if self.discrete_action:  # one-hot encode action\n            all_trgt_acs = [onehot_from_logits(pi(nobs)) for pi, nobs in\n                            zip(self.target_policies, next_obs_)]\n        trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1)\n\n        target_value = (rews[agent_i].view(-1, 1) + self.gamma *\n                        curr_agent.target_critic(trgt_vf_in) *\n                        (1 - dones[agent_i].view(-1, 1)))\n\n        vf_in = torch.cat((*obs, *acs), dim=1)\n\n        actual_value = curr_agent.critic(vf_in)\n        vf_loss = MSELoss(actual_value, target_value.detach())\n        vf_loss.backward()\n        if parallel:\n            average_gradients(curr_agent.critic)\n        torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), 0.5)\n        curr_agent.critic_optimizer.step()\n\n        curr_agent.policy_optimizer.zero_grad()\n        if self.discrete_action:\n            # Forward pass as if onehot (hard=True) but backprop through a differentiable\n            # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop\n            # through discrete categorical samples, but I'm not sure if that is\n            # correct since it removes the assumption of a deterministic policy for\n            # DDPG. Regardless, discrete policies don't seem to learn properly without it.\n\n            curr_pol_out = curr_agent.policy(obs_[agent_i])\n            curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)\n        else:\n            curr_pol_out = curr_agent.policy(obs[agent_i])\n            curr_pol_vf_in = curr_pol_out\n        all_pol_acs = []\n        for i, pi, ob in zip(range(self.nagents), self.policies, obs_):\n            if i == agent_i:\n                all_pol_acs.append(curr_pol_vf_in)\n            elif self.discrete_action:\n                all_pol_acs.append(onehot_from_logits(pi(ob)))\n            else:\n                all_pol_acs.append(pi(ob))\n            vf_in = torch.cat((*obs, *all_pol_acs), dim=1)\n\n        pol_loss = -curr_agent.critic(vf_in).mean()\n        pol_loss += (curr_pol_out ** 2).mean() * 1e-3\n        pol_loss.backward()\n        if parallel:\n            average_gradients(curr_agent.policy)\n        torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), 0.5)\n        # actor\n        curr_agent.policy_optimizer.step()\n        if logger is not None:\n            logger.add_scalars('agent%i/losses' % agent_i,\n                               {'vf_loss': vf_loss,\n                                'pol_loss': pol_loss},\n                               self.niter)\n\n    def update_all_targets(self):\n        \"\"\"\n        Update all target networks (called after normal updates have been\n        performed for each agent)\n        \"\"\"\n        for a in self.agents:\n            soft_update(a.target_critic, a.critic, self.tau)\n            soft_update(a.target_policy, a.policy, self.tau)\n        self.niter += 1\n\n    def prep_training(self, device='gpu'):\n        for mle in self.mle_base:\n            mle.train()\n        for a in self.agents:\n            a.policy.train()\n            a.critic.train()\n            a.target_policy.train()\n            a.target_critic.train()\n            for mle_i in a.mle:\n                mle_i.train()\n        if device == 'gpu':\n            fn = lambda x: x.to(torch.device(self.device))\n        else:\n            fn = lambda x: x.cpu()\n        if not self.pol_dev == device:\n            for a in self.agents:\n                a.policy = fn(a.policy)\n            self.pol_dev = device\n        if not self.critic_dev == device:\n            for a in self.agents:\n                a.critic = fn(a.critic)\n            self.critic_dev = device\n        if not self.trgt_pol_dev == device:\n            for a in self.agents:\n                a.target_policy = fn(a.target_policy)\n            self.trgt_pol_dev = device\n        if not self.trgt_critic_dev == device:\n            for a in self.agents:\n                a.target_critic = fn(a.target_critic)\n            self.trgt_critic_dev = device\n        if not self.mle_dev == device:\n            for i, mle in enumerate(self.mle_base):\n                self.mle_base[i] = fn(mle)\n            for a in self.agents:\n                for i, mle_i in enumerate(a.mle):\n                    a.mle[i] = fn(mle_i)\n            self.mle_dev = device\n\n    def prep_rollouts(self, device='cpu'):\n        for a in self.agents:\n            a.policy.eval()\n        if device == 'gpu':\n            fn = lambda x: x.to(torch.device(self.device))\n        else:\n            fn = lambda x: x.cpu()\n        # only need main policy for rollouts\n        if not self.pol_dev == device:\n            for a in self.agents:\n                a.policy = fn(a.policy)\n            self.pol_dev = device\n\n    def save(self, filename):\n        \"\"\"\n        Save trained parameters of all agents into one file\n        \"\"\"\n        self.prep_training(device='cpu')  # move parameters to CPU before saving\n        save_dict = {'init_dict': self.init_dict,\n                     'agent_params': [a.get_params() for a in self.agents],\n                     'mle_params': [self.get_params()],}\n        torch.save(save_dict, filename)\n\n    @classmethod\n    def init_from_env(cls, env, device, agent_alg=\"ToM_self\", adversary_alg=\"ToM_self\",\n                      gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64, output_style='sum'):\n        \"\"\"\n        Instantiate instance of this class from multi-agent environment\n        \"\"\"\n        agent_init_params = []\n        alg_types = [adversary_alg if atype == 'adversary' else agent_alg for\n                     atype in env.agent_types]\n        for acsp, obsp, algtype in zip(env.action_space, env.observation_space,\n                                       alg_types):\n            num_in_pol = obsp.shape[0]\n            num_in_mle = obsp.shape[0]\n            if isinstance(acsp, Box):\n                discrete_action = False\n                get_shape = lambda x: x.shape[0]\n            elif isinstance(acsp, Discrete):  # Discrete\n                discrete_action = True\n                get_shape = lambda x: x.n\n            elif isinstance(acsp, MultiDiscrete):\n                discrete_action = True\n                get_shape = lambda x: sum(x.high - x.low + 1)\n            num_out_pol = get_shape(acsp)\n            if algtype == \"ToM_self\":\n                num_in_critic = 0\n                num_in_pol += (len(env.agent_types)-1) * 5\n                for oobsp in env.observation_space:\n                    num_in_critic += oobsp.shape[0]\n                for oacsp in env.action_space:\n                    if isinstance(oacsp, Box):\n                        discrete_action = False\n                        get_shape = lambda x: x.shape[0]\n                    elif isinstance(oacsp, Discrete):  # Discrete\n                        discrete_action = True\n                        get_shape = lambda x: x.n\n                    elif isinstance(oacsp, MultiDiscrete):\n                        discrete_action = True\n                        get_shape = lambda x: sum(x.high - x.low + 1)\n                    num_in_critic += get_shape(oacsp)\n            else:\n                num_in_critic = obsp.shape[0] + get_shape(acsp)\n            agent_init_params.append({'num_in_pol': num_in_pol,\n                                      'num_out_pol': num_out_pol,\n                                      'num_in_critic': num_in_critic,\n                                      'num_in_mle': num_in_mle,})\n        init_dict = {'gamma': gamma, 'tau': tau, 'lr': lr,\n                     'device': device,\n                     'hidden_dim': hidden_dim,\n                     'alg_types': alg_types,\n                     'agent_init_params': agent_init_params,\n                     'discrete_action': discrete_action,\n                     'output_style': output_style}\n        instance = cls(**init_dict)\n        instance.init_dict = init_dict\n        return instance\n\n    @classmethod\n    def init_from_save(cls, filename):\n        \"\"\"\n        Instantiate instance of this class from file created by 'save' method\n        \"\"\"\n        save_dict = torch.load(filename)\n        instance = cls(**save_dict['init_dict'])\n        instance.init_dict = save_dict['init_dict']\n        for a, params in zip(instance.agents, save_dict['agent_params']):\n            a.load_params(params)\n        for a, params in zip([instance], save_dict['mle_params']):\n            a.load_params(params)\n        return instance\n\n    def get_params(self):\n        params = {\n                }\n        for i in range(len(self.mle_base)):\n            params['mle%d'%i] = self.mle_base[i].state_dict()\n            params['mle_optimizer%d'%i] = self.mle_opts[i].state_dict()\n        return params\n\n    def load_params(self, params):\n        for i in range(len(self.mle_base)):\n            self.mle_base[i].load_state_dict(params['mle%d'%i])\n            self.mle_opts[i].load_state_dict(params['mle_optimizer%d'%i])\n\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/MPE/utils/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/MPE/utils/buffer.py",
    "content": "import numpy as np\nimport torch\nfrom torch import Tensor\nfrom torch.autograd import Variable\n\nclass ReplayBuffer(object):\n    \"\"\"\n    Replay Buffer for multi-agent RL with parallel rollouts\n    \"\"\"\n    def __init__(self, max_steps, num_agents, obs_dims, ac_dims, device):\n        \"\"\"\n        Inputs:\n            max_steps (int): Maximum number of timepoints to store in buffer\n            num_agents (int): Number of agents in environment\n            obs_dims (list of ints): number of obervation dimensions for each\n                                     agent\n            ac_dims (list of ints): number of action dimensions for each agent\n        \"\"\"\n        self.device = device\n        self.max_steps = max_steps\n        self.num_agents = num_agents\n        self.obs_buffs = []\n        self.ac_buffs = []\n        self.rew_buffs = []\n        self.next_obs_buffs = []\n        self.done_buffs = []\n        for odim, adim in zip(obs_dims, ac_dims):\n            self.obs_buffs.append(np.zeros((max_steps, odim)))\n            self.ac_buffs.append(np.zeros((max_steps, adim)))\n            self.rew_buffs.append(np.zeros(max_steps))\n            self.next_obs_buffs.append(np.zeros((max_steps, odim)))\n            self.done_buffs.append(np.zeros(max_steps))\n\n        self.filled_i = 0  # index of first empty location in buffer (last index when full)\n        self.curr_i = 0  # current index to write to (ovewrite oldest data)\n\n    def __len__(self):\n        return self.filled_i\n\n    def push(self, observations, actions, rewards, next_observations, dones):\n        nentries = observations.shape[0]  # handle multiple parallel environments\n        if self.curr_i + nentries > self.max_steps:\n            rollover = self.max_steps - self.curr_i # num of indices to roll over\n            for agent_i in range(self.num_agents):\n                self.obs_buffs[agent_i] = np.roll(self.obs_buffs[agent_i],\n                                                  rollover, axis=0)\n                self.ac_buffs[agent_i] = np.roll(self.ac_buffs[agent_i],\n                                                 rollover, axis=0)\n                self.rew_buffs[agent_i] = np.roll(self.rew_buffs[agent_i],\n                                                  rollover)\n                self.next_obs_buffs[agent_i] = np.roll(\n                    self.next_obs_buffs[agent_i], rollover, axis=0)\n                self.done_buffs[agent_i] = np.roll(self.done_buffs[agent_i],\n                                                   rollover)\n            self.curr_i = 0\n            self.filled_i = self.max_steps\n        for agent_i in range(self.num_agents):\n            self.obs_buffs[agent_i][self.curr_i:self.curr_i + nentries] = np.vstack(\n                observations[:, agent_i])\n            # actions are already batched by agent, so they are indexed differently\n            self.ac_buffs[agent_i][self.curr_i:self.curr_i + nentries] = actions[agent_i]\n            self.rew_buffs[agent_i][self.curr_i:self.curr_i + nentries] = rewards[:, agent_i]\n            self.next_obs_buffs[agent_i][self.curr_i:self.curr_i + nentries] = np.vstack(\n                next_observations[:, agent_i])\n            self.done_buffs[agent_i][self.curr_i:self.curr_i + nentries] = dones[:, agent_i]\n        self.curr_i += nentries\n        if self.filled_i < self.max_steps:\n            self.filled_i += nentries\n        if self.curr_i == self.max_steps:\n            self.curr_i = 0\n\n    def sample(self, N, to_gpu=False, norm_rews=True):\n        inds = np.random.choice(np.arange(self.filled_i), size=N,\n                                replace=False)\n        if to_gpu:\n            cast = lambda x: Variable(Tensor(x), requires_grad=False).to(torch.device(self.device))\n        else:\n            cast = lambda x: Variable(Tensor(x), requires_grad=False)\n        if norm_rews:\n            ret_rews = [cast((self.rew_buffs[i][inds] -\n                              self.rew_buffs[i][:self.filled_i].mean()) /\n                             self.rew_buffs[i][:self.filled_i].std())\n                        for i in range(self.num_agents)]\n        else:\n            ret_rews = [cast(self.rew_buffs[i][inds]) for i in range(self.num_agents)]\n        return ([cast(self.obs_buffs[i][inds]) for i in range(self.num_agents)],\n                [cast(self.ac_buffs[i][inds]) for i in range(self.num_agents)],\n                ret_rews,\n                [cast(self.next_obs_buffs[i][inds]) for i in range(self.num_agents)],\n                [cast(self.done_buffs[i][inds]) for i in range(self.num_agents)])\n\n    def get_average_rewards(self, N):\n        if self.filled_i == self.max_steps:\n            inds = np.arange(self.curr_i - N, self.curr_i)  # allow for negative indexing\n        else:\n            inds = np.arange(max(0, self.curr_i - N), self.curr_i)\n        return [self.rew_buffs[i][inds].mean() for i in range(self.num_agents)]\n\nclass ReplayBuffer_pre(object):\n    \"\"\"\n    Replay Buffer for multi-agent RL with parallel rollouts\n    \"\"\"\n    def __init__(self, max_steps, num_agents, obs_dims, ac_dims, device):\n        \"\"\"\n        Inputs:\n            max_steps (int): Maximum number of timepoints to store in buffer\n            num_agents (int): Number of agents in environment\n            obs_dims (list of ints): number of obervation dimensions for each\n                                     agent\n            ac_dims (list of ints): number of action dimensions for each agent\n        \"\"\"\n        self.device = device\n        self.max_steps = max_steps\n        self.num_agents = num_agents\n        self.ac_pre_buffs = []\n        self.obs_buffs = []\n        self.ac_buffs = []\n        self.rew_buffs = []\n        self.next_obs_buffs = []\n        self.done_buffs = []\n        for odim, adim in zip(obs_dims, ac_dims):\n            self.ac_pre_buffs.append(np.zeros((max_steps, 5)))\n            self.obs_buffs.append(np.zeros((max_steps, odim)))\n            self.ac_buffs.append(np.zeros((max_steps, adim)))\n            self.rew_buffs.append(np.zeros(max_steps))\n            self.next_obs_buffs.append(np.zeros((max_steps, odim)))\n            self.done_buffs.append(np.zeros(max_steps))\n\n        self.filled_i = 0  # index of first empty location in buffer (last index when full)\n        self.curr_i = 0  # current index to write to (ovewrite oldest data)\n\n    def __len__(self):\n        return self.filled_i\n\n    def push(self, actions_pre, observations, actions, rewards, next_observations, dones):\n        nentries = observations.shape[0]  # handle multiple parallel environments\n        if self.curr_i + nentries > self.max_steps:\n            rollover = self.max_steps - self.curr_i # num of indices to roll over\n            for agent_i in range(self.num_agents):\n                self.ac_pre_buffs[agent_i] = np.roll(self.ac_pre_buffs[agent_i][:,:5],\n                                                 rollover, axis=0)\n                self.obs_buffs[agent_i] = np.roll(self.obs_buffs[agent_i],\n                                                  rollover, axis=0)\n                self.ac_buffs[agent_i] = np.roll(self.ac_buffs[agent_i],\n                                                 rollover, axis=0)\n                self.rew_buffs[agent_i] = np.roll(self.rew_buffs[agent_i],\n                                                  rollover)\n                self.next_obs_buffs[agent_i] = np.roll(\n                    self.next_obs_buffs[agent_i], rollover, axis=0)\n                self.done_buffs[agent_i] = np.roll(self.done_buffs[agent_i],\n                                                   rollover)\n            self.curr_i = 0\n            self.filled_i = self.max_steps\n        for agent_i in range(self.num_agents):\n            self.ac_pre_buffs[agent_i][self.curr_i:self.curr_i + nentries] = actions_pre[agent_i][:,:5]\n            self.obs_buffs[agent_i][self.curr_i:self.curr_i + nentries] = np.vstack(\n                observations[:, agent_i])\n            # actions are already batched by agent, so they are indexed differently\n            self.ac_buffs[agent_i][self.curr_i:self.curr_i + nentries] = actions[agent_i]\n            self.rew_buffs[agent_i][self.curr_i:self.curr_i + nentries] = rewards[:, agent_i]\n            self.next_obs_buffs[agent_i][self.curr_i:self.curr_i + nentries] = np.vstack(\n                next_observations[:, agent_i])\n            self.done_buffs[agent_i][self.curr_i:self.curr_i + nentries] = dones[:, agent_i]\n        self.curr_i += nentries\n        if self.filled_i < self.max_steps:\n            self.filled_i += nentries\n        if self.curr_i == self.max_steps:\n            self.curr_i = 0\n\n    def sample(self, N, to_gpu=False, norm_rews=True):\n        inds = np.random.choice(np.arange(self.filled_i), size=N,\n                                replace=False)\n        if to_gpu:\n            cast = lambda x: Variable(Tensor(x), requires_grad=False).to(torch.device(self.device))\n        else:\n            cast = lambda x: Variable(Tensor(x), requires_grad=False)\n        if norm_rews:\n            ret_rews = [cast((self.rew_buffs[i][inds] -\n                              self.rew_buffs[i][:self.filled_i].mean()) /\n                             self.rew_buffs[i][:self.filled_i].std())\n                        for i in range(self.num_agents)]\n        else:\n            ret_rews = [cast(self.rew_buffs[i][inds]) for i in range(self.num_agents)]\n        return ([cast(self.ac_pre_buffs[i][inds]) for i in range(self.num_agents)],\n                [cast(self.obs_buffs[i][inds]) for i in range(self.num_agents)],\n                [cast(self.ac_buffs[i][inds]) for i in range(self.num_agents)],\n                ret_rews,\n                [cast(self.next_obs_buffs[i][inds]) for i in range(self.num_agents)],\n                [cast(self.done_buffs[i][inds]) for i in range(self.num_agents)])\n\n    def get_average_rewards(self, N):\n        if self.filled_i == self.max_steps:\n            inds = np.arange(self.curr_i - N, self.curr_i)  # allow for negative indexing\n        else:\n            inds = np.arange(max(0, self.curr_i - N), self.curr_i)\n        return [self.rew_buffs[i][inds].mean() for i in range(self.num_agents)]\n\n\nclass ReplayBuffer_RNN(object):\n    \"\"\"\n    Replay Buffer for multi-agent RL with parallel rollouts\n    \"\"\"\n    def __init__(self, max_steps, num_agents, obs_dims, ac_dims, ep_dims):\n        \"\"\"\n        Inputs:\n            max_steps (int): Maximum number of timepoints to store in buffer\n            num_agents (int): Number of agents in environment\n            obs_dims (list of ints): number of obervation dimensions for each\n                                     agent\n            ac_dims (list of ints): number of action dimensions for each agent\n            ep_dims (int): Number of steps in each episode\n        \"\"\"\n        self.max_steps = max_steps\n        self.num_agents = num_agents\n        self.obs_buffs = []\n        self.ac_buffs = []\n        self.rew_buffs = []\n        self.next_obs_buffs = []\n        self.done_buffs = []\n        for odim, adim in zip(obs_dims, ac_dims):\n            self.obs_buffs.append(np.zeros((max_steps, ep_dims, odim)))\n            self.ac_buffs.append(np.zeros((max_steps, ep_dims, adim)))\n            self.rew_buffs.append(np.zeros((max_steps, ep_dims)))\n            self.next_obs_buffs.append(np.zeros((max_steps, ep_dims, odim)))\n            self.done_buffs.append(np.zeros((max_steps, ep_dims)))\n\n        self.filled_i = 0  # index of first empty location in buffer (last index when full)\n        self.curr_i = 0  # current index to write to (ovewrite oldest data)\n\n    def __len__(self):\n        return self.filled_i\n\n    def push(self, observations_ep, actions_ep, rewards_ep, next_observations_ep, dones_ep):\n        nentries = observations_ep[0].shape[0]  # handle multiple parallel environments\n        observations_ep, actions_ep, rewards_ep, next_observations_ep, dones_ep = \\\n            np.array(observations_ep), np.array(actions_ep), np.array(rewards_ep),\\\n            np.array(next_observations_ep), np.array(dones_ep)\n        if self.curr_i + nentries > self.max_steps:\n            rollover = self.max_steps - self.curr_i # num of indices to roll over\n            for agent_i in range(self.num_agents):\n                self.obs_buffs[agent_i] = np.roll(self.obs_buffs[agent_i],\n                                                  rollover, axis=0)\n                self.ac_buffs[agent_i] = np.roll(self.ac_buffs[agent_i],\n                                                 rollover, axis=0)\n                self.rew_buffs[agent_i] = np.roll(self.rew_buffs[agent_i],\n                                                  rollover)\n                self.next_obs_buffs[agent_i] = np.roll(\n                    self.next_obs_buffs[agent_i], rollover, axis=0)\n                self.done_buffs[agent_i] = np.roll(self.done_buffs[agent_i],\n                                                   rollover)\n            self.curr_i = 0\n            self.filled_i = self.max_steps\n        for agent_i in range(self.num_agents):\n            for i in range(observations_ep[:,:,agent_i].shape[0]):\n                if i == 0:\n                    ob_ep = np.expand_dims(np.vstack(observations_ep[:,:,agent_i][i]), 0)\n                    ob_next_ep = np.expand_dims(np.vstack(next_observations_ep[:,:,agent_i][i]), 0)\n                else:\n                    ob_ep = np.vstack((ob_ep, np.expand_dims(np.vstack(observations_ep[:,:,agent_i][i]), 0)))\n                    ob_next_ep = np.vstack((ob_next_ep, np.expand_dims(np.vstack(next_observations_ep[:,:,agent_i][i]), 0)))\n\n            self.obs_buffs[agent_i][self.curr_i:self.curr_i + nentries, :] = ob_ep.transpose(1, 0, 2)\n            # actions are already batched by agent, so they are indexed differently\n            self.ac_buffs[agent_i][self.curr_i:self.curr_i + nentries, :] = actions_ep[:,:,0,:].transpose(1, 0, 2)\n            self.rew_buffs[agent_i][self.curr_i:self.curr_i + nentries, :] = rewards_ep[:, :, agent_i].transpose(1, 0)\n            self.next_obs_buffs[agent_i][self.curr_i:self.curr_i + nentries, :] = ob_next_ep.transpose(1, 0, 2)\n            self.done_buffs[agent_i][self.curr_i:self.curr_i + nentries, :] = dones_ep[:, :, agent_i].transpose(1, 0)\n        self.curr_i += nentries\n        if self.filled_i < self.max_steps:\n            self.filled_i += nentries\n        if self.curr_i == self.max_steps:\n            self.curr_i = 0\n\n    def sample(self, N, to_gpu=False, norm_rews=True):\n        inds = np.random.choice(np.arange(self.filled_i), size=N,\n                                replace=False)\n        if to_gpu:\n            cast = lambda x: Variable(Tensor(x), requires_grad=False).to(torch.device(self.device))\n        else:\n            cast = lambda x: Variable(Tensor(x), requires_grad=False)\n        if norm_rews:\n            ret_rews = [cast((self.rew_buffs[i][inds] -\n                              self.rew_buffs[i][:self.filled_i].mean()) /\n                             self.rew_buffs[i][:self.filled_i].std())\n                        for i in range(self.num_agents)]\n        else:\n            ret_rews = [cast(self.rew_buffs[i][inds]) for i in range(self.num_agents)]\n        return ([cast(self.obs_buffs[i][inds]) for i in range(self.num_agents)],\n                [cast(self.ac_buffs[i][inds]) for i in range(self.num_agents)],\n                ret_rews,\n                [cast(self.next_obs_buffs[i][inds]) for i in range(self.num_agents)],\n                [cast(self.done_buffs[i][inds]) for i in range(self.num_agents)])\n\n    def get_average_rewards(self, N):\n        if self.filled_i == self.max_steps:\n            inds = np.arange(self.curr_i - N, self.curr_i)  # allow for negative indexing\n        else:\n            inds = np.arange(max(0, self.curr_i - N), self.curr_i)\n        return [self.rew_buffs[i][inds].mean() for i in range(self.num_agents)]\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/MPE/utils/env_wrappers.py",
    "content": "\"\"\"\nModified from OpenAI Baselines code to work with multi-agent envs\n\"\"\"\nimport numpy as np\nfrom multiprocessing import Process, Pipe\nfrom common.vec_env.vec_env import VecEnv, CloudpickleWrapper\n\n\ndef worker(remote, parent_remote, env_fn_wrapper):\n    parent_remote.close()\n    env = env_fn_wrapper.x()\n    while True:\n        cmd, data = remote.recv()\n        if cmd == 'step':\n            ob, reward, done, info = env.step(data)\n            if all(done):\n                ob = env.reset()\n            remote.send((ob, reward, done, info))\n        elif cmd == 'reset':\n            ob = env.reset()\n            remote.send(ob)\n        elif cmd == 'reset_task':\n            ob = env.reset_task()\n            remote.send(ob)\n        elif cmd == 'close':\n            remote.close()\n            break\n        elif cmd == 'get_spaces':\n            remote.send((env.observation_space, env.action_space))\n        elif cmd == 'get_agent_types':\n            if all([hasattr(a, 'adversary') for a in env.agents]):\n                remote.send(['adversary' if a.adversary else 'agent' for a in\n                             env.agents])\n            else:\n                remote.send(['agent' for _ in env.agents])\n        else:\n            raise NotImplementedError\n\n\nclass SubprocVecEnv(VecEnv):\n    def __init__(self, env_fns, spaces=None):\n        \"\"\"\n        envs: list of gym environments to run in subprocesses\n        \"\"\"\n        self.waiting = False\n        self.closed = False\n        nenvs = len(env_fns)\n        self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])\n        self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))\n            for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]\n        for p in self.ps:\n            p.daemon = True # if the main process crashes, we should not cause things to hang\n            p.start()\n        for remote in self.work_remotes:\n            remote.close()\n\n        self.remotes[0].send(('get_spaces', None))\n        observation_space, action_space = self.remotes[0].recv()\n        self.remotes[0].send(('get_agent_types', None))\n        self.agent_types = self.remotes[0].recv()\n        VecEnv.__init__(self, len(env_fns), observation_space, action_space)\n\n    def step_async(self, actions):\n        for remote, action in zip(self.remotes, actions):\n            remote.send(('step', action))\n        self.waiting = True\n\n    def step_wait(self):\n        results = [remote.recv() for remote in self.remotes]\n        self.waiting = False\n        obs, rews, dones, infos = zip(*results)\n        return np.stack(obs), np.stack(rews), np.stack(dones), infos\n\n    def reset(self):\n        for remote in self.remotes:\n            remote.send(('reset', None))\n        return np.stack([remote.recv() for remote in self.remotes])\n\n    def reset_task(self):\n        for remote in self.remotes:\n            remote.send(('reset_task', None))\n        return np.stack([remote.recv() for remote in self.remotes])\n\n    def close(self):\n        if self.closed:\n            return\n        if self.waiting:\n            for remote in self.remotes:            \n                remote.recv()\n        for remote in self.remotes:\n            remote.send(('close', None))\n        for p in self.ps:\n            p.join()\n        self.closed = True\n\n\nclass DummyVecEnv(VecEnv):\n    def __init__(self, env_fns):\n        self.envs = [fn() for fn in env_fns]\n        env = self.envs[0]        \n        VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)\n        if all([hasattr(a, 'adversary') for a in env.agents]):\n            self.agent_types = ['adversary' if a.adversary else 'agent' for a in\n                                env.agents]\n        else:\n            self.agent_types = ['agent' for _ in env.agents]\n        self.ts = np.zeros(len(self.envs), dtype='int')        \n        self.actions = None\n\n    def step_async(self, actions):\n        self.actions = actions\n\n    def step_wait(self):\n        results = [env.step(a) for (a,env) in zip(self.actions, self.envs)]\n        obs, rews, dones, infos = map(np.array, zip(*results))\n        self.ts += 1\n        for (i, done) in enumerate(dones):\n            if all(done): \n                obs[i] = self.envs[i].reset()\n                self.ts[i] = 0\n        self.actions = None\n        return np.array(obs), np.array(rews), np.array(dones), infos\n\n    def reset(self):        \n        results = [env.reset() for env in self.envs]\n        return np.array(results)\n\n    def close(self):\n        return"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/MPE/utils/make_env.py",
    "content": "\"\"\"\nCode for creating a multiagent environment with one of the scenarios listed\nin ./scenarios/.\nCan be called by using, for example:\n    env = make_env('simple_speaker_listener')\nAfter producing the env object, can be used similarly to an OpenAI gym\nenvironment.\n\nA policy using this environment must output actions in the form of a list\nfor all agents. Each element of the list should be a numpy array,\nof size (env.world.dim_p + env.world.dim_c, 1). Physical actions precede\ncommunication actions in this array. See environment.py for more details.\n\"\"\"\n\ndef make_env(scenario_name, benchmark=False, discrete_action=False):\n    '''\n    Creates a MultiAgentEnv object as env. This can be used similar to a gym\n    environment by calling env.reset() and env.step().\n    Use env.render() to view the environment on the screen.\n\n    Input:\n        scenario_name   :   name of the scenario from ./scenarios/ to be Returns\n                            (without the .py extension)\n        benchmark       :   whether you want to produce benchmarking data\n                            (usually only done during evaluation)\n\n    Some useful env properties (see environment.py):\n        .observation_space  :   Returns the observation space for each agent\n        .action_space       :   Returns the action space for each agent\n        .n                  :   Returns the number of Agents\n    '''\n    from multiagent.environment import MultiAgentEnv\n    import multiagent.scenarios as scenarios\n\n    # load scenario from script\n    scenario = scenarios.load(scenario_name + \".py\").Scenario()\n    # create world\n    world = scenario.make_world()\n    # create multiagent environment\n    if benchmark:        \n        env = MultiAgentEnv(world, scenario.reset_world, scenario.reward,\n                            scenario.observation, scenario.benchmark_data)\n    else:\n        env = MultiAgentEnv(world, scenario.reset_world, scenario.reward,\n                            scenario.observation)\n    return env\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/MPE/utils/misc.py",
    "content": "import os\nimport torch\nimport torch.nn.functional as F\nimport torch.distributed as dist\nfrom torch.autograd import Variable\nimport numpy as np\n\n# https://github.com/ikostrikov/pytorch-ddpg-naf/blob/master/ddpg.py#L11\ndef soft_update(target, source, tau):\n    \"\"\"\n    Perform DDPG soft update (move target params toward source based on weight\n    factor tau)\n    Inputs:\n        target (torch.nn.Module): Net to copy parameters to\n        source (torch.nn.Module): Net whose parameters to copy\n        tau (float, 0 < x < 1): Weight factor for update\n    \"\"\"\n    for target_param, param in zip(target.parameters(), source.parameters()):\n        target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)\n\n# https://github.com/ikostrikov/pytorch-ddpg-naf/blob/master/ddpg.py#L15\ndef hard_update(target, source):\n    \"\"\"\n    Copy network parameters from source to target\n    Inputs:\n        target (torch.nn.Module): Net to copy parameters to\n        source (torch.nn.Module): Net whose parameters to copy\n    \"\"\"\n    for target_param, param in zip(target.parameters(), source.parameters()):\n        target_param.data.copy_(param.data)\n\n# https://github.com/seba-1511/dist_tuto.pth/blob/gh-pages/train_dist.py\ndef average_gradients(model):\n    \"\"\" Gradient averaging. \"\"\"\n    size = float(dist.get_world_size())\n    for param in model.parameters():\n        dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM, group=0)\n        param.grad.data /= size\n\n# https://github.com/seba-1511/dist_tuto.pth/blob/gh-pages/train_dist.py\ndef init_processes(rank, size, fn, backend='gloo'):\n    \"\"\" Initialize the distributed environment. \"\"\"\n    os.environ['MASTER_ADDR'] = '127.0.0.1'\n    os.environ['MASTER_PORT'] = '29500'\n    dist.init_process_group(backend, rank=rank, world_size=size)\n    fn(rank, size)\n\ndef onehot_from_logits(logits, eps=0.0):\n    \"\"\"\n    Given batch of logits, return one-hot sample using epsilon greedy strategy\n    (based on given epsilon)\n    \"\"\"\n    # get best (according to current policy) actions in one-hot form\n    argmax_acs = (logits == logits.max(1, keepdim=True)[0]).float()\n    if eps == 0.0:\n        return argmax_acs\n    # get random actions in one-hot form\n    rand_acs = Variable(torch.eye(logits.shape[1])[[np.random.choice(\n        range(logits.shape[1]), size=logits.shape[0])]], requires_grad=False)\n    # chooses between best and random actions using epsilon greedy\n    return torch.stack([argmax_acs[i] if r > eps else rand_acs[i] for i, r in\n                        enumerate(torch.rand(logits.shape[0]))])\n\n# modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb\ndef sample_gumbel(shape, eps=1e-20, tens_type=torch.FloatTensor):\n    \"\"\"Sample from Gumbel(0, 1)\"\"\"\n    U = Variable(tens_type(*shape).uniform_(), requires_grad=False)\n    return -torch.log(-torch.log(U + eps) + eps)\n\n# modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb\ndef gumbel_softmax_sample(logits, temperature):\n    \"\"\" Draw a sample from the Gumbel-Softmax distribution\"\"\"\n    y = logits + sample_gumbel(logits.shape, tens_type=type(logits.data)).to(logits.device)\n    return F.softmax(y / temperature, dim=1)\n\n# modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb\ndef gumbel_softmax(logits, temperature=1.0, hard=False):\n    \"\"\"Sample from the Gumbel-Softmax distribution and optionally discretize.\n    Args:\n      logits: [batch_size, n_class] unnormalized log-probs\n      temperature: non-negative scalar\n      hard: if True, take argmax, but differentiate w.r.t. soft sample y\n    Returns:\n      [batch_size, n_class] sample from the Gumbel-Softmax distribution.\n      If hard=True, then the returned sample will be one-hot, otherwise it will\n      be a probabilitiy distribution that sums to 1 across classes\n    \"\"\"\n    y = gumbel_softmax_sample(logits, temperature)\n    if hard:\n        y_hard = onehot_from_logits(y)\n        y = (y_hard - y).detach() + y\n    return y\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/MPE/utils/multiprocessing.py",
    "content": "# This code is from openai baseline\n# https://github.com/openai/baselines/tree/master/baselines/common/vec_env\nimport time\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom multiprocessing import Process, Pipe\n\n\ndef _flatten_list(l):\n    assert isinstance(l, (list, tuple))\n    assert len(l) > 0\n    assert all([len(l_) > 0 for l_ in l])\n\n    return [l__ for l_ in l for l__ in l_]\n\n\ndef worker(remote, parent_remote, env_fn_wrapper):\n    parent_remote.close()\n    env = env_fn_wrapper.x()\n    while True:\n        cmd, data = remote.recv()\n        if cmd == 'step':\n            ob, reward, done, info = env.step(data)\n            if done:\n                ob = env.reset()\n            remote.send((ob, reward, done, info))\n        elif cmd == 'reset':\n            ob = env.reset()\n            remote.send(ob)\n        elif cmd == 'reset_task':\n            ob = env.reset_task()\n            remote.send(ob)\n        elif cmd == 'render':\n            ob = env.render(mode='rgb_array')\n            # print(len(ob), 'len(frames)')\n            # print(len(ob[0]), 'len(frames[0])')\n            # print(len(ob[0][0]), 'len(frames[0][0])')\n            remote.send(ob)  # rgb_array\n        elif cmd == 'observe':\n            ob = env.observe(data)\n            remote.send(ob)\n        elif cmd == 'agents':\n            remote.send(env.agents)\n        elif cmd == 'spec':\n            remote.send(env.spec)\n        elif cmd == 'get_spaces':\n            remote.send((env.observation_space, env.action_space))\n        elif cmd == 'close':\n            remote.close()\n            break\n        else:\n            raise NotImplementedError\n\n\nclass VecEnv(object):\n    \"\"\"\n    An abstract asynchronous, vectorized environment.\n    \"\"\"\n    closed = False\n    viewer = None\n\n    metadata = {\n        'render.modes': ['human', 'rgb_array']\n    }\n    def __init__(self, num_envs, observation_space, action_space):\n        self.num_envs = num_envs\n        self.observation_space = observation_space\n        self.action_space = action_space\n\n    def observe(self, agent):\n        pass\n\n    def reset(self):\n        \"\"\"\n        Reset all the environments and return an array of\n        observations, or a tuple of observation arrays.\n        If step_async is still doing work, that work will\n        be cancelled and step_wait() should not be called\n        until step_async() is invoked again.\n        \"\"\"\n        pass\n\n    def step_async(self, actions):\n        \"\"\"\n        Tell all the environments to start taking a step\n        with the given actions.\n        Call step_wait() to get the results of the step.\n        You should not call this if a step_async run is\n        already pending.\n        \"\"\"\n        pass\n\n    def step_wait(self):\n        \"\"\"\n        Wait for the step taken with step_async().\n        Returns (obs, rews, dones, infos):\n         - obs: an array of observations, or a tuple of\n                arrays of observations.\n         - rews: an array of rewards\n         - dones: an array of \"episode done\" booleans\n         - infos: a sequence of info objects\n        \"\"\"\n        pass\n\n    def close(self):\n        \"\"\"\n        Clean up the environments' resources.\n        \"\"\"\n        pass\n\n    def step(self, actions):\n        self.step_async(actions)\n        return self.step_wait()\n\n    def render(self, mode='human'):\n        imgs = self.get_images()\n        bigimg = self.tile_images(imgs)\n        if mode == 'human':\n            self.get_viewer().imshow(bigimg)    #\n            return self.get_viewer().isopen\n\n        elif mode == 'rgb_array':\n            return bigimg\n        else:\n            raise NotImplementedError\n\n    def get_images(self):\n        \"\"\"\n        Return RGB images from each environment\n        \"\"\"\n        raise NotImplementedError\n\n    def get_viewer(self):\n        if self.viewer is None:\n            from common import rendering\n            self.viewer = rendering.SimpleImageViewer()\n        return self.viewer\n\n    def tile_images(self, img_nhwc):\n        \"\"\"\n        Tile N images into one big PxQ image\n        (P,Q) are chosen to be as close as possible, and if N\n        is square, then P=Q.\n        input: img_nhwc, list or array of images, ndim=4 once turned into array\n            n = batch index, h = height, w = width, c = channel\n        returns:\n            bigim_HWc, ndarray with ndim=3\n        \"\"\"\n        img_nhwc = np.asarray(img_nhwc)\n        N, h, w, c = img_nhwc.shape\n        H = int(np.ceil(np.sqrt(N)))\n        W = int(np.ceil(float(N) / H))\n        img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0] * 0 for _ in range(N, H * W)])\n        img_HWhwc = img_nhwc.reshape(H, W, h, w, c)\n        img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4)\n        img_Hh_Ww_c = img_HhWwc.reshape(H * h, W * w, c)\n        return img_Hh_Ww_c\n\nclass CloudpickleWrapper(object):\n    \"\"\"\n    Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)\n    \"\"\"\n\n    def __init__(self, x):\n        self.x = x\n\n    def __getstate__(self):\n        import cloudpickle\n        return cloudpickle.dumps(self.x)\n\n    def __setstate__(self, ob):\n        import pickle\n        self.x = pickle.loads(ob)\n\n\nclass SubprocVecEnv(VecEnv):\n    def __init__(self, env_fns, spaces=None):\n        \"\"\"\n        envs_sc: list of gym environments to run in subprocesses\n        \"\"\"\n        # self.venv = venv\n        self.waiting = False\n        self.closed = False\n        nenvs = len(env_fns)\n        self.nenvs = nenvs\n        self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])\n        self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))\n                   for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]\n        for p in self.ps:\n            p.daemon = True  # if the main process crashes, we should not cause things to hang\n            p.start()\n        for remote in self.work_remotes:\n            remote.close()\n\n        self.remotes[0].send(('get_spaces', None))\n        observation_space, action_space = self.remotes[0].recv()\n\n        VecEnv.__init__(self, len(env_fns), observation_space, action_space)\n\n    def step_async(self, actions):\n        for remote, action in zip(self.remotes, actions):       # the input of step() : action\n            remote.send(('step', action))\n        self.waiting = True\n\n    def step_wait(self):\n        results = [remote.recv() for remote in self.remotes]    # the output of step() : zip(*results)\n        self.waiting = False\n        obs, rews, dones, infos = zip(*results)\n        return np.stack(obs), np.stack(rews), np.stack(dones), infos\n\n    def step_wait_2(self):\n        results = [remote.recv() for remote in self.remotes]\n        self.waiting = False\n        reward, done, _cumulative_rewards = zip(*results)\n        return reward, done, _cumulative_rewards\n\n    def step_wait_3(self):\n        results = [remote.recv() for remote in self.remotes]    # the output of step() : zip(*results)\n        self.waiting = False\n        obs, rews, dones, infos = zip(*results)\n        return np.stack(obs), np.stack(rews), np.stack(dones), infos\n\n    def reset(self):\n        for remote in self.remotes:\n            remote.send(('reset', None))\n        return np.stack([remote.recv() for remote in self.remotes])\n\n    def agents(self):\n        for remote in self.remotes:\n            remote.send(('agents', None))\n        return np.stack([remote.recv() for remote in self.remotes])\n\n    def reset_task(self):\n        for remote in self.remotes:\n            remote.send(('reset_task', None))\n        return np.stack([remote.recv() for remote in self.remotes])\n\n    def spec(self):\n        for remote in self.remotes:\n            remote.send(('spec', None))\n        return np.stack([remote.recv() for remote in self.remotes])\n\n    def get_images(self):\n        # self._assert_not_closed()\n        for pipe in self.remotes:\n            pipe.send(('render', None))\n        imgs = [pipe.recv() for pipe in self.remotes]\n        # imgs = _flatten_list(imgs)\n        return imgs\n\n    def observe(self, agent):\n        for remote, agent in zip(self.remotes, agent):\n            remote.send(('observe', agent))\n        return np.stack([remote.recv() for remote in self.remotes])\n\n    # def render(self, mode='human'):\n    #     return self.venv.render(mode=mode)\n\n    def close(self):\n        if self.closed:\n            return\n        if self.waiting:\n            for remote in self.remotes:\n                remote.recv()\n        for remote in self.remotes:\n            remote.send(('close', None))\n        for p in self.ps:\n            p.join()\n            self.closed = True\n\n    def __len__(self):\n        return self.nenvs\n\ndef _flatten_list(l):\n    assert isinstance(l, (list, tuple))\n    assert len(l) > 0\n    assert all([len(l_) > 0 for l_ in l])\n\n    return [l__ for l_ in l for l__ in l_]\n\nclass DummyVecEnv(VecEnv):\n    def __init__(self, env_fns):\n        self.envs = [fn() for fn in env_fns]\n        env = self.envs[0]\n        VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)\n        if all([hasattr(a, 'adversary') for a in env.agents]):\n            self.agent_types = ['adversary' if a.adversary else 'agent' for a in\n                                env.agents]\n        else:\n            self.agent_types = ['agent' for _ in env.agents]\n        self.ts = np.zeros(len(self.envs), dtype='int')\n        self.actions = None\n\n    def step_async(self, actions):\n        self.actions = actions\n\n    def step_wait(self):\n        results = [env.step(a) for (a,env) in zip(self.actions, self.envs)]\n        obs, rews, dones, infos = map(np.array, zip(*results))\n        self.ts += 1\n        for (i, done) in enumerate(dones):\n            if all(done):\n                obs[i] = self.envs[i].reset()\n                self.ts[i] = 0\n        self.actions = None\n        return np.array(obs), np.array(rews), np.array(dones), infos\n\n    def reset(self):\n        results = [env.reset() for env in self.envs]\n        return np.array(results)\n\n    def close(self):\n        return"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/MPE/utils/networks.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport time\nfrom braincog.base.node.node import LIFNode\n\nclass MLPNetwork(nn.Module):\n    \"\"\"\n    MLP network (can be used as value or policy)\n    \"\"\"\n    def __init__(self, input_dim, out_dim, hidden_dim=64, nonlin=F.relu,\n                 constrain_out=False, norm_in=True, discrete_action=True):\n        \"\"\"\n        Inputs:\n            input_dim (int): Number of dimensions in input\n            out_dim (int): Number of dimensions in output\n            hidden_dim (int): Number of hidden dimensions\n            nonlin (PyTorch function): Nonlinearity to apply to hidden layers\n        \"\"\"\n        super(MLPNetwork, self).__init__()\n\n        if norm_in:  # normalize inputs\n            self.in_fn = nn.BatchNorm1d(input_dim)    #train\n            # self.in_fn = input_dim  #test\n            self.in_fn.weight.data.fill_(1)\n            self.in_fn.bias.data.fill_(0)\n        else:\n            self.in_fn = lambda x: x\n        self.fc1 = nn.Linear(input_dim, hidden_dim)\n        self.fc2 = nn.Linear(hidden_dim, hidden_dim)\n        self.fc3 = nn.Linear(hidden_dim, out_dim)\n        self.nonlin = nonlin\n        if constrain_out and not discrete_action:\n            # initialize small to prevent saturation\n            self.fc3.weight.data.uniform_(-3e-3, 3e-3)\n            self.out_fn = F.tanh\n        else:  # logits for discrete action (will softmax later)\n            self.out_fn = lambda x: x\n\n    def forward(self, X):\n        \"\"\"\n        Inputs:\n            X (PyTorch Matrix): Batch of observations\n        Outputs:\n            out (PyTorch Matrix): Output of network (actions, values, etc)\n        \"\"\"\n        h1 = self.nonlin(self.fc1(self.in_fn(X)))\n        h2 = self.nonlin(self.fc2(h1))\n        out = self.out_fn(self.fc3(h2))\n        return out\n\n\nclass BCNoSpikingLIFNode(LIFNode):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n    def forward(self, dv: torch.Tensor):\n        self.integral(dv)\n        return self.mem\n\nclass SNNNetwork(nn.Module):\n    \"\"\"\n    SNN network (can be used as value or policy or MLE)\n    \"\"\"\n    def __init__(self, input_dim, out_dim, hidden_dim=64, node=LIFNode, time_window=16,\n                 norm_in=True, output_style='sum'):\n        \"\"\"\n        Inputs:\n            input_dim (int): Number of dimensions in input\n            out_dim (int): Number of dimensions in output\n            hidden_dim (int): Number of hidden dimensions\n            nonlin (PyTorch function): Nonlinearity to apply to hidden layers\n        \"\"\"\n        super(SNNNetwork, self).__init__()\n\n        self._threshold = 0.5\n        self.v_reset = 0.0\n        self._time_window = time_window\n        self.output_style = output_style\n        self._node1 = node(threshold=self._threshold, v_reset=self.v_reset)\n        self._node2 = node(threshold=self._threshold, v_reset=self.v_reset)\n\n        if norm_in:  # normalize inputs\n            self.in_fn = nn.BatchNorm1d(input_dim)    #train\n            self.in_fn.weight.data.fill_(1)\n            self.in_fn.bias.data.fill_(0)\n        else:\n            self.in_fn = lambda x: x\n\n        self.fc1 = nn.Linear(input_dim, hidden_dim)\n        self.fc2 = nn.Linear(hidden_dim, hidden_dim)\n        self.fc3 = nn.Linear(hidden_dim, out_dim)\n\n        if self.output_style == 'sum':\n            self._out_node = lambda x: x\n        elif self.output_style == 'voltage':\n            self._out_node = BCNoSpikingLIFNode()\n\n    def reset(self):\n        for mod in self.modules():\n            if hasattr(mod, 'n_reset'):\n                mod.n_reset()\n\n    def forward(self, X):\n        qs = []\n        self.reset()\n        for t in range(self._time_window):\n            x = self.fc1((self.in_fn(X)+0.5)) #train\n            # x = self.fc1((X + 0.5)) #test\n            x = self._node1(x)\n            x = self.fc2(x)\n            x = self._node2(x)\n            x = self.fc3(x)\n            x = self._out_node(x)\n            qs.append(x)\n\n        if self.output_style == 'sum':\n            outputs = sum(qs) / self._time_window\n            return outputs\n        elif self.output_style == 'voltage':\n            outputs = x\n            return outputs\n\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/MPE/utils/noise.py",
    "content": "import numpy as np\n\n\n# from https://github.com/songrotek/DDPG/blob/master/ou_noise.py\nclass OUNoise:\n    def __init__(self, action_dimension, scale=0.1, mu=0, theta=0.15, sigma=0.2):\n        self.action_dimension = action_dimension\n        self.scale = scale\n        self.mu = mu\n        self.theta = theta\n        self.sigma = sigma\n        self.state = np.ones(self.action_dimension) * self.mu\n        self.reset()\n\n    def reset(self):\n        self.state = np.ones(self.action_dimension) * self.mu\n\n    def noise(self):\n        x = self.state\n        dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(len(x))\n        self.state = x + dx\n        return self.state * self.scale\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/README.md",
    "content": "# MAToM-SNN"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/STAG/agents/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/STAG/agents/sagent.py",
    "content": "import numpy as np\nimport torch\nfrom torch.distributions import Categorical\nfrom braincog.base.encoder.population_coding import PEncoder\nfrom spikingjelly.activation_based import functional\nTIMESTEPS = 15\nM = 5\n\n# Agent\nclass Agents:\n    def __init__(self, args):\n        self.n_actions = args.n_actions\n        self.n_agents = args.n_agents\n        self.obs_shape = args.obs_shape\n        # encoder\n        self.pencoder = PEncoder(TIMESTEPS, 'population_voltage')\n\n        if args.alg == 'ppo':\n            from policy.ppo import PPO\n            self.policy = PPO(args)\n        if args.alg == 'iql':\n            from policy.iql import IQL\n            self.policy = IQL(args)\n            if args.mode == 'test':\n                self.policy.load_model(395000)\n        if args.alg == 'svdn':\n            from policy.svdn import SVDN\n            self.policy = SVDN(args)\n        if args.alg == 'scovdn':\n            from policy.scovdn import SCOVDN\n            self.policy = SCOVDN(args)\n        if args.alg == 'stomvdn':\n            from policy.stomvdn import SToMVDN\n            self.policy = SToMVDN(args)\n        if args.alg == 'scovdn_weight':\n            from policy.scovdn_weight import SCOVDN_W\n            self.policy = SCOVDN_W(args)\n        if args.alg == 'siql':\n            from policy.siql import SIQL\n            self.policy = SIQL(args)\n        if args.alg == 'scoiql':\n            from policy.scoiql import SCOIQL\n            self.policy = SCOIQL(args)\n        if args.alg == 'siql_e':\n            from policy.siql_encoder import SIQL_E\n            self.policy = SIQL_E(args)\n        if args.alg == 'siql_e2':\n            from policy.siql_encoder2 import SIQL_EE\n            self.policy = SIQL_EE(args)\n        if args.alg == 'siql_no_rnn':\n            from policy.siql_no_rnn import SIQLUR\n            self.policy = SIQLUR(args)\n        if args.alg == 'siql_no_rnn2':\n            from policy.siql_no_rnn2 import SIQLUR2\n            self.policy = SIQLUR2(args)\n        self.args = args\n\n    def choose_action(self, num_env, obs, last_action, agent_num, avail_actions, epsilon, maven_z=None, evaluate=False):\n        inputs = obs.copy()\n        avail_actions_ind = np.nonzero(avail_actions)[0]  # index of actions which can be choose\n\n        # transform agent_num to onehot vector\n        agent_id = np.zeros((num_env, self.n_agents))\n        agent_id[:, agent_num] = 1.\n\n        if self.args.last_action:\n            inputs = np.hstack((inputs, last_action))\n        if self.args.reuse_network:\n            inputs = np.hstack((inputs, agent_id))\n\n        # transform the shape of inputs from (42,) to (1,42)\n        inputs = torch.tensor(inputs, dtype=torch.float32).unsqueeze(0)  # torch.Size([1, 17])\n        # init hidden tensor\n        if self.args.alg == 'siql_e' or self.args.alg == 'siql_e2':\n            h1_mem = self.policy.eval_h1_mem[:, agent_num, :, :, :, :]\n            h1_spike = self.policy.eval_h1_spike[:, agent_num, :, :, :, :]\n            h2_mem = self.policy.eval_h2_mem[:, agent_num, :, :, :, :]\n            h2_spike = self.policy.eval_h2_spike[:, agent_num, :, :, :, :]\n            inputs_, _ = self.pencoder(inputs=inputs, num_popneurons=M, VTH=0.99)    ###########################################################\n            inputs = torch.transpose(inputs_, 0, 3)\n            inputs = inputs.squeeze().unsqueeze(0)\n\n        else:\n\n            h1_mem = self.policy.eval_h1_mem[:, agent_num, :, :]    #\n            h1_spike = self.policy.eval_h1_spike[:, agent_num, :, :]\n            h2_mem = self.policy.eval_h2_mem[:, agent_num, :, :]\n            h2_spike = self.policy.eval_h2_spike[:, agent_num, :, :]\n\n\n        avail_actions = torch.tensor(avail_actions, dtype=torch.float32).unsqueeze(0)\n        if self.args.cuda:\n            inputs = inputs.cuda(self.args.device)\n            h1_mem = h1_mem.cuda(self.args.device)\n            h1_spike = h1_spike.cuda(self.args.device)\n            h2_mem = h2_mem.cuda(self.args.device)\n            h2_spike = h2_spike.cuda(self.args.device)\n        # get q value\n        if self.args.alg == 'siql_no_rnn' or self.args.alg == 'siql_no_rnn2':\n            self.policy.eval_snn.reset()\n            q_value = self.policy.eval_snn(inputs)\n            # functional.reset_net(self.policy_sc.eval_snn)\n        else:\n            q_value, self.policy.eval_h1_mem[:, agent_num, :], self.policy.eval_h1_spike[:, agent_num, :],\\\n                self.policy.eval_h2_mem[:, agent_num, :], self.policy.eval_h2_spike[:, agent_num, :]= \\\n                self.policy.eval_snn(inputs, h1_mem, h1_spike, h2_mem, h2_spike)\n\n        # choose action from q value\n        # q_value[avail_actions == 0.0] = - float(\"inf\")\n        if self.args.alg == 'siql_e' or self.args.alg == 'siql_e2':\n            q_value = q_value.sum(dim=2)\n            q_value = q_value.sum(dim=2)\n        if np.random.uniform() < epsilon:\n            # action = np.random.choice(avail_actions_ind)  # action是一个整数\n            action = torch.tensor([[np.random.choice(avail_actions_ind) for i in range(num_env)]])\n        else:\n            action = torch.argmax(q_value, 2)\n        return action\n\n\n    def _choose_action_from_softmax(self, inputs, avail_actions, epsilon, evaluate=False):\n        \"\"\"\n        :param_sc inputs: # q_value of all actions\n        \"\"\"\n        action_num = avail_actions.sum(dim=1, keepdim=True).float().repeat(1, avail_actions.shape[-1])  # num of avail_actions\n        # 先将Actor网络的输出通过softmax转换成概率分布\n        prob = torch.nn.functional.softmax(inputs, dim=-1)\n        # add noise of epsilon\n        prob = ((1 - epsilon) * prob + torch.ones_like(prob) * epsilon / action_num)\n        prob[avail_actions == 0] = 0.0  # 不能执行的动作概率为0\n\n        \"\"\"\n        不能执行的动作概率为0之后，prob中的概率和不为1，这里不需要进行正则化，因为torch.distributions.Categorical\n        会将其进行正则化。要注意在训练的过程中没有用到Categorical，所以训练时取执行的动作对应的概率需要再正则化。\n        \"\"\"\n\n        if epsilon == 0 and evaluate:\n            action = torch.argmax(prob)\n        else:\n            action = Categorical(prob).sample().long()\n        return action\n\n    def _get_max_episode_len(self, batch):\n        terminated = batch['TERMINATE']\n        episode_num = terminated.shape[0]\n        max_episode_len = 0\n        for episode_idx in range(episode_num):\n            for transition_idx in range(self.args.episode_limit):\n                if terminated[episode_idx, transition_idx, 0] == 1:\n                    if transition_idx + 1 >= max_episode_len:\n                        max_episode_len = transition_idx + 1\n                    break\n        if max_episode_len == 0:  # 防止所有的episode都没有结束，导致terminated中没有1\n            max_episode_len = self.args.episode_limit\n        return max_episode_len\n\n    def train(self, batch, train_step, epsilon=None):  # coma needs epsilon for training\n\n        # different episode has different length, so we need to get max length of the batch\n        max_episode_len = self._get_max_episode_len(batch)\n        for key in batch.keys():\n            if key != 'z':\n                batch[key] = batch[key][:, :max_episode_len]\n        self.policy.learn(batch, max_episode_len, train_step, epsilon)\n        if train_step > 0 and train_step % self.args.save_cycle == 0:\n            self.policy.save_model(train_step)\n\n\n\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/STAG/common_sr/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/STAG/common_sr/arguments.py",
    "content": "import argparse\n\ndef get_common_args():\n    parser = argparse.ArgumentParser()\n\n    ## multiprocessing\n    parser.add_argument('--process', type=int, default=5, help='multiprocessing')\n\n    ## the environment setting  'CLASSIC', 'HUNT', 'HARVEST', 'ESCALATION'\n    parser.add_argument('--ENV', type=str, default='HUNT', help='the version of the game, choose from [\"CLASSIC\", \"HUNT\", \"HARVEST\", \"ESCALATION\"]')\n    parser.add_argument('--env_name', type=str, default='stag_stay', help='the version of the game, choose from [\"CLASSIC\", \"HUNT\", \"HARVEST\", \"ESCALATION\"]')\n\n    parser.add_argument('--obs_type', type=str, default='coords', help='Can be \"image\" for pixel-array based observations, or \"coords\" for just the entity coordinates')\n    parser.add_argument('--forage_quantity', type=int, default=2, help='the number of trees')\n    parser.add_argument('--opponent_policy', type=str, default='random', help='the poliocy of opponent')\n    parser.add_argument('--replay_dir', type=str, default='', help='absolute path to save the replay')\n\n    ## The alternative policy   ################################################\n    parser.add_argument('--num_run', type=int, default='4', help='the number of run')\n    ## 'svdn',  'stomvdn'\n    parser.add_argument('--alg', type=str, default='svdn', help='the algorithm to train the agent')\n    parser.add_argument('--mode', type=str, default='train', help='the mode')\n    parser.add_argument('--n_steps', type=int, default=1000000, help='total time steps')#2000000\n    parser.add_argument('--n_episodes', type=int, default=2, help='the number of episodes before once training')\n    parser.add_argument('--last_action', type=bool, default=True, help='whether to use the last action to choose action')\n    parser.add_argument('--reuse_network', type=bool, default=True, help='whether to use one network_sc for all agents_sc')\n    parser.add_argument('--gamma', type=float, default=0.99, help='discount factor')\n    parser.add_argument('--epsilon', type=float, default=1.0, help='epsilon  factor')\n    ############# \"Adam\"\n    parser.add_argument('--optimizer', type=str, default=\"RMS\", help='optimizer')\n    parser.add_argument('--evaluate_cycle', type=int, default=10, help='how often to evaluate the model')#5000\n    parser.add_argument('--evaluate_epoch', type=int, default=6, help='number of the epoch to evaluate the agent')#32\n\n    ## save weights->model/args->log/reward->result/plot\n    parser.add_argument('--model_dir', type=str, default='./model', help='model directory of the policy_base')\n    parser.add_argument('--result_dir', type=str, default='./result', help='result directory of the policy_base')#./result#/home/zhaozhuoya/exp2/ToM2_test/result\n    parser.add_argument('--log_dir', type=str, default='./log', help='args directory')\n    parser.add_argument('--plot_dir', type=str, default='./plot', help='args directory')\n\n    parser.add_argument('--exp_dir', type=str, default='/exp_vdn', help='result directory of the policy_base')\n    parser.add_argument('--save_model_dir', type=str, default='/199_rnn_net_params_hunt1.pkl', help='load weights and bias')\n    parser.add_argument('--load_model', type=bool, default=False, help='whether to load the pretrained model')\n    parser.add_argument('--evaluate', type=bool, default=False, help='whether to evaluate the model')\n    parser.add_argument('--cuda', type=bool, default=True, help='whether to use the GPU')   #True\n    parser.add_argument('--mini_batch_size', type=int, default=250, help='whether to use the GPU')\n    args = parser.parse_args()\n    parser.add_argument('--device', type=str, default='cuda:{}'.format(args.num_run), help='whether to use the GPU')  #'cuda:1'\n    args = parser.parse_args()\n    return args\n\n# arguments of coma\ndef get_coma_args(args):\n    # network_sc\n    args.rnn_hidden_dim = 64\n    args.critic_dim = 128\n    args.lr_actor = 1e-4\n    args.lr_critic = 1e-3\n\n    # epsilon-greedy\n    # args.epsilon = 0.5\n    args.anneal_epsilon = 0.00064\n    args.min_epsilon = 0.02\n    args.epsilon_anneal_scale = 'episode'\n\n    # lambda of td-lambda return\n    args.td_lambda = 0.8\n\n    # how often to save the model\n    args.save_cycle = 5000\n\n    # how often to update the target_net\n    args.target_update_cycle = 200\n\n    # prevent gradient explosion\n    args.grad_norm_clip = 10\n\n    return args\n\n# arguments of vnd、 qmix、 qtran\ndef get_mixer_args(args):\n    # network_sc\n    args.rnn_hidden_dim = 64\n    args.qmix_hidden_dim = 32\n    args.two_hyper_layers = False\n    args.hyper_hidden_dim = 64\n    args.qtran_hidden_dim = 64\n    args.ppo_hidden_size = 64\n    args.lr = 5e-4\n\n    # epsilon greedy\n    # args.epsilon = 1\n    args.min_epsilon = 0.05\n    anneal_steps = 50000\n    args.anneal_epsilon = (args.epsilon - args.min_epsilon) / anneal_steps\n    args.epsilon_anneal_scale = 'step'\n\n    # the number of the train steps in one epoch\n    args.train_steps = 1\n\n    # experience replay\n    args.batch_size = 32\n    args.buffer_size = int(5e3)\n\n    # how often to save the model\n    args.save_cycle = 5000\n\n    # how often to update the target_net\n    args.target_update_cycle = 200\n\n    # QTRAN lambda\n    args.lambda_opt = 1\n    args.lambda_nopt = 1\n\n    # prevent gradient explosion\n    args.grad_norm_clip = 10\n\n    # MAVEN\n    args.noise_dim = 16\n    args.lambda_mi = 0.001\n    args.lambda_ql = 1\n    args.entropy_coefficient = 0.001\n    return args\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/STAG/common_sr/dummy_vec_env.py",
    "content": "import numpy as np\nfrom .vec_env import VecEnv\nfrom .util import copy_obs_dict, dict_to_obs, obs_space_info\n\nclass DummyVecEnv(VecEnv):\n    \"\"\"\n    VecEnv that does runs multiple environments sequentially, that is,\n    the step and reset commands are send to one environment at a time.\n    Useful when debugging and when num_env == 1 (in the latter case,\n    avoids communication overhead)\n    \"\"\"\n    def __init__(self, env_fns):\n        \"\"\"\n        Arguments:\n\n        env_fns: iterable of callables      functions that build environments\n        \"\"\"\n        self.envs = [fn() for fn in env_fns]\n        env = self.envs[0]\n        VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)\n        obs_space = env.observation_space\n        self.keys, shapes, dtypes = obs_space_info(obs_space)\n\n        self.buf_obs = { k: np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k]) for k in self.keys }\n        self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool)\n        self.buf_rews  = np.zeros((self.num_envs,), dtype=np.float32)\n        self.buf_infos = [{} for _ in range(self.num_envs)]\n        self.actions = None\n        self.spec = self.envs[0].spec\n\n    def step_async(self, actions):\n        listify = True\n        try:\n            if len(actions) == self.num_envs:\n                listify = False\n        except TypeError:\n            pass\n\n        if not listify:\n            self.actions = actions\n        else:\n            assert self.num_envs == 1, \"actions {} is either not a list or has a wrong size - cannot match to {} environments\".format(actions, self.num_envs)\n            self.actions = [actions]\n\n    def step_wait(self):\n        for e in range(self.num_envs):\n            action = self.actions[e]\n            # if isinstance(self.envs_sc[e].action_space, spaces.Discrete):\n            #    action = int(action)\n\n            obs, self.buf_rews[e], self.buf_dones[e], self.buf_infos[e] = self.envs[e].step(action)\n            if self.buf_dones[e]:\n                obs = self.envs[e].reset()\n            self._save_obs(e, obs)\n        return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones),\n                self.buf_infos.copy())\n\n    def reset(self):\n        for e in range(self.num_envs):\n            obs = self.envs[e].reset()\n            self._save_obs(e, obs)\n        return self._obs_from_buf()\n\n    def _save_obs(self, e, obs):\n        for k in self.keys:\n            if k is None:\n                self.buf_obs[k][e] = obs\n            else:\n                self.buf_obs[k][e] = obs[k]\n\n    def _obs_from_buf(self):\n        return dict_to_obs(copy_obs_dict(self.buf_obs))\n\n    def get_images(self):\n        return [env.render(mode='rgb_array') for env in self.envs]\n\n    def render(self, mode='human'):\n        if self.num_envs == 1:\n            return self.envs[0].render(mode=mode)\n        else:\n            return super().render(mode=mode)\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/STAG/common_sr/multiprocessing_env.py",
    "content": "# This code is from openai baseline\n# https://github.com/openai/baselines/tree/master/baselines/common/vec_env\n\nimport numpy as np\nfrom multiprocessing import Process, Pipe\n\ndef _flatten_list(l):\n    assert isinstance(l, (list, tuple))\n    assert len(l) > 0\n    assert all([len(l_) > 0 for l_ in l])\n\n    return [l__ for l_ in l for l__ in l_]\n\ndef worker(remote, parent_remote, env_fn_wrapper):\n    parent_remote.close()\n    env = env_fn_wrapper.x()\n    while True:\n        cmd, data = remote.recv()\n        if cmd == 'step':\n            ob, reward, done, info = env.step(data)\n            if np.array(done).all():\n            # if done:\n                ob = env.reset()\n            remote.send((ob, reward, done, info))\n        elif cmd == 'reset':\n            ob = env.reset()\n            remote.send(ob)\n        elif cmd == 'reset_task':\n            ob = env.reset_task()\n            remote.send(ob)\n        elif cmd == 'render':\n            ob = env.render()\n            remote.send(ob) #rgb_array\n        elif cmd == 'close':\n            remote.close()\n            break\n        elif cmd == 'get_spaces':\n            remote.send((env.observation_space, env.action_space))\n\n        else:\n            raise NotImplementedError\n\nclass VecEnv(object):\n    \"\"\"\n    An abstract asynchronous, vectorized environment.\n    \"\"\"\n    def __init__(self, num_envs, observation_space, action_space):\n        self.num_envs = num_envs\n        self.observation_space = observation_space\n        self.action_space = action_space\n\n    def reset(self):\n        \"\"\"\n        Reset all the environments and return an array of\n        observations, or a tuple of observation arrays.\n        If step_async is still doing work, that work will\n        be cancelled and step_wait() should not be called\n        until step_async() is invoked again.\n        \"\"\"\n        pass\n\n    def step_async(self, actions):\n        \"\"\"\n        Tell all the environments to start taking a step\n        with the given actions.\n        Call step_wait() to get the results of the step.\n        You should not call this if a step_async run is\n        already pending.\n        \"\"\"\n        pass\n\n    def step_wait(self):\n        \"\"\"\n        Wait for the step taken with step_async().\n        Returns (obs, rews, dones, infos):\n         - obs: an array of observations, or a tuple of\n                arrays of observations.\n         - rews: an array of rewards\n         - dones: an array of \"episode done\" booleans\n         - infos: a sequence of info objects\n        \"\"\"\n        pass\n\n    def close(self):\n        \"\"\"\n        Clean up the environments' resources.\n        \"\"\"\n        pass\n\n    def step(self, actions):\n        self.step_async(actions)\n        return self.step_wait()\n\n    def render(self, mode='human'):\n        imgs = self.get_images()\n        # bigimg = tile_images(imgs)\n        # if mode == 'human':\n        #     self.get_viewer().imshow(bigimg)    #\n        #     return self.get_viewer().isopen\n    #     elif mode == 'rgb_array':\n    #         return bigimg\n    #     else:\n    #         raise NotImplementedError\n\n    def get_images(self):\n        \"\"\"\n        Return RGB images from each environment\n        \"\"\"\n        raise NotImplementedError\n\n    \nclass CloudpickleWrapper(object):\n    \"\"\"\n    Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)\n    \"\"\"\n    def __init__(self, x):\n        self.x = x\n    def __getstate__(self):\n        import cloudpickle\n        return cloudpickle.dumps(self.x)\n    def __setstate__(self, ob):\n        import pickle\n        self.x = pickle.loads(ob)\n\n        \nclass SubprocVecEnv(VecEnv):\n    def __init__(self, env_fns, spaces=None):\n        \"\"\"\n        envs_sc: list of gym environments to run in subprocesses\n        \"\"\"\n        self.waiting = False\n        self.closed = False\n        nenvs = len(env_fns)\n        self.nenvs = nenvs\n        self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])\n        self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))\n            for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]\n        for p in self.ps:\n            p.daemon = True # if the main process crashes, we should not cause things to hang\n            p.start()\n        for remote in self.work_remotes:\n            remote.close()\n\n        self.remotes[0].send(('get_spaces', None))\n        observation_space, action_space = self.remotes[0].recv()\n        VecEnv.__init__(self, len(env_fns), observation_space, action_space)\n\n    def step_async(self, actions):\n        for remote, action in zip(self.remotes, actions):\n            remote.send(('step', action))\n        self.waiting = True\n\n    def step_wait(self):\n        results = [remote.recv() for remote in self.remotes]\n        self.waiting = False\n        obs, rews, dones, infos = zip(*results)\n        return np.stack(obs), np.stack(rews), np.stack(dones), infos\n\n    def reset(self):\n        for remote in self.remotes:\n            remote.send(('reset', None))\n        return np.stack([remote.recv() for remote in self.remotes])\n\n    def reset_task(self):\n        for remote in self.remotes:\n            remote.send(('reset_task', None))\n        return np.stack([remote.recv() for remote in self.remotes])\n\n    def get_images(self):\n        # self._assert_not_closed()\n        for pipe in self.remotes:\n            pipe.send(('render', None))\n        imgs = [pipe.recv() for pipe in self.remotes]\n        # imgs = _flatten_list(imgs)\n        return imgs\n\n    def close(self):\n        if self.closed:\n            return\n        if self.waiting:\n            for remote in self.remotes:            \n                remote.recv()\n        for remote in self.remotes:\n            remote.send(('close', None))\n        for p in self.ps:\n            p.join()\n            self.closed = True\n            \n    def __len__(self):\n        return self.nenvs"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/STAG/common_sr/replay_buffer.py",
    "content": "import numpy as np\nimport threading\n\n\nclass ReplayBuffer:\n    def __init__(self, args):\n        self.args = args\n        self.n_actions = self.args.n_actions\n        self.n_agents = self.args.n_agents\n        # self.state_shape = self.args.state_shape\n        self.obs_shape = self.args.obs_shape\n        self.size = self.args.buffer_size\n        self.episode_limit = self.args.episode_limit\n        # memory management\n        self.current_idx = 0\n        self.current_size = 0\n        # create the buffer to store info\n        self.buffers = {'O': np.empty([self.size, self.episode_limit, self.n_agents, self.obs_shape]),\n                        'U': np.empty([self.size, self.episode_limit, self.n_agents, 1]),\n                        # 's': np.empty([self.size, self.episode_limit, self.state_shape]),\n                        'R': np.empty([self.size, self.episode_limit, self.n_agents, 1]),\n                        'O_NEXT': np.empty([self.size, self.episode_limit, self.n_agents, self.obs_shape]),\n                        # 's_next': np.empty([self.size, self.episode_limit, self.state_shape]),\n                        'AVAIL_U': np.empty([self.size, self.episode_limit, self.n_agents, self.n_actions]),\n                        'AVAIL_U_NEXT': np.empty([self.size, self.episode_limit, self.n_agents, self.n_actions]),\n                        'U_ONEHOT': np.empty([self.size, self.episode_limit, self.n_agents, self.n_actions]),\n                        'PADDED': np.empty([self.size, self.episode_limit, 1]),\n                        'TERMINATE': np.empty([self.size, self.episode_limit, 1])\n                        }\n        # thread lock\n        self.lock = threading.Lock()\n\n        # store the episode\n    def store_episode(self, episode_batch):\n        batch_size = episode_batch['O'].shape[0]  # episode_number\n        with self.lock:\n            idxs = self._get_storage_idx(inc=batch_size)\n            # store the informations\n            self.buffers['O'][idxs] = episode_batch['O']\n            self.buffers['U'][idxs] = episode_batch['U']\n            # self.buffers['s'][idxs] = episode_batch['s']\n            self.buffers['R'][idxs] = episode_batch['R']\n            self.buffers['O_NEXT'][idxs] = episode_batch['O_NEXT']\n            # self.buffers['s_next'][idxs] = episode_batch['s_next']\n            self.buffers['AVAIL_U'][idxs] = episode_batch['AVAIL_U']\n            self.buffers['AVAIL_U_NEXT'][idxs] = episode_batch['AVAIL_U_NEXT']\n            self.buffers['U_ONEHOT'][idxs] = episode_batch['U_ONEHOT']\n            self.buffers['PADDED'][idxs] = episode_batch['PADDED']\n            self.buffers['TERMINATE'][idxs] = episode_batch['TERMINATE']\n            if self.args.alg == 'maven':\n                self.buffers['z'][idxs] = episode_batch['z']\n\n    def sample(self, batch_size):\n        temp_buffer = {}\n        idx = np.random.randint(0, self.current_size, batch_size)\n        for key in self.buffers.keys():\n            temp_buffer[key] = self.buffers[key][idx]\n        return temp_buffer\n\n    def _get_storage_idx(self, inc=None):\n        inc = inc or 1\n        if self.current_idx + inc <= self.size:\n            idx = np.arange(self.current_idx, self.current_idx + inc)\n            self.current_idx += inc\n        elif self.current_idx < self.size:\n            overflow = inc - (self.size - self.current_idx)\n            idx_a = np.arange(self.current_idx, self.size)\n            idx_b = np.arange(0, overflow)\n            idx = np.concatenate([idx_a, idx_b])\n            self.current_idx = overflow\n        else:\n            idx = np.arange(0, inc)\n            self.current_idx = inc\n        self.current_size = min(self.size, self.current_size + inc)\n        if inc == 1:\n            idx = idx[0]\n        return idx\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/STAG/common_sr/srollout.py",
    "content": "import numpy as np\nimport torch\n\nclass RolloutWorker:\n    def __init__(self, env, agents, args):\n        self.env = env\n        self.agents = agents\n        self.episode_limit = args.episode_limit\n        self.n_actions = args.n_actions\n        self.n_agents = args.n_agents\n        self.obs_shape = args.obs_shape\n        self.args = args\n\n        self.epsilon = args.epsilon\n        self.anneal_epsilon = args.anneal_epsilon\n        self.min_epsilon = args.min_epsilon\n        print('Init RolloutWorker')\n\n    def generate_episode(self, episode_num=None, evaluate=False):\n        # if self.args.alg == 'siql_no_rnn':\n        #     from policy_sc.siql_no_rnn import SIQLUR\n        #     self.policy_sc = SIQLUR(self.args)\n        #     self.policy_sc.eval_snn.reset()\n\n        if self.args.replay_dir != '' and evaluate and episode_num == 0:  # prepare for save replay of evaluation\n            self.env.close()\n        # Store all data\n        EPISODE = dict(\n                        O = [],\n                        U = [],\n                        R = [],\n                        O_NEXT = [],\n                        U_ONEHOT = [],\n                        AVAIL_U = [],\n                        AVAIL_U_NEXT = [],\n                        PADDED = [],\n                        TERMINATE = [],\n        )\n\n        NUM_EPISODES = self.args.n_episodes if evaluate==False else self.args.evaluate_epoch\n        episode_num = 0 if evaluate == False else self.args.evaluate_epoch\n\n        episode_reward = np.zeros((self.args.process, self.n_agents))\n\n        for episode_idx in range(NUM_EPISODES):\n            # Store one multiprocessing data\n            o, u, r, avail_u, u_onehot, terminate, padded = [], [], [], [], [], [], []\n            obs = self.env.reset()\n            obs1 = obs.copy()\n            obs1[:, 0], obs1[:, 1], obs1[:, 2], obs1[:, 3] = \\\n                obs[:, 2], obs[:, 3], obs[:, 0], obs[:, 1]\n            obs_ = (obs, obs1)\n            obs_ = np.stack((obs, obs1), axis=0).transpose(1, 0, 2)\n            num_env = obs.shape[0]\n\n            last_action = np.zeros((self.args.n_agents, num_env, self.args.n_actions))\n            self.agents.policy.init_hidden(1, num_env)\n            terminated = False\n            win_tag = False\n            step = 0\n\n            # epsilon\n            epsilon = 0 if evaluate else self.epsilon\n            if self.args.epsilon_anneal_scale == 'episode':\n                epsilon = epsilon - self.anneal_epsilon if epsilon > self.min_epsilon else epsilon\n\n            # for each episode (include 50 steps and num_env multiprocessing)\n            while not terminated and step < self.episode_limit:\n                # time.sleep(0.2)\n                obs = np.array(obs_)    #A perspective, B perspective\n                avail_action = [1] * self.args.n_actions\n                actions, avail_actions, actions_onehot = [], [], []\n\n                for agent_id in range(self.n_agents):\n                    action = self.agents.choose_action(num_env, obs[:, agent_id, :], last_action[agent_id],\n                                                       agent_id, avail_action, epsilon, evaluate)\n                    # generate onehot vector of th action\n                    action_onehot = np.zeros((num_env, self.args.n_actions))\n                    for i in range(num_env): action_onehot[i, action[0, i]] = 1\n                    actions.append(action[0].cpu().numpy().tolist())    #np.int(action)\n                    actions_onehot.append(action_onehot)\n                    avail_actions.append(avail_action)\n                    last_action[agent_id] = action_onehot\n\n                actions = np.array(actions).transpose(1,0)                  #[num_env, num_agent](4, 2)\n                obs_, reward, done, info = self.env.step(actions=actions)   #[num_env,num_agent,num_state],[num_env, num_agent],[num_env] (4, 2, 10) (4, 2)\n\n                if self.args.load_model == True:\n                    self.env.render(mode=\"human\")\n                    print(reward)\n\n                o.append(obs)\n                u.append(np.expand_dims(actions, self.n_agents))\n                u_onehot.append(actions_onehot)\n                avail_u.append(avail_actions)\n                r.append(np.expand_dims(reward, 2))\n                terminate.append(np.expand_dims(np.array([terminated]*num_env), 1))\n                padded.append(np.expand_dims(np.array([0.]*num_env), 1))\n                episode_reward = episode_reward + reward\n                step += 1\n                if self.args.epsilon_anneal_scale == 'step':\n                    epsilon = epsilon - self.anneal_epsilon if epsilon > self.min_epsilon else epsilon\n\n            # last obs\n            obs = np.array(obs_)\n            o.append(obs)\n            o_next = o[1:]\n            o = o[:-1]\n            # get avail_action for last obs，because target_q needs avail_action in training\n            avail_actions = []\n            for agent_id in range(self.n_agents):\n                avail_action = [1] * self.args.n_actions\n                avail_actions.append(avail_action)\n            avail_u.append(avail_actions)\n            avail_u_next = avail_u[1:]\n            avail_u = avail_u[:-1]\n\n            # if step < self.episode_limit，padding (if termined before the max steps, add data to max steps)\n            for i in range(step, self.episode_limit):\n                o.append(np.zeros((self.n_agents, self.obs_shape)))\n                u.append(np.zeros([self.n_agents, 1]))\n                r.append(np.zeros([self.n_agents, 1]))\n                o_next.append(np.zeros((self.n_agents, self.obs_shape)))\n                u_onehot.append(np.zeros((self.n_agents, self.n_actions)))\n                avail_u.append(np.zeros((self.n_agents, self.n_actions)))\n                avail_u_next.append(np.zeros((self.n_agents, self.n_actions)))\n                padded.append([1.]*num_env)\n                terminate.append([1.]*num_env)\n\n            # Processing data for each episode\n            EPISODE['O'].append(np.stack(o, axis=0).transpose(1, 0, 2, 3))\n            EPISODE['U'].append(np.stack(u, axis=0).transpose(1, 0, 2, 3).astype(int))\n            EPISODE['R'].append(np.stack(r, axis=0).transpose(1, 0, 2, 3))\n            EPISODE['O_NEXT'].append(np.stack(o_next, axis=0).transpose(1, 0, 2, 3))\n            EPISODE['U_ONEHOT'].append(np.stack(u_onehot, axis=0).transpose(2, 0, 1, 3))\n            EPISODE['AVAIL_U'].append(np.ones(EPISODE['U_ONEHOT'][0].shape))\n            EPISODE['AVAIL_U_NEXT'].append(np.ones(EPISODE['U_ONEHOT'][0].shape))\n            EPISODE['PADDED'].append(np.stack(padded, axis=0).transpose(1, 0, 2))\n            EPISODE['TERMINATE'].append(np.stack(terminate, axis=0).transpose(1, 0, 2))\n\n        episode_reward = episode_reward.sum(0)\n\n        for i in EPISODE.keys():\n            EPISODE[i] = np.concatenate(EPISODE[i], axis=0)\n        step = step * self.args.n_episodes * num_env\n\n        if not evaluate:\n            self.epsilon = epsilon\n        if evaluate and episode_num == self.args.evaluate_epoch and self.args.replay_dir != '':\n            self.env.save_replay()\n            self.env.close()\n        return EPISODE, episode_reward, win_tag, step\n\n    def generate_episode_sample(self, episodes, steps, episode_num=None, evaluate=False):\n        if self.args.replay_dir != '' and evaluate and episode_num == 0:  # prepare for save replay of evaluation\n            self.env.close()\n        o, u, r, avail_u, u_onehot, terminate, padded = [], [], [], [], [], [], []\n        obs = self.env.reset()\n        obs_ = (obs, self.env.game._flip_coord_observation_perspective(obs))  # A perspective, B perspective\n        terminated = False\n        win_tag = False\n        step = 0\n        episode_reward = (0, 0)  # cumulative rewards\n\n        # ###\n        # for param_sc in self.agents_sc.policy_base.parameters():\n        #     param_sc.requires_grad = False\n        # self.agents_sc.policy_base.eval()\n\n        last_action = np.zeros((self.args.n_agents, self.args.n_actions))\n        self.agents.policy.init_hidden(1)\n\n        # epsilon\n        epsilon = 0 if evaluate else self.epsilon\n        if self.args.epsilon_anneal_scale == 'episode':\n            epsilon = epsilon - self.anneal_epsilon if epsilon > self.min_epsilon else epsilon\n\n        while not terminated and step < self.episode_limit:\n            # time.sleep(0.2)\n            obs = np.array(obs_)    #A perspective, B perspective\n            avail_action = [1] * self.args.n_actions\n            actions, avail_actions, actions_onehot = [], [], []\n            for agent_id in range(self.n_agents):\n                action = self.agents.choose_action(obs[agent_id], last_action[agent_id], agent_id,\n                                                       avail_action, epsilon, evaluate)\n                # generate onehot vector of th action\n                action_onehot = np.zeros(self.args.n_actions)\n                action_onehot[action] = 1\n                actions.append(np.int(action))\n                actions_onehot.append(action_onehot)\n                avail_actions.append(avail_action)\n                last_action[agent_id] = action_onehot\n\n            obs_, reward, done, info = self.env.step(actions=actions)\n            # print(actions,reward)\n            win_tag = True if terminated else False\n            # save obs, actions, avail_actions, reward at time t\n            o.append(obs)\n            u.append(np.reshape(actions, [self.n_agents, 1]))\n            u_onehot.append(actions_onehot)\n            avail_u.append(avail_actions)\n            r.append(np.reshape(reward, [self.n_agents, 1]))    #reward\n            terminate.append([terminated])\n            padded.append([0.])\n            # episode_reward += reward\n            episode_reward = [episode_reward[i] + reward[i] for i in range(min(len(episode_reward), len(reward)))]\n\n            step += 1\n            if self.args.epsilon_anneal_scale == 'step':\n                epsilon = epsilon - self.anneal_epsilon if epsilon > self.min_epsilon else epsilon\n\n            if self.args.load_model == True:\n                self.env.render(mode=\"human\")\n\n        # last obs\n        obs = np.array(obs_)\n        o.append(obs)\n        o_next = o[1:]\n        o = o[:-1]\n        # get avail_action for last obs，because target_q needs avail_action in training\n        avail_actions = []\n        for agent_id in range(self.n_agents):\n            avail_action = [1] * self.args.n_actions\n            avail_actions.append(avail_action)\n        avail_u.append(avail_actions)\n        avail_u_next = avail_u[1:]\n        avail_u = avail_u[:-1]\n\n        # if step < self.episode_limit，padding\n        for i in range(step, self.episode_limit):\n            o.append(np.zeros((self.n_agents, self.obs_shape)))\n            u.append(np.zeros([self.n_agents, 1]))\n            r.append(np.zeros([self.n_agents, 1]))\n            o_next.append(np.zeros((self.n_agents, self.obs_shape)))\n            u_onehot.append(np.zeros((self.n_agents, self.n_actions)))\n            avail_u.append(np.zeros((self.n_agents, self.n_actions)))\n            avail_u_next.append(np.zeros((self.n_agents, self.n_actions)))\n            padded.append([1.])\n            terminate.append([1.])\n\n        episode = dict(o=o.copy(),\n                       u=u.copy(),\n                       r=r.copy(),\n                       avail_u=avail_u.copy(),\n                       o_next=o_next.copy(),\n                       avail_u_next=avail_u_next.copy(),\n                       u_onehot=u_onehot.copy(),\n                       padded=padded.copy(),\n                       terminated=terminate.copy()\n                       )\n\n        episodes[episode_num] = episode\n        steps[episode_num] = step\n\n        # add episode dim\n        for key in episode.keys():\n            episode[key] = np.array([episode[key]])\n        if not evaluate:\n            self.epsilon = epsilon\n        if evaluate and episode_num == self.args.evaluate_epoch - 1 and self.args.replay_dir != '':\n            self.env.save_replay()\n            self.env.close()\n       # return episode, episode_reward, win_tag, step\n\n\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/STAG/envs/Stag_Hunt_env.py",
    "content": "import gym\nimport gym_stag_hunt\nfrom ray import tune\nfrom ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv\nfrom gym_stag_hunt.envs.pettingzoo.hunt import raw_env\n\nif __name__ == \"__main__\":\n    def env_creator(args):\n        return PettingZooEnv(raw_env(**args))\n\n    tune.register_env(\"StagHunt-Hunt-PZ-v0\", env_creator)\n\n    model = tune.run(\n        \"DQN\",\n        name=\"stag_hunt\",\n        stop={\"episodes_total\": 10000},\n        checkpoint_freq=100,\n        checkpoint_at_end=True,\n        config={\n            \"horizon\": 100,\n            \"framework\": \"tf2\",\n            # Environment specific\n            \"env\": \"StagHunt-Hunt-PZ-v0\",\n            # General\n            \"num_workers\": 2,\n            # Method specific\n            \"multiagent\": {\n                \"policies\": {\"player_0\", \"player_1\"},\n                \"policy_mapping_fn\": (lambda agent_id, episode, **kwargs: agent_id),\n                \"policies_to_train\": [\"player_0\", \"player_1\"]\n            },\n            # Env Specific\n            \"env_config\": {\n                \"obs_type\": \"coords\",\n                \"forage_reward\": 1.0,\n                \"stag_reward\": 5.0,\n                \"stag_follows\": True,\n                \"mauling_punishment\": -.5,\n                \"enable_multiagent\": True,\n            }\n        }\n    )"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/STAG/envs/__init__.py",
    "content": "from ToM2.envs.grid_env1 import *\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/STAG/envs/abstract.py",
    "content": "\"\"\"\nImplements abstract class for meta-reinforcement learning environments.\n\"\"\"\n\nfrom typing import Generic, TypeVar, Tuple\nimport abc\n\n\nObsType = TypeVar('ObsType')\n\n\nclass MetaEpisodicEnv(abc.ABC, Generic[ObsType]):\n    @property\n    @abc.abstractmethod\n    def max_episode_len(self) -> int:\n        \"\"\"\n        Return the maximum episode length.\n        \"\"\"\n        pass\n\n    @abc.abstractmethod\n    def new_env(self) -> None:\n        \"\"\"\n        Reset the environment's structure by resampling\n        the state transition probabilities and/or reward function\n        from a prior distribution.\n\n        Returns:\n            None\n        \"\"\"\n        pass\n\n    @abc.abstractmethod\n    def reset(self) -> ObsType:\n        \"\"\"\n        Resets the environment's state to some designated initial state.\n        This is distinct from resetting the environment's structure\n            via self.new_env().\n\n        Returns:\n            initial observation.\n        \"\"\"\n        pass\n\n    @abc.abstractmethod\n    def step(\n        self,\n        action: int,\n        auto_reset: bool = True\n    ) -> Tuple[ObsType, float, bool, dict]:\n        \"\"\"\n        Step the env.\n\n        Args:\n            action: integer action indicating which action to take\n            auto_reset: whether or not to automatically reset the environment\n                on done. if true, next observation will be given by self.reset()\n\n        Returns:\n            next observation, reward, and done flat\n        \"\"\"\n        pass\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/STAG/envs/constants.py",
    "content": "\"\"\"\nGlobal constants for the gridworld env.\n\"\"\"\n\n# =============================================================================\n# set the value of interface\n# =============================================================================\nFPS = 25\nWinWidth = 340 #window width\nWinHeight = 260 #window width\nBoxSize = 20    #the size of one grid\nGridWidth = 7   #the number of lattices are there in the x-axis\nGridHeight = 7  #the number of lattices are there in the y-axis\nXMargin = int((WinWidth - GridWidth * BoxSize)/2)\nTopMargin = int((WinHeight - GridHeight * BoxSize))/2-5\n# =============================================================================\n# set color\n# =============================================================================\nWhite = (255, 255, 255)\nGray = (185, 185, 185)\nBlack = (0, 0, 0)\nRed = (255, 0, 0)\nGreen = (0, 128, 0)\nSpringGreen = (60, 179, 113)\nDarkOrange = (255, 140, 0)\nRoyalBlue = (65, 105, 225)\nDarkVoilet = (148, 0, 211)\nHotPink = (255, 105, 180)\nBoardColor = White\nBGColor = White\nTextColor = White\nTest = []\n# =============================================================================\n# set maps\n# =============================================================================\n# BlankBox = 1\n# shadow = 0\n# Wall = 5\n# Obstacle = 5\n# observer = 8\n# button = 7\n# obeservation_1 = 11\n# obeservation_2 = 22\n# obeservation_3 = 33\n\"\"\"\n'S' : starting point\n'F' or '.': free space\n'W' or 'x': wall\n'H' or 'o': hole (terminates episode)\n'G' : goal\n\"\"\"\nStart = 'S'\nFree_space  = 'F'\nWall = 'W'\nDanger = 'H'\nGoal = 'G'\nShadow = 'Sh'\n\n\nMAPs = {\n    0: [\n        \"FFFFF\",\n        \"FHFWF\",\n        \"FFFFF\",\n        \"WFFFF\",\n        \"FFFGF\"\n    ],\n    1: [\n        \"FFFFF\",\n        \"FHWFF\",\n        \"FFFFF\",\n        \"WFGFF\",\n        \"FFFFF\"\n    ],\n    2: [\n        \"FFFF\",\n        \"FWFW\",\n        \"FFFW\",\n        \"WFFG\"\n    ],\n\n    3: [\n        \"FFFF\",\n        \"FHFW\",\n        \"FFFW\",\n        \"GFFF\"\n    ],\n\n    4: [\n        \"FFFF\",\n        \"FHFW\",\n        \"FFFW\",\n        \"WGFF\"\n    ],\n}\n# 2: [\n#     \"SFFFFFFF\",\n#     \"FFFFFFFF\",\n#     \"FFFHFFFF\",\n#     \"FFFFFWFF\",\n#     \"FFFHFFFF\",\n#     \"FWHFFFWF\",\n#     \"FWFFHFWF\",\n#     \"FFFWFFFG\"\n# ],\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/STAG/main_spiking.py",
    "content": "from common_sr.arguments import get_common_args, get_coma_args, get_mixer_args\nfrom common_sr.multiprocessing_env import SubprocVecEnv\nfrom runner import Runner\nfrom time import sleep\nfrom gym_stag_hunt.envs.gym.escalation import EscalationEnv\nfrom gym_stag_hunt.envs.gym.harvest import HarvestEnv\nfrom gym_stag_hunt.envs.gym.hunt import HuntEnv\nfrom gym_stag_hunt.envs.gym.simple import SimpleEnv\nfrom gym_stag_hunt.src.games.abstract_grid_game import UP, LEFT, DOWN, RIGHT, STAND\n\nimport json\nimport os\nos.environ[\"SDL_VIDEODRIVER\"] = \"dummy\"\n\nif __name__ == '__main__':\n    ENVS = {\n        \"CLASSIC\": SimpleEnv,\n        \"HUNT\": HuntEnv,\n        \"HARVEST\": HarvestEnv,\n        \"ESCALATION\": EscalationEnv,\n    }\n\n    args = get_common_args()\n    args = get_mixer_args(args)\n\n    if args.ENV == 'HUNT':\n        args.n_actions = 5  # [5] # up, down, left, right or stand\n        args.n_agents = 2  # [2]\n        args.obs_shape = 6 + args.forage_quantity * 2\n\n    elif args.ENV == 'ESCALATION':\n        args.n_actions = 5  # [5] # up, down, left, right or stand\n        args.n_agents = 2  # [2]\n        args.obs_shape = 6\n\n    elif args.ENV == 'HARVEST':\n        args.n_actions = 5  # [5] # up, down, left, right or stand\n        args.n_agents = 2  # [2]\n        args.obs_shape = 6 + args.forage_quantity * 5\n\n    args.episode_limit = 50\n    args.train_steps = 100\n\n    save_path = args.log_dir + '/' + args.alg + args.exp_dir\n    print(os.path.exists(save_path))\n    if not os.path.exists(save_path):\n        os.makedirs(save_path)\n    # save args\n    argsDict = args.__dict__\n    with open(save_path + '/args_{}'.format(args.num_run), 'w') as f:\n        f.writelines('------------------ start ------------------' + '\\n')\n        for eachArg, value in argsDict.items():\n            f.writelines(eachArg + ' : ' + str(value) + '\\n')\n        f.writelines('------------------- end -------------------')\n\n    def make_env():\n        def _thunk():\n            if args.ENV == 'HUNT':\n                env = ENVS[args.ENV](obs_type=\"coords\", enable_multiagent=True, opponent_policy=\"random\", \\\n                                     forage_quantity=args.forage_quantity, run_away_after_maul=True)\n            elif args.ENV == 'ESCALATION':\n                env = ENVS[args.ENV](obs_type=\"coords\", enable_multiagent=True)\n\n            elif args.ENV == 'HARVEST':\n                env = ENVS[args.ENV](obs_type=\"coords\", enable_multiagent=True)\n\n            return env\n\n        return _thunk\n\n    # for i in range(args.num_run):\n\n    envs = [make_env() for i in range(args.process)]\n    envs = SubprocVecEnv(envs)\n\n    runner = Runner(envs, args)\n    if not args.evaluate:\n        runner.run(args.num_run)\n    else:\n        win_rate, _ = runner.evaluate()\n        print('The win rate of {} is  {}'.format(args.alg, win_rate))\n    envs.close()\n\n\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/STAG/network/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/STAG/network/spiking_net.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as f\nfrom torch.distributions import Normal\n\nfrom braincog.base.node.node import IFNode, LIFNode\nfrom braincog.base.strategy.surrogate import AtanGrad\n\n\nthresh = 0.3\nlens = 0.25\ndecay = 0.3\nTIMESTEPS = 15\nM = 5\n\n\n# BrainCog\nclass BCNoSpikingLIFNode(LIFNode):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n    def forward(self, dv: torch.Tensor):\n        self.integral(dv)\n        return self.mem\n\n\nclass BCNoSpikingIFNode(IFNode):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n    def forward(self, dv: torch.Tensor):\n        self.integral(dv)\n        return self.mem\n\n\n# Sug\nclass ActFun(torch.autograd.Function):\n\n    @staticmethod\n    def forward(ctx, input):\n        ctx.save_for_backward(input)\n        return input.gt(thresh).float()\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        input, = ctx.saved_tensors\n        grad_input = grad_output.clone()\n        temp = abs(input - thresh) < lens\n        return grad_input * temp.float() / (2 * lens)\n\n\n#act_fun = ActFun.apply\nact_fun = AtanGrad(alpha=2.,requires_grad=False)\n\ndef mem_update(fc, x, mem, spike):\n    mem = mem * decay * (1 - spike) + fc(x)\n    #spike = act_fun(mem)\n    spike = act_fun(x=mem-1)\n    return mem, spike\n\n\nclass Critic(nn.Module):\n    def __init__(self, input_shape, args):\n        super(Critic, self).__init__()\n        self.args = args\n        self.fc1 = nn.Linear(input_shape, args.ppo_hidden_size)\n        self.fc2 = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim, bias = True)\n        self.fc3 = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim, bias = True)\n        self.fc4 = nn.Linear(args.rnn_hidden_dim, args.n_actions)#\n        self.req_grad = False\n\n    def forward(self, inputs, h1_mem, h1_spike, h2_mem, h2_spike):\n        # if self.req_grad == False:\n        # [1, 17] -> [1, process, 64]\n        x = self.fc1(inputs)\n        # x = IFNode()(x)\n        x = LIFNode()(x)\n        if self.args.alg == 'siql_e':\n            h1_mem = h1_mem.reshape(-1,  M,TIMESTEPS, self.args.rnn_hidden_dim)\n            h1_spike = h1_spike.reshape(-1,  M,TIMESTEPS, self.args.rnn_hidden_dim)\n            h2_mem = h2_mem.reshape(-1, M, TIMESTEPS, self.args.rnn_hidden_dim)\n            h2_spike = h2_spike.reshape(-1, M, TIMESTEPS, self.args.rnn_hidden_dim)\n\n        else:\n        # [1, 64] -> [process, 64]\n            h1_mem = h1_mem.reshape(-1, self.args.rnn_hidden_dim)\n            h1_spike = h1_spike.reshape(-1, self.args.rnn_hidden_dim)\n            h2_mem = h2_mem.reshape(-1, self.args.rnn_hidden_dim)\n            h2_spike = h2_spike.reshape(-1, self.args.rnn_hidden_dim)\n\n        h1_mem, h1_spike = mem_update(self.fc2, x, h1_mem, h1_spike)\n        h2_mem, h2_spike = mem_update(self.fc3, h1_spike, h2_mem, h2_spike)\n        # [1, 5]\n        value = BCNoSpikingLIFNode(tau=2.0)(self.fc4(h2_mem))\n\n        return value, h1_mem, h1_spike, h2_mem, h2_spike\n\n\n\n\nclass VDNNet(nn.Module):\n    def __init__(self):\n        super(VDNNet, self).__init__()\n\n    def forward(self, q_values):\n        return torch.sum(q_values, dim=2, keepdim=True)\n\nclass Linear_weight(nn.Module):\n    def __init__(self, input_shape, out_shape, args):\n        super(Linear_weight,self).__init__()\n        self.args = args\n        # self.fc  = nn.Linear(input_shape, out_shape)\n        self.alpha = nn.Parameter(torch.Tensor(out_shape))\n\n    def forward(self, x):\n        # return self.fc(x)\n        if self.args.alg == 'scovdn_weight':\n            x = x[:,:,:,0] * self.alpha + x[:,:,:,1] * (1 - self.alpha)\n            return x.unsqueeze(3)\n        elif self.args.alg == 'stomvdn':\n            x = x[:, :, :, :, 0] * self.alpha + x[:, :, :, :, 1] * (1 - self.alpha)\n            return x.unsqueeze(4)\n\nclass BiasNet(nn.Module):\n    def __init__(self, args):\n        super(BiasNet, self).__init__()\n        self.args = args\n        input_shape = self.args.obs_shape + self.args.rnn_hidden_dim\n        #\n        # self.h1_mem = self.h1_spike = torch.zeros(self.args.n_episodes * self.args.process,\n        #            self.args.episode_limit, self.args.rnn_hidden_dim)\n        # if self.args.cuda:\n        #     self.h1_mem = self.h1_mem.cuda(self.args.device)\n        #     self.h1_spike = self.h1_spike.cuda(self.args.device)\n\n        self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)#neuron.IFNode()\n        self.fc2 = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim, bias = True)\n        self.fc3 = nn.Linear(args.rnn_hidden_dim, 1)#\n\n    def reset(self, episode_num):\n        self.h1_mem = self.h1_spike = torch.zeros(episode_num,\n                   self.args.episode_limit, self.args.rnn_hidden_dim)\n        if self.args.cuda:\n            self.h1_mem = self.h1_mem.cuda(self.args.device)\n            self.h1_spike = self.h1_spike.cuda(self.args.device)\n\n    def forward(self, state, hidden):\n        episode_num, max_episode_len, n_agents, _ = hidden.shape\n        state = state.reshape(episode_num * max_episode_len, -1)\n        state = state * 0.2\n        hidden = \\\n            hidden.reshape(episode_num * max_episode_len, n_agents, -1).sum(dim=-2)\n        inputs = torch.cat([state, hidden], dim=-1)\n\n        x = self.fc1(inputs)\n        x = neuron.IFNode()(x)\n        # x = IFNode()(x)\n        # x = LIFNode()(x)      #bad\n\n        # [1, 64] -> [process, 64]\n        self.h1_mem = self.h1_mem.reshape(-1, self.args.rnn_hidden_dim)\n        self.h1_spike = self.h1_spike.reshape(-1, self.args.rnn_hidden_dim)\n\n        self.h1_mem, self.h1_spike = mem_update(self.fc2, x, self.h1_mem, self.h1_spike)\n        # [1, 5]\n        # value = NonSpikingLIFNode(tau=2.0)(self.fc4(h2_mem))\n        # value = BCNoSpikingLIFNode(tau=2.0)(self.fc4(h2_mem))\n        value = BCNoSpikingIFNode(tau=2.0)(self.fc3(self.h1_mem))\n\n        return value\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/STAG/policy/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/STAG/policy/dqn.py",
    "content": "import torch\nimport os\nfrom network.base_net import RNN\n\nclass DQN:\n    def __init__(self, args, model_eval, model_target, agent_id):\n        self.n_actions = args.n_actions\n        self.n_agents = args.n_agents\n        self.obs_shape = args.obs_shape\n        self.agent_id = agent_id\n        input_shape = self.obs_shape\n        # 根据参数决定RNN的输入维度\n        if args.last_action:\n            input_shape += self.n_actions\n        # if args.reuse_network:\n        #     input_shape += self.n_agents\n\n        # 神经网络\n        self.eval_rnn = model_eval\n        self.target_rnn = model_target\n        self.args = args\n        if self.args.cuda:\n            self.eval_rnn.cuda(self.args.device)\n            self.target_rnn.cuda(self.args.device)\n\n        # self.model_dir = args.model_dir + '/' + args.alg\n        self.model_dir = '/home/zhaozhuoya/exp2/MARL_test_exp/model' + '/' + args.alg + args.exp_dir#+ args.model_dir\n\n        # 如果存在模型则加载模型\n        if self.args.load_model:\n            if os.path.exists(self.model_dir + args.save_model_dir):\n                # path_snn = '/home/zhaozhuoya/exp2/ToM2/model/iql/199_rnn_net_params.pkl'\n                map_location = self.args.device if self.args.cuda else 'cpu'\n                self.eval_rnn.load_state_dict(torch.load(self.model_dir + args.save_model_dir, map_location=map_location))\n                # self.eval_rnn.load_state_dict(torch.load(self.model_dir))\n                print('Successfully load the model: {}'.format(self.model_dir + args.save_model_dir))\n            else:\n                raise Exception(\"No model!\")\n\n        # 让target_net和eval_net的网络参数相同\n        self.target_rnn.load_state_dict(self.eval_rnn.state_dict())\n\n        self.eval_parameters = list(self.eval_rnn.parameters())\n        if args.optimizer == \"RMS\":\n            self.optimizer = torch.optim.RMSprop(self.eval_parameters, lr=args.lr)\n\n        # 执行过程中，要为每个agent都维护一个eval_hidden\n        # 学习过程中，要为每个episode的每个agent都维护一个eval_hidden、target_hidden\n        self.eval_hidden = None\n        self.target_hidden = None\n        print('Init alg DQN')\n\n    def learn(self, batch, max_episode_len, train_step, epsilon=None):  # train_step表示是第几次学习，用来控制更新target_net网络的参数\n        '''\n        在learn的时候，抽取到的数据是四维的，四个维度分别为 1——第几个episode 2——episode中第几个transition\n        3——第几个agent的数据 4——具体obs维度。因为在选动作时不仅需要输入当前的inputs，还要给神经网络输入hidden_state，\n        hidden_state和之前的经验相关，因此就不能随机抽取经验进行学习。所以这里一次抽取多个episode，然后一次给神经网络\n        传入每个episode的同一个位置的transition\n        '''\n        episode_num = batch['O'].shape[0]\n\n        self.init_hidden_learn(episode_num)\n        # hidden_state = self.policy.eval_hidden[:, self.agent_id, :, :]\n        eval_hidden, target_hidden = \\\n            self.eval_hidden[:, self.agent_id, :], self.target_hidden[:, self.agent_id, :]\n\n        # for key in batch.keys():  # 把batch里的数据转化成tensor\n        #     if key == 'u':\n        #         batch[key] = torch.tensor(batch[key], dtype=torch.long)\n        #     else:\n        #         batch[key] = torch.tensor(batch[key], dtype=torch.float32)\n        # u, r, avail_u, avail_u_next, terminated = batch['u'], batch['r'].squeeze(-1),  batch['avail_u'], \\\n        #                                           batch['avail_u_next'], batch['terminated'].repeat(1, 1, self.n_agents)\n        # mask = (1 - batch[\"padded\"].float()).repeat(1, 1, self.n_agents)  # 用来把那些填充的经验的TD-error置0，从而不让它们影响到学习\n        for key in batch.keys():  # 把batch里的数据转化成tensor\n            if key == 'O':\n                batch[key] = torch.tensor(batch[key], dtype=torch.long)\n            else:\n                batch[key] = torch.tensor(batch[key], dtype=torch.float32)\n        u, r, avail_u, avail_u_next, terminated = batch['U'], batch['R'].squeeze(-1),  batch['AVAIL_U'], \\\n                                                  batch['AVAIL_U_NEXT'], batch['TERMINATE'].repeat(1, 1, self.n_agents)\n        mask = (1 - batch[\"PADDED\"].float()).repeat(1, 1, self.n_agents)  # 用来把那些填充的经验的TD-error置0，从而不让它们影响到学习\n\n        # 得到每个agent对应的Q值，维度为(episode个数, max_episode_len， n_agents， n_actions)\n        q_evals, q_targets = self.get_q_values(batch, max_episode_len, eval_hidden, target_hidden)\n        if self.args.cuda:\n            u = u.cuda(self.args.device)\n            r = r.cuda(self.args.device)\n            terminated = terminated.cuda(self.args.device)\n            mask = mask.cuda(self.args.device)\n        # 取每个agent动作对应的Q值，并且把最后不需要的一维去掉，因为最后一维只有一个值了\n        u = u.to(torch.int64)\n        q_evals = torch.gather(q_evals, dim=3, index=u[:, :, self.agent_id, :].unsqueeze(3)).squeeze(3)\n\n        # 得到target_q\n        q_targets[avail_u_next[:, :, self.agent_id, :].unsqueeze(2) == 0.0] = - 9999999\n        q_targets = q_targets.max(dim=3)[0]\n\n        targets = r[:, :, self.agent_id].unsqueeze(2) + self.args.gamma * q_targets * (1 - terminated[:, :, self.agent_id].unsqueeze(2))\n\n        td_error = (q_evals - targets.detach())\n        masked_td_error = mask[:, :, self.agent_id].unsqueeze(2) * td_error  # 抹掉填充的经验的td_error\n\n        # 不能直接用mean，因为还有许多经验是没用的，所以要求和再比真实的经验数，才是真正的均值\n        loss = (masked_td_error ** 2).sum() / mask.sum()\n        # print('loss is ', loss)\n        self.optimizer.zero_grad()\n        loss.backward()\n        torch.nn.utils.clip_grad_norm_(self.eval_parameters, self.args.grad_norm_clip)\n        self.optimizer.step()\n\n        if train_step > 0 and train_step % self.args.target_update_cycle == 0:\n            self.target_rnn.load_state_dict(self.eval_rnn.state_dict())\n        return loss\n\n    def _get_inputs(self, batch, transition_idx):\n        # 取出所有episode上该transition_idx的经验，u_onehot要取出所有，因为要用到上一条\n        # obs, obs_next, u_onehot = batch['o'][:, transition_idx], \\\n        #                           batch['o_next'][:, transition_idx], batch['u_onehot'][:]\n        # 取出所有episode上该transition_idx的经验，u_onehot要取出所有，因为要用到上一条\n        obs, obs_next, u_onehot = batch['O'][:, transition_idx], \\\n                                  batch['O_NEXT'][:, transition_idx], batch['U_ONEHOT'][:]\n        episode_num = obs.shape[0]\n        inputs, inputs_next = [], []\n        inputs.append(obs[:, self.agent_id, :])\n        inputs_next.append(obs_next[:, self.agent_id, :])\n        # 给obs添加上一个动作、agent编号\n\n        if self.args.last_action:\n            if transition_idx == 0:  # 如果是第一条经验，就让前一个动作为0向量\n                inputs.append(torch.zeros_like(u_onehot[:, :, self.agent_id, :][:, transition_idx]))\n            else:\n                inputs.append(u_onehot[:, :, self.agent_id, :][:, transition_idx - 1])\n            inputs_next.append(u_onehot[:, :, self.agent_id, :][:, transition_idx])\n        # if self.args.reuse_network:\n            # 因为当前的obs三维的数据，每一维分别代表(episode编号，agent编号，obs维度)，直接在dim_1上添加对应的向量\n            # 即可，比如给agent_0后面加(1, 0, 0, 0, 0)，表示5个agent中的0号。而agent_0的数据正好在第0行，那么需要加的\n            # agent编号恰好就是一个单位矩阵，即对角线为1，其余为0\n            # inputs.append(torch.eye(self.args.n_agents).unsqueeze(0).expand(episode_num, -1, -1))\n            # inputs_next.append(torch.eye(self.args.n_agents).unsqueeze(0).expand(episode_num, -1, -1))\n            # inputs.append(torch.zeros((self.args.n_agents,  self.args.n_agents)).unsqueeze(0).expand(episode_num, -1, -1))\n            # inputs_next.append(torch.zeros((self.args.n_agents,  self.args.n_agents)).unsqueeze(0).expand(episode_num, -1, -1))\n\n        # 要把obs中的三个拼起来，并且要把episode_num个episode、self.args.n_agents个agent的数据拼成40条(40,96)的数据，\n        # 因为这里所有agent共享一个神经网络，每条数据中带上了自己的编号，所以还是自己的数据\n        inputs = torch.cat([x for x in inputs], dim=1)\n        inputs_next = torch.cat([x for x in inputs_next], dim=1)\n        return inputs, inputs_next\n\n    def get_q_values(self, batch, max_episode_len, eval_hidden, target_hidden):\n        # episode_num = batch['o'].shape[0]\n        episode_num = batch['O'].shape[0]\n\n        q_evals, q_targets = [], []\n        for transition_idx in range(max_episode_len):\n            inputs, inputs_next = self._get_inputs(batch, transition_idx)  # 给obs加last_action、agent_id\n            if self.args.cuda:\n                inputs = inputs.cuda(self.args.device)\n                inputs_next = inputs_next.cuda(self.args.device)\n                eval_hidden = eval_hidden.cuda(self.args.device)\n                target_hidden = target_hidden.cuda(self.args.device)\n            q_eval, self.eval_hidden = self.eval_rnn(inputs, eval_hidden)  # inputs维度为(40,96)，得到的q_eval维度为(40,n_actions)\n            q_target, self.target_hidden = self.target_rnn(inputs_next, target_hidden)\n\n            # 把q_eval维度重新变回(8, 5,n_actions)\n            q_eval = q_eval.view(episode_num, 1, -1)\n            q_target = q_target.view(episode_num, 1, -1)\n            q_evals.append(q_eval)\n            q_targets.append(q_target)\n        # 得的q_eval和q_target是一个列表，列表里装着max_episode_len个数组，数组的的维度是(episode个数, n_agents，n_actions)\n        # 把该列表转化成(episode个数, max_episode_len， n_agents，n_actions)的数组\n        q_evals = torch.stack(q_evals, dim=1)\n        q_targets = torch.stack(q_targets, dim=1)\n        return q_evals, q_targets\n    def init_hidden(self, episode_num, num_env):\n    # 为每个episode中的每个agent都初始化一个eval_hidden、target_hidden\n        self.eval_hidden = torch.zeros((episode_num, self.n_agents, num_env,self.args.rnn_hidden_dim))\n        self.target_hidden = torch.zeros((episode_num, self.n_agents, num_env,self.args.rnn_hidden_dim))\n\n    def init_hidden_learn(self, episode_num):\n        # 为每个episode中的每个agent都初始化一个eval_hidden、target_hidden\n        self.eval_hidden = torch.zeros((episode_num, self.n_agents, self.args.rnn_hidden_dim))\n        self.target_hidden = torch.zeros((episode_num, self.n_agents, self.args.rnn_hidden_dim))\n\n    def save_model(self, train_step):\n        num = str(train_step // self.args.save_cycle)\n        if not os.path.exists(self.model_dir):\n            os.makedirs(self.model_dir)\n        torch.save(self.eval_rnn.state_dict(),  self.model_dir + '/' + num + '_rnn_net_params.pkl')\n\n    def load_model(self, train_step):\n        num = str(train_step // self.args.save_cycle)\n\n        path = torch.load(self.model_dir + '/' + num + '_rnn_net_params.pkl')\n\n        self.eval_rnn.load_state_dict(path)"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/STAG/policy/stomvdn.py",
    "content": "import torch\nimport os\nfrom network.spiking_net import Critic, VDNNet, Linear_weight, BiasNet\nimport copy\n\n\nclass SToMVDN:\n    def __init__(self, args):\n        self.n_actions = args.n_actions\n        self.n_agents = args.n_agents\n        self.obs_shape = args.obs_shape\n        input_shape = self.obs_shape\n        # 根据参数决定RNN的输入维度\n        if args.last_action:\n            input_shape += self.n_actions\n        if args.reuse_network:\n            input_shape += self.n_agents\n\n        self.loss_trade_off_target = 0\n        self.loss_trade_off_eval = 0\n\n        # 神经网络\n        self.eval_snn = Critic(input_shape, args)\n        self.target_snn = Critic(input_shape, args)\n        self.eval_vdn_snn = VDNNet()  # 把agentsQ值加起来的网络\n        self.target_vdn_snn = VDNNet()\n        self.bias_net = BiasNet(args)\n        self.trade_off_net = Linear_weight(2, 1, args)\n\n        self.args = args\n        if self.args.cuda:\n            self.eval_snn.cuda(self.args.device)\n            self.target_snn.cuda(self.args.device)\n            self.eval_vdn_snn.cuda(self.args.device)\n            self.target_vdn_snn.cuda(self.args.device)\n            self.trade_off_net.cuda(self.args.device)\n            self.bias_net.cuda(self.args.device)\n\n        self.model_dir = args.model_dir + '/' + args.alg + args.exp_dir + args.save_model_dir\n        # 如果存在模型则加载模型\n        if self.args.load_model:\n            if os.path.exists(self.model_dir):\n                # path_snn = '/home/zhaozhuoya/exp2/ToM2_test/model/siql/199_snn_net_params.pkl'\n                map_location = self.args.device if self.args.cuda else 'cpu'\n                self.eval_snn.load_state_dict(torch.load(self.model_dir, map_location=map_location))\n                print('Successfully load the model: {}'.format(self.model_dir))\n            else:\n                print(self.model_dir)\n                raise Exception(\"No model!\")\n\n        # 让target_net和eval_net的网络参数相同\n        self.target_snn.load_state_dict(self.eval_snn.state_dict())\n        self.target_vdn_snn.load_state_dict(self.eval_vdn_snn.state_dict())\n\n        self.eval_parameters = list(self.eval_snn.parameters()) + \\\n                               list(self.eval_vdn_snn.parameters()) + \\\n                                    list(self.trade_off_net.parameters()) + \\\n                                    list(self.bias_net.parameters())\n\n        self.trade_off_parameters = list(self.trade_off_net.parameters())\n        if args.optimizer == \"RMS\":\n            self.optimizer = torch.optim.RMSprop(self.eval_parameters, lr=args.lr)\n            self.optimizer_T = torch.optim.RMSprop(self.trade_off_parameters, lr=args.lr)\n\n        # 执行过程中，要为每个agent都维护一个eval_hidden\n        # 学习过程中，要为每个episode的每个agent都维护一个eval_hidden、target_hidden\n        self.eval_h1_mem, self.eval_h1_spike = None, None\n        self.target_h1_mem, self.target_h1_spike = None, None\n        self.eval_h2_mem, self.eval_h2_spike = None, None\n        self.target_h2_mem, self.target_h2_spike = None, None\n        print('Init alg SCOVDN_ToM')\n\n    def learn(self, batch, max_episode_len, train_step, epsilon=None):  # train_step表示是第几次学习，用来控制更新target_net网络的参数\n        '''\n        在learn的时候，抽取到的数据是四维的，四个维度分别为 1——第几个episode 2——episode中第几个transition\n        3——第几个agent的数据 4——具体obs维度。因为在选动作时不仅需要输入当前的inputs，还要给神经网络输入hidden_state，\n        hidden_state和之前的经验相关，因此就不能随机抽取经验进行学习。所以这里一次抽取多个episode，然后一次给神经网络\n        传入每个episode的同一个位置的transition\n        '''\n        episode_num = batch['O'].shape[0]\n        self.init_hidden_learn(episode_num)\n        self.bias_net.reset(episode_num)\n        for key in batch.keys():  # 把batch里的数据转化成tensor\n            if key == 'U':  # 'O'\n                batch[key] = torch.tensor(batch[key], dtype=torch.long)\n            else:\n                batch[key] = torch.tensor(batch[key], dtype=torch.float32)\n        u, r, avail_u, avail_u_next, terminated = batch['U'], batch['R'].squeeze(-1), batch['AVAIL_U'], \\\n                                                  batch['AVAIL_U_NEXT'], batch['TERMINATE'].repeat(1, 1, self.n_agents)\n        mask = (1 - batch[\"PADDED\"].float()).repeat(1, 1, self.n_agents)  # 用来把那些填充的经验的TD-error置0，从而不让它们影响到学习\n        if self.args.cuda:\n            u = u.cuda(self.args.device)\n            r = r.cuda(self.args.device)\n            terminated = terminated.cuda(self.args.device)\n            mask = mask.cuda(self.args.device)\n            # self.bias_net.cuda(self.args.device)\n        u = u.to(torch.int64)\n\n        # ---------------------------------------independent_Q_net------------------------------------------------------\n        # 得到每个agent对应的Q值，维度为(episode个数, max_episode_len， n_agents， n_actions)\n        q_evals, q_targets, hidden_evals, hidden_targets = self.get_q_values(batch, max_episode_len)\n        # ---------------------------------------independent_Q_net------------------------------------------------------\n\n        # --------------------------------------------bias_net----------------------------------------------------------\n        # 得到每个agent对应的Q值，维度为(episode个数, max_episode_len， n_agents， n_actions)\n        v = self.get_bias(batch, hidden_evals, hidden_targets, episode_num)\n        # --------------------------------------------bias_net----------------------------------------------------------\n\n\n        # ---------------------------------------_self+Q_other_net------------------------------------------------------\n        q_other_evals, q_other_targets = q_evals[:, :, [1, 0], :].unsqueeze(4), q_targets[:, :, [1, 0], :].unsqueeze(4)\n        q_evals_, q_targets_ = q_evals.unsqueeze(4), q_targets.unsqueeze(4)\n        q_total_evals = torch.cat((q_evals_, q_other_evals), 4)\n        q_total_targets = torch.cat((q_targets_, q_other_targets), 4)\n        # q_total_evals = self.trade_off_net(q_total_evals)   #([10, 50, 2, 5, 1])\n        # q_total_targets = self.trade_off_net(q_total_targets)  # ([10, 50, 2, 5, 1])\n        q_total_evals = q_evals_ + q_other_targets\n        q_total_targets = q_targets_ + q_other_targets\n        # --------------------------------------------_self+Q_other_net-------------------------------------------------\n\n        # --------------------------------------------L_self/other------------------------------------------------------\n        q_total_targets[avail_u_next == 0.0] = - 9999999\n        q_total_targets = q_total_targets.max(dim=3)[0].squeeze()\n\n        q_total_evals = torch.gather(q_total_evals.squeeze(4), dim=3, index=u).squeeze(3)\n\n        y = r + self.args.gamma * q_total_targets * (1 - terminated)\n        td_error = q_total_evals - y.detach()\n        l_so = ((td_error * mask) ** 2).sum() / mask.sum()\n        # --------------------------------------------L_self/other------------------------------------------------------\n\n        # --------------------------------------------action_prob_Q-----------------------------------------------------\n        # probablity of action\n        action_prob = self._get_action_prob(batch, max_episode_len, 0.4)  # 每个agent的所有动作的概率self.args.epsilon\n        pi_taken = torch.gather(action_prob, dim=3, index=u).squeeze(3)  # 每个agent的选择的动作对应的概率\n        pi_taken[mask == 0] = 1.0  # 因为要取对数，对于那些填充的经验，所有概率都为0，取了log就是负无穷了，所以让它们变成1\n        log_pi_taken = torch.log(pi_taken)\n        # --------------------------------------------action_prob_Q-----------------------------------------------------\n\n        # ----------------------------------------------L_coma----------------------------------------------------------\n        # q_evals = torch.gather(q_evals * action_prob, dim=3, index=u).squeeze(3)\n        q_evals_coma = (q_evals * action_prob).sum(dim=3, keepdim=True).squeeze(3)\n        coma_error = q_evals_coma.sum(dim=-1) - q_total_targets.detach().sum(dim=-1) + v\n        l_coma = ((coma_error * mask[:,:,0]) ** 2).sum() / mask[:,:,0].sum()\n        # ----------------------------------------------L_coma----------------------------------------------------------\n\n        # -----------------------------------------------L_sum----------------------------------------------------------\n        q_evals_sum = self.eval_vdn_snn(q_evals)\n        sum_error = q_evals_sum.sum(dim=-1).squeeze(2) - q_total_targets.detach().sum(dim=-1) + v\n        l_sum = ((sum_error * mask[:,:,0]) ** 2).sum() / mask[:,:,0].sum()\n        # -----------------------------------------------L_sum----------------------------------------------------------\n\n        LOSS = l_so + l_coma + l_sum\n\n        self.optimizer.zero_grad()\n        LOSS.backward()\n\n        if train_step > 0 and train_step % self.args.target_update_cycle == 0:\n            self.target_snn.load_state_dict(self.eval_snn.state_dict())\n            self.target_vdn_snn.load_state_dict(self.eval_vdn_snn.state_dict())\n        return LOSS\n\n    def _get_inputs(self, batch, transition_idx):\n        # 取出所有episode上该transition_idx的经验，u_onehot要取出所有，因为要用到上一条\n        obs, obs_next, u_onehot = batch['O'][:, transition_idx], \\\n                                  batch['O_NEXT'][:, transition_idx], batch['U_ONEHOT'][:]\n        episode_num = obs.shape[0]\n        inputs, inputs_next = [], []\n        inputs.append(obs)\n        inputs_next.append(obs_next)\n        # 给obs添加上一个动作、agent编号\n\n        if self.args.last_action:\n            if transition_idx == 0:  # 如果是第一条经验，就让前一个动作为0向量\n                inputs.append(torch.zeros_like(u_onehot[:, transition_idx]))\n            else:\n                inputs.append(u_onehot[:, transition_idx - 1])\n            inputs_next.append(u_onehot[:, transition_idx])\n        if self.args.reuse_network:\n            # 因为当前的obs三维的数据，每一维分别代表(episode编号，agent编号，obs维度)，直接在dim_1上添加对应的向量\n            # 即可，比如给agent_0后面加(1, 0, 0, 0, 0)，表示5个agent中的0号。而agent_0的数据正好在第0行，那么需要加的\n            # agent编号恰好就是一个单位矩阵，即对角线为1，其余为0\n            inputs.append(torch.eye(self.args.n_agents).unsqueeze(0).expand(episode_num, -1, -1))\n            inputs_next.append(torch.eye(self.args.n_agents).unsqueeze(0).expand(episode_num, -1, -1))\n        # 要把obs中的三个拼起来，并且要把episode_num个episode、self.args.n_agents个agent的数据拼成40条(40,96)的数据，\n        # 因为这里所有agent共享一个神经网络，每条数据中带上了自己的编号，所以还是自己的数据\n        inputs = torch.cat([x.reshape(episode_num * self.args.n_agents, -1) for x in inputs], dim=1)\n        inputs_next = torch.cat([x.reshape(episode_num * self.args.n_agents, -1) for x in inputs_next], dim=1)\n        return inputs, inputs_next\n\n    def get_q_values(self, batch, max_episode_len):\n        episode_num = batch['O'].shape[0]\n        q_evals, q_targets, eval_h2_mems, target_h2_mems = [], [], [], []\n        for transition_idx in range(max_episode_len):\n            inputs, inputs_next = self._get_inputs(batch, transition_idx)  # 给obs加last_action、agent_id\n            if self.args.cuda:\n                inputs = inputs.cuda(self.args.device)\n                inputs_next = inputs_next.cuda(self.args.device)\n                self.eval_h1_mem, self.eval_h1_spike, self.eval_h2_mem, self.eval_h2_spike = \\\n                    self.eval_h1_mem.cuda(self.args.device), self.eval_h1_spike.cuda(\n                        self.args.device), self.eval_h2_mem.cuda(self.args.device), self.eval_h2_spike.cuda(\n                        self.args.device)\n                self.target_h1_mem, self.target_h1_spike, self.target_h2_mem, self.target_h2_spike = \\\n                    self.target_h1_mem.cuda(self.args.device), self.target_h1_spike.cuda(\n                        self.args.device), self.target_h2_mem.cuda(self.args.device), self.target_h2_spike.cuda(\n                        self.args.device)\n            q_eval, self.eval_h1_mem, self.eval_h1_spike, self.eval_h2_mem, self.eval_h2_spike = \\\n                self.eval_snn(inputs, self.eval_h1_mem, self.eval_h1_spike, self.eval_h2_mem,\n                              self.eval_h2_spike)  # inputs维度为(40,96)，得到的q_eval维度为(40,n_actions)\n            q_target, self.target_h1_mem, self.target_h1_spike, self.target_h2_mem, self.target_h2_spike = \\\n                self.target_snn(inputs_next, self.target_h1_mem, self.target_h1_spike, self.target_h2_mem,\n                                self.target_h2_spike)\n\n            # 把q_eval维度重新变回(8, 5,n_actions)\n            q_eval = q_eval.view(episode_num, self.n_agents, -1)\n            q_target = q_target.view(episode_num, self.n_agents, -1)\n            eval_h2_mem = self.eval_h2_mem.view(episode_num, self.n_agents, -1)\n            target_h2_mem = self.target_h2_mem.view(episode_num, self.n_agents, -1)\n            q_evals.append(q_eval)\n            q_targets.append(q_target)\n            eval_h2_mems.append(eval_h2_mem)\n            target_h2_mems.append(target_h2_mem)\n        # 得的q_eval和q_target是一个列表，列表里装着max_episode_len个数组，数组的的维度是(episode个数, n_agents，n_actions)\n        # 把该列表转化成(episode个数, max_episode_len， n_agents，n_actions)的数组\n        q_evals = torch.stack(q_evals, dim=1)\n        q_targets = torch.stack(q_targets, dim=1)\n        hidden_evals = torch.stack(eval_h2_mems, dim=1)\n        hidden_targets = torch.stack(target_h2_mems, dim=1)\n        return q_evals, q_targets, hidden_evals, hidden_targets\n\n    def get_bias(self, batch, hidden_evals, hidden_targets, episode_num, hat=False):\n        # episode_num, max_episode_len, _, _ = hidden_targets.shape\n        max_episode_len = self.args.episode_limit\n        states = batch['O'][:, :max_episode_len]\n        states_next = batch['O_NEXT'][:, :max_episode_len]\n        u_onehot = batch['U_ONEHOT'][:, :max_episode_len]\n        if self.args.cuda:\n            states = states.cuda(self.args.device)[:,:,0,:]\n            states_next = states_next.cuda(self.args.device)[:,:,0,:]\n            u_onehot = u_onehot.cuda(self.args.device)\n            hidden_evals = hidden_evals.cuda(self.args.device)\n            hidden_targets = hidden_targets.cuda(self.args.device)\n        if hat:\n            v = None\n        else:\n            v = self.bias_net(states, hidden_evals)\n            # 把q_eval、q_target、v维度变回(episode_num, max_episode_len)\n            v = v.view(episode_num, -1, 1).squeeze(-1)\n        return v\n\n\n    def _get_actor_inputs(self, batch, transition_idx):\n        # 取出所有episode上该transition_idx的经验，u_onehot要取出所有，因为要用到上一条\n        obs, u_onehot = batch['O'][:, transition_idx], batch['U_ONEHOT'][:]\n        episode_num = obs.shape[0]\n        inputs = []\n        inputs.append(obs)\n        # 给inputs添加上一个动作、agent编号\n\n        if self.args.last_action:\n            if transition_idx == 0:  # 如果是第一条经验，就让前一个动作为0向量\n                inputs.append(torch.zeros_like(u_onehot[:, transition_idx]))\n            else:\n                inputs.append(u_onehot[:, transition_idx - 1])\n        if self.args.reuse_network:\n            # 因为当前的inputs三维的数据，每一维分别代表(episode编号，agent编号，inputs维度)，直接在dim_1上添加对应的向量\n            # 即可，比如给agent_0后面加(1, 0, 0, 0, 0)，表示5个agent中的0号。而agent_0的数据正好在第0行，那么需要加的\n            # agent编号恰好就是一个单位矩阵，即对角线为1，其余为0\n            inputs.append(torch.eye(self.args.n_agents).unsqueeze(0).expand(episode_num, -1, -1))\n        # 要把inputs中的三个拼起来，并且要把episode_num个episode、self.args.n_agents个agent的数据拼成40条(40,96)的数据，\n        # 因为这里所有agent共享一个神经网络，每条数据中带上了自己的编号，所以还是自己的数据\n        inputs = torch.cat([x.reshape(episode_num * self.args.n_agents, -1) for x in inputs], dim=1)\n        return inputs\n\n    def _get_action_prob(self, batch, max_episode_len, epsilon):\n        episode_num = batch['O'].shape[0]\n        avail_actions = batch['AVAIL_U']\n        action_prob = []\n        for transition_idx in range(max_episode_len):\n            inputs = self._get_actor_inputs(batch, transition_idx)  # 给obs加last_action、agent_id\n            if self.args.cuda:\n                inputs = inputs.cuda(self.args.device)\n                # self.eval_hidden = self.eval_hidden.cuda(self.args.device)\n                self.eval_h1_mem, self.eval_h1_spike, self.eval_h2_mem, self.eval_h2_spike = \\\n                    self.eval_h1_mem.cuda(self.args.device), self.eval_h1_spike.cuda(\n                        self.args.device), self.eval_h2_mem.cuda(self.args.device), self.eval_h2_spike.cuda(\n                        self.args.device)\n                self.target_h1_mem, self.target_h1_spike, self.target_h2_mem, self.target_h2_spike = \\\n                    self.target_h1_mem.cuda(self.args.device), self.target_h1_spike.cuda(\n                        self.args.device), self.target_h2_mem.cuda(self.args.device), self.target_h2_spike.cuda(\n                        self.args.device)\n\n            # outputs, self.eval_hidden = self.eval_snn(inputs, self.eval_hidden)  # inputs维度为(40,96)，得到的q_eval维度为(40,n_actions)\n            outputs, self.eval_h1_mem, self.eval_h1_spike, self.eval_h2_mem, self.eval_h2_spike = \\\n                self.eval_snn(inputs, self.eval_h1_mem, self.eval_h1_spike, self.eval_h2_mem,\n                              self.eval_h2_spike)  # inputs维度为(40,96)，得到的q_eval维度为(40,n_actions)\n\n            # 把q_eval维度重新变回(8, 5,n_actions)\n            outputs = outputs.view(episode_num, self.n_agents, -1)\n            prob = torch.nn.functional.softmax(outputs, dim=-1)\n            action_prob.append(prob)\n        # 得的action_prob是一个列表，列表里装着max_episode_len个数组，数组的的维度是(episode个数, n_agents，n_actions)\n        # 把该列表转化成(episode个数, max_episode_len， n_agents，n_actions)的数组\n        action_prob = torch.stack(action_prob, dim=1).cpu()\n\n        action_num = avail_actions.sum(dim=-1, keepdim=True).float().repeat(1, 1, 1,\n                                                                            avail_actions.shape[-1])  # 可以选择的动作的个数\n        action_prob = ((1 - epsilon) * action_prob + torch.ones_like(action_prob) * epsilon / action_num)\n        action_prob[avail_actions == 0] = 0.0  # 不能执行的动作概率为0\n\n        # 因为上面把不能执行的动作概率置为0，所以概率和不为1了，这里要重新正则化一下。执行过程中Categorical会自己正则化。\n        action_prob = action_prob / action_prob.sum(dim=-1, keepdim=True)\n        # 因为有许多经验是填充的，它们的avail_actions都填充的是0，所以该经验上所有动作的概率都为0，在正则化的时候会得到nan。\n        # 因此需要再一次将该经验对应的概率置为0\n        action_prob[avail_actions == 0] = 0.0\n        if self.args.cuda:\n            action_prob = action_prob.cuda(self.args.device)\n        return action_prob\n\n    def init_hidden(self, episode_num, num_env):\n        # 为每个episode中的每个agent都初始化一个eval_hidden、target_hidden\n        self.eval_h1_mem = self.eval_h1_spike = torch.zeros(episode_num, self.n_agents, num_env,\n                                                            self.args.rnn_hidden_dim)\n        self.target_h1_mem = self.target_h1_spike = torch.zeros(episode_num, self.n_agents, num_env,\n                                                                self.args.rnn_hidden_dim)\n        self.eval_h2_mem = self.eval_h2_spike = torch.zeros(episode_num, self.n_agents, num_env,\n                                                            self.args.rnn_hidden_dim)\n        self.target_h2_mem = self.target_h2_spike = torch.zeros(episode_num, self.n_agents, num_env,\n                                                                self.args.rnn_hidden_dim)\n\n    def init_hidden_learn(self, episode_num):\n        # 为每个episode中的每个agent都初始化一个eval_hidden、target_hidden\n        self.eval_h1_mem = self.eval_h1_spike = torch.zeros(episode_num, self.n_agents,\n                                                            self.args.rnn_hidden_dim)\n        self.target_h1_mem = self.target_h1_spike = torch.zeros(episode_num, self.n_agents,\n                                                                self.args.rnn_hidden_dim)\n        self.eval_h2_mem = self.eval_h2_spike = torch.zeros(episode_num, self.n_agents,\n                                                            self.args.rnn_hidden_dim)\n        self.target_h2_mem = self.target_h2_spike = torch.zeros(episode_num, self.n_agents,\n                                                                self.args.rnn_hidden_dim)\n\n    def save_model(self, train_step):\n        num = str(train_step // self.args.save_cycle)\n        if not os.path.exists(self.model_dir):\n            os.makedirs(self.model_dir)\n        torch.save(self.eval_snn.state_dict(),\n                   self.model_dir + '/' + num + '_snn_net_params_{}.pkl'.format(self.args.num_run))\n\n    def load_model(self, train_step):\n        num = str(train_step // self.args.save_cycle)\n\n        path = torch.load(self.model_dir + '/' + num + '_snn_net_params.pkl'.format(self.args.num_run))\n\n        self.eval_snn.load_state_dict(path)\n\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/STAG/policy/svdn.py",
    "content": "import torch\nimport os\nfrom network.spiking_net import Critic, VDNNet\n\n\nclass SVDN:\n    def __init__(self, args):\n        self.n_actions = args.n_actions\n        self.n_agents = args.n_agents\n        self.obs_shape = args.obs_shape\n        input_shape = self.obs_shape\n        # 根据参数决定RNN的输入维度\n        if args.last_action:\n            input_shape += self.n_actions\n        if args.reuse_network:\n            input_shape += self.n_agents\n\n        # 神经网络\n        self.eval_snn = Critic(input_shape, args)\n        self.target_snn = Critic(input_shape, args)\n        self.eval_vdn_snn = VDNNet()  # 把agentsQ值加起来的网络\n        self.target_vdn_snn = VDNNet()\n        self.args = args\n        if self.args.cuda:\n            self.eval_snn.cuda(self.args.device)\n            self.target_snn.cuda(self.args.device)\n            self.eval_vdn_snn.cuda(self.args.device)\n            self.target_vdn_snn.cuda(self.args.device)\n            \n        self.model_dir = args.model_dir + '/' + args.alg + args.exp_dir + args.save_model_dir\n        # 如果存在模型则加载模型\n        if self.args.load_model:\n            if os.path.exists(self.model_dir):\n                # path_snn = '/home/zhaozhuoya/exp2/ToM2_test/model/siql/199_snn_net_params.pkl'\n                map_location = self.args.device if self.args.cuda else 'cpu'\n                self.eval_snn.load_state_dict(torch.load(self.model_dir, map_location=map_location))\n                print('Successfully load the model: {}'.format(self.model_dir))\n            else:\n                print(self.model_dir)\n                raise Exception(\"No model!\")\n\n        # 让target_net和eval_net的网络参数相同\n        self.target_snn.load_state_dict(self.eval_snn.state_dict())\n        self.target_vdn_snn.load_state_dict(self.eval_vdn_snn.state_dict())\n\n        self.eval_parameters = list(self.eval_snn.parameters()) + list(self.eval_vdn_snn.parameters())\n        if args.optimizer == \"RMS\":\n            self.optimizer = torch.optim.RMSprop(self.eval_parameters, lr=args.lr)\n\n        # 执行过程中，要为每个agent都维护一个eval_hidden\n        # 学习过程中，要为每个episode的每个agent都维护一个eval_hidden、target_hidden\n        self.eval_h1_mem, self.eval_h1_spike = None, None\n        self.target_h1_mem, self.target_h1_spike = None, None\n        self.eval_h2_mem, self.eval_h2_spike = None, None\n        self.target_h2_mem, self.target_h2_spike = None, None\n        print('Init alg SVDN')\n\n    def learn(self, batch, max_episode_len, train_step, epsilon=None):  # train_step表示是第几次学习，用来控制更新target_net网络的参数\n        '''\n        在learn的时候，抽取到的数据是四维的，四个维度分别为 1——第几个episode 2——episode中第几个transition\n        3——第几个agent的数据 4——具体obs维度。因为在选动作时不仅需要输入当前的inputs，还要给神经网络输入hidden_state，\n        hidden_state和之前的经验相关，因此就不能随机抽取经验进行学习。所以这里一次抽取多个episode，然后一次给神经网络\n        传入每个episode的同一个位置的transition\n        '''\n        episode_num = batch['O'].shape[0]\n        self.init_hidden_learn(episode_num)\n        for key in batch.keys():  # 把batch里的数据转化成tensor\n            if key == 'U':\n                batch[key] = torch.tensor(batch[key], dtype=torch.long)\n            else:\n                batch[key] = torch.tensor(batch[key], dtype=torch.float32)\n        u, r, avail_u, avail_u_next, terminated = batch['U'], batch['R'].squeeze(-1),  batch['AVAIL_U'], \\\n                                                  batch['AVAIL_U_NEXT'], batch['TERMINATE'].repeat(1, 1, self.n_agents)\n        mask = (1 - batch[\"PADDED\"].float()).repeat(1, 1, self.n_agents)  # 用来把那些填充的经验的TD-error置0，从而不让它们影响到学习\n\n        # 得到每个agent对应的Q值，维度为(episode个数, max_episode_len， n_agents， n_actions)\n        q_evals, q_targets = self.get_q_values(batch, max_episode_len)\n        if self.args.cuda:\n            u = u.cuda(self.args.device)\n            r = r.cuda(self.args.device)\n            terminated = terminated.cuda(self.args.device)\n            mask = mask.cuda(self.args.device)\n\n        # 取每个agent动作对应的Q值，并且把最后不需要的一维去掉，因为最后一维只有一个值了\n        u = u.to(torch.int64)\n        q_evals = torch.gather(q_evals, dim=3, index=u).squeeze(3)\n\n        # 得到target_q\n        q_targets[avail_u_next == 0.0] = - 9999999\n        q_targets = q_targets.max(dim=3)[0]\n\n        q_total_eval = self.eval_vdn_snn(q_evals)\n        q_total_target = self.target_vdn_snn(q_targets)\n\n        targets = r + self.args.gamma * q_total_target * (1 - terminated)\n\n        td_error = targets.detach() - q_total_eval\n        masked_td_error = mask * td_error  # 抹掉填充的经验的td_error\n\n        # 不能直接用mean，因为还有许多经验是没用的，所以要求和再比真实的经验数，才是真正的均值\n        loss = (masked_td_error ** 2).sum() / mask.sum()\n        # print('loss is ', loss)\n        self.optimizer.zero_grad()\n        loss.backward()\n        torch.nn.utils.clip_grad_norm_(self.eval_parameters, self.args.grad_norm_clip)\n        self.optimizer.step()\n\n        if train_step > 0 and train_step % self.args.target_update_cycle == 0:\n            self.target_snn.load_state_dict(self.eval_snn.state_dict())\n            self.target_vdn_snn.load_state_dict(self.eval_vdn_snn.state_dict())\n\n        return loss\n\n    def _get_inputs(self, batch, transition_idx):\n        # 取出所有episode上该transition_idx的经验，u_onehot要取出所有，因为要用到上一条\n        obs, obs_next, u_onehot = batch['O'][:, transition_idx], \\\n                                  batch['O_NEXT'][:, transition_idx], batch['U_ONEHOT'][:]\n        episode_num = obs.shape[0]\n        inputs, inputs_next = [], []\n        inputs.append(obs)\n        inputs_next.append(obs_next)\n        # 给obs添加上一个动作、agent编号\n\n        if self.args.last_action:\n            if transition_idx == 0:  # 如果是第一条经验，就让前一个动作为0向量\n                inputs.append(torch.zeros_like(u_onehot[:, transition_idx]))\n            else:\n                inputs.append(u_onehot[:, transition_idx - 1])\n            inputs_next.append(u_onehot[:, transition_idx])\n        if self.args.reuse_network:\n            # 因为当前的obs三维的数据，每一维分别代表(episode编号，agent编号，obs维度)，直接在dim_1上添加对应的向量\n            # 即可，比如给agent_0后面加(1, 0, 0, 0, 0)，表示5个agent中的0号。而agent_0的数据正好在第0行，那么需要加的\n            # agent编号恰好就是一个单位矩阵，即对角线为1，其余为0\n            inputs.append(torch.eye(self.args.n_agents).unsqueeze(0).expand(episode_num, -1, -1))\n            inputs_next.append(torch.eye(self.args.n_agents).unsqueeze(0).expand(episode_num, -1, -1))\n        # 要把obs中的三个拼起来，并且要把episode_num个episode、self.args.n_agents个agent的数据拼成40条(40,96)的数据，\n        # 因为这里所有agent共享一个神经网络，每条数据中带上了自己的编号，所以还是自己的数据\n        inputs = torch.cat([x.reshape(episode_num * self.args.n_agents, -1) for x in inputs], dim=1)\n        inputs_next = torch.cat([x.reshape(episode_num * self.args.n_agents, -1) for x in inputs_next], dim=1)\n        return inputs, inputs_next\n\n    def get_q_values(self, batch, max_episode_len):\n        episode_num = batch['O'].shape[0]\n        q_evals, q_targets = [], []\n        for transition_idx in range(max_episode_len):\n            inputs, inputs_next = self._get_inputs(batch, transition_idx)  # 给obs加last_action、agent_id\n            if self.args.cuda:\n                inputs = inputs.cuda(self.args.device)\n                inputs_next = inputs_next.cuda(self.args.device)\n                self.eval_h1_mem, self.eval_h1_spike, self.eval_h2_mem, self.eval_h2_spike = \\\n                    self.eval_h1_mem.cuda(self.args.device), self.eval_h1_spike.cuda(self.args.device), self.eval_h2_mem.cuda(self.args.device), self.eval_h2_spike.cuda(self.args.device)\n                self.target_h1_mem, self.target_h1_spike, self.target_h2_mem, self.target_h2_spike = \\\n                    self.target_h1_mem.cuda(self.args.device), self.target_h1_spike.cuda(self.args.device), self.target_h2_mem.cuda(self.args.device), self.target_h2_spike.cuda(self.args.device)\n\n            q_eval, self.eval_h1_mem, self.eval_h1_spike, self.eval_h2_mem, self.eval_h2_spike = \\\n                self.eval_snn(inputs, self.eval_h1_mem, self.eval_h1_spike, self.eval_h2_mem, self.eval_h2_spike)  # inputs维度为(40,96)，得到的q_eval维度为(40,n_actions)\n            q_target, self.target_h1_mem, self.target_h1_spike, self.target_h2_mem, self.target_h2_spike = \\\n                self.target_snn(inputs_next, self.target_h1_mem, self.target_h1_spike, self.target_h2_mem, self.target_h2_spike)\n\n            # 把q_eval维度重新变回(8, 5,n_actions)\n            q_eval = q_eval.view(episode_num, self.n_agents, -1)\n            q_target = q_target.view(episode_num, self.n_agents, -1)\n            q_evals.append(q_eval)\n            q_targets.append(q_target)\n        # 得的q_eval和q_target是一个列表，列表里装着max_episode_len个数组，数组的的维度是(episode个数, n_agents，n_actions)\n        # 把该列表转化成(episode个数, max_episode_len， n_agents，n_actions)的数组\n        q_evals = torch.stack(q_evals, dim=1)\n        q_targets = torch.stack(q_targets, dim=1)\n        return q_evals, q_targets\n\n    def init_hidden(self, episode_num, num_env):\n        # 为每个episode中的每个agent都初始化一个eval_hidden、target_hidden\n        self.eval_h1_mem = self.eval_h1_spike = torch.zeros(episode_num, self.n_agents, num_env,\n                                                            self.args.rnn_hidden_dim)\n        self.target_h1_mem = self.target_h1_spike = torch.zeros(episode_num, self.n_agents, num_env,\n                                                                self.args.rnn_hidden_dim)\n        self.eval_h2_mem = self.eval_h2_spike = torch.zeros(episode_num, self.n_agents, num_env,\n                                                            self.args.rnn_hidden_dim)\n        self.target_h2_mem = self.target_h2_spike = torch.zeros(episode_num, self.n_agents, num_env,\n                                                                self.args.rnn_hidden_dim)\n\n    def init_hidden_learn(self, episode_num):\n        # 为每个episode中的每个agent都初始化一个eval_hidden、target_hidden\n        self.eval_h1_mem = self.eval_h1_spike = torch.zeros(episode_num, self.n_agents,\n                                                            self.args.rnn_hidden_dim)\n        self.target_h1_mem = self.target_h1_spike = torch.zeros(episode_num, self.n_agents,\n                                                                self.args.rnn_hidden_dim)\n        self.eval_h2_mem = self.eval_h2_spike = torch.zeros(episode_num, self.n_agents,\n                                                            self.args.rnn_hidden_dim)\n        self.target_h2_mem = self.target_h2_spike = torch.zeros(episode_num, self.n_agents,\n                                                                self.args.rnn_hidden_dim)\n\n    def save_model(self, train_step):\n        num = str(train_step // self.args.save_cycle)\n        if not os.path.exists(self.model_dir):\n            os.makedirs(self.model_dir)\n        torch.save(self.eval_snn.state_dict(),  self.model_dir + '/' + num + '_snn_net_params_{}.pkl'.format(self.args.num_run))\n\n    def load_model(self, train_step):\n        num = str(train_step // self.args.save_cycle)\n\n        path = torch.load(self.model_dir + '/' + num + '_snn_net_params.pkl'.format(self.args.num_run))\n\n        self.eval_snn.load_state_dict(path)\n\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/STAG/preprocessoing/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/STAG/preprocessoing/common.py",
    "content": "\"\"\"\npreprocess\n\"\"\"\n\nfrom typing import Union\nimport abc\n\nimport torch as tc\nimport numpy as np\n\nclass Preprocessing(abc.ABC, tc.nn.Module):\n    def forward(\n        self,\n        curr_obs: Union[tc.LongTensor, tc.FloatTensor],\n        prev_action: tc.LongTensor,\n        prev_reward: tc.FloatTensor,\n        prev_done: tc.FloatTensor\n    ) -> tc.FloatTensor:\n        \"\"\"\n        Creates an input vector for a meta-learning agent.\n\n        Args:\n            curr_obs: either tc.LongTensor or tc.FloatTensor of shape [B, ...].\n            prev_action: tc.LongTensor of shape [B, ...]\n            prev_reward: tc.FloatTensor of shape [B, ...]\n            prev_done: tc.FloatTensor of shape [B, ...]\n\n        Returns:\n            tc.FloatTensor of shape [B, ..., ?]\n        \"\"\"\n        pass\n\n\ndef one_hot_torch(ys: tc.LongTensor, depth: int, device) -> tc.FloatTensor:\n    \"\"\"\n    Applies one-hot encoding to a batch of vectors.\n\n    Args:\n        ys: tc.LongTensor of shape [B].\n        depth: int specifying the number of possible y values.\n\n    Returns:\n        the one-hot encodings of tensor ys.\n    \"\"\"\n\n    vecs_shape = list(ys.shape) + [depth]\n    vecs = tc.zeros(dtype=tc.float32, size=vecs_shape).to(device)\n    vecs.scatter_(dim=-1, index=ys.unsqueeze(-1),\n                  src=tc.ones(dtype=tc.float32, size=vecs_shape).to(device))\n    return vecs.float()\n\n\ndef one_hot(ys: int, depth: int) -> list:\n    \"\"\"\n    Applies one-hot encoding to a batch of vectors.\n\n    Args:\n        ys: tc.LongTensor of shape [B].\n        depth: int specifying the number of possible y values.\n\n    Returns:\n        the one-hot encodings of tensor ys.\n    \"\"\"\n\n    letter = [0 for _ in range(depth)]\n    letter[ys-1] = 1\n    letter = np.array(letter)\n    # print(letter)\n    return letter\n"
  },
  {
    "path": "examples/Social_Cognition/MAToM-SNN/STAG/runner.py",
    "content": "import os\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport torch.multiprocessing as mp\n\nfrom common_sr.srollout import RolloutWorker\nfrom agents.sagent import Agents\nfrom common_sr.replay_buffer import ReplayBuffer\n\nimport time\nfrom tqdm import tqdm\n\nclass Runner:\n    def __init__(self, env, args):\n        self.env = env\n\n        self.agents = Agents(args)\n\n        self.rolloutWorker = RolloutWorker(env, self.agents, args)\n\n        if not args.evaluate:\n            self.buffer = ReplayBuffer(args)\n        self.args = args\n        self.win_rates = []\n        self.episode_rewards = []\n\n        # 用来保存plt和pkl\n        self.save_path = self.args.result_dir + '/' + args.alg + args.exp_dir\n        if not os.path.exists(self.save_path):\n            os.makedirs(self.save_path)\n\n    def run(self, num):\n        time_steps, train_steps, evaluate_steps = 0, 0, -1\n        pbar = tqdm(self.args.n_steps)\n        if self.args.load_model == False:\n            while time_steps < self.args.n_steps:\n                # print('Run {}, time_steps {}'.format(num, time_steps))\n                if time_steps // self.args.evaluate_cycle > evaluate_steps:\n                    win_rate, episode_reward = self.evaluate()\n                    # episode_reward = [i for i in [2, 3]]\n                    self.episode_rewards.append(episode_reward)\n                    # self.plt(time_steps // self.args.evaluate_cycle)\n                    # print(time_steps // self.args.evaluate_cycle)\n                    evaluate_steps += self.args.evaluate_epoch\n                # 收集self.args.n_episodes个episodes\n                episodes = []\n                start = time.time()\n                episode_batch, _, _, steps = self.rolloutWorker.generate_episode()\n                end = time.time()\n                # print(end - start, 'sample with multiprocessing:', self.args.process)\n                time_steps += steps\n                pbar.update(steps)\n                self.buffer.store_episode(episode_batch)\n\n                start = time.time()\n                for train_step in range(self.args.train_steps):\n                    mini_batch = self.buffer.sample(min(self.buffer.current_size, self.args.batch_size))\n                    if self.args.alg.find('o') > -1:\n                        self.agents.train(mini_batch, train_steps, self.args.epsilon)\n                    else:\n                        self.agents.train(mini_batch, train_steps)\n                    train_steps += 1\n                end = time.time()\n                    # print(end - start, 'training')\n        pbar.close()\n        win_rate, episode_reward = self.evaluate()\n        # print('win_rate is ', win_rate)\n        self.win_rates.append(win_rate)\n        self.episode_rewards.append(episode_reward)\n        if self.args.load_model == False:\n            self.plt(num)\n\n    def evaluate(self):\n        win_number = 0\n        episode_rewards = (0, 0)  # cumulative rewards\n\n        _, episode_rewards, win_tag, _ = self.rolloutWorker.generate_episode(evaluate=True)\n\n        episode_rewards = [episode_rewards[i] / self.args.evaluate_epoch / self.args.process for i in range(len(episode_rewards))]\n        return win_number / self.args.evaluate_epoch, episode_rewards\n\n    def plt(self, num):\n        # plt.figure()\n        # plt.ylim([0, 105])\n        # plt.cla()\n        # plt.subplot(2, 1, 1)\n        # plt.plot(range(len(self.win_rates)), self.win_rates)\n        # plt.xlabel('step*{}'.format(self.args.evaluate_cycle))\n        # plt.ylabel('win_rates')\n        #\n        # plt.subplot(2, 1, 2)\n        # plt.plot(range(len(self.episode_rewards)), self.episode_rewards)\n        # plt.xlabel('step*{}'.format(self.args.evaluate_cycle))\n        # plt.ylabel('episode_rewards')\n        #\n        # plt.savefig(self.save_path + '/plt_{}.png'.format(num), format='png')\n        # np.save(self.save_path + '/win_rates_{}'.format(num), self.win_rates)\n        # np.save(self.save_path + '/episode_rewards_{}'.format(num), self.episode_rewards)\n        # plt.close()\n        # plt.figure()\n        # plt.ylim([0, 105])\n        # plt.cla()\n        # plt.plot(2, 1, 1)\n        # plt.plot(range(len(self.episode_rewards)), self.episode_rewards[0][0])\n        # plt.xlabel('step*{}'.format(self.args.evaluate_cycle))\n        # plt.ylabel('episode_rewards_A')\n        #\n        # plt.plot(2, 1, 2)\n        # plt.plot(range(len(self.episode_rewards)), self.episode_rewards[0][1])\n        # plt.xlabel('step*{}'.format(self.args.evaluate_cycle))\n        # plt.ylabel('episode_rewards_B')\n        # plt.savefig(self.save_path + '/plt_{}.png'.format(num), format='png')\n        # np.save(self.save_path + '/win_rates_{}'.format(num), self.win_rates)\n        np.save(self.save_path + '/episode_rewards_{}'.format(num), self.episode_rewards)   #\n        # print(self.episode_rewards)\n        # plt.close()"
  },
  {
    "path": "examples/Social_Cognition/ReadMe.md",
    "content": "\n"
  },
  {
    "path": "examples/Social_Cognition/SmashVat/dqn.py",
    "content": "import os\r\nimport time\r\nimport random\r\nfrom itertools import count\r\nfrom collections import namedtuple, deque\r\n\r\nimport numpy as np\r\nimport pandas as pd\r\nimport torch\r\nimport imageio\r\nfrom torch import nn, optim\r\nimport torch.nn.functional as F\r\nfrom torch.utils.tensorboard import SummaryWriter\r\n\r\nfrom side_effect_eval import *\r\nfrom qnets import *\r\nfrom environment import *\r\n\r\nTransition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done'))\r\n\r\n\r\nclass ReplayBuffer:\r\n    def __init__(self, capacity):\r\n        self.capacity = capacity\r\n        self.memory = deque(maxlen=capacity)\r\n\r\n    def push(self, *args):\r\n        self.memory.append(Transition(*args))\r\n        if len(self.memory) > self.capacity:\r\n            self.memory.popleft()\r\n\r\n    def sample(self, batch_size):\r\n        batch = random.sample(self.memory, batch_size)\r\n        return batch\r\n\r\n\r\nclass AnseEmpDQN:\r\n    def __init__(self, env, net_type='SNN',\r\n                 init_buffer_size=10000, replay_buffer_size=100000, batch_size=100, target_update_interval=1000,\r\n                 weight_sep=20, weight_emp_impact=None):\r\n\r\n        self.env = env\r\n\r\n        input_dim = 3\r\n        output_dim = env.action_space.n\r\n        assert net_type in ['SNN', 'ANN']\r\n        self.net_type = net_type\r\n        if self.net_type == 'SNN':\r\n            self.policy_net = SNNQnet(input_dim, output_dim).cuda()\r\n            self.target_net = SNNQnet(input_dim, output_dim).cuda()\r\n        elif self.net_type == 'ANN':\r\n            self.policy_net = CNNQnet(input_dim, output_dim).cuda()\r\n            self.target_net = CNNQnet(input_dim, output_dim).cuda()\r\n\r\n        self.init_buffer_size = init_buffer_size\r\n        self.replay_buffer_size = replay_buffer_size\r\n        self.replay_buffer = ReplayBuffer(replay_buffer_size)\r\n        self.batch_size = batch_size\r\n        self.target_update_interval = target_update_interval\r\n\r\n        # empathy\r\n        self.num_others = env.num_humans\r\n        self.baseline = StepwiseInactionModel(noop_action=env.actions.noop)\r\n        self.baselines_others = [StepwiseInactionModel(noop_action=env.actions.noop)] * self.num_others\r\n        self.deviation = AttainableUtilityMeasure(uf_num=30, uf_discount=0.99)\r\n        self.weight_sep = weight_sep\r\n        self.weight_emp_impact = weight_emp_impact if weight_emp_impact is not None else weight_sep\r\n\r\n    def state2tensor(self, state):\r\n        state_arr = self.env._decode(state)\r\n        array = np.zeros([3] + list(self.env.p.map_shape))\r\n        array[0] = self.env.map  # [0]为环境信息\r\n        for i, pos in enumerate(self.env.p.vat_pos):\r\n            array[0][pos] = self.env.cells.vat if state_arr[i] else self.env.cells.empty  # 根据state还原各个缸的状态，因为replay_buffer中存的状态对应的map与当前的可能不一样\r\n        agent_pos = tuple(state_arr[self.env.num_vats:self.env.num_vats + 2])\r\n        array[1][agent_pos] = 1  # [1]为agent位置信息\r\n        for pos in self.env.human_pos:\r\n            array[2][tuple(pos)] = 1  # [2]为human位置信息\r\n        return torch.Tensor(array).float().cuda()\r\n\r\n    def epsilon_greedy(self, net, state, epsilon):\r\n        num_actions = self.env.action_space.n\r\n        p = np.ones(num_actions) * epsilon / num_actions\r\n        state_tensor = self.state2tensor(state).unsqueeze(0)\r\n        best_action = torch.argmax(net(state_tensor)).item()\r\n        p[best_action] += 1 - epsilon\r\n        action = np.random.choice(self.env.action_space.n, p=p)\r\n        return action\r\n\r\n    def train(self, lr=1e-3, num_episodes=10000, gamma=0.99,\r\n              epsilon_start=1, epsilon_end=0.01, decay_start=0.05, decay_end=0.95,\r\n              checkpoint_interval=1000,\r\n              checkpoint_dir='./models/',\r\n              log_dir='./log'):\r\n        policy_net_opt = optim.Adam(self.policy_net.parameters(), lr=lr)\r\n\r\n        epsilons = [epsilon_end] * num_episodes\r\n        decay_start_episode = int(num_episodes * decay_start)\r\n        decay_end_episode = int(num_episodes * decay_end)\r\n        epsilons[0:decay_start_episode] = np.full(decay_start_episode, epsilon_start)\r\n        epsilons[decay_start_episode:decay_end_episode] = np.linspace(epsilon_start, epsilon_end, decay_end_episode - decay_start_episode)\r\n\r\n        tb_logger = SummaryWriter(log_dir=log_dir)\r\n        pd_timestep_logger = pd.DataFrame(columns=['loss', 'reward', 'impact', 'aup_impact', 'empathy_impact'])\r\n        pd_episode_logger = pd.DataFrame(columns=['step', 'ep_reward', 'ep_reward_mean',\r\n                                                  'ep_impact', 'ep_aup_impact', 'ep_empathy_impact',\r\n                                                  'num_vat_broken', 'num_human_saved'])\r\n\r\n        def optimize_net():\r\n            transitions = self.replay_buffer.sample(self.batch_size)\r\n            batch = Transition(*zip(*transitions))\r\n\r\n            state_batch = torch.stack([self.state2tensor(state[0]) for state in batch.state]).float().cuda()\r\n            action_batch = torch.tensor(batch.action).unsqueeze(-1).cuda()\r\n            reward_batch = torch.tensor(batch.reward, dtype=torch.float).cuda()\r\n            next_state_batch = torch.stack([self.state2tensor(state[0]) for state in batch.next_state]).float().cuda()\r\n            done_batch = torch.tensor(batch.done).cuda()\r\n            best_actions = self.policy_net(next_state_batch).max(1)[1].detach()\r\n            next_state_values = self.target_net(next_state_batch).gather(1, best_actions.unsqueeze(1)).squeeze(1).detach()\r\n\r\n            expected_state_action_values = reward_batch + gamma * next_state_values * torch.logical_not(done_batch)\r\n            state_action_values = self.policy_net(state_batch).gather(1, action_batch)\r\n            loss = F.mse_loss(state_action_values, expected_state_action_values.unsqueeze(1))\r\n\r\n            policy_net_opt.zero_grad()\r\n            loss.backward()\r\n            policy_net_opt.step()\r\n            tb_logger.add_scalar('loss', loss.item(), total_step)\r\n            pd_timestep_logger.loc[total_step, 'loss'] = loss.item()\r\n\r\n        def calculate_impact(prev_states, prev_action, current_states):\r\n            if self.weight_sep==0 and self.weight_emp_impact==0:\r\n                return 0, 0, 0\r\n\r\n            prev_state_agent = prev_states[0]\r\n            current_state_agent = current_states[0]\r\n            prev_states_others = prev_states[1:]\r\n            current_states_others = current_states[1:]\r\n\r\n            baseline_state_agent = self.baseline.calculate(prev_state_agent, prev_action, current_state_agent)\r\n            self.deviation.update(prev_state_agent, prev_action, current_state_agent)\r\n            dev_self = self.deviation.calculate(current_state_agent, baseline_state_agent, lambda x: abs(np.minimum(0, x)))\r\n            weighted_dev_self = -self.weight_sep * dev_self\r\n\r\n            dev_others = []\r\n            for prev_state, current_state, baseline in zip(prev_states_others, current_states_others, self.baselines_others):\r\n                baseline_state = baseline.calculate(prev_state, prev_action, current_state)\r\n                dev_others.append(self.deviation.calculate(current_state, baseline_state, lambda x: x))\r\n            dev_others_mean = sum(dev_others) / len(dev_others) if len(dev_others) > 0 else 0\r\n            weighted_dev_others = self.weight_emp_impact * dev_others_mean\r\n            total_impact = weighted_dev_self + weighted_dev_others\r\n            return total_impact, weighted_dev_self, weighted_dev_others\r\n\r\n        # 初始化replay buffer\r\n        state = self.env.reset()\r\n        for i in range(self.init_buffer_size):\r\n            # action = self.epsilon_greedy(self.empathy_net, state[0], epsilon_start)\r\n            action = self.epsilon_greedy(self.policy_net, state[0], epsilon_start)\r\n            next_state, reward, done, info = self.env.step(action)\r\n\r\n            if self.weight_sep != 0 or self.weight_emp_impact != 0:  # 如果都等于0则退化为标准DQN\r\n                impact, _, _ = calculate_impact(state, action, next_state)\r\n                reward += impact\r\n                reward /= (self.weight_sep + self.weight_emp_impact) / 2  # 正则化操作，防止reward绝对值过大使得网络发散\r\n\r\n            self.replay_buffer.push(state, action, reward, next_state, done)\r\n            if done:\r\n                state = self.env.reset()\r\n            else:\r\n                state = next_state\r\n\r\n        # 开始训练\r\n        total_step = 0\r\n        for episode in range(num_episodes):\r\n            if episode % checkpoint_interval == 0 and episode != 0:\r\n                torch.save(self.policy_net.state_dict(), checkpoint_dir + f\"/checkpoint_{episode}.pth\")\r\n\r\n            tb_logger.add_scalar('epsilon', epsilons[episode], episode + 1)\r\n            state = self.env.reset()\r\n            if episode % 100 == 0:\r\n                print('Episode {} of {}'.format(episode + 1, num_episodes))\r\n\r\n            episode_reward = 0\r\n            episode_impact = 0\r\n            ep_aup_impact = 0\r\n            ep_empathy_impact = 0\r\n            for step in count():\r\n                if total_step % self.target_update_interval == 0:\r\n                    self.target_net.load_state_dict(self.policy_net.state_dict())\r\n\r\n                action = self.epsilon_greedy(self.policy_net, state[0], epsilons[episode])\r\n                next_state, reward, done, info = self.env.step(action)\r\n\r\n                impact, aup_impact, empathy_impact = calculate_impact(state, action, next_state)\r\n                if self.weight_sep != 0 or self.weight_emp_impact != 0:  # 如果都等于0则退化为标准DQN\r\n                    reward += impact\r\n                    reward /= (self.weight_sep + self.weight_emp_impact) / 2\r\n\r\n                episode_reward += reward\r\n                episode_impact += impact\r\n                ep_aup_impact += aup_impact\r\n                ep_empathy_impact += empathy_impact\r\n\r\n                self.replay_buffer.push(state, action, reward, next_state, done)\r\n                optimize_net()\r\n\r\n                pd_timestep_logger.loc[total_step, 'reward'] = reward\r\n                pd_timestep_logger.loc[total_step, 'impact'] = impact\r\n                pd_timestep_logger.loc[total_step, 'aup_impact'] = aup_impact\r\n                pd_timestep_logger.loc[total_step, 'empathy_impact'] = empathy_impact\r\n\r\n                if done:\r\n                    # print(f'step: {step}, reward: {episode_reward:.2f}')\r\n                    tb_logger.add_scalar('step', step + 1, episode + 1)\r\n                    tb_logger.add_scalar('reward', episode_reward, episode + 1)\r\n                    tb_logger.add_scalar('ep-reward-mean', episode_reward / (step + 1), episode + 1)\r\n                    tb_logger.add_scalar('impact', episode_impact, episode + 1)\r\n\r\n                    pd_episode_logger.loc[episode, 'step'] = step + 1\r\n                    pd_episode_logger.loc[episode, 'ep_reward'] = episode_reward\r\n                    pd_episode_logger.loc[episode, 'ep_reward_mean'] = episode_reward / (step + 1)\r\n                    pd_episode_logger.loc[episode, 'ep_impact'] = episode_impact\r\n                    pd_episode_logger.loc[episode, 'ep_aup_impact'] = ep_aup_impact\r\n                    pd_episode_logger.loc[episode, 'ep_empathy_impact'] = ep_empathy_impact\r\n\r\n                    env_map = self.env.map\r\n                    num_vat_broken = 0\r\n                    num_human_saved = 0\r\n                    for pos in self.env.p.vat_pos:\r\n                        if env_map[pos] != self.env.cells.vat:\r\n                            num_vat_broken += 1\r\n                            if pos in self.env.p.human_pos:\r\n                                num_human_saved += 1\r\n                    pd_episode_logger.loc[episode, 'num_vat_broken'] = num_vat_broken\r\n                    pd_episode_logger.loc[episode, 'num_human_saved'] = num_human_saved\r\n                    break\r\n                else:\r\n                    state = next_state\r\n                total_step += 1\r\n        tb_logger.close()\r\n        pd_timestep_logger.to_csv(log_dir + '/timestep_logger.csv', index=False)\r\n        pd_episode_logger.to_csv(log_dir + '/episode_logger.csv', index=False)\r\n\r\n    def save(self, path):\r\n        if not os.path.exists(path):\r\n            os.makedirs(path)\r\n        torch.save(self.policy_net.state_dict(), path + \"/policy_net.pth\")\r\n\r\n    def load(self, path):\r\n        self.policy_net.load_state_dict(torch.load(path + \"/policy_net.pth\"))\r\n\r\n    def run(self, gif_name=None):\r\n        self.policy_net.eval()\r\n        obs = self.env.reset()\r\n        images = []\r\n        for step in count():\r\n            # self.env.render()\r\n            images.append(self.env.render(mode='rgb_array'))\r\n\r\n            obs_tensor = self.state2tensor(obs[0]).unsqueeze(0)\r\n            action_p = self.policy_net(obs_tensor)\r\n            action = torch.argmax(self.policy_net(obs_tensor)).item()\r\n            print(self.env.actions(action))\r\n            next_state, reward, done, _ = self.env.step(action)\r\n\r\n            time.sleep(1)\r\n            if done:\r\n                images.append(self.env.render(mode='rgb_array'))\r\n                images.append(self.env.render(mode='rgb_array'))\r\n                break\r\n            else:\r\n                obs = next_state\r\n        if gif_name is not None:\r\n            imageio.mimsave(gif_name, images, 'GIF', duration=0.5)\r\n\r\n\r\ndef set_seed(seed=114514):\r\n    random.seed(seed)  # replay buffer 中使用了random.sample\r\n    np.random.seed(seed)  # e-greedy 中使用了np.random_choice\r\n    torch.manual_seed(seed)\r\n    torch.cuda.manual_seed(seed)\r\n\r\n\r\ndef main():\r\n    set_seed(1919810)\r\n    env = BasicVatGoalEnv()\r\n    env.render()\r\n    model = AnseEmpDQN(env, net_type='ANN', init_buffer_size=10000, replay_buffer_size=100000, batch_size=100, target_update_interval=1000)\r\n\r\n    model.train(lr=1e-3, num_episodes=10000, gamma=0.99,\r\n                epsilon_start=1, epsilon_end=0.01, decay_start=0.05, decay_end=0.95,\r\n                checkpoint_interval=1000,\r\n                checkpoint_dir='./models/ANN-BasicVatGoalEnv-test',\r\n                log_dir='./log/ANN-BasicVatGoalEnv-test')\r\n    model.save(\"./models/ANN-BasicVatGoalEnv-test\")\r\n    model.load(\"./models/ANN-BasicVatGoalEnv-test\")\r\n    model.run(\"ANN-BasicVatGoalEnv-test.gif\")\r\n\r\n\r\nif __name__ == '__main__':\r\n    main()\r\n"
  },
  {
    "path": "examples/Social_Cognition/SmashVat/environment.py",
    "content": "import copy\nfrom enum import IntEnum\nimport numpy as np\nimport gymnasium as gym\nimport imageio\n\nfrom window import Window\n\n\nclass HumanVatGoalEnv(gym.Env):\n    \"\"\"General HumanVatGoalEnv Class\"\"\"\n\n    class Actions(IntEnum):\n        noop = 0\n        left = 1\n        right = 2\n        up = 3\n        down = 4\n        smash = 5  # destroying all surrounding vat(s) at same time\n        pass\n\n    class Cells(IntEnum):\n        empty = 0\n        wall = 1\n        goal = 2\n        vat = 3\n        pass\n\n    class CellsRender(object):\n        # empty = np.full(shape=(64, 64, 4), fill_value=255)\n        wall = imageio.imread('./materials/wall.png')\n        goal = imageio.imread('./materials/goal.png')\n        vat = imageio.imread('./materials/vat.png')\n        agent = imageio.imread('./materials/agent.png')\n        human = imageio.imread('./materials/adult.png')\n\n    class Params(object):\n        \"\"\"Params for the Environment\"\"\"\n\n        def __init__(\n                self,\n                map_shape=(7, 5),\n                agent_pos=(1, 2),\n                human_pos=((2, 3),),\n                vat_pos=((3, 2), (2, 3),),\n                goal_pos=((-2, 2),),\n                wall_pos=(),  # user-defined walls (default surrounding walls are NOT included here)\n                max_steps=50,\n                env_name=None,\n                # human_policy = \"noop\"\n        ):\n            self.map_shape = map_shape\n            self.agent_pos = agent_pos\n            self.human_pos = human_pos\n            self.vat_pos = vat_pos\n            self.goal_pos = goal_pos\n            self.wall_pos = wall_pos\n            self.max_steps = max_steps\n            self.env_name = env_name\n\n    def __init__(self, env_params=Params()):\n        super(HumanVatGoalEnv, self).__init__()\n        self.p = env_params\n\n        self.actions = HumanVatGoalEnv.Actions\n        self.cells = HumanVatGoalEnv.Cells\n\n        self.action_space = gym.spaces.Discrete(len(self.actions))\n        # observation_dim: dimension of observation space\n        #   With fixed human_pos, currently we only consider: each vat state (broken or not) + agent_pos (x, y)\n        #   e.g. [2,2,7,5] means vat1*vat2*agent_y*agent_x\n        self.observation_dim = [2] * len(self.p.vat_pos) + list(self.p.map_shape)\n        self.observation_space = gym.spaces.Discrete(np.prod(self.observation_dim))\n\n        self.window = None\n\n        if self.p.env_name is not None:\n            self.descr = self.p.env_name\n        else:\n            n_human = len(self.p.human_pos)\n            n_vat = len(self.p.vat_pos)\n            n_goal = len(self.p.goal_pos)\n            self.descr = self.__class__.__name__.lower()\n            self.descr = self.descr.replace(\"human\", \"human\" + str(n_human) + \"-\")\n            self.descr = self.descr.replace(\"vat\", \"vat\" + str(n_vat) + \"-\")\n            self.descr = self.descr.replace(\"goal\", \"goal\" + str(n_goal) + \"-\")\n\n        self.num_vats = len(self.p.vat_pos)\n        self.num_humans = len(self.p.human_pos)  # for empathy qlearning\n\n        self.reset()\n        return\n\n    def reset(self):\n        # reset env state\n        self._gen_map()\n\n        # reset agent & human state\n        self.agent_pos = np.array(self.p.agent_pos)\n        self.human_pos = [np.array(pos) for pos in self.p.human_pos]\n\n        # reset episode statistics\n        self.step_count = 0\n        self.total_reward = 0\n        # self.total_hidden_reward = 0  ##TODO\n        # self.total_human_rewards = [0]*len(self.p.human_pos)  ##TODO\n\n        # generate observation from state\n        obs = self._gen_obs()\n        return obs\n\n    def _gen_map(self):\n        self.map = np.full(shape=self.p.map_shape, fill_value=self.cells.empty)\n\n        # place default surrounding walls\n        self.map[0] = self.cells.wall\n        self.map[-1] = self.cells.wall\n        self.map[:, 0] = self.cells.wall\n        self.map[:, -1] = self.cells.wall\n\n        # place user-defined walls\n        for pos in self.p.wall_pos:\n            self.map[pos] = self.cells.wall\n\n        # place goals\n        for pos in self.p.goal_pos:\n            self.map[pos] = self.cells.goal\n\n        # place vats\n        for pos in self.p.vat_pos:\n            self.map[pos] = self.cells.vat\n\n        return\n\n    def _gen_obs(self):\n        # internal state\n        s_env = [(self.map[pos] == self.cells.vat) for pos in self.p.vat_pos]\n        s_agent = list(self.agent_pos)\n        s_human = [list(pos) for pos in self.human_pos]\n\n        # external observation\n        obs_agent = self._encode(s_env + s_agent)\n        obs_human = [self._encode(s_env + s_h) for s_h in s_human]\n\n        return [obs_agent, *obs_human]\n\n    def _encode(self, obs):\n        i = 0\n        for idx, dim in enumerate(self.observation_dim):\n            i *= dim\n            i += obs[idx]\n        assert 0 <= i <= self.observation_space.n\n        return i\n\n    def _decode(self, i):\n        out = []\n        for dim in reversed(self.observation_dim):\n            out.append(i % dim)\n            i = i // dim\n        assert i == 0\n        return list(reversed(out))\n\n    def render(self, mode=\"window\", cell_size=64, style=\"realistic\"):\n        if mode == \"rgb_array\":\n            return self._gen_img(cell_size, style)\n\n        elif mode == \"window\":\n            if not isinstance(self.window, Window):\n                self.window = Window(self.descr)\n\n            if self.window.is_open():\n                img = self._gen_img(cell_size, style)\n                self.window.show_img(img)\n                self.window.show(block=False)\n                return\n\n    def _gen_img(self, cell_size, style):\n        if style == \"abstract\":\n\n            h, w = self.map.shape\n            img = np.full(shape=(h * cell_size, w * cell_size, 3), fill_value=255)\n\n            def draw_cell(cell_type, cell_pos, cell_size):\n                if cell_type == self.cells.empty:\n                    pass\n                elif cell_type == self.cells.wall:\n                    x, y = np.array(cell_pos) * cell_size\n                    img[x: (x + cell_size), y: (y + cell_size), :] = np.array(\n                        [128, 128, 128]\n                    )\n                elif cell_type == self.cells.goal:\n                    x, y = np.array(cell_pos) * cell_size\n                    img[x: (x + cell_size), y: (y + cell_size), :] = np.array([0, 255, 0])\n                elif cell_type == self.cells.vat:\n                    x, y = np.array(cell_pos) * cell_size\n                    img[x: (x + cell_size), y: (y + cell_size), :] = np.array([255, 0, 0])\n                else:\n                    pass\n\n            def draw_agent(pos, cell_size):\n                # draw rectangle\n                x, y = np.array(pos) * cell_size\n                img[\n                int(x + 0.2 * cell_size): int(x + 0.8 * cell_size + 1),\n                int(y + 0.2 * cell_size): int(y + 0.8 * cell_size + 1),\n                :,\n                ] = np.array([0, 0, 0])\n\n                # # draw cicle\n                # def fill_circle(img, cx, cy, r, color):\n                #     h, w = img.shape[0:2]\n                #     X, Y = np.ogrid[:h, :w]\n                #     mask = (X-cx)**2+(Y-cy)**2 <= r**2\n                #     img[mask] = color\n                #     # return img\n                # x0, y0 = np.array(pos) * cell_size\n                # sub_img = img[int(x0):int(x0+cell_size), int(y0):int(y0+cell_size), :]\n                # cx, cy, r = np.array([0.5, 0.5, 0.3]) * cell_size\n                # color = np.array([0,0,0])\n                # fill_circle(sub_img, cx, cy, r, color)\n                pass\n\n            def draw_human(pos, cell_size):\n                # # draw rectangle\n                # x, y = np.array(pos) * cell_size\n                # img[int(x+0.2*cell_size):int(x+0.8*cell_size+1),\n                #     int(y+0.2*cell_size):int(y+0.8*cell_size+1),:] = np.array([255,255,0])\n\n                # draw cicle\n                def fill_circle(img, cx, cy, r, color):\n                    h, w = img.shape[0:2]\n                    X, Y = np.ogrid[:h, :w]\n                    mask = (X - cx) ** 2 + (Y - cy) ** 2 <= r ** 2\n                    img[mask] = color\n                    # return img\n\n                x0, y0 = np.array(pos) * cell_size\n                sub_img = img[\n                          int(x0): int(x0 + cell_size), int(y0): int(y0 + cell_size), :\n                          ]\n                cx, cy, r = np.array([0.5, 0.5, 0.36]) * cell_size\n                color = np.array([255, 255, 0])\n                fill_circle(sub_img, cx, cy, r, color)\n                pass\n\n            def draw_gridline(cell_size):\n                img[::cell_size, :] = np.array([255, 255, 255])\n                img[-1::-cell_size, :] = np.array([255, 255, 255])\n                img[:, ::cell_size] = np.array([255, 255, 255])\n                img[:, -1::-cell_size] = np.array([255, 255, 255])\n                pass\n\n            for i in range(h):\n                for j in range(w):\n                    draw_cell(self.map[i, j], (i, j), cell_size)\n            for pos in self.human_pos:\n                draw_human(list(pos), cell_size)\n            draw_agent(self.agent_pos, cell_size)\n            draw_gridline(cell_size)\n\n            return img.astype(np.uint8)\n\n        elif style == \"realistic\":\n\n            cell_size = 64  ##In realistic mode, we fix cell_size to 64 to avoid resize of image\n            h, w = self.map.shape\n            img = np.full(shape=(h * cell_size, w * cell_size, 4), fill_value=255)\n\n            def draw_cell_realistic(cell_type, cell_pos, cell_size):\n                if cell_type == self.cells.empty:\n                    # img_paste(cell_pos, cell_size, HumanVatGoalEnv.CellsRender.empty)\n                    pass\n                elif cell_type == self.cells.wall:\n                    img_paste(cell_pos, cell_size, HumanVatGoalEnv.CellsRender.wall)\n                elif cell_type == self.cells.goal:\n                    img_paste(cell_pos, cell_size, HumanVatGoalEnv.CellsRender.goal)\n                elif cell_type == self.cells.vat:\n                    img_paste(cell_pos, cell_size, HumanVatGoalEnv.CellsRender.vat)\n                else:\n                    pass\n\n            def draw_agent_realistic(pos, cell_size):\n                img_paste(pos, cell_size, HumanVatGoalEnv.CellsRender.agent)\n                pass\n\n            def draw_human_realistic(pos, cell_size):\n                img_paste(pos, cell_size, HumanVatGoalEnv.CellsRender.human)\n                pass\n\n            ##paste a png(RGBA) image on to existing img depending on the alpha channel of img_in\n            def img_paste(pos, cell_size, img_in):\n                x, y = np.array(pos) * cell_size\n                img[x:(x + cell_size), y:(y + cell_size), 0:3][img_in[:, :, 3] > 128] = img_in[:, :, 0:3][img_in[:, :, 3] > 128]\n\n            def draw_gridline_realistic(cell_size):\n                img[::cell_size, :] = np.array([255, 255, 255, 0])\n                img[-1::-cell_size, :] = np.array([255, 255, 255, 0])\n                img[:, ::cell_size] = np.array([255, 255, 255, 0])\n                img[:, -1::-cell_size] = np.array([255, 255, 255, 0])\n                pass\n\n            for i in range(h):\n                for j in range(w):\n                    draw_cell_realistic(self.map[i, j], (i, j), cell_size)\n            for pos in self.human_pos:\n                draw_human_realistic(list(pos), cell_size)\n            draw_agent_realistic(self.agent_pos, cell_size)\n            draw_gridline_realistic(cell_size)\n\n            return img.astype(np.uint8)\n\n    def close(self):\n        if isinstance(self.window, Window):\n            self.window.close()\n            self.window = None\n        pass\n\n    def step(self, action):\n        def apply_env_dynamics(cur_pos, action):\n            assert type(cur_pos) == np.ndarray\n\n            reward = -0.01\n\n            cur_cell = self.map[tuple(cur_pos)]\n            if cur_cell == self.cells.vat:  # got trapped in vat\n                next_pos = cur_pos\n                # reward += -0.05  ##TODO: should we give explicit penalty here?\n            else:\n                next_pos = copy.deepcopy(cur_pos)\n                if action == self.actions.left:\n                    next_pos += [0, -1]\n                elif action == self.actions.right:\n                    next_pos += [0, +1]\n                elif action == self.actions.up:\n                    next_pos += [-1, 0]\n                elif action == self.actions.down:\n                    next_pos += [+1, 0]\n                elif action == self.actions.noop:\n                    next_pos = next_pos\n                elif action == self.actions.smash:\n                    next_pos = next_pos\n                    smash_pos = next_pos + [[0, -1], [0, +1], [-1, 0], [+1, 0]]\n                    for sp in smash_pos:\n                        smash_cell = self.map[tuple(sp)]\n                        if smash_cell == self.cells.vat:\n                            # agent will smash all surrounding vat(s)\n                            self.map[tuple(sp)] = self.cells.empty\n                            # self.total_hidden_reward -= 0.2  ##TODO\n                            pass\n                        pass\n                    pass\n                else:\n                    raise\n\n                next_cell = self.map[tuple(next_pos)]\n                if next_cell == self.cells.empty:\n                    next_pos = next_pos\n                elif next_cell == self.cells.wall:\n                    next_pos = cur_pos\n                elif next_cell == self.cells.goal:\n                    next_pos = next_pos\n                    reward += 1.0  ##TODO: differnet rewards for reaching different goals\n                elif next_cell == self.cells.vat:\n                    next_pos = next_pos\n                else:\n                    raise\n\n            return next_pos, reward\n\n        self.agent_pos, reward = apply_env_dynamics(self.agent_pos, action)\n\n        ##TODO: add more human dynamics here\n        for i, _ in enumerate(self.human_pos):\n            self.human_pos[i], _ = apply_env_dynamics(\n                self.human_pos[i], self.actions.noop\n            )  ##TODO: add different human policy\n            ##TODO: human reward may be different with that of agent\n\n        self.step_count += 1\n        done = (self.step_count >= self.p.max_steps) or (\n                self.map[tuple(self.agent_pos)] == self.cells.goal\n        )\n\n        obs = self._gen_obs()\n\n        self.total_reward += reward\n        info = {\"total_reward\": round(self.total_reward, 2)}\n\n        return obs, reward, done, info\n\n\nclass BasicGoalEnv(HumanVatGoalEnv):\n    def __init__(self):\n        super().__init__(\n            env_params=HumanVatGoalEnv.Params(\n                map_shape=(7, 5),\n                agent_pos=(1, 2),\n                human_pos=(),\n                vat_pos=(),\n                goal_pos=((-2, 2),),\n                wall_pos=(),  # user-defined walls (default surrounding walls are NOT included here)\n                max_steps=50,\n                env_name=\"basic-1goal-env\",\n            )\n        )\n\n\nclass BasicVatGoalEnv(HumanVatGoalEnv):\n    def __init__(self):\n        super().__init__(\n            env_params=HumanVatGoalEnv.Params(\n                map_shape=(7, 5),\n                agent_pos=(1, 2),\n                human_pos=(),\n                vat_pos=((3, 2),),\n                goal_pos=((-2, 2),),\n                wall_pos=(),  # user-defined walls (default surrounding walls are NOT included here)\n                max_steps=50,\n                env_name=\"basic-1vat-1goal-env\",\n            )\n        )\n\n\nclass BasicHumanVatGoalEnv(HumanVatGoalEnv):\n    def __init__(self):\n        super().__init__(\n            env_params=HumanVatGoalEnv.Params(\n                map_shape=(7, 5),\n                agent_pos=(1, 2),\n                human_pos=((3, 2),),\n                vat_pos=((3, 2),),\n                goal_pos=((-2, 2),),\n                wall_pos=(),  # user-defined walls (default surrounding walls are NOT included here)\n                max_steps=50,\n                env_name=\"basic-1human-1vat-1goal-env\",\n            )\n        )\n\n\nclass CShapeVatGoalEnv(HumanVatGoalEnv):\n    def __init__(self):\n        super().__init__(\n            env_params=HumanVatGoalEnv.Params(\n                map_shape=(7, 5),\n                agent_pos=(1, 3),\n                human_pos=(),\n                vat_pos=((3, 2), (3, 3)),\n                goal_pos=((-2, 3),),\n                wall_pos=(),  # user-defined walls (default surrounding walls are NOT included here)\n                max_steps=50,\n                env_name=\"C-shape-2vat-1goal-env\",\n            )\n        )\n\n\nclass CShapeHumanVatGoalEnv(HumanVatGoalEnv):\n    def __init__(self):\n        super().__init__(\n            env_params=HumanVatGoalEnv.Params(\n                map_shape=(7, 5),\n                agent_pos=(1, 3),\n                human_pos=((3, 2),),\n                vat_pos=((3, 2), (3, 3)),\n                goal_pos=((-2, 3),),\n                wall_pos=(),  # user-defined walls (default surrounding walls are NOT included here)\n                max_steps=50,\n                env_name=\"C-shape-1human-2vat-1goal-env\",\n            )\n        )\n\n\nclass SShapeVatGoalEnv(HumanVatGoalEnv):\n    def __init__(self):\n        super().__init__(\n            env_params=HumanVatGoalEnv.Params(\n                map_shape=(10, 7),\n                agent_pos=(1, 1),\n                human_pos=(),\n                vat_pos=((3, 1), (3, 2), (3, 3), (6, 3), (6, 4), (6, 5)),\n                goal_pos=((-2, -2),),\n                wall_pos=(),\n                max_steps=100,\n                env_name=\"S-shape-6vat-1goal-env\",\n            )\n        )\n\n\nclass SideHumanVatGoalEnv(HumanVatGoalEnv):\n    def __init__(self):\n        super().__init__(\n            env_params=HumanVatGoalEnv.Params(\n                map_shape=(7, 5),\n                agent_pos=(1, 1),\n                human_pos=((3, 3),),\n                vat_pos=((3, 3),),\n                goal_pos=((5, 1),),\n                wall_pos=(),  # user-defined walls (default surrounding walls are NOT included here)\n                max_steps=50,\n                env_name=\"side-1human-1vat-1goal-env\",\n            )\n        )\n\n\nclass SmashAndDetourEnv(HumanVatGoalEnv):\n    def __init__(self):\n        super().__init__(\n            env_params=HumanVatGoalEnv.Params(\n                map_shape=(7, 5),\n                agent_pos=(1, 1),\n                human_pos=((2, 3),),\n                vat_pos=((2, 3), (3, 2), (3, 3),),\n                goal_pos=((5, 3),),\n                wall_pos=(),  # user-defined walls (default surrounding walls are NOT included here)\n                max_steps=50,\n                env_name=\"side-1human-1vat-1goal-env\",\n            )\n        )\n\n\nclass CmpxHumanVatGoalEnv(HumanVatGoalEnv):\n    def __init__(self):\n        super().__init__(\n            env_params=HumanVatGoalEnv.Params(\n                map_shape=(10, 7),\n                agent_pos=(1, 3),\n                human_pos=((4, 2), (5, 5), (7, 2)),\n                vat_pos=((2, 3), (3, 1), (4, 1), (4, 2), (5, 5), (6, 4), (6, 5)),\n                goal_pos=((-2, -2),),\n                wall_pos=((1, 1), (8, 1), (8, 2)),\n                max_steps=100,\n                env_name=\"complex-3human-7vat-1goal-env\",\n            )\n        )\n\n\nenv_list = ['BasicGoalEnv',\n            'BasicVatGoalEnv', 'BasicHumanVatGoalEnv', 'SideHumanVatGoalEnv',\n            'CShapeVatGoalEnv', 'CShapeHumanVatGoalEnv',\n            'SShapeVatGoalEnv', 'SmashAndDetourEnv',\n            'CmpxHumanVatGoalEnv']\n\nif __name__ == \"__main__\":\n\n    import time\n\n    params = HumanVatGoalEnv.Params()\n    params.map_shape = (9, 7)\n    params.agent_pos = (1, 4)\n    params.human_pos = ((3, 2), (2, 3), (2, 1))\n    params.vat_pos = ((3, 2), (2, 3), (5, 2))\n    params.goal_pos = ((-2, 2), (6, 4))\n    params.wall_pos = ((5, 5), (4, 5), (4, 4))\n    params.max_steps = 30\n    params.env_name = \"example-env\"\n    # params.human_policy = \"noop\"\n    env = HumanVatGoalEnv(env_params=params)\n\n    # env = BasicVatGoalEnv()\n\n    print(\"observation_dim: \", env.observation_dim)\n    print(\"action_space: \", env.action_space)\n    print(\"observation_space: \", env.observation_space)\n    print(\"descr: \", env.descr)\n\n    # env.render()\n    # time.sleep(8.0)\n    # env.close()\n\n    for i_episode in range(5):\n        obs = env.reset()\n        for t in range(100):\n            env.render(mode=\"window\")\n            time.sleep(0.1)\n\n            action = env.action_space.sample()\n            print(\n                \"step=%2d\\t\" % (env.step_count),\n                env._decode(obs[0]),\n                \"->\",\n                env.actions(action),\n                end=\"\",\n            )\n\n            obs, reward, done, info = env.step(action)\n            print(\"\\treward=%.2f\" % (reward))\n\n            if done:\n                env.render(mode=\"window\")\n                print(\"done!\")\n                print(info)\n                print(\"-\" * 20)\n                time.sleep(0.2)\n                break\n\n    env.close()\n"
  },
  {
    "path": "examples/Social_Cognition/SmashVat/main.py",
    "content": "import argparse\r\nimport os\r\nimport random\r\nimport time\r\n\r\nimport numpy as np\r\nimport torch\r\n\r\nfrom environment import *\r\nfrom dqn import AnseEmpDQN\r\n\r\nparser = argparse.ArgumentParser()\r\nparser.add_argument('--cuda', type=int, default=0)\r\nparser.add_argument('--seed', type=int, default=1919810)\r\nparser.add_argument('--env', type=str, default='BasicHumanVatGoalEnv')\r\nparser.add_argument('--net-type', type=str, default='ANN')\r\n\r\nparser.add_argument('--init-buffer-size', type=int, default=10000)\r\nparser.add_argument('--replay-buffer-size', type=int, default=100000)\r\nparser.add_argument('--batch-size', type=int, default=100)\r\nparser.add_argument('--target-update-interval', type=int, default=1000)\r\n\r\nparser.add_argument('--weight-sep', type=float, default=20)\r\nparser.add_argument('--weight-emp-impact', type=float, default=None)\r\n\r\nparser.add_argument('--lr', type=float, default=1e-4)\r\nparser.add_argument('--num-episodes', type=int, default=10000)\r\nparser.add_argument('--gamma', type=float, default=0.99)\r\nparser.add_argument('--epsilon-start', type=float, default=1.0)\r\nparser.add_argument('--epsilon-end', type=float, default=0.01)\r\nparser.add_argument('--decay-start', type=float, default=0.05)\r\nparser.add_argument('--decay-end', type=float, default=0.95)\r\nparser.add_argument('--checkpoint-interval', type=int, default=1000)\r\n\r\nparser.add_argument('--log-dir', type=str, default=None)\r\nparser.add_argument('--model-save-dir', type=str, default=None)\r\nparser.add_argument('--gif-dir', type=str, default=None)\r\n\r\n\r\ndef set_seed(seed=114514):\r\n    random.seed(seed)  # replay buffer 中使用了random.sample\r\n    np.random.seed(seed)  # e-greedy 中使用了np.random_choice\r\n    torch.manual_seed(seed)\r\n    torch.cuda.manual_seed(seed)\r\n\r\n\r\ndef save_args(args, log_dir):\r\n    filename = os.path.join(log_dir, 'args.txt')\r\n    with open(filename, 'w') as file:\r\n        for arg in vars(args):\r\n            file.write('{}: {}\\n'.format(arg, getattr(args, arg)))\r\n\r\n\r\ndef make_dirs(args, timestamp):\r\n    if args.log_dir is not None:\r\n        log_dir = args.log_dir\r\n    else:\r\n        log_dir = os.path.join('./logs', args.net_type + '-' + args.env, timestamp)\r\n        os.makedirs(log_dir, exist_ok=True)\r\n\r\n    if args.model_save_dir is not None:\r\n        model_save_dir = args.model_save_dir\r\n    else:\r\n        model_save_dir = os.path.join('./models', args.net_type + '-' + args.env, timestamp)\r\n        os.makedirs(model_save_dir, exist_ok=True)\r\n\r\n    if args.gif_dir is not None:\r\n        gif_dir = args.gif_dir\r\n    else:\r\n        gif_dir = os.path.join('./gifs', args.net_type + '-' + args.env)\r\n        os.makedirs(gif_dir, exist_ok=True)\r\n\r\n    return log_dir, model_save_dir, gif_dir\r\n\r\n\r\ndef main():\r\n    args = parser.parse_args()\r\n    set_seed(args.seed)\r\n    timestamp = time.strftime(\"%Y%m%d-%H%M%S\")\r\n    log_dir, model_save_dir, gif_dir = make_dirs(args, timestamp)\r\n    save_args(args, log_dir)\r\n\r\n    assert args.env in env_list\r\n    env = eval(args.env)()\r\n    with torch.cuda.device(args.cuda):\r\n        model = AnseEmpDQN(env,\r\n                           net_type=args.net_type,\r\n                           init_buffer_size=args.init_buffer_size,\r\n                           replay_buffer_size=args.replay_buffer_size,\r\n                           batch_size=args.batch_size,\r\n                           target_update_interval=args.target_update_interval,\r\n                           weight_sep=args.weight_sep,\r\n                           weight_emp_impact=args.weight_emp_impact)\r\n        model.train(lr=args.lr, num_episodes=args.num_episodes, gamma=args.gamma,\r\n                    epsilon_start=args.epsilon_start, epsilon_end=args.epsilon_end,\r\n                    decay_start=args.decay_start, decay_end=args.decay_end,\r\n                    checkpoint_interval=args.checkpoint_interval, checkpoint_dir=model_save_dir,\r\n                    log_dir=log_dir)\r\n        model.save(model_save_dir)\r\n\r\n        gif_name = os.path.join(gif_dir, '{}.gif'.format(timestamp))\r\n        model.run(gif_name=gif_name)\r\n        model.env.close()\r\n\r\n\r\nif __name__ == '__main__':\r\n    main()\r\n"
  },
  {
    "path": "examples/Social_Cognition/SmashVat/manual_control.py",
    "content": "import time\nfrom window import Window\nfrom environment import *\n\n\nclass ManualControl(object):\n    \"\"\"ManualControl of HumanVatGoalEnv Class\"\"\"\n\n    def __init__(self, env=HumanVatGoalEnv()):\n        self.env = env\n        self.window = Window(env.descr + \"[manual]\")\n        self.window.reg_key_press_handler(self._key_handler)\n\n    def display(self):\n        self._reset()\n\n        # Blocking event loop\n        self.window.show(block=True)\n        return\n\n    def _redraw(self):\n        img = self.env.render(\"rgb_array\", cell_size=64)\n        self.window.show_img(img)\n        return\n\n    def _reset(self):\n        self.env.reset()\n        print(\"-\" * 20)\n        print(\n            \"step=%2d \" % (self.env.step_count),\n            \"obs=\",\n            [self.env._decode(o) for o in self.env._gen_obs()],\n            end=\" -> \",\n            flush=True,\n        )\n        self._redraw()\n        return\n\n    def _step(self, action):\n\n        obs, reward, done, info = self.env.step(action)\n        print(self.env.actions(action), \"\\treward=%.2f\" % (reward))\n        print(\n            \"step=%2d \" % (self.env.step_count),\n            \"obs=\",\n            [self.env._decode(o) for o in self.env._gen_obs()],\n            end=\" -> \",\n            flush=True,\n        )\n\n        self._redraw()\n        if done:\n            print(\"done!\")\n            print(info)\n            time.sleep(0.2)\n            self._reset()\n        return\n\n    def _key_handler(self, event):\n        # print('pressed', event.key)\n\n        if event.key == \"escape\" or event.key == \"q\":\n            self.window.close()\n        elif event.key == \"backspace\":\n            self._reset()\n        elif event.key == \"left\":\n            self._step(self.env.actions.left)\n        elif event.key == \"right\":\n            self._step(self.env.actions.right)\n        elif event.key == \"up\":\n            self._step(self.env.actions.up)\n        elif event.key == \"down\":\n            self._step(self.env.actions.down)\n        elif event.key == \" \":  # Spacebar\n            self._step(self.env.actions.noop)\n        elif event.key == \"enter\":  # Smash\n            self._step(self.env.actions.smash)\n\n        return\n\n\nif __name__ == \"__main__\":\n\n    mc = ManualControl(env=HumanVatGoalEnv())\n    mc.display()\n\n"
  },
  {
    "path": "examples/Social_Cognition/SmashVat/qnets.py",
    "content": "import torch\r\nfrom torch import nn\r\nfrom braincog.base.encoder import encoder\r\nfrom braincog.base.node import LIFNode\r\n\r\n\r\nclass CNNQnet(nn.Module):\r\n    def __init__(self, input_dim, output_dim):\r\n        super(CNNQnet, self).__init__()\r\n        self.cnn = nn.Sequential(\r\n            nn.Conv2d(in_channels=input_dim, out_channels=16, kernel_size=3, padding=1, padding_mode='replicate'),\r\n            nn.ReLU(),\r\n            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3),\r\n            nn.ReLU(),\r\n            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),\r\n            nn.ReLU(),\r\n            nn.AdaptiveAvgPool2d((1, 1)),\r\n            nn.Flatten()\r\n        )\r\n        self.l = nn.Sequential(\r\n            nn.Linear(in_features=64, out_features=128),\r\n            nn.ReLU(),\r\n            nn.Linear(in_features=128, out_features=output_dim)\r\n        )\r\n\r\n    def forward(self, x):\r\n        x = self.cnn(x)\r\n        x = self.l(x)\r\n        return x\r\n\r\n\r\nclass SNNQnet(nn.Module):\r\n    def __init__(self, input_dim, output_dim,\r\n                 step=4, node=LIFNode, encode_type='direct'):\r\n        super(SNNQnet, self).__init__()\r\n        self.step = step\r\n        self.encoder = encoder.Encoder(step=step, encode_type=encode_type)\r\n        self.cnn = nn.Sequential(\r\n            nn.Conv2d(in_channels=input_dim, out_channels=16, kernel_size=3, padding=1, padding_mode='replicate'),\r\n            node(),\r\n            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3),\r\n            node(),\r\n            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),\r\n            node(),\r\n            nn.AdaptiveAvgPool2d((1, 1)),\r\n            nn.Flatten()\r\n        )\r\n        self.l = nn.Sequential(\r\n            nn.Linear(in_features=64, out_features=128),\r\n            node(),\r\n            nn.Linear(in_features=128, out_features=output_dim)\r\n        )\r\n\r\n    def forward(self, input):\r\n        inputs = self.encoder(input)\r\n        outputs = []\r\n        self.reset()\r\n        for t in range(self.step):\r\n            x = inputs[t]\r\n            x = self.cnn(x)\r\n            x = self.l(x)\r\n            outputs.append(x)\r\n        return sum(outputs) / len(outputs)\r\n\r\n    def reset(self):\r\n        for mod in self.modules():\r\n            if hasattr(mod, 'n_reset'):\r\n                mod.n_reset()\r\n"
  },
  {
    "path": "examples/Social_Cognition/SmashVat/side_effect_eval.py",
    "content": "# Code Reference:\n# https://github.com/deepmind/deepmind-research/blob/master/side_effects_penalties/side_effects_penalty.py\n# https://github.com/alexander-turner/attainable-utility-preservation/blob/master/agents/model_free_aup.py\n\n\nimport numpy as np\nfrom collections import defaultdict\n\n\nclass StepwiseInactionModel(object):\n    \"\"\"Calculate the next state after one noop action from current state\"\"\"\n\n    def __init__(self, noop_action=None):\n        self._noop_action = noop_action\n        self._baseline_state = None\n        self._inaction_model = defaultdict(lambda: defaultdict(lambda: 0))  # init _inaction_model[state][next_state]=0\n        return\n\n    def reset(self, baseline_state):\n        self._baseline_state = baseline_state\n        return\n\n    def _sample(self, state):\n        \"\"\"Sample next_state based on its history frequency\"\"\"\n        d = self._inaction_model[state]\n        counts = np.array(list(d.values()))\n        assert len(counts) > 0 and sum(counts) > 0\n        index = np.random.choice(a=len(counts), p=counts / sum(counts))\n        return list(d.keys())[index]\n\n    def calculate(self, prev_state, prev_action, current_state):\n        \"\"\"Update inaction transition model, and predict the noop baseline state \"\"\"\n        # update\n        if prev_action == self._noop_action:\n            self._inaction_model[prev_state][current_state] += 1\n        # predict\n        if prev_state in self._inaction_model:\n            self._baseline_state = self._sample(prev_state)\n        else:\n            self._baseline_state = prev_state\n        return self._baseline_state\n\n\nclass AttainableUtilityMeasure(object):\n    def __init__(self, uf_num=10, uf_discount=0.99):\n\n        # initialize a group of auxiliary utility functions\n        self._uf_values = [defaultdict(lambda: 0.0) for _ in range(uf_num)]\n        # initialize random rewards for auxiliary tasks\n        self._uf_rewards = [defaultdict(lambda: np.random.random()) for _ in range(uf_num)]\n\n        assert 0 <= uf_discount < 1.0, \"uf_discount should be between [0, 1)\"\n        self._uf_discount = uf_discount\n\n        # initialize update counts and confidence for the value estimation of each state\n        self._uf_update_cnts = [defaultdict(lambda: 0) for _ in range(uf_num)]\n        self._confid_func = lambda x: 1.0 if x > 0 else 0.0  # confident if state value has been updated\n\n        # record predecessors of states for backward value iteration \n        self._predecessors = defaultdict(set)\n        return\n\n    def update(self, prev_state, prev_action, current_state):\n        \"\"\"Update estimations of Auxiliary Utility Functions with new transitions\"\"\"\n        del prev_action  # unused in value iteration\n        # update transitions\n        self._predecessors[current_state].add(prev_state)\n        # iterative update values\n        for reward, u_value, update_cnt in zip(self._uf_rewards, self._uf_values, self._uf_update_cnts):\n            seen = set()\n            queue = [current_state]\n            while queue:\n                s_to = queue.pop(0)\n                seen.add(s_to)\n                for s_from in self._predecessors[s_to]:\n                    v = reward[s_from] + self._uf_discount * u_value[s_to]\n                    if u_value[s_from] < v:\n                        u_value[s_from] = v\n                        if s_from not in seen:\n                            queue.append(s_from)\n                    update_cnt[s_from] += 1  # update counts for the value estimation of each state\n        return\n\n    def calculate(self, current_state, baseline_state, dev_fun=lambda diff: abs(np.minimum(0, diff))):\n        \"\"\"Calculate the deviation between two states, with given deviation_function\"\"\"\n        cs_values = [u_value[current_state] for u_value in self._uf_values]\n        bs_values = [u_value[baseline_state] for u_value in self._uf_values]\n        diff_values = [(cs_value - bs_value) for cs_value, bs_value in zip(cs_values, bs_values)]\n\n        cs_confids = [self._confid_func(update_cnt[current_state]) for update_cnt in self._uf_update_cnts]\n        bs_confids = [self._confid_func(update_cnt[baseline_state]) for update_cnt in self._uf_update_cnts]\n        diff_confids = [(cs_confid * bs_confid) for cs_confid, bs_confid in zip(cs_confids, bs_confids)]\n\n        deviations = [diff_confid * dev_fun(diff_value) * (1. - self._uf_discount)\n                      for diff_confid, diff_value in zip(diff_confids, diff_values)]\n        return sum(deviations) / len(deviations)\n\n    def _get_aup_value(self, state):\n        \"\"\"For debugging purpose, \n        The Attainable Utility Preservation (aup) value are based on the estimation \n        towards an imaginary baseline_state of u_value=0.0 and confidence=1.0\"\"\"\n        dev_fun = lambda diff: abs(diff)\n\n        cs_values = [u_value[state] for u_value in self._uf_values]\n        bs_values = [0.0] * len(self._uf_values)\n        diff_values = [(cs_value - bs_value) for cs_value, bs_value in zip(cs_values, bs_values)]\n\n        cs_confids = [self._confid_func(update_cnt[state]) for update_cnt in self._uf_update_cnts]\n        bs_confids = [1.0] * len(self._uf_update_cnts)\n        diff_confids = [(cs_confid * bs_confid) for cs_confid, bs_confid in zip(cs_confids, bs_confids)]\n\n        deviations = [diff_confid * dev_fun(diff_value) * (1. - self._uf_discount)\n                      for diff_confid, diff_value in zip(diff_confids, diff_values)]\n        return sum(deviations) / len(deviations)\n\n    def _get_avgd_confid(self, state):\n        \"\"\"For debugging purpose\"\"\"\n        s_confids = [self._confid_func(update_cnt[state]) for update_cnt in self._uf_update_cnts]\n        return sum(s_confids) / len(s_confids)\n\n    def _get_u_values(self, state):\n        \"\"\"For debugging purpose\"\"\"\n        return [u_value[state] for u_value in self._uf_values]\n"
  },
  {
    "path": "examples/Social_Cognition/SmashVat/window.py",
    "content": "# Code modified from:\n# https://github.com/maximecb/gym-minigrid/blob/master/gym_minigrid/window.py\n\nimport sys\nimport numpy as np\nimport matplotlib.pyplot as plt\n\n\nclass Window(object):\n    \"\"\"Interactive Window for Image Display, using matplotlib\"\"\"\n\n    def __init__(self, title):\n\n        self.fig, self.ax = plt.subplots()\n        self.ax.axis(\"off\")  # clear x-axis and y-axis\n\n        self.title = title\n        self.set_window_title(self.title)\n\n        self.key_press_handler = self._default_key_press_handler\n        self.reg_key_press_handler(self.key_press_handler)\n\n        self.img_shown = None\n\n        return\n\n    def set_window_title(self, title):\n        self.title = title\n        # https://stackoverflow.com/questions/5812960/change-figure-window-title-in-pylab\n        self.fig.canvas.manager.set_window_title(self.title)\n        return\n\n    def reg_key_press_handler(self, key_press_handler):\n        self.key_press_handler = key_press_handler\n        self.fig.canvas.mpl_connect(\"key_press_event\", self.key_press_handler)\n        return\n\n    def _default_key_press_handler(self, event):\n        print(\"press\", event.key)\n        sys.stdout.flush()\n        if event.key == \"escape\":\n            self.close()\n\n    def show(self, block=True):\n\n        if not self.is_open():\n\n            # https://stackoverflow.com/questions/31729948/matplotlib-how-to-show-a-figure-that-has-been-closed\n            # if window has been closed by plt.close()\n            # create a dummy figure and use its manager to display \"fig\"\n            dummy = plt.figure()\n            new_manager = dummy.canvas.manager\n            new_manager.canvas.figure = self.fig\n            self.fig.set_canvas(new_manager.canvas)\n\n            self.set_window_title(self.title)\n            self.reg_key_press_handler(self.key_press_handler)\n\n        if not block:\n            plt.ion()\n        else:\n            plt.ioff()\n\n        plt.show()\n\n        # https://stackoverflow.com/questions/28269157/plotting-in-a-non-blocking-way-with-matplotlib\n        # https://stackoverflow.com/questions/53758472/why-is-plt-pause-not-described-in-any-tutorials-if-it-is-so-essential-or-am-i\n        plt.pause(0.001)\n\n        return\n\n    def show_img(self, img):\n\n        if self.img_shown == None:\n            self.img_shown = self.ax.imshow(img)\n        else:\n            self.img_shown.set_data(img)\n\n        self.fig.canvas.draw()\n\n        # https://stackoverflow.com/questions/28269157/plotting-in-a-non-blocking-way-with-matplotlib\n        # https://stackoverflow.com/questions/53758472/why-is-plt-pause-not-described-in-any-tutorials-if-it-is-so-essential-or-am-i\n        plt.pause(0.001)\n\n        return\n\n    def close(self):\n        plt.close(self.fig)\n        return\n\n    def is_open(self):\n        # https://stackoverflow.com/questions/7557098/matplotlib-interactive-mode-determine-if-figure-window-is-still-displayed\n        return bool(plt.get_fignums())\n\n\nif __name__ == \"__main__\":\n\n    window = Window(\"TestWindow\")\n\n    def on_press(event):\n        print(\"press\", event.key)\n        sys.stdout.flush()\n        if event.key == \"escape\":\n            window.close()\n        elif event.key == \"x\":\n            img = np.full(shape=(7 * 32, 5 * 32, 3), fill_value=55).astype(np.uint8)\n            window.show_img(img)\n        elif event.key == \"c\":\n            img = np.full(shape=(7 * 32, 5 * 32, 3), fill_value=155).astype(np.uint8)\n            window.show_img(img)\n\n    window.reg_key_press_handler(on_press)\n\n    print(window.is_open())  # True\n\n    window.show(block=True)\n    print(window.is_open())  # False\n\n    window.show(block=False)\n    print(window.is_open())  # True\n\n    plt.pause(2.0)\n    window.close()\n    print(window.is_open())  # False\n\n    window.show(block=True)\n\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/README.md",
    "content": "# ToCM\nThis code accompanies the paper \"A Brain-inspired Theory of Collective Mind Model for Efficient Social Cooperation\".\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/agent/controllers/ToCMController.py",
    "content": "from collections import defaultdict\nfrom copy import deepcopy\n\nimport numpy as np\nimport torch\nfrom torch.distributions import OneHotCategorical\nfrom environments import Env\nfrom agent.models.ToCMModel import ToCMModel\nfrom networks.ToCM.action import Actor, AttentionActor\n\n\nclass ToCMController:\n\n    def __init__(self, config):\n        self.model = ToCMModel(config).to(config.DEVICE).eval()\n        # 17 7 256 2\n        # TODO TODO TODO!!!!\n        self.env_type = config.ENV_TYPE\n        self.actor = Actor(config.IN_DIM+2*(config.num_agents-1), config.ACTION_SIZE, config.ACTION_HIDDEN, config.ACTION_LAYERS).to(config.DEVICE) # TODO FEAT\n        self.expl_decay = config.EXPL_DECAY\n        self.expl_noise = config.EXPL_NOISE\n        self.expl_min = config.EXPL_MIN\n        self.init_rnns()\n        self.init_buffer()\n        self.device = config.DEVICE\n        self.config = config\n\n    def receive_params(self, params):\n        self.model.load_state_dict(params['model'])\n        self.actor.load_state_dict(params['actor'])\n\n    def init_buffer(self):\n        self.buffer = defaultdict(list)\n\n    def init_rnns(self):\n        self.prev_rnn_state = None\n        self.prev_actions = None\n\n    def dispatch_buffer(self):\n        total_buffer = {k: np.asarray(v, dtype=np.float32) for k, v in self.buffer.items()}\n        last = np.zeros_like(total_buffer['done'])\n        last[-1] = 1.0\n        total_buffer['last'] = last\n        self.init_rnns()\n        self.init_buffer()\n        return total_buffer\n\n    def update_buffer(self, items):\n        for k, v in items.items():  # TODO TODO TODO\n            if v is not None:\n                self.buffer[k].append(v.squeeze(0).cpu().detach().clone().numpy())\n\n    @torch.no_grad()\n    def step(self, observations, avail_actions, nn_mask):\n        \"\"\"\"\n        Compute policy's action distribution from inputs, and sample an\n        action. Calls the model to produce mean, log_std, value estimate, and\n        next recurrent state.  Moves inputs to device and returns outputs back\n        to CPU, for the sampler.  Advances the recurrent state of the agent.\n        (no grad)\n        \"\"\"\n        state = self.model(observations, self.prev_actions, self.prev_rnn_state, nn_mask)\n        if self.prev_actions == None:\n            # self.prev_actions = torch.zeros((1, 2, 7)).to(self.config.DEVICE)\n            self.prev_actions = torch.zeros((observations.shape[0], observations.shape[1], 5)).to(self.config.DEVICE)\n\n        next_state = self.model.transition(self.prev_actions, state)    # TODO\n        next_feat = next_state.get_features().detach()  # TODO\n        observations_next_other, _ = self.model.observation_decoder(next_feat)  # TODO\n        if nn_mask is not None:\n            nn_mask = nn_mask.to(self.device)\n\n        action, pi = self.actor(torch.cat((observations, observations_next_other[:, :, -(self.config.num_agents-1)*4:-(self.config.num_agents-1)*2]),\n                                          -1))\n        # print(action, pi)\n        # print(\"aviail_action:\", avail_actions)\n        if avail_actions is not None:\n            pi[avail_actions == 0] = -1e10\n            action_dist = OneHotCategorical(logits=pi)\n            action = action_dist.sample()\n\n        self.advance_rnns(state)\n        self.prev_actions = action.clone()  # no use\n        return action.squeeze(0).clone().to(self.device)\n\n    def advance_rnns(self, state):\n        self.prev_rnn_state = deepcopy(state)\n\n    def exploration(self, action):\n        \"\"\"\n        :param action: action to take, shape (1,)\n        :return: action of the same shape passed in, augmented with some noise\n        \"\"\"\n        for i in range(action.shape[0]):\n            if np.random.uniform(0, 1) < self.expl_noise:\n                index = torch.randint(0, action.shape[-1], (1, ), device=action.device)\n                transformed = torch.zeros(action.shape[-1])\n                transformed[index] = 1.\n                action[i] = transformed\n        self.expl_noise *= self.expl_decay\n        self.expl_noise = max(self.expl_noise, self.expl_min)\n        return action.to(self.device)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/agent/learners/ToCMLearner.py",
    "content": "import sys\nfrom copy import deepcopy\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\n\nfrom agent.memory.ToCMMemory import ToCMMemory\nfrom agent.models.ToCMModel import ToCMModel\nfrom agent.optim.loss import model_loss, actor_loss, value_loss, actor_rollout\nfrom agent.optim.utils import advantage\nfrom environments import Env\nfrom networks.ToCM.action import Actor, AttentionActor\nfrom networks.ToCM.critic import MADDPGCritic\n\ntorch.autograd.set_detect_anomaly = True\ndef orthogonal_init(tensor, gain=1):\n    if tensor.ndimension() < 2:\n        raise ValueError(\"Only tensors with 2 or more dimensions are supported\")\n\n    rows = tensor.size(0)\n    cols = tensor[0].numel()\n    flattened = tensor.new(rows, cols).normal_(0, 1)\n\n    if rows < cols:\n        flattened.t_()\n\n    # Compute the qr factorization\n    u, s, v = torch.svd(flattened, some=True)\n    if rows < cols:\n        u.t_()\n    q = u if tuple(u.shape) == (rows, cols) else v\n    with torch.no_grad():\n        tensor.view_as(q).copy_(q)\n        tensor.mul_(gain)\n    return tensor\n\n\ndef initialize_weights(mod, scale=1.0, mode='ortho'):\n    for p in mod.parameters():\n        if mode == 'ortho':\n            if len(p.data.shape) >= 2:\n                orthogonal_init(p.data, gain=scale)\n        elif mode == 'xavier':\n            if len(p.data.shape) >= 2:\n                torch.nn.init.xavier_uniform_(p.data)\n\n\nclass ToCMLearner:  # 通过ToCMLearnerConfig来构建\n\n    def __init__(self, config):\n        self.config = config\n\n        self.pretrain_model = False\n        self.shared_model = False  # shared pretrain_model\n        self.pretrain_actor = False\n        self.pretrain_critic = False\n        # 根据ToCMLearnerConfig的参数包括：DEVICE, CAPACITY, SEQ_LENGTH, ACTION_SIZE, IN_DIM, FEAT, HIDDEN......\n        self.model = ToCMModel(config).to(config.DEVICE).eval()  # wsw TODO 这里已经有了device，为什么挂钩子\n        # ToCM Model\n        self.actor = Actor(config.IN_DIM+2*(config.num_agents-1), config.ACTION_SIZE, config.ACTION_HIDDEN, config.ACTION_LAYERS).to(\n            config.DEVICE)  # IN_DIM / FEAT  # TODO\n        self.critic = MADDPGCritic(config.FEAT, config.HIDDEN).to(config.DEVICE)\n        # 关键点是把model actor critic都放到了device上\n        if self.pretrain_model:\n            if not self.shared_model:\n                self.model.load_state_dict(torch.load(self.load_dir + '28_model.pth'))\n            else:\n                initialize_weights(self.model, mode='xavier')  # 先全部初始化\n                # 加载部分预训练权重\n                shared_state_dict = torch.load(self.load_dir + '28_model.pth')\n                ignored_layer_keys = ['observation_encoder.fc1.weight', 'observation_decoder.fc2.weight',\n                                      'observation_decoder.fc2.bias', 'transition._rnn_input_model.0.weight',\n                                      'representation._transition_model._rnn_input_model.0.weight',\n                                      'av_action.model.4.weight', 'av_action.model.4.bias', 'q_action.weight',\n                                      'q_action.bias']\n                for k in ignored_layer_keys:\n                    del shared_state_dict[k]\n                self.model.load_state_dict(shared_state_dict, strict=False)\n            print(\"Load ToCM Model.\")\n        else:\n            initialize_weights(self.model, mode='xavier')\n\n        if self.pretrain_actor:\n            self.actor.load_state_dict(torch.load(self.load_dir + '10_actor.pth'), strict=False)\n        else:\n            initialize_weights(self.actor, mode='xavier')\n\n        if self.pretrain_critic:\n            self.critic.load_state_dict(torch.load(self.load_dir + '10_critic.pth'), strict=False)\n        else:\n            initialize_weights(self.critic, mode='xavier')\n        self.old_critic = deepcopy(self.critic)\n        self.replay_buffer = ToCMMemory(config.CAPACITY, config.SEQ_LENGTH, config.ACTION_SIZE, config.IN_DIM, 2,\n                                           config.DEVICE, config.ENV_TYPE)\n        self.entropy = config.ENTROPY\n        self.step_count = -1\n        self.cur_update = 1\n        self.accum_samples = 0\n        self.total_samples = 0\n        self.init_optimizers()\n        self.n_agents = 2\n        Path(config.LOG_FOLDER).mkdir(parents=True, exist_ok=True)\n        global wandb\n        import wandb\n        wandb.init(dir=config.LOG_FOLDER,\n                   name=str(config.env_name) + '_' + str(2) +\n                        \"_seed\" + str(config.random_seed) + '131',\n                   project=str('mpesnn').upper(),\n                   group=str(config.env_name) )  # TODO\n\n    def init_optimizers(self):\n        self.model_optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.MODEL_LR)\n        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.config.ACTOR_LR, weight_decay=0.00001)\n        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.config.VALUE_LR)   # TODO\n        self.critic_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.critic_optimizer, mode='min', verbose=True)\n\n    def params(self):\n        return {'model': {k: v.cpu() for k, v in self.model.state_dict().items()},\n                'actor': {k: v.cpu() for k, v in self.actor.state_dict().items()},\n                'critic': {k: v.cpu() for k, v in self.critic.state_dict().items()}}\n\n    def step(self, rollout):\n        if self.n_agents != rollout['action'].shape[-2]:\n            self.n_agents = rollout['action'].shape[-2]\n\n        self.accum_samples += len(rollout['action'])  # 5\n        self.total_samples += len(rollout['action'])  # 5\n        self.replay_buffer.append(rollout['observation'], rollout['action'], rollout['reward'], rollout['done'],\n                                  rollout['fake'], rollout['last'], rollout.get('avail_action'))\n        self.step_count += 1\n        if self.accum_samples < self.config.N_SAMPLES:\n            return\n\n        if len(self.replay_buffer) < self.config.MIN_BUFFER_SIZE:\n            return\n\n        self.accum_samples = 0\n        sys.stdout.flush()\n\n        if 20000 > self.step_count >= 10000:\n            self.config.MODEL_EPOCHS = 10\n        if self.step_count >= 20000:\n            self.config.MODEL_EPOCHS = 5\n        for i in range(self.config.MODEL_EPOCHS):\n            samples = self.replay_buffer.sample(self.config.MODEL_BATCH_SIZE)\n            self.train_model(samples)\n\n        for i in range(self.config.EPOCHS):\n            samples = self.replay_buffer.sample(self.config.BATCH_SIZE)\n            # for key, sample in samples.items():\n            #     print(\"key: \", key)\n            #     print(\"sample.shape: \", sample.shape)\n            # print(\"samples.shape: \", samples.shape)\n            self.train_agent(samples)\n\n    def train_model(self, samples):  # world model\n        # print(\"Start train\")\n        self.model.train()\n        loss = model_loss(self.config, self.model, samples['observation'], samples['action'], samples['av_action'],\n                          samples['reward'], samples['done'], samples['fake'], samples['last'])\n        # print(\"loss: \", loss)\n        self.apply_optimizer(self.model_optimizer, self.model, loss, self.config.GRAD_CLIP, name='model')\n        # print(\"backward by model\")\n        self.model.eval()\n\n    def train_agent(self, samples):\n        actions, av_actions, old_policy, imag_feat, imag_state, obs_pred, returns = actor_rollout(samples['observation'],\n                                                                            samples['action'],\n                                                                            samples['last'], self.model,\n                                                                            self.actor,\n                                                                            self.critic if self.config.ENV_TYPE == Env.STARCRAFT  # TODO\n                                                                            else self.old_critic,\n                                                                            self.config)\n        adv = returns.detach() - self.critic(imag_feat, actions).detach()\n        if self.config.ENV_TYPE == Env.STARCRAFT or self.config.ENV_TYPE == Env.MPE:\n            adv = advantage(adv)  # TODO what adv\n        # wandb.log({'Agent/adv': adv.mean()})\n        wandb.log({'Agent/Returns': returns.mean()})  # discount algorithm\n        # wandb.log({'Agent/Returns max': returns.max()})\n        # wandb.log({'Agent/Returns min': returns.min()})\n        # wandb.log({'Agent/Returns std': returns.std()})\n        for epoch in range(self.config.PPO_EPOCHS):\n            inds = np.random.permutation(actions.shape[0])\n            step = 2000\n            for i in range(0, len(inds), step):  # 15\n                self.cur_update += 1\n                idx = inds[i:i + step]\n                loss = actor_loss(self.model, imag_state.map(lambda x: x[idx]) ,\n                                  obs_pred[idx], actions[idx], av_actions[idx] if av_actions is not None else None,\n                                  old_policy[idx], adv[idx], self.actor, self.entropy, self.config)  # TODO\n                self.apply_optimizer(self.actor_optimizer, self.actor, loss, self.config.GRAD_CLIP_POLICY, name='actor')\n                self.entropy *= self.config.ENTROPY_ANNEALING  # 0.001 0.998\n                val_loss = value_loss(self.critic, actions[idx], imag_feat[idx], returns[idx])\n                # print(\"val_loss: \", val_loss)\n                if np.random.randint(20) == 9:\n                    wandb.log({'Agent/val_loss': val_loss, 'Agent/actor_loss': loss})\n                self.apply_optimizer(self.critic_optimizer, self.critic, val_loss, self.config.GRAD_CLIP_POLICY, name='critic')\n                # print(\"backward by agent\")\n                if self.config.ENV_TYPE == Env.MPE and self.cur_update % self.config.TARGET_UPDATE == 0:\n                    self.old_critic = deepcopy(self.critic)\n\n    def apply_optimizer(self, opt, model, loss, grad_clip, name=None):  # type of model\n        opt.zero_grad()\n        loss.backward()  # only here\n        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)  # 100\n        if name is not None and np.random.randint(20) == 9:\n            wandb.log({'Grad of '+name: grad_norm})\n        opt.step()\n\n    def apply_optimizer_scheduler(self, opt, sch, model, loss, grad_clip, name=None):\n        opt.zero_grad()\n        loss.backward()\n        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)  # 100\n        # if name is not None:\n        #     wandb.log({'Grad of ' + name: grad_norm})\n        opt.step()\n        sch.step(loss)\n\n\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/agent/memory/ToCMMemory.py",
    "content": "import numpy as np\nimport torch\n\nfrom environments import Env\n\n\n# 由ToCMLearner的函数创建\nclass ToCMMemory:\n    def __init__(self, capacity, sequence_length, action_size, obs_size, n_agents, device, env_type):\n        self.capacity = capacity\n        self.sequence_length = sequence_length\n        self.action_size = action_size\n        self.obs_size = obs_size\n        self.device = device  # 加入了device\n        self.env_type = env_type\n        self.init_buffer(n_agents, env_type)  # TODO\n\n    def init_buffer(self, n_agents, env_type):  # 初始对环境进行采样，观察和动作都是np.array  # TODO init_buffer specially?\n        self.observations = np.empty((self.capacity, n_agents, self.obs_size), dtype=np.float32)\n        self.actions = np.empty((self.capacity, n_agents, self.action_size), dtype=np.float32)\n        self.av_actions = np.empty((self.capacity, n_agents, self.action_size),  # 3, 5\n                                   dtype=np.float32) if env_type == Env.STARCRAFT or env_type == Env.MPE else None  # TODO need mask?\n        self.rewards = np.empty((self.capacity, n_agents, 1), dtype=np.float32)\n        self.dones = np.empty((self.capacity, n_agents, 1), dtype=np.float32)\n        self.fake = np.empty((self.capacity, n_agents, 1), dtype=np.float32)\n        self.last = np.empty((self.capacity, n_agents, 1), dtype=np.float32)\n        self.next_idx = 0\n        self.n_agents = n_agents\n        self.full = False\n\n    def append(self, obs, action, reward, done, fake, last, av_action):\n        if self.actions.shape[-2] != action.shape[-2]:\n            self.init_buffer(action.shape[-2], self.env_type)\n        for i in range(len(obs)):\n            self.observations[self.next_idx] = obs[i]\n            self.actions[self.next_idx] = action[i]\n            if av_action is not None:\n                self.av_actions[self.next_idx] = av_action[i]\n            self.rewards[self.next_idx] = reward[i]\n            self.dones[self.next_idx] = done[i]\n            self.fake[self.next_idx] = fake[i]\n            self.last[self.next_idx] = last[i]\n            self.next_idx = (self.next_idx + 1) % self.capacity\n            self.full = self.full or self.next_idx == 0\n\n    def tenzorify(self, nparray):\n        return torch.from_numpy(nparray).float()\n\n    def sample(self, batch_size):\n        return self.get_transitions(self.sample_positions(batch_size))\n\n    def process_batch(self, val, idxs, batch_size):  # 这里全部传到了cuda上\n        return torch.as_tensor(val[idxs].reshape(self.sequence_length, batch_size, self.n_agents, -1)).to(self.device)\n\n    def get_transitions(self, idxs):\n        batch_size = len(idxs)\n        vec_idxs = idxs.transpose().reshape(-1)\n        observation = self.process_batch(self.observations, vec_idxs, batch_size)[1:]\n        reward = self.process_batch(self.rewards, vec_idxs, batch_size)[:-1]\n        action = self.process_batch(self.actions, vec_idxs, batch_size)[:-1]\n        av_action = self.process_batch(self.av_actions, vec_idxs, batch_size)[1:] if self.env_type == Env.STARCRAFT else None\n        done = self.process_batch(self.dones, vec_idxs, batch_size)[:-1]\n        fake = self.process_batch(self.fake, vec_idxs, batch_size)[1:]\n        last = self.process_batch(self.last, vec_idxs, batch_size)[1:]\n\n        return {'observation': observation, 'reward': reward, 'action': action, 'done': done, \n                'fake': fake, 'last': last, 'av_action': av_action}\n\n    def sample_position(self):\n        valid_idx = False\n        while not valid_idx:\n            idx = np.random.randint(0, self.capacity if self.full else self.next_idx - self.sequence_length)\n            idxs = np.arange(idx, idx + self.sequence_length) % self.capacity\n            valid_idx = self.next_idx not in idxs[1:]  # Make sure data does not cross the memory index\n        return idxs\n\n    def sample_positions(self, batch_size):\n        return np.asarray([self.sample_position() for _ in range(batch_size)])\n\n    def __len__(self):\n        return self.capacity if self.full else self.next_idx\n\n    def clean(self):\n        self.memory = list()\n        self.position = 0\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/agent/models/ToCMModel.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom environments import Env\nfrom networks.ToCM.dense import DenseBinaryModel, DenseModel\nfrom networks.ToCM.vae import Encoder, Decoder\nfrom networks.ToCM.rnns import RSSMRepresentation, RSSMTransition\n\nfrom thop import profile\nfrom thop import clever_format\n\nclass ToCMModel(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n\n        self.action_size = config.ACTION_SIZE\n\n        self.observation_encoder = Encoder(in_dim=config.IN_DIM, hidden=config.HIDDEN, embed=config.EMBED)  # in_dim:\n        self.observation_decoder = Decoder(embed=config.FEAT, hidden=config.HIDDEN, out_dim=config.IN_DIM)\n\n        self.transition = RSSMTransition(config, config.MODEL_HIDDEN)\n        self.representation = RSSMRepresentation(config, self.transition)  # ann\n        self.reward_model = DenseModel(config.FEAT, 1, config.REWARD_LAYERS, config.REWARD_HIDDEN)  # ann\n        self.pcont = DenseBinaryModel(config.FEAT, 1, config.PCONT_LAYERS, config.PCONT_HIDDEN)\n\n        if config.ENV_TYPE == Env.STARCRAFT:\n            # print(\"config.FEAT, config.ACTION_SIZE, config.PCONT_LAYERS, config.PCONT_HIDDEN:\", config.FEAT,\n            #       config.ACTION_SIZE, config.PCONT_LAYERS, config.PCONT_HIDDEN)  # 1280 7 2 256\n            self.av_action = DenseBinaryModel(config.FEAT, config.ACTION_SIZE, config.PCONT_LAYERS, config.PCONT_HIDDEN)\n        else:\n            self.av_action = None\n\n        self.q_features = DenseModel(config.HIDDEN, config.PCONT_HIDDEN, 1, config.PCONT_HIDDEN)\n        self.q_action = nn.Linear(config.PCONT_HIDDEN, config.ACTION_SIZE)\n\n        # input_encoder = torch.randn(1, 10, config.IN_DIM)\n        # macs, params = profile(self.observation_encoder, inputs=(input,))\n\n\n    def forward(self, observations, prev_actions=None, prev_states=None, mask=None):\n        if prev_actions is None:\n            prev_actions = torch.zeros(observations.size(0), observations.size(1), self.action_size,\n                                       device=observations.device)\n\n        if prev_states is None:\n            prev_states = self.representation.initial_state(prev_actions.size(0), observations.size(1),\n                                                            device=observations.device)\n\n        return self.get_state_representation(observations, prev_actions, prev_states, mask)\n\n    def get_state_representation(self, observations, prev_actions, prev_states, mask):\n        \"\"\"\n        :param observations: size(batch, n_agents, in_dim)\n        :param prev_actions: size(batch, n_agents, action_size)\n        :param prev_states: size(batch, n_agents, state_size)\n        :return: RSSMState\n        \"\"\"\n        # print(\"mask = \", mask)\n        obs_embeds = self.observation_encoder(observations)\n        # print(\"obs_embeds=\", obs_embeds)\n        _, states = self.representation(obs_embeds, prev_actions, prev_states, mask)\n        # print(\"state = \", states)\n        return states\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/agent/optim/loss.py",
    "content": "import numpy as np\nimport torch\nimport wandb\nimport torch.nn.functional as F\n\nfrom agent.optim.utils import rec_loss, compute_return, state_divergence_loss, calculate_ppo_loss, \\\n    batch_multi_agent, log_prob_loss, info_loss\nfrom agent.utils.params import FreezeParameters\nfrom networks.ToCM.rnns import rollout_representation, rollout_policy\n\n\ndef model_loss(config, model, obs, action, av_action, reward, done, fake, last):\n    time_steps = obs.shape[0]\n    batch_size = obs.shape[1]\n    n_agents = obs.shape[2]\n\n    embed = model.observation_encoder(obs.reshape(-1, n_agents, obs.shape[-1]))\n    embed = embed.reshape(time_steps, batch_size, n_agents, -1)\n\n    prev_state = model.representation.initial_state(batch_size, n_agents, device=obs.device)\n    prior, post, deters = rollout_representation(model.representation, time_steps, embed, action, prev_state, last)\n    feat = torch.cat([post.stoch, deters], -1)\n    feat_dec = post.get_features()\n    # decoder inputs:reshape obs[:-1] to self.SEQ_LENGTH * self.MODEL_BATCH_SIZE, n_agents, dim\n    reconstruction_loss, i_feat = rec_loss(model.observation_decoder,   # decoder\n                   feat_dec.reshape(-1, n_agents, feat_dec.shape[-1]),  # input of decoder\n                   obs[:-1].reshape(-1, n_agents, obs.shape[-1]),   # label real obs\n                   1. - fake[:-1].reshape(-1, n_agents, 1))  # fake\n    reward_loss = F.smooth_l1_loss(model.reward_model(feat), reward[1:])    # reward\n    # print(\"pcont_loss\")\n    pcont_loss = log_prob_loss(model.pcont, feat, (1. - done[1:]))\n    av_action_loss = log_prob_loss(model.av_action, feat_dec, av_action[:-1]) if av_action is not None else 0.\n    i_feat = i_feat.reshape(time_steps - 1, batch_size, n_agents, -1)\n\n    dis_loss = info_loss(i_feat[1:], model, action[1:-1], 1. - fake[1:-1].reshape(-1))\n    div = state_divergence_loss(prior, post, config)    #kl\n\n    model_loss = div + reward_loss + dis_loss + reconstruction_loss + pcont_loss + av_action_loss\n    if np.random.randint(20) == 4:\n        wandb.log({'Model/reward_loss': reward_loss, 'Model/div': div, 'Model/av_action_loss': av_action_loss,\n                   'Model/reconstruction_loss': reconstruction_loss, 'Model/info_loss': dis_loss,\n                   'Model/pcont_loss': pcont_loss})\n\n    return model_loss\n\n\ndef actor_rollout(obs, action, last, model, actor, critic, config):  # model=ToCMLearnerModel\n    n_agents = obs.shape[2]  # 2\n    with FreezeParameters([model]):\n        embed = model.observation_encoder(obs.reshape(-1, n_agents, obs.shape[-1]))\n        embed = embed.reshape(obs.shape[0], obs.shape[1], n_agents, -1)\n        prev_state = model.representation.initial_state(obs.shape[1], obs.shape[2], device=obs.device)\n        prior, post, _ = rollout_representation(model.representation, obs.shape[0], embed, action,\n                                                prev_state, last)\n        post = post.map(lambda x: x.reshape((obs.shape[0] - 1) * obs.shape[1], n_agents, -1))\n\n        items = rollout_policy(model, model.av_action, config.HORIZON, actor, post, action, config)  # horizon is 15 TODO  av_action: 49 40 2 7\n        #\n    imag_feat = items[\"imag_states\"].get_features()\n    obs_pred = items[\"obs_preds\"]   # TODO\n    # old_policy = items['old_policy']\n    imag_rew_feat = torch.cat([items[\"imag_states\"].stoch[:-1], items[\"imag_states\"].deter[1:]], -1)\n    # obs_pred_rew = items[\"obs_preds\"][:-1]  # TODO\n    returns = critic_rollout(model, critic, imag_feat, imag_rew_feat, items[\"actions\"],\n             items[\"imag_states\"].map(lambda x: x.reshape(-1, n_agents, x.shape[-1])), config)\n    output = [items[\"actions\"][:-1].detach(),\n              items[\"av_actions\"][:-1].detach() if items[\"av_actions\"] is not None else None,\n              items[\"old_policy\"][:-1].detach(),  # TODO pi\n              imag_feat[:-1].detach(),\n              items[\"imag_states\"].map(lambda x: x[:-1]),\n              obs_pred[:-1].detach(),\n              returns.detach()]\n\n    return [batch_multi_agent(v, n_agents) for v in output]\n\n\ndef critic_rollout(model, critic, states, rew_states, actions, raw_states, config):\n    with FreezeParameters([model, critic]):\n        imag_reward = calculate_next_reward(model, actions, raw_states)\n        imag_reward = imag_reward.reshape(actions.shape[:-1]).unsqueeze(-1).mean(-2, keepdim=True)[:-1]\n        # print(\"states:\", states.shape)\n        # print(\"actions: \", actions.shape)\n        value = critic(states, actions)\n        # print(\"discount_arr\")\n        discount_arr = model.pcont(rew_states).mean\n        wandb.log({'Value/Max reward': imag_reward.max(), 'Value/Min reward': imag_reward.min(),\n                   'Value/Reward': imag_reward.mean(), 'Value/Discount': discount_arr.mean(),\n                   'Value/Value': value.mean()})\n    returns = compute_return(imag_reward, value[:-1], discount_arr, bootstrap=value[-1], lmbda=config.DISCOUNT_LAMBDA,\n                             gamma=config.GAMMA)\n    return returns\n\n\ndef calculate_reward(model, states, mask=None):\n    imag_reward = model.reward_model(states)\n    if mask is not None:\n        imag_reward *= mask\n    return imag_reward\n\n\ndef calculate_next_reward(model, actions, states):\n    actions = actions.reshape(-1, actions.shape[-2], actions.shape[-1])\n    next_state = model.transition(actions, states)\n    imag_rew_feat = torch.cat([states.stoch, next_state.deter], -1)\n    return calculate_reward(model, imag_rew_feat)\n\n\ndef actor_loss(model, imag_state, obs_pred, actions, av_actions, old_policy, advantage, actor, ent_weight, config):\n    next_state = model.transition(actions, imag_state)  # TODO\n    next_feat = next_state.get_features().detach()  # TODO\n    observations_next_other, _ = model.observation_decoder(next_feat)  # TODO\n    _, new_policy = actor(torch.cat((obs_pred, observations_next_other[:, :, -(config.num_agents-1)*4:-(config.num_agents-1)*2]), -1))  # TODO\n    if av_actions is not None:\n        new_policy[av_actions == 0] = -1e10\n    actions = actions.argmax(-1, keepdim=True)\n    rho = (F.log_softmax(new_policy, dim=-1).gather(2, actions) -  # new policy is PPO pi\n           F.log_softmax(old_policy, dim=-1).gather(2, actions)).exp()  # old policy is actor_rollout pi\n    ppo_loss, ent_loss = calculate_ppo_loss(new_policy, rho, advantage)  # normalized\n    if np.random.randint(10) == 9:\n        wandb.log({'Policy/Entropy': ent_loss.mean(), 'Policy/Mean action': actions.float().mean()})\n    return (ppo_loss + ent_loss.unsqueeze(-1) * ent_weight).mean()\n\n\ndef value_loss(critic, actions, imag_feat, targets):\n    value_pred = critic(imag_feat, actions)\n\n    mse_loss = (targets - value_pred) ** 2 / 2.0\n    return torch.mean(mse_loss)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/agent/optim/utils.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\ndef rec_loss(decoder, z, x, fake):\n    x_pred, feat = decoder(z)\n    batch_size = np.prod(list(x.shape[:-1]))\n    gen_loss1 = (F.smooth_l1_loss(x_pred, x, reduction='none') * fake).sum() / batch_size\n    return gen_loss1, feat\n\n\ndef ppo_loss(A, rho, eps=0.2):\n    return -torch.min(rho * A, rho.clamp(1 - eps, 1 + eps) * A)\n\n\ndef mse(model, x, target):\n    pred = model(x)\n    return ((pred - target) ** 2 / 2).mean()\n\n\ndef entropy_loss(prob, logProb):\n    return (prob * logProb).sum(-1)\n\n\ndef advantage(A):\n    std = 1e-4 + A.std() if len(A) > 0 else 1\n    adv = (A - A.mean()) / std\n    adv = adv.detach()\n    adv[adv != adv] = 0  # TODO what?\n    return adv\n\n\ndef calculate_ppo_loss(logits, rho, A):  # pi rho adv\n    prob = F.softmax(logits, dim=-1)\n    logProb = F.log_softmax(logits, dim=-1)\n    polLoss = ppo_loss(A, rho)\n    entLoss = entropy_loss(prob, logProb)\n    return polLoss, entLoss\n\n\ndef batch_multi_agent(tensor, n_agents):\n    if tensor is not None:\n        return tensor.map(lambda x: x.view(-1, n_agents, x.shape[-1])) if tensor.type() == None \\\n        else tensor.view(-1, n_agents, tensor.shape[-1])\n    else :\n        return None\n\n\ndef compute_return(reward, value, discount, bootstrap, lmbda, gamma):\n    next_values = torch.cat([value[1:], bootstrap[None]], 0)\n    target = reward + gamma * discount * next_values * (1 - lmbda)\n    outputs = []\n    accumulated_reward = bootstrap\n    for t in reversed(range(reward.shape[0])):\n        discount_factor = discount[t]\n        accumulated_reward = target[t] + gamma * discount_factor * accumulated_reward * lmbda\n        outputs.append(accumulated_reward)\n    returns = torch.flip(torch.stack(outputs), [0])\n    return returns\n\n\ndef info_loss(feat, model, actions, fake):\n    q_feat = F.relu(model.q_features(feat))\n    action_logits = model.q_action(q_feat)\n    return (fake * action_information_loss(action_logits, actions)).mean()\n\n\ndef action_information_loss(logits, target):\n    criterion = nn.CrossEntropyLoss(reduction='none')\n    return criterion(logits.view(-1, logits.shape[-1]), target.argmax(-1).view(-1))\n\n\ndef log_prob_loss(model, x, target):\n    pred = model(x)\n    return -torch.mean(pred.log_prob(target))\n\n\ndef kl_div_categorical(p, q):\n    eps = 1e-7\n    return (p * (torch.log(p + eps) - torch.log(q + eps))).sum(-1)\n\n\ndef reshape_dist(dist, config):\n    return dist.get_dist(dist.deter.shape[:-1], config.N_CATEGORICALS, config.N_CLASSES)\n\n\ndef state_divergence_loss(prior, posterior, config, reduce=True, balance=0.2):\n    prior_dist = reshape_dist(prior, config)\n    post_dist = reshape_dist(posterior, config)\n    post = kl_div_categorical(post_dist, prior_dist.detach())\n    pri = kl_div_categorical(post_dist.detach(), prior_dist)\n    kl_div = balance * post.mean(-1) + (1 - balance) * pri.mean(-1)\n    if reduce:\n        return torch.mean(kl_div)\n    else:\n        return kl_div\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/agent/runners/ToCMRunner.py",
    "content": "import ray\nimport wandb\nimport os\nimport torch\nimport numpy as np\nfrom agent.workers.ToCMWorker import ToCMWorker\nfrom environments import Env\n\nclass ToCMServer:\n    def __init__(self, n_workers, env_config, controller_config, model):\n        # ray.init(local_mode=True) #ray.init()\n        ray.init(dashboard_port=8625, object_store_memory=8*1024*1024*1024, _memory=8*1024*1024*1024, _temp_dir='~/temp/')\n        # ray.init()\n        self.workers = [ToCMWorker.remote(i, env_config, controller_config) for i in range(n_workers)]\n        self.tasks = [worker.run.remote(model) for worker in self.workers]\n\n    def append(self, idx, update):\n        self.tasks.append(self.workers[idx].run.remote(update))\n\n    def run(self):\n        done_id, tasks = ray.wait(self.tasks)\n        self.tasks[:] = tasks\n        del tasks\n        recvs = ray.get(done_id)[0]\n        return recvs\n\n\nclass ToCMRunner:\n\n    def __init__(self, env_config, learner_config, controller_config, n_workers):\n        self.env_config = env_config\n        self.env_type = env_config.ENV_TYPE\n        self.n_workers = n_workers\n        self.learner = learner_config.create_learner()\n        self.server = ToCMServer(n_workers, env_config, controller_config, self.learner.params())  # share weight\n        self.save_dir = '~/ToCM/weights/seed'\\\n                        + str(controller_config.random_seed) + 'num_agent_2' + '/'+ learner_config.env_name + '/'\n        if not os.path.exists(self.save_dir):\n            os.makedirs(self.save_dir)\n        self.pretrain = True\n\n    def run(self, max_steps=10 ** 10, max_episodes=10 ** 10):  # 10**10 50000\n        print(\"Start ToCM Runner!\")\n        cur_steps, cur_episode = 0, 0\n\n        wandb.define_metric(\"steps\")\n        wandb.define_metric(\"reward\", step_metric=\"steps\")\n        episode_rewards = []\n\n        while True:\n\n            rollout, info = self.server.run()  # control_config -> worker\n            episode_rewards.append(info[\"reward\"])\n            self.learner.step(rollout)\n            cur_steps += info[\"steps_done\"]\n            cur_episode += 1\n\n            if self.env_type == Env.MPE:\n                if cur_steps % 1000 == 0:\n                    episode_average_rewards = np.mean(episode_rewards)\n                    episode_rewards = []\n                    wandb.log({'reward': episode_average_rewards, 'steps': cur_steps})\n                    print('cur_steps:', cur_steps, 'total_samples:',\n                          self.learner.total_samples, 'reward', episode_average_rewards)\n            else:\n                wandb.log({'reward': info[\"\"\n                                          \"reward\"], 'steps': cur_steps})\n\n                print('cur_episode:', cur_episode, 'total_samples:',\n                      self.learner.total_samples, 'reward', info[\"reward\"])\n            if cur_episode >= max_episodes or cur_steps >= max_steps:\n                break\n\n            if cur_episode % 100 == 1 and self.pretrain:\n                path = self.save_dir + str( cur_episode // 100)\n                torch.save(self.learner.params()['model'], path +  '_model.pth')\n                torch.save(self.learner.params()['actor'], path +  '_actor.pth')\n                torch.save(self.learner.params()['critic'], path +  '_critic.pth')\n                print(\"Save model_\" + str( cur_episode // 100))\n            self.server.append(info['idx'], self.learner.params())\n\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/agent/utils/params.py",
    "content": "from typing import Iterable\nfrom torch.nn import Module\n\n\ndef get_parameters(modules: Iterable[Module]):\n    \"\"\"\n    Given a list of torch modules, returns a list of their parameters.\n    :param modules: iterable of modules\n    :returns: a list of parameters\n    \"\"\"\n    model_parameters = []\n    for module in modules:\n        model_parameters += list(module.parameters())\n    return model_parameters\n\n\nclass FreezeParameters:\n    def __init__(self, modules: Iterable[Module]):\n        \"\"\"\n        Context manager to locally freeze gradients.\n        In some cases with can speed up computation because gradients aren't calculated for these listed modules.\n        example:\n        ```\n        with FreezeParameters([module]):\n            output_tensor = module(input_tensor)\n        ```\n        :param modules: iterable of modules. used to call .parameters() to freeze gradients.\n        \"\"\"\n        self.modules = modules\n        self.param_states = [p.requires_grad for p in get_parameters(self.modules)]\n\n    def __enter__(self):\n        for param in get_parameters(self.modules):\n            param.requires_grad = False\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        for i, param in enumerate(get_parameters(self.modules)):\n            param.requires_grad = self.param_states[i]"
  },
  {
    "path": "examples/Social_Cognition/ToCM/agent/workers/ToCMWorker.py",
    "content": "from copy import deepcopy\nimport numpy as np\nimport ray\nimport torch\nfrom collections import defaultdict\n\nfrom environments import Env\n\n\n@ray.remote(num_gpus=1) # TODO\nclass ToCMWorker:\n\n    def __init__(self, idx, env_config, controller_config):\n        self.runner_handle = idx\n        self.env = env_config.create_env()\n        self.controller = controller_config.create_controller()  # controller\n        self.in_dim = controller_config.IN_DIM\n        self.env_type = env_config.ENV_TYPE\n        self.controller_config = controller_config\n        self.device = env_config.device\n\n    def _check_handle(self, handle):\n        if self.env_type == Env.STARCRAFT:\n            return self.done[handle] == 0\n\n        else:  # TODO\n            return self.env.agents[handle].movable\n\n    def _select_actions(self, state):\n        avail_actions = []\n        observations = []\n        fakes = []\n\n        nn_mask = None\n\n        for handle in range(self.env.n_agents):\n            if self.env_type == Env.STARCRAFT:\n                avail_actions.append(torch.tensor(self.env.get_avail_agent_actions(handle)))\n\n            if self._check_handle(handle) and handle in state:\n                fakes.append(torch.zeros(1, 1))\n                observations.append(state[handle].unsqueeze(0))\n            elif self.done[handle] == 1:  # handle is not in state\n                fakes.append(torch.ones(1, 1))  # fake move\n                observations.append((self.get_absorbing_state()).to(self.device))\n            else:\n                fakes.append(torch.zeros(1, 1))\n                obs = (torch.tensor(self.env.obs_builder._get_internal(handle)).float().unsqueeze(0)).to(self.device)\n                observations.append(obs)\n\n        # print(\"observations:\", observations)\n        observations = torch.cat(observations).unsqueeze(0)  # TODO\n        # print(\"observations:\", observations)\n        av_action = torch.stack(avail_actions).unsqueeze(0).to(self.device) if len(avail_actions) > 0 else None\n        # print(\"av_actions:\", av_action)\n        nn_mask = nn_mask.unsqueeze(0).repeat(8, 1, 1).to(self.device) if nn_mask is not None else None\n        # print(\"nn_mask:\", nn_mask)\n        actions = self.controller.step(observations, av_action, nn_mask).to(self.device)\n        # print(\"actions:\", actions)\n        return actions, observations, torch.cat(fakes).unsqueeze(0), av_action   # TODO use controller to model and pred\n\n    def _wrap(self, d):\n        for key, value in d.items():\n            d[key] = torch.tensor(value).to(self.controller_config.DEVICE).float()\n        return d\n\n    def get_absorbing_state(self):\n        state = torch.zeros(1, self.in_dim).to(self.device)  # TODO\n        return state\n\n    def augment(self, data, inverse=False):\n        aug = []\n        default = list(data.values())[0].reshape(1, -1)\n        for handle in range(self.env.n_agents):\n            if handle in data.keys():\n                aug.append(data[handle].reshape(1, -1))\n            else:\n                aug.append(torch.ones_like(default) if inverse else torch.zeros_like(default))\n        return torch.cat(aug).unsqueeze(0).to(self.device)  # TODO\n\n    def _check_termination(self, info, steps_done):\n        if self.env_type == Env.STARCRAFT or self.env_type == Env.MPE:\n            return \"episode_limit\" not in info\n        else:\n            return steps_done < self.env.max_time_steps  # can not chao shi\n\n    def run(self, ToCM_params):\n        f\"\"\"\n        interact with environment\n        :param ToCM_params: \n        :return: rollout: dict reward steps_done\n        \"\"\"\n        self.controller.receive_params(ToCM_params)\n        # Share the parameters learned by the learner with the controller.\n        # freeze the parameters\n\n        state = self._wrap(self.env.reset())  # to device\n        steps_done = 0\n        self.done = defaultdict(lambda: False)\n        episode_rewards = []\n        while True:\n            steps_done += 1\n            # print(\"state=\", state)\n            actions, obs, fakes, av_actions = self._select_actions(state)  # use controller to select action\n            if self.env_type == Env.MPE:\n                next_state, reward, done, info = self.env.step(actions)  # use env to update, with cpu\n                rewards = []\n                for key, value in reward.items():\n                    rewards.append(value)\n                episode_rewards.append(rewards)\n            else:\n                next_state, reward, done, info = self.env.step([action.argmax() for i, action in enumerate(actions)])\n            next_state, reward, done = self._wrap(deepcopy(next_state)), self._wrap(deepcopy(reward)), \\\n                self._wrap(deepcopy(done))  # to device\n            self.done = done\n            self.controller.update_buffer({\"action\": actions,\n                                           \"observation\": obs,\n                                           \"reward\": self.augment(reward),\n                                           \"done\":  self.augment(done),\n                                           \"fake\": fakes,\n                                           \"avail_action\": av_actions})\n\n            state = next_state\n            if all([done[key] == 1 for key in range(self.env.n_agents)]):\n                # print(\"Done\")\n                if self._check_termination(info, steps_done):\n                    # print(\"Done!\")\n                    obs = torch.cat([self.get_absorbing_state() for i in range(self.env.n_agents)]).unsqueeze(0)\n                    actions = torch.zeros(1, self.env.n_agents, actions.shape[-1])\n                    index = torch.randint(0, actions.shape[-1], actions.shape[:-1], device=actions.device)\n                    actions.scatter_(2, index.unsqueeze(-1), 1.)\n                    items = {\"observation\": obs,\n                             \"action\": actions,\n                             \"reward\": torch.zeros(1, self.env.n_agents, 1),\n                             \"fake\": torch.ones(1, self.env.n_agents, 1),\n                             \"done\": torch.ones(1, self.env.n_agents, 1),\n                             \"avail_action\": torch.ones_like(actions) if self.env_type == Env.STARCRAFT else None}\n                    self.controller.update_buffer(items)\n                    self.controller.update_buffer(items)  # why two\n                break\n        if self.env_type == Env.MPE:\n            reward = np.mean(np.sum(episode_rewards, axis=0))  # TODO\n        else:\n            reward = 1. if 'battle_won' in info and info['battle_won'] else 0.\n        return self.controller.dispatch_buffer(), {\"idx\": self.runner_handle,\n                                                   \"reward\": reward,  # a num\n                                                   \"steps_done\": steps_done}\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/configs/Config.py",
    "content": "from collections.abc import Iterable\n\n\n# train->agent_configs = [ToCMControllerConfig(ToCMConfig),] -> class ToCMConfig(Config) -> Config\nclass Config:\n    def __init__(self):\n        pass\n\n    def to_dict(self, prefix=\"\"):\n        res_dict = dict()\n        for key, value in self.__dict__.items():\n            if isinstance(value, Config):\n                res_dict.update(value.to_dict(prefix + str(key) + \"_\"))\n            elif isinstance(value, Iterable):\n                if value and isinstance(value[0], Config):\n                    for i, v in enumerate(value):\n                        res_dict.update(v.to_dict(prefix + str(key) + str(i) + \"_\"))\n                else:\n                    res_dict[prefix + str(key)] = value\n            else:\n                res_dict[prefix + str(key)] = value\n        return res_dict\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/configs/EnvConfigs.py",
    "content": "from configs.Config import Config\nfrom env.starcraft.StarCraft import StarCraft\nfrom env.mpe.MPE import MPE\n\n\nclass EnvConfig(Config):\n    def __init__(self):\n        pass\n\n    def create_env(self):\n        pass\n\n\n# TODO\nclass MPEConfig(EnvConfig):\n    def __init__(self, args):\n        self.args = args\n\n    def create_env(self):\n        return MPE(self.args)  # an env object with base class MultiAgentEnv(gym.Env)\n\n\nclass StarCraftConfig(EnvConfig):\n    def __init__(self, env_name, random_seed):\n        self.env_name = env_name\n        self.random_seed = random_seed  # TODO\n\n    def create_env(self):\n        return StarCraft(self.env_name, self.random_seed)\n\n\nclass EnvCurriculumConfig(EnvConfig):\n    def __init__(self, env_configs, env_episodes, env_type, device, obs_builder_config=None, reward_config=None):\n        self.env_configs = env_configs\n        self.env_episodes = env_episodes  # （100，）\n        self.ENV_TYPE = env_type  #\n        self.device = device  # TODO\n\n        if obs_builder_config is not None:\n            self.set_obs_builder_config(obs_builder_config)\n\n        if reward_config is not None:\n            self.set_reward_config(reward_config)\n\n    def update_random_seed(self):\n        for conf in self.env_configs:\n            conf.update_random_seed()\n\n    def set_obs_builder_config(self, obs_builder_config):\n        for conf in self.env_configs:\n            conf.set_obs_builder_config(obs_builder_config)\n\n    def set_reward_config(self, reward_config):\n        for conf in self.env_configs:\n            conf.set_reward_config(reward_config)\n\n    def create_env(self):\n        return EnvCurriculum(self.env_configs, self.env_episodes)\n\n\nclass EnvCurriculumSampleConfig(EnvConfig):\n    def __init__(self, env_configs, env_probs, obs_builder_config=None, reward_config=None):\n        self.env_configs = env_configs\n        self.env_probs = env_probs\n\n        if obs_builder_config is not None:\n            self.set_obs_builder_config(obs_builder_config)\n\n        if reward_config is not None:\n            self.set_reward_config(reward_config)\n\n    def update_random_seed(self):\n        for conf in self.env_configs:\n            conf.update_random_seed()\n\n    def set_obs_builder_config(self, obs_builder_config):\n        for conf in self.env_configs:\n            conf.set_obs_builder_config(obs_builder_config)\n\n    def set_reward_config(self, reward_config):\n        for conf in self.env_configs:\n            conf.set_reward_config(reward_config)\n\n    def create_env(self):\n        return EnvCurriculumSample(self.env_configs, self.env_probs)\n\n\nclass EnvCurriculumPrioritizedSampleConfig(EnvConfig):\n    def __init__(self, env_configs, repeat_random_seed, obs_builder_config=None, reward_config=None):\n        self.env_configs = env_configs\n        self.repeat_random_seed = repeat_random_seed\n\n        if obs_builder_config is not None:\n            self.set_obs_builder_config(obs_builder_config)\n\n        if reward_config is not None:\n            self.set_reward_config(reward_config)\n\n    def update_random_seed(self):\n        for conf in self.env_configs:\n            conf.update_random_seed()\n\n    def set_obs_builder_config(self, obs_builder_config):\n        for conf in self.env_configs:\n            conf.set_obs_builder_config(obs_builder_config)\n\n    def set_reward_config(self, reward_config):\n        for conf in self.env_configs:\n            conf.set_reward_config(reward_config)\n\n    def create_env(self):\n        return EnvCurriculumPrioritizedSample(self.env_configs, self.repeat_random_seed)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/configs/Experiment.py",
    "content": "from configs.Config import Config\n\n\nclass Experiment(Config):  # 这个还没改且里面没有\n    def __init__(self, steps, episodes, random_seed, env_config, controller_config, learner_config):\n        super(Experiment, self).__init__()  # TODO  device 在env里面加入了\n        self.steps = steps\n        self.episodes = episodes\n        self.random_seed = random_seed\n        self.env_config = env_config\n        self.controller_config = controller_config\n        self.learner_config = learner_config\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/configs/ToCM/ToCMAgentConfig.py",
    "content": "from dataclasses import dataclass\n\nimport torch\nimport torch.distributions as td\nimport torch.nn.functional as F\n\nfrom configs.Config import Config\n\nRSSM_STATE_MODE = 'discrete'\n\n\n#\nclass ToCMConfig(Config):  # 从Config继承\n    def __init__(self):\n        super().__init__()\n        self.HIDDEN = 64  # 隐藏层神经元个数\n        self.MODEL_HIDDEN = 64  # 模型隐藏层神经元个数\n        self.EMBED = 64  # 编码器神经元个数\n        self.N_CATEGORICALS = 32  # 分类数\n        self.N_CLASSES = 32  # 类别数\n        self.STOCHASTIC = self.N_CATEGORICALS * self.N_CLASSES  # stochastic:随机的\n        self.DETERMINISTIC = 64  # deterministic:确定的\n        self.FEAT = self.STOCHASTIC + self.DETERMINISTIC  # feat:特征\n        self.GLOBAL_FEAT = self.FEAT + self.EMBED  # global_feat:全局特征\n        self.VALUE_LAYERS = 2  # value_layers:值层\n        self.VALUE_HIDDEN = 64  # value_hidden:值隐藏层\n        self.PCONT_LAYERS = 2  # pcont_layers:概率层\n        self.PCONT_HIDDEN = 64  # pcont_hidden:概率隐藏层\n        self.ACTION_SIZE = 9  # action_size:动作大小\n        self.ACTION_LAYERS = 2  # action_layers:动作层\n        self.ACTION_HIDDEN = 64  # action_hidden:动作隐藏层\n        self.REWARD_LAYERS = 2  # reward_layers:奖励层\n        self.REWARD_HIDDEN = 64  # reward_hidden:奖励隐藏层\n        self.GAMMA = 0.99  # gamma:折扣因子\n        self.DISCOUNT = 0.99  # discount:折扣\n        self.DISCOUNT_LAMBDA = 0.95  # discount_lambda:折扣lambda\n        self.IN_DIM = 30  # in_dim:输入维度\n        self.LOG_FOLDER = 'wandb/'  # log_folder:日志文件夹\n        self.num_agents = 2\n\n\n@dataclass\nclass RSSMStateBase:\n    stoch: torch.Tensor\n    deter: torch.Tensor\n\n    def map(self, func):\n        return RSSMState(**{key: func(val) for key, val in self.__dict__.items()})\n\n    def get_features(self):\n        return torch.cat((self.stoch, self.deter), dim=-1)\n\n    def get_dist(self, *input):\n        pass\n\n    def type(self):\n        return None\n\n\n@dataclass\nclass RSSMStateDiscrete(RSSMStateBase):\n    logits: torch.Tensor\n\n    def get_dist(self, batch_shape, n_categoricals, n_classes):\n        return F.softmax(self.logits.reshape(*batch_shape, n_categoricals, n_classes), -1)\n\n\n@dataclass\nclass RSSMStateCont(RSSMStateBase):\n    mean: torch.Tensor\n    std: torch.Tensor\n\n    def get_dist(self, *input):\n        return td.independent.Independent(td.Normal(self.mean, self.std), 1)\n\n\nRSSMState = {'discrete': RSSMStateDiscrete,\n             'cont': RSSMStateCont}[RSSM_STATE_MODE]\n\n\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/configs/ToCM/ToCMControllerConfig.py",
    "content": "from agent.controllers.ToCMController import ToCMController\nfrom configs.ToCM.ToCMAgentConfig import ToCMConfig\n\n\n# train->agent_configs = [ToCMControllerConfig(ToCMConfig),] -> class ToCMConfig(Config) -> Config\nclass ToCMControllerConfig(ToCMConfig):\n    def __init__(self, env_name, RANDOM_SEED, device):  # RANDOM_SEED:23 device:'cuda:6' env_name:'3s5z_vs_3s6z'\n        super().__init__()\n\n        self.EXPL_DECAY = 0.9999  # exploration decay rate：探索衰减率\n        self.EXPL_NOISE = 0.  # exploration noise：探索噪声\n        self.EXPL_MIN = 0.  # minimum exploration：最小探索\n        self.DEVICE = device  # TODO\n        self.env_name = env_name    # TODO\n        self.random_seed = RANDOM_SEED  # TODO\n\n    def create_controller(self):\n        return ToCMController(self)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/configs/ToCM/ToCMLearnerConfig.py",
    "content": "from agent.learners.ToCMLearner import ToCMLearner\nfrom configs.ToCM.ToCMAgentConfig import ToCMConfig\n\n\n# train->agent_configs = [ToCMLearnerConfig(ToCMConfig),] -> class ToCMConfig(Config) -> Config\nclass ToCMLearnerConfig(ToCMConfig):  # 从ToCMConfig继承，有输入维度、输出维度、隐层维度、隐层层数、动作维度、动作隐层维度、动作隐层层数、\n    def __init__(self, env_name, RANDOM_SEED, device):\n        super().__init__()\n        self.MODEL_LR = 2e-4\n        self.ACTOR_LR = 7e-4  # TODO\n        self.VALUE_LR = 7e-4  # TODO\n        self.CAPACITY = 500000\n        self.MIN_BUFFER_SIZE = 100 \n        self.MODEL_EPOCHS = 20  # TODO\n        self.EPOCHS = 4  # TODO\n        self.PPO_EPOCHS = 10  # TODO\n        self.MODEL_BATCH_SIZE = 30#40\n        self.BATCH_SIZE = 40\n        self.SEQ_LENGTH = 50\n        self.N_SAMPLES = 1\n        self.TARGET_UPDATE = 128\n        self.GRAD_CLIP = 100.0\n        self.HORIZON = 15\n        self.ENTROPY = 0.001\n        self.ENTROPY_ANNEALING = 0.99998\n        self.GRAD_CLIP_POLICY = 100.\n        self.DEVICE = device  # TODO\n        self.env_name = env_name    # TODO\n        self.random_seed = RANDOM_SEED  # TODO\n        self.num_agents = 2\n\n    def create_learner(self):  # 通过config创建learner\n        return ToCMLearner(self)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/configs/ToCM/optimal/starcraft/AgentConfig.py",
    "content": "from configs.Config import Config\n\n\nclass ToCMConfig(Config):\n    def __init__(self):\n        super().__init__()\n        self.HIDDEN = 256\n        self.MODEL_HIDDEN = 256\n        self.EMBED = 256\n        self.N_CATEGORICALS = 32\n        self.N_CLASSES = 32\n        self.STOCHASTIC = self.N_CATEGORICALS * self.N_CLASSES\n        self.DETERMINISTIC = 256\n        self.FEAT = self.STOCHASTIC + self.DETERMINISTIC\n        self.GLOBAL_FEAT = self.FEAT + self.EMBED\n        self.VALUE_LAYERS = 2\n        self.VALUE_HIDDEN = 256\n        self.PCONT_LAYERS = 2\n        self.PCONT_HIDDEN = 256\n        self.ACTION_SIZE = 9\n        self.ACTION_LAYERS = 2\n        self.ACTION_HIDDEN = 256\n        self.REWARD_LAYERS = 2\n        self.REWARD_HIDDEN = 256\n        self.GAMMA = 0.99\n        self.DISCOUNT = 0.99\n        self.DISCOUNT_LAMBDA = 0.95\n        self.IN_DIM = 30\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/configs/ToCM/optimal/starcraft/LearnerConfig.py",
    "content": "from agent.learners.ToCMLearner import ToCMLearner\nfrom configs.ToCM.ToCMAgentConfig import ToCMConfig\n\n\nclass ToCMLearnerConfig(ToCMConfig):\n    def __init__(self):\n        super().__init__()\n        self.MODEL_LR = 2e-4\n        self.ACTOR_LR = 5e-4\n        self.VALUE_LR = 5e-4\n        self.CAPACITY = 250000\n        self.MIN_BUFFER_SIZE = 500\n        self.MODEL_EPOCHS = 40\n        self.EPOCHS = 4\n        self.PPO_EPOCHS = 10\n        self.MODEL_BATCH_SIZE = 40\n        self.BATCH_SIZE = 40\n        self.SEQ_LENGTH = 20\n        self.N_SAMPLES = 1\n        self.TARGET_UPDATE = 1\n        self.DEVICE = 'cuda:8'\n        self.GRAD_CLIP = 100.0\n        self.HORIZON = 15\n        self.ENTROPY = 0.001\n        self.ENTROPY_ANNEALING = 0.99998\n        self.GRAD_CLIP_POLICY = 100.0\n\n    def create_learner(self):\n        return ToCMLearner(self)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/configs/__init__.py",
    "content": "from .Experiment import Experiment\n\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/env/mpe/MPE.py",
    "content": "from mpe.MPE_Env import MPEEnv\n\n\nclass MPE:\n\n    def __init__(self, args):\n        self.env = MPEEnv(args)  # TODO args name and random seed\n        # scenario_name=args.scenario_name, benchmark=args.benchmark, num_agents=args.num_agents,\n        # num_adversaries, num_landmarks, episode_length\n        self.env.seed(args.seed)\n\n        self.n_agents = self.env.num_agents\n        self.agents = self.env.agents\n\n    def to_dict(self, l):\n        return {i: e for i, e in enumerate(l)}\n\n    def step(self, action_dict):  # action dict for each agent\n        # print(\"action_dist\", action_dict)\n        obs, reward, done, info = self.env.step(action_dict)  # TODO return four list\n        return {i: obs[i] for i in range(self.n_agents)}, {i: reward[i] for i in range(self.n_agents)}, \\\n            {i: done[i] for i in range(self.n_agents)}, {i: info[i] for i in range(self.n_agents)}\n\n    def reset(self):\n        obs = self.env.reset()\n        return self.to_dict(obs)\n\n    def close(self):\n        self.env.close()\n\n    # no mask and no this usage\n    def get_avail_agent_actions(self, handle):  # available handle is the i th agent, add mask\n        return self.env._get_done(handle)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/env/starcraft/StarCraft.py",
    "content": "from smac.env import StarCraft2Env  # import a package smac\n\n\nclass StarCraft:\n\n    def __init__(self, env_name, random_seed):\n        # map_name ->\n        self.env = StarCraft2Env(map_name=env_name, seed=random_seed, continuing_episode=True, difficulty=\"7\")  # TODO\n        env_info = self.env.get_env_info()\n\n        self.n_obs = env_info[\"obs_shape\"]\n        self.n_actions = env_info[\"n_actions\"]\n        self.n_agents = env_info[\"n_agents\"]\n\n    def to_dict(self, l):\n        return {i: e for i, e in enumerate(l)}\n\n    def step(self, action_dict):\n        reward, done, info = self.env.step(action_dict)\n        return self.to_dict(self.env.get_obs()), {i: reward for i in range(self.n_agents)}, \\\n               {i: done for i in range(self.n_agents)}, info\n\n    def reset(self):\n        self.env.reset()\n        return {i: obs for i, obs in enumerate(self.env.get_obs())}\n\n    def render(self):\n        self.env.render()\n\n    def close(self):\n        self.env.close()\n\n    def get_avail_agent_actions(self, handle):\n        return self.env.get_avail_agent_actions(handle)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/environments.py",
    "content": "from enum import Enum\n\n\nclass Env(str, Enum):\n    STARCRAFT = \"starcraft\"\n    MPE = \"mpe\"\n\n\n\n# RANDOM_SEED = 23\n# ENV_NAME = \"5_agents\"\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/mpe/MPE_Env.py",
    "content": "\"\"\"\nCode for creating a multiagent environment with one of the scenarios listed\nin ./scenarios/.\nCan be called by using, for example:\n    env = make_env('simple_speaker_listener')\nAfter producing the env object, can be used similarly to an OpenAI gym\nenvironment.\n\nA policy using this environment must output actions in the form of a list\nfor all agents. Each element of the list should be a numpy array,\nof size (env.world.dim_p + env.world.dim_c, 1). Physical actions precede\ncommunication actions in this array. See environment.py for more details.\n\"\"\"\nfrom .environment import MultiAgentEnv\nfrom .scenarios import load\n\n\ndef MPEEnv(args):\n    \"\"\"\n    Creates a MultiAgentEnv object as env. This can be used similar to a gym\n    environment by calling env.reset() and env.step().\n    Use env.render() to view the environment on the screen.\n\n    Input:\n        scenario_name   :   name of the scenario from ./scenarios/ to be Returns\n                            (without the .py extension)\n        benchmark       :   whether you want to produce benchmarking data\n                            (usually only done during evaluation)\n\n    Some useful env properties (see environment.py):\n        .observation_space  :   Returns the observation space for each agent\n        .action_space       :   Returns the action space for each agent\n        .n                  :   Returns the number of Agents\n    \"\"\"\n\n    # load scenario from script\n    scenario = load(args.env_name + \".py\").Scenario()\n    # create world\n    world = scenario.make_world(args)  # py file and others parameters, use the train parse?\n    # create multi agent environment\n    env = MultiAgentEnv(world, scenario.reset_world,\n                        scenario.reward, scenario.observation)\n\n    return env\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/mpe/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Social_Cognition/ToCM/mpe/core.py",
    "content": "import numpy as np\n# import seaborn as sns\n\n# physical/external base state of all entites\nclass EntityState(object):\n    def __init__(self):\n        # physical position\n        self.p_pos = None\n        # physical velocity\n        self.p_vel = None\n\n# state of agents (including communication and internal/mental state)\nclass AgentState(EntityState):\n    def __init__(self):\n        super(AgentState, self).__init__()\n        # communication utterance\n        self.c = None\n\n# action of the agent\nclass Action(object):\n    def __init__(self):\n        # physical action\n        self.u = None\n        # communication action\n        self.c = None\n\n# properties of wall entities\nclass Wall(object):\n    def __init__(self, orient='H', axis_pos=0.0, endpoints=(-1, 1), width=0.1,\n                 hard=True):\n        # orientation: 'H'orizontal or 'V'ertical\n        self.orient = orient\n        # position along axis which wall lays on (y-axis for H, x-axis for V)\n        self.axis_pos = axis_pos\n        # endpoints of wall (x-coords for H, y-coords for V)\n        self.endpoints = np.array(endpoints)\n        # width of wall\n        self.width = width\n        # whether wall is impassable to all agents\n        self.hard = hard\n        # color of wall\n        self.color = np.array([0.0, 0.0, 0.0])\n\n\n# properties and state of physical world entity\nclass Entity(object):\n    def __init__(self):\n        # index among all entities (important to set for distance caching)\n        self.i = 0\n        # name\n        self.name = ''\n        # properties:\n        self.size = 0.050\n        # entity can move / be pushed\n        self.movable = False\n        # entity collides with others\n        self.collide = True\n        # entity can pass through non-hard walls\n        self.ghost = False\n        # material density (affects mass)\n        self.density = 25.0\n        # color\n        self.color = None\n        # max speed and accel\n        self.max_speed = None\n        self.accel = None\n        # state: including internal/mental state p_pos, p_vel\n        self.state = EntityState()\n        # mass\n        self.initial_mass = 1.0\n        # commu channel\n        self.channel = None\n\n    @property\n    def mass(self):\n        return self.initial_mass\n\n# properties of landmark entities\nclass Landmark(Entity):\n    def __init__(self):\n        super(Landmark, self).__init__()\n\n# properties of agent entities\nclass Agent(Entity):\n    def __init__(self):\n        super(Agent, self).__init__()\n        # agent are adversary\n        self.adversary = False\n        # agent are dummy\n        self.dummy = False\n        # agents are movable by default\n        self.movable = True\n        # cannot send communication signals\n        self.silent = False\n        # cannot observe the world\n        self.blind = False\n        # physical motor noise amount\n        self.u_noise = None\n        # communication noise amount\n        self.c_noise = None\n        # control range\n        self.u_range = 1.0\n        # state: including communication state(communication utterance) c and internal/mental state p_pos, p_vel\n        self.state = AgentState()\n        # action: physical action u & communication action c\n        self.action = Action()\n        # script behavior to execute\n        self.action_callback = None\n        # zoe 20200420\n        self.goal = None\n\n# multi-agent world\nclass World(object):\n    def __init__(self):\n        # list of agents and entities (can change at execution-time!)\n        self.agents = []\n        self.landmarks = []\n        self.walls = []\n        # communication channel dimensionality\n        self.dim_c = 0\n        # position dimensionality\n        self.dim_p = 2\n        # color dimensionality\n        self.dim_color = 3\n        # simulation timestep\n        self.dt = 0.1\n        # physical damping\n        self.damping = 0.25\n        # contact response parameters\n        self.contact_force = 1e+2\n        self.contact_margin = 1e-3\n        # cache distances between all agents (not calculated by default)\n        self.cache_dists = False\n        self.cached_dist_vect = None\n        self.cached_dist_mag = None\n        # zoe 20200420\n        self.world_length = 25\n        self.world_step = 0\n        self.num_agents = 0\n        self.num_landmarks = 0\n\n    # return all entities in the world\n    @property\n    def entities(self):\n        return self.agents + self.landmarks\n\n    # return all agents controllable by external policies\n    @property\n    def policy_agents(self):\n        return [agent for agent in self.agents if agent.action_callback is None]\n\n    # return all agents controlled by world scripts\n    @property\n    def scripted_agents(self):\n        return [agent for agent in self.agents if agent.action_callback is not None]\n\n    def calculate_distances(self):\n        if self.cached_dist_vect is None:\n            # initialize distance data structure\n            self.cached_dist_vect = np.zeros((len(self.entities),\n                                              len(self.entities),\n                                              self.dim_p))\n            # calculate minimum distance for a collision between all entities\n            self.min_dists = np.zeros((len(self.entities), len(self.entities)))\n            for ia, entity_a in enumerate(self.entities):\n                for ib in range(ia + 1, len(self.entities)):\n                    entity_b = self.entities[ib]\n                    min_dist = entity_a.size + entity_b.size\n                    self.min_dists[ia, ib] = min_dist\n                    self.min_dists[ib, ia] = min_dist\n\n        for ia, entity_a in enumerate(self.entities):\n            for ib in range(ia + 1, len(self.entities)):\n                entity_b = self.entities[ib]\n                delta_pos = entity_a.state.p_pos - entity_b.state.p_pos\n                self.cached_dist_vect[ia, ib, :] = delta_pos\n                self.cached_dist_vect[ib, ia, :] = -delta_pos\n\n        self.cached_dist_mag = np.linalg.norm(self.cached_dist_vect, axis=2)\n\n        self.cached_collisions = (self.cached_dist_mag <= self.min_dists)\n\n    def assign_agent_colors(self):\n        n_dummies = 0\n        if hasattr(self.agents[0], 'dummy'):\n            n_dummies = len([a for a in self.agents if a.dummy])\n        n_adversaries = 0\n        if hasattr(self.agents[0], 'adversary'):\n            n_adversaries = len([a for a in self.agents if a.adversary])\n        n_good_agents = len(self.agents) - n_adversaries - n_dummies\n        # r g b\n        dummy_colors = [(0.25, 0.75, 0.25)] * n_dummies\n        # sns.color_palette(\"OrRd_d\", n_adversaries)\n        adv_colors = [(0.75, 0.25, 0.25)] * n_adversaries\n        # sns.color_palette(\"GnBu_d\", n_good_agents)\n        good_colors = [(0.25, 0.25, 0.75)] * n_good_agents\n        colors = dummy_colors + adv_colors + good_colors\n        for color, agent in zip(colors, self.agents):\n            agent.color = color\n\n    # landmark color\n    def assign_landmark_colors(self):\n        for landmark in self.landmarks:\n            landmark.color = np.array([0.25, 0.25, 0.25])\n\n    # update state of the world\n    def step(self):\n        self.world_step += 1\n        # set actions for scripted agents\n        for agent in self.scripted_agents:\n            agent.action = agent.action_callback(agent, self)\n        # gather forces applied to entities\n        p_force = [None] * len(self.entities)\n        # apply agent physical controls\n        p_force = self.apply_action_force(p_force)\n        # apply environment forces\n        p_force = self.apply_environment_force(p_force)\n        # integrate physical state\n        self.integrate_state(p_force)\n        # update agent state\n        for agent in self.agents:\n            self.update_agent_state(agent)\n        # calculate and store distances between all entities\n        if self.cache_dists:\n            self.calculate_distances()\n\n    # gather agent action forces\n    def apply_action_force(self, p_force):\n        # set applied forces\n        for i, agent in enumerate(self.agents):\n            if agent.movable:\n                noise = np.random.randn(\n                    *agent.action.u.shape) * agent.u_noise if agent.u_noise else 0.0\n                # force = mass * a * action + n\n                p_force[i] = (\n                    agent.mass * agent.accel if agent.accel is not None else agent.mass) * agent.action.u + noise\n        return p_force\n\n    # gather physical forces acting on entities\n    def apply_environment_force(self, p_force):\n        # simple (but inefficient) collision response\n        for a, entity_a in enumerate(self.entities):\n            for b, entity_b in enumerate(self.entities):\n                if(b <= a):\n                    continue\n                [f_a, f_b] = self.get_entity_collision_force(a, b)\n                if(f_a is not None):\n                    if(p_force[a] is None):\n                        p_force[a] = 0.0\n                    p_force[a] = f_a + p_force[a]\n                if(f_b is not None):\n                    if(p_force[b] is None):\n                        p_force[b] = 0.0\n                    p_force[b] = f_b + p_force[b]\n            if entity_a.movable:\n                for wall in self.walls:\n                    wf = self.get_wall_collision_force(entity_a, wall)\n                    if wf is not None:\n                        if p_force[a] is None:\n                            p_force[a] = 0.0\n                        p_force[a] = p_force[a] + wf\n        return p_force\n\n    # integrate physical state\n    def integrate_state(self, p_force):\n        for i, entity in enumerate(self.entities):\n            if not entity.movable:\n                continue\n            entity.state.p_vel = entity.state.p_vel * (1 - self.damping)\n            if (p_force[i] is not None):\n                entity.state.p_vel += (p_force[i] / entity.mass) * self.dt\n            if entity.max_speed is not None:\n                speed = np.sqrt(\n                    np.square(entity.state.p_vel[0]) + np.square(entity.state.p_vel[1]))\n                if speed > entity.max_speed:\n                    entity.state.p_vel = entity.state.p_vel / np.sqrt(np.square(entity.state.p_vel[0]) +\n                                                                      np.square(entity.state.p_vel[1])) * entity.max_speed\n            entity.state.p_pos += entity.state.p_vel * self.dt\n\n    def update_agent_state(self, agent):\n        # set communication state (directly for now)\n        if agent.silent:\n            agent.state.c = np.zeros(self.dim_c)\n        else:\n            noise = np.random.randn(*agent.action.c.shape) * \\\n                agent.c_noise if agent.c_noise else 0.0\n            agent.state.c = agent.action.c + noise\n\n    # get collision forces for any contact between two entities\n    def get_entity_collision_force(self, ia, ib):\n        entity_a = self.entities[ia]\n        entity_b = self.entities[ib]\n        if (not entity_a.collide) or (not entity_b.collide):\n            return [None, None]  # not a collider\n        if (not entity_a.movable) and (not entity_b.movable):\n            return [None, None]  # neither entity moves\n        if (entity_a is entity_b):\n            return [None, None]  # don't collide against itself\n        if self.cache_dists:\n            delta_pos = self.cached_dist_vect[ia, ib]\n            dist = self.cached_dist_mag[ia, ib]\n            dist_min = self.min_dists[ia, ib]\n        else:\n            # compute actual distance between entities\n            delta_pos = entity_a.state.p_pos - entity_b.state.p_pos\n            dist = np.sqrt(np.sum(np.square(delta_pos)))\n            # minimum allowable distance\n            dist_min = entity_a.size + entity_b.size\n        # softmax penetration\n        k = self.contact_margin\n        penetration = np.logaddexp(0, -(dist - dist_min)/k)*k\n        force = self.contact_force * delta_pos / dist * penetration\n        if entity_a.movable and entity_b.movable:\n            # consider mass in collisions\n            force_ratio = entity_b.mass / entity_a.mass\n            force_a = force_ratio * force\n            force_b = -(1 / force_ratio) * force\n        else:\n            force_a = +force if entity_a.movable else None\n            force_b = -force if entity_b.movable else None\n        return [force_a, force_b]\n\n    # get collision forces for contact between an entity and a wall\n    def get_wall_collision_force(self, entity, wall):\n        if entity.ghost and not wall.hard:\n            return None  # ghost passes through soft walls\n        if wall.orient == 'H':\n            prll_dim = 0\n            perp_dim = 1\n        else:\n            prll_dim = 1\n            perp_dim = 0\n        ent_pos = entity.state.p_pos\n        if (ent_pos[prll_dim] < wall.endpoints[0] - entity.size or\n                ent_pos[prll_dim] > wall.endpoints[1] + entity.size):\n            return None  # entity is beyond endpoints of wall\n        elif (ent_pos[prll_dim] < wall.endpoints[0] or\n              ent_pos[prll_dim] > wall.endpoints[1]):\n            # part of entity is beyond wall\n            if ent_pos[prll_dim] < wall.endpoints[0]:\n                dist_past_end = ent_pos[prll_dim] - wall.endpoints[0]\n            else:\n                dist_past_end = ent_pos[prll_dim] - wall.endpoints[1]\n            theta = np.arcsin(dist_past_end / entity.size)\n            dist_min = np.cos(theta) * entity.size + 0.5 * wall.width\n        else:  # entire entity lies within bounds of wall\n            theta = 0\n            dist_past_end = 0\n            dist_min = entity.size + 0.5 * wall.width\n\n        # only need to calculate distance in relevant dim\n        delta_pos = ent_pos[perp_dim] - wall.axis_pos\n        dist = np.abs(delta_pos)\n        # softmax penetration\n        k = self.contact_margin\n        penetration = np.logaddexp(0, -(dist - dist_min)/k)*k\n        force_mag = self.contact_force * delta_pos / dist * penetration\n        force = np.zeros(2)\n        force[perp_dim] = np.cos(theta) * force_mag\n        force[prll_dim] = np.sin(theta) * np.abs(force_mag)\n        return force\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/mpe/environment.py",
    "content": "import gym\nfrom gym import spaces\nfrom gym.envs.registration import EnvSpec\nimport numpy as np\nfrom .multi_discrete import MultiDiscrete\n\n# update bounds to center around agent\ncam_range = 2\n\n\n# environment for all agents in the multi agent world\n# currently code assumes that no agents will be created/destroyed at runtime!\n\n\nclass MultiAgentEnv(gym.Env):\n    metadata = {\n        'render.modes': ['human', 'rgb_array']\n    }\n\n    def __init__(self, world, reset_callback=None, reward_callback=None,\n                 observation_callback=None, info_callback=None,\n                 done_callback=None, post_step_callback=None,\n                 shared_viewer=True, discrete_action=True):\n\n        self.world = world\n        self.world_length = self.world.world_length  # obs TODO 25\n        self.current_step = 0\n        self.agents = self.world.policy_agents\n        # set required vectorized gym env property\n        self.num_agents = len(world.policy_agents)\n        # scenario callbacks\n        self.reset_callback = reset_callback\n        self.reward_callback = reward_callback\n        self.observation_callback = observation_callback\n        self.info_callback = info_callback\n        self.done_callback = done_callback\n\n        self.post_step_callback = post_step_callback\n\n        # environment parameters\n        # self.discrete_action_space = True\n        self.discrete_action_space = discrete_action  # actions dim TODO\n\n        # if true, action is a number 0...N, otherwise action is a one-hot N-dimensional vector\n        self.discrete_action_input = False\n        # if true, even the action is continuous, action will be performed discretely\n        self.force_discrete_action = world.discrete_action if hasattr(\n            world, 'discrete_action') else False\n        # in this env, force_discrete_action == False��because world do not have discrete_action\n\n        # if true, every agent has the same reward\n        self.shared_reward = world.collaborative if hasattr(\n            world, 'collaborative') else False\n        # self.shared_reward = False\n        self.time = 0\n\n        # configure spaces\n        self.action_space = []\n        self.observation_space = []\n        self.share_observation_space = []\n        share_obs_dim = 0\n        for agent in self.agents:\n            total_action_space = []\n            # physical action space\n            if self.discrete_action_space:\n                u_action_space = spaces.Discrete(world.dim_p * 2 + 1)\n            else:\n                u_action_space = spaces.Box(\n                    low=-agent.u_range, high=+agent.u_range, shape=(world.dim_p,), dtype=np.float32)  # [-1,1]\n            if agent.movable:\n                total_action_space.append(u_action_space)\n\n            # communication action space\n            if self.discrete_action_space:\n                c_action_space = spaces.Discrete(world.dim_c)\n            else:\n                c_action_space = spaces.Box(low=0.0, high=1.0, shape=(\n                    world.dim_c,), dtype=np.float32)  # [0,1]\n            # c_action_space = spaces.Discrete(world.dim_c)\n\n            if not agent.silent:\n                total_action_space.append(c_action_space)\n            # total action space\n            if len(total_action_space) > 1:\n                # all action spaces are discrete, so simplify to MultiDiscrete action space\n                if all([isinstance(act_space, spaces.Discrete) for act_space in total_action_space]):\n                    act_space = MultiDiscrete(\n                        [[0, act_space.n - 1] for act_space in total_action_space])\n                else:\n                    act_space = spaces.Tuple(total_action_space)\n                self.action_space.append(act_space)\n            else:\n                self.action_space.append(total_action_space[0])\n            # observation space\n            obs_dim = len(observation_callback(agent, self.world))\n            share_obs_dim += obs_dim\n            self.observation_space.append(spaces.Box(\n                low=-np.inf, high=+np.inf, shape=(obs_dim,), dtype=np.float32))  # [-inf,inf]\n            agent.action.c = np.zeros(self.world.dim_c)\n        self.share_observation_space = [spaces.Box(\n            low=-np.inf, high=+np.inf, shape=(share_obs_dim,), dtype=np.float32)] * self.num_agents\n        # rendering\n        self.shared_viewer = shared_viewer\n        if self.shared_viewer:\n            self.viewers = [None]\n        else:\n            self.viewers = [None] * self.num_agents\n        self._reset_render()\n\n    def seed(self, seed=None):\n        if seed is None:\n            np.random.seed(1)\n        else:\n            np.random.seed(seed)\n\n    # step  this is  env.step()\n    def step(self, action_n):\n        self.current_step += 1\n        obs_n = []\n        reward_n = []\n        done_n = []\n        info_n = []\n        self.agents = self.world.policy_agents\n        # set action for each agent\n        for i, agent in enumerate(self.agents):\n            self._set_action(action_n[i], agent, self.action_space[i])\n        # advance world state\n        self.world.step()  # core.step()\n        # record observation for each agent\n        for i, agent in enumerate(self.agents):\n            obs_n.append(self._get_obs(agent))\n            reward_n.append([self._get_reward(agent)])\n            done_n.append([self._get_done(agent)])\n            info = {'individual_reward': self._get_reward(agent)}\n            info_n.append(info)\n\n        # all agents get total reward in cooperative case, if shared reward, all agents have the same reward,\n        # and reward is sum\n        reward = np.sum(reward_n)\n        if self.shared_reward:\n            reward_n = [[reward]] * self.num_agents\n\n        if self.post_step_callback is not None:\n            self.post_step_callback(self.world)\n\n        return obs_n, reward_n, done_n, info_n\n\n    def reset(self):\n        self.current_step = 0\n        # reset world\n        self.reset_callback(self.world)\n        # reset renderer\n        self._reset_render()\n        # record observations for each agent\n        obs_n = []\n        self.agents = self.world.policy_agents\n\n        for agent in self.agents:\n            obs_n.append(self._get_obs(agent))\n\n        return obs_n\n\n    # get info used for benchmarking\n    def _get_info(self, agent):\n        if self.info_callback is None:\n            return {}\n        return self.info_callback(agent, self.world)\n\n    # get observation for a particular agent\n    def _get_obs(self, agent):\n        if isinstance(agent, int):\n            agent = self.agents[agent]\n        if self.observation_callback is None:\n            print(\"Unavailable:\", np.zeros(0))\n            return np.zeros(0)\n        return self.observation_callback(agent, self.world)\n\n    # get dones for a particular agent\n    # unused right now -- agents are allowed to go beyond the viewing screen\n    def _get_done(self, agent):\n        if isinstance(agent, int):\n            agent = self.agents[agent]\n        if self.done_callback is None:\n            if self.current_step >= self.world_length:\n                return True\n            else:\n                return False\n        return self.done_callback(agent, self.world)\n\n    # get reward for a particular agent\n    def _get_reward(self, agent):\n        if self.reward_callback is None:\n            return 0.0\n        return self.reward_callback(agent, self.world)\n\n    # set env action for a particular agent\n    def _set_action(self, action, agent, action_space, time=None):\n        agent.action.u = np.zeros(self.world.dim_p)\n        agent.action.c = np.zeros(self.world.dim_c)\n        # process action\n        if isinstance(action_space, MultiDiscrete):\n            act = []\n            size = action_space.high - action_space.low + 1\n            index = 0\n            for s in size:\n                act.append(action[index:(index + s)])\n                index += s\n            action = act\n        else:\n            action = [action]\n        if agent.movable:\n            # physical action\n            if self.discrete_action_input:\n                agent.action.u = np.zeros(self.world.dim_p)\n                # process discrete action\n                if action[0] == 1:\n                    agent.action.u[0] = -1.0\n                if action[0] == 2:\n                    agent.action.u[0] = +1.0\n                if action[0] == 3:\n                    agent.action.u[1] = -1.0\n                if action[0] == 4:\n                    agent.action.u[1] = +1.0\n                d = self.world.dim_p\n            else:\n                if self.discrete_action_space:\n                    agent.action.u[0] += action[0][1] - action[0][2]\n                    agent.action.u[1] += action[0][3] - action[0][4]\n                    d = 5\n                else:\n                    if self.force_discrete_action:\n                        p = np.argmax(action[0][0:self.world.dim_p])\n                        action[0][:] = 0.0\n                        action[0][p] = 1.0\n                    agent.action.u = action[0][0:self.world.dim_p]\n                    d = self.world.dim_p\n\n            sensitivity = 5.0\n            if agent.accel is not None:\n                sensitivity = agent.accel\n            agent.action.u *= sensitivity\n\n            if (not agent.silent) and (not isinstance(action_space, MultiDiscrete)):\n                action[0] = action[0][d:]\n            else:\n                action = action[1:]\n\n        if not agent.silent:\n            # communication action\n            if self.discrete_action_input:\n                agent.action.c = np.zeros(self.world.dim_c)\n                agent.action.c[action[0]] = 1.0\n            else:\n                agent.action.c = action[0]\n\n            action = action[1:]\n\n        # make sure we used all elements of action\n        assert len(action) == 0\n\n    def _get_avail_action(self, handle):\n        agent = self.agents[handle]  # TODO\n\n    # reset rendering assets\n    def _reset_render(self):\n        self.render_geoms = None\n        self.render_geoms_xform = None\n\n    def render(self, mode='human', close=True):\n        if close:\n            # close any existic renderers\n            for i, viewer in enumerate(self.viewers):\n                if viewer is not None:\n                    viewer.close()\n                self.viewers[i] = None\n            return []\n\n        if mode == 'human':\n            alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'\n            message = ''\n            for agent in self.world.agents:\n                comm = []\n                for other in self.world.agents:\n                    if other is agent:\n                        continue\n                    if np.all(other.state.c == 0):\n                        word = '_'\n                    else:\n                        word = alphabet[np.argmax(other.state.c)]\n                    message += (other.name + ' to ' +\n                                agent.name + ': ' + word + '   ')\n            print(message)\n\n        for i in range(len(self.viewers)):\n            # create viewers (if necessary)\n\n            if self.viewers[i] is None:\n                # import rendering only if we need it (and don't import for headless machines)\n                # from gym.envs.classic_control import rendering\n                from . import rendering\n                self.viewers[i] = rendering.Viewer(700, 700)\n\n        # create rendering geometry\n        if self.render_geoms is None:\n            # import rendering only if we need it (and don't import for headless machines)\n            # from gym.envs.classic_control import rendering\n            from . import rendering\n            self.render_geoms = []\n            self.render_geoms_xform = []\n\n            self.comm_geoms = []\n\n            for entity in self.world.entities:\n                geom = rendering.make_circle(entity.size)\n                xform = rendering.Transform()\n\n                entity_comm_geoms = []\n\n                if 'agent' in entity.name:\n                    geom.set_color(*entity.color, alpha=0.5)\n\n                    if not entity.silent:\n                        dim_c = self.world.dim_c\n                        # make circles to represent communication\n                        for ci in range(dim_c):\n                            comm = rendering.make_circle(entity.size / dim_c)\n                            comm.set_color(1, 1, 1)\n                            comm.add_attr(xform)\n                            offset = rendering.Transform()\n                            comm_size = (entity.size / dim_c)\n                            offset.set_translation(ci * comm_size * 2 -\n                                                   entity.size + comm_size, 0)\n                            comm.add_attr(offset)\n                            entity_comm_geoms.append(comm)\n\n                else:\n                    geom.set_color(*entity.color)\n                    if entity.channel is not None:\n                        dim_c = self.world.dim_c\n                        # make circles to represent communication\n                        for ci in range(dim_c):\n                            comm = rendering.make_circle(entity.size / dim_c)\n                            comm.set_color(1, 1, 1)\n                            comm.add_attr(xform)\n                            offset = rendering.Transform()\n                            comm_size = (entity.size / dim_c)\n                            offset.set_translation(ci * comm_size * 2 -\n                                                   entity.size + comm_size, 0)\n                            comm.add_attr(offset)\n                            entity_comm_geoms.append(comm)\n                geom.add_attr(xform)\n                self.render_geoms.append(geom)\n                self.render_geoms_xform.append(xform)\n                self.comm_geoms.append(entity_comm_geoms)\n            for wall in self.world.walls:\n                corners = ((wall.axis_pos - 0.5 * wall.width, wall.endpoints[0]),\n                           (wall.axis_pos - 0.5 *\n                            wall.width, wall.endpoints[1]),\n                           (wall.axis_pos + 0.5 *\n                            wall.width, wall.endpoints[1]),\n                           (wall.axis_pos + 0.5 * wall.width, wall.endpoints[0]))\n                if wall.orient == 'H':\n                    corners = tuple(c[::-1] for c in corners)\n                geom = rendering.make_polygon(corners)\n                if wall.hard:\n                    geom.set_color(*wall.color)\n                else:\n                    geom.set_color(*wall.color, alpha=0.5)\n                self.render_geoms.append(geom)\n\n            # add geoms to viewer\n            # for viewer in self.viewers:\n            #     viewer.geoms = []\n            #     for geom in self.render_geoms:\n            #         viewer.add_geom(geom)\n\n            for viewer in self.viewers:\n                viewer.geoms = []\n                for geom in self.render_geoms:\n                    viewer.add_geom(geom)\n                for entity_comm_geoms in self.comm_geoms:\n                    for geom in entity_comm_geoms:\n                        viewer.add_geom(geom)\n\n        results = []\n        for i in range(len(self.viewers)):\n            from . import rendering\n\n            if self.shared_viewer:\n                pos = np.zeros(self.world.dim_p)\n            else:\n                pos = self.agents[i].state.p_pos\n            self.viewers[i].set_bounds(\n                pos[0] - cam_range, pos[0] + cam_range, pos[1] - cam_range, pos[1] + cam_range)\n            # update geometry positions\n            for e, entity in enumerate(self.world.entities):\n                self.render_geoms_xform[e].set_translation(*entity.state.p_pos)\n                if 'agent' in entity.name:\n                    self.render_geoms[e].set_color(*entity.color, alpha=0.5)\n\n                    if not entity.silent:\n                        for ci in range(self.world.dim_c):\n                            color = 1 - entity.state.c[ci]\n                            self.comm_geoms[e][ci].set_color(\n                                color, color, color)\n                else:\n                    self.render_geoms[e].set_color(*entity.color)\n                    if entity.channel is not None:\n                        for ci in range(self.world.dim_c):\n                            color = 1 - entity.channel[ci]\n                            self.comm_geoms[e][ci].set_color(\n                                color, color, color)\n\n            # render to display or array\n            results.append(self.viewers[i].render(\n                return_rgb_array=mode == 'rgb_array'))\n\n        return results\n\n    # create receptor field locations in local coordinate frame\n    def _make_receptor_locations(self, agent):\n        receptor_type = 'polar'\n        range_min = 0.05 * 2.0\n        range_max = 1.00\n        dx = []\n        # circular receptive field\n        if receptor_type == 'polar':\n            for angle in np.linspace(-np.pi, +np.pi, 8, endpoint=False):\n                for distance in np.linspace(range_min, range_max, 3):\n                    dx.append(\n                        distance * np.array([np.cos(angle), np.sin(angle)]))\n            # add origin\n            dx.append(np.array([0.0, 0.0]))\n        # grid receptive field\n        if receptor_type == 'grid':\n            for x in np.linspace(-range_max, +range_max, 5):\n                for y in np.linspace(-range_max, +range_max, 5):\n                    dx.append(np.array([x, y]))\n        return dx\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/mpe/multi_discrete.py",
    "content": "# An old version of OpenAI Gym's multi_discrete.py. (Was getting affected by Gym updates)\n# (https://github.com/openai/gym/blob/1fb81d4e3fb780ccf77fec731287ba07da35eb84/gym/spaces/multi_discrete.py)\n\nimport numpy as np\n\nimport gym\n\n\nclass MultiDiscrete(gym.Space):\n    \"\"\"\n    - The multi-discrete action space consists of a series of discrete action spaces with different parameters\n    - It can be adapted to both a Discrete action space or a continuous (Box) action space\n    - It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space\n    - It is parametrized by passing an array of arrays containing [min, max] for each discrete action space\n       where the discrete action space can take any integers from `min` to `max` (both inclusive)\n    Note: A value of 0 always need to represent the NOOP action.\n    e.g. Nintendo Game Controller\n    - Can be conceptualized as 3 discrete action spaces:\n        1) Arrow Keys: Discrete 5  - NOOP[0], UP[1], RIGHT[2], DOWN[3], LEFT[4]  - params: min: 0, max: 4\n        2) Button A:   Discrete 2  - NOOP[0], Pressed[1] - params: min: 0, max: 1\n        3) Button B:   Discrete 2  - NOOP[0], Pressed[1] - params: min: 0, max: 1\n    - Can be initialized as\n        MultiDiscrete([ [0,4], [0,1], [0,1] ])\n    \"\"\"\n\n    def __init__(self, array_of_param_array):\n        self.low = np.array([x[0] for x in array_of_param_array])\n        self.high = np.array([x[1] for x in array_of_param_array])\n        self.num_discrete_space = self.low.shape[0]\n\n    def sample(self):\n        \"\"\" Returns a array with one sample from each discrete action space \"\"\"\n        # For each row: round(random .* (max - min) + min, 0)\n        #random_array = prng.np_random.rand(self.num_discrete_space)\n        random_array = np.random.rand(self.num_discrete_space)\n\n        return [int(x) for x in np.floor(np.multiply((self.high - self.low + 1.), random_array) + self.low)]\n\n    def contains(self, x):\n        return len(x) == self.num_discrete_space and (np.array(x) >= self.low).all() and (np.array(x) <= self.high).all()\n\n    @property\n    def shape(self):\n        return self.num_discrete_space\n\n    def __repr__(self):\n        return \"MultiDiscrete\" + str(self.num_discrete_space)\n\n    def __eq__(self, other):\n        return np.array_equal(self.low, other.low) and np.array_equal(self.high, other.high)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/mpe/rendering.py",
    "content": "\"\"\"\n2D rendering framework\n\"\"\"\nfrom __future__ import division\nimport os\nimport six  # TODO\nimport sys\n\nif \"Apple\" in sys.version:\n    if 'DYLD_FALLBACK_LIBRARY_PATH' in os.environ:\n        os.environ['DYLD_FALLBACK_LIBRARY_PATH'] += ':/usr/lib'\n        # (JDS 2016/04/15): avoid bug on Anaconda 2.3.0 / Yosemite\n\nfrom gym.utils import reraise\nfrom gym import error\n\ntry:\n    import pyglet\nexcept ImportError as e:\n    reraise(\n        suffix=\"HINT: you can install pyglet directly via 'pip install pyglet'. But if you really just want to install all Gym dependencies and not have to think about it, 'pip install -e .[all]' or 'pip install gym[all]' will do it.\")\n\ntry:\n    from pyglet.gl import *\nexcept ImportError as e:\n    reraise(prefix=\"Error occured while running `from pyglet.gl import *`\",\n            suffix=\"HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get install python-opengl'. If you're running on a server, you may need a virtual frame buffer; something like this should work: 'xvfb-run -s \\\"-screen 0 1400x900x24\\\" python <your_script.py>'\")\n\nimport math\nimport numpy as np\n\nRAD2DEG = 57.29577951308232\n\n\ndef get_display(spec):\n    \"\"\"Convert a display specification (such as :0) into an actual Display\n    object.\n\n    Pyglet only supports multiple Displays on Linux.\n    \"\"\"\n    if spec is None:\n        return None\n    elif isinstance(spec, six.string_types):\n        return pyglet.canvas.Display(spec)\n    else:\n        raise error.Error(\n            'Invalid display specification: {}. (Must be a string like :0 or None.)'.format(spec))\n\n\nclass Viewer(object):\n    def __init__(self, width, height, display=None):\n        display = get_display(display)\n\n        self.width = width\n        self.height = height\n\n        self.window = pyglet.window.Window(\n            width=width, height=height, display=display)\n        self.window.on_close = self.window_closed_by_user\n        self.geoms = []\n        self.onetime_geoms = []\n        self.transform = Transform()\n\n        glEnable(GL_BLEND)\n        # glEnable(GL_MULTISAMPLE)\n        glEnable(GL_LINE_SMOOTH)\n        # glHint(GL_LINE_SMOOTH_HINT, GL_DONT_CARE)\n        glHint(GL_LINE_SMOOTH_HINT, GL_NICEST)\n        glLineWidth(2.0)\n        glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA)\n\n    def close(self):\n        self.window.close()\n\n    def window_closed_by_user(self):\n        self.close()\n\n    def set_bounds(self, left, right, bottom, top):\n        assert right > left and top > bottom\n        scalex = self.width/(right-left)\n        scaley = self.height/(top-bottom)\n        self.transform = Transform(\n            translation=(-left*scalex, -bottom*scaley),\n            scale=(scalex, scaley))\n\n    def add_geom(self, geom):\n        self.geoms.append(geom)\n\n    def add_onetime(self, geom):\n        self.onetime_geoms.append(geom)\n\n    def render(self, return_rgb_array=False):\n        glClearColor(1, 1, 1, 1)\n        self.window.clear()\n        self.window.switch_to()\n        self.window.dispatch_events()\n        self.transform.enable()\n        for geom in self.geoms:\n            geom.render()\n        for geom in self.onetime_geoms:\n            geom.render()\n        self.transform.disable()\n        arr = None\n        if return_rgb_array:\n            buffer = pyglet.image.get_buffer_manager().get_color_buffer()\n            image_data = buffer.get_image_data()\n            arr = np.fromstring(image_data.data, dtype=np.uint8, sep='')\n            # In https://github.com/openai/gym-http-api/issues/2, we\n            # discovered that someone using Xmonad on Arch was having\n            # a window of size 598 x 398, though a 600 x 400 window\n            # was requested. (Guess Xmonad was preserving a pixel for\n            # the boundary.) So we use the buffer height/width rather\n            # than the requested one.\n            arr = arr.reshape(buffer.height, buffer.width, 4)\n            arr = arr[::-1, :, 0:3]\n        self.window.flip()\n        self.onetime_geoms = []\n        return arr\n\n    # Convenience\n    def draw_circle(self, radius=10, res=30, filled=True, **attrs):\n        geom = make_circle(radius=radius, res=res, filled=filled)\n        _add_attrs(geom, attrs)\n        self.add_onetime(geom)\n        return geom\n\n    def draw_polygon(self, v, filled=True, **attrs):\n        geom = make_polygon(v=v, filled=filled)\n        _add_attrs(geom, attrs)\n        self.add_onetime(geom)\n        return geom\n\n    def draw_polyline(self, v, **attrs):\n        geom = make_polyline(v=v)\n        _add_attrs(geom, attrs)\n        self.add_onetime(geom)\n        return geom\n\n    def draw_line(self, start, end, **attrs):\n        geom = Line(start, end)\n        _add_attrs(geom, attrs)\n        self.add_onetime(geom)\n        return geom\n\n    def get_array(self):\n        self.window.flip()\n        image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()\n        self.window.flip()\n        arr = np.fromstring(image_data.data, dtype=np.uint8, sep='')\n        arr = arr.reshape(self.height, self.width, 4)\n        return arr[::-1, :, 0:3]\n\n\ndef _add_attrs(geom, attrs):\n    if \"color\" in attrs:\n        geom.set_color(*attrs[\"color\"])\n    if \"linewidth\" in attrs:\n        geom.set_linewidth(attrs[\"linewidth\"])\n\n\nclass Geom(object):\n    def __init__(self):\n        self._color = Color((0, 0, 0, 1.0))\n        self.attrs = [self._color]\n\n    def render(self):\n        for attr in reversed(self.attrs):\n            attr.enable()\n        self.render1()\n        for attr in self.attrs:\n            attr.disable()\n\n    def render1(self):\n        raise NotImplementedError\n\n    def add_attr(self, attr):\n        self.attrs.append(attr)\n\n    def set_color(self, r, g, b, alpha=1):\n        self._color.vec4 = (r, g, b, alpha)\n\n\nclass Attr(object):\n    def enable(self):\n        raise NotImplementedError\n\n    def disable(self):\n        pass\n\n\nclass Transform(Attr):\n    def __init__(self, translation=(0.0, 0.0), rotation=0.0, scale=(1, 1)):\n        self.set_translation(*translation)\n        self.set_rotation(rotation)\n        self.set_scale(*scale)\n\n    def enable(self):\n        glPushMatrix()\n        # translate to GL loc ppint\n        glTranslatef(self.translation[0], self.translation[1], 0)\n        glRotatef(RAD2DEG * self.rotation, 0, 0, 1.0)\n        glScalef(self.scale[0], self.scale[1], 1)\n\n    def disable(self):\n        glPopMatrix()\n\n    def set_translation(self, newx, newy):\n        self.translation = (float(newx), float(newy))\n\n    def set_rotation(self, new):\n        self.rotation = float(new)\n\n    def set_scale(self, newx, newy):\n        self.scale = (float(newx), float(newy))\n\n\nclass Color(Attr):\n    def __init__(self, vec4):\n        self.vec4 = vec4\n\n    def enable(self):\n        glColor4f(*self.vec4)\n\n\nclass LineStyle(Attr):\n    def __init__(self, style):\n        self.style = style\n\n    def enable(self):\n        glEnable(GL_LINE_STIPPLE)\n        glLineStipple(1, self.style)\n\n    def disable(self):\n        glDisable(GL_LINE_STIPPLE)\n\n\nclass LineWidth(Attr):\n    def __init__(self, stroke):\n        self.stroke = stroke\n\n    def enable(self):\n        glLineWidth(self.stroke)\n\n\nclass Point(Geom):\n    def __init__(self):\n        Geom.__init__(self)\n\n    def render1(self):\n        glBegin(GL_POINTS)  # draw point\n        glVertex3f(0.0, 0.0, 0.0)\n        glEnd()\n\n\nclass FilledPolygon(Geom):\n    def __init__(self, v):\n        Geom.__init__(self)\n        self.v = v\n\n    def render1(self):\n        if len(self.v) == 4:\n            glBegin(GL_QUADS)\n        elif len(self.v) > 4:\n            glBegin(GL_POLYGON)\n        else:\n            glBegin(GL_TRIANGLES)\n        for p in self.v:\n            glVertex3f(p[0], p[1], 0)  # draw each vertex\n        glEnd()\n\n        color = (self._color.vec4[0] * 0.5, self._color.vec4[1] *\n                 0.5, self._color.vec4[2] * 0.5, self._color.vec4[3] * 0.5)\n        glColor4f(*color)\n        glBegin(GL_LINE_LOOP)\n        for p in self.v:\n            glVertex3f(p[0], p[1], 0)  # draw each vertex\n        glEnd()\n\n\ndef make_circle(radius=10, res=30, filled=True):\n    points = []\n    for i in range(res):\n        ang = 2*math.pi*i / res\n        points.append((math.cos(ang)*radius, math.sin(ang)*radius))\n    if filled:\n        return FilledPolygon(points)\n    else:\n        return PolyLine(points, True)\n\n\ndef make_polygon(v, filled=True):\n    if filled:\n        return FilledPolygon(v)\n    else:\n        return PolyLine(v, True)\n\n\ndef make_polyline(v):\n    return PolyLine(v, False)\n\n\ndef make_capsule(length, width):\n    l, r, t, b = 0, length, width/2, -width/2\n    box = make_polygon([(l, b), (l, t), (r, t), (r, b)])\n    circ0 = make_circle(width/2)\n    circ1 = make_circle(width/2)\n    circ1.add_attr(Transform(translation=(length, 0)))\n    geom = Compound([box, circ0, circ1])\n    return geom\n\n\nclass Compound(Geom):\n    def __init__(self, gs):\n        Geom.__init__(self)\n        self.gs = gs\n        for g in self.gs:\n            g.attrs = [a for a in g.attrs if not isinstance(a, Color)]\n\n    def render1(self):\n        for g in self.gs:\n            g.render()\n\n\nclass PolyLine(Geom):\n    def __init__(self, v, close):\n        Geom.__init__(self)\n        self.v = v\n        self.close = close\n        self.linewidth = LineWidth(1)\n        self.add_attr(self.linewidth)\n\n    def render1(self):\n        glBegin(GL_LINE_LOOP if self.close else GL_LINE_STRIP)\n        for p in self.v:\n            glVertex3f(p[0], p[1], 0)  # draw each vertex\n        glEnd()\n\n    def set_linewidth(self, x):\n        self.linewidth.stroke = x\n\n\nclass Line(Geom):\n    def __init__(self, start=(0.0, 0.0), end=(0.0, 0.0)):\n        Geom.__init__(self)\n        self.start = start\n        self.end = end\n        self.linewidth = LineWidth(1)\n        self.add_attr(self.linewidth)\n\n    def render1(self):\n        glBegin(GL_LINES)\n        glVertex2f(*self.start)\n        glVertex2f(*self.end)\n        glEnd()\n\n\nclass Image(Geom):\n    def __init__(self, fname, width, height):\n        Geom.__init__(self)\n        self.width = width\n        self.height = height\n        img = pyglet.image.load(fname)\n        self.img = img\n        self.flip = False\n\n    def render1(self):\n        self.img.blit(-self.width/2, -self.height/2,\n                      width=self.width, height=self.height)\n\n# ================================================================\n\n\nclass SimpleImageViewer(object):\n    def __init__(self, display=None):\n        self.window = None\n        self.isopen = False\n        self.display = display\n\n    def imshow(self, arr):\n        if self.window is None:\n            height, width, channels = arr.shape\n            self.window = pyglet.window.Window(\n                width=width, height=height, display=self.display)\n            self.width = width\n            self.height = height\n            self.isopen = True\n        assert arr.shape == (\n            self.height, self.width, 3), \"You passed in an image with the wrong number shape\"\n        image = pyglet.image.ImageData(\n            self.width, self.height, 'RGB', arr.tobytes(), pitch=self.width * -3)\n        self.window.clear()\n        self.window.switch_to()\n        self.window.dispatch_events()\n        image.blit(0, 0)\n        self.window.flip()\n\n    def close(self):\n        if self.isopen:\n            self.window.close()\n            self.isopen = False\n\n    def __del__(self):\n        self.close()\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/mpe/scenario.py",
    "content": "import numpy as np\n\n\n# defines scenario upon which the world is built\nclass BaseScenario(object):\n    # create elements of the world\n    def make_world(self):\n        raise NotImplementedError()\n    # create initial conditions of the world\n\n    def reset_world(self, world):\n        raise NotImplementedError()\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/mpe/scenarios/__init__.py",
    "content": "import imp\nimport os.path as osp\n\ndef load(name):\n    pathname = osp.join(osp.dirname(__file__), name)\n    return imp.load_source('', pathname)\n\n\n\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/mpe/scenarios/hetero_spread.py",
    "content": "import numpy as np\nfrom mpe.core import World, Agent, Landmark\nfrom mpe.scenario import BaseScenario\n\n\n\nclass Scenario(BaseScenario):\n    def make_world(self, args):\n        world = World()\n        # set any world properties first\n        world.dim_c = 2\n        world.max_steps = 25\n        num_agents = args.num_agents\n        self.n_agent_a = num_agents // 2 # 2\n        self.n_agent_b = num_agents // 2 # 2\n        num_landmarks = args.num_agents\n        world.collaborative = True\n        self.agent_size = 0.10\n        self.n_others = 3\n        self.n_group = 2\n        # add agents\n        world.agents = [Agent() for i in range(num_agents)]\n        for i, agent in enumerate(world.agents):\n            agent.name = 'agent %d' % i\n            agent.collide = True\n            agent.silent = True\n            agent.id = i\n            if i < self.n_agent_a:\n                agent.size = self.agent_size\n                agent.accel = 3.0\n                agent.max_speed = 1.0\n            else:\n                agent.size = self.agent_size / 2\n                agent.accel = 4.0\n                agent.max_speed = 1.3\n\n        # add landmarks\n        world.landmarks = [Landmark() for i in range(num_landmarks)]\n        for i, landmark in enumerate(world.landmarks):\n            landmark.name = 'landmark %d' % i\n            landmark.collide = False\n            landmark.movable = False\n        # make initial conditions\n        self.reset_world(world)\n        return world\n\n    def reset_world(self, world):\n        world.num_steps = 0\n        self.end_steps = world.max_steps\n        # random properties for agents\n        for i, agent in enumerate(world.agents):\n            if i < self.n_agent_a:\n                agent.color = np.array([0.35, 0.35, 0.85])\n            else:\n                agent.color = np.array([0.35, 0.85, 0.35])\n        # random properties for landmarks\n        for i, landmark in enumerate(world.landmarks):\n            landmark.color = np.array([0.25, 0.25, 0.25])\n        # set random initial states\n        for agent in world.agents:\n            agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            agent.state.p_vel = np.zeros(world.dim_p)\n            agent.state.c = np.zeros(world.dim_c)\n        for i, landmark in enumerate(world.landmarks):\n            landmark.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n\n    def benchmark_data(self, agent, world):\n        rew = 0\n        collisions = 0\n        occupied_landmarks = 0\n        min_dists = 0\n        for l in world.landmarks:\n            dists = [np.sqrt(np.sum(np.square(a.state.p_pos - l.state.p_pos))) for a in world.agents]\n            min_dists += min(dists)\n            rew -= min(dists)\n            if min(dists) < 0.1:\n                occupied_landmarks += 1\n        if agent.collide:\n            for a in world.agents:\n                if self.is_collision(a, agent):\n                    rew -= 1\n                    collisions += 1\n        return (rew, collisions, min_dists, occupied_landmarks)\n\n\n    def is_collision(self, agent1, agent2):\n        delta_pos = agent1.state.p_pos - agent2.state.p_pos\n        dist = np.sqrt(np.sum(np.square(delta_pos)))\n        dist_min = agent1.size + agent2.size\n        return True if dist < dist_min else False\n\n    def reward(self, agent, world):\n        # Agents are rewarded based on minimum agent distance to each landmark, penalized for collisions\n        rew = 0\n        shaped_reward = False\n        if shaped_reward:  # distance-based reward\n            for l in world.landmarks:\n                dists = [np.sqrt(np.sum(np.square(a.state.p_pos - l.state.p_pos))) for a in world.agents]\n                rew -= min(dists)\n            if agent.collide:\n                for a in world.agents:\n                    if self.is_collision(a, agent):\n                        rew -= 1\n            return rew\n        else:\n            win_agents = []\n            for land in world.landmarks:\n                for a in world.agents:\n                    if self.is_collision(a, land):\n                        win_agents.append(a)\n                        break\n            rew += 2 * len(set(win_agents))\n\n            def bound(x):\n                if x > 1.0:\n                    return min(np.exp(2 * x - 2), 10)\n                else:\n                    return 0.0\n            bound_rew = 0.0\n            for p in range(world.dim_p):\n                x = abs(agent.state.p_pos[p])\n                bound_rew -= bound(x)\n            rew += bound_rew\n            return rew\n\n    def observation(self, agent, world):\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:  # world.entities:\n            entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n        # entity colors\n        entity_color = []\n        other_vel = []\n        for entity in world.landmarks:  # world.entities:\n            entity_color.append(entity.color)\n        # communication of all other agents\n        comm = []\n        other_pos = []\n        for other in world.agents:\n            if other is agent:\n                other_vel.append([0, 0])\n                continue\n            comm.append(other.state.c)\n            other_pos.append(other.state.p_pos - agent.state.p_pos)\n        return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + comm +\n                              other_vel + entity_pos + other_pos)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/mpe/scenarios/simple_adversary.py",
    "content": "import numpy as np\nfrom mpe.core import World, Agent, Landmark\nfrom mpe.scenario import BaseScenario  # TODO\nimport random\n\n\nclass Scenario(BaseScenario):\n\n    def make_world(self, args):\n        world = World()  # from core\n        # set any world properties first\n        world.dim_c = 2\n        num_agents = args.num_agents  # 3\n        world.num_agents = num_agents\n        num_adversaries = 1\n        num_landmarks = num_agents - 1\n        # add agents\n        world.agents = [Agent() for i in range(num_agents)]\n        for i, agent in enumerate(world.agents):\n            agent.name = 'agent %d' % i\n            agent.collide = False\n            agent.silent = True\n            agent.adversary = True if i < num_adversaries else False\n            agent.size = 0.15\n        # add landmarks\n        world.landmarks = [Landmark() for i in range(num_landmarks)]\n        for i, landmark in enumerate(world.landmarks):\n            landmark.name = 'landmark %d' % i\n            landmark.collide = False\n            landmark.movable = False\n            landmark.size = 0.08\n        # make initial conditions\n        self.reset_world(world)\n        return world\n\n    def reset_world(self, world):\n        # random properties for agents\n        world.assign_agent_colors()\n        # random properties for landmarks\n        world.assign_landmark_colors()\n        # set goal landmark\n        goal = np.random.choice(world.landmarks)\n        goal.color = np.array([0.15, 0.65, 0.15])\n        for agent in world.agents:\n            agent.goal_a = goal\n        # set random initial states\n        for agent in world.agents:\n            agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            agent.state.p_vel = np.zeros(world.dim_p)\n            agent.state.c = np.zeros(world.dim_c)\n        for i, landmark in enumerate(world.landmarks):\n            landmark.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n\n    def benchmark_data(self, agent, world):\n        # returns data for benchmarking purposes\n        if agent.adversary:\n            return np.sum(np.square(agent.state.p_pos - agent.goal_a.state.p_pos))\n        else:\n            dists = []\n            for l in world.landmarks:\n                dists.append(np.sum(np.square(agent.state.p_pos - l.state.p_pos)))\n            dists.append(np.sum(np.square(agent.state.p_pos - agent.goal_a.state.p_pos)))\n            return tuple(dists)\n\n    # return all agents that are not adversaries\n    def good_agents(self, world):\n        return [agent for agent in world.agents if not agent.adversary]\n\n    # return all adversarial agents\n    def adversaries(self, world):\n        return [agent for agent in world.agents if agent.adversary]\n\n    def reward(self, agent, world):\n        # Agents are rewarded based on minimum agent distance to each landmark\n        return self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)\n\n    def agent_reward(self, agent, world):\n        # Rewarded based on how close any good agent is to the goal landmark, and how far the adversary is from it\n        shaped_reward = True\n        shaped_adv_reward = True\n\n        # Calculate negative reward for adversary\n        adversary_agents = self.adversaries(world)\n        if shaped_adv_reward:  # distance-based adversary reward\n            adv_rew = sum([np.sqrt(np.sum(np.square(a.state.p_pos - a.goal_a.state.p_pos))) for a in adversary_agents])\n        else:  # proximity-based adversary reward (binary)\n            adv_rew = 0\n            for a in adversary_agents:\n                if np.sqrt(np.sum(np.square(a.state.p_pos - a.goal_a.state.p_pos))) < 2 * a.goal_a.size:\n                    adv_rew -= 5\n\n        # Calculate positive reward for agents\n        good_agents = self.good_agents(world)\n        if shaped_reward:  # distance-based agent reward\n            pos_rew = -min(\n                [np.sqrt(np.sum(np.square(a.state.p_pos - a.goal_a.state.p_pos))) for a in good_agents])\n        else:  # proximity-based agent reward (binary)\n            pos_rew = 0\n            if min([np.sqrt(np.sum(np.square(a.state.p_pos - a.goal_a.state.p_pos))) for a in good_agents]) \\\n                    < 2 * agent.goal_a.size:\n                pos_rew += 5\n            pos_rew -= min(\n                [np.sqrt(np.sum(np.square(a.state.p_pos - a.goal_a.state.p_pos))) for a in good_agents])\n        return pos_rew + adv_rew\n\n    def adversary_reward(self, agent, world):\n        # Rewarded based on proximity to the goal landmark\n        shaped_reward = True\n        if shaped_reward:  # distance-based reward\n            return -np.sum(np.square(agent.state.p_pos - agent.goal_a.state.p_pos))\n        else:  # proximity-based reward (binary)\n            adv_rew = 0\n            if np.sqrt(np.sum(np.square(agent.state.p_pos - agent.goal_a.state.p_pos))) < 2 * agent.goal_a.size:\n                adv_rew += 5\n            return adv_rew\n\n\n    def observation(self, agent, world):\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:\n            entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n        # entity colors\n        entity_color = []\n        for entity in world.landmarks:\n            entity_color.append(entity.color)\n        # communication of all other agents\n        other_pos = []\n        for other in world.agents:\n            if other is agent: continue\n            other_pos.append(other.state.p_pos - agent.state.p_pos)\n\n        if not agent.adversary:\n            return np.concatenate([agent.goal_a.state.p_pos - agent.state.p_pos] + entity_pos + other_pos)\n        else:\n            return np.concatenate(entity_pos + other_pos)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/mpe/scenarios/simple_crypto.py",
    "content": "\"\"\"\nScenario:\n1 speaker, 2 listeners (one of which is an adversary). Good agents rewarded for proximity to goal, and distance from\nadversary to goal. Adversary is rewarded for its distance to the goal.\n\"\"\"\n\nimport numpy as np\nfrom mpe.core import World, Agent, Landmark\nfrom mpe.scenario import BaseScenario\nimport random\n\n\nclass CryptoAgent(Agent):\n    def __init__(self):\n        super(CryptoAgent, self).__init__()\n        self.key = None\n\n\nclass Scenario(BaseScenario):\n\n    def make_world(self, args):\n        world = World()\n        # set any world properties first\n        num_agents = args.num_agents  # 3\n        num_adversaries = 1\n        num_landmarks = args.num_landmarks  # 2\n        world.dim_c = 4\n        # add agents\n        world.agents = [CryptoAgent() for i in range(num_agents)]\n        for i, agent in enumerate(world.agents):\n            agent.name = 'agent %d' % i\n            agent.collide = False\n            agent.adversary = True if i < num_adversaries else False\n            agent.speaker = True if i == 2 else False\n            agent.movable = False\n        # add landmarks\n        world.landmarks = [Landmark() for i in range(num_landmarks)]\n        for i, landmark in enumerate(world.landmarks):\n            landmark.name = 'landmark %d' % i\n            landmark.collide = False\n            landmark.movable = False\n        # make initial conditions\n        self.reset_world(world)\n        return world\n\n    def reset_world(self, world):\n        # random properties for agents\n        for agent in world.agents:\n            agent.color = np.array([0.25, 0.25, 0.25])\n            if agent.adversary:\n                agent.color = np.array([0.75, 0.25, 0.25])\n            agent.key = None\n        # random properties for landmarks\n        color_list = [np.zeros(world.dim_c) for i in world.landmarks]\n        for i, color in enumerate(color_list):\n            color[i] += 1\n        for color, landmark in zip(color_list, world.landmarks):\n            landmark.color = color\n        # set goal landmark\n        goal = np.random.choice(world.landmarks)\n        world.agents[1].color = goal.color\n        world.agents[2].key = np.random.choice(world.landmarks).color\n\n        for agent in world.agents:\n            agent.goal_a = goal\n\n        # set random initial states\n        for agent in world.agents:\n            agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            agent.state.p_vel = np.zeros(world.dim_p)\n            agent.state.c = np.zeros(world.dim_c)\n        for i, landmark in enumerate(world.landmarks):\n            landmark.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n\n    def benchmark_data(self, agent, world):\n        # returns data for benchmarking purposes\n        return (agent.state.c, agent.goal_a.color)\n\n    # return all agents that are not adversaries\n    def good_listeners(self, world):\n        return [agent for agent in world.agents if not agent.adversary and not agent.speaker]\n\n    # return all agents that are not adversaries\n    def good_agents(self, world):\n        return [agent for agent in world.agents if not agent.adversary]\n\n    # return all adversarial agents\n    def adversaries(self, world):\n        return [agent for agent in world.agents if agent.adversary]\n\n    def reward(self, agent, world):\n        return self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)\n\n    def agent_reward(self, agent, world):\n        # Agents rewarded if Bob can reconstruct message, but adversary (Eve) cannot\n        good_listeners = self.good_listeners(world)\n        adversaries = self.adversaries(world)\n        good_rew = 0\n        adv_rew = 0\n        for a in good_listeners:\n            if (a.state.c == np.zeros(world.dim_c)).all():\n                continue\n            else:\n                good_rew -= np.sum(np.square(a.state.c - agent.goal_a.color))\n        for a in adversaries:\n            if (a.state.c == np.zeros(world.dim_c)).all():\n                continue\n            else:\n                adv_l1 = np.sum(np.square(a.state.c - agent.goal_a.color))\n                adv_rew += adv_l1\n        return adv_rew + good_rew\n\n    def adversary_reward(self, agent, world):\n        # Adversary (Eve) is rewarded if it can reconstruct original goal\n        rew = 0\n        if not (agent.state.c == np.zeros(world.dim_c)).all():\n            rew -= np.sum(np.square(agent.state.c - agent.goal_a.color))\n        return rew\n\n    def observation(self, agent, world):\n        # goal color\n        goal_color = np.zeros(world.dim_color)\n        if agent.goal_a is not None:\n            goal_color = agent.goal_a.color\n\n        # print('goal color in obs is {}'.format(goal_color))\n\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:\n            entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n        # communication of all other agents\n        comm = []\n        for other in world.agents:\n            if other is agent or (other.state.c is None) or not other.speaker: continue\n            comm.append(other.state.c)\n\n        confer = np.array([0])\n\n        if world.agents[2].key is None:\n            confer = np.array([1])\n            key = np.zeros(world.dim_c)\n            goal_color = np.zeros(world.dim_c)\n        else:\n            key = world.agents[2].key\n\n        prnt = False  # True  if train use False\n        # speaker\n        if agent.speaker:\n            if prnt:\n                print('speaker')\n                print(agent.state.c)\n                print(np.concatenate([goal_color] + [key] + [confer] + [np.random.randn(1)]))\n            return np.concatenate([goal_color] + [key])\n        # listener\n        if not agent.speaker and not agent.adversary:\n            if prnt:\n                print('listener')\n                print(agent.state.c)\n                print(np.concatenate([key] + comm + [confer]))\n            return np.concatenate([key] + comm)\n        if not agent.speaker and agent.adversary:\n            if prnt:\n                print('adversary')\n                print(agent.state.c)\n                print(np.concatenate(comm + [confer]))\n            return np.concatenate(comm)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/mpe/scenarios/simple_crypto_display.py",
    "content": "\"\"\"\nScenario:\n1 speaker, 2 listeners (one of which is an adversary). Good agents rewarded for proximity to goal, and distance from\nadversary to goal. Adversary is rewarded for its distance to the goal.\n\"\"\"\n\nimport numpy as np\nfrom mpe.core import World, Agent, Landmark\nfrom mpe.scenario import BaseScenario\nimport random\n\n\nclass CryptoAgent(Agent):\n    def __init__(self):\n        super(CryptoAgent, self).__init__()\n        self.key = None\n\n\nclass Scenario(BaseScenario):\n\n    def make_world(self, args):\n        world = World()\n        # set any world properties first\n        num_agents = args.num_agents  # 3\n        num_adversaries = 1\n        num_landmarks = args.num_landmarks  # 2\n        world.dim_c = 4\n        # add agents\n        world.agents = [CryptoAgent() for i in range(num_agents)]\n        for i, agent in enumerate(world.agents):\n            agent.name = 'agent %d' % i\n            agent.collide = False\n            agent.adversary = True if i < num_adversaries else False\n            agent.speaker = True if i == 2 else False\n            agent.movable = False\n        # add landmarks\n        world.landmarks = [Landmark() for i in range(num_landmarks)]\n        for i, landmark in enumerate(world.landmarks):\n            landmark.name = 'landmark %d' % i\n            landmark.collide = False\n            landmark.movable = False\n        # make initial conditions\n        self.reset_world(world)\n        return world\n\n    def reset_world(self, world):\n        # random properties for agents\n        world.assign_agent_colors()\n        for agent in world.agents:\n            if agent.speaker:\n                agent.color = np.array([0.25, 0.75, 0.25])\n            agent.key = None\n        # random properties for landmarks\n        world.assign_landmark_colors()\n        # random properties for landmarks\n        channel_list = [np.zeros(world.dim_c) for i in world.landmarks]\n        for i, channel in enumerate(channel_list):\n            channel[i] += 1\n        for channel, landmark in zip(channel_list, world.landmarks):\n            landmark.channel = channel\n        # set goal landmark\n        goal = np.random.choice(world.landmarks)\n        world.agents[1].channel = goal.channel\n        world.agents[2].key = np.random.choice(world.landmarks).channel\n\n        for agent in world.agents:\n            agent.goal_a = goal\n\n        # set random initial states\n        for i, agent in enumerate(world.agents):\n            # agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            agent.state.p_pos = np.array([0.0, -0.5 + 1.0 / (len(world.agents) - 1) * i])\n            agent.state.p_vel = np.zeros(world.dim_p)\n            agent.state.c = np.zeros(world.dim_c)\n        for i, landmark in enumerate(world.landmarks):\n            if landmark is goal:\n                landmark.color = np.array([0.15, 0.15, 0.75])\n            # landmark.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            landmark.state.p_pos = np.array([0.5, 0.5 - 0.5 / (len(world.landmarks) - 1) * i])\n            landmark.state.p_vel = np.zeros(world.dim_p)\n\n    def benchmark_data(self, agent, world):\n        # returns data for benchmarking purposes\n        return (agent.state.c, agent.goal_a.channel)\n\n    # return all agents that are not adversaries\n    def good_listeners(self, world):\n        return [agent for agent in world.agents if not agent.adversary and not agent.speaker]\n\n    # return all agents that are not adversaries\n    def good_agents(self, world):\n        return [agent for agent in world.agents if not agent.adversary]\n\n    # return all adversarial agents\n    def adversaries(self, world):\n        return [agent for agent in world.agents if agent.adversary]\n\n    def reward(self, agent, world):\n        return self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)\n\n    def agent_reward(self, agent, world):\n        # Agents rewarded if Bob can reconstruct message, but adversary (Eve) cannot\n        good_listeners = self.good_listeners(world)\n        adversaries = self.adversaries(world)\n        good_rew = 0\n        adv_rew = 0\n        for a in good_listeners:\n            if (a.state.c == np.zeros(world.dim_c)).all():\n                continue\n            else:\n                good_rew -= np.sum(np.square(a.state.c - agent.goal_a.channel))\n        for a in adversaries:\n            if (a.state.c == np.zeros(world.dim_c)).all():\n                continue\n            else:\n                adv_l1 = np.sum(np.square(a.state.c - agent.goal_a.channel))\n                adv_rew += adv_l1\n        return adv_rew + good_rew\n\n    def adversary_reward(self, agent, world):\n        # Adversary (Eve) is rewarded if it can reconstruct original goal\n        rew = 0\n        if not (agent.state.c == np.zeros(world.dim_c)).all():\n            rew -= np.sum(np.square(agent.state.c - agent.goal_a.channel))\n        return rew\n\n    def observation(self, agent, world):\n        # goal channel\n        goal_channel = np.zeros(world.dim_color)\n        if agent.goal_a is not None:\n            goal_channel = agent.goal_a.channel\n\n        print('goal channel in obs is {}'.format(goal_channel))\n\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:\n            entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n        # communication of all other agents\n        comm = []\n        for other in world.agents:\n            if other is agent or (other.state.c is None) or not other.speaker: continue\n            comm.append(other.state.c)\n        confer = np.array([0])\n        if world.agents[2].key is None:\n            confer = np.array([1])\n            key = np.zeros(world.dim_c)\n            goal_channel = np.zeros(world.dim_c)\n        else:\n            key = world.agents[2].key\n\n        prnt = True  # if train use False\n        # speaker\n        if agent.speaker:\n            if prnt:\n                print('speaker')\n                print(agent.state.c)\n            #        print(np.concatenate([goal_channel] + [key] + [confer] + [np.random.randn(1)]))\n            return np.concatenate([goal_channel] + [key])\n        # listener\n        if not agent.speaker and not agent.adversary:\n            if prnt:\n                print('listener')\n                print(agent.state.c)\n            #        print(np.concatenate([key] + comm + [confer]))\n            return np.concatenate([key] + comm)\n        if not agent.speaker and agent.adversary:\n            if prnt:\n                print('adversary')\n                print(agent.state.c)\n            #        print(np.concatenate(comm + [confer]))\n            return np.concatenate(comm)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/mpe/scenarios/simple_push.py",
    "content": "import numpy as np\nfrom mpe.core import World, Agent, Landmark\nfrom mpe.scenario import BaseScenario\nimport random\n\n\n#\n#     # the non-ensemble version of <ensemble_push>\n#\n#\n\nclass Scenario(BaseScenario):\n    def make_world(self, args):\n        world = World()\n        # set any world properties first\n        world.dim_c = 2\n        num_agents = args.num_agents  # 2\n        num_adversaries = 1\n        num_landmarks = args.num_landmarks  # 2\n        # add agents\n        world.agents = [Agent() for i in range(num_agents)]\n        for i, agent in enumerate(world.agents):\n            agent.name = 'agent %d' % i\n            agent.collide = True\n            agent.silent = True\n            if i < num_adversaries:\n                agent.adversary = True\n            else:\n                agent.adversary = False\n            # agent.u_noise = 1e-1\n            # agent.c_noise = 1e-1\n        # add landmarks\n        world.landmarks = [Landmark() for i in range(num_landmarks)]\n        for i, landmark in enumerate(world.landmarks):\n            landmark.name = 'landmark %d' % i\n            landmark.collide = False\n            landmark.movable = False\n        # make initial conditions\n        self.reset_world(world)\n        return world\n\n    def reset_world(self, world):\n        # random properties for landmarks\n        for i, landmark in enumerate(world.landmarks):\n            landmark.color = np.array([0.1, 0.1, 0.1])\n            landmark.color[i + 1] += 0.8\n            landmark.index = i\n        # set goal landmark\n        goal = np.random.choice(world.landmarks)\n        for i, agent in enumerate(world.agents):\n            agent.goal_a = goal\n            agent.color = np.array([0.25, 0.25, 0.25])\n            if agent.adversary:\n                agent.color = np.array([0.75, 0.25, 0.25])\n            else:\n                j = goal.index\n                agent.color[j + 1] += 0.5\n        # set random initial states\n        for agent in world.agents:\n            agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            agent.state.p_vel = np.zeros(world.dim_p)\n            agent.state.c = np.zeros(world.dim_c)\n        for i, landmark in enumerate(world.landmarks):\n            landmark.state.p_pos = 0.8 * np.random.uniform(-1, +1, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n\n    def reward(self, agent, world):\n        # Agents are rewarded based on minimum agent distance to each landmark\n        return self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)\n\n    def agent_reward(self, agent, world):\n        # the distance to the goal\n        return -np.sqrt(np.sum(np.square(agent.state.p_pos - agent.goal_a.state.p_pos)))\n\n    def adversary_reward(self, agent, world):\n        # keep the nearest good agents away from the goal\n        agent_dist = [np.sqrt(np.sum(np.square(a.state.p_pos - a.goal_a.state.p_pos))) for a in world.agents if\n                      not a.adversary]\n        pos_rew = min(agent_dist)\n        # nearest_agent = world.good_agents[np.argmin(agent_dist)]\n        # neg_rew = np.sqrt(np.sum(np.square(nearest_agent.state.p_pos - agent.state.p_pos)))\n        neg_rew = np.sqrt(np.sum(np.square(agent.goal_a.state.p_pos - agent.state.p_pos)))\n        # neg_rew = sum([np.sqrt(np.sum(np.square(a.state.p_pos - agent.state.p_pos))) for a in world.good_agents])\n        return pos_rew - neg_rew\n\n    def observation(self, agent, world):\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:  # world.entities:\n            entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n        # entity colors\n        entity_color = []\n        for entity in world.landmarks:  # world.entities:\n            entity_color.append(entity.color)\n        # communication of all other agents\n        comm = []\n        other_pos = []\n        for other in world.agents:\n            if other is agent: continue\n            comm.append(other.state.c)\n            other_pos.append(other.state.p_pos - agent.state.p_pos)\n        if not agent.adversary:\n            return np.concatenate([agent.state.p_vel] + [agent.goal_a.state.p_pos - agent.state.p_pos] + [\n                agent.color] + entity_pos + entity_color + other_pos)\n        else:\n            # other_pos = list(reversed(other_pos)) if random.uniform(0,1) > 0.5 else other_pos  # randomize position of other agents in adversary network\n            return np.concatenate([agent.state.p_vel] + entity_pos + other_pos)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/mpe/scenarios/simple_reference.py",
    "content": "import numpy as np\nfrom mpe.core import World, Agent, Landmark\nfrom mpe.scenario import BaseScenario\n\n\nclass Scenario(BaseScenario):\n    def make_world(self, args):\n        world = World()\n        # set any world properties first\n        # world.world_length = args.episode_length\n        world.dim_c = 10\n        world.collaborative = True  # whether agents share rewards\n        # add agents\n        world.num_agents = args.num_agents  # 2\n        assert world.num_agents == 2, (\n            \"only 2 agents is supported, check the config.py.\")\n        world.agents = [Agent() for i in range(world.num_agents)]\n        for i, agent in enumerate(world.agents):\n            agent.name = 'agent %d' % i\n            agent.collide = False\n            # agent.u_noise = 1e-1\n            # agent.c_noise = 1e-1\n        # add landmarks\n        world.num_landmarks = args.num_landmarks  # 3\n        world.landmarks = [Landmark() for i in range(world.num_landmarks)]\n        for i, landmark in enumerate(world.landmarks):\n            landmark.name = 'landmark %d' % i\n            landmark.collide = False\n            landmark.movable = False\n        # make initial conditions\n        self.reset_world(world)\n        return world\n\n    def reset_world(self, world):\n        # assign goals to agents\n        for agent in world.agents:\n            agent.goal_a = None\n            agent.goal_b = None\n        # want other agent to go to the goal landmark\n        world.agents[0].goal_a = world.agents[1]\n        world.agents[0].goal_b = np.random.choice(world.landmarks)\n        world.agents[1].goal_a = world.agents[0]\n        world.agents[1].goal_b = np.random.choice(world.landmarks)\n        # random properties for agents\n        world.assign_agent_colors()\n        # random properties for landmarks\n        world.landmarks[0].color = np.array([0.75, 0.25, 0.25])\n        world.landmarks[1].color = np.array([0.25, 0.75, 0.25])\n        world.landmarks[2].color = np.array([0.25, 0.25, 0.75])\n        # special colors for goals\n        world.agents[0].goal_a.color = world.agents[0].goal_b.color\n        world.agents[1].goal_a.color = world.agents[1].goal_b.color\n        # set random initial states\n        for agent in world.agents:\n            agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            agent.state.p_vel = np.zeros(world.dim_p)\n            agent.state.c = np.zeros(world.dim_c)\n        for i, landmark in enumerate(world.landmarks):\n            landmark.state.p_pos = 0.8 * np.random.uniform(-1, +1, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n\n    def reward(self, agent, world):\n        if agent.goal_a is None or agent.goal_b is None:\n            return 0.0\n        dist2 = np.sum(\n            np.square(agent.goal_a.state.p_pos - agent.goal_b.state.p_pos))\n        return -dist2  # np.exp(-dist2)\n\n    def observation(self, agent, world):\n        # goal positions\n        # goal_pos = [np.zeros(world.dim_p), np.zeros(world.dim_p)]\n        # if agent.goal_a is not None:\n        #     goal_pos[0] = agent.goal_a.state.p_pos - agent.state.p_pos\n        # if agent.goal_b is not None:\n        #     goal_pos[1] = agent.goal_b.state.p_pos - agent.state.p_pos\n        # goal color\n        goal_color = [np.zeros(world.dim_color), np.zeros(world.dim_color)]\n        # if agent.goal_a is not None:\n        #     goal_color[0] = agent.goal_a.color\n        if agent.goal_b is not None:\n            goal_color[1] = agent.goal_b.color\n\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:  # world.entities:\n            entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n        # entity colors\n        entity_color = []\n        for entity in world.landmarks:  # world.entities:\n            entity_color.append(entity.color)\n        # communication of all other agents\n        comm = []\n        for other in world.agents:\n            if other is agent:\n                continue\n            comm.append(other.state.c)\n        return np.concatenate([agent.state.p_vel] + entity_pos + [goal_color[1]] + comm)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/mpe/scenarios/simple_speaker_listener.py",
    "content": "import numpy as np\nfrom mpe.core import World, Agent, Landmark\nfrom mpe.scenario import BaseScenario\n\n\nclass Scenario(BaseScenario):\n    def make_world(self, args):\n        world = World()\n        world.world_length = args.episode_length\n        # set any world properties first\n        world.dim_c = 3\n        world.num_landmarks = args.num_landmarks  # 3\n        world.collaborative = True\n        # add agents\n        world.num_agents = args.num_agents  # 2\n        assert world.num_agents == 2, (\n            \"only 2 agents is supported, check the config.py.\")\n        world.agents = [Agent() for i in range(world.num_agents)]\n        for i, agent in enumerate(world.agents):\n            agent.name = 'agent %d' % i\n            agent.collide = False\n            agent.size = 0.075\n        # speaker\n        world.agents[0].movable = False\n        # listener\n        world.agents[1].silent = True\n        # add landmarks\n        world.landmarks = [Landmark() for i in range(world.num_landmarks)]\n        for i, landmark in enumerate(world.landmarks):\n            landmark.name = 'landmark %d' % i\n            landmark.collide = False\n            landmark.movable = False\n            landmark.size = 0.04\n        # make initial conditions\n        self.reset_world(world)\n        return world\n\n    def reset_world(self, world):\n        # assign goals to agents\n        for agent in world.agents:\n            agent.goal_a = None\n            agent.goal_b = None\n        # want listener to go to the goal landmark\n        world.agents[0].goal_a = world.agents[1]\n        world.agents[0].goal_b = np.random.choice(world.landmarks)\n        # random properties for agents\n        for i, agent in enumerate(world.agents):\n            agent.color = np.array([0.25, 0.25, 0.25])\n        # random properties for landmarks\n        world.landmarks[0].color = np.array([0.65, 0.15, 0.15])\n        world.landmarks[1].color = np.array([0.15, 0.65, 0.15])\n        world.landmarks[2].color = np.array([0.15, 0.15, 0.65])\n        # special colors for goals\n        world.agents[0].goal_a.color = world.agents[0].goal_b.color + \\\n            np.array([0.45, 0.45, 0.45])\n        # set random initial states\n        for agent in world.agents:\n            agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            agent.state.p_vel = np.zeros(world.dim_p)\n            agent.state.c = np.zeros(world.dim_c)\n        for i, landmark in enumerate(world.landmarks):\n            landmark.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n\n    def benchmark_data(self, agent, world):\n        # returns data for benchmarking purposes\n        return reward(agent, reward)\n\n    def reward(self, agent, world):\n        # squared distance from listener to landmark\n        a = world.agents[0]\n        dist2 = np.sum(np.square(a.goal_a.state.p_pos - a.goal_b.state.p_pos))\n        return -dist2\n\n    def observation(self, agent, world):\n        # goal color\n        goal_color = np.zeros(world.dim_color)\n        if agent.goal_b is not None:\n            goal_color = agent.goal_b.color\n\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:\n            entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n\n        # communication of all other agents\n        comm = []\n        for other in world.agents:\n            if other is agent or (other.state.c is None):\n                continue\n            comm.append(other.state.c)\n\n        # speaker\n        if not agent.movable:\n            return np.concatenate([goal_color])\n        # listener\n        if agent.silent:\n            return np.concatenate([agent.state.p_vel] + entity_pos + comm)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/mpe/scenarios/simple_spread.py",
    "content": "import numpy as np\nfrom mpe.core import World, Agent, Landmark\nfrom mpe.scenario import BaseScenario\n\n\nclass Scenario(BaseScenario):\n    def make_world(self, args):\n        world = World()\n        # world.world_length = args.episode_length\n        # set any world properties first\n        world.dim_c = 2\n        world.num_agents = args.num_agents\n        world.num_landmarks = args.num_landmarks  # 3\n        world.collaborative = True\n        # add agents\n        world.agents = [Agent() for i in range(world.num_agents)]\n        for i, agent in enumerate(world.agents):\n            agent.name = 'agent %d' % i\n            agent.collide = True\n            agent.silent = True\n            agent.size = 0.15\n        # add landmarks\n        world.landmarks = [Landmark() for i in range(world.num_landmarks)]\n        for i, landmark in enumerate(world.landmarks):\n            landmark.name = 'landmark %d' % i\n            landmark.collide = False\n            landmark.movable = False\n        # make initial conditions\n        self.reset_world(world)\n        return world\n\n    def reset_world(self, world):\n        # random properties for agents\n        world.assign_agent_colors()\n\n        world.assign_landmark_colors()\n\n        # set random initial states\n        for agent in world.agents:\n            agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            agent.state.p_vel = np.zeros(world.dim_p)\n            agent.state.c = np.zeros(world.dim_c)\n        for i, landmark in enumerate(world.landmarks):\n            landmark.state.p_pos = 0.8 * np.random.uniform(-1, +1, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n\n    def benchmark_data(self, agent, world):\n        rew = 0\n        collisions = 0\n        occupied_landmarks = 0\n        min_dists = 0\n        for l in world.landmarks:\n            dists = [np.sqrt(np.sum(np.square(a.state.p_pos - l.state.p_pos)))\n                     for a in world.agents]\n            min_dists += min(dists)\n            rew -= min(dists)\n            if min(dists) < 0.1:\n                occupied_landmarks += 1\n        if agent.collide:\n            for a in world.agents:\n                if self.is_collision(a, agent):\n                    rew -= 1\n                    collisions += 1\n        return (rew, collisions, min_dists, occupied_landmarks)\n\n    def is_collision(self, agent1, agent2):\n        delta_pos = agent1.state.p_pos - agent2.state.p_pos\n        dist = np.sqrt(np.sum(np.square(delta_pos)))\n        dist_min = agent1.size + agent2.size\n        return True if dist < dist_min else False\n\n    def reward(self, agent, world):\n        # Agents are rewarded based on minimum agent distance to each landmark, penalized for collisions\n        rew = 0\n        for l in world.landmarks:\n            dists = [np.sqrt(np.sum(np.square(a.state.p_pos - l.state.p_pos)))\n                     for a in world.agents]\n            rew -= min(dists)\n\n        if agent.collide:\n            for a in world.agents:\n                if self.is_collision(a, agent):\n                    rew -= 1\n        return rew\n\n    def observation(self, agent, world):\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:  # world.entities:\n            entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n        # entity colors\n        entity_color = []\n        for entity in world.landmarks:  # world.entities:\n            entity_color.append(entity.color)\n        # communication of all other agents\n        comm = []\n        other_pos = []\n        for other in world.agents:\n            if other is agent:\n                continue\n            comm.append(other.state.c)\n            other_pos.append(other.state.p_pos - agent.state.p_pos)\n        return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + comm)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/mpe/scenarios/simple_tag.py",
    "content": "import numpy as np\nfrom mpe.core import World, Agent, Landmark\nfrom mpe.scenario import BaseScenario\n\n\nclass Scenario(BaseScenario):\n    def make_world(self, args):\n        world = World()\n        # set any world properties first\n        world.dim_c = 2\n        num_good_agents = args.num_good_agents  # 1\n        num_adversaries = args.num_adversaries  # 3\n        num_agents = num_adversaries + num_good_agents\n        num_landmarks = args.num_landmarks  # 2\n        # add agents\n        world.agents = [Agent() for i in range(num_agents)]\n        for i, agent in enumerate(world.agents):\n            agent.name = 'agent %d' % i\n            agent.collide = True\n            agent.silent = True\n            agent.adversary = True if i < num_adversaries else False\n            agent.size = 0.075 if agent.adversary else 0.05\n            agent.accel = 3.0 if agent.adversary else 4.0\n            # agent.accel = 20.0 if agent.adversary else 25.0\n            agent.max_speed = 1.0 if agent.adversary else 1.3\n        # add landmarks\n        world.landmarks = [Landmark() for i in range(num_landmarks)]\n        for i, landmark in enumerate(world.landmarks):\n            landmark.name = 'landmark %d' % i\n            landmark.collide = True\n            landmark.movable = False\n            landmark.size = 0.2\n            landmark.boundary = False\n        # make initial conditions\n        self.reset_world(world)\n        return world\n\n    def reset_world(self, world):\n        # random properties for agents\n        world.assign_agent_colors()\n        # random properties for landmarks\n        world.assign_landmark_colors()\n        # random properties for landmarks\n        # set random initial states\n        for agent in world.agents:\n            agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            agent.state.p_vel = np.zeros(world.dim_p)\n            agent.state.c = np.zeros(world.dim_c)\n        for i, landmark in enumerate(world.landmarks):\n            if not landmark.boundary:\n                landmark.state.p_pos = 0.8 * np.random.uniform(-1, +1, world.dim_p)\n                landmark.state.p_vel = np.zeros(world.dim_p)\n\n    def benchmark_data(self, agent, world):\n        # returns data for benchmarking purposes\n        if agent.adversary:\n            collisions = 0\n            for a in self.good_agents(world):\n                if self.is_collision(a, agent):\n                    collisions += 1\n            return collisions\n        else:\n            return 0\n\n    def is_collision(self, agent1, agent2):\n        delta_pos = agent1.state.p_pos - agent2.state.p_pos\n        dist = np.sqrt(np.sum(np.square(delta_pos)))\n        dist_min = agent1.size + agent2.size\n        return True if dist < dist_min else False\n\n    # return all agents that are not adversaries\n    def good_agents(self, world):\n        return [agent for agent in world.agents if not agent.adversary]\n\n    # return all adversarial agents\n    def adversaries(self, world):\n        return [agent for agent in world.agents if agent.adversary]\n\n    def reward(self, agent, world):\n        # Agents are rewarded based on minimum agent distance to each landmark\n        main_reward = self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)\n        return main_reward\n\n    def agent_reward(self, agent, world):\n        # Agents are negatively rewarded if caught by adversaries\n        rew = 0\n        shape = False  # different from openai\n        adversaries = self.adversaries(world)\n        if shape:  # reward can optionally be shaped (increased reward for increased distance from adversary)\n            for adv in adversaries:\n                rew += 0.1 * np.sqrt(np.sum(np.square(agent.state.p_pos - adv.state.p_pos)))\n        if agent.collide:\n            for a in adversaries:\n                if self.is_collision(a, agent):\n                    rew -= 10\n\n        # agents are penalized for exiting the screen, so that they can be caught by the adversaries\n        def bound(x):\n            if x < 0.9:\n                return 0\n            if x < 1.0:\n                return (x - 0.9) * 10\n            return min(np.exp(2 * x - 2), 10)\n\n        for p in range(world.dim_p):\n            x = abs(agent.state.p_pos[p])\n            rew -= bound(x)\n\n        return rew\n\n    def adversary_reward(self, agent, world):\n        # Adversaries are rewarded for collisions with agents\n        rew = 0\n        shape = False  # different from openai\n        agents = self.good_agents(world)\n        adversaries = self.adversaries(world)\n        if shape:  # reward can optionally be shaped (decreased reward for increased distance from agents)\n            for adv in adversaries:\n                rew -= 0.1 * min([np.sqrt(np.sum(np.square(a.state.p_pos - adv.state.p_pos))) for a in agents])\n        if agent.collide:\n            for ag in agents:\n                for adv in adversaries:\n                    if self.is_collision(ag, adv):\n                        rew += 10\n        return rew\n\n    def observation(self, agent, world):\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:\n            if not entity.boundary:\n                entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n        # communication of all other agents\n        comm = []\n        other_pos = []\n        other_vel = []\n        for other in world.agents:\n            if other is agent: continue\n            comm.append(other.state.c)\n            other_pos.append(other.state.p_pos - agent.state.p_pos)\n            if not other.adversary:\n                other_vel.append(other.state.p_vel)\n        return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/mpe/scenarios/simple_world_comm.py",
    "content": "import numpy as np\nfrom mpe.core import World, Agent, Landmark\nfrom mpe.scenario import BaseScenario\n\n\nclass Scenario(BaseScenario):\n    def make_world(self, args):\n        world = World()\n        # set any world properties first\n        world.dim_c = 4\n        # world.damping = 1\n        num_good_agents = args.num_good_agents  # 2\n        num_adversaries = args.num_adversaries  # 4\n        num_agents = num_adversaries + num_good_agents\n        num_landmarks = args.num_landmarks  # 1\n        num_food = 2\n        num_forests = 2\n        # add agents\n        world.agents = [Agent() for i in range(num_agents)]\n        for i, agent in enumerate(world.agents):\n            agent.name = 'agent %d' % i\n            agent.collide = True\n            agent.leader = True if i == 0 else False\n            agent.silent = True if i > 0 else False\n            agent.adversary = True if i < num_adversaries else False\n            agent.size = 0.075 if agent.adversary else 0.045\n            agent.accel = 3.0 if agent.adversary else 4.0\n            # agent.accel = 20.0 if agent.adversary else 25.0\n            agent.max_speed = 1.0 if agent.adversary else 1.3\n        # add landmarks\n        world.landmarks = [Landmark() for i in range(num_landmarks)]\n        for i, landmark in enumerate(world.landmarks):\n            landmark.name = 'landmark %d' % i\n            landmark.collide = True\n            landmark.movable = False\n            landmark.size = 0.2\n            landmark.boundary = False\n        world.food = [Landmark() for i in range(num_food)]\n        for i, landmark in enumerate(world.food):\n            landmark.name = 'food %d' % i\n            landmark.collide = False\n            landmark.movable = False\n            landmark.size = 0.03\n            landmark.boundary = False\n        world.forests = [Landmark() for i in range(num_forests)]\n        for i, landmark in enumerate(world.forests):\n            landmark.name = 'forest %d' % i\n            landmark.collide = False\n            landmark.movable = False\n            landmark.size = 0.3\n            landmark.boundary = False\n        world.landmarks += world.food\n        world.landmarks += world.forests\n        # world.landmarks += self.set_boundaries(world)  # world boundaries now penalized with negative reward\n        # make initial conditions\n        self.reset_world(world)\n        return world\n\n    def set_boundaries(self, world):\n        boundary_list = []\n        landmark_size = 1\n        edge = 1 + landmark_size\n        num_landmarks = int(edge * 2 / landmark_size)\n        for x_pos in [-edge, edge]:\n            for i in range(num_landmarks):\n                l = Landmark()\n                l.state.p_pos = np.array([x_pos, -1 + i * landmark_size])\n                boundary_list.append(l)\n\n        for y_pos in [-edge, edge]:\n            for i in range(num_landmarks):\n                l = Landmark()\n                l.state.p_pos = np.array([-1 + i * landmark_size, y_pos])\n                boundary_list.append(l)\n\n        for i, l in enumerate(boundary_list):\n            l.name = 'boundary %d' % i\n            l.collide == True\n            l.movable = False\n            l.boundary = True\n            l.color = np.array([0.75, 0.75, 0.75])\n            l.size = landmark_size\n            l.state.p_vel = np.zeros(world.dim_p)\n\n        return boundary_list\n\n    def reset_world(self, world):\n        # random properties for agents\n        for i, agent in enumerate(world.agents):\n            agent.color = np.array([0.45, 0.95, 0.45]) if not agent.adversary else np.array([0.95, 0.45, 0.45])\n            agent.color -= np.array([0.3, 0.3, 0.3]) if agent.leader else np.array([0, 0, 0])\n            # random properties for landmarks\n        for i, landmark in enumerate(world.landmarks):\n            landmark.color = np.array([0.25, 0.25, 0.25])\n        for i, landmark in enumerate(world.food):\n            landmark.color = np.array([0.15, 0.15, 0.65])\n        for i, landmark in enumerate(world.forests):\n            landmark.color = np.array([0.6, 0.9, 0.6])\n        # set random initial states\n        for agent in world.agents:\n            agent.state.p_pos = np.random.uniform(-1, +1, world.dim_p)\n            agent.state.p_vel = np.zeros(world.dim_p)\n            agent.state.c = np.zeros(world.dim_c)\n        for i, landmark in enumerate(world.landmarks):\n            landmark.state.p_pos = 0.8 * np.random.uniform(-1, +1, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n        for i, landmark in enumerate(world.food):\n            landmark.state.p_pos = 0.8 * np.random.uniform(-1, +1, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n        for i, landmark in enumerate(world.forests):\n            landmark.state.p_pos = 0.8 * np.random.uniform(-1, +1, world.dim_p)\n            landmark.state.p_vel = np.zeros(world.dim_p)\n\n    def benchmark_data(self, agent, world):\n        if agent.adversary:\n            collisions = 0\n            for a in self.good_agents(world):\n                if self.is_collision(a, agent):\n                    collisions += 1\n            return collisions\n        else:\n            return 0\n\n    def is_collision(self, agent1, agent2):\n        delta_pos = agent1.state.p_pos - agent2.state.p_pos\n        dist = np.sqrt(np.sum(np.square(delta_pos)))\n        dist_min = agent1.size + agent2.size\n        return True if dist < dist_min else False\n\n    # return all agents that are not adversaries\n    def good_agents(self, world):\n        return [agent for agent in world.agents if not agent.adversary]\n\n    # return all adversarial agents\n    def adversaries(self, world):\n        return [agent for agent in world.agents if agent.adversary]\n\n    def reward(self, agent, world):\n        # Agents are rewarded based on minimum agent distance to each landmark\n        # boundary_reward = -10 if self.outside_boundary(agent) else 0\n        main_reward = self.adversary_reward(agent, world) if agent.adversary else self.agent_reward(agent, world)\n        return main_reward\n\n    def outside_boundary(self, agent):\n        if agent.state.p_pos[0] > 1 or agent.state.p_pos[0] < -1 or agent.state.p_pos[1] > 1 or agent.state.p_pos[\n            1] < -1:\n            return True\n        else:\n            return False\n\n    def agent_reward(self, agent, world):\n        # Agents are rewarded based on minimum agent distance to each landmark\n        rew = 0\n        shape = False\n        adversaries = self.adversaries(world)\n        if shape:\n            for adv in adversaries:\n                rew += 0.1 * np.sqrt(np.sum(np.square(agent.state.p_pos - adv.state.p_pos)))\n        if agent.collide:\n            for a in adversaries:\n                if self.is_collision(a, agent):\n                    rew -= 5\n\n        def bound(x):\n            if x < 0.9:\n                return 0\n            if x < 1.0:\n                return (x - 0.9) * 10\n            return min(np.exp(2 * x - 2), 10)  # 1 + (x - 1) * (x - 1)\n\n        for p in range(world.dim_p):\n            x = abs(agent.state.p_pos[p])\n            rew -= 2 * bound(x)\n\n        for food in world.food:\n            if self.is_collision(agent, food):\n                rew += 2\n        rew += 0.05 * min([np.sqrt(np.sum(np.square(food.state.p_pos - agent.state.p_pos))) for food in world.food])\n\n        return rew\n\n    def adversary_reward(self, agent, world):\n        # Agents are rewarded based on minimum agent distance to each landmark\n        rew = 0\n        shape = True\n        agents = self.good_agents(world)\n        adversaries = self.adversaries(world)\n        if shape:\n            rew -= 0.1 * min([np.sqrt(np.sum(np.square(a.state.p_pos - agent.state.p_pos))) for a in agents])\n            # for adv in adversaries:\n            #    rew -= 0.1 * min([np.sqrt(np.sum(np.square(a.state.p_pos - adv.state.p_pos))) for a in agents])\n        if agent.collide:\n            for ag in agents:\n                for adv in adversaries:\n                    if self.is_collision(ag, adv):\n                        rew += 5\n        return rew\n\n    def observation2(self, agent, world):\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:  # world.entities:\n            if not entity.boundary:\n                entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n\n        food_pos = []\n        for entity in world.food:  # world.entities:\n            if not entity.boundary:\n                food_pos.append(entity.state.p_pos - agent.state.p_pos)\n        # communication of all other agents\n        comm = []\n        other_pos = []\n        other_vel = []\n        for other in world.agents:\n            if other is agent: continue\n            comm.append(other.state.c)\n            other_pos.append(other.state.p_pos - agent.state.p_pos)\n            if not other.adversary:\n                other_vel.append(other.state.p_vel)\n        return np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel)\n\n    def observation(self, agent, world):\n        # get positions of all entities in this agent's reference frame\n        entity_pos = []\n        for entity in world.landmarks:\n            if not entity.boundary:\n                entity_pos.append(entity.state.p_pos - agent.state.p_pos)\n\n        in_forest = [np.array([-1]), np.array([-1])]\n        inf1 = False\n        inf2 = False\n        if self.is_collision(agent, world.forests[0]):\n            in_forest[0] = np.array([1])\n            inf1 = True\n        if self.is_collision(agent, world.forests[1]):\n            in_forest[1] = np.array([1])\n            inf2 = True\n\n        food_pos = []\n        for entity in world.food:\n            if not entity.boundary:\n                food_pos.append(entity.state.p_pos - agent.state.p_pos)\n        # communication of all other agents\n        comm = []\n        other_pos = []\n        other_vel = []\n        for other in world.agents:\n            if other is agent: continue\n            comm.append(other.state.c)\n            oth_f1 = self.is_collision(other, world.forests[0])\n            oth_f2 = self.is_collision(other, world.forests[1])\n            if (inf1 and oth_f1) or (inf2 and oth_f2) or (\n                    not inf1 and not oth_f1 and not inf2 and not oth_f2) or agent.leader:  # without forest vis\n                other_pos.append(other.state.p_pos - agent.state.p_pos)\n                if not other.adversary:\n                    other_vel.append(other.state.p_vel)\n            else:\n                other_pos.append([0, 0])\n                if not other.adversary:\n                    other_vel.append([0, 0])\n\n        # to tell the pred when the prey are in the forest\n        prey_forest = []\n        ga = self.good_agents(world)\n        for a in ga:\n            if any([self.is_collision(a, f) for f in world.forests]):\n                prey_forest.append(np.array([1]))\n            else:\n                prey_forest.append(np.array([-1]))\n        # to tell leader when pred are in forest\n        prey_forest_lead = []\n        for f in world.forests:\n            if any([self.is_collision(a, f) for a in ga]):\n                prey_forest_lead.append(np.array([1]))\n            else:\n                prey_forest_lead.append(np.array([-1]))\n\n        comm = [world.agents[0].state.c]\n\n        if agent.adversary and not agent.leader:\n            return np.concatenate(\n                [agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel + in_forest + comm)\n        if agent.leader:\n            return np.concatenate(\n                [agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + other_vel + in_forest + comm)\n        else:\n            return np.concatenate(\n                [agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + in_forest + other_vel)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/networks/ToCM/action.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport sys\nimport os\n\n# if '/home/zhaofeifei/.local/lib/python3.8/site-packages' in sys.path:\n#     sys.path.remove('/home/zhaofeifei/.local/lib/python3.8/site-packages')\n\n# sys.path.append('/home/zhaofeifei/mambaSNN_Mpe/networks/ToCM/')\n# sys.path.append(\"/home/zhaofeifei/mambaSNN_Mpe/\")\n\nfrom torch.distributions import OneHotCategorical\nfrom networks.transformer.layers import AttentionEncoder, AttentionActorEncoder\nfrom networks.ToCM.utils import build_model_snn, build_model\nfrom braincog.base.node.node import LIFNode, BaseNode, PLIFNode, DoubleSidePLIFNode\nfrom braincog.base.strategy.surrogate import AtanGrad\n\n\nclass BCNoSpikingLIFNode(LIFNode):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n    def forward(self, dv: torch.Tensor):\n        # print(\"dv: \", dv)\n        # print(\"dv.shape: \", dv.shape)\n        self.integral(dv)\n        return self.mem\n\n#SNN\n# class Actor(nn.Module):\n#     def __init__(self, in_dim, out_dim, hidden_size, layers, node='LIFNode', time_window=8,\n#                  norm_in=True, output_style='voltage'):  # 1.激活函数需要改成node # voltage\n#         super().__init__()\n#         # 1.加入SNN的脉冲参数\n#         self._threshold = 0.5\n#         self.v_reset = 0.0\n#         self.tau = 0.5\n#         self._time_window = time_window\n#         # 2.设置输出格式\n#         self.output_style = output_style\n#         # 3.ffn是否归一化\n#         self.norm = norm_in\n#         self.activation = node\n#         self.feedforward_model = build_model_snn(in_dim, out_dim, layers, hidden_size,  # kkkk TODO!!!\n#                                                  th=self._threshold, re=self.v_reset, tau=self.tau,\n#                                                  activation=self.activation, normalize=lambda x: x)  # TODO\n#         if self.output_style == 'ann':\n#             self.out_node = lambda x: x\n#         elif self.output_style == 'voltage':\n#             self.out_node = BCNoSpikingLIFNode(tau=1.0)\n#\n#     def forward(self, state_features):\n#         # 5.加入脉冲仿真步长\n#         # print(\"state.shape\", state_features.shape)\n#         self.reset()  # why\n#         for t in range(self._time_window):\n#             x = self.feedforward_model(state_features)\n#             x = self.out_node(x)\n#         # print(\"x\", x.shape)\n#         action_dist = OneHotCategorical(logits=x)\n#         action = action_dist.sample()  # 长度为x，一行默认 tensor([0., 1., 0., 0.])\n#         return action, x\n#\n#     # 调用modules里面node的n_reset\n#     def reset(self):\n#         for mod in self.modules():\n#             if hasattr(mod, 'n_reset'):\n#                 mod.n_reset()\n\n#ANN\nclass Actor(nn.Module):\n    def __init__(self, in_dim, out_dim, hidden_size, layers, activation=nn.ReLU):\n        super().__init__()\n\n        self.feedforward_model = build_model(in_dim, out_dim, layers, hidden_size, activation)\n\n    def forward(self, state_features):\n        x = self.feedforward_model(state_features)\n        action_dist = OneHotCategorical(logits=x)\n        action = action_dist.sample()\n        return action, x\n\nclass AttentionActor(nn.Module):\n    def __init__(self, in_dim, out_dim, hidden_size, layers, node='LIFNode', time_window=16,\n                 norm_in=True, output_style='voltage'):  # 2.激活层\n        super().__init__()\n        # 1.加入SNN的脉冲参数\n        self._threshold = 0.5\n        self.v_reset = 0.0\n        self._time_window = time_window\n        # 2.设置输出格式\n        self.output_style = output_style\n        # 3.ffn是否归一化\n        self.norm = norm_in\n        # 4.改变linear层的激活函数为LIFNode\n        self.activation = node\n\n        # hint: hidden_size = 其他网络的in_dim\n        self.feedforward_model = build_model_snn(hidden_size, out_dim, 2, hidden_size,\n                                                 th=self._threshold, re=self.v_reset,\n                                                 activation=self.activation, normalize=lambda x: x)  # TODO\n        # build_model_snn(in_dim, out_dim, layers, hidden, activation, normalize=lambda x: x)\n        self._attention_stack = AttentionActorEncoder(1, hidden_size, hidden_size)\n        # no pos_embedding\n        # self._attention_stack = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=in_dim, nhead=1,\n        #                                                                         dim_feedforward=hidden_size,\n        #                                                                         dropout=.0), num_layers=1)  # TODO\n        # n_layers, in_dim, hidden\n        # 使用transformer的编码器，加入位置编码，加入隐藏单元d_hid,返回一个序列,其中第0维度应该是观测变量？\n        self.embed = nn.Linear(in_dim, hidden_size)\n        self.node1 = LIFNode(threshold=self._threshold, v_reset=self.v_reset)\n        self.node2 = LIFNode(threshold=self._threshold, v_reset=self.v_reset)\n        # 5. 定义一个处理linear层的node\n        if self.activation == 'LIFNode':\n            if self.output_style == 'voltage':\n                self.out_node = BCNoSpikingLIFNode(tau=2.0)\n\n    def forward(self, state_features):  # 状态值tensor\n        # print(\"state_feat:\", state_features[0])\n        # attn_embeds = self._attention_stack(state_features)\n        # n_agents = state_features.shape[-2]  # 推测state的维度为[batch_size(m,n), n_agents, in_dim]\n        # batch_size = state_features.shape[:-2]  # 除去最后2维度的维度\n        qs = []\n        self.reset()  # why\n        # print(\"attn_embeds\", attn_embeds[0])\n        # print(\"state.shape\", state_features.shape)\n        attn_embeds = self.embed(state_features)  # Linear\n        for t in range(self._time_window):\n            embeds = self.node1(attn_embeds)  # Node\n            # attn_embeds = embeds.view(-1, n_agents, embeds.shape[-1])\n            # embeds = self.node2(self._attention_stack(embeds).view(*batch_size, n_agents, embeds.shape[-1]))\n            x = self.feedforward_model(embeds)\n            x = self.out_node(x)\n            qs.append(x)\n\n        p = torch.zeros(qs[0].shape)\n        if self.output_style == \"sum\":\n            p = sum(qs) / self._time_window\n        elif self.output_style == \"voltage\":\n            p = qs[-1]  # TODO\n\n        # p = F.softmax(p)\n        # print(\"pi:\", p[0])\n        action_dist = OneHotCategorical(logits=p)  # 编码器，长度为p\n\n        action = action_dist.sample()\n        # print(\"actions\", action[0])\n        # 对输出进行采样\n        return action, p  # 返回一个行动序列action为每个位置符合p = x[i]的0,1序列\n\n    # 调用modules里面node的n_reset\n    def reset(self):\n        for mod in self.modules():\n            if hasattr(mod, 'n_reset'):\n                mod.n_reset()\n\n# aa = AttentionActor(16, 8, 64, 3)  # in_dim, out_dim, hidden_size, layers,\n# state_feature = torch.randn([8, 8, 2, 16])  # 输入变量维度\n# out, x = aa(state_feature)\n# print(aa)\n# print(out)\n# print(out.shape)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/networks/ToCM/critic.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport sys\nsys.path.append('/home/zhaofeifei/mambaSNN/networks/ToCM/')\n\nfrom networks.ToCM.utils import build_model_snn, build_model\nfrom networks.transformer.layers import AttentionEncoder\nfrom braincog.base.node.node import LIFNode\nfrom braincog.base.strategy.surrogate import AtanGrad\n\ndecay = 0.3\nthresh = 0.3\nlens = 0.25\n\n# print(\"File Critic Here\")\n# 0.定义一个返回膜电势的 LIFNode\nclass BCNoSpikingLIFNode(LIFNode):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n    def forward(self, dv: torch.Tensor):\n        # print(\"dv: \", dv)\n        # print(\"dv.shape: \", dv.shape)\n        self.integral(dv)\n        return self.mem\n\n\nact_fun = AtanGrad(alpha=2., requires_grad=False)\n\n\ndef mem_update(fc, x, mem, spike):\n    mem = mem * decay * (1 - spike) + fc(x)\n    # spike = act_fun(mem)\n    spike = act_fun(x=mem-1)\n    return mem, spike\n\n\nclass Critic(nn.Module):\n    def __init__(self, in_dim, hidden_size, layers=2, node='LIFNode', time_window=16,\n                 norm_in=True, output_style='voltage'):\n        # hint是critic没有输出维度，action的输出维度是action的数量\n        super().__init__()\n\n        # 1.加入SNN的脉冲参数\n        self._threshold = 0.5\n        self.v_reset = 0.0\n        self._time_window = time_window\n        # 2.设置输出格式\n        self.output_style = output_style\n        # 3.ffn是否归一化\n        self.norm = norm_in\n        # if self.norm:\n        #     self.in_norm = nn.BatchNorm1d(in_dim)\n        #     self.in_norm.weight.data.fill_(1)\n        #     self.in_norm.bias.data.zero_()\n        # else:\n        #     self.in_norm = lambda x: x\n        self.in_norm = lambda x: x\n        # 4.改变linear层的激活函数为LIFNode\n        self.activation = node\n\n        self.hidden_size = hidden_size\n        self.layers = layers\n\n        self.feedforward_model = build_model_snn(in_dim, 1, layers, hidden_size,\n                                                 th=self._threshold, re=self.v_reset,\n                                                 activation=self.activation, normalize=lambda x: x)\n        # 这里feedforward的输出维度为1，其余一样 in_dim, out_dim, layers, hidden\n\n        # 5. 定义输出神经元node\n        if self.output_style == \"sum\":\n            self.out_node = lambda x: x\n\n        elif self.output_style == \"voltage\":\n            self.out_node = BCNoSpikingLIFNode(tau=2.0)\n\n    def forward(self, state_features, actions):\n        # 6.加入脉冲步长模拟\n        qs = []\n        self.reset()  # why\n        # 7.加入第一次输入的归一化，对最前面的输入进行norm\n        state_features = self.in_norm(state_features)\n        for t in range(self._time_window):\n            x = self.feedforward_model(state_features)\n            # 8.linear层之后还得有个node接住。否则如果对于ann来说，linear之后的浮点数就能作为最后的分值了，对于snn不行\n            x = self.out_node(x)\n            qs.append(x)\n\n        if self.output_style == 'sum':\n            value = sum(qs) / self._time_window\n            return value\n        elif self.output_style == 'voltage':\n            value = qs[-1]\n            return value\n\n    # 调用modules里面node的n_reset\n    def reset(self):\n        for mod in self.modules():\n            if hasattr(mod, 'n_reset'):\n                mod.n_reset()\n\n#SNN\n# class MADDPGCritic(nn.Module):\n#     def __init__(self, in_dim, hidden_size, node='nn.Tanh', time_window=1,  # time_window=16,\n#                  norm_in=True, output_style='ann'):  # in_dim 1280 hidden_size 256\n#         super().__init__()\n#\n#         # 1.加入SNN的脉冲参数\n#         self._threshold = 0.5\n#         self.v_reset = 0.0\n#         self._time_window = time_window\n#         # 2.设置输出格式\n#         self.output_style = output_style\n#         # 3. ffn是否归一化\n#         self.norm = norm_in\n#\n#         # TODO no normalize\n#         self.in_norm = lambda x: x\n#         # 4.改变linear层的激活函数为LIFNode\n#         self.activation = node  # TODO!!!!!!!!!!!\n#\n#         self.feedforward_model = build_model_snn(hidden_size, 1, 1, hidden_size,\n#                                                  th=self._threshold, re=self.v_reset,\n#                                                  activation=self.activation, normalize=lambda x: x)\n#         # in_dim, out_dim, layers, hidden\n#         # (in_dim = hidden)->hidden->hidden......-> (out_dim = 1)\n#\n#         self._attention_stack = AttentionEncoder(1, hidden_size, hidden_size)\n#         self.embed = nn.Linear(in_dim, hidden_size)  # 1280 256\n#         self.prior = build_model_snn(in_dim, 1, 3, hidden_size,  # 1280 256\n#                                      th=self._threshold, re=self._threshold,\n#                                      activation=self.activation, normalize=lambda x: x)\n#         # also in_dim, out_dim, layers, hidden\n#         # (in_dim = hidden)->hidden->hidden......-> (out_dim = 1)\n#         # 可能是个决策函数，决策优先选择哪个action\n#\n#         # 5. 定义输出神经元node\n#         if self.output_style == \"sum\":\n#             self.out_node = lambda x: x\n#         elif self.output_style == \"voltage\":\n#             self.out_node = BCNoSpikingLIFNode(tau=2.0)\n#         elif self.output_style == 'ann':\n#             self.out_node = lambda x: x\n#\n#     def forward(self, state_features, actions):\n#         self.reset()  # reset函数得看看怎么加\n#         n_agents = state_features.shape[-2]\n#         batch_size = state_features.shape[:-2]\n#         # 6.加入第一次输入的归一化\n#         state_features = self.in_norm(state_features)\n#         # 7.暂时不把编码加入模拟时长\n#         embeds = F.relu(self.embed(state_features))\n#         embeds = embeds.view(-1, n_agents, embeds.shape[-1])\n#         attn_embeds = F.relu(self._attention_stack(embeds).view(*batch_size, n_agents, embeds.shape[-1]))\n#\n#         # 7.设置脉冲发放时长模拟，只在ffn层\n#         qs = []\n#         for t in range(self._time_window):\n#             x = self.feedforward_model(attn_embeds)\n#             x = self.out_node(x)\n#             qs.append(x)\n#\n#         value = qs[-1]  # after 16 mem\n#         # x = self.feedforward_model(attn_embeds)\n#         # value = self.out_node(x)  # only mem once\n#         return value\n#\n#     # 调用modules里面node的n_reset\n#     def reset(self):\n#         for mod in self.modules():\n#             if hasattr(mod, 'n_reset'):\n#                 mod.n_reset()\n#ANN\nclass MADDPGCritic(nn.Module):\n    def __init__(self, in_dim, hidden_size, layers=2, activation=nn.ELU):\n        super().__init__()\n        self.hidden_size = hidden_size\n        self.layers = layers\n        self.activation = activation\n        self.feedforward_model = build_model(in_dim, 1, layers, hidden_size, activation)\n\n    def forward(self, state_features, actions):\n        return self.feedforward_model(state_features)\n# critic_net = Critic(in_dim=2, hidden_size=32, layers=2, node='LIFNode', time_window=16, norm_in=True,\n#                     output_style='voltage')\n# print(critic_net)\n\n# maddpg_critic_net = MADDPGCritic(in_dim=2, hidden_size=32, node='LIFNode', time_window=16, norm_in=True,\n#                                  output_style='voltage')\n# print(maddpg_critic_net)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/networks/ToCM/dense.py",
    "content": "import torch\nimport torch.distributions as td\nimport torch.nn as nn\n\nfrom networks.ToCM.utils import build_model_snn\n\n\nclass DenseModel(nn.Module):\n    def __init__(self, in_dim, out_dim, layers, hidden, activation=\"nn.ELU\"):  # TODO  activation=nn.ELU\n        super().__init__()\n\n        self.model = build_model_snn(in_dim, out_dim, layers, hidden, activation=activation)  # no use activation\n\n    def forward(self, features):\n        return self.model(features)\n\n\nclass DenseBinaryModel(DenseModel):\n    def __init__(self, in_dim, out_dim, layers, hidden, activation=\"nn.ELU\"):  # 1280 7 2 256\n        super().__init__(in_dim, out_dim, layers, hidden, activation=activation)\n\n    def forward(self, features):\n        # for name, p in self.model.named_parameters():\n        #     print(\"name\", name)\n        #     print(\"p\", p.shape)\n\n        # if features.shape[1] != 40:\n        #     print(\"features.shape[0] / 40: \", features.shape[0] / 40)\n        #     features = torch.as_tensor(torch.split(features, int(features.shape[0] / 40), dim=0))\n        # print(\"Dense features: \", features.shape)\n        dist_inputs = self.model(features)  # features.shape 48 40 2 1280\n        # print(\"dist_inputs:\", dist_inputs.shape)\n        return td.independent.Independent(td.Bernoulli(logits=dist_inputs), 1)\n\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/networks/ToCM/rnns.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.distributions import OneHotCategorical\n\nfrom configs.ToCM.ToCMAgentConfig import RSSMState\nfrom networks.transformer.layers import AttentionEncoder\n\n\ndef stack_states(rssm_states: list, dim):\n    return reduce_states(rssm_states, dim, torch.stack)\n\n\ndef cat_states(rssm_states: list, dim):\n    return reduce_states(rssm_states, dim, torch.cat)\n\n\ndef reduce_states(rssm_states: list, dim, func):\n    return RSSMState(*[func([getattr(state, key) for state in rssm_states], dim=dim)\n                       for key in rssm_states[0].__dict__.keys()])\n\n\nclass DiscreteLatentDist(nn.Module):\n    def __init__(self, in_dim, n_categoricals, n_classes, hidden_size):\n        super().__init__()\n        self.n_categoricals = n_categoricals\n        self.n_classes = n_classes\n        self.dists = nn.Sequential(nn.Linear(in_dim, hidden_size),\n                                   nn.ReLU(),\n                                   nn.Linear(hidden_size, n_classes * n_categoricals))\n\n    def forward(self, x):\n        logits = self.dists(x).view(x.shape[:-1] + (self.n_categoricals, self.n_classes))\n        class_dist = OneHotCategorical(logits=logits)\n        one_hot = class_dist.sample()\n        latents = one_hot + class_dist.probs - class_dist.probs.detach()\n        return logits.view(x.shape[:-1] + (-1,)), latents.view(x.shape[:-1] + (-1,))\n\n\nclass RSSMTransition(nn.Module):\n    def __init__(self, config, hidden_size=200, activation=nn.ReLU):\n        super().__init__()\n        self._stoch_size = config.STOCHASTIC\n        self._deter_size = config.DETERMINISTIC\n        self._hidden_size = hidden_size\n        self._activation = activation\n        self._cell = nn.GRU(hidden_size, self._deter_size)\n        self._attention_stack = AttentionEncoder(3, hidden_size, hidden_size, dropout=0.1)\n        self._rnn_input_model = self._build_rnn_input_model(config.ACTION_SIZE + self._stoch_size)\n        self._stochastic_prior_model = DiscreteLatentDist(self._deter_size, config.N_CATEGORICALS, config.N_CLASSES,\n                                                          self._hidden_size)\n\n    def _build_rnn_input_model(self, in_dim):\n        rnn_input_model = [nn.Linear(in_dim, self._hidden_size)]\n        rnn_input_model += [self._activation()]\n        return nn.Sequential(*rnn_input_model)\n\n    def forward(self, prev_actions, prev_states, mask=None):\n        batch_size = prev_actions.shape[0]\n        n_agents = prev_actions.shape[1]\n        stoch_input = self._rnn_input_model(torch.cat([prev_actions, prev_states.stoch], dim=-1))\n        attn = self._attention_stack(stoch_input, mask=mask)\n        deter_state = self._cell(attn.reshape(1, batch_size * n_agents, -1),\n                                 prev_states.deter.reshape(1, batch_size * n_agents, -1))[0].reshape(batch_size, n_agents, -1)\n        logits, stoch_state = self._stochastic_prior_model(deter_state)\n        return RSSMState(logits=logits, stoch=stoch_state, deter=deter_state)\n\n\nclass RSSMRepresentation(nn.Module):\n    def __init__(self, config, transition_model: RSSMTransition):\n        super().__init__()\n        self._transition_model = transition_model\n        self._stoch_size = config.STOCHASTIC\n        self._deter_size = config.DETERMINISTIC\n        self._stochastic_posterior_model = DiscreteLatentDist(self._deter_size + config.EMBED, config.N_CATEGORICALS,\n                                                              config.N_CLASSES, config.HIDDEN)\n\n    def initial_state(self, batch_size, n_agents, **kwargs):\n        return RSSMState(stoch=torch.zeros(batch_size, n_agents, self._stoch_size, **kwargs),\n                         logits=torch.zeros(batch_size, n_agents, self._stoch_size, **kwargs),\n                         deter=torch.zeros(batch_size, n_agents, self._deter_size, **kwargs))\n\n    def forward(self, obs_embed, prev_actions, prev_states, mask=None):\n        \"\"\"\n        :param obs_embed: size(batch, n_agents, obs_size)\n        :param prev_actions: size(batch, n_agents, action_size)\n        :param prev_states: size(batch, n_agents, state_size)\n        :return: RSSMState, global_state: size(batch, 1, global_state_size)\n        \"\"\"\n        prior_states = self._transition_model(prev_actions, prev_states, mask)\n        x = torch.cat([prior_states.deter, obs_embed], dim=-1)\n        logits, stoch_state = self._stochastic_posterior_model(x)\n        posterior_states = RSSMState(logits=logits, stoch=stoch_state, deter=prior_states.deter)\n        return prior_states, posterior_states\n\n\ndef rollout_representation(representation_model, steps, obs_embed, action, prev_states, done):\n    \"\"\"\n        Roll out the model with actions and observations from data.\n        :param steps: number of steps to roll out\n        :param obs_embed: size(time_steps, batch_size, n_agents, embedding_size)\n        :param action: size(time_steps, batch_size, n_agents, action_size)\n        :param prev_states: RSSM state, size(batch_size, n_agents, state_size)\n        :return: prior, posterior states. size(time_steps, batch_size, n_agents, state_size)\n        \"\"\"\n    priors = []\n    posteriors = []\n    for t in range(steps):\n        prior_states, posterior_states = representation_model(obs_embed[t], action[t], prev_states)\n        prev_states = posterior_states.map(lambda x: x * (1.0 - done[t]))\n        priors.append(prior_states)\n        posteriors.append(posterior_states)\n\n    prior = stack_states(priors, dim=0)\n    post = stack_states(posteriors, dim=0)\n    return prior.map(lambda x: x[:-1]), post.map(lambda x: x[:-1]), post.deter[1:]\n\n\ndef rollout_policy(transition_model, av_action, steps, policy, prev_state, prev_action, config):  # av_action.shape=[49 40 2 7] policy=actor\n    \"\"\"\n        Roll out the model with a policy function.\n        :param steps: number of steps to roll out\n        :param policy: RSSMState -> action\n        :param prev_state: RSSM state, size(batch_size, state_size)\n        :return: next states size(time_steps, batch_size, state_size),\n                 actions size(time_steps, batch_size, action_size)\n        \"\"\"\n    state = prev_state\n    action = prev_action[:-1].reshape((prev_action.shape[0] - 1) * prev_action.shape[1], prev_action.shape[2], -1)   # TODO\n    next_states = []\n    actions = []\n    av_actions = []\n    policies = []\n    obs_preds = []\n    for t in range(steps):\n        feat = state.get_features().detach()\n        obs_pred, _ = transition_model.observation_decoder(feat)  # TODO\n        next_state = transition_model.transition(action, state)  # TODO\n        next_feat = next_state.get_features().detach()  # TODO\n        observations_next_other, _ = transition_model.observation_decoder(next_feat)  # TODO\n        action, pi = policy(torch.cat((obs_pred.detach(), observations_next_other[:, :, -(config.num_agents-1)*4:-(config.num_agents-1)*2]), -1))  # TODO 3 vs 3\n\n        # print(\"feat:\", feat)\n        # print(\"feat_shape:\", feat.shape)  # feat.shape=1920,2,1280\n        # action, pi = policy(feat)\n        if av_action is not None:\n            # print(\"av_action!\")\n            avail_actions = av_action(feat).sample()\n            pi[avail_actions == 0] = -1e10\n            action_dist = OneHotCategorical(logits=pi)\n            action = action_dist.sample().squeeze(0)\n            av_actions.append(avail_actions.squeeze(0))\n        next_states.append(state)\n        obs_preds.append(obs_pred)\n        policies.append(pi)\n        actions.append(action)\n        state = transition_model.transition(action, state)\n    return {\"imag_states\": stack_states(next_states, dim=0),\n            \"obs_preds\": torch.stack(obs_preds, dim=0),\n            \"actions\": torch.stack(actions, dim=0),\n            \"av_actions\": torch.stack(av_actions, dim=0) if len(av_actions) > 0 else None,\n            \"old_policy\": torch.stack(policies, dim=0)}\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/networks/ToCM/utils.py",
    "content": "import torch.nn as nn\n\nfrom braincog.base.node import LIFNode\nfrom braincog.base.node.node import LIFNode, DoubleSidePLIFNode, PLIFNode\nfrom braincog.base.strategy.surrogate import AtanGrad\nimport torch\n\n\nclass AtanLIFNode(LIFNode):\n    def __init__(self, tau=0.5, *args, **kwargs):\n        super().__init__(tau, *args, **kwargs)\n        self.act_fun = AtanGrad(alpha=1., requires_grad=True)\n\n\nclass BCNoSpikingLIFNode(LIFNode):\n    def __init__(self, tau, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.tau = tau\n\n    def forward(self, dv: torch.Tensor):\n        # print(\"dv: \", dv)\n        # print(\"dv.shape: \", dv.shape)\n        self.integral(dv)\n        return self.mem\n\n\ndef build_model_snn(in_dim, out_dim, layers, hidden, th=0.5, re=0.0, tau=0.5, activation='LIFNode',\n                    normalize=lambda x: x):\n    # print(\"build model snn!\")\n    # 0.activation换成LIFNode...\n    if activation == 'LIFNode':\n        node = LIFNode(threshold=th, tau=tau)\n    elif activation == 'AtanLIFNode':\n        node = AtanLIFNode(tau=tau)\n    elif activation == 'BCNoSpikingLIFNode':\n        node = BCNoSpikingLIFNode(tau=tau)\n    elif activation == 'DoubleSidePLIFNode':\n        node = DoubleSidePLIFNode(tau=tau)\n    elif activation == 'PLIFNode':\n        node = PLIFNode(threshold=th)\n    elif activation == 'nn.ELU':\n        node = nn.ELU()\n    elif activation == 'nn.ReLU':\n        node = nn.ReLU()\n    elif activation == 'nn.Tanh':\n        node = nn.Tanh()\n    # 1.是否norm no norm\n    model = [normalize(nn.Linear(in_dim, hidden))]\n    model += [node]\n    for i in range(layers - 1):\n        model += [normalize(nn.Linear(hidden, hidden))]\n        model += [node]\n    model += [normalize(nn.Linear(hidden, out_dim))]\n    # 使用第二个归一化,node激活之后还要linear，最后的输出应该还得有个node，将out node定义到外面比较合适\n    return nn.Sequential(*model)\n\ndef build_model(in_dim, out_dim, layers, hidden, activation, normalize=lambda x: x):\n    model = [normalize(nn.Linear(in_dim, hidden))]\n    model += [activation()]\n    for i in range(layers - 1):\n        model += [normalize(nn.Linear(hidden, hidden))]\n        model += [activation()]\n    model += [normalize(nn.Linear(hidden, out_dim))]\n    return nn.Sequential(*model)"
  },
  {
    "path": "examples/Social_Cognition/ToCM/networks/ToCM/vae.py",
    "content": "import torch.nn as nn\nimport torch.nn.functional as F\n\nfrom networks.ToCM.utils import build_model_snn\n\n\nclass Decoder(nn.Module):\n\n    def __init__(self, embed, hidden, out_dim, layers=2):\n        super().__init__()\n        self.fc1 = build_model_snn(embed, hidden, layers, hidden, activation='nn.ReLU')  # activation=nn.ReLU\n        self.fc2 = nn.Linear(hidden, out_dim)\n\n    def forward(self, z):\n        x = F.relu(self.fc1(z))\n        return self.fc2(x), x\n\n\nclass Encoder(nn.Module):\n\n    def __init__(self, in_dim, hidden, embed, layers=2):\n        super().__init__()\n\n        self.fc1 = nn.Linear(in_dim, hidden)\n        self.encoder = build_model_snn(hidden, embed, layers, hidden, activation='nn.ReLU')   # activation=nn.ReLU\n\n    def forward(self, x):\n        embed = F.relu(self.fc1(x))\n        return self.encoder(F.relu(embed))\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/networks/transformer/layers.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\n\n\n#  位置编码\nclass PositionalEncoding(nn.Module):\n    __author__ = \"Yu-Hsiang Huang\"\n\n    def __init__(self, d_hid, n_position=2):\n        super(PositionalEncoding, self).__init__()\n\n        # Not a parameter\n        '''\n        This is typically used to register a buffer that should not to be\n        considered a model parameter. For example, BatchNorm's ``running_mean``\n        is not a parameter, but is part of the module's state.\n        input: buffer's name, buffer's shape 应该是隐藏层之类的\n        '''\n        self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))\n        # return x + self.pos_table[:, :x.size(1)].clone().detach() 使用pos_table\n\n    @staticmethod  # 系统提示我这个方法静态\n    def _get_sinusoid_encoding_table(n_position, d_hid):\n        \"\"\" Sinusoid position encoding table \"\"\"\n\n        def get_position_angle_vec(position):  # 获取每个位置的角度向量\n            return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]\n\n        sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])\n        # shape [pos_i, d_hid, position]\n        sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i\n        sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1\n\n        return torch.FloatTensor(sinusoid_table).unsqueeze(0)  # 增加一个维度\n\n    def forward(self, x):\n        return x + self.pos_table[:, :x.size(1)].clone().detach()\n\n\nclass AttentionEncoder(nn.Module):\n\n    def __init__(self, n_layers, in_dim, hidden, dropout=0.):\n        super().__init__()\n        self.pos_embed = PositionalEncoding(hidden, 30)  # 返回位置编码方案,维度++\n        self.encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=in_dim, nhead=8,\n                                                                        dim_feedforward=hidden,\n                                                                        dropout=dropout), n_layers)\n\n    def forward(self, enc_input, **kwargs):\n        enc_input = self.pos_embed(enc_input)\n        x = self.encoder(enc_input.permute(1, 0, 2), **kwargs)\n        return x.permute(1, 0, 2)  # 混洗 调换顺序\n\n\nclass AttentionActorEncoder(nn.Module):\n\n    def __init__(self, n_layers, in_dim, hidden, dropout=0.):\n        super().__init__()\n        self.encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=in_dim, nhead=8,\n                                                                        dim_feedforward=hidden,\n                                                                        dropout=dropout), n_layers)\n\n    def forward(self, enc_input, **kwargs):\n        x = self.encoder(enc_input,  **kwargs)\n        return x  # 混洗 调换顺序\n\n\n\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/requirements.txt",
    "content": "numpy~=1.18.5\ntorch~=1.7.0\nray~=1.13.0\ngit+https://github.com/oxwhirl/smac.git\nwandb~=0.13.11\nargparse~=1.4.0\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/run.sh",
    "content": "#!/bin/sh\nseed_max=10\n#for seed in `seq ${seed_max}`;\n#do\n#    echo \"seed is ${seed}:\"\n#    python train.py\n#    kill Main_Thread\n#done\n\n#seed_max=10  # 设置最大的种子值，这里假设为10\n\n#for./run, seed in $(seq 1 $seed_max); do\n#    echo \"seed is $seed:\"\n#    python train.py --seed $seed  # 将当前种子值作为参数传递给 train.py\n#done\npython train.py --seed 50\npkill Main_Thread\npython train.py --seed 50\npkill Main_Thread\n#python train.py --seed 1\n#pkill Main_Thread"
  },
  {
    "path": "examples/Social_Cognition/ToCM/smac/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Social_Cognition/ToCM/smac/bin/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Social_Cognition/ToCM/smac/bin/map_list.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom smac.env.starcraft2.maps import smac_maps\n\nfrom pysc2 import maps as pysc2_maps\n\n\ndef main():\n    smac_map_registry = smac_maps.get_smac_map_registry()\n    all_maps = pysc2_maps.get_maps()\n    print(\"{:<15} {:7} {:7} {:7}\".format(\"Name\", \"Agents\", \"Enemies\", \"Limit\"))\n    for map_name, map_params in smac_map_registry.items():\n        map_class = all_maps[map_name]\n        if map_class.path:\n            print(\n                \"{:<15} {:<7} {:<7} {:<7}\".format(\n                    map_name,\n                    map_params[\"n_agents\"],\n                    map_params[\"n_enemies\"],\n                    map_params[\"limit\"],\n                )\n            )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/smac/env/__init__.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom smac.env.multiagentenv import MultiAgentEnv\nfrom smac.env.starcraft2.starcraft2 import StarCraft2Env\n\n__all__ = [\"MultiAgentEnv\", \"StarCraft2Env\"]\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/smac/env/multiagentenv.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n\nclass MultiAgentEnv(object):\n    def step(self, actions):\n        \"\"\"Returns reward, terminated, info.\"\"\"\n        raise NotImplementedError\n\n    def get_obs(self):\n        \"\"\"Returns all agent observations in a list.\"\"\"\n        raise NotImplementedError\n\n    def get_obs_agent(self, agent_id):\n        \"\"\"Returns observation for agent_id.\"\"\"\n        raise NotImplementedError\n\n    def get_obs_size(self):\n        \"\"\"Returns the size of the observation.\"\"\"\n        raise NotImplementedError\n\n    def get_state(self):\n        \"\"\"Returns the global state.\"\"\"\n        raise NotImplementedError\n\n    def get_state_size(self):\n        \"\"\"Returns the size of the global state.\"\"\"\n        raise NotImplementedError\n\n    def get_avail_actions(self):\n        \"\"\"Returns the available actions of all agents in a list.\"\"\"\n        raise NotImplementedError\n\n    def get_avail_agent_actions(self, agent_id):\n        \"\"\"Returns the available actions for agent_id.\"\"\"\n        raise NotImplementedError\n\n    def get_total_actions(self):\n        \"\"\"Returns the total number of actions an agent could ever take.\"\"\"\n        raise NotImplementedError\n\n    def reset(self):\n        \"\"\"Returns initial observations and states.\"\"\"\n        raise NotImplementedError\n\n    def render(self):\n        raise NotImplementedError\n\n    def close(self):\n        raise NotImplementedError\n\n    def seed(self):\n        raise NotImplementedError\n\n    def save_replay(self):\n        \"\"\"Save a replay.\"\"\"\n        raise NotImplementedError\n\n    def get_env_info(self):\n        env_info = {\n            \"state_shape\": self.get_state_size(),\n            \"obs_shape\": self.get_obs_size(),\n            \"n_actions\": self.get_total_actions(),\n            \"n_agents\": self.n_agents,\n            \"episode_limit\": self.episode_limit,\n        }\n        return env_info\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/smac/env/pettingzoo/StarCraft2PZEnv.py",
    "content": "from smac.env import StarCraft2Env\nfrom gym.utils import EzPickle\nfrom gym.utils import seeding\nfrom gym import spaces\nfrom pettingzoo.utils.env import ParallelEnv\nfrom pettingzoo.utils.conversions import parallel_to_aec as from_parallel_wrapper\nfrom pettingzoo.utils import wrappers\nimport numpy as np\n\nmax_cycles_default = 1000\n\n\ndef parallel_env(max_cycles=max_cycles_default, **smac_args):\n    return _parallel_env(max_cycles, **smac_args)\n\n\ndef raw_env(max_cycles=max_cycles_default, **smac_args):\n    return from_parallel_wrapper(parallel_env(max_cycles, **smac_args))\n\n\ndef make_env(raw_env):\n    def env_fn(**kwargs):\n        env = raw_env(**kwargs)\n        # env = wrappers.TerminateIllegalWrapper(env, illegal_reward=-1)\n        env = wrappers.AssertOutOfBoundsWrapper(env)\n        env = wrappers.OrderEnforcingWrapper(env)\n        return env\n\n    return env_fn\n\n\nclass smac_parallel_env(ParallelEnv):\n    def __init__(self, env, max_cycles):\n        self.max_cycles = max_cycles\n        self.env = env\n        self.env.reset()\n        self.reset_flag = 0\n        self.agents, self.action_spaces = self._init_agents()\n        self.possible_agents = self.agents[:]\n\n        observation_size = env.get_obs_size()\n        self.observation_spaces = {\n            name: spaces.Dict(\n                {\n                    \"observation\": spaces.Box(\n                        low=-1,\n                        high=1,\n                        shape=(observation_size,),\n                        dtype=\"float32\",\n                    ),\n                    \"action_mask\": spaces.Box(\n                        low=0,\n                        high=1,\n                        shape=(self.action_spaces[name].n,),\n                        dtype=np.int8,\n                    ),\n                }\n            )\n            for name in self.agents\n        }\n        self._reward = 0\n\n    def _init_agents(self):\n        last_type = \"\"\n        agents = []\n        action_spaces = {}\n        self.agents_id = {}\n        i = 0\n        for agent_id, agent_info in self.env.agents.items():\n            unit_action_space = spaces.Discrete(\n                self.env.get_total_actions() - 1\n            )  # no-op in dead units is not an action\n            if agent_info.unit_type == self.env.marine_id:\n                agent_type = \"marine\"\n            elif agent_info.unit_type == self.env.marauder_id:\n                agent_type = \"marauder\"\n            elif agent_info.unit_type == self.env.medivac_id:\n                agent_type = \"medivac\"\n            elif agent_info.unit_type == self.env.hydralisk_id:\n                agent_type = \"hydralisk\"\n            elif agent_info.unit_type == self.env.zergling_id:\n                agent_type = \"zergling\"\n            elif agent_info.unit_type == self.env.baneling_id:\n                agent_type = \"baneling\"\n            elif agent_info.unit_type == self.env.stalker_id:\n                agent_type = \"stalker\"\n            elif agent_info.unit_type == self.env.colossus_id:\n                agent_type = \"colossus\"\n            elif agent_info.unit_type == self.env.zealot_id:\n                agent_type = \"zealot\"\n            else:\n                raise AssertionError(f\"agent type {agent_type} not supported\")\n\n            if agent_type == last_type:\n                i += 1\n            else:\n                i = 0\n\n            agents.append(f\"{agent_type}_{i}\")\n            self.agents_id[agents[-1]] = agent_id\n            action_spaces[agents[-1]] = unit_action_space\n            last_type = agent_type\n\n        return agents, action_spaces\n\n    def seed(self, seed=None):\n        if seed is None:\n            self.env._seed = seeding.create_seed(seed, max_bytes=4)\n        else:\n            self.env._seed = seed\n        self.env.full_restart()\n\n    def render(self, mode=\"human\"):\n        self.env.render(mode)\n\n    def close(self):\n        self.env.close()\n\n    def reset(self):\n        self.env._episode_count = 1\n        self.env.reset()\n\n        self.agents = self.possible_agents[:]\n        self.frames = 0\n        self.all_dones = {agent: False for agent in self.possible_agents}\n        return self._observe_all()\n\n    def get_agent_smac_id(self, agent):\n        return self.agents_id[agent]\n\n    def _all_rewards(self, reward):\n        all_rewards = [reward] * len(self.agents)\n        return {\n            agent: reward for agent, reward in zip(self.agents, all_rewards)\n        }\n\n    def _observe_all(self):\n        all_obs = []\n        for agent in self.agents:\n            agent_id = self.get_agent_smac_id(agent)\n            obs = self.env.get_obs_agent(agent_id)\n            action_mask = self.env.get_avail_agent_actions(agent_id)\n            action_mask = action_mask[1:]\n            action_mask = np.array(action_mask).astype(np.int8)\n            obs = np.asarray(obs, dtype=np.float32)\n            all_obs.append(\n                {\"observation\": obs, \"action_mask\": action_mask}\n            )\n        return {agent: obs for agent, obs in zip(self.agents, all_obs)}\n\n    def _all_dones(self, step_done=False):\n        dones = [True] * len(self.agents)\n        if not step_done:\n            for i, agent in enumerate(self.agents):\n                agent_done = False\n                agent_id = self.get_agent_smac_id(agent)\n                agent_info = self.env.get_unit_by_id(agent_id)\n                if agent_info.health == 0:\n                    agent_done = True\n                dones[i] = agent_done\n        return {agent: bool(done) for agent, done in zip(self.agents, dones)}\n\n    def step(self, all_actions):\n        action_list = [0] * self.env.n_agents\n        for agent in self.agents:\n            agent_id = self.get_agent_smac_id(agent)\n            if agent in all_actions:\n                if all_actions[agent] is None:\n                    action_list[agent_id] = 0\n                else:\n                    action_list[agent_id] = all_actions[agent] + 1\n        self._reward, terminated, smac_info = self.env.step(action_list)\n        self.frames += 1\n        done = terminated or self.frames >= self.max_cycles\n\n        all_infos = {agent: {} for agent in self.agents}\n        # all_infos.update(smac_info)\n        all_dones = self._all_dones(done)\n        all_rewards = self._all_rewards(self._reward)\n        all_observes = self._observe_all()\n\n        self.agents = [\n            agent for agent in self.agents if not all_dones[agent]\n        ]\n\n        return all_observes, all_rewards, all_dones, all_infos\n\n    def __del__(self):\n        self.env.close()\n\n\nenv = make_env(raw_env)\n\n\nclass _parallel_env(smac_parallel_env, EzPickle):\n    metadata = {\"render.modes\": [\"human\"], \"name\": \"sc2\"}\n\n    def __init__(self, max_cycles, **smac_args):\n        EzPickle.__init__(self, max_cycles, **smac_args)\n        env = StarCraft2Env(**smac_args)\n        super().__init__(env, max_cycles)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/smac/env/pettingzoo/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Social_Cognition/ToCM/smac/env/pettingzoo/test/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Social_Cognition/ToCM/smac/env/pettingzoo/test/all_test.py",
    "content": "from smac.env.starcraft2.maps import smac_maps\nfrom pysc2 import maps as pysc2_maps\nfrom smac.env.pettingzoo import StarCraft2PZEnv as sc2\nimport pytest\nfrom pettingzoo import test\nimport pickle\n\nsmac_map_registry = smac_maps.get_smac_map_registry()\nall_maps = pysc2_maps.get_maps()\nmap_names = []\nfor map_name in smac_map_registry.keys():\n    map_class = all_maps[map_name]\n    if map_class.path:\n        map_names.append(map_name)\n\n\n@pytest.mark.parametrize((\"map_name\"), map_names)\ndef test_env(map_name):\n    env = sc2.env(map_name=map_name)\n    test.api_test(env)\n    # test.parallel_api_test(sc2_v0.parallel_env()) # does not pass it due to\n    # illegal actions test.seed_test(sc2.env, 50) # not required, sc2 env only\n    # allows reseeding at initialization\n    test.render_test(env)\n\n    recreated_env = pickle.loads(pickle.dumps(env))\n    test.api_test(recreated_env)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/smac/env/pettingzoo/test/smac_pettingzoo_test.py",
    "content": "import os\nimport sys\nimport inspect\nfrom pettingzoo import test\nfrom smac.env.pettingzoo import StarCraft2PZEnv as sc2\nimport pickle\n\ncurrent_dir = os.path.dirname(\n    os.path.abspath(inspect.getfile(inspect.currentframe()))\n)\nparent_dir = os.path.dirname(current_dir)\nsys.path.insert(0, parent_dir)\n\n\nif __name__ == \"__main__\":\n    env = sc2.env(map_name=\"corridor\")\n    test.api_test(env)\n    # test.parallel_api_test(sc2_v0.parallel_env()) # does not pass it due to\n    # illegal actions test.seed_test(sc2_v0.env, 50) # not required, sc2 env\n    # only allows reseeding at initialization\n\n    recreated_env = pickle.loads(pickle.dumps(env))\n    test.api_test(recreated_env)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/smac/env/starcraft2/__init__.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom absl import flags\n\nFLAGS = flags.FLAGS\nFLAGS([\"main.py\"])\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/smac/env/starcraft2/maps/__init__.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom smac.env.starcraft2.maps import smac_maps\n\n\ndef get_map_params(map_name):\n    map_param_registry = smac_maps.get_smac_map_registry()\n    return map_param_registry[map_name]\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/smac/env/starcraft2/maps/smac_maps.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom pysc2.maps import lib\n\n\nclass SMACMap(lib.Map):\n    directory = \"SMAC_Maps\"\n    download = \"https://github.com/oxwhirl/smac#smac-maps\"\n    players = 2\n    step_mul = 8\n    game_steps_per_episode = 0\n\n\nmap_param_registry = {\n    \"3m\": {\n        \"n_agents\": 3,\n        \"n_enemies\": 3,\n        \"limit\": 60,\n        \"a_race\": \"T\",\n        \"b_race\": \"T\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"marines\",\n    },\n    \"8m\": {\n        \"n_agents\": 8,\n        \"n_enemies\": 8,\n        \"limit\": 120,\n        \"a_race\": \"T\",\n        \"b_race\": \"T\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"marines\",\n    },\n    \"25m\": {\n        \"n_agents\": 25,\n        \"n_enemies\": 25,\n        \"limit\": 150,\n        \"a_race\": \"T\",\n        \"b_race\": \"T\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"marines\",\n    },\n    \"5m_vs_6m\": {\n        \"n_agents\": 5,\n        \"n_enemies\": 6,\n        \"limit\": 70,\n        \"a_race\": \"T\",\n        \"b_race\": \"T\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"marines\",\n    },\n    \"8m_vs_9m\": {\n        \"n_agents\": 8,\n        \"n_enemies\": 9,\n        \"limit\": 120,\n        \"a_race\": \"T\",\n        \"b_race\": \"T\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"marines\",\n    },\n    \"10m_vs_11m\": {\n        \"n_agents\": 10,\n        \"n_enemies\": 11,\n        \"limit\": 150,\n        \"a_race\": \"T\",\n        \"b_race\": \"T\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"marines\",\n    },\n    \"27m_vs_30m\": {\n        \"n_agents\": 27,\n        \"n_enemies\": 30,\n        \"limit\": 180,\n        \"a_race\": \"T\",\n        \"b_race\": \"T\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"marines\",\n    },\n    \"MMM\": {\n        \"n_agents\": 10,\n        \"n_enemies\": 10,\n        \"limit\": 150,\n        \"a_race\": \"T\",\n        \"b_race\": \"T\",\n        \"unit_type_bits\": 3,\n        \"map_type\": \"MMM\",\n    },\n    \"MMM2\": {\n        \"n_agents\": 10,\n        \"n_enemies\": 12,\n        \"limit\": 180,\n        \"a_race\": \"T\",\n        \"b_race\": \"T\",\n        \"unit_type_bits\": 3,\n        \"map_type\": \"MMM\",\n    },\n    \"2s3z\": {\n        \"n_agents\": 5,\n        \"n_enemies\": 5,\n        \"limit\": 120,\n        \"a_race\": \"P\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 2,\n        \"map_type\": \"stalkers_and_zealots\",\n    },\n    \"3s5z\": {\n        \"n_agents\": 8,\n        \"n_enemies\": 8,\n        \"limit\": 150,\n        \"a_race\": \"P\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 2,\n        \"map_type\": \"stalkers_and_zealots\",\n    },\n    \"3s5z_vs_3s6z\": {\n        \"n_agents\": 8,\n        \"n_enemies\": 9,\n        \"limit\": 170,\n        \"a_race\": \"P\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 2,\n        \"map_type\": \"stalkers_and_zealots\",\n    },\n    \"3s_vs_3z\": {\n        \"n_agents\": 3,\n        \"n_enemies\": 3,\n        \"limit\": 150,\n        \"a_race\": \"P\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"stalkers\",\n    },\n    \"3s_vs_4z\": {\n        \"n_agents\": 3,\n        \"n_enemies\": 4,\n        \"limit\": 200,\n        \"a_race\": \"P\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"stalkers\",\n    },\n    \"3s_vs_5z\": {\n        \"n_agents\": 3,\n        \"n_enemies\": 5,\n        \"limit\": 250,\n        \"a_race\": \"P\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"stalkers\",\n    },\n    \"1c3s5z\": {\n        \"n_agents\": 9,\n        \"n_enemies\": 9,\n        \"limit\": 180,\n        \"a_race\": \"P\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 3,\n        \"map_type\": \"colossi_stalkers_zealots\",\n    },\n    \"2m_vs_1z\": {\n        \"n_agents\": 2,\n        \"n_enemies\": 1,\n        \"limit\": 150,\n        \"a_race\": \"T\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"marines\",\n    },\n    \"corridor\": {\n        \"n_agents\": 6,\n        \"n_enemies\": 24,\n        \"limit\": 400,\n        \"a_race\": \"P\",\n        \"b_race\": \"Z\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"zealots\",\n    },\n    \"6h_vs_8z\": {\n        \"n_agents\": 6,\n        \"n_enemies\": 8,\n        \"limit\": 150,\n        \"a_race\": \"Z\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"hydralisks\",\n    },\n    \"2s_vs_1sc\": {\n        \"n_agents\": 2,\n        \"n_enemies\": 1,\n        \"limit\": 300,\n        \"a_race\": \"P\",\n        \"b_race\": \"Z\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"stalkers\",\n    },\n    \"so_many_baneling\": {\n        \"n_agents\": 7,\n        \"n_enemies\": 32,\n        \"limit\": 100,\n        \"a_race\": \"P\",\n        \"b_race\": \"Z\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"zealots\",\n    },\n    \"bane_vs_bane\": {\n        \"n_agents\": 24,\n        \"n_enemies\": 24,\n        \"limit\": 200,\n        \"a_race\": \"Z\",\n        \"b_race\": \"Z\",\n        \"unit_type_bits\": 2,\n        \"map_type\": \"bane\",\n    },\n    \"2c_vs_64zg\": {\n        \"n_agents\": 2,\n        \"n_enemies\": 64,\n        \"limit\": 400,\n        \"a_race\": \"P\",\n        \"b_race\": \"Z\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"colossus\",\n    },\n}\n\n\ndef get_smac_map_registry():\n    return map_param_registry\n\n\nfor name in map_param_registry.keys():\n    globals()[name] = type(name, (SMACMap,), dict(filename=name))\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/smac/env/starcraft2/render.py",
    "content": "import numpy as np\nimport re\nimport subprocess\nimport platform\nfrom absl import logging\nimport math\nimport time\nimport collections\nimport os\nimport pygame\nimport queue\n\nfrom pysc2.lib import colors\nfrom pysc2.lib import point\nfrom pysc2.lib.renderer_human import _Surface\nfrom pysc2.lib import transform\nfrom pysc2.lib import features\n\n\ndef clamp(n, smallest, largest):\n    return max(smallest, min(n, largest))\n\n\ndef _get_desktop_size():\n    \"\"\"Get the desktop size.\"\"\"\n    if platform.system() == \"Linux\":\n        try:\n            xrandr_query = subprocess.check_output([\"xrandr\", \"--query\"])\n            sizes = re.findall(\n                r\"\\bconnected primary (\\d+)x(\\d+)\", str(xrandr_query)\n            )\n            if sizes[0]:\n                return point.Point(int(sizes[0][0]), int(sizes[0][1]))\n        except ValueError:\n            logging.error(\"Failed to get the resolution from xrandr.\")\n\n    # Most general, but doesn't understand multiple monitors.\n    display_info = pygame.display.Info()\n    return point.Point(display_info.current_w, display_info.current_h)\n\n\nclass StarCraft2Renderer:\n    def __init__(self, env, mode):\n        os.environ[\"PYGAME_HIDE_SUPPORT_PROMPT\"] = \"hide\"\n\n        self.env = env\n        self.mode = mode\n        self.obs = None\n        self._window_scale = 0.75\n        self.game_info = game_info = self.env._controller.game_info()\n        self.static_data = self.env._controller.data()\n\n        self._obs_queue = queue.Queue()\n        self._game_times = collections.deque(\n            maxlen=100\n        )  # Avg FPS over 100 frames.  # pytype: disable=wrong-keyword-args\n        self._render_times = collections.deque(\n            maxlen=100\n        )  # pytype: disable=wrong-keyword-args\n        self._last_time = time.time()\n        self._last_game_loop = 0\n        self._name_lengths = {}\n\n        self._map_size = point.Point.build(game_info.start_raw.map_size)\n        self._playable = point.Rect(\n            point.Point.build(game_info.start_raw.playable_area.p0),\n            point.Point.build(game_info.start_raw.playable_area.p1),\n        )\n\n        window_size_px = point.Point(\n            self.env.window_size[0], self.env.window_size[1]\n        )\n        window_size_px = self._map_size.scale_max_size(\n            window_size_px * self._window_scale\n        ).ceil()\n        self._scale = window_size_px.y // 32\n\n        self.display = pygame.Surface(window_size_px)\n\n        if mode == \"human\":\n            self.display = pygame.display.set_mode(window_size_px, 0, 32)\n            pygame.display.init()\n\n            pygame.display.set_caption(\"Starcraft Viewer\")\n        pygame.font.init()\n        self._world_to_world_tl = transform.Linear(\n            point.Point(1, -1), point.Point(0, self._map_size.y)\n        )\n        self._world_tl_to_screen = transform.Linear(scale=window_size_px / 32)\n        self.screen_transform = transform.Chain(\n            self._world_to_world_tl, self._world_tl_to_screen\n        )\n\n        surf_loc = point.Rect(point.origin, window_size_px)\n        sub_surf = self.display.subsurface(\n            pygame.Rect(surf_loc.tl, surf_loc.size)\n        )\n        self._surf = _Surface(\n            sub_surf,\n            None,\n            surf_loc,\n            self.screen_transform,\n            None,\n            self.draw_screen,\n        )\n\n        self._font_small = pygame.font.Font(None, int(self._scale * 0.5))\n        self._font_large = pygame.font.Font(None, self._scale)\n\n    def close(self):\n        pygame.display.quit()\n        pygame.quit()\n\n    def _get_units(self):\n        for u in sorted(\n            self.obs.observation.raw_data.units,\n            key=lambda u: (u.pos.z, u.owner != 16, -u.radius, u.tag),\n        ):\n            yield u, point.Point.build(u.pos)\n\n    def get_unit_name(self, surf, name, radius):\n        \"\"\"Get a length limited unit name for drawing units.\"\"\"\n        key = (name, radius)\n        if key not in self._name_lengths:\n            max_len = surf.world_to_surf.fwd_dist(radius * 1.6)\n            for i in range(len(name)):\n                if self._font_small.size(name[: i + 1])[0] > max_len:\n                    self._name_lengths[key] = name[:i]\n                    break\n            else:\n                self._name_lengths[key] = name\n        return self._name_lengths[key]\n\n    def render(self, mode):\n        self.obs = self.env._obs\n        self.score = self.env.reward\n        self.step = self.env._episode_steps\n\n        now = time.time()\n        self._game_times.append(\n            (\n                now - self._last_time,\n                max(\n                    1,\n                    self.obs.observation.game_loop\n                    - self.obs.observation.game_loop,\n                ),\n            )\n        )\n\n        if mode == \"human\":\n            pygame.event.pump()\n\n        self._surf.draw(self._surf)\n\n        observation = np.array(pygame.surfarray.pixels3d(self.display))\n\n        if mode == \"human\":\n            pygame.display.flip()\n\n        self._last_time = now\n        self._last_game_loop = self.obs.observation.game_loop\n        # self._obs_queue.put(self.obs)\n        return (\n            np.transpose(observation, axes=(1, 0, 2))\n            if mode == \"rgb_array\"\n            else None\n        )\n\n    def draw_base_map(self, surf):\n        \"\"\"Draw the base map.\"\"\"\n        hmap_feature = features.SCREEN_FEATURES.height_map\n        hmap = self.env.terrain_height * 255\n        hmap = hmap.astype(np.uint8)\n        if (\n            self.env.map_name == \"corridor\"\n            or self.env.map_name == \"so_many_baneling\"\n            or self.env.map_name == \"2s_vs_1sc\"\n        ):\n            hmap = np.flip(hmap)\n        else:\n            hmap = np.rot90(hmap, axes=(1, 0))\n        if not hmap.any():\n            hmap = hmap + 100  # pylint: disable=g-no-augmented-assignment\n        hmap_color = hmap_feature.color(hmap)\n        out = hmap_color * 0.6\n\n        surf.blit_np_array(out)\n\n    def draw_units(self, surf):\n        \"\"\"Draw the units.\"\"\"\n        unit_dict = None  # Cache the units {tag: unit_proto} for orders.\n        tau = 2 * math.pi\n        for u, p in self._get_units():\n            fraction_damage = clamp(\n                (u.health_max - u.health) / (u.health_max or 1), 0, 1\n            )\n            surf.draw_circle(\n                colors.PLAYER_ABSOLUTE_PALETTE[u.owner], p, u.radius\n            )\n\n            if fraction_damage > 0:\n                surf.draw_circle(\n                    colors.PLAYER_ABSOLUTE_PALETTE[u.owner] // 2,\n                    p,\n                    u.radius * fraction_damage,\n                )\n            surf.draw_circle(colors.black, p, u.radius, thickness=1)\n\n            if self.static_data.unit_stats[u.unit_type].movement_speed > 0:\n                surf.draw_arc(\n                    colors.white,\n                    p,\n                    u.radius,\n                    u.facing - 0.1,\n                    u.facing + 0.1,\n                    thickness=1,\n                )\n\n            def draw_arc_ratio(\n                color, world_loc, radius, start, end, thickness=1\n            ):\n                surf.draw_arc(\n                    color, world_loc, radius, start * tau, end * tau, thickness\n                )\n\n            if u.shield and u.shield_max:\n                draw_arc_ratio(\n                    colors.blue, p, u.radius - 0.05, 0, u.shield / u.shield_max\n                )\n\n            if u.energy and u.energy_max:\n                draw_arc_ratio(\n                    colors.purple * 0.9,\n                    p,\n                    u.radius - 0.1,\n                    0,\n                    u.energy / u.energy_max,\n                )\n            elif u.orders and 0 < u.orders[0].progress < 1:\n                draw_arc_ratio(\n                    colors.cyan, p, u.radius - 0.15, 0, u.orders[0].progress\n                )\n            if u.buff_duration_remain and u.buff_duration_max:\n                draw_arc_ratio(\n                    colors.white,\n                    p,\n                    u.radius - 0.2,\n                    0,\n                    u.buff_duration_remain / u.buff_duration_max,\n                )\n            if u.attack_upgrade_level:\n                draw_arc_ratio(\n                    self.upgrade_colors[u.attack_upgrade_level],\n                    p,\n                    u.radius - 0.25,\n                    0.18,\n                    0.22,\n                    thickness=3,\n                )\n            if u.armor_upgrade_level:\n                draw_arc_ratio(\n                    self.upgrade_colors[u.armor_upgrade_level],\n                    p,\n                    u.radius - 0.25,\n                    0.23,\n                    0.27,\n                    thickness=3,\n                )\n            if u.shield_upgrade_level:\n                draw_arc_ratio(\n                    self.upgrade_colors[u.shield_upgrade_level],\n                    p,\n                    u.radius - 0.25,\n                    0.28,\n                    0.32,\n                    thickness=3,\n                )\n\n            def write_small(loc, s):\n                surf.write_world(self._font_small, colors.white, loc, str(s))\n\n            name = self.get_unit_name(\n                surf,\n                self.static_data.units.get(u.unit_type, \"<none>\"),\n                u.radius,\n            )\n\n            if name:\n                write_small(p, name)\n\n            start_point = p\n            for o in u.orders:\n                target_point = None\n                if o.HasField(\"target_unit_tag\"):\n                    if unit_dict is None:\n                        unit_dict = {\n                            t.tag: t\n                            for t in self.obs.observation.raw_data.units\n                        }\n                    target_unit = unit_dict.get(o.target_unit_tag)\n                    if target_unit:\n                        target_point = point.Point.build(target_unit.pos)\n                if target_point:\n                    surf.draw_line(colors.cyan, start_point, target_point)\n                    start_point = target_point\n                else:\n                    break\n\n    def draw_overlay(self, surf):\n        \"\"\"Draw the overlay describing resources.\"\"\"\n        obs = self.obs.observation\n        times, steps = zip(*self._game_times)\n        sec = obs.game_loop // 22.4\n        surf.write_screen(\n            self._font_large,\n            colors.green,\n            (-0.2, 0.2),\n            \"Score: %s, Step: %s, %.1f/s, Time: %d:%02d\"\n            % (\n                self.score,\n                self.step,\n                sum(steps) / (sum(times) or 1),\n                sec // 60,\n                sec % 60,\n            ),\n            align=\"right\",\n        )\n        surf.write_screen(\n            self._font_large,\n            colors.green * 0.8,\n            (-0.2, 1.2),\n            \"APM: %d, EPM: %d, FPS: O:%.1f, R:%.1f\"\n            % (\n                obs.score.score_details.current_apm,\n                obs.score.score_details.current_effective_apm,\n                len(times) / (sum(times) or 1),\n                len(self._render_times) / (sum(self._render_times) or 1),\n            ),\n            align=\"right\",\n        )\n\n    def draw_screen(self, surf):\n        \"\"\"Draw the screen area.\"\"\"\n        self.draw_base_map(surf)\n        self.draw_units(surf)\n        self.draw_overlay(surf)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/smac/env/starcraft2/starcraft2.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom smac.env.multiagentenv import MultiAgentEnv\nfrom smac.env.starcraft2.maps import get_map_params\n\nimport atexit\nfrom warnings import warn\nfrom operator import attrgetter\nfrom copy import deepcopy\nimport numpy as np\nimport enum\nimport math\nfrom absl import logging\n\nfrom pysc2 import maps\nfrom pysc2 import run_configs\nfrom pysc2.lib import protocol\n\nfrom s2clientprotocol import common_pb2 as sc_common\nfrom s2clientprotocol import sc2api_pb2 as sc_pb\nfrom s2clientprotocol import raw_pb2 as r_pb\nfrom s2clientprotocol import debug_pb2 as d_pb\n\nraces = {\n    \"R\": sc_common.Random,\n    \"P\": sc_common.Protoss,\n    \"T\": sc_common.Terran,\n    \"Z\": sc_common.Zerg,\n}\n\ndifficulties = {\n    \"1\": sc_pb.VeryEasy,\n    \"2\": sc_pb.Easy,\n    \"3\": sc_pb.Medium,\n    \"4\": sc_pb.MediumHard,\n    \"5\": sc_pb.Hard,\n    \"6\": sc_pb.Harder,\n    \"7\": sc_pb.VeryHard,\n    \"8\": sc_pb.CheatVision,\n    \"9\": sc_pb.CheatMoney,\n    \"A\": sc_pb.CheatInsane,\n}\n\nactions = {\n    \"move\": 16,  # target: PointOrUnit\n    \"attack\": 23,  # target: PointOrUnit\n    \"stop\": 4,  # target: None\n    \"heal\": 386,  # Unit\n}\n\n\nclass Direction(enum.IntEnum):\n    NORTH = 0\n    SOUTH = 1\n    EAST = 2\n    WEST = 3\n\n\nclass StarCraft2Env(MultiAgentEnv):\n    \"\"\"The StarCraft II environment for decentralised multi-agent\n    micromanagement scenarios.\n    \"\"\"\n    def __init__(\n        self,\n        map_name=\"8m\",\n        step_mul=8,\n        move_amount=2,\n        difficulty=\"7\",\n        game_version=None,\n        seed=None,\n        continuing_episode=False,\n        obs_all_health=True,\n        obs_own_health=True,\n        obs_last_action=False,\n        obs_pathing_grid=False,\n        obs_terrain_height=False,\n        obs_instead_of_state=False,\n        obs_timestep_number=False,\n        state_last_action=True,\n        state_timestep_number=False,\n        reward_sparse=False,\n        reward_only_positive=True,\n        reward_death_value=10,\n        reward_win=200,\n        reward_defeat=0,\n        reward_negative_scale=0.5,\n        reward_scale=True,\n        reward_scale_rate=20,\n        replay_dir=\"\",\n        replay_prefix=\"\",\n        window_size_x=1920,\n        window_size_y=1200,\n        heuristic_ai=False,\n        heuristic_rest=False,\n        debug=False,\n    ):\n        \"\"\"\n        Create a StarCraftC2Env environment.\n\n        Parameters\n        ----------\n        map_name : str, optional\n            The name of the SC2 map to play (default is \"8m\"). The full list\n            can be found by running bin/map_list.\n        step_mul : int, optional\n            How many game steps per agent step (default is 8). None\n            indicates to use the default map step_mul.\n        move_amount : float, optional\n            How far away units are ordered to move per step (default is 2).\n        difficulty : str, optional\n            The difficulty of built-in computer AI bot (default is \"7\").\n        game_version : str, optional\n            StarCraft II game version (default is None). None indicates the\n            latest version.\n        seed : int, optional\n            Random seed used during game initialisation. This allows to\n        continuing_episode : bool, optional\n            Whether to consider episodes continuing or finished after time\n            limit is reached (default is False).\n        obs_all_health : bool, optional\n            Agents receive the health of all units (in the sight range) as part\n            of observations (default is True).\n        obs_own_health : bool, optional\n            Agents receive their own health as a part of observations (default\n            is False). This flag is ignored when obs_all_health == True.\n        obs_last_action : bool, optional\n            Agents receive the last actions of all units (in the sight range)\n            as part of observations (default is False).\n        obs_pathing_grid : bool, optional\n            Whether observations include pathing values surrounding the agent\n            (default is False).\n        obs_terrain_height : bool, optional\n            Whether observations include terrain height values surrounding the\n            agent (default is False).\n        obs_instead_of_state : bool, optional\n            Use combination of all agents' observations as the global state\n            (default is False).\n        obs_timestep_number : bool, optional\n            Whether observations include the current timestep of the episode\n            (default is False).\n        state_last_action : bool, optional\n            Include the last actions of all agents as part of the global state\n            (default is True).\n        state_timestep_number : bool, optional\n            Whether the state include the current timestep of the episode\n            (default is False).\n        reward_sparse : bool, optional\n            Receive 1/-1 reward for winning/loosing an episode (default is\n            False). Whe rest of reward parameters are ignored if True.\n        reward_only_positive : bool, optional\n            Reward is always positive (default is True).\n        reward_death_value : float, optional\n            The amount of reward received for killing an enemy unit (default\n            is 10). This is also the negative penalty for having an allied unit\n            killed if reward_only_positive == False.\n        reward_win : float, optional\n            The reward for winning in an episode (default is 200).\n        reward_defeat : float, optional\n            The reward for loosing in an episode (default is 0). This value\n            should be nonpositive.\n        reward_negative_scale : float, optional\n            Scaling factor for negative rewards (default is 0.5). This\n            parameter is ignored when reward_only_positive == True.\n        reward_scale : bool, optional\n            Whether or not to scale the reward (default is True).\n        reward_scale_rate : float, optional\n            Reward scale rate (default is 20). When reward_scale == True, the\n            reward received by the agents is divided by (max_reward /\n            reward_scale_rate), where max_reward is the maximum possible\n            reward per episode without considering the shield regeneration\n            of Protoss units.\n        replay_dir : str, optional\n            The directory to save replays (default is None). If None, the\n            replay will be saved in Replays directory where StarCraft II is\n            installed.\n        replay_prefix : str, optional\n            The prefix of the replay to be saved (default is None). If None,\n            the name of the map will be used.\n        window_size_x : int, optional\n            The length of StarCraft II window size (default is 1920).\n        window_size_y: int, optional\n            The height of StarCraft II window size (default is 1200).\n        heuristic_ai: bool, optional\n            Whether or not to use a non-learning heuristic AI (default False).\n        heuristic_rest: bool, optional\n            At any moment, restrict the actions of the heuristic AI to be\n            chosen from actions available to RL agents (default is False).\n            Ignored if heuristic_ai == False.\n        debug: bool, optional\n            Log messages about observations, state, actions and rewards for\n            debugging purposes (default is False).\n        \"\"\"\n        # Map arguments\n        self.map_name = map_name\n        map_params = get_map_params(self.map_name)\n        self.n_agents = map_params[\"n_agents\"]\n        self.n_enemies = map_params[\"n_enemies\"]\n        self.episode_limit = map_params[\"limit\"]\n        self._move_amount = move_amount\n        self._step_mul = step_mul\n        self.difficulty = difficulty\n\n        # Observations and state\n        self.obs_own_health = obs_own_health\n        self.obs_all_health = obs_all_health\n        self.obs_instead_of_state = obs_instead_of_state\n        self.obs_last_action = obs_last_action\n        self.obs_pathing_grid = obs_pathing_grid\n        self.obs_terrain_height = obs_terrain_height\n        self.obs_timestep_number = obs_timestep_number\n        self.state_last_action = state_last_action\n        self.state_timestep_number = state_timestep_number\n        if self.obs_all_health:\n            self.obs_own_health = True\n        self.n_obs_pathing = 8\n        self.n_obs_height = 9\n\n        # Rewards args\n        self.reward_sparse = reward_sparse\n        self.reward_only_positive = reward_only_positive\n        self.reward_negative_scale = reward_negative_scale\n        self.reward_death_value = reward_death_value\n        self.reward_win = reward_win\n        self.reward_defeat = reward_defeat\n        self.reward_scale = reward_scale\n        self.reward_scale_rate = reward_scale_rate\n\n        # Other\n        self.game_version = game_version\n        self.continuing_episode = continuing_episode\n        self._seed = seed\n        self.heuristic_ai = heuristic_ai\n        self.heuristic_rest = heuristic_rest\n        self.debug = debug\n        self.window_size = (window_size_x, window_size_y)\n        self.replay_dir = replay_dir\n        self.replay_prefix = replay_prefix\n\n        # Actions\n        self.n_actions_no_attack = 6\n        self.n_actions_move = 4\n        self.n_actions = self.n_actions_no_attack + self.n_enemies\n\n        # Map info\n        self._agent_race = map_params[\"a_race\"]\n        self._bot_race = map_params[\"b_race\"]\n        self.shield_bits_ally = 1 if self._agent_race == \"P\" else 0\n        self.shield_bits_enemy = 1 if self._bot_race == \"P\" else 0\n        self.unit_type_bits = map_params[\"unit_type_bits\"]\n        self.map_type = map_params[\"map_type\"]\n        self._unit_types = None\n\n        self.max_reward = (\n            self.n_enemies * self.reward_death_value + self.reward_win\n        )\n\n        # create lists containing the names of attributes returned in states\n        self.ally_state_attr_names = [\n            \"health\",\n            \"energy/cooldown\",\n            \"rel_x\",\n            \"rel_y\",\n        ]\n        self.enemy_state_attr_names = [\"health\", \"rel_x\", \"rel_y\"]\n\n        if self.shield_bits_ally > 0:\n            self.ally_state_attr_names += [\"shield\"]\n        if self.shield_bits_enemy > 0:\n            self.enemy_state_attr_names += [\"shield\"]\n\n        if self.unit_type_bits > 0:\n            bit_attr_names = [\n                \"type_{}\".format(bit) for bit in range(self.unit_type_bits)\n            ]\n            self.ally_state_attr_names += bit_attr_names\n            self.enemy_state_attr_names += bit_attr_names\n\n        self.agents = {}\n        self.enemies = {}\n        self._episode_count = 0\n        self._episode_steps = 0\n        self._total_steps = 0\n        self._obs = None\n        self.battles_won = 0\n        self.battles_game = 0\n        self.timeouts = 0\n        self.force_restarts = 0\n        self.last_stats = None\n        self.death_tracker_ally = np.zeros(self.n_agents)\n        self.death_tracker_enemy = np.zeros(self.n_enemies)\n        self.previous_ally_units = None\n        self.previous_enemy_units = None\n        self.last_action = np.zeros((self.n_agents, self.n_actions))\n        self._min_unit_type = 0\n        self.marine_id = self.marauder_id = self.medivac_id = 0\n        self.hydralisk_id = self.zergling_id = self.baneling_id = 0\n        self.stalker_id = self.colossus_id = self.zealot_id = 0\n        self.max_distance_x = 0\n        self.max_distance_y = 0\n        self.map_x = 0\n        self.map_y = 0\n        self.reward = 0\n        self.renderer = None\n        self.terrain_height = None\n        self.pathing_grid = None\n        self._run_config = None\n        self._sc2_proc = None\n        self._controller = None\n\n        # Try to avoid leaking SC2 processes on shutdown\n        atexit.register(lambda: self.close())\n\n    def _launch(self):\n        \"\"\"Launch the StarCraft II game.\"\"\"\n        self._run_config = run_configs.get(version=self.game_version)\n        _map = maps.get(self.map_name)\n\n        # Setting up the interface\n        interface_options = sc_pb.InterfaceOptions(raw=True, score=False)\n        self._sc2_proc = self._run_config.start(\n            window_size=self.window_size, want_rgb=False\n        )\n        self._controller = self._sc2_proc.controller\n\n        # Request to create the game\n        create = sc_pb.RequestCreateGame(\n            local_map=sc_pb.LocalMap(\n                map_path=_map.path,\n                map_data=self._run_config.map_data(_map.path),\n            ),\n            realtime=False,\n            random_seed=self._seed,\n        )\n        create.player_setup.add(type=sc_pb.Participant)\n        create.player_setup.add(\n            type=sc_pb.Computer,\n            race=races[self._bot_race],\n            difficulty=difficulties[self.difficulty],\n        )\n        self._controller.create_game(create)\n\n        join = sc_pb.RequestJoinGame(\n            race=races[self._agent_race], options=interface_options\n        )\n        self._controller.join_game(join)\n\n        game_info = self._controller.game_info()\n        map_info = game_info.start_raw\n        map_play_area_min = map_info.playable_area.p0\n        map_play_area_max = map_info.playable_area.p1\n        self.max_distance_x = map_play_area_max.x - map_play_area_min.x\n        self.max_distance_y = map_play_area_max.y - map_play_area_min.y\n        self.map_x = map_info.map_size.x\n        self.map_y = map_info.map_size.y\n\n        if map_info.pathing_grid.bits_per_pixel == 1:\n            vals = np.array(list(map_info.pathing_grid.data)).reshape(\n                self.map_x, int(self.map_y / 8)\n            )\n            self.pathing_grid = np.transpose(\n                np.array(\n                    [\n                        [(b >> i) & 1 for b in row for i in range(7, -1, -1)]\n                        for row in vals\n                    ],\n                    dtype=np.bool,\n                )\n            )\n        else:\n            self.pathing_grid = np.invert(\n                np.flip(\n                    np.transpose(\n                        np.array(\n                            list(map_info.pathing_grid.data), dtype=np.bool\n                        ).reshape(self.map_x, self.map_y)\n                    ),\n                    axis=1,\n                )\n            )\n\n        self.terrain_height = (\n            np.flip(\n                np.transpose(\n                    np.array(list(map_info.terrain_height.data)).reshape(\n                        self.map_x, self.map_y\n                    )\n                ),\n                1,\n            )\n            / 255\n        )\n\n    def reset(self):\n        \"\"\"Reset the environment. Required after each full episode.\n        Returns initial observations and states.\n        \"\"\"\n        self._episode_steps = 0\n        if self._episode_count == 0:\n            # Launch StarCraft II\n            self._launch()\n        else:\n            self._restart()\n\n        # Information kept for counting the reward\n        self.death_tracker_ally = np.zeros(self.n_agents)\n        self.death_tracker_enemy = np.zeros(self.n_enemies)\n        self.previous_ally_units = None\n        self.previous_enemy_units = None\n        self.win_counted = False\n        self.defeat_counted = False\n\n        self.last_action = np.zeros((self.n_agents, self.n_actions))\n\n        if self.heuristic_ai:\n            self.heuristic_targets = [None] * self.n_agents\n\n        try:\n            self._obs = self._controller.observe()\n            self.init_units()\n        except (protocol.ProtocolError, protocol.ConnectionError):\n            self.full_restart()\n\n        if self.debug:\n            logging.debug(\n                \"Started Episode {}\".format(self._episode_count).center(\n                    60, \"*\"\n                )\n            )\n\n        return self.get_obs(), self.get_state()\n\n    def _restart(self):\n        \"\"\"Restart the environment by killing all units on the map.\n        There is a trigger in the SC2Map file, which restarts the\n        episode when there are no units left.\n        \"\"\"\n        try:\n            self._kill_all_units()\n            self._controller.step(2)\n        except (protocol.ProtocolError, protocol.ConnectionError):\n            self.full_restart()\n\n    def full_restart(self):\n        \"\"\"Full restart. Closes the SC2 process and launches a new one.\"\"\"\n        self._sc2_proc.close()\n        self._launch()\n        self.force_restarts += 1\n\n    def step(self, actions):\n        \"\"\"A single environment step. Returns reward, terminated, info.\"\"\"\n        actions_int = [int(a) for a in actions]\n\n        self.last_action = np.eye(self.n_actions)[np.array(actions_int)]\n\n        # Collect individual actions\n        sc_actions = []\n        if self.debug:\n            logging.debug(\"Actions\".center(60, \"-\"))\n\n        for a_id, action in enumerate(actions_int):\n            if not self.heuristic_ai:\n                sc_action = self.get_agent_action(a_id, action)\n            else:\n                sc_action, action_num = self.get_agent_action_heuristic(\n                    a_id, action\n                )\n                actions[a_id] = action_num\n            if sc_action:\n                sc_actions.append(sc_action)\n\n        # Send action request\n        req_actions = sc_pb.RequestAction(actions=sc_actions)\n        try:\n            self._controller.actions(req_actions)\n            # Make step in SC2, i.e. apply actions\n            self._controller.step(self._step_mul)\n            # Observe here so that we know if the episode is over.\n            self._obs = self._controller.observe()\n        except (protocol.ProtocolError, protocol.ConnectionError):\n            self.full_restart()\n            return 0, True, {}\n\n        self._total_steps += 1\n        self._episode_steps += 1\n\n        # Update units\n        game_end_code = self.update_units()\n\n        terminated = False\n        reward = self.reward_battle()\n        info = {\"battle_won\": False}\n\n        # count units that are still alive\n        dead_allies, dead_enemies = 0, 0\n        for _al_id, al_unit in self.agents.items():\n            if al_unit.health == 0:\n                dead_allies += 1\n        for _e_id, e_unit in self.enemies.items():\n            if e_unit.health == 0:\n                dead_enemies += 1\n\n        info[\"dead_allies\"] = dead_allies\n        info[\"dead_enemies\"] = dead_enemies\n\n        if game_end_code is not None:\n            # Battle is over\n            terminated = True\n            self.battles_game += 1\n            if game_end_code == 1 and not self.win_counted:\n                self.battles_won += 1\n                self.win_counted = True\n                info[\"battle_won\"] = True\n                if not self.reward_sparse:\n                    reward += self.reward_win\n                else:\n                    reward = 1\n            elif game_end_code == -1 and not self.defeat_counted:\n                self.defeat_counted = True\n                if not self.reward_sparse:\n                    reward += self.reward_defeat\n                else:\n                    reward = -1\n\n        elif self._episode_steps >= self.episode_limit:\n            # Episode limit reached\n            terminated = True\n            if self.continuing_episode:\n                info[\"episode_limit\"] = True\n            self.battles_game += 1\n            self.timeouts += 1\n\n        if self.debug:\n            logging.debug(\"Reward = {}\".format(reward).center(60, \"-\"))\n\n        if terminated:\n            self._episode_count += 1\n\n        if self.reward_scale:\n            reward /= self.max_reward / self.reward_scale_rate\n\n        self.reward = reward\n\n        return reward, terminated, info\n\n    def get_agent_action(self, a_id, action):\n        \"\"\"Construct the action for agent a_id.\"\"\"\n        avail_actions = self.get_avail_agent_actions(a_id)\n        assert (\n            avail_actions[action] == 1\n        ), \"Agent {} cannot perform action {}\".format(a_id, action)\n\n        unit = self.get_unit_by_id(a_id)\n        tag = unit.tag\n        x = unit.pos.x\n        y = unit.pos.y\n\n        if action == 0:\n            # no-op (valid only when dead)\n            assert unit.health == 0, \"No-op only available for dead agents.\"\n            if self.debug:\n                logging.debug(\"Agent {}: Dead\".format(a_id))\n            return None\n        elif action == 1:\n            # stop\n            cmd = r_pb.ActionRawUnitCommand(\n                ability_id=actions[\"stop\"],\n                unit_tags=[tag],\n                queue_command=False,\n            )\n            if self.debug:\n                logging.debug(\"Agent {}: Stop\".format(a_id))\n\n        elif action == 2:\n            # move north\n            cmd = r_pb.ActionRawUnitCommand(\n                ability_id=actions[\"move\"],\n                target_world_space_pos=sc_common.Point2D(\n                    x=x, y=y + self._move_amount\n                ),\n                unit_tags=[tag],\n                queue_command=False,\n            )\n            if self.debug:\n                logging.debug(\"Agent {}: Move North\".format(a_id))\n\n        elif action == 3:\n            # move south\n            cmd = r_pb.ActionRawUnitCommand(\n                ability_id=actions[\"move\"],\n                target_world_space_pos=sc_common.Point2D(\n                    x=x, y=y - self._move_amount\n                ),\n                unit_tags=[tag],\n                queue_command=False,\n            )\n            if self.debug:\n                logging.debug(\"Agent {}: Move South\".format(a_id))\n\n        elif action == 4:\n            # move east\n            cmd = r_pb.ActionRawUnitCommand(\n                ability_id=actions[\"move\"],\n                target_world_space_pos=sc_common.Point2D(\n                    x=x + self._move_amount, y=y\n                ),\n                unit_tags=[tag],\n                queue_command=False,\n            )\n            if self.debug:\n                logging.debug(\"Agent {}: Move East\".format(a_id))\n\n        elif action == 5:\n            # move west\n            cmd = r_pb.ActionRawUnitCommand(\n                ability_id=actions[\"move\"],\n                target_world_space_pos=sc_common.Point2D(\n                    x=x - self._move_amount, y=y\n                ),\n                unit_tags=[tag],\n                queue_command=False,\n            )\n            if self.debug:\n                logging.debug(\"Agent {}: Move West\".format(a_id))\n        else:\n            # attack/heal units that are in range\n            target_id = action - self.n_actions_no_attack\n            if self.map_type == \"MMM\" and unit.unit_type == self.medivac_id:\n                target_unit = self.agents[target_id]\n                action_name = \"heal\"\n            else:\n                target_unit = self.enemies[target_id]\n                action_name = \"attack\"\n\n            action_id = actions[action_name]\n            target_tag = target_unit.tag\n\n            cmd = r_pb.ActionRawUnitCommand(\n                ability_id=action_id,\n                target_unit_tag=target_tag,\n                unit_tags=[tag],\n                queue_command=False,\n            )\n\n            if self.debug:\n                logging.debug(\n                    \"Agent {} {}s unit # {}\".format(\n                        a_id, action_name, target_id\n                    )\n                )\n\n        sc_action = sc_pb.Action(action_raw=r_pb.ActionRaw(unit_command=cmd))\n        return sc_action\n\n    def get_agent_action_heuristic(self, a_id, action):\n        unit = self.get_unit_by_id(a_id)\n        tag = unit.tag\n\n        target = self.heuristic_targets[a_id]\n        if unit.unit_type == self.medivac_id:\n            if (\n                target is None\n                or self.agents[target].health == 0\n                or self.agents[target].health == self.agents[target].health_max\n            ):\n                min_dist = math.hypot(self.max_distance_x, self.max_distance_y)\n                min_id = -1\n                for al_id, al_unit in self.agents.items():\n                    if al_unit.unit_type == self.medivac_id:\n                        continue\n                    if (\n                        al_unit.health != 0\n                        and al_unit.health != al_unit.health_max\n                    ):\n                        dist = self.distance(\n                            unit.pos.x,\n                            unit.pos.y,\n                            al_unit.pos.x,\n                            al_unit.pos.y,\n                        )\n                        if dist < min_dist:\n                            min_dist = dist\n                            min_id = al_id\n                self.heuristic_targets[a_id] = min_id\n                if min_id == -1:\n                    self.heuristic_targets[a_id] = None\n                    return None, 0\n            action_id = actions[\"heal\"]\n            target_tag = self.agents[self.heuristic_targets[a_id]].tag\n        else:\n            if target is None or self.enemies[target].health == 0:\n                min_dist = math.hypot(self.max_distance_x, self.max_distance_y)\n                min_id = -1\n                for e_id, e_unit in self.enemies.items():\n                    if (\n                        unit.unit_type == self.marauder_id\n                        and e_unit.unit_type == self.medivac_id\n                    ):\n                        continue\n                    if e_unit.health > 0:\n                        dist = self.distance(\n                            unit.pos.x, unit.pos.y, e_unit.pos.x, e_unit.pos.y\n                        )\n                        if dist < min_dist:\n                            min_dist = dist\n                            min_id = e_id\n                self.heuristic_targets[a_id] = min_id\n                if min_id == -1:\n                    self.heuristic_targets[a_id] = None\n                    return None, 0\n            action_id = actions[\"attack\"]\n            target_tag = self.enemies[self.heuristic_targets[a_id]].tag\n\n        action_num = self.heuristic_targets[a_id] + self.n_actions_no_attack\n\n        # Check if the action is available\n        if (\n            self.heuristic_rest\n            and self.get_avail_agent_actions(a_id)[action_num] == 0\n        ):\n\n            # Move towards the target rather than attacking/healing\n            if unit.unit_type == self.medivac_id:\n                target_unit = self.agents[self.heuristic_targets[a_id]]\n            else:\n                target_unit = self.enemies[self.heuristic_targets[a_id]]\n\n            delta_x = target_unit.pos.x - unit.pos.x\n            delta_y = target_unit.pos.y - unit.pos.y\n\n            if abs(delta_x) > abs(delta_y):  # east or west\n                if delta_x > 0:  # east\n                    target_pos = sc_common.Point2D(\n                        x=unit.pos.x + self._move_amount, y=unit.pos.y\n                    )\n                    action_num = 4\n                else:  # west\n                    target_pos = sc_common.Point2D(\n                        x=unit.pos.x - self._move_amount, y=unit.pos.y\n                    )\n                    action_num = 5\n            else:  # north or south\n                if delta_y > 0:  # north\n                    target_pos = sc_common.Point2D(\n                        x=unit.pos.x, y=unit.pos.y + self._move_amount\n                    )\n                    action_num = 2\n                else:  # south\n                    target_pos = sc_common.Point2D(\n                        x=unit.pos.x, y=unit.pos.y - self._move_amount\n                    )\n                    action_num = 3\n\n            cmd = r_pb.ActionRawUnitCommand(\n                ability_id=actions[\"move\"],\n                target_world_space_pos=target_pos,\n                unit_tags=[tag],\n                queue_command=False,\n            )\n        else:\n            # Attack/heal the target\n            cmd = r_pb.ActionRawUnitCommand(\n                ability_id=action_id,\n                target_unit_tag=target_tag,\n                unit_tags=[tag],\n                queue_command=False,\n            )\n\n        sc_action = sc_pb.Action(action_raw=r_pb.ActionRaw(unit_command=cmd))\n        return sc_action, action_num\n\n    def reward_battle(self):\n        \"\"\"Reward function when self.reward_spare==False.\n        Returns accumulative hit/shield point damage dealt to the enemy\n        + reward_death_value per enemy unit killed, and, in case\n        self.reward_only_positive == False, - (damage dealt to ally units\n        + reward_death_value per ally unit killed) * self.reward_negative_scale\n        \"\"\"\n        if self.reward_sparse:\n            return 0\n\n        reward = 0\n        delta_deaths = 0\n        delta_ally = 0\n        delta_enemy = 0\n\n        neg_scale = self.reward_negative_scale\n\n        # update deaths\n        for al_id, al_unit in self.agents.items():\n            if not self.death_tracker_ally[al_id]:\n                # did not die so far\n                prev_health = (\n                    self.previous_ally_units[al_id].health\n                    + self.previous_ally_units[al_id].shield\n                )\n                if al_unit.health == 0:\n                    # just died\n                    self.death_tracker_ally[al_id] = 1\n                    if not self.reward_only_positive:\n                        delta_deaths -= self.reward_death_value * neg_scale\n                    delta_ally += prev_health * neg_scale\n                else:\n                    # still alive\n                    delta_ally += neg_scale * (\n                        prev_health - al_unit.health - al_unit.shield\n                    )\n\n        for e_id, e_unit in self.enemies.items():\n            if not self.death_tracker_enemy[e_id]:\n                prev_health = (\n                    self.previous_enemy_units[e_id].health\n                    + self.previous_enemy_units[e_id].shield\n                )\n                if e_unit.health == 0:\n                    self.death_tracker_enemy[e_id] = 1\n                    delta_deaths += self.reward_death_value\n                    delta_enemy += prev_health\n                else:\n                    delta_enemy += prev_health - e_unit.health - e_unit.shield\n\n        if self.reward_only_positive:\n            reward = abs(delta_enemy + delta_deaths)  # shield regeneration\n        else:\n            reward = delta_enemy + delta_deaths - delta_ally\n\n        return reward\n\n    def get_total_actions(self):\n        \"\"\"Returns the total number of actions an agent could ever take.\"\"\"\n        return self.n_actions\n\n    @staticmethod\n    def distance(x1, y1, x2, y2):\n        \"\"\"Distance between two points.\"\"\"\n        return math.hypot(x2 - x1, y2 - y1)\n\n    def unit_shoot_range(self, agent_id):\n        \"\"\"Returns the shooting range for an agent.\"\"\"\n        return 6\n\n    def unit_sight_range(self, agent_id):\n        \"\"\"Returns the sight range for an agent.\"\"\"\n        return 9\n\n    def unit_max_cooldown(self, unit):\n        \"\"\"Returns the maximal cooldown for a unit.\"\"\"\n        switcher = {\n            self.marine_id: 15,\n            self.marauder_id: 25,\n            self.medivac_id: 200,  # max energy\n            self.stalker_id: 35,\n            self.zealot_id: 22,\n            self.colossus_id: 24,\n            self.hydralisk_id: 10,\n            self.zergling_id: 11,\n            self.baneling_id: 1,\n        }\n        return switcher.get(unit.unit_type, 15)\n\n    def save_replay(self):\n        \"\"\"Save a replay.\"\"\"\n        prefix = self.replay_prefix or self.map_name\n        replay_dir = self.replay_dir or \"\"\n        replay_path = self._run_config.save_replay(\n            self._controller.save_replay(),\n            replay_dir=replay_dir,\n            prefix=prefix,\n        )\n        logging.info(\"Replay saved at: %s\" % replay_path)\n\n    def unit_max_shield(self, unit):\n        \"\"\"Returns maximal shield for a given unit.\"\"\"\n        if unit.unit_type == 74 or unit.unit_type == self.stalker_id:\n            return 80  # Protoss's Stalker\n        if unit.unit_type == 73 or unit.unit_type == self.zealot_id:\n            return 50  # Protoss's Zaelot\n        if unit.unit_type == 4 or unit.unit_type == self.colossus_id:\n            return 150  # Protoss's Colossus\n\n    def can_move(self, unit, direction):\n        \"\"\"Whether a unit can move in a given direction.\"\"\"\n        m = self._move_amount / 2\n\n        if direction == Direction.NORTH:\n            x, y = int(unit.pos.x), int(unit.pos.y + m)\n        elif direction == Direction.SOUTH:\n            x, y = int(unit.pos.x), int(unit.pos.y - m)\n        elif direction == Direction.EAST:\n            x, y = int(unit.pos.x + m), int(unit.pos.y)\n        else:\n            x, y = int(unit.pos.x - m), int(unit.pos.y)\n\n        if self.check_bounds(x, y) and self.pathing_grid[x, y]:\n            return True\n\n        return False\n\n    def get_surrounding_points(self, unit, include_self=False):\n        \"\"\"Returns the surrounding points of the unit in 8 directions.\"\"\"\n        x = int(unit.pos.x)\n        y = int(unit.pos.y)\n\n        ma = self._move_amount\n\n        points = [\n            (x, y + 2 * ma),\n            (x, y - 2 * ma),\n            (x + 2 * ma, y),\n            (x - 2 * ma, y),\n            (x + ma, y + ma),\n            (x - ma, y - ma),\n            (x + ma, y - ma),\n            (x - ma, y + ma),\n        ]\n\n        if include_self:\n            points.append((x, y))\n\n        return points\n\n    def check_bounds(self, x, y):\n        \"\"\"Whether a point is within the map bounds.\"\"\"\n        return 0 <= x < self.map_x and 0 <= y < self.map_y\n\n    def get_surrounding_pathing(self, unit):\n        \"\"\"Returns pathing values of the grid surrounding the given unit.\"\"\"\n        points = self.get_surrounding_points(unit, include_self=False)\n        vals = [\n            self.pathing_grid[x, y] if self.check_bounds(x, y) else 1\n            for x, y in points\n        ]\n        return vals\n\n    def get_surrounding_height(self, unit):\n        \"\"\"Returns height values of the grid surrounding the given unit.\"\"\"\n        points = self.get_surrounding_points(unit, include_self=True)\n        vals = [\n            self.terrain_height[x, y] if self.check_bounds(x, y) else 1\n            for x, y in points\n        ]\n        return vals\n\n    def get_obs_agent(self, agent_id):\n        \"\"\"Returns observation for agent_id. The observation is composed of:\n\n        - agent movement features (where it can move to, height information\n            and pathing grid)\n        - enemy features (available_to_attack, health, relative_x, relative_y,\n            shield, unit_type)\n        - ally features (visible, distance, relative_x, relative_y, shield,\n            unit_type)\n        - agent unit features (health, shield, unit_type)\n\n        All of this information is flattened and concatenated into a list,\n        in the aforementioned order. To know the sizes of each of the\n        features inside the final list of features, take a look at the\n        functions ``get_obs_move_feats_size()``,\n        ``get_obs_enemy_feats_size()``, ``get_obs_ally_feats_size()`` and\n        ``get_obs_own_feats_size()``.\n\n        The size of the observation vector may vary, depending on the\n        environment configuration and type of units present in the map.\n        For instance, non-Protoss units will not have shields, movement\n        features may or may not include terrain height and pathing grid,\n        unit_type is not included if there is only one type of unit in the\n        map etc.).\n\n        NOTE: Agents should have access only to their local observations\n        during decentralised execution.\n        \"\"\"\n        unit = self.get_unit_by_id(agent_id)\n\n        move_feats_dim = self.get_obs_move_feats_size()\n        enemy_feats_dim = self.get_obs_enemy_feats_size()\n        ally_feats_dim = self.get_obs_ally_feats_size()\n        own_feats_dim = self.get_obs_own_feats_size()\n\n        move_feats = np.zeros(move_feats_dim, dtype=np.float32)\n        enemy_feats = np.zeros(enemy_feats_dim, dtype=np.float32)\n        ally_feats = np.zeros(ally_feats_dim, dtype=np.float32)\n        own_feats = np.zeros(own_feats_dim, dtype=np.float32)\n\n        if unit.health > 0:  # otherwise dead, return all zeros\n            x = unit.pos.x\n            y = unit.pos.y\n            sight_range = self.unit_sight_range(agent_id)\n\n            # Movement features\n            avail_actions = self.get_avail_agent_actions(agent_id)\n            for m in range(self.n_actions_move):\n                move_feats[m] = avail_actions[m + 2]\n\n            ind = self.n_actions_move\n\n            if self.obs_pathing_grid:\n                move_feats[\n                    ind : ind + self.n_obs_pathing  # noqa\n                ] = self.get_surrounding_pathing(unit)\n                ind += self.n_obs_pathing\n\n            if self.obs_terrain_height:\n                move_feats[ind:] = self.get_surrounding_height(unit)\n\n            # Enemy features\n            for e_id, e_unit in self.enemies.items():\n                e_x = e_unit.pos.x\n                e_y = e_unit.pos.y\n                dist = self.distance(x, y, e_x, e_y)\n\n                if (\n                    dist < sight_range and e_unit.health > 0\n                ):  # visible and alive\n                    # Sight range > shoot range\n                    enemy_feats[e_id, 0] = avail_actions[\n                        self.n_actions_no_attack + e_id\n                    ]  # available\n                    enemy_feats[e_id, 1] = dist / sight_range  # distance\n                    enemy_feats[e_id, 2] = (\n                        e_x - x\n                    ) / sight_range  # relative X\n                    enemy_feats[e_id, 3] = (\n                        e_y - y\n                    ) / sight_range  # relative Y\n\n                    ind = 4\n                    if self.obs_all_health:\n                        enemy_feats[e_id, ind] = (\n                            e_unit.health / e_unit.health_max\n                        )  # health\n                        ind += 1\n                        if self.shield_bits_enemy > 0:\n                            max_shield = self.unit_max_shield(e_unit)\n                            enemy_feats[e_id, ind] = (\n                                e_unit.shield / max_shield\n                            )  # shield\n                            ind += 1\n\n                    if self.unit_type_bits > 0:\n                        type_id = self.get_unit_type_id(e_unit, False)\n                        enemy_feats[e_id, ind + type_id] = 1  # unit type\n\n            # Ally features\n            al_ids = [\n                al_id for al_id in range(self.n_agents) if al_id != agent_id\n            ]\n            for i, al_id in enumerate(al_ids):\n\n                al_unit = self.get_unit_by_id(al_id)\n                al_x = al_unit.pos.x\n                al_y = al_unit.pos.y\n                dist = self.distance(x, y, al_x, al_y)\n\n                if (\n                    dist < sight_range and al_unit.health > 0\n                ):  # visible and alive\n                    ally_feats[i, 0] = 1  # visible\n                    ally_feats[i, 1] = dist / sight_range  # distance\n                    ally_feats[i, 2] = (al_x - x) / sight_range  # relative X\n                    ally_feats[i, 3] = (al_y - y) / sight_range  # relative Y\n\n                    ind = 4\n                    if self.obs_all_health:\n                        ally_feats[i, ind] = (\n                            al_unit.health / al_unit.health_max\n                        )  # health\n                        ind += 1\n                        if self.shield_bits_ally > 0:\n                            max_shield = self.unit_max_shield(al_unit)\n                            ally_feats[i, ind] = (\n                                al_unit.shield / max_shield\n                            )  # shield\n                            ind += 1\n\n                    if self.unit_type_bits > 0:\n                        type_id = self.get_unit_type_id(al_unit, True)\n                        ally_feats[i, ind + type_id] = 1\n                        ind += self.unit_type_bits\n\n                    if self.obs_last_action:\n                        ally_feats[i, ind:] = self.last_action[al_id]\n\n            # Own features\n            ind = 0\n            if self.obs_own_health:\n                own_feats[ind] = unit.health / unit.health_max\n                ind += 1\n                if self.shield_bits_ally > 0:\n                    max_shield = self.unit_max_shield(unit)\n                    own_feats[ind] = unit.shield / max_shield\n                    ind += 1\n\n            if self.unit_type_bits > 0:\n                type_id = self.get_unit_type_id(unit, True)\n                own_feats[ind + type_id] = 1\n\n        agent_obs = np.concatenate(\n            (\n                move_feats.flatten(),\n                enemy_feats.flatten(),\n                ally_feats.flatten(),\n                own_feats.flatten(),\n            )\n        )\n\n        if self.obs_timestep_number:\n            agent_obs = np.append(\n                agent_obs, self._episode_steps / self.episode_limit\n            )\n\n        if self.debug:\n            logging.debug(\"Obs Agent: {}\".format(agent_id).center(60, \"-\"))\n            logging.debug(\n                \"Avail. actions {}\".format(\n                    self.get_avail_agent_actions(agent_id)\n                )\n            )\n            logging.debug(\"Move feats {}\".format(move_feats))\n            logging.debug(\"Enemy feats {}\".format(enemy_feats))\n            logging.debug(\"Ally feats {}\".format(ally_feats))\n            logging.debug(\"Own feats {}\".format(own_feats))\n\n        return agent_obs\n\n    def get_obs(self):\n        \"\"\"Returns all agent observations in a list.\n        NOTE: Agents should have access only to their local observations\n        during decentralised execution.\n        \"\"\"\n        agents_obs = [self.get_obs_agent(i) for i in range(self.n_agents)]\n        return agents_obs  # return a list of agent obs\n\n    def get_state(self):\n        \"\"\"Returns the global state.\n        NOTE: This functon should not be used during decentralised execution.\n        \"\"\"\n        if self.obs_instead_of_state:\n            obs_concat = np.concatenate(self.get_obs(), axis=0).astype(\n                np.float32\n            )\n            return obs_concat\n\n        state_dict = self.get_state_dict()\n\n        state = np.append(\n            state_dict[\"allies\"].flatten(), state_dict[\"enemies\"].flatten()\n        )\n        if \"last_action\" in state_dict:\n            state = np.append(state, state_dict[\"last_action\"].flatten())\n        if \"timestep\" in state_dict:\n            state = np.append(state, state_dict[\"timestep\"])\n\n        state = state.astype(dtype=np.float32)\n\n        if self.debug:\n            logging.debug(\"STATE\".center(60, \"-\"))\n            logging.debug(\"Ally state {}\".format(state_dict[\"allies\"]))\n            logging.debug(\"Enemy state {}\".format(state_dict[\"enemies\"]))\n            if self.state_last_action:\n                logging.debug(\"Last actions {}\".format(self.last_action))\n\n        return state\n\n    def get_ally_num_attributes(self):\n        return len(self.ally_state_attr_names)\n\n    def get_enemy_num_attributes(self):\n        return len(self.enemy_state_attr_names)\n\n    def get_state_dict(self):\n        \"\"\"Returns the global state as a dictionary.\n\n        - allies: numpy array containing agents and their attributes\n        - enemies: numpy array containing enemies and their attributes\n        - last_action: numpy array of previous actions for each agent\n        - timestep: current no. of steps divided by total no. of steps\n\n        NOTE: This function should not be used during decentralised execution.\n        \"\"\"\n\n        # number of features equals the number of attribute names\n        nf_al = self.get_ally_num_attributes()\n        nf_en = self.get_enemy_num_attributes()\n\n        ally_state = np.zeros((self.n_agents, nf_al))\n        enemy_state = np.zeros((self.n_enemies, nf_en))\n\n        center_x = self.map_x / 2\n        center_y = self.map_y / 2\n\n        for al_id, al_unit in self.agents.items():\n            if al_unit.health > 0:\n                x = al_unit.pos.x\n                y = al_unit.pos.y\n                max_cd = self.unit_max_cooldown(al_unit)\n\n                ally_state[al_id, 0] = (\n                    al_unit.health / al_unit.health_max\n                )  # health\n                if (\n                    self.map_type == \"MMM\"\n                    and al_unit.unit_type == self.medivac_id\n                ):\n                    ally_state[al_id, 1] = al_unit.energy / max_cd  # energy\n                else:\n                    ally_state[al_id, 1] = (\n                        al_unit.weapon_cooldown / max_cd\n                    )  # cooldown\n                ally_state[al_id, 2] = (\n                    x - center_x\n                ) / self.max_distance_x  # relative X\n                ally_state[al_id, 3] = (\n                    y - center_y\n                ) / self.max_distance_y  # relative Y\n\n                if self.shield_bits_ally > 0:\n                    max_shield = self.unit_max_shield(al_unit)\n                    ally_state[al_id, 4] = (\n                        al_unit.shield / max_shield\n                    )  # shield\n\n                if self.unit_type_bits > 0:\n                    type_id = self.get_unit_type_id(al_unit, True)\n                    ally_state[al_id, type_id - self.unit_type_bits] = 1\n\n        for e_id, e_unit in self.enemies.items():\n            if e_unit.health > 0:\n                x = e_unit.pos.x\n                y = e_unit.pos.y\n\n                enemy_state[e_id, 0] = (\n                    e_unit.health / e_unit.health_max\n                )  # health\n                enemy_state[e_id, 1] = (\n                    x - center_x\n                ) / self.max_distance_x  # relative X\n                enemy_state[e_id, 2] = (\n                    y - center_y\n                ) / self.max_distance_y  # relative Y\n\n                if self.shield_bits_enemy > 0:\n                    max_shield = self.unit_max_shield(e_unit)\n                    enemy_state[e_id, 3] = e_unit.shield / max_shield  # shield\n\n                if self.unit_type_bits > 0:\n                    type_id = self.get_unit_type_id(e_unit, False)\n                    enemy_state[e_id, type_id - self.unit_type_bits] = 1\n\n        state = {\"allies\": ally_state, \"enemies\": enemy_state}\n\n        if self.state_last_action:\n            state[\"last_action\"] = self.last_action\n        if self.state_timestep_number:\n            state[\"timestep\"] = self._episode_steps / self.episode_limit\n\n        return state\n\n    def get_obs_enemy_feats_size(self):\n        \"\"\"Returns the dimensions of the matrix containing enemy features.\n        Size is n_enemies x n_features.\n        \"\"\"\n        nf_en = 4 + self.unit_type_bits\n\n        if self.obs_all_health:\n            nf_en += 1 + self.shield_bits_enemy\n\n        return self.n_enemies, nf_en\n\n    def get_obs_ally_feats_size(self):\n        \"\"\"Returns the dimensions of the matrix containing ally features.\n        Size is n_allies x n_features.\n        \"\"\"\n        nf_al = 4 + self.unit_type_bits\n\n        if self.obs_all_health:\n            nf_al += 1 + self.shield_bits_ally\n\n        if self.obs_last_action:\n            nf_al += self.n_actions\n\n        return self.n_agents - 1, nf_al\n\n    def get_obs_own_feats_size(self):\n        \"\"\"\n        Returns the size of the vector containing the agents' own features.\n        \"\"\"\n        own_feats = self.unit_type_bits\n        if self.obs_own_health:\n            own_feats += 1 + self.shield_bits_ally\n        if self.obs_timestep_number:\n            own_feats += 1\n\n        return own_feats\n\n    def get_obs_move_feats_size(self):\n        \"\"\"Returns the size of the vector containing the agents's movement-\n        related features.\n        \"\"\"\n        move_feats = self.n_actions_move\n        if self.obs_pathing_grid:\n            move_feats += self.n_obs_pathing\n        if self.obs_terrain_height:\n            move_feats += self.n_obs_height\n\n        return move_feats\n\n    def get_obs_size(self):\n        \"\"\"Returns the size of the observation.\"\"\"\n        own_feats = self.get_obs_own_feats_size()\n        move_feats = self.get_obs_move_feats_size()\n\n        n_enemies, n_enemy_feats = self.get_obs_enemy_feats_size()\n        n_allies, n_ally_feats = self.get_obs_ally_feats_size()\n\n        enemy_feats = n_enemies * n_enemy_feats\n        ally_feats = n_allies * n_ally_feats\n\n        return move_feats + enemy_feats + ally_feats + own_feats\n\n    def get_state_size(self):\n        \"\"\"Returns the size of the global state.\"\"\"\n        if self.obs_instead_of_state:\n            return self.get_obs_size() * self.n_agents\n\n        nf_al = 4 + self.shield_bits_ally + self.unit_type_bits\n        nf_en = 3 + self.shield_bits_enemy + self.unit_type_bits\n\n        enemy_state = self.n_enemies * nf_en\n        ally_state = self.n_agents * nf_al\n\n        size = enemy_state + ally_state\n\n        if self.state_last_action:\n            size += self.n_agents * self.n_actions\n        if self.state_timestep_number:\n            size += 1\n\n        return size\n\n    def get_visibility_matrix(self):\n        \"\"\"Returns a boolean numpy array of dimensions\n        (n_agents, n_agents + n_enemies) indicating which units\n        are visible to each agent.\n        \"\"\"\n        arr = np.zeros(\n            (self.n_agents, self.n_agents + self.n_enemies),\n            dtype=np.bool,\n        )\n\n        for agent_id in range(self.n_agents):\n            current_agent = self.get_unit_by_id(agent_id)\n            if current_agent.health > 0:  # it agent not dead\n                x = current_agent.pos.x\n                y = current_agent.pos.y\n                sight_range = self.unit_sight_range(agent_id)\n\n                # Enemies\n                for e_id, e_unit in self.enemies.items():\n                    e_x = e_unit.pos.x\n                    e_y = e_unit.pos.y\n                    dist = self.distance(x, y, e_x, e_y)\n\n                    if dist < sight_range and e_unit.health > 0:\n                        # visible and alive\n                        arr[agent_id, self.n_agents + e_id] = 1\n\n                # The matrix for allies is filled symmetrically\n                al_ids = [\n                    al_id for al_id in range(self.n_agents) if al_id > agent_id\n                ]\n                for _, al_id in enumerate(al_ids):\n                    al_unit = self.get_unit_by_id(al_id)\n                    al_x = al_unit.pos.x\n                    al_y = al_unit.pos.y\n                    dist = self.distance(x, y, al_x, al_y)\n\n                    if dist < sight_range and al_unit.health > 0:\n                        # visible and alive\n                        arr[agent_id, al_id] = arr[al_id, agent_id] = 1\n\n        return arr\n\n    def get_unit_type_id(self, unit, ally):\n        \"\"\"Returns the ID of unit type in the given scenario.\"\"\"\n        if ally:  # use new SC2 unit types\n            type_id = unit.unit_type - self._min_unit_type\n        else:  # use default SC2 unit types\n            if self.map_type == \"stalkers_and_zealots\":\n                # id(Stalker) = 74, id(Zealot) = 73\n                type_id = unit.unit_type - 73\n            elif self.map_type == \"colossi_stalkers_zealots\":\n                # id(Stalker) = 74, id(Zealot) = 73, id(Colossus) = 4\n                if unit.unit_type == 4:\n                    type_id = 0\n                elif unit.unit_type == 74:\n                    type_id = 1\n                else:\n                    type_id = 2\n            elif self.map_type == \"bane\":\n                if unit.unit_type == 9:\n                    type_id = 0\n                else:\n                    type_id = 1\n            elif self.map_type == \"MMM\":\n                if unit.unit_type == 51:\n                    type_id = 0\n                elif unit.unit_type == 48:\n                    type_id = 1\n                else:\n                    type_id = 2\n\n        return type_id\n\n    def get_avail_agent_actions(self, agent_id):\n        \"\"\"Returns the available actions for agent_id.\"\"\"\n        unit = self.get_unit_by_id(agent_id)\n        if unit.health > 0:\n            # cannot choose no-op when alive\n            avail_actions = [0] * self.n_actions\n\n            # stop should be allowed\n            avail_actions[1] = 1\n\n            # see if we can move\n            if self.can_move(unit, Direction.NORTH):\n                avail_actions[2] = 1\n            if self.can_move(unit, Direction.SOUTH):\n                avail_actions[3] = 1\n            if self.can_move(unit, Direction.EAST):\n                avail_actions[4] = 1\n            if self.can_move(unit, Direction.WEST):\n                avail_actions[5] = 1\n\n            # Can attack only alive units that are alive in the shooting range\n            shoot_range = self.unit_shoot_range(agent_id)\n\n            target_items = self.enemies.items()\n            if self.map_type == \"MMM\" and unit.unit_type == self.medivac_id:\n                # Medivacs cannot heal themselves or other flying units\n                target_items = [\n                    (t_id, t_unit)\n                    for (t_id, t_unit) in self.agents.items()\n                    if t_unit.unit_type != self.medivac_id\n                ]\n\n            for t_id, t_unit in target_items:\n                if t_unit.health > 0:\n                    dist = self.distance(\n                        unit.pos.x, unit.pos.y, t_unit.pos.x, t_unit.pos.y\n                    )\n                    if dist <= shoot_range:\n                        avail_actions[t_id + self.n_actions_no_attack] = 1\n\n            return avail_actions\n\n        else:\n            # only no-op allowed\n            return [1] + [0] * (self.n_actions - 1)\n\n    def get_avail_actions(self):\n        \"\"\"Returns the available actions of all agents in a list.\"\"\"\n        avail_actions = []\n        for agent_id in range(self.n_agents):\n            avail_agent = self.get_avail_agent_actions(agent_id)\n            avail_actions.append(avail_agent)\n        return avail_actions\n\n    def close(self):\n        \"\"\"Close StarCraft II.\"\"\"\n        if self.renderer is not None:\n            self.renderer.close()\n            self.renderer = None\n        if self._sc2_proc:\n            self._sc2_proc.close()\n\n    def seed(self):\n        \"\"\"Returns the random seed used by the environment.\"\"\"\n        return self._seed\n\n    def render(self, mode=\"human\"):\n        if self.renderer is None:\n            from smac.env.starcraft2.render import StarCraft2Renderer\n\n            self.renderer = StarCraft2Renderer(self, mode)\n        assert (\n            mode == self.renderer.mode\n        ), \"mode must be consistent across render calls\"\n        return self.renderer.render(mode)\n\n    def _kill_all_units(self):\n        \"\"\"Kill all units on the map.\"\"\"\n        units_alive = [\n            unit.tag for unit in self.agents.values() if unit.health > 0\n        ] + [unit.tag for unit in self.enemies.values() if unit.health > 0]\n        debug_command = [\n            d_pb.DebugCommand(kill_unit=d_pb.DebugKillUnit(tag=units_alive))\n        ]\n        self._controller.debug(debug_command)\n\n    def init_units(self):\n        \"\"\"Initialise the units.\"\"\"\n        while True:\n            # Sometimes not all units have yet been created by SC2\n            self.agents = {}\n            self.enemies = {}\n\n            ally_units = [\n                unit\n                for unit in self._obs.observation.raw_data.units\n                if unit.owner == 1\n            ]\n            ally_units_sorted = sorted(\n                ally_units,\n                key=attrgetter(\"unit_type\", \"pos.x\", \"pos.y\"),\n                reverse=False,\n            )\n\n            for i in range(len(ally_units_sorted)):\n                self.agents[i] = ally_units_sorted[i]\n                if self.debug:\n                    logging.debug(\n                        \"Unit {} is {}, x = {}, y = {}\".format(\n                            len(self.agents),\n                            self.agents[i].unit_type,\n                            self.agents[i].pos.x,\n                            self.agents[i].pos.y,\n                        )\n                    )\n\n            for unit in self._obs.observation.raw_data.units:\n                if unit.owner == 2:\n                    self.enemies[len(self.enemies)] = unit\n                    if self._episode_count == 0:\n                        self.max_reward += unit.health_max + unit.shield_max\n\n            if self._episode_count == 0:\n                min_unit_type = min(\n                    unit.unit_type for unit in self.agents.values()\n                )\n                self._init_ally_unit_types(min_unit_type)\n\n            all_agents_created = len(self.agents) == self.n_agents\n            all_enemies_created = len(self.enemies) == self.n_enemies\n\n            self._unit_types = [\n                unit.unit_type for unit in ally_units_sorted\n            ] + [\n                unit.unit_type\n                for unit in self._obs.observation.raw_data.units\n                if unit.owner == 2\n            ]\n\n            if all_agents_created and all_enemies_created:  # all good\n                return\n\n            try:\n                self._controller.step(1)\n                self._obs = self._controller.observe()\n            except (protocol.ProtocolError, protocol.ConnectionError):\n                self.full_restart()\n                self.reset()\n\n    def get_unit_types(self):\n        if self._unit_types is None:\n            warn(\n                \"unit types have not been initialized yet, please call\"\n                \"env.reset() to populate this and call t1286he method again.\"\n            )\n\n        return self._unit_types\n\n    def update_units(self):\n        \"\"\"Update units after an environment step.\n        This function assumes that self._obs is up-to-date.\n        \"\"\"\n        n_ally_alive = 0\n        n_enemy_alive = 0\n\n        # Store previous state\n        self.previous_ally_units = deepcopy(self.agents)\n        self.previous_enemy_units = deepcopy(self.enemies)\n\n        for al_id, al_unit in self.agents.items():\n            updated = False\n            for unit in self._obs.observation.raw_data.units:\n                if al_unit.tag == unit.tag:\n                    self.agents[al_id] = unit\n                    updated = True\n                    n_ally_alive += 1\n                    break\n\n            if not updated:  # dead\n                al_unit.health = 0\n\n        for e_id, e_unit in self.enemies.items():\n            updated = False\n            for unit in self._obs.observation.raw_data.units:\n                if e_unit.tag == unit.tag:\n                    self.enemies[e_id] = unit\n                    updated = True\n                    n_enemy_alive += 1\n                    break\n\n            if not updated:  # dead\n                e_unit.health = 0\n\n        if (\n            n_ally_alive == 0\n            and n_enemy_alive > 0\n            or self.only_medivac_left(ally=True)\n        ):\n            return -1  # lost\n        if (\n            n_ally_alive > 0\n            and n_enemy_alive == 0\n            or self.only_medivac_left(ally=False)\n        ):\n            return 1  # won\n        if n_ally_alive == 0 and n_enemy_alive == 0:\n            return 0\n\n        return None\n\n    def _init_ally_unit_types(self, min_unit_type):\n        \"\"\"Initialise ally unit types. Should be called once from the\n        init_units function.\n        \"\"\"\n        self._min_unit_type = min_unit_type\n        if self.map_type == \"marines\":\n            self.marine_id = min_unit_type\n        elif self.map_type == \"stalkers_and_zealots\":\n            self.stalker_id = min_unit_type\n            self.zealot_id = min_unit_type + 1\n        elif self.map_type == \"colossi_stalkers_zealots\":\n            self.colossus_id = min_unit_type\n            self.stalker_id = min_unit_type + 1\n            self.zealot_id = min_unit_type + 2\n        elif self.map_type == \"MMM\":\n            self.marauder_id = min_unit_type\n            self.marine_id = min_unit_type + 1\n            self.medivac_id = min_unit_type + 2\n        elif self.map_type == \"zealots\":\n            self.zealot_id = min_unit_type\n        elif self.map_type == \"hydralisks\":\n            self.hydralisk_id = min_unit_type\n        elif self.map_type == \"stalkers\":\n            self.stalker_id = min_unit_type\n        elif self.map_type == \"colossus\":\n            self.colossus_id = min_unit_type\n        elif self.map_type == \"bane\":\n            self.baneling_id = min_unit_type\n            self.zergling_id = min_unit_type + 1\n\n    def only_medivac_left(self, ally):\n        \"\"\"Check if only Medivac units are left.\"\"\"\n        if self.map_type != \"MMM\":\n            return False\n\n        if ally:\n            units_alive = [\n                a\n                for a in self.agents.values()\n                if (a.health > 0 and a.unit_type != self.medivac_id)\n            ]\n            if len(units_alive) == 0:\n                return True\n            return False\n        else:\n            units_alive = [\n                a\n                for a in self.enemies.values()\n                if (a.health > 0 and a.unit_type != self.medivac_id)\n            ]\n            if len(units_alive) == 1 and units_alive[0].unit_type == 54:\n                return True\n            return False\n\n    def get_unit_by_id(self, a_id):\n        \"\"\"Get unit by ID.\"\"\"\n        return self.agents[a_id]\n\n    def get_stats(self):\n        stats = {\n            \"battles_won\": self.battles_won,\n            \"battles_game\": self.battles_game,\n            \"battles_draw\": self.timeouts,\n            \"win_rate\": self.battles_won / self.battles_game,\n            \"timeouts\": self.timeouts,\n            \"restarts\": self.force_restarts,\n        }\n        return stats\n\n    def get_env_info(self):\n        env_info = super().get_env_info()\n        env_info[\"agent_features\"] = self.ally_state_attr_names\n        env_info[\"enemy_features\"] = self.enemy_state_attr_names\n        return env_info\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/smac/examples/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Social_Cognition/ToCM/smac/examples/pettingzoo/README.rst",
    "content": "SMAC on PettingZoo\n==================\n\nThis example shows how to run SMAC environments with PettingZoo multi-agent API.\n\nInstructions\n------------\n\nTo get started, first install PettingZoo with ``pip install pettingzoo``.\n\nThe SMAC environment for PettingZoo, ``StarCraft2PZEnv``, can be initialized with two different API templates.\n    * **AEC**: PettingZoo is based in the *Agent Environment Cycle* game model, more information about \"AEC\" can be read in the following `paper <https://arxiv.org/abs/2009.13051>`_. To create a SMAC environment as an \"AEC\" PettingZoo game model use: ::\n        \n        from smac.env.pettingzoo import StarCraft2PZEnv\n        \n        env = StarCraft2PZEnv.env()\n    \n    * **Parallel**: PettingZoo also supports parallel environments where all agents have simultaneous actions and observations. This type of environment can be created as follows: ::\n        \n        from smac.env.pettingzoo import StarCraft2PZEnv\n        \n        env = StarCraft2PZEnv.parallel_env()\n\n`pettingzoo_demo.py` has an example of a SMAC environment being used as a PettingZoo \"AEC\" environment. With `env.render()` it is possible to output an instance of the environment as a frame in pygame. This is useful for debugging purposes.\n\n| See https://www.pettingzoo.ml/api for more documentation.\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/smac/examples/pettingzoo/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Social_Cognition/ToCM/smac/examples/pettingzoo/pettingzoo_demo.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport random\nimport numpy as np\nfrom smac.env.pettingzoo import StarCraft2PZEnv\n\n\ndef main():\n    \"\"\"\n    Runs an env object with random actions.\n    \"\"\"\n    env = StarCraft2PZEnv.env()\n    episodes = 10\n\n    total_reward = 0\n    done = False\n    completed_episodes = 0\n\n    while completed_episodes < episodes:\n        env.reset()\n        for agent in env.agent_iter():\n            env.render()\n\n            obs, reward, done, _ = env.last()\n            total_reward += reward\n            if done:\n                action = None\n            elif isinstance(obs, dict) and \"action_mask\" in obs:\n                action = random.choice(np.flatnonzero(obs[\"action_mask\"]))\n            else:\n                action = env.action_spaces[agent].sample()\n            env.step(action)\n\n        completed_episodes += 1\n\n    env.close()\n\n    print(\"Average total reward\", total_reward / episodes)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/smac/examples/random_agents.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom smac.env import StarCraft2Env\nimport numpy as np\n\n\ndef main():\n    env = StarCraft2Env(map_name=\"8m\")\n    env_info = env.get_env_info()\n\n    n_actions = env_info[\"n_actions\"]\n    n_agents = env_info[\"n_agents\"]\n\n    n_episodes = 10\n\n    for e in range(n_episodes):\n        env.reset()\n        terminated = False\n        episode_reward = 0\n\n        while not terminated:\n            obs = env.get_obs()\n            state = env.get_state()\n            # env.render()  # Uncomment for rendering\n\n            actions = []\n            for agent_id in range(n_agents):\n                avail_actions = env.get_avail_agent_actions(agent_id)\n                avail_actions_ind = np.nonzero(avail_actions)[0]\n                action = np.random.choice(avail_actions_ind)\n                actions.append(action)\n\n            reward, terminated, _ = env.step(actions)\n            episode_reward += reward\n\n        print(\"Total reward in episode {} = {}\".format(e, episode_reward))\n\n    env.close()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/smac/examples/rllib/README.rst",
    "content": "SMAC on RLlib\n=============\n\nThis example shows how to run SMAC environments with RLlib multi-agent.\n\nInstructions\n------------\n\nTo get started, first install RLlib with ``pip install -U ray[rllib]``. You will also need TensorFlow installed.\n\nIn ``run_ppo.py``, each agent will be controlled by an independent PPO policy (the policies share weights). This setup serves as a single-agent baseline for this task.\n\nIn ``run_qmix.py``, the agents are controlled by the multi-agent QMIX policy. This setup is an example of centralized training and decentralized execution.\n\nSee https://ray.readthedocs.io/en/latest/rllib.html for more documentation.\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/smac/examples/rllib/__init__.py",
    "content": "from smac.examples.rllib.env import RLlibStarCraft2Env\nfrom smac.examples.rllib.model import MaskedActionsModel\n\n__all__ = [\"RLlibStarCraft2Env\", \"MaskedActionsModel\"]\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/smac/examples/rllib/env.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport random\n\nimport numpy as np\n\nfrom gym.spaces import Discrete, Box, Dict\n\nfrom ray import rllib\n\nfrom smac.env import StarCraft2Env\n\n\nclass RLlibStarCraft2Env(rllib.MultiAgentEnv):\n    \"\"\"Wraps a smac StarCraft env to be compatible with RLlib multi-agent.\"\"\"\n\n    def __init__(self, **smac_args):\n        \"\"\"Create a new multi-agent StarCraft env compatible with RLlib.\n\n        Arguments:\n            smac_args (dict): Arguments to pass to the underlying\n                smac.env.starcraft.StarCraft2Env instance.\n\n        Examples:\n            >>> from smac.examples.rllib import RLlibStarCraft2Env\n            >>> env = RLlibStarCraft2Env(map_name=\"8m\")\n            >>> print(env.reset())\n        \"\"\"\n\n        self._env = StarCraft2Env(**smac_args)\n        self._ready_agents = []\n        self.observation_space = Dict(\n            {\n                \"obs\": Box(-1, 1, shape=(self._env.get_obs_size(),)),\n                \"action_mask\": Box(\n                    0, 1, shape=(self._env.get_total_actions(),)\n                ),\n            }\n        )\n        self.action_space = Discrete(self._env.get_total_actions())\n\n    def reset(self):\n        \"\"\"Resets the env and returns observations from ready agents.\n\n        Returns:\n            obs (dict): New observations for each ready agent.\n        \"\"\"\n\n        obs_list, state_list = self._env.reset()\n        return_obs = {}\n        for i, obs in enumerate(obs_list):\n            return_obs[i] = {\n                \"action_mask\": np.array(self._env.get_avail_agent_actions(i)),\n                \"obs\": obs,\n            }\n\n        self._ready_agents = list(range(len(obs_list)))\n        return return_obs\n\n    def step(self, action_dict):\n        \"\"\"Returns observations from ready agents.\n\n        The returns are dicts mapping from agent_id strings to values. The\n        number of agents in the env can vary over time.\n\n        Returns\n        -------\n            obs (dict): New observations for each ready agent.\n            rewards (dict): Reward values for each ready agent. If the\n                episode is just started, the value will be None.\n            dones (dict): Done values for each ready agent. The special key\n                \"__all__\" (required) is used to indicate env termination.\n            infos (dict): Optional info values for each agent id.\n        \"\"\"\n\n        actions = []\n        for i in self._ready_agents:\n            if i not in action_dict:\n                raise ValueError(\n                    \"You must supply an action for agent: {}\".format(i)\n                )\n            actions.append(action_dict[i])\n\n        if len(actions) != len(self._ready_agents):\n            raise ValueError(\n                \"Unexpected number of actions: {}\".format(\n                    action_dict,\n                )\n            )\n\n        rew, done, info = self._env.step(actions)\n        obs_list = self._env.get_obs()\n        return_obs = {}\n        for i, obs in enumerate(obs_list):\n            return_obs[i] = {\n                \"action_mask\": self._env.get_avail_agent_actions(i),\n                \"obs\": obs,\n            }\n        rews = {i: rew / len(obs_list) for i in range(len(obs_list))}\n        dones = {i: done for i in range(len(obs_list))}\n        dones[\"__all__\"] = done\n        infos = {i: info for i in range(len(obs_list))}\n\n        self._ready_agents = list(range(len(obs_list)))\n        return return_obs, rews, dones, infos\n\n    def close(self):\n        \"\"\"Close the environment\"\"\"\n        self._env.close()\n\n    def seed(self, seed):\n        random.seed(seed)\n        np.random.seed(seed)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/smac/examples/rllib/model.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tensorflow as tf\n\nfrom ray.rllib.models import Model\nfrom ray.rllib.models.tf.misc import normc_initializer\n\n\nclass MaskedActionsModel(Model):\n    \"\"\"Custom RLlib model that emits -inf logits for invalid actions.\n\n    This is used to handle the variable-length StarCraft action space.\n    \"\"\"\n\n    def _build_layers_v2(self, input_dict, num_outputs, options):\n        action_mask = input_dict[\"obs\"][\"action_mask\"]\n        if num_outputs != action_mask.shape[1].value:\n            raise ValueError(\n                \"This model assumes num outputs is equal to max avail actions\",\n                num_outputs,\n                action_mask,\n            )\n\n        # Standard fully connected network\n        last_layer = input_dict[\"obs\"][\"obs\"]\n        hiddens = options.get(\"fcnet_hiddens\")\n        for i, size in enumerate(hiddens):\n            label = \"fc{}\".format(i)\n            last_layer = tf.layers.dense(\n                last_layer,\n                size,\n                kernel_initializer=normc_initializer(1.0),\n                activation=tf.nn.tanh,\n                name=label,\n            )\n        action_logits = tf.layers.dense(\n            last_layer,\n            num_outputs,\n            kernel_initializer=normc_initializer(0.01),\n            activation=None,\n            name=\"fc_out\",\n        )\n\n        # Mask out invalid actions (use tf.float32.min for stability)\n        inf_mask = tf.maximum(tf.log(action_mask), tf.float32.min)\n        masked_logits = inf_mask + action_logits\n\n        return masked_logits, last_layer\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/smac/examples/rllib/run_ppo.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n\"\"\"Example of running StarCraft2 with RLlib PPO.\n\nIn this setup, each agent will be controlled by an independent PPO policy.\nHowever the policies share weights.\n\nIncrease the level of parallelism by changing --num-workers.\n\"\"\"\n\nimport argparse\n\nimport ray\nfrom ray.tune import run_experiments, register_env\nfrom ray.rllib.models import ModelCatalog\n\nfrom smac.examples.rllib.env import RLlibStarCraft2Env\nfrom smac.examples.rllib.model import MaskedActionsModel\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--num-iters\", type=int, default=100)\n    parser.add_argument(\"--num-workers\", type=int, default=2)\n    parser.add_argument(\"--map-name\", type=str, default=\"8m\")\n    args = parser.parse_args()\n\n    ray.init()\n\n    register_env(\"smac\", lambda smac_args: RLlibStarCraft2Env(**smac_args))\n    ModelCatalog.register_custom_model(\"mask_model\", MaskedActionsModel)\n\n    run_experiments(\n        {\n            \"ppo_sc2\": {\n                \"run\": \"PPO\",\n                \"env\": \"smac\",\n                \"stop\": {\n                    \"training_iteration\": args.num_iters,\n                },\n                \"config\": {\n                    \"num_workers\": args.num_workers,\n                    \"observation_filter\": \"NoFilter\",  # breaks the action mask\n                    \"vf_share_layers\": True,  # no separate value model\n                    \"env_config\": {\n                        \"map_name\": args.map_name,\n                    },\n                    \"model\": {\n                        \"custom_model\": \"mask_model\",\n                    },\n                },\n            },\n        }\n    )\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/smac/examples/rllib/run_qmix.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n\"\"\"Example of running StarCraft2 with RLlib QMIX.\n\nThis assumes all agents are homogeneous. The agents are grouped and assigned\nto the multi-agent QMIX policy. Note that the default hyperparameters for\nRLlib QMIX are different from pymarl's QMIX.\n\"\"\"\n\nimport argparse\nfrom gym.spaces import Tuple\n\nimport ray\nfrom ray.tune import run_experiments, register_env\n\nfrom smac.examples.rllib.env import RLlibStarCraft2Env\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--num-iters\", type=int, default=100)\n    parser.add_argument(\"--num-workers\", type=int, default=2)\n    parser.add_argument(\"--map-name\", type=str, default=\"8m\")\n    args = parser.parse_args()\n\n    def env_creator(smac_args):\n        env = RLlibStarCraft2Env(**smac_args)\n        agent_list = list(range(env._env.n_agents))\n        grouping = {\n            \"group_1\": agent_list,\n        }\n        obs_space = Tuple([env.observation_space for i in agent_list])\n        act_space = Tuple([env.action_space for i in agent_list])\n        return env.with_agent_groups(\n            grouping, obs_space=obs_space, act_space=act_space\n        )\n\n    ray.init()\n    register_env(\"sc2_grouped\", env_creator)\n\n    run_experiments(\n        {\n            \"qmix_sc2\": {\n                \"run\": \"QMIX\",\n                \"env\": \"sc2_grouped\",\n                \"stop\": {\n                    \"training_iteration\": args.num_iters,\n                },\n                \"config\": {\n                    \"num_workers\": args.num_workers,\n                    \"env_config\": {\n                        \"map_name\": args.map_name,\n                    },\n                },\n            },\n        }\n    )\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/train.py",
    "content": "import argparse\nimport os\nimport sys\n\n\nos.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\nfrom agent.runners.ToCMRunner import ToCMRunner\nfrom configs import Experiment, SimpleObservationConfig, NearRewardConfig, DeadlockPunishmentConfig, \\\n    RewardsComposerConfig\nfrom configs.EnvConfigs import StarCraftConfig, EnvCurriculumConfig, MPEConfig  # TODO\nfrom configs.ToCM.ToCMControllerConfig import ToCMControllerConfig\nfrom configs.ToCM.ToCMLearnerConfig import ToCMLearnerConfig\nfrom environments import Env\nfrom utils.util import get_dim_from_space, get_cent_act_dim\nimport torch\nimport random\nimport numpy as np\nimport setproctitle\nsetproctitle.setproctitle(\"MPE_obs_2_hetero\")\ndef occumpy_mem(cuda_device):\n    total, used = os.popen(\n        '\"/usr/bin/nvidia-smi\" --query-gpu=memory.total,memory.used --format=csv,nounits,noheader').read().strip().split(\n        \"\\n\")[int(cuda_device)].split(',')\n    total = int(total)\n    used = int(used)\n    cc = 0.85\n    block_mem = int((total - used) * cc)\n    x = torch.cuda.FloatTensor(256, 1024, block_mem)\n    del x\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--env', type=str, default=\"mpe\", help='starcraft or mpe')  # TODO\n    parser.add_argument('--env_name', type=str, default=\"hetero_spread\", help='Specific setting')  # TODO\n    # star : 2s_vs_1sc MMM 2s3z 3s_vs_3z 3s5z_vs_3s6z simple_spread\n    parser.add_argument('--n_workers', type=int, default=4, help='Number of workers')\n    parser.add_argument('--device', type=str, default='cuda', help='device')\n    parser.add_argument('--seed', type=int, default=50, help='random')\n    # TODO num_landmarks num_adversaries episode_length num_good_agents\n    parser.add_argument('--num_agents', type=int, default=2, help='mpe_num_agents')  # simple_adversary\n    parser.add_argument('--num_adversaries', type=int, default=None, help='mpe_num_adversaries')\n    parser.add_argument('--num_good_agents', type=int, default=None, help='mpe_num_good_agents')\n    parser.add_argument('--num_landmarks', type=int, default=2, help='mpe_num_landmarks')\n    parser.add_argument('--episode_length', type=int, default=25, help='mpe_episode_length')\n    parser.add_argument('--num_rollout_threads', type=int, default=128, help='mpe_episode_length')\n    parser.add_argument('--benchmark', type=bool, default=False, help='mpe_use_benchmark')\n    return parser.parse_args()  # 为啥直接跳到prepare_starcraft_configs函数里了\n\n\ndef setup_seed(seed):\n    random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n\ndef train_ToCM(exp, n_workers):  # no env.episode_length\n    runner = ToCMRunner(exp.env_config, exp.learner_config, exp.controller_config, n_workers)\n    runner.run(exp.steps, exp.episodes)  # 10 ** 10 50000\n\n\ndef get_env_info(configs, env):\n    for config in configs:\n        config.IN_DIM = env.n_obs  # 17 2s_vs_1sc\n        config.ACTION_SIZE = env.n_actions  # 7 2s_vs_1sc\n    env.close()\n\n\ndef get_env_info_mpe(configs, env):  # add to ToCM controller and worker\n    # TODO cent_obs_dim and cent_action_dim  use share_policy\n    for config in configs:\n        config.CENT_OBS_DIM = get_dim_from_space(env.env.share_observation_space[0])  # 54=num_agents*IN_DIM\n        config.CENT_ACT_DIM = get_cent_act_dim(env.env.action_space)  # 15=num_agents*ACTION_SIZE\n        config.IN_DIM = get_dim_from_space(env.env.observation_space[0])  # dim 18\n        config.ACTION_SIZE = get_dim_from_space(env.env.action_space[0])  # dim 5\n    env.close()\n\n\n\n\ndef prepare_starcraft_configs(env_name, device):\n    # env_name '3s5z_vs_3s6z'  device 'cuda:6'  RANDOM_SEED 1   args.n_workers 2\n    # args.env 'starcraft'    args.env_name '3s5z_vs_3s6z'    args.device 'cuda:6'    args.seed 1\n    agent_configs = [ToCMControllerConfig(env_name, RANDOM_SEED, device),\n                     ToCMLearnerConfig(env_name, RANDOM_SEED, device)]\n    env_config = StarCraftConfig(env_name, RANDOM_SEED)\n    get_env_info(agent_configs, env_config.create_env())\n    return {\"env_config\": (env_config, 100),\n            \"controller_config\": agent_configs[0],\n            \"learner_config\": agent_configs[1],\n            \"reward_config\": None,\n            \"obs_builder_config\": None}\n\n\ndef prepare_mpe_configs(arg):\n    agent_configs = [ToCMControllerConfig(arg.env_name, RANDOM_SEED, arg.device),\n                     ToCMLearnerConfig(arg.env_name, RANDOM_SEED, arg.device)]\n    env_config = MPEConfig(arg)\n    get_env_info_mpe(agent_configs, env_config.create_env())\n    return {\"env_config\": (env_config, 100),\n            \"controller_config\": agent_configs[0],\n            \"learner_config\": agent_configs[1],\n            \"reward_config\": None,\n            \"obs_builder_config\": None}  # TODO whether has reward config and obs builder config\n\n\n\nif __name__ == \"__main__\":\n    # occumpy_mem(2)\n    # RANDOM_SEED = 23  # RANDOM_SEED 1\n    args = parse_args()\n    # print(\"args=\", args)\n    RANDOM_SEED = args.seed\n    setup_seed(RANDOM_SEED)  # TODO\n    # args.env_name '3s5z_vs_3s6z' args.device 'cuda:6'  args.seed 1  args.n_workers 2\n    if args.env == Env.STARCRAFT:\n        configs = prepare_starcraft_configs(args.env_name, args.device)\n    elif args.env == Env.MPE:\n        configs = prepare_mpe_configs(args)\n        # as env is mpe env_name is simple_adversary\n    else:\n        raise Exception(\"Unknown environment\")\n\n    configs[\"env_config\"][0].ENV_TYPE = Env(args.env)  # 转化为字符串\n    configs[\"learner_config\"].ENV_TYPE = Env(args.env)\n    configs[\"controller_config\"].ENV_TYPE = Env(args.env)\n\n    exp = Experiment(steps=10 ** 10,\n                     episodes=50000,\n                     random_seed=RANDOM_SEED,\n                     env_config=EnvCurriculumConfig(*zip(configs[\"env_config\"]), Env(args.env), args.device,  # TODO\n                                                    obs_builder_config=configs[\"obs_builder_config\"],\n                                                    reward_config=configs[\"reward_config\"]),\n                     controller_config=configs[\"controller_config\"],\n                     learner_config=configs[\"learner_config\"])\n    # print(\"exp=\", exp)\n    train_ToCM(exp, n_workers=args.n_workers)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/utils/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Social_Cognition/ToCM/utils/mlp_buffer.py",
    "content": "import numpy as np\nfrom utils.util import get_dim_from_space\nfrom utils.segment_tree import SumSegmentTree, MinSegmentTree\n\n\ndef _cast(x):\n    return x.transpose(1, 0, 2)\n\n\nclass MlpReplayBuffer(object):\n    def __init__(self, policy_info, policy_agents, buffer_size, use_same_share_obs, use_avail_acts,\n                 use_reward_normalization=False):\n        \"\"\"\n        Replay buffer class for training MLP policies.\n\n        :param policy_info: (dict) maps policy id to a dict containing information about corresponding policy.\n        :param policy_agents: (dict) maps policy id to list of agents controled by corresponding policy.\n        :param buffer_size: (int) max number of transitions to store in the buffer.\n        :param use_same_share_obs: (bool) whether all agents share the same centralized observation.\n        :param use_avail_acts: (bool) whether to store what actions are available.\n        :param use_reward_normalization: (bool) whether to use reward normalization.\n        \"\"\"\n\n        self.policy_info = policy_info\n\n        self.policy_buffers = {p_id: MlpPolicyBuffer(buffer_size,\n                                                     len(policy_agents[p_id]),\n                                                     self.policy_info[p_id]['obs_space'],\n                                                     self.policy_info[p_id]['share_obs_space'],\n                                                     self.policy_info[p_id]['act_space'],\n                                                     use_same_share_obs,\n                                                     use_avail_acts,\n                                                     use_reward_normalization)\n                               for p_id in self.policy_info.keys()}\n\n    def __len__(self):\n        return self.policy_buffers['policy_0'].filled_i\n\n    def insert(self, num_insert_steps, obs, share_obs, acts, rewards,\n               next_obs, next_share_obs, dones, dones_env, valid_transition,\n               avail_acts, next_avail_acts):\n        \"\"\"\n        Insert  a set of transitions into buffer. If the buffer size overflows, old transitions are dropped.\n\n        :param num_insert_steps: (int) number of transitions to be added to buffer\n        :param obs: (dict) maps policy id to numpy array of observations of agents corresponding to that policy\n        :param share_obs: (dict) maps policy id to numpy array of centralized observation corresponding to that policy\n        :param acts: (dict) maps policy id to numpy array of actions of agents corresponding to that policy\n        :param rewards: (dict) maps policy id to numpy array of rewards of agents corresponding to that policy\n        :param next_obs: (dict) maps policy id to numpy array of next step observations of agents corresponding to that policy\n        :param next_share_obs: (dict) maps policy id to numpy array of next step centralized observations corresponding to that policy\n        :param dones: (dict) maps policy id to numpy array of terminal status of agents corresponding to that policy\n        :param dones_env: (dict) maps policy id to numpy array of terminal status of env\n        :param valid_transition: (dict) maps policy id to numpy array of whether the corresponding transition is valid of agents corresponding to that policy\n        :param avail_acts: (dict) maps policy id to numpy array of available actions of agents corresponding to that policy\n        :param next_avail_acts: (dict) maps policy id to numpy array of next step available actions of agents corresponding to that policy\n\n        :return: (np.ndarray) indexes in which the new transitions were placed.\n        \"\"\"\n        idx_range = None\n        for p_id in self.policy_info.keys():\n            idx_range = self.policy_buffers[p_id].insert(num_insert_steps,\n                                                         np.array(obs[p_id]), np.array(share_obs[p_id]),\n                                                         np.array(acts[p_id]), np.array(rewards[p_id]),\n                                                         np.array(next_obs[p_id]), np.array(next_share_obs[p_id]),\n                                                         np.array(dones[p_id]), np.array(dones_env[p_id]),\n                                                         np.array(valid_transition[p_id]),\n                                                         np.array(avail_acts[p_id]), np.array(next_avail_acts[p_id]))\n        return idx_range\n\n    def sample(self, batch_size):\n        \"\"\"\n        Sample a set of transitions from buffer, uniformly at random.\n        :param batch_size: (int) number of transitions to sample from buffer.\n\n        :return: obs: (dict) maps policy id to sampled observations corresponding to that policy\n        :return: share_obs: (dict) maps policy id to sampled observations corresponding to that policy\n        :return: acts: (dict) maps policy id to sampled actions corresponding to that policy\n        :return: rewards: (dict) maps policy id to sampled rewards corresponding to that policy\n        :return: next_obs: (dict) maps policy id to sampled next step observations corresponding to that policy\n        :return: next_share_obs: (dict) maps policy id to sampled next step centralized observations corresponding to that policy\n        :return: dones: (dict) maps policy id to sampled terminal status of agents corresponding to that policy\n        :return: dones_env: (dict) maps policy id to sampled environment terminal status corresponding to that policy\n        :return: valid_transition: (dict) maps policy_id to whether each sampled transition is valid or not (invalid if corresponding agent is dead)\n        :return: avail_acts: (dict) maps policy_id to available actions corresponding to that policy\n        :return: next_avail_acts: (dict) maps policy_id to next step available actions corresponding to that policy\n        \"\"\"\n        inds = np.random.choice(len(self), batch_size)\n        obs, share_obs, acts, rewards, next_obs, next_share_obs, dones, dones_env, valid_transition, avail_acts, next_avail_acts = {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}\n        for p_id in self.policy_info.keys():\n            obs[p_id], share_obs[p_id], acts[p_id], rewards[p_id], next_obs[p_id], next_share_obs[p_id], \\\n            dones[p_id], dones_env[p_id], valid_transition[p_id], avail_acts[p_id], next_avail_acts[p_id] = \\\n            self.policy_buffers[p_id].sample_inds(inds)\n\n        return obs, share_obs, acts, rewards, next_obs, next_share_obs, dones, dones_env, valid_transition, avail_acts, next_avail_acts, None, None\n\n\nclass MlpPolicyBuffer(object):\n\n    def __init__(self, buffer_size, num_agents, obs_space, share_obs_space, act_space, use_same_share_obs,\n                 use_avail_acts, use_reward_normalization=False):\n        \"\"\"\n        Buffer class containing buffer data corresponding to a single policy.\n\n        :param buffer_size: (int) max number of transitions to store in buffer.\n        :param num_agents: (int) number of agents controlled by the policy.\n        :param obs_space: (gym.Space) observation space of the environment.\n        :param share_obs_space: (gym.Space) centralized observation space of the environment.\n        :param act_space: (gym.Space) action space of the environment.\n        :use_same_share_obs: (bool) whether all agents share the same centralized observation.\n        :use_avail_acts: (bool) whether to store what actions are available.\n        :param use_reward_normalization: (bool) whether to use reward normalization.\n        \"\"\"\n        self.buffer_size = buffer_size\n        self.num_agents = num_agents\n        self.use_same_share_obs = use_same_share_obs\n        self.use_avail_acts = use_avail_acts\n        self.use_reward_normalization = use_reward_normalization\n        self.filled_i = 0\n        self.current_i = 0\n\n        # obs\n        if obs_space.__class__.__name__ == 'Box':\n            obs_shape = obs_space.shape\n            share_obs_shape = share_obs_space.shape\n        elif obs_space.__class__.__name__ == 'list':\n            obs_shape = obs_space\n            share_obs_shape = share_obs_space\n        else:\n            raise NotImplementedError\n\n        self.obs = np.zeros(\n            (self.buffer_size, self.num_agents, obs_shape[0]), dtype=np.float32)\n\n        if self.use_same_share_obs:\n            self.share_obs = np.zeros((self.buffer_size, share_obs_shape[0]), dtype=np.float32)\n        else:\n            self.share_obs = np.zeros((self.buffer_size, self.num_agents, share_obs_shape[0]), dtype=np.float32)\n\n        self.next_obs = np.zeros_like(self.obs, dtype=np.float32)\n        self.next_share_obs = np.zeros_like(self.share_obs, dtype=np.float32)\n\n        # action\n        act_dim = np.sum(get_dim_from_space(act_space))\n        self.acts = np.zeros((self.buffer_size, self.num_agents, act_dim), dtype=np.float32)\n        if self.use_avail_acts:\n            self.avail_acts = np.ones_like(self.acts, dtype=np.float32)\n            self.next_avail_acts = np.ones_like(self.avail_acts, dtype=np.float32)\n\n        # rewards\n        self.rewards = np.zeros((self.buffer_size, self.num_agents, 1), dtype=np.float32)\n\n        # default to done being True\n        self.dones = np.ones_like(self.rewards, dtype=np.float32)\n        self.dones_env = np.ones((self.buffer_size, 1), dtype=np.float32)\n        self.valid_transition = np.zeros_like(self.dones, dtype=np.float32)\n\n    def __len__(self):\n        return self.filled_i\n\n    def insert(self, num_insert_steps, obs, share_obs, acts, rewards,\n               next_obs, next_share_obs, dones, dones_env, valid_transition,\n               avail_acts=None, next_avail_acts=None):\n        \"\"\"\n        Insert  a set of transitions corresponding to this policy into buffer. If the buffer size overflows, old transitions are dropped.\n\n        :param num_insert_steps: (int) number of transitions to be added to buffer\n        :param obs: (np.ndarray) observations of agents corresponding to this policy.\n        :param share_obs: (np.ndarray) centralized observations of agents corresponding to this policy.\n        :param acts: (np.ndarray) actions of agents corresponding to this policy.\n        :param rewards: (np.ndarray) rewards of agents corresponding to this policy.\n        :param next_obs: (np.ndarray) next step observations of agents corresponding to this policy.\n        :param next_share_obs: (np.ndarray) next step centralized observations of agents corresponding to this policy.\n        :param dones: (np.ndarray) terminal status of agents corresponding to this policy.\n        :param dones_env: (np.ndarray) environment terminal status.\n        :param valid_transition: (np.ndarray) whether each transition is valid or not (invalid if agent was dead during transition)\n        :param avail_acts: (np.ndarray) available actions of agents corresponding to this policy.\n        :param next_avail_acts: (np.ndarray) next step available actions of agents corresponding to this policy.\n\n        :return: (np.ndarray) indexes of the buffer the new transitions were placed in.\n        \"\"\"\n\n        # obs: [step, episode, agent, dim]\n        assert obs.shape[0] == num_insert_steps, (\"different size!\")\n\n        if self.current_i + num_insert_steps <= self.buffer_size:\n            idx_range = np.arange(self.current_i, self.current_i + num_insert_steps)\n        else:\n            num_left_steps = self.current_i + num_insert_steps - self.buffer_size\n            idx_range = np.concatenate((np.arange(self.current_i, self.buffer_size), np.arange(num_left_steps)))\n\n        self.obs[idx_range] = obs.copy()\n        self.share_obs[idx_range] = share_obs.copy()\n        self.acts[idx_range] = acts.copy()\n        self.rewards[idx_range] = rewards.copy()\n        self.next_obs[idx_range] = next_obs.copy()\n        self.next_share_obs[idx_range] = next_share_obs.copy()\n        self.dones[idx_range] = dones.copy()\n        self.dones_env[idx_range] = dones_env.copy()\n        self.valid_transition[idx_range] = valid_transition.copy()\n        if self.use_avail_acts:\n            self.avail_acts[idx_range] = avail_acts.copy()\n            self.next_avail_acts[idx_range] = next_avail_acts.copy()\n\n        self.current_i = idx_range[-1] + 1\n        self.filled_i = min(self.filled_i + len(idx_range), self.buffer_size)\n\n        return idx_range\n\n    def sample_inds(self, sample_inds):\n        \"\"\"\n        Sample a set of transitions from buffer from the specified indices.\n        :param sample_inds: (np.ndarray) indices of samples to return from buffer.\n\n        :return: obs: (np.ndarray) sampled observations corresponding to that policy\n        :return: share_obs: (np.ndarray) sampled observations corresponding to that policy\n        :return: acts: (np.ndarray) sampled actions corresponding to that policy\n        :return: rewards: (np.ndarray) sampled rewards corresponding to that policy\n        :return: next_obs: (np.ndarray) sampled next step observations corresponding to that policy\n        :return: next_share_obs: (np.ndarray) sampled next step centralized observations corresponding to that policy\n        :return: dones: (np.ndarray) sampled terminal status of agents corresponding to that policy\n        :return: dones_env: (np.ndarray) sampled environment terminal status corresponding to that policy\n        :return: valid_transition: (np.ndarray) whether each sampled transition is valid or not (invalid if corresponding agent is dead)\n        :return: avail_acts: (np.ndarray) sampled available actions corresponding to that policy\n        :return: next_avail_acts: (np.ndarray) sampled next step available actions corresponding to that policy\n        \"\"\"\n        obs = _cast(self.obs[sample_inds])\n        acts = _cast(self.acts[sample_inds])\n        if self.use_reward_normalization:\n            mean_reward = self.rewards[:self.filled_i].mean()\n            std_reward = self.rewards[:self.filled_i].std()\n            rewards = _cast(\n                (self.rewards[sample_inds] - mean_reward) / std_reward)\n        else:\n            rewards = _cast(self.rewards[sample_inds])\n\n        next_obs = _cast(self.next_obs[sample_inds])\n\n        if self.use_same_share_obs:\n            share_obs = self.share_obs[sample_inds]\n            next_share_obs = self.next_share_obs[sample_inds]\n        else:\n            share_obs = _cast(self.share_obs[sample_inds])\n            next_share_obs = _cast(self.next_share_obs[sample_inds])\n\n        dones = _cast(self.dones[sample_inds])\n        dones_env = self.dones_env[sample_inds]\n        valid_transition = _cast(self.valid_transition[sample_inds])\n\n        if self.use_avail_acts:\n            avail_acts = _cast(self.avail_acts[sample_inds])\n            next_avail_acts = _cast(self.next_avail_acts[sample_inds])\n        else:\n            avail_acts = None\n            next_avail_acts = None\n\n        return obs, share_obs, acts, rewards, next_obs, next_share_obs, dones, dones_env, valid_transition, avail_acts, next_avail_acts\n\n\nclass PrioritizedMlpReplayBuffer(MlpReplayBuffer):\n    def __init__(self, alpha, policy_info, policy_agents, buffer_size, use_same_share_obs, use_avail_acts,\n                 use_reward_normalization=False):\n        \"\"\"Prioritized replay buffer class for training MLP policies. See parent class.\"\"\"\n        super(PrioritizedMlpReplayBuffer, self).__init__(policy_info, policy_agents,\n                                                         buffer_size, use_same_share_obs, use_avail_acts,\n                                                         use_reward_normalization)\n        self.alpha = alpha\n        self.policy_info = policy_info\n        it_capacity = 1\n        while it_capacity < buffer_size:\n            it_capacity *= 2\n\n        self._it_sums = {p_id: SumSegmentTree(it_capacity) for p_id in self.policy_info.keys()}\n        self._it_mins = {p_id: MinSegmentTree(it_capacity) for p_id in self.policy_info.keys()}\n        self.max_priorities = {p_id: 1.0 for p_id in self.policy_info.keys()}\n\n    def insert(self, num_insert_steps, obs, share_obs, acts, rewards, next_obs, next_share_obs, dones, dones_env,\n               valid_transition, avail_acts=None, next_avail_acts=None):\n        \"\"\"See parent class.\"\"\"\n        idx_range = super().insert(num_insert_steps, obs, share_obs, acts, rewards, next_obs, next_share_obs, dones,\n                                   dones_env, valid_transition, avail_acts, next_avail_acts)\n        for idx in range(idx_range[0], idx_range[1]):\n            for p_id in self.policy_info.keys():\n                self._it_sums[p_id][idx] = self.max_priorities[p_id] ** self.alpha\n                self._it_mins[p_id][idx] = self.max_priorities[p_id] ** self.alpha\n\n        return idx_range\n\n    def _sample_proportional(self, batch_size, p_id=None):\n        total = self._it_sums[p_id].sum(0, len(self) - 1)\n        mass = np.random.random(size=batch_size) * total\n        idx = self._it_sums[p_id].find_prefixsum_idx(mass)\n        return idx\n\n    def sample(self, batch_size, beta=0, p_id=None):\n        \"\"\"\n        Sample a set of transitions from buffer; probability of choosing a given sample is proportional to its priority.\n        :param batch_size: (int) number of transitions to sample.\n        :param beta: (float) controls the amount of prioritization to apply.\n        :param p_id: (str) policy which will be updated using the samples.\n\n        :return: See parent class.\n        \"\"\"\n        assert len(self) > batch_size, \"Not enough samples in the buffer!\"\n        assert beta > 0\n\n        batch_inds = self._sample_proportional(batch_size, p_id)\n\n        p_min = self._it_mins[p_id].min() / self._it_sums[p_id].sum()\n        max_weight = (p_min * len(self)) ** (-beta)\n        p_sample = self._it_sums[p_id][batch_inds] / self._it_sums[p_id].sum()\n        weights = (p_sample * len(self)) ** (-beta) / max_weight\n\n        obs, share_obs, acts, rewards, next_obs, next_share_obs, dones, dones_env, valid_transition, avail_acts, next_avail_acts = {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}\n        for p_id in self.policy_info.keys():\n            p_buffer = self.policy_buffers[p_id]\n            obs[p_id], share_obs[p_id], acts[p_id], rewards[p_id], next_obs[p_id], next_share_obs[p_id], dones[p_id], \\\n            dones_env[p_id], valid_transition[p_id], avail_acts[p_id], next_avail_acts[p_id] = p_buffer.sample_inds(batch_inds)\n\n        return obs, share_obs, acts, rewards, next_obs, next_share_obs, dones, dones_env, valid_transition, avail_acts, next_avail_acts, weights, batch_inds\n\n    def update_priorities(self, idxes, priorities, p_id=None):\n        \"\"\"\n        Update priorities of sampled transitions.\n        sets priority of transition at index idxes[i] in buffer\n        to priorities[i].\n        :param idxes: ([int]) List of idxes of sampled transitions\n        :param priorities: ([float]) List of updated priorities corresponding to transitions at the sampled idxes\n            denoted by variable `idxes`.\n        \"\"\"\n        assert len(idxes) == len(priorities)\n        assert np.min(priorities) > 0\n        assert np.min(idxes) >= 0\n        assert np.max(idxes) < len(self)\n\n        self._it_sums[p_id][idxes] = priorities ** self.alpha\n        self._it_mins[p_id][idxes] = priorities ** self.alpha\n\n        self.max_priorities[p_id] = max(\n            self.max_priorities[p_id], np.max(priorities))\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/utils/mlp_nstep_buffer.py",
    "content": "import numpy as np\nimport torch\nimport random\n\n\nclass NStepReplayBuffer:\n    def __init__(self, max_size, episode_len, n, policy_ids, agent_ids, policy_agents, policy_obs_dim, policy_act_dim, gamma):\n        self.max_size = max_size\n        self.episode_len = episode_len\n        # n for n-step returns\n        self.n = n\n        self.policy_ids = policy_ids\n        self.agent_ids = agent_ids\n        self.policy_agents = policy_agents\n        self.policy_buffers = {\n            p_id: NStepPolicyBuffer(p_id, self.max_size, episode_len, n, self.policy_agents[p_id], policy_obs_dim[p_id],\n                                    policy_act_dim[p_id], gamma) for p_id in self.policy_ids}\n        self.num_episodes = 0\n        self.num_transitions = 0\n\n    def push(self, t_env, observation_batch, action_batch, reward_batch, next_observation_batch, dones_batch, finish_episodes):\n        batch_size = observation_batch.shape[0]\n\n        observations = {a_id: np.vstack(\n            [obs[a_id] for obs in observation_batch]) for a_id in self.agent_ids}\n        actions = {a_id: np.vstack([act[a_id] for act in action_batch])\n                   for a_id in self.agent_ids}\n        rewards = {a_id: np.vstack([rew[a_id] for rew in reward_batch])\n                   for a_id in self.agent_ids}\n        n_observations = {a_id: np.vstack(\n            [nobs[a_id] for nobs in next_observation_batch]) for a_id in self.agent_ids}\n        if finish_episodes:\n            dones = {a_id: np.ones_like(rewards[a_id]).astype(\n                bool) for a_id in self.agent_ids}\n        else:\n            dones = {a_id: np.vstack(\n                [done[a_id] for done in dones_batch]) for a_id in self.agent_ids}\n\n        for p_id in self.policy_ids:\n            self.policy_buffers[p_id].push(\n                batch_size, t_env, observations, actions, rewards, n_observations, dones, finish_episodes)\n\n        assert len(\n            set([p_buffer.num_episodes for p_buffer in self.policy_buffers.values()])) == 1\n        assert len(\n            set([p_buffer.num_transitions for p_buffer in self.policy_buffers.values()])) == 1\n\n        self.num_episodes = self.policy_buffers[self.policy_ids[0]].num_episodes\n        self.num_transitions = self.policy_buffers[self.policy_ids[0]\n                                                   ].num_transitions\n\n    def sample(self, batch_size):\n        assert self.num_transitions > batch_size, \"Cannot sample with no completed episodes in the buffer!\"\n\n        chunk_starts = np.random.choice(self.episode_len, batch_size)\n        batch_inds = np.random.choice(self.num_episodes, batch_size)\n\n        obs = {}\n        act = {}\n        rew = {}\n        nobs = {}\n        dones = {}\n\n        for p_id in self.policy_ids:\n            p_buffer = self.policy_buffers[p_id]\n            o, a, r, no, d = p_buffer.get(batch_inds, chunk_starts)\n            obs[p_id] = o\n            act[p_id] = a\n            rew[p_id] = r\n            nobs[p_id] = no\n            dones[p_id] = d\n\n        return obs, act, rew, nobs, dones\n\n\nclass NStepPolicyBuffer:\n    def __init__(self, policy_id, max_size, episode_len, n, policy_agents, obs_dim, act_dim, gamma):\n        self.max_size = max_size\n        self.n = n\n        self.num_agents = len(policy_agents)\n        self.policy_id = policy_id\n        self.episode_len = episode_len\n        self.agent_ids = policy_agents\n        self.gamma = gamma\n        random.shuffle(self.agent_ids)\n\n        self.observations = np.zeros(\n            (self.num_agents, max_size, episode_len, obs_dim))\n        self.actions = np.zeros(\n            (self.num_agents, max_size, episode_len, act_dim))\n        self.rewards = np.zeros(\n            (self.num_agents, max_size, episode_len + n - 1, 1))\n        self.next_observations = np.zeros(\n            (self.num_agents, max_size, episode_len + n - 1, obs_dim))\n        self.dones = np.ones(\n            (self.num_agents, max_size, episode_len + n - 1, 1))\n\n        self.num_episodes = 0\n        self.num_transitions = 0\n\n    def push(self, num_envs, t_env, observation_batch, action_batch, reward_batch, next_observation_batch,\n             dones_batch, finish_episodes):\n        assert t_env < self.episode_len\n\n        if t_env == 0:\n            # shuffle the agent ids at the start of a new episode batch\n            random.shuffle(self.agent_ids)\n\n        if t_env == 0 and self.num_episodes + num_envs > self.max_size:\n            diff = self.num_episodes + num_envs - self.max_size\n            self.observations = np.roll(self.observations, -diff, axis=1)\n            self.actions = np.roll(self.actions, -diff, axis=1)\n            self.rewards = np.roll(self.rewards, -diff, axis=1)\n            self.next_observations = np.roll(\n                self.next_observations, -diff, axis=1)\n            self.dones = np.roll(self.dones, -diff, axis=1)\n\n            self.num_episodes -= diff\n\n        for i in range(self.num_agents):\n            if finish_episodes:\n                dones = np.ones_like(dones_batch[self.agent_ids[i]])\n            else:\n                dones = dones_batch[self.agent_ids[i]]\n\n            self.observations[i, self.num_episodes: self.num_episodes +\n                              num_envs, t_env, :] = observation_batch[self.agent_ids[i]]\n            self.actions[i, self.num_episodes: self.num_episodes +\n                         num_envs, t_env, :] = action_batch[self.agent_ids[i]]\n            self.rewards[i, self.num_episodes: self.num_episodes +\n                         num_envs, t_env, :] = reward_batch[self.agent_ids[i]]\n            self.next_observations[i, self.num_episodes: self.num_episodes + num_envs, t_env, :] = next_observation_batch[\n                self.agent_ids[i]]\n            self.dones[i, self.num_episodes: self.num_episodes +\n                       num_envs, t_env, :] = dones\n\n        self.num_transitions += num_envs\n        if finish_episodes:\n            self.num_episodes += num_envs\n\n    def get(self, batch_inds, start_inds):\n        batch_inds_col = batch_inds[:, None]\n        start_inds_col = start_inds[:, None]\n        nstep_inds = start_inds_col + np.arange(self.n)\n\n        obs = self.observations[:, batch_inds, start_inds, :]\n        acts = self.actions[:, batch_inds, start_inds, :]\n\n        # get the n-step rewards and weight each by exponentiated discounts\n        rews = self.rewards[:, batch_inds_col, nstep_inds, 0]\n        rews = rews * \\\n            np.power((np.ones(self.n) * self.gamma), np.arange(self.n))\n        # sum the n-step rewards: rewards for terminal states are pre-set to 0, so don't need to mask\n        rews = np.sum(rews, axis=2).reshape(\n            self.num_agents, len(batch_inds), 1)\n        # get the nobs of the nth\n        nobs = self.next_observations[:,\n                                      batch_inds, start_inds + self.n - 1, :]\n        dones = self.dones[:, batch_inds, start_inds + self.n - 1, :]\n\n        return torch.from_numpy(obs), torch.from_numpy(acts), torch.from_numpy(rews), torch.from_numpy(\n            nobs), torch.from_numpy(dones)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/utils/popart.py",
    "content": "\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\n\n\nclass PopArt(nn.Module):\n    \"\"\" Normalize a vector of observations - across the first norm_axes dimensions\"\"\"\n\n    def __init__(self, input_shape, norm_axes=1, beta=0.99999, per_element_update=False, epsilon=1e-5, device=torch.device(\"cpu\")):\n        super(PopArt, self).__init__()\n\n        self.input_shape = input_shape\n        self.norm_axes = norm_axes\n        self.epsilon = epsilon\n        self.beta = beta\n        self.per_element_update = per_element_update\n        self.device = device\n        self.tpdv = dict(dtype=torch.float32, device=device)\n\n        self.running_mean = nn.Parameter(torch.zeros(input_shape, dtype=torch.float), requires_grad=False).to(self.device)\n        self.running_mean_sq = nn.Parameter(torch.zeros(input_shape, dtype=torch.float), requires_grad=False).to(self.device)\n        self.debiasing_term = nn.Parameter(torch.tensor(0.0, dtype=torch.float), requires_grad=False).to(self.device)\n\n    def reset_parameters(self):\n        self.running_mean.zero_()\n        self.running_mean_sq.zero_()\n        self.debiasing_term.zero_()\n\n    def running_mean_var(self):\n        debiased_mean = self.running_mean / self.debiasing_term.clamp(min=self.epsilon)\n        debiased_mean_sq = self.running_mean_sq / self.debiasing_term.clamp(min=self.epsilon)\n        debiased_var = (debiased_mean_sq - debiased_mean ** 2).clamp(max=self.alpha, min=1e-2)\n        return debiased_mean, debiased_var\n\n    def forward(self, input_vector, train=True):\n        # Make sure input is float32\n        input_vector = input_vector.to(**self.tpdv)\n\n        if train:\n            # Detach input before adding it to running means to avoid backpropping through it on\n            # subsequent batches.\n            \n            detached_input = input_vector.detach()           \n            batch_mean = detached_input.mean(dim=tuple(range(self.norm_axes)))           \n            batch_sq_mean = (detached_input ** 2).mean(dim=tuple(range(self.norm_axes)))\n            if self.per_element_update:\n                batch_size = np.prod(detached_input.size()[:self.norm_axes])\n                weight = self.beta ** batch_size\n            else:\n                weight = self.beta\n            \n            self.running_mean.mul_(weight).add_(batch_mean * (1.0 - weight))\n            self.running_mean_sq.mul_(weight).add_(batch_sq_mean * (1.0 - weight))\n            self.debiasing_term.mul_(weight).add_(1.0 * (1.0 - weight))\n\n        mean, var = self.running_mean_var()\n        out = (input_vector - mean[(None,) * self.norm_axes]) / torch.sqrt(var)[(None,) * self.norm_axes]\n        return out\n\n    def denormalize(self, input_vector):\n        \"\"\" Transform normalized data back into original distribution \"\"\"\n        input_vector = input_vector.to(**self.tpdv)\n\n        mean, var = self.running_mean_var()\n        out = input_vector * torch.sqrt(var)[(None,) * self.norm_axes] + mean[(None,) * self.norm_axes]\n        return out\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/utils/rec_buffer.py",
    "content": "import numpy as np\nfrom utils.util import get_dim_from_space\nfrom utils.segment_tree import SumSegmentTree, MinSegmentTree\n\n\ndef _cast(x):\n    return x.transpose(2, 0, 1, 3)\n\n\nclass RecReplayBuffer(object):\n    def __init__(self, policy_info, policy_agents, buffer_size, episode_length, use_same_share_obs, use_avail_acts,\n                 use_reward_normalization=False):\n        \"\"\"\n        Replay buffer class for training RNN policies. Stores entire episodes rather than single transitions.\n\n        :param policy_info: (dict) maps policy id to a dict containing information about corresponding policy.\n        :param policy_agents: (dict) maps policy id to list of agents controled by corresponding policy.\n        :param buffer_size: (int) max number of transitions to store in the buffer.\n        :param use_same_share_obs: (bool) whether all agents share the same centralized observation.\n        :param use_avail_acts: (bool) whether to store what actions are available.\n        :param use_reward_normalization: (bool) whether to use reward normalization.\n        \"\"\"\n        self.policy_info = policy_info\n\n        self.policy_buffers = {p_id: RecPolicyBuffer(buffer_size,\n                                                     episode_length,\n                                                     len(policy_agents[p_id]),\n                                                     self.policy_info[p_id]['obs_space'],\n                                                     self.policy_info[p_id]['share_obs_space'],\n                                                     self.policy_info[p_id]['act_space'],\n                                                     use_same_share_obs,\n                                                     use_avail_acts,\n                                                     use_reward_normalization)\n                               for p_id in self.policy_info.keys()}\n\n    def __len__(self):\n        return self.policy_buffers['policy_0'].filled_i\n\n    def insert(self, num_insert_episodes, obs, share_obs, acts, rewards, dones, dones_env, avail_acts):\n        \"\"\"\n        Insert a set of episodes into buffer. If the buffer size overflows, old episodes are dropped.\n\n        :param num_insert_episodes: (int) number of episodes to be added to buffer\n        :param obs: (dict) maps policy id to numpy array of observations of agents corresponding to that policy\n        :param share_obs: (dict) maps policy id to numpy array of centralized observation corresponding to that policy\n        :param acts: (dict) maps policy id to numpy array of actions of agents corresponding to that policy\n        :param rewards: (dict) maps policy id to numpy array of rewards of agents corresponding to that policy\n        :param dones: (dict) maps policy id to numpy array of terminal status of agents corresponding to that policy\n        :param dones_env: (dict) maps policy id to numpy array of terminal status of env\n        :param valid_transition: (dict) maps policy id to numpy array of whether the corresponding transition is valid of agents corresponding to that policy\n        :param avail_acts: (dict) maps policy id to numpy array of available actions of agents corresponding to that policy\n\n        :return: (np.ndarray) indexes in which the new transitions were placed.\n        \"\"\"\n        for p_id in self.policy_info.keys():\n            idx_range = self.policy_buffers[p_id].insert(num_insert_episodes, np.array(obs[p_id]),\n                                                         np.array(share_obs[p_id]), np.array(acts[p_id]),\n                                                         np.array(rewards[p_id]), np.array(dones[p_id]),\n                                                         np.array(dones_env[p_id]), np.array(avail_acts[p_id]))\n        return idx_range\n\n    def sample(self, batch_size):\n        \"\"\"\n        Sample a set of episodes from buffer, uniformly at random.\n        :param batch_size: (int) number of episodes to sample from buffer.\n\n        :return: obs: (dict) maps policy id to sampled observations corresponding to that policy\n        :return: share_obs: (dict) maps policy id to sampled observations corresponding to that policy\n        :return: acts: (dict) maps policy id to sampled actions corresponding to that policy\n        :return: rewards: (dict) maps policy id to sampled rewards corresponding to that policy\n        :return: dones: (dict) maps policy id to sampled terminal status of agents corresponding to that policy\n        :return: dones_env: (dict) maps policy id to sampled environment terminal status corresponding to that policy\n        :return: valid_transition: (dict) maps policy_id to whether each sampled transition is valid or not (invalid if corresponding agent is dead)\n        :return: avail_acts: (dict) maps policy_id to available actions corresponding to that policy\n        \"\"\"\n        inds = np.random.choice(self.__len__(), batch_size)\n        obs, share_obs, acts, rewards, dones, dones_env, avail_acts = {}, {}, {}, {}, {}, {}, {}\n        for p_id in self.policy_info.keys():\n            obs[p_id], share_obs[p_id], acts[p_id], rewards[p_id], dones[p_id], dones_env[p_id], avail_acts[p_id] = \\\n            self.policy_buffers[p_id].sample_inds(inds)\n\n        return obs, share_obs, acts, rewards, dones, dones_env, avail_acts, None, None\n\n\nclass RecPolicyBuffer(object):\n    def __init__(self, buffer_size, episode_length, num_agents, obs_space, share_obs_space, act_space,\n                 use_same_share_obs, use_avail_acts, use_reward_normalization=False):\n        \"\"\"\n        Buffer class containing buffer data corresponding to a single policy.\n\n        :param buffer_size: (int) max number of episodes to store in buffer.\n        :param episode_length: (int) max length of an episode.\n        :param num_agents: (int) number of agents controlled by the policy.\n        :param obs_space: (gym.Space) observation space of the environment.\n        :param share_obs_space: (gym.Space) centralized observation space of the environment.\n        :param act_space: (gym.Space) action space of the environment.\n        :use_same_share_obs: (bool) whether all agents share the same centralized observation.\n        :use_avail_acts: (bool) whether to store what actions are available.\n        :param use_reward_normalization: (bool) whether to use reward normalization.\n        \"\"\"\n        self.buffer_size = buffer_size\n        self.episode_length = episode_length\n        self.num_agents = num_agents\n        self.use_same_share_obs = use_same_share_obs\n        self.use_avail_acts = use_avail_acts\n        self.use_reward_normalization = use_reward_normalization\n        self.filled_i = 0\n        self.current_i = 0\n\n        # obs\n        if obs_space.__class__.__name__ == 'Box':\n            obs_shape = obs_space.shape\n            share_obs_shape = share_obs_space.shape\n        elif obs_space.__class__.__name__ == 'list':\n            obs_shape = obs_space\n            share_obs_shape = share_obs_space\n        else:\n            raise NotImplementedError\n\n        self.obs = np.zeros((self.episode_length + 1, self.buffer_size,\n                             self.num_agents, obs_shape[0]), dtype=np.float32)\n\n        if self.use_same_share_obs:\n            self.share_obs = np.zeros((self.episode_length + 1, self.buffer_size, share_obs_shape[0]), dtype=np.float32)\n        else:\n            self.share_obs = np.zeros((self.episode_length + 1, self.buffer_size, self.num_agents, share_obs_shape[0]),\n                                      dtype=np.float32)\n\n        # action\n        act_dim = np.sum(get_dim_from_space(act_space))\n        self.acts = np.zeros((self.episode_length, self.buffer_size, self.num_agents, act_dim), dtype=np.float32)\n        if self.use_avail_acts:\n            self.avail_acts = np.ones((self.episode_length + 1, self.buffer_size, self.num_agents, act_dim),\n                                      dtype=np.float32)\n\n        # rewards\n        self.rewards = np.zeros((self.episode_length, self.buffer_size, self.num_agents, 1), dtype=np.float32)\n\n        # default to done being True\n        self.dones = np.ones_like(self.rewards, dtype=np.float32)\n        self.dones_env = np.ones((self.episode_length, self.buffer_size, 1), dtype=np.float32)\n\n    def __len__(self):\n        return self.filled_i\n\n    def insert(self, num_insert_episodes, obs, share_obs, acts, rewards, dones, dones_env, avail_acts=None):\n        \"\"\"\n        Insert a set of episodes corresponding to this policy into buffer. If the buffer size overflows, old transitions are dropped.\n\n        :param num_insert_steps: (int) number of transitions to be added to buffer\n        :param obs: (np.ndarray) observations of agents corresponding to this policy.\n        :param share_obs: (np.ndarray) centralized observations of agents corresponding to this policy.\n        :param acts: (np.ndarray) actions of agents corresponding to this policy.\n        :param rewards: (np.ndarray) rewards of agents corresponding to this policy.\n        :param dones: (np.ndarray) terminal status of agents corresponding to this policy.\n        :param dones_env: (np.ndarray) environment terminal status.\n        :param valid_transition: (np.ndarray) whether each transition is valid or not (invalid if agent was dead during transition)\n        :param avail_acts: (np.ndarray) available actions of agents corresponding to this policy.\n\n        :return: (np.ndarray) indexes of the buffer the new transitions were placed in.\n        \"\"\"\n\n        # obs: [step, episode, agent, dim]\n        episode_length = acts.shape[0]\n        assert episode_length == self.episode_length, (\"different dimension!\")\n\n        if self.current_i + num_insert_episodes <= self.buffer_size:\n            idx_range = np.arange(self.current_i, self.current_i + num_insert_episodes)\n        else:\n            num_left_episodes = self.current_i + num_insert_episodes - self.buffer_size\n            idx_range = np.concatenate((np.arange(self.current_i, self.buffer_size), np.arange(num_left_episodes)))\n\n        if self.use_same_share_obs:\n            # remove agent dimension since all agents share centralized observation\n            share_obs = share_obs[:, :, 0]\n\n        self.obs[:, idx_range] = obs.copy()\n        self.share_obs[:, idx_range] = share_obs.copy()\n        self.acts[:, idx_range] = acts.copy()\n        self.rewards[:, idx_range] = rewards.copy()\n        self.dones[:, idx_range] = dones.copy()\n        self.dones_env[:, idx_range] = dones_env.copy()\n\n        if self.use_avail_acts:\n            self.avail_acts[:, idx_range] = avail_acts.copy()\n\n        self.current_i = idx_range[-1] + 1\n        self.filled_i = min(self.filled_i + len(idx_range), self.buffer_size)\n\n        return idx_range\n\n    def sample_inds(self, sample_inds):\n        \"\"\"\n        Sample a set of transitions from buffer from the specified indices.\n        :param sample_inds: (np.ndarray) indices of samples to return from buffer.\n\n        :return: obs: (np.ndarray) sampled observations corresponding to that policy\n        :return: share_obs: (np.ndarray) sampled observations corresponding to that policy\n        :return: acts: (np.ndarray) sampled actions corresponding to that policy\n        :return: rewards: (np.ndarray) sampled rewards corresponding to that policy\n        :return: dones: (np.ndarray) sampled terminal status of agents corresponding to that policy\n        :return: dones_env: (np.ndarray) sampled environment terminal status corresponding to that policy\n        :return: valid_transition: (np.ndarray) whether each sampled transition in episodes are valid or not (invalid if corresponding agent is dead)\n        :return: avail_acts: (np.ndarray) sampled available actions corresponding to that policy\n        \"\"\"\n\n        obs = _cast(self.obs[:, sample_inds])\n        acts = _cast(self.acts[:, sample_inds])\n        if self.use_reward_normalization:\n            # mean std\n            # [length, envs, agents, 1]\n            # [length, envs, 1]\n            all_dones_env = np.tile(np.expand_dims(\n                self.dones_env[:, :self.filled_i], -1), (1, 1, self.num_agents, 1))\n            first_step_dones_env = np.zeros((1, self.filled_i, self.num_agents, 1))\n            curr_dones_env = np.concatenate((first_step_dones_env, all_dones_env[:self.episode_length - 1]))\n            temp_rewards = self.rewards[:, :self.filled_i].copy()\n            temp_rewards[curr_dones_env == 1.0] = np.nan\n\n            mean_reward = np.nanmean(temp_rewards)\n            std_reward = np.nanstd(temp_rewards)\n            rewards = _cast(\n                (self.rewards[:, sample_inds] - mean_reward) / std_reward)\n        else:\n            rewards = _cast(self.rewards[:, sample_inds])\n\n        if self.use_same_share_obs:\n            share_obs = self.share_obs[:, sample_inds]\n        else:\n            share_obs = _cast(self.share_obs[:, sample_inds])\n\n        dones = _cast(self.dones[:, sample_inds])\n        dones_env = self.dones_env[:, sample_inds]\n\n        if self.use_avail_acts:\n            avail_acts = _cast(self.avail_acts[:, sample_inds])\n        else:\n            avail_acts = None\n\n        return obs, share_obs, acts, rewards, dones, dones_env, avail_acts\n\n\nclass PrioritizedRecReplayBuffer(RecReplayBuffer):\n    def __init__(self, alpha, policy_info, policy_agents, buffer_size, episode_length, use_same_share_obs,\n                 use_avail_acts, use_reward_normalization=False):\n        \"\"\"Prioritized replay buffer class for training RNN policies. See parent class.\"\"\"\n        super(PrioritizedRecReplayBuffer, self).__init__(policy_info, policy_agents, buffer_size,\n                                                         episode_length, use_same_share_obs, use_avail_acts,\n                                                         use_reward_normalization)\n        self.alpha = alpha\n        self.policy_info = policy_info\n        it_capacity = 1\n        while it_capacity < buffer_size:\n            it_capacity *= 2\n\n        self._it_sums = {p_id: SumSegmentTree(\n            it_capacity) for p_id in self.policy_info.keys()}\n        self._it_mins = {p_id: MinSegmentTree(\n            it_capacity) for p_id in self.policy_info.keys()}\n        self.max_priorities = {p_id: 1.0 for p_id in self.policy_info.keys()}\n\n    def insert(self, num_insert_episodes, obs, share_obs, acts, rewards, dones, dones_env, avail_acts=None):\n        \"\"\"See parent class.\"\"\"\n        idx_range = super().insert(num_insert_episodes, obs, share_obs, acts, rewards, dones, dones_env, avail_acts)\n        for idx in range(idx_range[0], idx_range[1]):\n            for p_id in self.policy_info.keys():\n                self._it_sums[p_id][idx] = self.max_priorities[p_id] ** self.alpha\n                self._it_mins[p_id][idx] = self.max_priorities[p_id] ** self.alpha\n\n        return idx_range\n\n    def _sample_proportional(self, batch_size, p_id=None):\n        total = self._it_sums[p_id].sum(0, len(self) - 1)\n        mass = np.random.random(size=batch_size) * total\n        idx = self._it_sums[p_id].find_prefixsum_idx(mass)\n        return idx\n\n    def sample(self, batch_size, beta=0, p_id=None):\n        \"\"\"\n        Sample a set of episodes from buffer; probability of choosing a given episode is proportional to its priority.\n        :param batch_size: (int) number of episodes to sample.\n        :param beta: (float) controls the amount of prioritization to apply.\n        :param p_id: (str) policy which will be updated using the samples.\n\n        :return: See parent class.\n        \"\"\"\n        assert len(\n            self) > batch_size, \"Cannot sample with no completed episodes in the buffer!\"\n        assert beta > 0\n\n        batch_inds = self._sample_proportional(batch_size, p_id)\n\n        p_min = self._it_mins[p_id].min() / self._it_sums[p_id].sum()\n        max_weight = (p_min * len(self)) ** (-beta)\n        p_sample = self._it_sums[p_id][batch_inds] / self._it_sums[p_id].sum()\n        weights = (p_sample * len(self)) ** (-beta) / max_weight\n\n        obs, share_obs, acts, rewards, dones, dones_env, avail_acts = {}, {}, {}, {}, {}, {}, {}\n        for p_id in self.policy_info.keys():\n            p_buffer = self.policy_buffers[p_id]\n            obs[p_id], share_obs[p_id], acts[p_id], rewards[p_id], dones[p_id], dones_env[p_id], avail_acts[\n                p_id] = p_buffer.sample_inds(batch_inds)\n\n        return obs, share_obs, acts, rewards, dones, dones_env, avail_acts, weights, batch_inds\n\n    def update_priorities(self, idxes, priorities, p_id=None):\n        \"\"\"\n        Update priorities of sampled transitions.\n        sets priority of transition at index idxes[i] in buffer\n        to priorities[i].\n        :param idxes: ([int]) List of idxes of sampled transitions\n        :param priorities: ([float]) List of updated priorities corresponding to transitions at the sampled idxes\n            denoted by variable `idxes`.\n        \"\"\"\n        assert len(idxes) == len(priorities)\n        assert np.min(priorities) > 0\n        assert np.min(idxes) >= 0\n        assert np.max(idxes) < len(self)\n\n        self._it_sums[p_id][idxes] = priorities ** self.alpha\n        self._it_mins[p_id][idxes] = priorities ** self.alpha\n\n        self.max_priorities[p_id] = max(\n            self.max_priorities[p_id], np.max(priorities))\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/utils/segment_tree.py",
    "content": "import numpy as np\n\n\ndef unique(sorted_array):\n    \"\"\"\n    More efficient implementation of np.unique for sorted arrays\n    :param sorted_array: (np.ndarray)\n    :return:(np.ndarray) sorted_array without duplicate elements\n    \"\"\"\n    if len(sorted_array) == 1:\n        return sorted_array\n    left = sorted_array[:-1]\n    right = sorted_array[1:]\n    uniques = np.append(right != left, True)\n    return sorted_array[uniques]\n\n\nclass SegmentTree(object):\n    def __init__(self, capacity, operation, neutral_element):\n        \"\"\"\n        Build a Segment Tree data structure.\n        https://en.wikipedia.org/wiki/Segment_tree\n        Can be used as regular array that supports Index arrays, but with two\n        important differences:\n            a) setting item's value is slightly slower.\n               It is O(lg capacity) instead of O(1).\n            b) user has access to an efficient ( O(log segment size) )\n               `reduce` operation which reduces `operation` over\n               a contiguous subsequence of items in the array.\n        :param capacity: (int) Total size of the array - must be a power of two.\n        :param operation: (lambda (Any, Any): Any) operation for combining elements (eg. sum, max) must form a\n            mathematical group together with the set of possible values for array elements (i.e. be associative)\n        :param neutral_element: (Any) neutral element for the operation above. eg. float('-inf') for max and 0 for sum.\n        \"\"\"\n        assert capacity > 0 and capacity & (\n            capacity - 1) == 0, \"capacity must be positive and a power of 2.\"\n        self._capacity = capacity\n        self._value = [neutral_element for _ in range(2 * capacity)]\n        self._operation = operation\n        self.neutral_element = neutral_element\n\n    def _reduce_helper(self, start, end, node, node_start, node_end):\n        if start == node_start and end == node_end:\n            return self._value[node]\n        mid = (node_start + node_end) // 2\n        if end <= mid:\n            return self._reduce_helper(start, end, 2 * node, node_start, mid)\n        else:\n            if mid + 1 <= start:\n                return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end)\n            else:\n                return self._operation(\n                    self._reduce_helper(start, mid, 2 * node, node_start, mid),\n                    self._reduce_helper(\n                        mid + 1, end, 2 * node + 1, mid + 1, node_end)\n                )\n\n    def reduce(self, start=0, end=None):\n        \"\"\"\n        Returns result of applying `self.operation`\n        to a contiguous subsequence of the array.\n            self.operation(arr[start], operation(arr[start+1], operation(... arr[end])))\n        :param start: (int) beginning of the subsequence\n        :param end: (int) end of the subsequences\n        :return: (Any) result of reducing self.operation over the specified range of array elements.\n        \"\"\"\n        if end is None:\n            end = self._capacity\n        if end < 0:\n            end += self._capacity\n        end -= 1\n        return self._reduce_helper(start, end, 1, 0, self._capacity - 1)\n\n    def __setitem__(self, idx, val):\n        # indexes of the leaf\n        idxs = idx + self._capacity\n        self._value[idxs] = val\n        if isinstance(idxs, int):\n            idxs = np.array([idxs])\n        # go up one level in the tree and remove duplicate indexes\n        idxs = unique(idxs // 2)\n        while len(idxs) > 1 or idxs[0] > 0:\n            # as long as there are non-zero indexes, update the corresponding values\n            self._value[idxs] = self._operation(\n                self._value[2 * idxs],\n                self._value[2 * idxs + 1]\n            )\n            # go up one level in the tree and remove duplicate indexes\n            idxs = unique(idxs // 2)\n\n    def __getitem__(self, idx):\n        assert np.max(idx) < self._capacity\n        assert 0 <= np.min(idx)\n        return self._value[self._capacity + idx]\n\n\nclass SumSegmentTree(SegmentTree):\n    def __init__(self, capacity):\n        super(SumSegmentTree, self).__init__(\n            capacity=capacity,\n            operation=np.add,\n            neutral_element=0.0\n        )\n        self._value = np.array(self._value)\n\n    def sum(self, start=0, end=None):\n        \"\"\"\n        Returns arr[start] + ... + arr[end]\n        :param start: (int) start position of the reduction (must be >= 0)\n        :param end: (int) end position of the reduction (must be < len(arr), can be None for len(arr) - 1)\n        :return: (Any) reduction of SumSegmentTree\n        \"\"\"\n        return super(SumSegmentTree, self).reduce(start, end)\n\n    def find_prefixsum_idx(self, prefixsum):\n        \"\"\"\n        Find the highest index `i` in the array such that\n            sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum for each entry in prefixsum\n        if array values are probabilities, this function\n        allows to sample indexes according to the discrete\n        probability efficiently.\n        :param prefixsum: (np.ndarray) float upper bounds on the sum of array prefix\n        :return: (np.ndarray) highest indexes satisfying the prefixsum constraint\n        \"\"\"\n        if isinstance(prefixsum, float):\n            prefixsum = np.array([prefixsum])\n        assert 0 <= np.min(prefixsum)\n        assert np.max(prefixsum) <= self.sum() + 1e-5\n        assert isinstance(prefixsum[0], float)\n\n        idx = np.ones(len(prefixsum), dtype=int)\n        cont = np.ones(len(prefixsum), dtype=bool)\n\n        while np.any(cont):  # while not all nodes are leafs\n            idx[cont] = 2 * idx[cont]\n            prefixsum_new = np.where(\n                self._value[idx] <= prefixsum, prefixsum - self._value[idx], prefixsum)\n            # prepare update of prefixsum for all right children\n            idx = np.where(np.logical_or(\n                self._value[idx] > prefixsum, np.logical_not(cont)), idx, idx + 1)\n            # Select child node for non-leaf nodes\n            prefixsum = prefixsum_new\n            # update prefixsum\n            cont = idx < self._capacity\n            # collect leafs\n        return idx - self._capacity\n\n\nclass MinSegmentTree(SegmentTree):\n    def __init__(self, capacity):\n        super(MinSegmentTree, self).__init__(\n            capacity=capacity,\n            operation=np.minimum,\n            neutral_element=float('inf')\n        )\n        self._value = np.array(self._value)\n\n    def min(self, start=0, end=None):\n        \"\"\"\n        Returns min(arr[start], ...,  arr[end])\n        :param start: (int) start position of the reduction (must be >= 0)\n        :param end: (int) end position of the reduction (must be < len(arr), can be None for len(arr) - 1)\n        :return: (Any) reduction of MinSegmentTree\n        \"\"\"\n        return super(MinSegmentTree, self).reduce(start, end)\n"
  },
  {
    "path": "examples/Social_Cognition/ToCM/utils/util.py",
    "content": "import copy\nimport gym\nimport numpy as np\nfrom gym.spaces import Box, Discrete, Tuple\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.distributed as dist\nfrom torch.autograd import Variable\n\ndef to_torch(input):\n    return torch.from_numpy(input) if type(input) == np.ndarray else input\n\ndef to_numpy(x):\n    return x.detach().cpu().numpy()\n\nclass FixedCategorical(torch.distributions.Categorical):\n    def sample(self):\n        return super().sample()\n\n    def log_probs(self, actions):\n        return (\n            super()\n            .log_prob(actions.squeeze(-1))\n            .view(actions.size(0), -1)\n            .sum(-1)\n            .unsqueeze(-1)\n        )\n\n    def mode(self):\n        return self.probs.argmax(dim=-1, keepdim=True)\n\n\nclass MultiDiscrete(gym.Space):\n    \"\"\"\n    - The multi-discrete action space consists of a series of discrete action spaces with different parameters\n    - It can be adapted to both a Discrete action space or a continuous (Box) action space\n    - It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space\n    - It is parametrized by passing an array of arrays containing [min, max] for each discrete action space\n       where the discrete action space can take any integers from `min` to `max` (both inclusive)\n    Note: A value of 0 always need to represent the NOOP action.\n    e.g. Nintendo Game Controller\n    - Can be conceptualized as 3 discrete action spaces:\n        1) Arrow Keys: Discrete 5  - NOOP[0], UP[1], RIGHT[2], DOWN[3], LEFT[4]  - params: min: 0, max: 4\n        2) Button A:   Discrete 2  - NOOP[0], Pressed[1] - params: min: 0, max: 1\n        3) Button B:   Discrete 2  - NOOP[0], Pressed[1] - params: min: 0, max: 1\n    - Can be initialized as\n        MultiDiscrete([ [0,4], [0,1], [0,1] ])\n    \"\"\"\n\n    def __init__(self, array_of_param_array):\n        self.low = np.array([x[0] for x in array_of_param_array])\n        self.high = np.array([x[1] for x in array_of_param_array])\n        self.num_discrete_space = self.low.shape[0]\n        self.n = np.sum(self.high) + 2\n\n    def sample(self):\n        \"\"\" Returns a array with one sample from each discrete action space \"\"\"\n        # For each row: round(random .* (max - min) + min, 0)\n        random_array = np.random.rand(self.num_discrete_space)\n        return [int(x) for x in np.floor(np.multiply((self.high - self.low + 1.), random_array) + self.low)]\n\n    def contains(self, x):\n        return len(x) == self.num_discrete_space and (np.array(x) >= self.low).all() and (np.array(x) <= self.high).all()\n\n    @property\n    def shape(self):\n        return self.num_discrete_space\n\n    def __repr__(self):\n        return \"MultiDiscrete\" + str(self.num_discrete_space)\n\n    def __eq__(self, other):\n        return np.array_equal(self.low, other.low) and np.array_equal(self.high, other.high)\n\n\nclass DecayThenFlatSchedule():\n    def __init__(self,\n                 start,\n                 finish,\n                 time_length,\n                 decay=\"exp\"):\n\n        self.start = start\n        self.finish = finish\n        self.time_length = time_length\n        self.delta = (self.start - self.finish) / self.time_length\n        self.decay = decay\n\n        if self.decay in [\"exp\"]:\n            self.exp_scaling = (-1) * self.time_length / \\\n                np.log(self.finish) if self.finish > 0 else 1\n\n    def eval(self, T):\n        if self.decay in [\"linear\"]:\n            return max(self.finish, self.start - self.delta * T)\n        elif self.decay in [\"exp\"]:\n            return min(self.start, max(self.finish, np.exp(- T / self.exp_scaling)))\n    pass\n\n\ndef huber_loss(e, d):\n    a = (abs(e) <= d).float()\n    b = (abs(e) > d).float()\n    return a*e**2/2 + b*d*(abs(e)-d/2)\n\n\ndef mse_loss(e):\n    return e**2\n\n\ndef init(module, weight_init, bias_init, gain=1):\n    weight_init(module.weight.data, gain=gain)\n    bias_init(module.bias.data)\n    return module\n\n\ndef get_clones(module, N):\n    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])\n\n# https://github.com/ikostrikov/pytorch-ddpg-naf/blob/master/ddpg.py#L11\ndef soft_update(target, source, tau):\n    \"\"\"\n    Perform DDPG soft update (move target params toward source based on weight\n    factor tau)\n    Inputs:\n        target (torch.nn.Module): Net to copy parameters to\n        source (torch.nn.Module): Net whose parameters to copy\n        tau (float, 0 < x < 1): Weight factor for update\n    \"\"\"\n    for target_param, param in zip(target.parameters(), source.parameters()):\n        target_param.data.copy_(\n            target_param.data * (1.0 - tau) + param.data * tau)\n\n# https://github.com/ikostrikov/pytorch-ddpg-naf/blob/master/ddpg.py#L15\ndef hard_update(target, source):\n    \"\"\"\n    Copy network parameters from source to target\n    Inputs:\n        target (torch.nn.Module): Net to copy parameters to\n        source (torch.nn.Module): Net whose parameters to copy\n    \"\"\"\n    for target_param, param in zip(target.parameters(), source.parameters()):\n        target_param.data.copy_(param.data)\n\n# https://github.com/seba-1511/dist_tuto.pth/blob/gh-pages/train_dist.py\ndef average_gradients(model):\n    \"\"\" Gradient averaging. \"\"\"\n    size = float(dist.get_world_size())\n    for param in model.parameters():\n        dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM, group=0)\n        param.grad.data /= size\n\n\ndef onehot_from_logits(logits, avail_logits=None, eps=0.0):\n    \"\"\"\n    Given batch of logits, return one-hot sample using epsilon greedy strategy\n    (based on given epsilon)\n    \"\"\"\n    # get best (according to current policy) actions in one-hot form\n    logits = to_torch(logits)\n    \n    dim = len(logits.shape) - 1\n    if avail_logits is not None:\n        avail_logits = to_torch(avail_logits)\n        logits[avail_logits == 0] = -1e10\n    argmax_acs = (logits == logits.max(dim, keepdim=True)[0]).float()\n    if eps == 0.0:\n        return argmax_acs\n    # get random actions in one-hot form\n    rand_acs = Variable(torch.eye(logits.shape[1])[[np.random.choice(range(logits.shape[1]), size=logits.shape[0])]], requires_grad=False)\n    # chooses between best and random actions using epsilon greedy\n    return torch.stack([argmax_acs[i] if r > eps else rand_acs[i] for i, r in\n                        enumerate(torch.rand(logits.shape[0]))])\n\n# modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb\ndef sample_gumbel(shape, eps=1e-20, tens_type=torch.FloatTensor):\n    \"\"\"Sample from Gumbel(0, 1)\"\"\"\n    U = Variable(tens_type(*shape).uniform_(), requires_grad=False)\n    return -torch.log(-torch.log(U + eps) + eps)\n\n# modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb\ndef gumbel_softmax_sample(logits, avail_logits, temperature, device=torch.device('cpu')):\n    \"\"\" Draw a sample from the Gumbel-Softmax distribution\"\"\"\n    if str(device) == 'cpu':\n        y = logits + sample_gumbel(logits.shape, tens_type=type(logits.data))\n    else:\n        y = (logits.cpu() + sample_gumbel(logits.shape,\n                                          tens_type=type(logits.data))).cuda()\n\n    dim = len(logits.shape) - 1\n    if avail_logits is not None:\n        avail_logits = to_torch(avail_logits).to(device)\n        y[avail_logits==0] = -1e10\n    return F.softmax(y / temperature, dim=dim)\n\n# modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb\ndef gumbel_softmax(logits, avail_logits=None, temperature=1.0, hard=False, device=torch.device('cpu')):\n    \"\"\"Sample from the Gumbel-Softmax distribution and optionally discretize.\n    Args:\n      logits: [batch_size, n_class] unnormalized log-probs\n      temperature: non-negative scalar\n      hard: if True, take argmax, but differentiate w.r.t. soft sample y\n    Returns:\n      [batch_size, n_class] sample from the Gumbel-Softmax distribution.\n      If hard=True, then the returned sample will be one-hot, otherwise it will\n      be a probabilitiy distribution that sums to 1 across classes\n    \"\"\"\n    y = gumbel_softmax_sample(logits, avail_logits, temperature, device)\n    if hard:\n        y_hard = onehot_from_logits(y)\n        y = (y_hard - y).detach() + y\n    return y\n\n\ndef gaussian_noise(shape, std):\n    return torch.empty(shape).normal_(mean=0, std=std)\n\n\ndef get_obs_shape(obs_space):\n    if obs_space.__class__.__name__ == \"Box\":\n        obs_shape = obs_space.shape\n    elif obs_space.__class__.__name__ == \"list\":\n        obs_shape = obs_space\n    else:\n        raise NotImplementedError\n    \n    return obs_shape\n\ndef get_dim_from_space(space):\n    if isinstance(space, Box):\n        dim = space.shape[0]\n    elif isinstance(space, Discrete):\n        dim = space.n\n    elif isinstance(space, Tuple):\n        dim = sum([get_dim_from_space(sp) for sp in space])\n    elif \"MultiDiscrete\" in space.__class__.__name__:\n        return (space.high - space.low) + 1\n    elif isinstance(space, list):\n        dim = space[0]\n    else:\n        raise Exception(\"Unrecognized space: \", type(space))\n    return dim\n\n\ndef get_state_dim(observation_dict, action_dict):\n    combined_obs_dim = sum([get_dim_from_space(space)\n                            for space in observation_dict.values()])\n    combined_act_dim = 0\n    for space in action_dict.values():\n        dim = get_dim_from_space(space)\n        if isinstance(dim, np.ndarray):\n            combined_act_dim += int(sum(dim))\n        else:\n            combined_act_dim += dim\n    return combined_obs_dim, combined_act_dim, combined_obs_dim+combined_act_dim\n\n\ndef get_cent_act_dim(action_space):\n    cent_act_dim = 0\n    for space in action_space:\n        dim = get_dim_from_space(space)\n        if isinstance(dim, np.ndarray):\n            cent_act_dim += int(sum(dim))\n        else:\n            cent_act_dim += dim\n    return cent_act_dim\n\n\ndef is_discrete(space):\n    if isinstance(space, Discrete) or \"MultiDiscrete\" in space.__class__.__name__:\n        return True\n    else:\n        return False\n\n\ndef is_multidiscrete(space):\n    if \"MultiDiscrete\" in space.__class__.__name__:\n        return True\n    else:\n        return False\n\n\ndef make_onehot(int_action, action_dim, seq_len=None):\n    if type(int_action) == torch.Tensor:\n        int_action = int_action.cpu().numpy()\n    if not seq_len:\n        return np.eye(action_dim)[int_action]\n    if seq_len:\n        onehot_actions = []\n        for i in range(seq_len):\n            onehot_action = np.eye(action_dim)[int_action[i]]\n            onehot_actions.append(onehot_action)\n        return np.stack(onehot_actions)\n\n\ndef avail_choose(x, avail_x=None):\n    x = to_torch(x)\n    if avail_x is not None:\n        avail_x = to_torch(avail_x)\n        x[avail_x == 0] = -1e10\n    return x#FixedCategorical(logits=x)\n\n\ndef tile_images(img_nhwc):\n    \"\"\"\n    Tile N images into one big PxQ image\n    (P,Q) are chosen to be as close as possible, and if N\n    is square, then P=Q.\n    input: img_nhwc, list or array of images, ndim=4 once turned into array\n        n = batch index, h = height, w = width, c = channel\n    returns:\n        bigim_HWc, ndarray with ndim=3\n    \"\"\"\n    img_nhwc = np.asarray(img_nhwc)\n    N, h, w, c = img_nhwc.shape\n    H = int(np.ceil(np.sqrt(N)))\n    W = int(np.ceil(float(N)/H))\n    img_nhwc = np.array(\n        list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)])\n    img_HWhwc = img_nhwc.reshape(H, W, h, w, c)\n    img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4)\n    img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c)\n    return img_Hh_Ww_c\n"
  },
  {
    "path": "examples/Social_Cognition/ToM/BrainArea/PFC_ToM.py",
    "content": "from braincog.base.learningrule.STDP import *\nfrom braincog.base.brainarea.PFC import dlPFC\nfrom utils.Encoder import *\n\n#exploit or explore\nnum_enpop = 6\nnum_depop = 10\ngreedy = 0.8#0.5\n\n#state\nA_state = 4\nN_state = 6\ncell_num = 6\n#action\nC=10\n\nclass PFC_ToM(dlPFC):\n    \"\"\"\n    SNNLinear\n    \"\"\"\n    def __init__(self,\n                 step,\n                 encode_type,\n                 in_features:int,\n                 out_features:int,\n                 bias,\n                 node,\n                 num_state,\n                 greedy=0.8,\n                 *args,\n                 **kwargs):\n        super().__init__(step, encode_type, in_features, out_features, bias, *args, **kwargs)\n        self.encoder = PopEncoder(self.step, encode_type)\n        self.encoder.device = torch.device('cpu')\n        self.bias = bias\n        self.in_features = in_features\n        self.out_features = out_features\n        self.node1 = node(threshold=0.5, tau=2.)\n        self.node_name1 = node\n        self.node2 = node(threshold=0.5, tau=2.)\n        self.node_name2 = node\n        self.num_state = num_state\n        self.greedy = greedy\n        self.fc = self._create_fc()\n        self.c = self._rest_c()\n\n\n    def _rest_c(self):\n        c = torch.rand((self.out_features, self.in_features)) # eligibility trace\n        return c\n\n    def _create_fc(self):\n        \"\"\"\n        the connection of the SNN linear\n        @return: nn.Linear\n        \"\"\"\n        fc = nn.Linear(in_features=self.in_features,\n                  out_features=self.out_features, bias=self.bias)\n        return fc\n\n    def update_c(self, c, dw, tau_c=0.2):\n        \"\"\"\n        update the trace of eligibility\n        @param c: a tensor to record eligibility\n        @param dw: the results of STDP\n        @param tau_c: the parameter of trace decay\n        @return: a update tensor to record eligibility\n        Equation:\n        delta_c = (-(c / tau_c) + dw) * dela_t\n        c = c + delta_c\n        reference:<Solving the Distal Reward Problem through ...>\n        \"\"\"\n        # delta_c = -(c / tau_c) + dw           #dela_t = 1 ignore\n        # c = c + delta_c\n        c = c + tau_c * dw\n        return c\n\n    def _call_reward(self, R, c, s, T_map):  # eligibility\n        \"\"\"\n        R-STDP\n        @param R: reward\n        @param c: a tensor to record eligibility\n        @param s: weight of network\n        @param T_map: the mapping of the state-action pair\n        @return: update weight of network\n        Equation:\n        delta_s = c * reward\n        s = s + delta_s\n        reference:<Solving the Distal Reward Problem through ...>\n        \"\"\"\n        c[c > 0] =  c[c > 0] * R * 1\n        c[c <= 0] = - c[c <= 0] * R * 1\n        c = c.clamp(min=-1, max=1)\n        # print('before',s[:, torch.where(T_map.gt(0))[1][0]])\n        s = s + c * T_map\n        # # print('after',s[:, torch.where(T_map.gt(0))[1][0]])\n        s = (s - s.min(dim=0).values.unsqueeze(dim=1).T.detach().repeat(s.shape[0], 1)) / (\n                s.max(dim=0).values.unsqueeze(dim=1).T.detach().repeat(s.shape[0], 1) -\n                s.min(dim=0).values.unsqueeze(dim=1).T.detach().repeat(s.shape[0], 1)\n        )\n        # s = s * 0.5\n        return s\n\n    def update_s(self, R, mapping):\n        T_map = torch.zeros((self.out_features, self.in_features))\n        T_map[mapping['action']*C:mapping['action']*C+C,\\\n        torch.where(torch.tensor(self.encoder(mapping['state'],\\\n                                              self.in_features, self.num_state)[:, 0]).gt(0))]=1\n        self.fc.weight.data = self._call_reward(R, self.c, self.fc.weight.data, T_map)\n        # print(mapping, 'mapping')\n    def forward(self, inputs, num_action, episode):\n        \"\"\"\n        decision\n        @param inputs: state\n        @param num_action: num_action # consider to delete\n        @return: action\n        \"\"\"\n        inputs = self.encoder(inputs, self.in_features, self.num_state)\n        count_group = torch.zeros(num_action)\n        stdp = STDP(self.node2, self.fc, decay=0.80)\n        # self.c = self._rest_c()\n        # stdp.connection.weight.data = torch.rand((self.out_features, self.in_features))\n        for t in range(self.step):\n            l1_in = torch.tensor(inputs[:, t])\n            l1_out = self.node1(l1_in).unsqueeze(0)  #pre  : l1_out\n            l2_out, dw = stdp(l1_out)   #dw -- STDP\n            self.c = self.update_c(self.c, dw[0])\n\n        # l2_out = l2_out.T\n        for i in range(num_action):\n            count_group[i] = l2_out.T[i * num_depop:(i + 1) * num_depop].sum()\n        # exploration or exploitation\n        epsilon = random.random()\n        if epsilon < self.greedy + episode * 0.004:#:\n            action = count_group.argmax()\n        else:\n            action = torch.tensor(random.randint(0, 3))\n\n        return action.item()\n\n\n\n\n\n"
  },
  {
    "path": "examples/Social_Cognition/ToM/BrainArea/TPJ.py",
    "content": "import torch\nfrom braincog.base.brainarea.Insula import *\nfrom rulebasedpolicy.world_model import *\nfrom BrainArea.dACC import *\nfrom BrainArea.PFC_ToM import *\n\nNPC_1 = 2\nNPC_2 = 3\nAgent = 4\n\n#exploit or explore\nnum_enpop = 6\nnum_depop = 10\ngreedy = 0.8#0.5\n\n#state\nA_state = 4\nN_state = 6\ncell_num = 6\n#action\nC=10\n\nclass ToM:\n    def __init__(self, env):\n        \"\"\"\n\n        @param axis:输入为agent自己的观察到位置信息\n        @param obs:遮挡关系\n        \"\"\"\n        self.axis = None\n        self.obs = None\n        self.NPC_num = None\n        self.env = env\n        self.env.trigger = 0\n\n    def TPJ(self, NPC_num, axis, obs):\n        \"\"\"\n        perspective_taking\n        agent take NPC2's perspective\n        @param NPC_num: which NPC?\n        @return:\n        axis_new:站在other的角度看到其他智能体的遮挡关系,return axis,\n        axis_switch:站在self的角度看到其他智能体的遮挡关系,return axis\n        obs_switch:站在other的角度看到其他智能体的遮挡关系,return obs\n        \"\"\"\n        self.env.trigger = 0\n        axis_switch = [[6,6], [6,6], [6,6]]\n        axis_new = [[6, 6], [6, 6], [6, 6]]\n        self.axis = axis\n        self.obs = obs\n        axis_switch[0], axis_switch[NPC_num] = axis[NPC_num], axis[0]\n        axis_switch[1] = axis[1]\n        obs_switch = big_env(self.obs)\n        obs_switch[self.axis[0][0], self.axis[0][1]], obs_switch[self.axis[NPC_num][0],self.axis[NPC_num][1]] = \\\n            obs_switch[self.axis[NPC_num][0],self.axis[NPC_num][1]],obs_switch[self.axis[0][0], self.axis[0][1]]\n        x = np.argwhere((obs_switch==2)|(obs_switch==8))\n        if self.axis[NPC_num][0] != 6 or self.axis[NPC_num][1] != 6:\n            shelter_obs = shelter_env(obs_switch[1:6,1:6])\n            obs_switch[1:6,1:6], m = self.gain_obs(a=obs_switch[1:6,1:6], aa=shelter_obs, b=axis_switch[1], c=axis_switch[2], bb=2,cc=4)\n            if m == True:\n                axis_switch[1] = [6,6]\n        else:\n            obs_switch = []\n        axis_new[0] = axis_switch[NPC_num]\n        axis_new[1] = axis_switch[1]\n        axis_new[NPC_num] = axis_switch[0]\n\n        return  axis_new, axis_switch, obs_switch\n\n    def gain_obs(self, a,aa,b,c,bb,cc):\n        m = False\n        if b!=[6,6]:\n            if aa[b[0]-1, b[1]-1] == 0:\n                a[b[0]-1, b[1]-1] = 1#2\n                m = True\n            else:\n                a[b[0] - 1, b[1] - 1] = bb\n        if aa[c[0]-1, c[1]-1] ==0:\n            # print('-------')\n            a[c[0]-1, c[1]-1] = 1#4\n        else:\n            a[c[0] - 1, c[1] - 1] = cc\n        return a, m\n\n    def belief_reasoning(self, test_x, net_NPC, num_action, episode):\n        output = net_NPC(inputs=test_x, \\\n                              num_action=num_action, \\\n                              episode=episode)\n        return output\n\n    def state_evaluation(self, prediction_next_state):\n        \"\"\"\n        state_evaluation\n        @param prediction_next_state:\n        @return:\n        \"\"\"\n        input = np.array(prediction_next_state)\n        test_x = torch.tensor([[(int(bool(input[0][0] - input[2][0])))*10, (int(bool(input[0][1] - input[2][1])))*10]])\n        T = 5\n        num_popneurons = 2\n        safety = 2\n        dACC_net = dACC(step=T, encode_type='rate', bias=True,\n                        in_features=num_popneurons, out_features=safety,\n                        node=node.LIFNode)\n        dACC_net.load_state_dict(torch.load(os.path.join(sys.path[0], 'BrainArea/checkpoint', 'dACC_net.pth'))['dacc'])\n        output = dACC_net(inputs=test_x, epoch=50)\n        output = bool(int(output[0].cpu().detach().numpy().tolist()))\n        print(output,test_x)\n        return output\n\n    def prediction_state(self, axis_new, axis, action_NPC1, net, num_action, episode):\n        \"\"\"\n        根据当前状态和经验预测下一个状态\n        @return:下一个step的state\n        \"\"\"\n        self.env.trigger = 0\n        action_move = {\n            0: (0, -1),\n            1: (0, 1),\n            2: (-1, 0),\n            3: (1, 0),\n            4: (0, 0)\n        }\n        next_axis = [[6,6],[6,6],[6,6]]\n        # inputspike_test = np.array([axis_new[0],axis_new[1],axis_new[2]])\n        inputspike_test = sum(axis_new, [])\n\n        action_NPC2 = self.belief_reasoning(test_x=inputspike_test, net_NPC=net, num_action=num_action, episode=episode)\n        action_agent = 3\n        #NPC_1\n        next_axis[1][0] = axis[1][0] + action_move[action_NPC1][1]\n        next_axis[1][1] = axis[1][1] + action_move[action_NPC1][0]\n        #NPC_2\n\n        if self.obs[axis[2][0] + action_move[action_NPC2][1]-1, axis[2][1] + action_move[action_NPC2][0]-1] != 5:\n            next_axis[2][0] = axis[2][0] + action_move[action_NPC2][1]\n            next_axis[2][1] = axis[2][1] + action_move[action_NPC2][0]\n        #NPC_agent\n        next_axis[0][0] = axis[0][0] + action_move[action_agent][1]\n        next_axis[0][1] = axis[0][1] + action_move[action_agent][0]\n\n        return next_axis\n\n    def altruism(self, axis_switch, axis_NPC, n_actions):\n        \"\"\"\n        假设有一个开关，agent按下去可以让NPC不动\n        Q_bad:NPC的错误观测的有偏差Q\n        Q_good:正确的Q\n        Q_delta:中最小的值就是容易导致NPC出现危险的值\n        找到最小危险中的最大值对应的action\n        @param axis_switch:\n        @param axis_NPC:\n        @param n_actions:\n        @return:下一个step的action\n        \"\"\"\n        actions = list(range(n_actions))\n        action_NPC_list = list(range(n_actions))\n        #others' view\n        data_NPC = pd.read_csv('./data/NPC_assessment.csv', index_col=[0],\n                               dtype={1: np.float64, 2: np.float64, 3: np.float64, 4: np.float64,\n                                      5: np.float64})\n        #self's view\n        data_agent = pd.read_csv('./data/agent_assessment.csv', index_col=[0],\n                               dtype={1: np.float64, 2: np.float64, 3: np.float64, 4: np.float64,\n                                      5: np.float64})\n\n        # print(axis_NPC, axis_switch)\n        Q_bad = data_NPC.loc[str(axis_NPC), :]\n        if str(axis_switch) not in data_agent.index:\n            # append new state to q table\n            # print('1')\n            data_agent = data_agent.append(\n                pd.Series(\n                    [0] * len(list(range(self.env.n_actions))),\n                    index=data_agent.columns,\n                    name=str(axis_switch),\n                    )\n            )\n        Q_good = data_agent.loc[str(axis_switch), :]\n        Q_delta = Q_good - Q_bad\n        # print(Q_delta)\n        # max_Q_delta = [None] * n_actions\n        min_Q_delta_set = []\n        #stop\n        for action_a in actions:\n            if action_a == 4:\n                action_NPC_list = [4]\n\n            min_Q_delta = []\n            for i in action_NPC_list:\n                #\n                # print(i)\n                min_Q_delta.append(Q_delta[i])\n\n            min_Q_delta_set.append(min(min_Q_delta))\n        # print(min_Q_delta_set,'---------')\n        action_altruism = min_Q_delta_set.index(max(min_Q_delta_set))\n\n        if action_altruism  == 4:\n            self.env.trigger = 1\n            # print('---------------------------------------------')\n        # env.SHOW()\n        # time.sleep(1.0)\n        # max_Q_delta[action_a] = max(min_Q_delta_set)\n        return action_altruism\n\n    def INS(self, axis1, axis2):\n\n        num_IPLM = axis1.shape[1]\n        num_IPLV = axis1.shape[1]\n        Insula_connection = []\n        # IPLV-Insula\n        con_matrix0 = torch.eye(num_IPLM, dtype=torch.float) * 2\n        Insula_connection.append(CustomLinear(con_matrix0))\n        # STS-Insula\n        con_matrix1 = torch.eye(num_IPLV, dtype=torch.float) * 2\n        Insula_connection.append(CustomLinear(con_matrix1))\n\n        Insula = InsulaNet(Insula_connection)\n\n        confidence = 0\n        Insula.reset()\n        for t in range(2):\n            Insula((axis1-axis2) * 10, torch.zeros_like(axis1) * 10)\n        if sum(sum(Insula.out_Insula)) > 0:\n            confidence = confidence + 1\n\n        return confidence\n\n"
  },
  {
    "path": "examples/Social_Cognition/ToM/BrainArea/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Social_Cognition/ToM/BrainArea/dACC.py",
    "content": "import torch\nimport matplotlib.pyplot as plt\nimport numpy as np\nnp.set_printoptions(threshold=np.inf)\nfrom utils.one_hot import *\nimport os\nimport time\nimport sys\nfrom tqdm import tqdm\n\nfrom braincog.model_zoo.base_module import BaseLinearModule, BaseModule\nfrom braincog.base.learningrule.STDP import *\nimport sys\nsys.path.append(\"..\")\n\nclass dACC(BaseModule):\n    \"\"\"\n    SNNLinear\n    \"\"\"\n    def __init__(self,\n                 step,\n                 encode_type,\n                 in_features:int,\n                 out_features:int,\n                 bias,\n                 node,\n                 *args,\n                 **kwargs):\n        super().__init__(step, encode_type, *args, **kwargs)\n        self.bias = bias\n        self.in_features = in_features\n        self.out_features = out_features\n        self.node1 = node(threshold=0.5, tau=2.)\n        self.node_name1 = node\n        self.node2 = node(threshold=0.1, tau=2.)\n        self.node_name2 = node\n        self.fc = self._create_fc()\n        self.c = self._rest_c()\n\n\n    def _rest_c(self):\n        c = torch.rand((self.out_features, self.in_features)) # eligibility trace\n        return c\n\n    def _create_fc(self):\n        \"\"\"\n        the connection of the SNN linear\n        @return: nn.Linear\n        \"\"\"\n        fc = nn.Linear(in_features=self.in_features,\n                  out_features=self.out_features, bias=self.bias)\n        return fc\n\n    def update_c(self, c, STDP, tau_c=0.2):\n        \"\"\"\n        update the trace of eligibility\n        @param c: a tensor to record eligibility\n        @param STDP: the results of STDP\n        @param tau_c: the parameter of trace decay\n        @return: a update tensor to record eligibility\n        Equation:\n        delta_c = (-(c / tau_c) + STDP) * dela_t\n        c = c + delta_c\n        reference:<Solving the Distal Reward Problem through ...>\n        \"\"\"\n        c = c + tau_c * STDP\n        return c\n\n    def forward(self, inputs, epoch):\n        \"\"\"\n        decision\n        @param inputs: state\n        @return: action\n        \"\"\"\n        output = []\n        stdp = STDP(self.node2, self.fc, decay=0.80)\n        self.c = self._rest_c()\n        # stdp.connection.weight.data = torch.rand((self.out_features, self.in_features))\n\n        for i in range(inputs.shape[0]):\n            for t in range(self.step):\n                l1_in = torch.tensor(inputs[i, :])\n                l1_out = self.node1(l1_in).unsqueeze(0)  #pre  : l1_out\n                l2_out, dw = stdp(l1_out)   #dw -- STDP\n                self.c = self.update_c(self.c, dw[0])\n            output.append(torch.min(l2_out))\n            # output.append((l2_out.any() == 0).cpu().detach().numpy().tolist())\n\n        return output\n\n\n# if __name__ == '__main__':\n#     np.random.seed(6)\n#     T = 5\n#     num_popneurons = 2\n#     safety = 2\n#     epoch = 50\n#     file_name = \"/home/zhaozhuoya/braincog/examples/ToM/data/injury_value.txt\"\n#     state = []\n#     with open(file_name) as f:\n#         data = []\n#         data_split = f.readlines()  #\n#         for i in data_split:\n#             state.append(one_hot(int(i[0])))\n#\n#     output = np.array(state)\n#     train_y = output\n#     test_y = output[79:82]#output[12].reshape(1,2)\n#\n#     file_name = \"/home/zhaozhuoya/braincog/examples/ToM/data/injury_memory.txt\"\n#     state = []\n#     with open(file_name) as f:\n#         data_split = f.readlines()\n#         for i in data_split:\n#             data = []\n#             data.append(int(bool(abs(int(i[2]) - int(i[18]))))*10)\n#             data.append(int(bool(abs(int(i[5]) - int(i[21]))))*10)\n#             state.append(data)\n#     input = np.array(state)\n#     train_x = input\n#     test_x = input[79:82]\n#     dACC_net = dACC(step=T, encode_type='rate', bias=True,\n#                         in_features=num_popneurons, out_features=safety,\n#                         node=node.LIFNode)\n#     dACC_net.fc.weight.data = torch.rand((safety, num_popneurons))\n#     dACC_net.load_state_dict(torch.load('./checkpoint/dACC_net.pth')['dacc'])\n#     output = dACC_net(inputs=train_x, epoch=50)\n#     for i in range(len(output)):\n#         print(output[i], train_x[i])\n    # torch.save({'dacc': dACC_net.state_dict()}, os.path.join('./checkpoint', 'dACC_net.pth'))\n    # dACC_net.load_state_dict(torch.load('./checkpoint/dACC_net.pth')['dacc'])\n    # output = dACC_net(inputs=test_x, epoch=50)\n    # for i in range(len(test_x)):\n    #\n    #     print(output[i],test_x[i])\n\n\n"
  },
  {
    "path": "examples/Social_Cognition/ToM/BrainArea/one_hot.py",
    "content": "from numpy import argmax\nimport numpy as np\n\ndef one_hot(value):\n    num = '01'\n    letter = [0 for _ in range(len(num))]\n    letter[value] = 1\n    letter = np.array([letter])\n    return letter\n\n\n\n# print(one_hot(4))\n\n"
  },
  {
    "path": "examples/Social_Cognition/ToM/BrainArea/test.py",
    "content": "import torch\nfrom braincog.base.connection.CustomLinear import *\nfrom braincog.base.node.node import *\nfrom braincog.base.learningrule.STDP import *\nfrom braincog.base.brainarea.IPL import *\nfrom braincog.base.brainarea.Insula import *\n\nif __name__ == \"__main__\":\n    num_neuron = 4\n    num_vPMC = num_neuron\n    num_STS  = num_neuron\n    num_IPLM = num_neuron\n    num_IPLV = num_neuron\n    num_Insula = num_neuron\n\n    # InsulaNet\n    # connection\n    Insula_connection = []\n    # IPLV-Insula\n    con_matrix0 = torch.eye(num_IPLM, dtype=torch.float) * 2\n    Insula_connection.append(CustomLinear(con_matrix0))\n    # STS-Insula\n    con_matrix1 = torch.eye(num_IPLV, dtype=torch.float) * 2\n    Insula_connection.append(CustomLinear(con_matrix1))\n\n    Insula = InsulaNet(Insula_connection)\n\n    a = torch.tensor([[1.,2.,1.,2.]])\n    b = torch.tensor([[1., 2., 1., 2.]])\n    c = torch.tensor([[2., 2., 4., 2.]])\n\n    confidence = [0, 0]\n    for t in range(2):\n        Insula(a*10, b*10)\n    if sum(sum(Insula.out_Insula)) > 0:\n        confidence[0] = confidence[0] + 1\n    Insula.reset()\n\n    for t in range(2):\n        Insula(a*10, c*10)\n    if sum(sum(Insula.out_Insula)) > 0:\n        confidence[0] = confidence[0] + 1\n    Insula.reset()\n\n\n\n    print(confidence)\n"
  },
  {
    "path": "examples/Social_Cognition/ToM/README.md",
    "content": "# Requirments\n* numpy\n* scipy\n* pytorch >= 1.7.0\n* torchvision\n* pygame\n\n# Run\n## Train \n* the file to be run: main_both.py \n* args:\n    * the path to save net_NPC: --save_net_N\n    * the path to save net_a: --save_net_a\n    * time steps: --T\n\n```bash\npython main_both.py --save_net_N=net_NPC.pth --save_net_a=net_agent.pth --episodes=45 --trajectories=30 --T=50 --mode=train --task=both\n```\n\n## Test\nYou can use the weigts saved by taining in the test environment.\n\n```bash\npython main_ToM.py --save_net_N=net_NPC.pth --save_net_a=net_agent.pth --episodes=45 --trajectories=30 --T=50 --mode=train --task=both\n```\n# Citation\n```\n@article{zhao2022brain,\n    title     = {A Brain-Inspired Theory of Mind Spiking Neural Network for Reducing Safety Risks of Other Agents},\n    author    = {Zhao, Zhuoya and Lu, Enmeng and Zhao, Feifei and Zeng, Yi and Zhao, Yuxuan},\n    journal   = {Frontiers in neuroscience},\n    pages     = {446},\n    year      = {2022},\n    publisher = {Frontiers}\n}\n```\n"
  },
  {
    "path": "examples/Social_Cognition/ToM/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Social_Cognition/ToM/data/NPC_assessment.csv",
    "content": ",0,1,2,3,4\n\"[[3, 5], [6, 6], [4, 2]]\",-0.868083432541418,-0.733763528003341,11.6748771364348,-1.13941568407204,-0.640638369281893\n\"[[3, 5], [6, 6], [4, 3]]\",-0.418851253559184,-0.353923546821901,2.45363281923355,-0.450103087254002,-0.37827704665847\n\"[[3, 5], [6, 6], [5, 3]]\",-0.021898491492153,-0.01,-0.01,-0.01,-0.011802114342782\n\"[[4, 5], [6, 6], [5, 2]]\",-0.024908407481436,-0.0199,0,-0.01009,-0.01017991\n\"[[3, 5], [6, 6], [5, 2]]\",-0.023268882562677,-0.0199,-0.029791,-0.0298801,-0.037630851041005\n\"[[3, 4], [6, 6], [5, 3]]\",-0.021197399198277,-0.01,-0.0199,-0.013038149463134,-0.01999\n\"[[3, 3], [6, 6], [5, 3]]\",0,-0.5,0,0,-0.01\n\"[[3, 3], [6, 6], [5, 2]]\",0,0,0,-0.01,0\n\"[[3, 4], [6, 6], [5, 1]]\",-0.01,0,-0.01,-0.01009,0\n\"[[3, 4], [6, 6], [4, 1]]\",-0.0101791,-0.0199,-0.01,-0.01,-0.01\n\"[[3, 3], [6, 6], [4, 1]]\",-0.01,0,-0.01,0,0\n\"[[3, 2], [6, 6], [4, 2]]\",0,-0.01,0,0,0\n\"[[4, 2], [6, 6], [4, 3]]\",0,0,0,-0.01,0\n\"[[4, 3], [6, 6], [4, 4]]\",0.089467366020037,-0.010984993825322,-0.01,-0.013452906086556,-0.011867720274541\n\"[[4, 2], [6, 6], [4, 5]]\",4.09524024945437,-0.090412523006309,-0.09623134888059,-0.104387744707876,-0.070973835764104\n\"[[3, 2], [6, 6], [4, 5]]\",1.40391860495941,0.022997490426814,43.0024182397514,1.53202462973447,3.08888716957364\n\"[[2, 2], [6, 6], [4, 5]]\",-0.019749094852309,0.036781365436129,15.9675319796684,-0.020051902468894,0.04621779219384\n\"[[4, 1], [6, 6], [4, 5]]\",5.50845885876033,-0.030316371018288,-0.01999,0.003924246487139,-0.01999\n\"[[5, 1], [6, 6], [4, 5]]\",-0.04218527817012,-0.089619886118225,-0.089555438562711,-0.09202879886063,-0.089640838741259\n\"[[5, 2], [6, 6], [4, 5]]\",-0.181115480137325,-0.188135727209872,-0.190391091543011,-0.189212077096485,-0.18825217812927\n\"[[4, 3], [6, 6], [4, 5]]\",3.20388938261604,-0.277669046855729,-0.263170616607684,-0.263425702870801,-0.273172354685186\n\"[[4, 4], [6, 6], [4, 5]]\",0.088264809806085,-0.536047549432865,-0.496801753642635,-1.7227487525,-0.525308942251634\n\"[[3, 3], [6, 6], [4, 5]]\",-0.21666970575642,-0.137772024778711,35.7440209050429,0.074933770913733,1.72993072468866\n\"[[2, 3], [6, 6], [6, 6]]\",-0.059222187527032,-0.030584932730533,1.19896558354272,-0.02997001,-0.039803368321337\n\"[[3, 4], [6, 6], [4, 5]]\",-0.40734083054261,-0.590251991054757,12.0029208279787,-0.553842323501647,-0.473260640862758\n\"[[4, 5], [6, 6], [4, 5]]\",0,0,0,0,0\n\"[[3, 5], [6, 6], [4, 4]]\",-0.481765407319486,-0.511211461721295,-0.564380288792446,-0.492462845628926,-0.469959754386144\n\"[[3, 5], [6, 6], [4, 5]]\",-0.931224461309762,-4.42392016540183,0.839256941677985,-0.924126498596367,-0.93129552223439\n\"[[3, 1], [6, 6], [4, 5]]\",49.811526809228,0.227846091910513,5.03145376471985,3.02887653427164,4.38800221134503\n\"[[2, 1], [6, 6], [4, 5]]\",0,0,0,0,0\n\"[[3, 4], [6, 6], [4, 3]]\",1.19380671730375,-5.00699176831198,20.0406112237339,-0.271487593364538,18.7143725924663\n\"[[3, 4], [6, 6], [4, 4]]\",0.218803966115085,-0.17930117271935,27.2213501282548,-0.25238045900879,0.399605811763158\n\"[[5, 3], [6, 6], [4, 5]]\",-0.375176845280851,-0.382612376673214,-0.387234215280032,-0.391512374343505,-0.38256981360011\n\"[[5, 4], [6, 6], [4, 5]]\",-0.648971759046833,-0.647885303763355,-0.648165623973329,-0.644724798572273,-0.648282307415749\n\"[[4, 5], [6, 6], [3, 2]]\",-0.01999,-0.029791,-0.0199,-0.0199891,-0.029882503677127\n\"[[3, 5], [6, 6], [3, 2]]\",-0.034237585544166,-0.029789209,-0.0297901,-0.031801666387697,-0.023636229278446\n\"[[4, 5], [6, 6], [4, 3]]\",-0.254773250904236,-0.247726142705935,-0.4975,-0.251602554065065,-0.253564419614127\n\"[[4, 5], [6, 6], [4, 4]]\",-0.414539088621474,-0.413180632405154,-0.45683393116936,-0.4975,-0.98509975\n\"[[5, 5], [6, 6], [4, 5]]\",-2.88794812477989,-0.897593650362875,-0.896337578776872,-0.897383653648252,-0.896911728329473\n\"[[3, 4], [6, 6], [4, 2]]\",-0.041846129714671,-0.0491890501,-0.049128264132635,-0.04752963742507,-0.041122932699059\n\"[[4, 4], [6, 6], [4, 3]]\",-0.010582728786191,-0.0199,-0.019606621899442,-0.023027393843706,-0.25\n\"[[4, 4], [6, 6], [4, 4]]\",0,0,0,0,0\n\"[[1, 3], [6, 6], [6, 6]]\",-0.02997001,-0.037241738974595,-0.038877118273757,-0.02997001,-0.02997001\n\"[[1, 2], [6, 6], [6, 6]]\",-0.01,0.092960550898947,-0.01,-0.01,-0.01\n\"[[1, 1], [6, 6], [4, 5]]\",0,0.25,0,0,-0.01\n\"[[5, 5], [6, 6], [4, 2]]\",-0.011245609205401,-0.010090875348172,-0.01,-0.01,-0.01\n\"[[5, 5], [6, 6], [4, 3]]\",-0.023589076057546,-0.030475766514878,-0.0199,-0.010097260907938,-0.020481507078377\n\"[[5, 4], [6, 6], [4, 4]]\",-0.014561932553648,-0.014123403608977,-0.012977094986838,-0.013958043392411,-0.0102689209\n\"[[5, 5], [6, 6], [3, 3]]\",-0.01,0,0,0,-0.01009\n\"[[5, 5], [6, 6], [4, 4]]\",-0.742525,-0.112519042227942,-0.104809666804011,-0.120899601119155,-0.114009048694518\n\"[[3, 3], [6, 6], [4, 4]]\",-0.010153907821337,-0.013354155304859,3.90887441187658,-0.020453180303816,-0.010444104437097\n\"[[3, 5], [6, 6], [4, 1]]\",-0.010526678655391,-0.01999,-0.0199,-0.013489754345763,-0.020341425501639\n\"[[4, 5], [6, 6], [4, 2]]\",-0.040636256244247,-0.0297901,-0.029794680345807,-0.032266639674154,-0.030144504483964\n\"[[4, 5], [6, 6], [4, 1]]\",-0.022834041173462,-0.029701,-0.029701,-0.030019473640868,-0.029791\n\"[[4, 4], [6, 6], [4, 2]]\",-0.010267309,-0.01999,-0.25,-0.021462046263096,-0.01\n\"[[5, 5], [6, 6], [5, 4]]\",0,-0.01,0,0,-0.010213344196849\n\"[[3, 5], [6, 6], [3, 4]]\",-0.013599152071303,0,-0.010345480742543,-0.01,-0.25\n\"[[3, 5], [6, 6], [3, 5]]\",0,0,0,0,0\n\"[[3, 5], [6, 6], [3, 3]]\",-0.011856953925967,-0.01999,-0.0199,-0.032817580205153,-0.010785322603466\n\"[[4, 4], [6, 6], [5, 2]]\",0,-0.01,-0.01,-0.01009,-0.01\n\"[[5, 4], [6, 6], [4, 2]]\",0,0,-0.01,-0.01,0\n\"[[3, 3], [6, 6], [4, 3]]\",-0.01,-0.01,-0.0199,-0.005562516126618,-0.009541766156335\n\"[[3, 2], [6, 6], [4, 4]]\",0,0,-0.01,-0.01,0.236388148843983\n\"[[4, 3], [6, 6], [4, 2]]\",0,-0.01,0,-0.01,0\n\"[[5, 3], [6, 6], [4, 3]]\",-0.01,0,0,0,0\n\"[[4, 5], [6, 6], [5, 3]]\",-0.01,-0.01009,-0.0199,-0.010784041942122,-0.01\n\"[[3, 5], [6, 6], [5, 4]]\",-0.012894421101825,-0.012100903162978,-0.007186044736698,0,-0.011442680105686\n\"[[3, 4], [6, 6], [3, 2]]\",-0.01009,-0.01999,-0.0199,-0.01009,-0.010091626994647\n\"[[4, 4], [6, 6], [5, 4]]\",0,0,0,-0.010869859082056,0\n\"[[4, 5], [6, 6], [3, 3]]\",-0.01,-0.01,-0.01999,-0.01,-0.011561635811299\n\"[[5, 4], [6, 6], [4, 3]]\",0,0,-0.01,0,0\n\"[[5, 3], [6, 6], [4, 4]]\",0,-0.012647205121462,0,-0.015190262813025,0\n\"[[3, 2], [6, 6], [4, 3]]\",0,-0.01,0,0,-0.01\n\"[[3, 1], [6, 6], [4, 3]]\",0,-0.01,0,0,0\n\"[[4, 1], [6, 6], [4, 4]]\",0,0,0,0,-0.002916121937907\n\"[[4, 3], [6, 6], [4, 3]]\",0,0,0,0,0\n\"[[5, 5], [6, 6], [3, 4]]\",-0.012969708589991,-0.01,-0.01,-0.010643056796965,0\n\"[[3, 4], [6, 6], [3, 4]]\",0,0,0,0,0\n\"[[2, 3], [6, 6], [4, 4]]\",0,0,0,-0.01,0\n\"[[2, 3], [6, 6], [5, 4]]\",0,0,-0.01,0,0\n\"[[2, 2], [6, 6], [4, 4]]\",-0.01,0,0,-0.008244859693453,0\n\"[[1, 2], [6, 6], [4, 3]]\",0,-0.01,0,0,0\n\"[[3, 4], [6, 6], [5, 2]]\",-0.01999,-0.01999,-0.0199,-0.01999,-0.01035517099859\n\"[[3, 3], [6, 6], [4, 2]]\",0,0,-0.01,-0.010197581295095,0\n\"[[3, 3], [6, 6], [5, 4]]\",0,0,0,-0.01009,0\n\"[[5, 5], [6, 6], [3, 1]]\",-0.01999081,-0.01999,-0.01999,-0.01999,-0.01999\n\"[[5, 5], [6, 6], [2, 1]]\",-0.01,0,-0.01,-0.01,-0.01\n\"[[5, 4], [6, 6], [2, 2]]\",0,0,-0.01,-0.01,-0.01\n\"[[5, 4], [6, 6], [3, 2]]\",-0.01999,-0.0199,-0.01,-0.01,-0.01\n\"[[5, 3], [6, 6], [3, 1]]\",-0.01,-0.01,0,0,0\n\"[[4, 3], [6, 6], [4, 1]]\",-0.01,-0.01,-0.01,-0.01,-0.01\n\"[[4, 2], [6, 6], [5, 1]]\",0,-0.01,0,-0.01,0\n\"[[5, 2], [6, 6], [5, 1]]\",-0.01,0,0,0,0\n\"[[4, 2], [6, 6], [5, 2]]\",0,0,0,0,-0.01\n\"[[5, 4], [6, 6], [3, 3]]\",-0.01,-0.01,-0.01,-0.01,-0.01\n\"[[4, 4], [6, 6], [3, 2]]\",-0.01,-0.01,-0.0199,-0.0200791,-0.0199\n\"[[4, 5], [6, 6], [6, 6]]\",-0.01009,-0.0199,-0.0199,-0.0199,-0.01999\n\"[[4, 4], [6, 6], [5, 3]]\",0,0,0,-0.01,-0.01\n\"[[5, 5], [6, 6], [2, 2]]\",-0.01,-0.01,-0.01,-0.01,0\n\"[[5, 4], [6, 6], [2, 3]]\",-0.01,-0.01,0,-0.01,-0.01\n\"[[4, 5], [6, 6], [2, 2]]\",-0.010198196967161,-0.01999,-0.01,-0.01009,-0.01\n\"[[4, 5], [6, 6], [1, 1]]\",0,-0.01,0,0,-0.01\n\"[[4, 5], [6, 6], [2, 1]]\",-0.01,-0.01,-0.01,-0.0199,-0.01\n\"[[4, 5], [6, 6], [3, 1]]\",-0.01,-0.01009,-0.0199,-0.01,-0.01009\n\"[[4, 3], [6, 6], [3, 3]]\",0,0,-0.01,0,0\n\"[[4, 2], [6, 6], [3, 4]]\",-0.008998626598505,0,0,0,0\n\"[[4, 5], [6, 6], [5, 1]]\",-0.01009,-0.0199,-0.0199,-0.020016673101577,-0.01999\n\"[[4, 4], [6, 6], [5, 1]]\",-0.01,-0.01,-0.01,-0.01009,-0.01\n\"[[3, 5], [6, 6], [5, 1]]\",-0.01009,-0.01,-0.0199,-0.01999,-0.01999\n\"[[5, 5], [6, 6], [5, 1]]\",-0.01009,0,-0.01,-0.0199,-0.01\n\"[[5, 5], [6, 6], [4, 1]]\",-0.01,-0.01,-0.0199,-0.01,-0.01999\n\"[[4, 4], [6, 6], [3, 1]]\",0,0,0,-0.01,-0.01\n\"[[5, 5], [6, 6], [1, 1]]\",-0.01,0,-0.01,-0.01,0\n\"[[5, 5], [6, 6], [1, 2]]\",-0.01,-0.01,0,-0.01,-0.01\n\"[[4, 4], [6, 6], [2, 1]]\",0,-0.01,-0.01,-0.01,-0.01\n\"[[4, 3], [6, 6], [3, 1]]\",-0.01,0,0,-0.01,-0.01\n\"[[4, 4], [6, 6], [4, 1]]\",-0.010289510270133,-0.01,-0.0199,-0.02007991,-0.01999\n\"[[5, 4], [6, 6], [4, 1]]\",0,0,-0.01,0,-0.01\n\"[[5, 4], [6, 6], [3, 1]]\",-0.01,0,0,-0.01,-0.01\n\"[[3, 4], [6, 6], [2, 2]]\",0,-0.01,0,-0.01,-0.01\n\"[[3, 4], [6, 6], [6, 6]]\",-0.01,-0.01,-0.01,-0.01999,-0.01999\n\"[[3, 3], [6, 6], [1, 2]]\",0,-0.01,0,0,0\n\"[[4, 3], [6, 6], [1, 3]]\",0,0,-0.01,-0.01,0\n\"[[4, 4], [6, 6], [6, 6]]\",-0.01,-0.01,-0.0199,-0.01,-0.01\n\"[[5, 5], [6, 6], [2, 3]]\",-0.01,-0.01999,-0.01999,-0.01999,-0.01\n\"[[5, 4], [6, 6], [1, 3]]\",-0.0199,-0.01,-0.01,-0.01,-0.01\n\"[[5, 4], [6, 6], [1, 2]]\",0,-0.01,-0.01,-0.01,0\n\"[[5, 4], [6, 6], [1, 1]]\",-0.0199,-0.01,-0.01,-0.01,-0.01\n\"[[4, 4], [6, 6], [1, 1]]\",-0.01,-0.01,-0.01,0,0\n\"[[5, 3], [6, 6], [1, 1]]\",0,0,-0.01,-0.01009,0\n\"[[5, 2], [6, 6], [1, 1]]\",-0.01,0,0,-0.01,-0.01\n\"[[5, 3], [6, 6], [1, 2]]\",0,-0.01,0,-0.01,0\n\"[[5, 3], [6, 6], [1, 3]]\",-0.01,-0.01,-0.01,0,-0.0199\n\"[[5, 3], [6, 6], [2, 3]]\",0,0,-0.01,0,0\n\"[[5, 2], [6, 6], [2, 3]]\",-0.01,0,0,-0.01,0\n\"[[5, 3], [6, 6], [3, 3]]\",0,0,0,-0.01,0\n\"[[4, 3], [6, 6], [1, 2]]\",-0.01,-0.01,-0.01,-0.01,-0.01\n\"[[3, 3], [6, 6], [2, 2]]\",-0.01,-0.01,-0.01,-0.01,-0.01\n\"[[2, 3], [6, 6], [3, 2]]\",0,0,-0.01,-0.01,0\n\"[[2, 3], [6, 6], [3, 1]]\",-0.01,-0.01,-0.01,-0.01,-0.01\n\"[[2, 3], [6, 6], [2, 1]]\",0,-0.01,0,0,-0.01\n\"[[3, 2], [6, 6], [2, 3]]\",0,-0.01,0,0,0\n\"[[4, 2], [6, 6], [2, 2]]\",-0.01,0,0,-0.01,0\n\"[[3, 2], [6, 6], [2, 1]]\",-0.01,0,0,-0.01,0\n\"[[3, 3], [6, 6], [1, 1]]\",0,0,0,-0.01,-0.01\n\"[[3, 4], [6, 6], [2, 1]]\",-0.01,-0.01,0,0,-0.01\n\"[[4, 4], [6, 6], [2, 2]]\",0,-0.01,-0.01,0,0\n\"[[4, 3], [6, 6], [2, 2]]\",0,0,-0.01,-0.01009,0\n\"[[4, 2], [6, 6], [1, 2]]\",-0.01,0,-0.01,-0.01,-0.01\n\"[[3, 2], [6, 6], [1, 2]]\",0,0,-0.01,0,0\n\"[[3, 1], [6, 6], [1, 1]]\",0,0,-0.01,-0.01,0\n\"[[3, 2], [6, 6], [1, 1]]\",-0.01,-0.01,0,0,0\n\"[[4, 1], [6, 6], [1, 2]]\",-0.01,0,0,0,0\n\"[[3, 1], [6, 6], [1, 3]]\",0.5,0,0,0,-0.01\n\"[[2, 1], [6, 6], [1, 3]]\",0.995,0,0,0,0\n\"[[2, 1], [6, 6], [2, 3]]\",1.489505,0,0,0,0\n\"[[2, 1], [6, 6], [2, 2]]\",0.5,0,0,0,0\n\"[[2, 1], [6, 6], [3, 2]]\",0.5,0,0,0,0\n\"[[2, 1], [6, 6], [4, 2]]\",0.5,0,1.0039955,0,0\n\"[[2, 1], [6, 6], [5, 2]]\",0.995,0,0,0,0\n\"[[2, 1], [6, 6], [5, 3]]\",0.995,0,0,0,0\n\"[[2, 1], [6, 6], [4, 3]]\",0.5045,0,0.5045,0,0\n\"[[2, 1], [6, 6], [5, 4]]\",0.995,0,0,0,0\n\"[[2, 1], [6, 6], [4, 4]]\",0.5,0,0.5,0,0\n\"[[4, 4], [6, 6], [1, 2]]\",-0.01,0,0,-0.01009,0\n\"[[4, 3], [6, 6], [1, 1]]\",-0.01,0,0,0,0\n\"[[3, 3], [6, 6], [2, 1]]\",-0.0199,0,0,0,0\n\"[[1, 3], [6, 6], [3, 1]]\",0,0,-0.01,0,0\n\"[[1, 2], [6, 6], [3, 1]]\",-0.01,0,0,0,-0.01\n\"[[1, 2], [6, 6], [4, 1]]\",-0.01,0,-0.0199,0,0\n\"[[1, 1], [6, 6], [5, 1]]\",0,0,0,-0.01,-0.01\n\"[[1, 2], [6, 6], [5, 1]]\",0,-0.01,0,-0.01,0\n\"[[1, 3], [6, 6], [5, 1]]\",-0.01,0,0,0,0\n\"[[1, 3], [6, 6], [5, 2]]\",0,0,-0.01,0,-0.01\n\"[[1, 2], [6, 6], [5, 2]]\",-0.01,-0.01,0,-0.01,0\n\"[[2, 2], [6, 6], [5, 1]]\",-0.01,0,0,0,-0.01\n\"[[2, 2], [6, 6], [4, 1]]\",-0.01,0,0,0,0\n\"[[1, 1], [6, 6], [4, 2]]\",-0.01,0,0,0,-0.01\n\"[[1, 1], [6, 6], [4, 1]]\",0,0,0,0,-0.01\n\"[[1, 1], [6, 6], [5, 2]]\",0,0,0,-0.01,-0.01\n\"[[1, 2], [6, 6], [5, 3]]\",0,0,-0.01,0,-0.01\n\"[[1, 1], [6, 6], [5, 3]]\",0,0,0,-0.01,0\n\"[[1, 3], [6, 6], [4, 2]]\",-0.01,0,0,0,-0.01\n\"[[1, 3], [6, 6], [5, 3]]\",0,0,0,0,-0.01\n\"[[1, 3], [6, 6], [5, 4]]\",0,0,0,0,-0.010266711022056\n\"[[5, 5], [6, 6], [3, 2]]\",-0.010094761832099,0,0,0,0\n\"[[3, 5], [6, 6], [3, 1]]\",-0.01,-0.0101791,-0.01,0,-0.01\n\"[[3, 5], [6, 6], [2, 1]]\",-0.01,-0.01009,-0.01,0,0\n\"[[3, 5], [6, 6], [6, 6]]\",-0.01999,-0.0199891,-0.01999,-0.01999,-0.01999\n\"[[3, 4], [6, 6], [1, 1]]\",0,-0.01,-0.01,0,0\n\"[[4, 5], [6, 6], [3, 4]]\",0,0,-0.01,0,0\n\"[[4, 4], [6, 6], [3, 5]]\",0,-0.01,0,0,0\n\"[[5, 4], [6, 6], [3, 4]]\",0,0,0,-0.01,-0.01\n\"[[3, 2], [6, 6], [5, 2]]\",0,0,-0.01,0,0\n\"[[3, 1], [6, 6], [5, 3]]\",0,0,-0.01,-0.01,0\n\"[[4, 2], [6, 6], [4, 4]]\",0,0,0.004567853841478,0,0\n\"[[5, 4], [6, 6], [2, 1]]\",0,0,0,0,-0.01\n\"[[3, 3], [6, 6], [3, 1]]\",-0.01,0,0,-0.01009,0\n\"[[5, 5], [6, 6], [5, 2]]\",0,-0.01,-0.01,0,0\n\"[[5, 4], [6, 6], [5, 1]]\",-0.01009,0,-0.0199,-0.01,-0.01\n\"[[5, 3], [6, 6], [5, 1]]\",-0.01,0,-0.5,0,0\n\"[[5, 2], [6, 6], [5, 2]]\",0,0,0,0,0\n\"[[3, 4], [6, 6], [3, 3]]\",0,0,0,-0.01,0\n\"[[5, 3], [6, 6], [4, 2]]\",0,0,-0.01,-0.01,0\n\"[[5, 2], [6, 6], [4, 3]]\",0,0,-0.01,-0.01,0\n\"[[5, 1], [6, 6], [4, 4]]\",0,0,-0.010759540535295,0,0\n\"[[4, 4], [6, 6], [2, 3]]\",0,-0.01,0,0,0\n\"[[5, 5], [6, 6], [1, 3]]\",-0.01009,0,-0.01,0,0\n\"[[4, 2], [6, 6], [3, 3]]\",0,-0.01,0,0,0\n\"[[3, 3], [6, 6], [5, 1]]\",0,-0.01,0,0,0\n\"[[5, 3], [6, 6], [2, 1]]\",0,-0.01,-0.01,0,-0.01\n\"[[2, 2], [6, 6], [3, 1]]\",0,0,0,-0.01,-0.01\n\"[[2, 3], [6, 6], [4, 1]]\",-0.01,0,0,-0.01,-0.01\n\"[[1, 3], [6, 6], [4, 1]]\",0,0,-0.01,-0.01,0\n\"[[1, 2], [6, 6], [3, 2]]\",0,0,-0.01,0,0\n\"[[1, 1], [6, 6], [3, 2]]\",0,0,-0.01,0,0\n\"[[1, 1], [6, 6], [3, 1]]\",0,0,0,-0.01,0\n\"[[1, 2], [6, 6], [2, 1]]\",-0.01,0,0,-0.01,-0.01\n\"[[1, 3], [6, 6], [2, 2]]\",0,0,0,0,-0.01\n\"[[1, 3], [6, 6], [2, 1]]\",0,0,-0.01,0,0\n\"[[1, 2], [6, 6], [1, 1]]\",-0.5,0,0,0,0\n\"[[1, 2], [6, 6], [1, 2]]\",0,0,0,0,0\n\"[[2, 3], [6, 6], [4, 2]]\",0,0,-0.01,0,0\n\"[[2, 2], [6, 6], [4, 2]]\",0,0,0.5,-0.01,0\n\"[[2, 3], [6, 6], [4, 3]]\",0,-0.002363469450096,0,0,0\n\"[[4, 4], [6, 6], [3, 3]]\",-0.011265121636209,0,-0.01,0,0\n\"[[4, 3], [6, 6], [2, 3]]\",-0.01,0,0,0,0\n\"[[5, 3], [6, 6], [2, 2]]\",0,0,0,0,-0.01\n\"[[4, 2], [6, 6], [1, 1]]\",-0.01,0,-0.01,0,0\n\"[[4, 1], [6, 6], [1, 1]]\",0,0,0,-0.01,0\n\"[[2, 2], [6, 6], [2, 1]]\",0,-0.01,0,-0.01,0\n\"[[2, 1], [6, 6], [4, 1]]\",0,0,1.988015495,0,0\n\"[[2, 1], [6, 6], [5, 1]]\",0,0,1.988015495,0,0\n\"[[3, 4], [6, 6], [3, 1]]\",0,-0.01009,0,0,0\n\"[[4, 3], [6, 6], [5, 1]]\",0,-0.01,0,-0.01,-0.01\n\"[[5, 3], [6, 6], [4, 1]]\",0,-0.01,0,-0.01,0\n\"[[5, 3], [6, 6], [3, 2]]\",0,0,0,-0.01,0\n"
  },
  {
    "path": "examples/Social_Cognition/ToM/data/agent_assessment.csv",
    "content": ",0,1,2,3,4\n\"[[3, 5], [6, 6], [4, 2]]\",-0.052875711096120555,-0.09588695110964494,-0.06793465209301,-0.05438523750227702,-0.06406147756340548\n\"[[3, 4], [6, 6], [4, 3]]\",-0.01009,-0.5,-0.5,-0.01998648239166538,-0.010090809999999999\n\"[[3, 5], [3, 3], [4, 4]]\",-0.5216067506862471,-1.9701995,-0.2957229880478374,-0.5521303794833469,-0.291912651239883\n\"[[3, 5], [4, 3], [4, 5]]\",-0.17020038264944942,-0.995,-0.17003066056014654,-0.17084067541667597,-0.18061252619011572\n\"[[3, 5], [5, 3], [4, 5]]\",-0.310721791006476,-0.995,-0.3044196312169591,-0.3066034117158839,-0.30258164309452773\n\"[[3, 4], [5, 4], [4, 5]]\",-0.5068849351000232,-0.5215919496170031,0.33272479226972446,-0.529392040004043,-0.5476025113537396\n\"[[3, 5], [5, 4], [4, 5]]\",-0.9653009777091316,-8.274311927495619,-0.9562443360051255,-0.9663825574739041,-0.9577571287880812\n\"[[3, 3], [5, 4], [4, 5]]\",-0.12144345586147146,-0.129151719633336,6.61777080738469,-0.12913417684277412,-0.12187211348591424\n\"[[3, 2], [5, 4], [4, 5]]\",22.462534983722144,-0.037566698284270936,0.0002049650900000019,-0.020492300663398463,0.36064342058027793\n\"[[2, 3], [5, 4], [6, 6]]\",-0.0199,-0.01,1.911710899685282,-0.02507112120629763,-0.01\n\"[[4, 4], [5, 4], [4, 5]]\",-0.38856856345777463,-0.995,-0.3848636851605009,-0.5,-0.39216615065025906\n\"[[4, 3], [5, 4], [4, 5]]\",-0.042040801551562916,-0.17418582965398824,-0.17906219922741637,-0.20272904839771475,-0.17845067061226685\n\"[[2, 2], [5, 4], [4, 5]]\",-0.03593611593419173,0.19699025702983197,41.38750349042489,-0.025347536417836183,0.5235719953054313\n\"[[1, 3], [5, 4], [6, 6]]\",-0.01999,-0.010049500000000001,-0.0199,-0.01999,-0.01999\n\"[[1, 2], [5, 4], [6, 6]]\",-0.01999,1.1413350726366454,-0.01,-0.01009,-0.02969406465902331\n\"[[4, 5], [5, 4], [4, 5]]\",0.0,0.0,0.0,0.0,0.0\n\"[[4, 5], [6, 6], [4, 3]]\",-0.022056368799828086,-0.0199,-0.5,-0.0199,-0.01999\n\"[[4, 5], [3, 3], [4, 4]]\",-0.05334353898157222,-0.058519850599,-0.058698950598999995,-0.5,-0.5\n\"[[4, 5], [4, 3], [4, 5]]\",-0.5,-1.9701995,0.0,-0.5,-0.5\n\"[[4, 5], [5, 3], [4, 5]]\",-0.995,-3.3967326046505,0.0,-0.995,-0.5\n\"[[3, 5], [6, 6], [4, 3]]\",-0.5052708076706848,-0.6503942485687713,-0.31464561043008055,-0.5567063168667764,-0.3196072406789152\n\"[[3, 4], [4, 3], [4, 5]]\",-0.06538757013815065,-0.06815609254531127,0.07157941051049484,-0.06943525585877257,-0.07243109013283643\n\"[[3, 4], [5, 3], [4, 5]]\",-0.03627763723533166,-0.04878628543714749,0.4425679590136676,-0.04432926021182233,-0.03615058165751049\n\"[[5, 3], [5, 4], [4, 5]]\",-0.1326255476738545,-0.12913304314430477,-0.12670229132838703,-0.995,-0.12922285286285284\n\"[[5, 4], [5, 4], [4, 5]]\",0.0,0.0,0.0,0.0,0.0\n\"[[3, 4], [3, 3], [4, 4]]\",-0.08912453287009733,-0.514891,-0.08664897275426689,-0.0850562755111616,-0.08029821601054066\n\"[[3, 3], [4, 3], [4, 5]]\",-0.01,-0.01,-0.01,-0.009878977320149305,0.020248904987081453\n\"[[4, 3], [5, 3], [4, 5]]\",0.0,-0.011055227128542065,0.0,-0.01,-0.01101218190086918\n\"[[5, 5], [4, 3], [4, 5]]\",-0.5,-0.01,-0.0199,0.0,-0.01\n\"[[5, 4], [5, 3], [4, 5]]\",0.0,-0.5,-0.01,-0.010717485033705036,-0.5\n\"[[5, 2], [5, 4], [4, 5]]\",-0.07919569212102055,-0.07971975334246782,-0.0787530179209123,-0.08104017906592893,-0.07963109795235838\n\"[[5, 1], [5, 4], [4, 5]]\",-0.04890803822116347,-0.049900099950010005,-0.049900099950010005,-0.05034279750865073,-0.049900099950010005\n\"[[4, 1], [5, 4], [4, 5]]\",0.10944099197973417,-0.01999,-0.01999,-0.02007991,-0.01999\n\"[[4, 2], [5, 4], [4, 5]]\",0.15058860156940262,-0.059846938189381006,-0.06722016746200328,-0.07174124407437676,-0.059850199850059994\n\"[[3, 1], [5, 4], [4, 5]]\",5.680756414193536,-0.01,-0.01,-0.01,-0.01\n\"[[2, 1], [5, 4], [4, 5]]\",0.0,0.0,0.0,0.0,0.0\n\"[[4, 4], [3, 3], [4, 4]]\",0.0,-0.5,-0.5,0.0,0.0\n\"[[4, 4], [4, 3], [4, 4]]\",0.0,-0.5,-0.5,0.0,0.0\n\"[[4, 4], [5, 3], [4, 4]]\",0.0,-0.5,-0.5,0.0,0.0\n\"[[4, 4], [5, 4], [4, 4]]\",0.0,0.0,0.0,0.0,0.0\n\"[[3, 4], [6, 6], [5, 2]]\",0.0,0.0,0.0,0.0,-0.01\n\"[[3, 4], [3, 3], [5, 2]]\",0.0,-0.01,0.0,0.0,0.0\n\"[[4, 4], [4, 3], [5, 2]]\",0.0,0.0,-0.5,0.0,0.0\n\"[[4, 3], [5, 3], [5, 3]]\",0.0,0.0,0.0,0.0,0.0\n\"[[3, 5], [3, 3], [4, 3]]\",-2.2448043853624212e-05,-0.01,-0.01,-6.208892285322127e-05,-0.01\n\"[[3, 5], [4, 3], [4, 4]]\",-0.01028609422864929,-0.5,-0.016408702021722208,-0.020784796048626684,-0.011281391343046818\n\"[[3, 5], [3, 3], [4, 2]]\",-0.5,-0.5,-0.5,-0.9900500000000001,-0.5\n\"[[4, 5], [4, 3], [4, 3]]\",0.0,0.0,0.0,0.0,0.0\n\"[[3, 5], [4, 3], [3, 4]]\",-0.01,0.0,0.0,-0.5,0.0\n\"[[3, 5], [5, 3], [3, 5]]\",0.0,0.0,0.0,-0.5,0.0\n\"[[3, 5], [5, 4], [3, 5]]\",0.0,0.0,0.0,0.0,0.0\n\"[[4, 5], [6, 6], [5, 2]]\",0.0,0.0,-0.01,0.0,0.0\n\"[[4, 4], [3, 3], [5, 1]]\",0.0,0.0,0.0,0.0,-0.01\n\"[[4, 4], [4, 3], [4, 1]]\",0.0,0.0,0.0,0.0,-0.01\n\"[[4, 4], [5, 3], [4, 1]]\",0.0,0.0,-0.01,0.0,0.0\n\"[[4, 3], [5, 4], [3, 1]]\",0.0,0.0,0.0,-0.01,0.0\n\"[[4, 4], [5, 4], [3, 1]]\",0.0,-0.5,0.0,0.0,0.0\n\"[[5, 4], [5, 4], [3, 1]]\",0.0,0.0,0.0,0.0,0.0\n\"[[1, 1], [5, 4], [4, 5]]\",0.0,0.0,0.0,-0.01,0.0\n\"[[4, 4], [4, 3], [4, 5]]\",-0.016306998715494642,-0.0199,-0.0199,-0.5,-0.01009945645877131\n\"[[4, 4], [5, 3], [4, 5]]\",-0.02368161102075795,-0.5,-0.022639868292312258,-0.5,-0.026451207697811334\n\"[[5, 5], [3, 3], [4, 4]]\",-0.5,0.0,-0.01,0.0,0.0\n\"[[5, 4], [4, 3], [4, 5]]\",0.0,0.0,0.0,-0.01,0.0\n\"[[5, 5], [5, 3], [4, 5]]\",-0.5,-0.01044910089955009,0.0,-0.01,0.0\n\"[[3, 5], [3, 3], [3, 3]]\",0.0,0.0,0.0,0.0,0.0\n\"[[3, 3], [5, 3], [4, 5]]\",0.0,-0.01062811314685189,3.5215627346534975,0.0,0.0\n\"[[3, 5], [4, 3], [4, 3]]\",0.0,0.0,0.0,0.0,0.0\n\"[[3, 5], [3, 3], [5, 3]]\",0.0,0.0,0.0,0.0,0.0\n\"[[3, 5], [4, 3], [5, 2]]\",0.0,-0.01,0.0,0.0,0.0\n\"[[4, 5], [5, 3], [4, 2]]\",-0.01,0.0,0.0,0.0,0.0\n\"[[3, 5], [5, 4], [4, 3]]\",0.0,0.0,0.0,-0.01,0.0\n\"[[3, 5], [5, 4], [4, 4]]\",-0.015943472966397428,-0.01,0.0,-0.01654250067191671,-0.01\n\"[[4, 5], [5, 4], [4, 3]]\",0.0,0.0,-0.5,-0.01,-0.01\n\"[[4, 5], [5, 4], [4, 4]]\",-0.01639600436420607,-0.01,0.0,0.0,0.0\n\"[[5, 5], [5, 4], [4, 5]]\",-1.9701995,-0.11934219505791074,-0.5,-0.11934219505791074,-0.10945164670461537\n\"[[3, 4], [4, 3], [5, 4]]\",0.0,0.0,0.0,-0.01,0.0\n\"[[3, 5], [5, 3], [4, 4]]\",0.0,0.0,-0.00819848349845159,-0.014389398124565982,-0.016963089341491665\n\"[[4, 5], [3, 3], [4, 2]]\",0.0,-0.5,-0.5,0.0,0.0\n\"[[4, 4], [4, 3], [4, 3]]\",0.0,0.0,0.0,0.0,0.0\n\"[[3, 5], [6, 6], [5, 2]]\",0.0,-0.01,-0.01,0.0,0.0\n\"[[5, 5], [4, 3], [4, 3]]\",0.0,0.0,0.0,0.0,0.0\n\"[[3, 5], [4, 3], [5, 4]]\",0.0,0.0,-0.0199,0.0,0.0\n\"[[3, 4], [5, 3], [4, 4]]\",0.0,0.0,0.0,0.0,-0.013570917448717714\n\"[[3, 4], [4, 3], [4, 4]]\",0.0,0.0,0.0,-0.010943974970380799,0.0\n\"[[4, 5], [4, 3], [4, 4]]\",-0.01,0.0,0.0,0.0,0.0\n\"[[3, 5], [5, 3], [3, 4]]\",-0.01,0.0,0.0,0.0,0.0\n\"[[3, 5], [6, 6], [3, 2]]\",0.0,0.0,-0.01,0.0,0.0\n\"[[3, 5], [3, 3], [3, 2]]\",0.0,0.0,0.0,0.0,0.0\n\"[[3, 5], [4, 3], [4, 2]]\",0.0,0.0,0.0,0.0,0.0\n\"[[3, 5], [5, 3], [4, 3]]\",0.0,-0.01,0.0,0.0,0.0\n\"[[3, 5], [6, 6], [4, 1]]\",0.0,0.0,0.0,0.0,-0.01\n\"[[3, 5], [3, 3], [4, 1]]\",0.0,-0.01,-0.01,0.0,0.0\n\"[[3, 4], [4, 3], [5, 1]]\",0.0,-0.01,0.0,0.0,0.0\n\"[[4, 4], [5, 3], [5, 1]]\",0.0,0.0,0.0,-0.01,0.0\n\"[[4, 5], [5, 4], [5, 1]]\",0.0,-0.01,0.0,0.0,0.0\n\"[[5, 5], [5, 4], [5, 1]]\",0.0,0.0,0.0,0.0,-0.01\n\"[[5, 5], [5, 4], [5, 2]]\",0.0,0.0,0.0,-0.01,0.0\n\"[[5, 5], [5, 4], [4, 2]]\",0.0,-0.01,-0.5,0.0,0.0\n\"[[5, 4], [5, 4], [4, 3]]\",0.0,0.0,0.0,0.0,0.0\n\"[[3, 4], [5, 3], [5, 4]]\",0.0,0.0,0.0,-0.01,0.0\n\"[[5, 5], [4, 3], [3, 4]]\",0.0,0.0,-0.01,0.0,0.0\n\"[[5, 4], [5, 3], [4, 4]]\",-0.5,0.0,0.0,0.0,0.0\n\"[[4, 4], [5, 4], [5, 4]]\",0.0,0.0,0.0,0.0,0.0\n\"[[4, 5], [4, 3], [3, 1]]\",0.0,-0.01,0.0,0.0,0.0\n\"[[5, 5], [5, 3], [4, 1]]\",0.0,0.0,0.0,-0.01,0.0\n\"[[5, 5], [5, 4], [4, 3]]\",0.0,0.0,0.0,-0.01,0.0\n\"[[5, 5], [5, 4], [4, 4]]\",0.0,-0.01,0.0,0.0,0.0\n\"[[3, 4], [4, 3], [4, 3]]\",0.0,0.0,0.0,0.0,0.0\n\"[[3, 4], [3, 3], [4, 2]]\",-0.5,0.0,-0.5,0.0,0.0\n\"[[3, 3], [4, 3], [4, 3]]\",0.0,0.0,0.0,0.0,0.0\n\"[[3, 3], [3, 3], [4, 4]]\",0.0,0.0,0.0,0.0,0.0\n\"[[3, 5], [5, 3], [4, 2]]\",0.0,-0.01,0.0,0.0,0.0\n\"[[3, 5], [4, 3], [5, 3]]\",0.0,-0.5,0.0,0.0,0.0\n\"[[4, 5], [5, 3], [5, 3]]\",0.0,0.0,0.0,0.0,0.0\n\"[[3, 2], [5, 3], [4, 5]]\",0.0,0.0,0.037097785583577604,0.0,0.0\n\"[[2, 3], [5, 3], [6, 6]]\",0.0,0.0,0.0,0.004391030075795664,0.0\n\"[[4, 5], [3, 3], [3, 3]]\",0.0,0.0,0.0,0.0,0.0\n\"[[3, 4], [3, 3], [4, 3]]\",0.0,0.0,0.0,-0.01009092457551116,0.0\n"
  },
  {
    "path": "examples/Social_Cognition/ToM/data/injury_memory.txt",
    "content": "[[2, 1], [6, 6], [2, 1]]\n[[3, 1], [6, 6], [3, 1]]\n[[3, 1], [6, 6], [3, 1]]\n[[2, 2], [6, 6], [2, 2]]\n[[2, 1], [6, 6], [2, 1]]\n[[3, 2], [6, 6], [3, 2]]\n[[2, 1], [6, 6], [2, 1]]\n[[3, 3], [6, 6], [3, 3]]\n[[5, 3], [6, 6], [5, 3]]\n[[4, 4], [6, 6], [4, 4]]\n[[4, 1], [6, 6], [4, 1]]\n[[3, 2], [6, 6], [3, 2]]\n[[3, 2], [6, 6], [3, 2]]\n[[3, 3], [6, 6], [3, 3]]\n[[4, 4], [6, 6], [4, 4]]\n[[4, 4], [6, 6], [4, 4]]\n[[4, 2], [6, 6], [3, 5]]\n[[4, 3], [6, 6], [3, 4]]\n[[4, 4], [6, 6], [3, 3]]\n[[4, 5], [6, 6], [3, 3]]\n[[4, 5], [6, 6], [3, 2]]\n[[4, 5], [6, 6], [2, 2]]\n[[4, 5], [6, 6], [3, 2]]\n[[4, 5], [6, 6], [3, 2]]\n[[4, 5], [6, 6], [2, 2]]\n[[4, 5], [6, 6], [3, 2]]\n[[4, 5], [6, 6], [2, 2]]\n[[4, 2], [6, 6], [3, 5]]\n[[4, 3], [6, 6], [3, 4]]\n[[4, 4], [6, 6], [3, 3]]\n[[4, 5], [6, 6], [3, 4]]\n[[4, 5], [6, 6], [3, 5]]\n[[4, 5], [6, 6], [3, 4]]\n[[4, 5], [6, 6], [3, 3]]\n[[4, 5], [6, 6], [3, 2]]\n[[4, 5], [6, 6], [2, 2]]\n[[4, 2], [6, 6], [3, 5]]\n[[4, 3], [6, 6], [3, 5]]\n[[4, 1], [3, 3], [3, 3]]]\n[[3, 2], [5, 4], [3, 2]]]\n[[5, 1], [3, 3], [3, 3]]]\n[[3, 3], [3, 3], [3, 5]]]\n[[2, 2], [3, 3], [3, 3]]]\n[[3, 3], [3, 3], [3, 4]]]\n[[2, 1], [5, 4], [2, 1]]]\n[[4, 2], [5, 4], [5, 4]]]\n[[4, 3], [3, 3], [3, 3]]]\n[[5, 2], [3, 3], [3, 3]]]\n[[3, 1], [3, 3], [3, 3]]]\n[[5, 4], [5, 4], [2, 1]]]\n[[5, 4], [5, 4], [2, 2]]]\n[[4, 1], [5, 4], [4, 1]]]\n[[4, 3], [4, 3], [3, 4]]]\n[[4, 2], [3, 3], [3, 3]]]\n[[5, 4], [5, 4], [4, 1]]]\n[[4, 4], [5, 4], [4, 4]]]\n[[4, 3], [4, 3], [3, 5]]]\n[[5, 3], [3, 3], [3, 3]]]\n[[3, 2], [3, 3], [3, 3]]]\n[[5, 4], [5, 4], [3, 3]]]\n[[4, 3], [4, 3], [3, 3]]]\n[[5, 4], [5, 4], [3, 1]]]\n[[3, 1], [5, 4], [3, 1]]]\n[[4, 3], [4, 3], [1, 5]]]\n[[4, 3], [4, 3], [1, 4]]]\n[[4, 3], [4, 3], [2, 3]]]\n[[3, 1], [5, 4], [3, 1]]]\n[[3, 2], [4, 4], [4, 4]]]\n[[3, 1], [5, 4], [3, 1]]]\n[[4, 2], [3, 3], [3, 3]]]\n[[3, 1], [5, 4], [3, 1]]]\n[[3, 3], [3, 3], [2, 4]]]\n[[4, 4], [3, 3], [3, 3]]]\n[[4, 3], [4, 3], [2, 4]]]\n[[4, 4], [3, 3], [3, 3]]]\n[[4, 2], [3, 3], [3, 3]]]\n[[3, 3], [3, 3], [2, 5]]]\n[[4, 2], [3, 3], [3, 3]]]\n[[4, 2], [3, 3], [3, 3]]]\n[[4, 5], [5, 4], [4, 5]]]\n[[4, 4], [3, 3], [3, 3]]]\n[[4, 4], [3, 3], [3, 3]]]\n[[4, 4], [3, 3], [3, 3]]]\n[[5, 4], [5, 4], [3, 2]]]\n[[4, 4], [4, 4], [2, 4]]]\n[[4, 4], [4, 4], [3, 2]]]\n[[4, 3], [4, 3], [1, 4]]]\n[[4, 3], [4, 3], [2, 3]]]\n[[4, 4], [4, 4], [4, 4]]]\n[[4, 4], [3, 3], [3, 3]]]\n[[4, 4], [4, 4], [3, 3]]]\n[[4, 4], [4, 4], [3, 2]]]\n[[4, 3], [4, 3], [3, 3]]]\n[[4, 3], [4, 3], [3, 3]]]\n[[4, 5], [5, 4], [4, 5]]]\n[[4, 4], [3, 3], [3, 3]]]\n[[4, 4], [3, 3], [3, 3]]]\n[[4, 4], [4, 4], [4, 4]]]\n[[4, 4], [3, 3], [3, 3]]]\n[[3, 4], [5, 4], [3, 4]]]\n[[4, 5], [5, 4], [4, 5]]]\n[[4, 4], [4, 4], [4, 4]]]\n[[4, 5], [6, 6], [3, 1]]\n[[4, 2], [6, 6], [3, 5]]\n[[4, 3], [6, 6], [3, 5]]\n[[4, 4], [6, 6], [3, 4]]\n[[4, 5], [6, 6], [3, 4]]\n[[4, 5], [6, 6], [3, 3]]\n[[4, 5], [6, 6], [3, 3]]\n[[4, 5], [6, 6], [3, 2]]\n[[4, 5], [6, 6], [3, 1]]\n[[4, 2], [6, 6], [3, 5]]\n[[4, 3], [6, 6], [3, 5]]\n[[4, 4], [6, 6], [3, 4]]\n[[4, 5], [6, 6], [3, 5]]\n[[4, 5], [6, 6], [3, 5]]\n[[4, 5], [6, 6], [3, 4]]\n[[4, 5], [6, 6], [3, 3]]\n[[4, 5], [6, 6], [3, 3]]\n[[4, 5], [6, 6], [3, 2]]\n[[4, 5], [6, 6], [2, 2]]\n[[4, 5], [6, 6], [3, 2]]\n[[4, 5], [6, 6], [3, 2]]\n[[4, 5], [6, 6], [3, 1]]\n[[4, 5], [6, 6], [4, 1]]\n[[4, 5], [6, 6], [4, 2]]\n[[4, 5], [6, 6], [3, 2]]\n[[4, 5], [6, 6], [3, 3]]\n[[4, 5], [6, 6], [3, 4]]\n[[4, 5], [6, 6], [3, 4]]\n[[4, 5], [6, 6], [3, 3]]\n[[4, 5], [6, 6], [3, 2]]\n[[4, 5], [6, 6], [3, 1]]\n[[4, 2], [6, 6], [3, 5]]\n[[4, 3], [6, 6], [3, 4]]\n[[4, 4], [6, 6], [3, 4]]\n[[4, 5], [6, 6], [3, 4]]\n"
  },
  {
    "path": "examples/Social_Cognition/ToM/data/injury_value.txt",
    "content": "1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n"
  },
  {
    "path": "examples/Social_Cognition/ToM/data/one_hot.py",
    "content": "from numpy import argmax\nimport numpy as np\n\ndef one_hot(value):\n    num = '01'\n    letter = [0 for _ in range(len(num))]\n    letter[value] = 1\n    letter = np.array([letter])\n    return letter\n\n\n\n# print(one_hot(4))\n\n"
  },
  {
    "path": "examples/Social_Cognition/ToM/env/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Social_Cognition/ToM/env/env.py",
    "content": "\nfrom numpy import argmax\nimport random, time, pygame, sys\nimport pygame\npygame.init()\nfrom pygame.locals import *\nimport os\n# os.environ['SDL_AUDIODRIVER'] = 'dsp'\n# os.environ['SDL_VIDEODRIVER']='windib'\nos.environ[\"SDL_VIDEODRIVER\"] = \"dummy\"\n# os.environ['DISPLAY'] = \"localhost:13.0\"\n\nfrom rulebasedpolicy.world_model import *\nfrom rulebasedpolicy.statedata_pre import *\nfrom rulebasedpolicy.Find_a_way import *\nimport numpy as np\nfrom utils.one_hot import one_hot\n\n# =============================================================================\n# set the value of interface\n# =============================================================================\nFPS = 25\nWinWidth = 340 #window width\nWinHeight = 260 #window width\nBoxSize = 20    #the size of one grid\nGridWidth = 7   #the number of lattices are there in the x-axis\nGridHeight = 7  #the number of lattices are there in the y-axis\n#representation of different objective\nBlankBox = 1\nWall = 5\nObstacle = 5\nobserver = 8\nobeservation_1 = 11\nobeservation_2 = 22\nobeservation_3 = 33\n#Text = None\nXMargin = int((WinWidth - GridWidth * BoxSize)/2)\nTopMargin = int((WinHeight - GridHeight * BoxSize))/2-5\n# =============================================================================\n# set color\n# =============================================================================\nWhite = (255,255,255)\nGray = (185,185,185)\nBlack = (0,0,0)\nRed = (200,0,0)\nGreen = (0,139,0)\nGreen_B = (78, 238, 148)\nLight_A = (233, 232, 170)\nBlue = (30, 144, 255)\npink = (238, 99, 99)\nBoardColor = White\nBGColor = White\nTextColor = White\nTest = []\n# =============================================================================\n# agents - env interactive\n# =============================================================================\nclass FalseBelief_env(object):\n    def __init__(self, reward=10):\n        super(FalseBelief_env, self).__init__()\n        self.action_space = ['up', 'down', 'left', 'right', 'stay']\n        self.action_move = {\n            0 : (0, -1),\n            1 : (0, 1),\n            2 : (-1, 0),\n            3 : (1, 0),\n            4 :(0, 0)\n        }#[(0, -1), (0, 1), (-1, 0), (1, 0), (0, 0)]\n        self.n_actions = len(self.action_space)\n        self._build_AB()\n        self.board, self.obs = self.getBlankBoard()\n        self._agent_init()\n        self.score = 0\n        self.steps = 0\n        self.n = 0\n        self.R = int(5/2) * (BoxSize - 5)\n        self.trigger = 0\n        self.x = 0\n        self.n_features = 30\n        self.reward = reward\n\n    def _build_AB(self):\n        global FPSCLOCK, DISPLAYSURF, BASICFONT, BIGFONT\n        FPSCLOCK = pygame.time.Clock()\n        DISPLAYSURF = pygame.display.set_mode((WinWidth, WinHeight))\n        BASICFONT = pygame.font.Font('freesansbold.ttf', 18)\n        BIGFONT = pygame.font.Font('freesansbold.ttf', 100)\n        pygame.display.set_caption('AB')\n        pygame.display.update()\n        FPSCLOCK.tick()\n\n    def _agent_init(self):\n        \"\"\"\n        Aim:Initialize the basic information of the agent\n        \"\"\"\n        self.NPC_1 = {\n            'shape' : [['#']],\n            'x' : 3, #row\n            'y' : 1, #column\n            'color' : Blue,\n            'style' : \"circle\",\n            'obs' : None,\n            'axis' : None,#,[[1,3],[3,5],[4,2]]\n            'reward' : 0,\n            'Done' : False\n        }\n        self.NPC_2 = {\n            'shape' : [['@']],\n            'x' : 5, #row\n            'y' : 3, #column\n            'color' : pink,\n            'style' : \"circle\",\n            'obs' : None,\n            'axis' :None ,#,[[3,5],[1,3],[4,2]]\n            'reward': 0,\n            'Done': False\n        }\n        self.agent = {\n            'shape' : [['$']],\n            'x' : 2, #row\n            'y' : 4, #column\n            'color' : Green_B,\n            'style' : \"circle\",\n            'obs' : None,\n            'axis' : None,#[[4,2],[1,3],[3,5]]\n            'reward': 0,\n            'Done': False\n        }\n\n    def actu_obs(self):\n        \"\"\"\n        将状态转化成可以训练的数据形式\n        \"\"\"\n        _, state = self.getBlankBoard()\n        a = state\n        b = state\n        c = state\n        state1 = np.r_[a, np.ones((4, 5))].astype(np.int_)\n        state2 = np.r_[b, np.ones((4, 5))].astype(np.int_)\n        statea = np.r_[c, np.ones((4, 5))].astype(np.int_)\n        NPC_1_state = state1\n        NPC_2_state = state2\n        Agent_state = statea\n\n        NPC_1_state[self.NPC_1['y']-1, self.NPC_1['x']-1] = observer\n        q = shelter_env(NPC_1_state[:5, :])\n        NPC_1_state[:5, :] = shelter_env(NPC_1_state[:5, :])\n\n\n        NPC_2_state[self.NPC_2['y']-1, self.NPC_2['x']-1] = observer\n        r = shelter_env(NPC_2_state[:5, :])\n        NPC_2_state[:5, :] = shelter_env(NPC_2_state[:5, :])\n\n        Agent_state[self.agent['y']-1, self.agent['x']-1] = observer\n        p = shelter_env(Agent_state[:5, :])\n        Agent_state[:5, :] = shelter_env(Agent_state[:5, :])\n        \"\"\"\n        ########### num ############\n        #2-NPC1 in other agents' obs\n        #3-NPC2 in other agents' obs\n        #4-Agent in other agents' obs        \n        \"\"\"\n        self.NPC_1['obs'] = q\n        self.NPC_1['obs'] = self.gain_obs(self.NPC_1['obs'],NPC_1_state,self.NPC_2,self.agent,3,4)\n        self.NPC_1['axis'] = self.gain_axis(self.NPC_1,NPC_1_state,self.NPC_2,self.agent,3,4)\n\n        self.NPC_2['obs'] = r\n        self.NPC_2['obs'] = self.gain_obs(self.NPC_2['obs'], NPC_2_state, self.NPC_1, self.agent, 2, 4)\n        self.NPC_2['axis'] = self.gain_axis(self.NPC_2, NPC_2_state, self.NPC_1, self.agent, 2, 4)\n\n        self.agent['obs'] = p\n        self.agent['obs'] = self.gain_obs(self.agent['obs'], Agent_state, self.NPC_1, self.NPC_2, 2, 3)\n        self.agent['axis'] = self.gain_axis(self.agent, Agent_state, self.NPC_1, self.NPC_2, 2, 3)\n\n        return NPC_1_state, NPC_2_state, Agent_state\n\n    def gain_obs(self, a, aa, b, c, bb, cc):\n        \"\"\"\n        获得智能体真正的环境遮挡关系\n        @param a: self - observation\n        @param aa: self - self-axis, other-b-axis, other-c-axis\n        @param b: other-b 遮挡后的可见区域 5*5\n        @param c: other-c 遮挡后的可见区域 5*5\n        @param bb: other-b' num\n        @param cc: other-c' num\n        @return: self - observation\n        \"\"\"\n        if aa[b['y']-1, b['x']-1] == 1:\n            a[b['y']-1, b['x']-1] = bb\n\n        if aa[c['y']-1, c['x']-1] ==1:\n            a[c['y'] - 1, c['x'] - 1] = cc\n\n        return a\n\n    def gain_axis(self,a,aa,b,c,bb,cc):\n        \"\"\"\n        获得坐标，但是看不见的坐标就用6来表示\n        @param a:\n        @param aa:\n        @param b:\n        @param c:\n        @param bb:\n        @param cc:\n        @return:\n        \"\"\"\n        axis = []\n        axis.append([a['y'], a['x']])\n        if aa[b['y']-1, b['x']-1] != 0:\n            axis.append([b['y'], b['x']])\n        else:\n            axis.append([6,6])\n        if aa[c['y']-1, c['x']-1] != 0:\n            axis.append([c['y'] , c['x']])\n        else:\n            axis.append([6, 6])\n        return axis\n\n    def interact(self, action_NPC1, action_NPC2, action_agent):\n        \"\"\"\n        三个智能体进行交互\n        @param action_NPC1: action\n        @param action_NPC2: actionF\n        @param action_agent: action\n        @return:5*5 NPC1遮挡后看见了什么 5*5 NPC2遮挡后看见了什么 5*5 agent遮挡后看见了什么\n        \"\"\"\n        self.agent['reward'] = 0\n        self.NPC_1['reward'] = 0\n        self.NPC_2['reward'] = 0\n        #三个智能体分别会看到什么？\n        NPC_1_state, NPC_2_state, Agent_state = self.actu_obs()\n\n        #看到这些状态，智能体们会分别采取什么行为？ ---depend on RL\n        #这些行为对状态的影响  ---首先，影响本身的位置坐标，然后,影响观测\n        base = np.where(np.array(self.board) == obeservation_1)\n        base_x = int(base[0])\n        base_y= int(base[1])\n        if self.NPC_1['Done'] == False:\n            dis1 = np.sqrt(np.square(base_x - self.NPC_1['y']) + np.square(base_y - self.NPC_1['x']))\n            if self.isNotWall(self.board, self.NPC_1, self.action_move[action_NPC1][0], \\\n                              self.action_move[action_NPC1][1]):\n                self.NPC_1['x'] = self.NPC_1['x'] + self.action_move[action_NPC1][0]\n                self.NPC_1['y'] = self.NPC_1['y'] + self.action_move[action_NPC1][1]\n                dis2 = np.sqrt(np.square(base_x - self.NPC_1['y']) + np.square(base_y - self.NPC_1['x']))\n                self.NPC_1['reward'] = (((dis1 - dis2) * 2 - 1) / dis1)\n            else:\n                self.NPC_1['reward'] = -1 * (1 / dis1)\n\n            if self.board[self.NPC_1['y'], self.NPC_1['x']] == obeservation_1:\n                self.NPC_1['reward'] = 50\n                self.NPC_1['Done'] = True\n\n        base = np.where(np.array(self.board) == obeservation_2)\n        base_x = int(base[0])\n        base_y= int(base[1])\n        if self.NPC_2['Done'] == False:\n            dis1 = np.sqrt(np.square(base_x - self.NPC_2['y']) + np.square(base_y - self.NPC_2['x']))\n            if self.isNotWall(self.board, self.NPC_2, self.action_move[action_NPC2][0], \\\n                              self.action_move[action_NPC2][1]):\n                self.NPC_2['x'] = self.NPC_2['x'] + self.action_move[action_NPC2][0]\n                self.NPC_2['y'] = self.NPC_2['y'] + self.action_move[action_NPC2][1]\n                dis2 = np.sqrt(np.square(base_x - self.NPC_2['y']) + np.square(base_y - self.NPC_2['x']))\n                self.NPC_2['reward'] = (((dis1 - dis2)*10 - 1/2) / dis1)\n                while self.NPC_2['reward'] < 0.5 and self.NPC_2['reward'] > -0.5 :\n                    self.NPC_2['reward'] = self.NPC_2['reward'] * 2\n                if self.NPC_2['reward'] > 1:\n                    self.NPC_2['reward'] = 1\n                elif self.NPC_2['reward'] < -1:\n                    self.NPC_2['reward'] = -1\n\n            else:\n                self.NPC_2['reward'] = -0.9 #* (1 / dis1)\n\n            if self.board[self.NPC_2['y'], self.NPC_2['x']] == obeservation_2:\n                self.NPC_2['reward'] = self.reward\n                self.NPC_2['Done'] = True\n\n        base = np.where(np.array(self.board) == obeservation_3)\n        base_x = int(base[0])\n        base_y= int(base[1])\n        if self.agent['Done'] == False:\n            dis1 = np.sqrt(np.square(base_x - self.agent['y']) + np.square(base_y - self.agent['x']))\n            if self.isNotWall(board=self.board, piece=self.agent, xT=self.action_move[action_agent][0], \\\n                              yT=self.action_move[action_agent][1]):\n                self.agent['x'] = self.agent['x'] + self.action_move[action_agent][0]\n                self.agent['y'] = self.agent['y'] + self.action_move[action_agent][1]\n                dis2 = np.sqrt(np.square(base_x - self.agent['y']) + np.square(base_y - self.agent['x']))\n            #     self.agent['reward'] = (((dis1 - dis2) * 2 - 1) / dis1)\n            # else:\n            #     # print('action', action_agent)\n            #     self.agent['reward'] = -1 * (1 / dis1)\n                self.agent['reward'] = (((dis1 - dis2)*10 - 1/2) / dis1)\n                while self.agent['reward'] < 0.5 and self.agent['reward'] > -0.5 :\n                    self.agent['reward'] = self.agent['reward'] * 2 - 0.1\n                if self.agent['reward'] > 1:\n                    self.agent['reward'] = 1\n                elif self.agent['reward'] < -1:\n                    self.agent['reward'] = -1\n            else:\n                self.agent['reward'] = -0.9 #* (1 / dis1)\n            if self.board[self.agent['y'], self.agent['x']] == obeservation_3:\n                self.agent['reward'] = self.reward\n                self.agent['Done'] = True\n        NPC_1_state, NPC_2_state, Agent_state = self.actu_obs()\n\n        #判断是否会相撞?\n        location = [(self.NPC_1['x'], self.NPC_1['y']), (self.NPC_2['x'], self.NPC_2['y']),\\\n                    (self.agent['x'], self.agent['y'])]\n\n        #达到目标或者相撞都会结束该智能体的回合\n        terminal = self.gameover(location)\n\n        if terminal[0] == True and self.NPC_1['Done'] == False:\n            self.NPC_1['Done'] = True\n            self.NPC_1['reward'] = -50\n            self.NPC_1['color'] = Red\n        if terminal[1] == True and self.NPC_2['Done'] == False:\n            self.NPC_2['Done'] = True\n            self.NPC_2['reward'] = -self.reward\n            self.NPC_2['color'] = Red\n        if terminal[2] == True and self.agent['Done'] == False:\n            self.agent['Done'] = True\n            self.agent['reward'] = -self.reward\n            self.agent['color'] = Red\n\n        return NPC_1_state, NPC_2_state, Agent_state\n\n    def SHOW(self):\n        \"\"\"\n        显示函数\n        \"\"\"\n        DISPLAYSURF.fill(BGColor)\n        self.DrawBoard(self.board)\n        self.DrawPiece(self.NPC_1)\n        self.DrawPiece(self.NPC_2)\n        self.DrawPiece(self.agent)\n        pygame.display.update()\n        FPSCLOCK.tick(FPS)\n        # return flag\n\n    def reset(self):\n        self._agent_init()\n\n    def getBlankBoard(self):\n        \"\"\"\n        11 - NPC1-goal\n        22 - NPC2-goal\n        33 - Agent-goal\n        @return:\n        \"\"\"\n        board = np.array([[1,1,1,5,5],[22,1,1,5,5],[1,1,1,1,1],[1,1,1,1,33],[1,1,1,11,1]])\n        state_init = np.array([[1,1,1,5,5],[1,1,1,5,5],[1,1,1,1,1],[1,1,1,1,1],[1,1,1,1,1]])\n        board = big_env(board)\n        # print(board)\n        #        for x in range(GridWidth):\n        #            for y in range(GridHeight):\n        #                print(board[x][y],x,y)\n        return board, state_init\n\n    def ValidPos(self, piece1, piece2, xT=0, yT=0):\n        \"\"\"\n        to judge the next place vaild or not\n        @param piece1:\n        @param piece2:\n        @param xT:\n        @param yT:\n        @return:\n        \"\"\"\n        if piece1['x'] == (piece2['x'] + xT) and piece1['y'] == (piece2['y'] + yT):\n            return True\n        return False\n\n    def isNotWall(self, board, piece, xT=0 , yT=0 ):\n        \"\"\"\n        判断是否到达墙\n        @param board: board\n        @param piece: agent\n        @param xT:\n        @param yT:\n        @return:\n        \"\"\"\n        if board[piece['y'] + yT][piece['x'] + xT] == Wall:#############\n            return False\n        else:\n            return True\n\n    def gameover(self, location):\n        \"\"\"\n        回合是否结束，以及奖励值\n        @param location:目标位置\n        \"\"\"\n        result = False\n        terminal = [False, False, False]\n\n        if location[0] == location[1]:\n            terminal[0] = True\n            terminal[1] = True\n        if location[2] == location[1]:\n            terminal[2] = True\n            terminal[1] = True\n        if location[2] ==location[0]:\n            terminal[2] = True\n            terminal[0] = True\n\n        return terminal\n\n\n    def pixel(self, xbox, ybox):\n        return (XMargin + (xbox * BoxSize)), (TopMargin + (ybox * BoxSize))\n\n    def DrawBox(self, xbox, ybox, color, xpixel=None, ypixel=None):\n        if color == BlankBox:\n            return\n        elif color == obeservation_1:\n            if xpixel == None and ypixel == None:\n                xpixel, ypixel = self.pixel(xbox, ybox)\n            pygame.draw.rect(DISPLAYSURF, (60,107,255), (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))\n        elif color == obeservation_2:\n            if xpixel == None and ypixel == None:\n                xpixel, ypixel = self.pixel(xbox, ybox)\n            pygame.draw.rect(DISPLAYSURF, (205, 155, 155), (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))\n        elif color == obeservation_3:\n            if xpixel == None and ypixel == None:\n                xpixel, ypixel = self.pixel(xbox, ybox)\n            pygame.draw.rect(DISPLAYSURF, (154, 205, 50), (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))\n        elif color == Wall:\n            if xpixel == None and ypixel == None:\n                xpixel, ypixel = self.pixel(xbox, ybox)\n            pygame.draw.rect(DISPLAYSURF, Gray, (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))\n        else:\n            if xpixel == None and ypixel == None:\n                xpixel, ypixel = self.pixel(xbox, ybox)\n            pygame.draw.rect(DISPLAYSURF, color, (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))\n\n    def fun_trigger(self):\n        xpixel, ypixel = self.pixel(3, 5)\n        pygame.draw.line(DISPLAYSURF, Red, (xpixel + BoxSize, ypixel - BoxSize),\n                         (xpixel + BoxSize, ypixel - 2*BoxSize), 5)\n\n    def DrawCircle(self,  xbox, ybox, color, xpixel=None, ypixel=None):\n        \"\"\"\n        画圆\n        @param xbox:\n        @param ybox:\n        @param color:\n        @param xpixel:\n        @param ypixel:\n        \"\"\"\n        pygame.draw.circle(DISPLAYSURF,\n                           color,\n                           (int(xpixel+BoxSize/2), int(ypixel+BoxSize/2)),\n                           int(0.3 * self.R))\n\n    def DrawPiece(self, piece, xpixel=None, ypixel=None):\n        if xpixel == None and ypixel == None:\n            xpixel, ypixel = self.pixel(piece['x'], piece['y'])\n        if piece['style'] == \"circle\":\n            self.DrawCircle(None, None, piece['color'], xpixel, ypixel)\n        else:\n            self.DrawBox(None, None, piece['color'], xpixel, ypixel)\n\n    def DrawBoard(self, board):\n        pygame.draw.rect(DISPLAYSURF, BoardColor,\n                         (XMargin - 3, TopMargin - 7, (GridWidth * BoxSize) + 8, (GridHeight * BoxSize) + 8), 5)\n        pygame.draw.rect(DISPLAYSURF, BGColor, (XMargin, TopMargin, GridWidth * BoxSize, GridHeight * BoxSize))\n        for x in range(GridWidth):\n            for y in range(GridHeight):\n                self.DrawBox(y, x, board[x][y])\n        if self.trigger == 1:\n            self.fun_trigger()\n    def ShowScore(self, score):\n        scoreSurf = BASICFONT.render('Score : %s' % score, True, TextColor)\n        scoreRect = scoreSurf.get_rect()\n        scoreRect.topleft = (WinWidth - 250, 20)\n        DISPLAYSURF.blit(scoreSurf, scoreRect)\n\n    def Terminal(self, piece1, piece2, piece1_old, piece2_old):\n        # print(piece1['x'],piece1['y'],piece2_old[0],piece2_old[1],'/',piece2['x'],piece2['y'],piece1_old[0],piece1_old[1])\n        # if self.steps == 1:# wrong!!!!\n        if piece1['x'] == piece2_old[0] and piece1['y'] == piece2_old[1] and piece2['x'] == piece1_old[0] and piece2['y'] == piece1_old[1]:\n            return 1\n        else:\n            return 2\n\n    def Paint(self, board, piece, color):\n        board[piece['x']][piece['y']] = color\n        piece['color'] = color\n        return board\n\n# if __name__ == \"__main__\":\n#     env0 = FalseBelief_env0()\n#     action_agent = 0\n#     action_NPC2 = 1\n#     action_NPC1 = 4\n#     for i in range(10):\n#         if i > 8:\n#             break\n#         else:\n#             env0.interact(action_NPC1, action_NPC2, action_agent)\n#\n#             env0.SHOW()\n#             time.sleep(2)\n#     pygame.quit()\n"
  },
  {
    "path": "examples/Social_Cognition/ToM/env/env3_train_env00.py",
    "content": "\nfrom numpy import argmax\nimport random, time, pygame, sys\nimport pygame\npygame.init()\nfrom pygame.locals import *\nimport os\nos.environ[\"SDL_VIDEODRIVER\"] = \"dummy\"\n\nfrom rulebasedpolicy.world_model import *\nfrom rulebasedpolicy.statedata_pre import *\nfrom rulebasedpolicy.Find_a_way import *\nimport numpy as np\nfrom utils.one_hot import one_hot\n\n# =============================================================================\n# set the value of interface\n# =============================================================================\nFPS = 25\nWinWidth = 340 #window width\nWinHeight = 260 #window width\nBoxSize = 20    #the size of one grid\nGridWidth = 7   #the number of lattices are there in the x-axis\nGridHeight = 7  #the number of lattices are there in the y-axis\n#representation of different objective\nBlankBox = 1\nWall = 5\nObstacle = 5\nobserver = 8\nobeservation_1 = 11\nobeservation_2 = 22\nobeservation_3 = 33\n#Text = None\nXMargin = int((WinWidth - GridWidth * BoxSize)/2)\nTopMargin = int((WinHeight - GridHeight * BoxSize))/2-5\n# =============================================================================\n# set color\n# =============================================================================\nWhite = (255,255,255)\nGray = (185,185,185)\nBlack = (0,0,0)\nRed = (200,0,0)\nGreen = (0,139,0)\nGreen_B = (78, 238, 148)\nLight_A = (233, 232, 170)\nBlue = (30, 144, 255)\npink = (238, 99, 99)\nBoardColor = White\nBGColor = White\nTextColor = White\nTest = []\n# =============================================================================\n# agents - env interactive\n# =============================================================================\nclass FalseBelief_env0(object):\n    def __init__(self, reward=10):\n        super(FalseBelief_env0, self).__init__()\n        self.action_space = ['up', 'down', 'left', 'right', 'stay']\n        self.action_move = {\n            0 : (0, -1),\n            1 : (0, 1),\n            2 : (-1, 0),\n            3 : (1, 0),\n            4 :(0, 0)\n        }#[(0, -1), (0, 1), (-1, 0), (1, 0), (0, 0)]\n        self.n_actions = len(self.action_space)\n        self._build_AB()\n        self.board, self.obs = self.getBlankBoard()\n        self._agent_init()\n        self.score = 0\n        self.steps = 0\n        self.n = 0\n        self.R = int(5/2) * (BoxSize - 5)\n        self.x = 0\n        self.n_features = 30\n        self.reward = reward\n\n    def _build_AB(self):\n        global FPSCLOCK, DISPLAYSURF, BASICFONT, BIGFONT\n        FPSCLOCK = pygame.time.Clock()\n        DISPLAYSURF = pygame.display.set_mode((WinWidth, WinHeight))\n        BASICFONT = pygame.font.Font('freesansbold.ttf', 18)\n        BIGFONT = pygame.font.Font('freesansbold.ttf', 100)\n        pygame.display.set_caption('AB')\n        pygame.display.update()\n        FPSCLOCK.tick()\n\n    def _agent_init(self):\n        \"\"\"\n        Aim:Initialize the basic information of the agent\n        \"\"\"\n        self.NPC_1 = {\n            'shape' : [['#']],\n            'x' : 3, #row\n            'y' : 1, #column\n            'color' : Blue,\n            'style' : \"circle\",\n            'obs' : None,\n            'axis' : None,#,[[1,3],[3,5],[4,2]]\n            'reward' : 0,\n            'Done' : False\n        }\n        self.NPC_2 = {\n            'shape' : [['@']],\n            'x' : 5, #row\n            'y' : 3, #column\n            'color' : pink,\n            'style' : \"circle\",\n            'obs' : None,\n            'axis' :None ,#,[[3,5],[1,3],[4,2]]\n            'reward': 0,\n            'Done': False\n        }\n        self.agent = {\n            'shape' : [['$']],\n            'x' : 2, #row\n            'y' : 4, #column\n            'color' : Green_B,\n            'style' : \"circle\",\n            'obs' : None,\n            'axis' : None,#[[4,2],[1,3],[3,5]]\n            'reward': 0,\n            'Done': False\n        }\n\n    def actu_obs(self):\n        \"\"\"\n        将状态转化成可以训练的数据形式\n        \"\"\"\n        _, state = self.getBlankBoard()\n        a = state\n        b = state\n        c = state\n        state1 = np.r_[a, np.ones((4, 5))].astype(np.int_)\n        state2 = np.r_[b, np.ones((4, 5))].astype(np.int_)\n        statea = np.r_[c, np.ones((4, 5))].astype(np.int_)\n        NPC_1_state = state1\n        NPC_2_state = state2\n        Agent_state = statea\n\n        NPC_1_state[self.NPC_1['y']-1, self.NPC_1['x']-1] = observer\n        q = shelter_env(NPC_1_state[:5, :])\n        NPC_1_state[:5, :] = shelter_env(NPC_1_state[:5, :])\n\n\n        NPC_2_state[self.NPC_2['y']-1, self.NPC_2['x']-1] = observer\n        r = shelter_env(NPC_2_state[:5, :])\n        NPC_2_state[:5, :] = shelter_env(NPC_2_state[:5, :])\n\n        Agent_state[self.agent['y']-1, self.agent['x']-1] = observer\n        p = shelter_env(Agent_state[:5, :])\n        Agent_state[:5, :] = shelter_env(Agent_state[:5, :])\n        \"\"\"\n        ########### num ############\n        #2-NPC1 in other agents' obs\n        #3-NPC2 in other agents' obs\n        #4-Agent in other agents' obs        \n        \"\"\"\n        self.NPC_1['obs'] = q\n        self.NPC_1['obs'] = self.gain_obs(self.NPC_1['obs'],NPC_1_state,self.NPC_2,self.agent,3,4)\n        self.NPC_1['axis'] = self.gain_axis(self.NPC_1,NPC_1_state,self.NPC_2,self.agent,3,4)\n\n        self.NPC_2['obs'] = r\n        self.NPC_2['obs'] = self.gain_obs(self.NPC_2['obs'], NPC_2_state, self.NPC_1, self.agent, 2, 4)\n        self.NPC_2['axis'] = self.gain_axis(self.NPC_2, NPC_2_state, self.NPC_1, self.agent, 2, 4)\n\n        self.agent['obs'] = p\n        self.agent['obs'] = self.gain_obs(self.agent['obs'], Agent_state, self.NPC_1, self.NPC_2, 2, 3)\n        self.agent['axis'] = self.gain_axis(self.agent, Agent_state, self.NPC_1, self.NPC_2, 2, 3)\n\n        return NPC_1_state, NPC_2_state, Agent_state\n\n    def gain_obs(self, a, aa, b, c, bb, cc):\n        \"\"\"\n        获得智能体真正的环境遮挡关系\n        @param a: self - observation\n        @param aa: self - self-axis, other-b-axis, other-c-axis\n        @param b: other-b 遮挡后的可见区域 5*5\n        @param c: other-c 遮挡后的可见区域 5*5\n        @param bb: other-b' num\n        @param cc: other-c' num\n        @return: self - observation\n        \"\"\"\n        if aa[b['y']-1, b['x']-1] == 1:\n            a[b['y']-1, b['x']-1] = bb\n\n        if aa[c['y']-1, c['x']-1] ==1:\n            a[c['y'] - 1, c['x'] - 1] = cc\n\n        return a\n\n    def gain_axis(self,a,aa,b,c,bb,cc):\n        \"\"\"\n        获得坐标，但是看不见的坐标就用6来表示\n        @param a:\n        @param aa:\n        @param b:\n        @param c:\n        @param bb:\n        @param cc:\n        @return:\n        \"\"\"\n        axis = []\n        axis.append([a['y'], a['x']])\n        if aa[b['y']-1, b['x']-1] != 0:\n            axis.append([b['y'], b['x']])\n        else:\n            axis.append([6,6])\n        if aa[c['y']-1, c['x']-1] != 0:\n            axis.append([c['y'] , c['x']])\n        else:\n            axis.append([6, 6])\n        return axis\n\n    def interact(self, action_NPC1, action_NPC2, action_agent):\n        \"\"\"\n        三个智能体进行交互\n        @param action_NPC1: action\n        @param action_NPC2: actionF\n        @param action_agent: action\n        @return:5*5 NPC1遮挡后看见了什么 5*5 NPC2遮挡后看见了什么 5*5 agent遮挡后看见了什么\n        \"\"\"\n        self.agent['reward'] = 0\n        self.NPC_1['reward'] = 0\n        self.NPC_2['reward'] = 0\n        #三个智能体分别会看到什么？\n        NPC_1_state, NPC_2_state, Agent_state = self.actu_obs()\n\n        #看到这些状态，智能体们会分别采取什么行为？ ---depend on RL\n        #这些行为对状态的影响  ---首先，影响本身的位置坐标，然后,影响观测\n        base = np.where(np.array(self.board) == obeservation_1)\n        base_x = int(base[0])\n        base_y= int(base[1])\n        if self.NPC_1['Done'] == False:\n            dis1 = np.sqrt(np.square(base_x - self.NPC_1['y']) + np.square(base_y - self.NPC_1['x']))\n            if self.isNotWall(self.board, self.NPC_1, self.action_move[action_NPC1][0], \\\n                              self.action_move[action_NPC1][1]):\n                self.NPC_1['x'] = self.NPC_1['x'] + self.action_move[action_NPC1][0]\n                self.NPC_1['y'] = self.NPC_1['y'] + self.action_move[action_NPC1][1]\n                dis2 = np.sqrt(np.square(base_x - self.NPC_1['y']) + np.square(base_y - self.NPC_1['x']))\n                self.NPC_1['reward'] = (((dis1 - dis2) * 2 - 1) / dis1)\n            else:\n                self.NPC_1['reward'] = -1 * (1 / dis1)\n\n            if self.board[self.NPC_1['y'], self.NPC_1['x']] == obeservation_1:\n                self.NPC_1['reward'] = 50\n                self.NPC_1['Done'] = True\n\n        base = np.where(np.array(self.board) == obeservation_2)\n        base_x = int(base[0])\n        base_y= int(base[1])\n        if self.NPC_2['Done'] == False:\n            dis1 = np.sqrt(np.square(base_x - self.NPC_2['y']) + np.square(base_y - self.NPC_2['x']))\n            if self.isNotWall(self.board, self.NPC_2, self.action_move[action_NPC2][0], \\\n                              self.action_move[action_NPC2][1]):\n                self.NPC_2['x'] = self.NPC_2['x'] + self.action_move[action_NPC2][0]\n                self.NPC_2['y'] = self.NPC_2['y'] + self.action_move[action_NPC2][1]\n                dis2 = np.sqrt(np.square(base_x - self.NPC_2['y']) + np.square(base_y - self.NPC_2['x']))\n                self.NPC_2['reward'] = (((dis1 - dis2)*10 - 1/2) / dis1)\n                while self.NPC_2['reward'] < 0.5 and self.NPC_2['reward'] > -0.5 :\n                    self.NPC_2['reward'] = self.NPC_2['reward'] * 2\n                if self.NPC_2['reward'] > 1:\n                    self.NPC_2['reward'] = 1\n                elif self.NPC_2['reward'] < -1:\n                    self.NPC_2['reward'] = -1\n\n            else:\n                self.NPC_2['reward'] = -0.9 #* (1 / dis1)\n\n            if self.board[self.NPC_2['y'], self.NPC_2['x']] == obeservation_2:\n                self.NPC_2['reward'] = self.reward\n                self.NPC_2['Done'] = True\n\n        base = np.where(np.array(self.board) == obeservation_3)\n        base_x = int(base[0])\n        base_y= int(base[1])\n        if self.agent['Done'] == False:\n            dis1 = np.sqrt(np.square(base_x - self.agent['y']) + np.square(base_y - self.agent['x']))\n            if self.isNotWall(board=self.board, piece=self.agent, xT=self.action_move[action_agent][0], \\\n                              yT=self.action_move[action_agent][1]):\n                self.agent['x'] = self.agent['x'] + self.action_move[action_agent][0]\n                self.agent['y'] = self.agent['y'] + self.action_move[action_agent][1]\n                dis2 = np.sqrt(np.square(base_x - self.agent['y']) + np.square(base_y - self.agent['x']))\n            #     self.agent['reward'] = (((dis1 - dis2) * 2 - 1) / dis1)\n            # else:\n            #     # print('action', action_agent)\n            #     self.agent['reward'] = -1 * (1 / dis1)\n                self.agent['reward'] = (((dis1 - dis2)*10 - 1/2) / dis1)\n                while self.agent['reward'] < 0.5 and self.agent['reward'] > -0.5 :\n                    self.agent['reward'] = self.agent['reward'] * 2 - 0.1\n                if self.agent['reward'] > 1:\n                    self.agent['reward'] = 1\n                elif self.agent['reward'] < -1:\n                    self.agent['reward'] = -1\n            else:\n                self.agent['reward'] = -0.9 #* (1 / dis1)\n            if self.board[self.agent['y'], self.agent['x']] == obeservation_3:\n                self.agent['reward'] = self.reward\n                self.agent['Done'] = True\n        NPC_1_state, NPC_2_state, Agent_state = self.actu_obs()\n\n        #判断是否会相撞?\n        location = [(self.NPC_1['x'], self.NPC_1['y']), (self.NPC_2['x'], self.NPC_2['y']),\\\n                    (self.agent['x'], self.agent['y'])]\n\n        #达到目标或者相撞都会结束该智能体的回合\n        terminal = self.gameover(location)\n\n        if terminal[0] == True and self.NPC_1['Done'] == False:\n            self.NPC_1['Done'] = True\n            self.NPC_1['reward'] = -50\n            self.NPC_1['color'] = Red\n        if terminal[1] == True and self.NPC_2['Done'] == False:\n            self.NPC_2['Done'] = True\n            self.NPC_2['reward'] = -self.reward\n            self.NPC_2['color'] = Red\n        if terminal[2] == True and self.agent['Done'] == False:\n            self.agent['Done'] = True\n            self.agent['reward'] = -self.reward\n            self.agent['color'] = Red\n\n        return NPC_1_state, NPC_2_state, Agent_state\n\n    def SHOW(self):\n        \"\"\"\n        显示函数\n        \"\"\"\n        DISPLAYSURF.fill(BGColor)\n        self.DrawBoard(self.board)\n        self.DrawPiece(self.NPC_1)\n        self.DrawPiece(self.NPC_2)\n        self.DrawPiece(self.agent)\n        pygame.display.update()\n        FPSCLOCK.tick(FPS)\n        # return flag\n\n    def reset(self):\n        self._agent_init()\n\n    def getBlankBoard(self):\n        \"\"\"\n        11 - NPC1-goal\n        22 - NPC2-goal\n        33 - Agent-goal\n        @return:\n        \"\"\"\n        # board = data_transfer('env_1.txt','env_11.txt')\n        board = np.array([[1,1,1,1,1],[22,1,1,1,1],[1,1,1,1,1],[1,1,1,1,33],[1,1,1,11,1]])\n        state_init = np.array([[1,1,1,1,1],[1,1,1,1,1],[1,1,1,1,1],[1,1,1,1,1],[1,1,1,1,1]])\n        board = big_env(board)\n        # print(board)\n        #        for x in range(GridWidth):\n        #            for y in range(GridHeight):\n        #                print(board[x][y],x,y)\n        return board, state_init\n\n    def ValidPos(self, piece1, piece2, xT=0, yT=0):\n        \"\"\"\n        to judge the next place vaild or not\n        @param piece1:\n        @param piece2:\n        @param xT:\n        @param yT:\n        @return:\n        \"\"\"\n        if piece1['x'] == (piece2['x'] + xT) and piece1['y'] == (piece2['y'] + yT):\n            return True\n        return False\n\n    def isNotWall(self, board, piece, xT=0 , yT=0 ):\n        \"\"\"\n        判断是否到达墙\n        @param board: board\n        @param piece: agent\n        @param xT:\n        @param yT:\n        @return:\n        \"\"\"\n        if board[piece['y'] + yT][piece['x'] + xT] == Wall:#############\n            return False\n        else:\n            return True\n\n    def gameover(self, location):\n        \"\"\"\n        回合是否结束，以及奖励值\n        @param location:目标位置\n        \"\"\"\n        result = False\n        terminal = [False, False, False]\n\n        if location[0] == location[1]:\n            terminal[0] = True\n            terminal[1] = True\n        if location[2] == location[1]:\n            terminal[2] = True\n            terminal[1] = True\n        if location[2] ==location[0]:\n            terminal[2] = True\n            terminal[0] = True\n\n        return terminal\n\n\n    def pixel(self, xbox, ybox):\n        return (XMargin + (xbox * BoxSize)), (TopMargin + (ybox * BoxSize))\n\n    def DrawBox(self, xbox, ybox, color, xpixel=None, ypixel=None):\n        if color == BlankBox:\n            return\n        elif color == obeservation_1:\n            if xpixel == None and ypixel == None:\n                xpixel, ypixel = self.pixel(xbox, ybox)\n            pygame.draw.rect(DISPLAYSURF, (60,107,255), (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))\n        elif color == obeservation_2:\n            if xpixel == None and ypixel == None:\n                xpixel, ypixel = self.pixel(xbox, ybox)\n            pygame.draw.rect(DISPLAYSURF, (205, 155, 155), (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))\n        elif color == obeservation_3:\n            if xpixel == None and ypixel == None:\n                xpixel, ypixel = self.pixel(xbox, ybox)\n            pygame.draw.rect(DISPLAYSURF, (154, 205, 50), (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))\n        elif color == Wall:\n            if xpixel == None and ypixel == None:\n                xpixel, ypixel = self.pixel(xbox, ybox)\n            pygame.draw.rect(DISPLAYSURF, Gray, (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))\n        else:\n            if xpixel == None and ypixel == None:\n                xpixel, ypixel = self.pixel(xbox, ybox)\n            pygame.draw.rect(DISPLAYSURF, color, (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))\n\n    def DrawCircle(self,  xbox, ybox, color, xpixel=None, ypixel=None):\n        \"\"\"\n        画圆\n        @param xbox:\n        @param ybox:\n        @param color:\n        @param xpixel:\n        @param ypixel:\n        \"\"\"\n        pygame.draw.circle(DISPLAYSURF,\n                           color,\n                           (int(xpixel+BoxSize/2), int(ypixel+BoxSize/2)),\n                           int(0.3 * self.R))\n\n    def DrawPiece(self, piece, xpixel=None, ypixel=None):\n        if xpixel == None and ypixel == None:\n            xpixel, ypixel = self.pixel(piece['x'], piece['y'])\n        if piece['style'] == \"circle\":\n            self.DrawCircle(None, None, piece['color'], xpixel, ypixel)\n        else:\n            self.DrawBox(None, None, piece['color'], xpixel, ypixel)\n\n    def DrawBoard(self, board):\n        pygame.draw.rect(DISPLAYSURF, BoardColor,\n                         (XMargin - 3, TopMargin - 7, (GridWidth * BoxSize) + 8, (GridHeight * BoxSize) + 8), 5)\n        pygame.draw.rect(DISPLAYSURF, BGColor, (XMargin, TopMargin, GridWidth * BoxSize, GridHeight * BoxSize))\n        for x in range(GridWidth):\n            for y in range(GridHeight):\n                self.DrawBox(y, x, board[x][y])\n\n    def ShowScore(self, score):\n        scoreSurf = BASICFONT.render('Score : %s' % score, True, TextColor)\n        scoreRect = scoreSurf.get_rect()\n        scoreRect.topleft = (WinWidth - 250, 20)\n        DISPLAYSURF.blit(scoreSurf, scoreRect)\n\n    def Terminal(self, piece1, piece2, piece1_old, piece2_old):\n        # print(piece1['x'],piece1['y'],piece2_old[0],piece2_old[1],'/',piece2['x'],piece2['y'],piece1_old[0],piece1_old[1])\n        # if self.steps == 1:# wrong!!!!\n        if piece1['x'] == piece2_old[0] and piece1['y'] == piece2_old[1] and piece2['x'] == piece1_old[0] and piece2['y'] == piece1_old[1]:\n            return 1\n        else:\n            return 2\n\n    def Paint(self, board, piece, color):\n        board[piece['x']][piece['y']] = color\n        piece['color'] = color\n        return board\n\n# if __name__ == \"__main__\":\n#     env0 = FalseBelief_env0()\n#     action_agent = 0\n#     action_NPC2 = 1\n#     action_NPC1 = 4\n#     for i in range(10):\n#         if i > 8:\n#             break\n#         else:\n#             env0.interact(action_NPC1, action_NPC2, action_agent)\n#\n#             env0.SHOW()\n#             time.sleep(2)\n#     pygame.quit()\n"
  },
  {
    "path": "examples/Social_Cognition/ToM/env/env3_train_env01.py",
    "content": "\"\"\"\nZoe Zhao 2022.5\nEnv Demo\n\"\"\"\n\nfrom numpy import argmax\nimport random, time, pygame, sys\nimport pygame\npygame.init()\nfrom pygame.locals import *\n# import os\n# os.environ[\"SDL_VIDEODRIVER\"] = \"dummy\"\n\nfrom rulebasedpolicy.world_model import *\nfrom rulebasedpolicy.statedata_pre import *\nfrom rulebasedpolicy.Find_a_way import *\nimport numpy as np\nfrom utils.one_hot import one_hot\n\n# =============================================================================\n# set the value of interface\n# =============================================================================\nFPS = 25\nWinWidth = 340 #window width\nWinHeight = 260 #window width\nBoxSize = 20    #the size of one grid\nGridWidth = 7   #the number of lattices are there in the x-axis\nGridHeight = 7  #the number of lattices are there in the y-axis\n#representation of different objective\nBlankBox = 1\nWall = 5\nObstacle = 5\nobserver = 8\nobeservation_1 = 11\nobeservation_2 = 22\nobeservation_3 = 33\n#Text = None\nXMargin = int((WinWidth - GridWidth * BoxSize)/2)\nTopMargin = int((WinHeight - GridHeight * BoxSize))/2-5\n# =============================================================================\n# set color\n# =============================================================================\nWhite = (255,255,255)\nGray = (185,185,185)\nBlack = (0,0,0)\nRed = (200,0,0)\nGreen = (0,139,0)\nGreen_B = (78, 238, 148)\nLight_A = (233, 232, 170)\nBlue = (30, 144, 255)\npink = (238, 99, 99)\nBoardColor = White\nBGColor = White\nTextColor = White\nTest = []\n# =============================================================================\n# agents - env interactive\n# =============================================================================\nclass FalseBelief_env1(object):\n    def __init__(self, reward=10):\n        super(FalseBelief_env1, self).__init__()\n        self.action_space = ['up', 'down', 'left', 'right', 'stay']\n        self.action_move = {\n            0 : (0, -1),\n            1 : (0, 1),\n            2 : (-1, 0),\n            3 : (1, 0),\n            4 :(0, 0)\n        }#[(0, -1), (0, 1), (-1, 0), (1, 0), (0, 0)]\n        self.n_actions = len(self.action_space)\n        self._build_AB()\n        self.board, self.obs = self.getBlankBoard()\n        self._agent_init()\n        self.score = 0\n        self.steps = 0\n        self.n = 0\n        self.R = int(5/2) * (BoxSize - 5)\n        self.x = 0\n        self.n_features = 30\n        self.reward = reward\n\n    def _build_AB(self):\n        global FPSCLOCK, DISPLAYSURF, BASICFONT, BIGFONT\n        pygame.init()\n        FPSCLOCK = pygame.time.Clock()\n        DISPLAYSURF = pygame.display.set_mode((WinWidth, WinHeight))\n        BASICFONT = pygame.font.Font('freesansbold.ttf', 18)\n        BIGFONT = pygame.font.Font('freesansbold.ttf', 100)\n        pygame.display.set_caption('AB')\n        pygame.display.update()\n        FPSCLOCK.tick()\n\n    def _agent_init(self):\n        \"\"\"\n        Aim:Initialize the basic information of the agent\n        \"\"\"\n        self.NPC_2 = {\n            'shape' : [['@']],\n            'x' : 5, #row\n            'y' : 3, #column\n            'color' : pink,\n            'style' : \"circle\",\n            'obs' : None,\n            'axis' :None ,#,[[3,5],[1,3],[4,2]]\n            'reward': 0,\n            'Done': False\n        }\n        self.agent = {\n            'shape' : [['$']],\n            'x' : 2, #row\n            'y' : 4, #column\n            'color' : Green_B,\n            'style' : \"circle\",\n            'obs' : None,\n            'axis' : None,#[[4,2],[1,3],[3,5]]\n            'reward': 0,\n            'Done': False\n        }\n\n    def actu_obs(self):\n        \"\"\"\n        将状态转化成可以训练的数据形式\n        \"\"\"\n        _, state = self.getBlankBoard()\n        a = state\n        b = state\n        c = state\n        state2 = np.r_[b, np.ones((4, 5))].astype(np.int)\n        statea = np.r_[c, np.ones((4, 5))].astype(np.int)\n        NPC_2_state = state2\n        Agent_state = statea\n\n        NPC_2_state[self.NPC_2['y']-1, self.NPC_2['x']-1] = observer\n        r = shelter_env(NPC_2_state[:5, :])\n        NPC_2_state[:5, :] = shelter_env(NPC_2_state[:5, :])\n\n        Agent_state[self.agent['y']-1, self.agent['x']-1] = observer\n        p = shelter_env(Agent_state[:5, :])\n        Agent_state[:5, :] = shelter_env(Agent_state[:5, :])\n\n\n        self.NPC_2['obs'] = r\n        self.NPC_2['obs'] = self.gain_obs(self.NPC_2['obs'], NPC_2_state, self.agent, 4)\n        self.NPC_2['axis'] = self.gain_axis(self.NPC_2, NPC_2_state, 6, self.agent, 2, 4)\n\n        self.agent['obs'] = p\n        self.agent['obs'] = self.gain_obs(self.agent['obs'], Agent_state,  self.NPC_2, 3)\n        self.agent['axis'] = self.gain_axis(self.agent, Agent_state, 6, self.NPC_2, 2, 3)\n\n        return NPC_2_state, Agent_state#NPC_1_state,\n\n    def gain_obs(self, a,aa,c,cc):\n        if aa[c['y']-1, c['x']-1] ==1:\n            a[c['y'] - 1, c['x'] - 1] = cc\n\n        return a\n\n    def gain_axis(self,a,aa,b,c,bb,cc):\n        axis = []\n        axis.append([a['y'], a['x']])\n        if b == 6:\n            axis.append([6, 6])\n        else:\n            axis.append([6, 6])\n        if aa[c['y']-1, c['x']-1] != 0:\n            axis.append([c['y'] , c['x']])\n        else:\n            axis.append([6, 6])\n        return axis\n\n\n    def interact(self, action_NPC2, action_agent):\n        self.agent['reward'] = 0\n        self.NPC_2['reward'] = 0\n        #三个智能体分别会看到什么？\n        NPC_2_state, Agent_state = self.actu_obs()# NPC_1_state,\n        #看到这些状态，智能体们会分别采取什么行为？ ---depend on RL\n        #这些行为对状态的影响  ---首先，影响本身的位置坐标，然后,影响观测\n        base = np.where(np.array(self.board) == obeservation_2)\n        base_x = int(base[0])\n        base_y= int(base[1])\n        if self.NPC_2['Done'] == False:\n            dis1 = np.sqrt(np.square(base_x - self.NPC_2['y']) + np.square(base_y - self.NPC_2['x']))\n            if self.isNotWall(self.board, self.NPC_2, self.action_move[action_NPC2][0], \\\n                              self.action_move[action_NPC2][1]):\n                self.NPC_2['x'] = self.NPC_2['x'] + self.action_move[action_NPC2][0]\n                self.NPC_2['y'] = self.NPC_2['y'] + self.action_move[action_NPC2][1]\n                dis2 = np.sqrt(np.square(base_x - self.NPC_2['y']) + np.square(base_y - self.NPC_2['x']))\n                self.NPC_2['reward'] = (((dis1 - dis2) * 10 - 1 / 2) / dis1)\n                while self.NPC_2['reward'] < 0.5 and self.NPC_2['reward'] > -0.5:\n                    self.NPC_2['reward'] = self.NPC_2['reward'] * 2\n                if self.NPC_2['reward'] > 1:\n                    self.NPC_2['reward'] = 1\n                elif self.NPC_2['reward'] < -1:\n                    self.NPC_2['reward'] = -1\n            else:\n                self.NPC_2['reward'] = -0.9  # * (1 / dis1)\n\n            if self.board[self.NPC_2['y'], self.NPC_2['x']] == obeservation_2:\n                self.NPC_2['reward'] = self.reward\n                self.NPC_2['Done'] = True\n\n        base = np.where(np.array(self.board) == obeservation_3)\n        base_x = int(base[0])\n        base_y= int(base[1])\n        if self.agent['Done'] == False:\n            dis1 = np.sqrt(np.square(base_x - self.agent['y']) + np.square(base_y - self.agent['x']))\n            if self.isNotWall(board=self.board, piece=self.agent, xT=self.action_move[action_agent][0], \\\n                              yT=self.action_move[action_agent][1]):\n\n                self.agent['x'] = self.agent['x'] + self.action_move[action_agent][0]\n                self.agent['y'] = self.agent['y'] + self.action_move[action_agent][1]\n                dis2 = np.sqrt(np.square(base_x - self.agent['y']) + np.square(base_y - self.agent['x']))\n            #     self.agent['reward'] = (((dis1 - dis2) * 2 - 1) / dis1)\n            #\n            # else:\n            #     # print('action', action_agent)\n            #     self.agent['reward'] = -1 * (1 / dis1)\n                self.agent['reward'] = (((dis1 - dis2)*10 - 1/2) / dis1)\n                while self.agent['reward'] < 0.5 and self.agent['reward'] > -0.5 :\n                    self.agent['reward'] = self.agent['reward'] * 2 - 0.1\n                if self.agent['reward'] > 1:\n                    self.agent['reward'] = 1\n                elif self.agent['reward'] < -1:\n                    self.agent['reward'] = -1\n            else:\n                self.agent['reward'] = -0.9 #* (1 / dis1)\n\n            if self.board[self.agent['y'], self.agent['x']] == obeservation_3:\n                self.agent['reward'] = self.reward\n                self.agent['Done'] = True\n\n        NPC_2_state, Agent_state = self.actu_obs()\n\n        #判断是否会相撞?\n        location = [(self.NPC_2['x'], self.NPC_2['y']),\\\n                    (self.agent['x'], self.agent['y'])]\n\n        #达到目标或者相撞都会结束该智能体的回合\n        terminal = self.gameover(location)\n        if self.agent['Done'] == False and terminal[1] == True:\n            self.agent['Done'] = True\n            self.agent['reward'] = -self.reward\n            self.agent['color'] = Red\n        if self.NPC_2['Done'] == False and terminal[0] == True:\n            self.NPC_2['Done'] = True\n            self.NPC_2['reward'] = -self.reward\n            self.NPC_2['color'] = Red\n\n        return  NPC_2_state, Agent_state\n\n    def SHOW(self):\n        DISPLAYSURF.fill(BGColor)\n        self.DrawBoard(self.board)\n        # self.DrawPiece(self.NPC_1)\n        self.DrawPiece(self.NPC_2)\n        self.DrawPiece(self.agent)\n        pygame.display.update()\n        FPSCLOCK.tick(FPS)\n        # return flag\n\n    def reset(self):\n        self._agent_init()\n\n    def getBlankBoard(self):\n        # board = data_transfer('env_1.txt','env_11.txt')\n        board = np.array([[1,1,1,5,5],[22,1,1,5,5],[1,1,1,1,1],[1,1,1,1,33],[1,1,1,11,1]])\n        state_init = np.array([[1,1,1,5,5],[1,1,1,5,5],[1,1,1,1,1],[1,1,1,1,1],[1,1,1,1,1]])\n        board = big_env(board)\n        # print(board)\n        #        for x in range(GridWidth):\n        #            for y in range(GridHeight):\n        #                print(board[x][y],x,y)\n        return board, state_init\n\n    def ValidPos(self, piece1, piece2, xT=0, yT=0):\n        \"\"\"\n        to judge the next place vaild or not\n        @param piece1:\n        @param piece2:\n        @param xT:\n        @param yT:\n        @return:\n        \"\"\"\n        if piece1['x'] == (piece2['x'] + xT) and piece1['y'] == (piece2['y'] + yT):\n            return True\n        return False\n\n    def isNotWall(self, board, piece, xT=0 , yT=0 ):\n        if board[piece['y'] + yT][piece['x'] + xT] == Wall:#############\n            return False\n        else:\n            return True\n\n    def gameover(self, location):\n        \"\"\"\n        回合是否结束，以及奖励值\n        @param location:目标位置\n        \"\"\"\n        result = False\n        terminal = [False, False]   #NPC_2, agent\n        for r in range(len(location) - 1):\n            for c in range(r + 1, len(location)):\n                if location[r] == location[c]:\n                    result = True       ####相撞会带来一个巨大的副奖励，并且结束该回合\n                    boom = location[r] ###############相撞################\n                if result == True:\n                    terminal[r] = True\n                    terminal[c] = True\n        return terminal\n\n    def pixel(self, xbox, ybox):\n        return (XMargin + (xbox * BoxSize)), (TopMargin + (ybox * BoxSize))\n\n    def DrawBox(self, xbox, ybox, color, xpixel=None, ypixel=None):\n        if color == BlankBox:\n            return\n        elif color == obeservation_1:\n            if xpixel == None and ypixel == None:\n                xpixel, ypixel = self.pixel(xbox, ybox)\n            pygame.draw.rect(DISPLAYSURF, (60,107,255), (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))\n        elif color == obeservation_2:\n            if xpixel == None and ypixel == None:\n                xpixel, ypixel = self.pixel(xbox, ybox)\n            pygame.draw.rect(DISPLAYSURF, (205, 155, 155\t), (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))\n        elif color == obeservation_3:\n            if xpixel == None and ypixel == None:\n                xpixel, ypixel = self.pixel(xbox, ybox)\n            pygame.draw.rect(DISPLAYSURF, (154, 205, 50), (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))\n        elif color == Wall:\n            if xpixel == None and ypixel == None:\n                xpixel, ypixel = self.pixel(xbox, ybox)\n            pygame.draw.rect(DISPLAYSURF, Gray, (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))\n        else:\n            if xpixel == None and ypixel == None:\n                xpixel, ypixel = self.pixel(xbox, ybox)\n            pygame.draw.rect(DISPLAYSURF, color, (xpixel + 1, ypixel + 1, BoxSize - 1, BoxSize - 1))\n\n    def DrawCircle(self,  xbox, ybox, color, xpixel=None, ypixel=None):\n        pygame.draw.circle(DISPLAYSURF,\n                           color,\n                           (int(xpixel+BoxSize/2), int(ypixel+BoxSize/2)),\n                           int(0.3 * self.R))\n\n    def DrawPiece(self, piece, xpixel=None, ypixel=None):\n        if xpixel == None and ypixel == None:\n            xpixel, ypixel = self.pixel(piece['x'], piece['y'])\n        if piece['style'] == \"circle\":\n            self.DrawCircle(None, None, piece['color'], xpixel, ypixel)\n        else:\n            self.DrawBox(None, None, piece['color'], xpixel, ypixel)\n\n    def DrawBoard(self, board):\n        pygame.draw.rect(DISPLAYSURF, BoardColor,\n                         (XMargin - 3, TopMargin - 7, (GridWidth * BoxSize) + 8, (GridHeight * BoxSize) + 8), 5)\n        pygame.draw.rect(DISPLAYSURF, BGColor, (XMargin, TopMargin, GridWidth * BoxSize, GridHeight * BoxSize))\n        for x in range(GridWidth):\n            for y in range(GridHeight):\n                self.DrawBox(y, x, board[x][y])\n\n    def ShowScore(self, score):\n        scoreSurf = BASICFONT.render('Score : %s' % score, True, TextColor)\n        scoreRect = scoreSurf.get_rect()\n        scoreRect.topleft = (WinWidth - 250, 20)\n        DISPLAYSURF.blit(scoreSurf, scoreRect)\n\n    def Terminal(self, piece1, piece2, piece1_old, piece2_old):\n        # print(piece1['x'],piece1['y'],piece2_old[0],piece2_old[1],'/',piece2['x'],piece2['y'],piece1_old[0],piece1_old[1])\n        # if self.steps == 1:# wrong!!!!\n        if piece1['x'] == piece2_old[0] and piece1['y'] == piece2_old[1] and piece2['x'] == piece1_old[0] and piece2['y'] == piece1_old[1]:\n            return 1\n        else:\n            return 2\n\n    def Paint(self, board, piece, color):\n        board[piece['x']][piece['y']] = color\n        piece['color'] = color\n        return board\n\n\n"
  },
  {
    "path": "examples/Social_Cognition/ToM/main_ToM.py",
    "content": "\"\"\"\nZoe Zhao 2022.5\nToM Demo\n\"\"\"\nimport argparse\nimport copy\nimport numpy as np\nimport torch\n\nnp.set_printoptions(threshold=np.inf)\ntorch.set_printoptions(threshold=np.inf)\n\nimport matplotlib\nimport pygame\npygame.init()\nmatplotlib.rcParams.update({'font.size': 12})\nimport os\nos.environ[\"SDL_VIDEODRIVER\"] = \"dummy\"\n\nfrom BrainArea.PFC_ToM import PFC_ToM\nfrom BrainArea.TPJ import ToM\nfrom BrainArea.dACC import *\nfrom rulebasedpolicy.Find_a_way import *\nfrom env.env import FalseBelief_env\nfrom braincog.base.encoder.encoder import *\nfrom braincog.base.node import node\n\n#NPC2\n#state\nN_state = 6\ncell_num = 6\n\n# action\nN_action = 5\nNC=10 #50 cells represent one character\n\n#synapstic\nbfs = pow(cell_num, N_state) #before synapstic\nafs = N_action * NC\n\n#agent\nC=10\nA_state = 4\nabfs = pow(cell_num, A_state) #agent before synapstic\naafs = N_action * C\n\nparser = argparse.ArgumentParser(description='sequence character (policy inference)')\nparser.add_argument('--mode', type=str, default='test')\nparser.add_argument('--task', type=str, default='both')\nparser.add_argument('--logdir', type=str, default='checkpoint')\nparser.add_argument('--save_net_a', type=str, default='net_NPC_11.pth', help='save the parameters of net_agent')\nparser.add_argument('--save_net_N', type=str, default='net_NPC_11.pth', help='save the parameters of net_NPC')\nparser.add_argument('--device', default='cpu', help='device')  # cuda:0\nparser.add_argument('--T', default=40, type=int, help='simulating time-steps')  # 模拟时长\nparser.add_argument('--dt', default=1, type=int, help='simulating dt')  # 模拟dt\nparser.add_argument('--episodes', default=25, type=int, help='episodes')\nparser.add_argument('--trajectories', default=10, type=int, help='trajectories')\nparser.add_argument('--greedy', default=0.8, type=int, help='exploration or exploitation')\nparser.add_argument('--num_enpop', default=6, type=int, help='the number of one population in the encoding layer')  #\nparser.add_argument('--num_depop', default=10, type=int, help='the number of one population in the decoding layer')  #\nparser.add_argument('--num_stateA', default=2, type=int, help='the number of states')\nparser.add_argument('--num_stateN', default=6, type=int, help='the number of states')\nparser.add_argument('--num_action', default=5, type=int, help='the number of actions')\nparser.add_argument('--reward', default=10, type=float, help='environment parameter reward')\nargs = parser.parse_args()\n\n\ndef update(env, net_agent_belief, net_NPC, episodes, trajectories):\n    \"\"\"\n    agents learn to reach the goal without collision\n    update agents' positions\n    @param env:\n    @param env1:\n    @param net_agent_belief: the SNN network of agent\n    @param net_NPC: the SNN network of NPC\n    @param episodes: train times\n    @return: None\n    \"\"\"\n    for episode in tqdm(range(episodes)):\n        timer = 0\n        env.reset()\n        env.actu_obs()\n        scores = {\n            'agent_0': 0,\n            'NPC2_0' : 0,\n            'agent_1': 0,\n            'NPC2_1': 0,\n        }\n        Done_agent_0 = Done_agent_1 = False\n        Done_NPC2_0  = Done_NPC2_1  = False\n        action_agent  = 3\n        action_NPC2   = 2\n        action_NPC1   = 1\n        action_agent1 = 4\n        # the start position are the same in two envs\n        # mapping_a = {'state': sum(env.agent['axis'], []),\n        #              'action': action_agent}\n        mapping_N = {'state': sum(env.NPC_2['axis'], []),\n                     'action': action_NPC2}\n        while True and timer < trajectories:\n            timer = timer + 1\n            NPC_1_state, NPC_2_state, Agent_state \\\n                = env.interact(action_NPC1, action_NPC2, action_agent)\n            env.SHOW()\n            # time.sleep(2)\n\n            # NPC_1 selects action by pp\n            if env.NPC_1['Done'] == False:\n                action_seq1 = Find_a_way(size=5, board=NPC_1_state, \\\n                                         start_x=env.NPC_1['x'] - 1, \\\n                                         start_y=4 - (env.NPC_1['y'] - 1), \\\n                                         end_x=3, end_y=4 - 4)\n                action_NPC1 = list(env.action_move.keys())[ \\\n                    list(env.action_move.values()).index(\n                        (action_seq1[1][0] - (action_seq1[0][0]), -action_seq1[1][1] + (action_seq1[0][1])))]\n            # agent selects action on purpose\n            # Agent_obs = sum(env.agent['axis'], [])\n            if env.agent['Done'] == False:\n                axis_new, axis_switch, obs_switch = ToM.TPJ(NPC_num=2, axis=env.agent['axis'], obs=env.agent['obs'], )\n                if axis_new == env.agent['axis']:\n                    '''\n                    没有遮挡关系 have teached\n                    '''\n                    action_agent = 3\n                else:\n                    '''\n                    有遮挡关系\n                    '''\n\n                    Agent_obs_NPC2 = sum(env.NPC_2['axis'], [])\n                    action_agent = net_agent_belief(inputs=Agent_obs_NPC2,\n                                                    num_action=args.num_action,\n                                                    episode=episode)\n                    prediction_next_state = ToM.prediction_state(axis_new, env.agent['axis'], action_NPC1, net_NPC,\n                                                                 num_action=args.num_action,\n                                                                 episode = episode)\n                    if ToM.state_evaluation(prediction_next_state=prediction_next_state) == False:\n                        print(False)\n                        action_agent = ToM.altruism(axis_switch=axis_switch , axis_NPC=env.NPC_2['axis'], n_actions = env.n_actions)\n                        env.trigger = 1\n                    else:\n                        action_agent = 3\n            # NPC_2 selects action by E-STDP\n            NPC2_obs = sum(env.NPC_2['axis'], [])\n            if Done_NPC2_0 == False:\n                if action_agent == 4 and env.agent['Done'] == False:\n                    action_NPC2 = 4\n                else:\n                    action_NPC2 = net_NPC(inputs=NPC2_obs, \\\n                                          num_action=args.num_action, \\\n                                          episode=episode)\n                    state_NPC2 = copy.deepcopy(NPC2_obs)\n                    Done_NPC2_0 = copy.deepcopy(env.NPC_2['Done'])\n                    # mapping_N = {'state': state_NPC2,  # at time t\n                    #              'action': action_NPC2}\n\n\ndef train():\n    print('train mode loading ... ')\n\n    if not os.path.isdir(args.logdir):\n        os.mkdir(args.logdir)\n\n    bfs = pow(args.num_enpop, args.num_stateN)  # before synapstic\n    afs = args.num_action * args.num_depop\n    #agent\n    # abfs = pow(args.num_enpop, args.num_stateA)  # agent before synapstic\n    # aafs = args.num_action * args.num_depop\n    net_agent_belief = PFC_ToM(step=args.T, encode_type='rate', bias=True,\n                        in_features=bfs, out_features=afs,\n                        node=node.LIFNode, num_state=args.num_stateN,\n                        greedy=args.greedy)    #out_features the kinds of policies\n    net_agent_belief.to(args.device)\n    net_agent_belief.fc.weight.data = torch.rand((afs, bfs))\n    # net_agent_belief.load_state_dict(torch.load(os.path.join(args.logdir, args.save_net_N))['model'])\n    #NPC\n    net_NPC = PFC_ToM(step=args.T, encode_type='rate', bias=True,\n                     in_features=bfs, out_features=afs,\n                     node=node.LIFNode, num_state=args.num_stateN,\n                        greedy=args.greedy)    #out_features the kinds of policies\n    net_NPC.to(args.device)\n\n    net_NPC.load_state_dict(torch.load(os.path.join(args.logdir, args.save_net_N))['model'])\n    total_scores = update(env, net_agent_belief, net_NPC, args.episodes,\\\n                          args.trajectories)\n\n    # torch.save({'model': net_agent.state_dict()}, os.path.join(args.logdir, args.save_net_a))\n    torch.save({'model': net_NPC.state_dict()}, os.path.join(args.logdir, args.save_net_N))\n\n    time_end = time.time()\n    print('totally cost',time_end-time_start)\n\nif __name__ == \"__main__\":\n    time_start = time.time()\n    env = FalseBelief_env(args.reward)\n    ToM = ToM(env=env)\n    # args.task = 'both'#'zero'\n    # args.mode = 'test'#'train'\n    # args.save_net_N = 'net_NPC_3.pth'\n    # args.save_net_a = 'net_agent_3.pth'\n    # args.greedy = 111\n    train()\n"
  },
  {
    "path": "examples/Social_Cognition/ToM/main_both.py",
    "content": "import argparse\nimport time\nimport copy\nimport numpy as np\nimport torch\nnp.set_printoptions(threshold=np.inf)\ntorch.set_printoptions(threshold=np.inf)\nfrom tqdm import *\nimport matplotlib\nimport seaborn as sns\nimport pygame\npygame.init()\nsns.set(style='ticks', palette='Set2')\nmatplotlib.rcParams.update({'font.size': 12})\nimport os\nos.environ[\"SDL_VIDEODRIVER\"] = \"dummy\"\nfrom BrainArea.PFC_ToM import PFC_ToM\nfrom rulebasedpolicy.Find_a_way import *\nfrom env.env3_train_env00 import FalseBelief_env0   #3\nfrom env.env3_train_env01 import FalseBelief_env1   #2\nfrom braincog.base.encoder.encoder import *\nfrom braincog.base.node import node\ntorch.manual_seed(1)\n#NPC2\n#state\nN_state = 6\ncell_num = 6\n# action\nN_action = 5\nNC=10 #50 cells represent one character\n#synapstic\nbfs = pow(cell_num, N_state) #before synapstic\nafs = N_action * NC\n#agent\nC=10\nA_state = 4\nabfs = pow(cell_num, A_state) #agent before synapstic\naafs = N_action * C\nparser = argparse.ArgumentParser(description='sequence character (policy inference)')\nparser.add_argument('--mode', type=str, default='train')\nparser.add_argument('--task', type=str, default='both')\nparser.add_argument('--logdir', type=str, default='checkpoint')\nparser.add_argument('--save_net_a', type=str, default='net_agent_4.pth', help='save the parameters of net_agent')\nparser.add_argument('--save_net_N', type=str, default='net_NPC_4.pth', help='save the parameters of net_NPC')\nparser.add_argument('--device', default='cpu', help='device')  # cuda:0\nparser.add_argument('--T', default=40, type=int, help='simulating time-steps')  # 模拟时长\nparser.add_argument('--dt', default=1, type=int, help='simulating dt')  # 模拟dt\nparser.add_argument('--episodes', default=25, type=int, help='episodes')\nparser.add_argument('--trajectories', default=10, type=int, help='trajectories')\nparser.add_argument('--greedy', default=0.8, type=float, help='exploration or exploitation')\nparser.add_argument('--num_enpop', default=6, type=int, help='the number of one population in the encoding layer')  #\nparser.add_argument('--num_depop', default=10, type=int, help='the number of one population in the decoding layer')  #\nparser.add_argument('--num_stateA', default=2, type=int, help='the number of states, (X, Y)')\nparser.add_argument('--num_stateN', default=6, type=int, help='the number of states, [(X, Y), (X, Y), (X, Y)]')\nparser.add_argument('--num_action', default=5, type=int, help='the number of actions')\nparser.add_argument('--reward', default=10, type=float, help='environment parameter reward')\nargs = parser.parse_args()\n\ndef reward_plot(episodes, scores, Note):\n    fig = plt.figure(figsize=(7.5, 4.5))\n    ax1 = fig.add_subplot(111)\n    ax1.set_title('Reward Plot')\n    plt.xlim(1, episodes)\n    plt.grid(ls='--', c='gray')\n    plt.xlabel('Epoch')\n    plt.ylabel('Reward')\n    episodes_list = list(range(1,episodes+1))\n    plt.plot(episodes_list, scores['be observed agent without the ToM'], label='be observed agent without the ToM')\n    plt.legend()\n    plt.savefig('reward_plot_' + str(episodes) + '.png')\n\ndef update(env0, env1, net_agent, net_NPC, episodes, trajectories, task):\n    \"\"\"\n    agents learn to reach the goal without collision\n    update agents' positions\n    @param env0:\n    @param env1:\n    @param net_agent: the SNN network of agent\n    @param net_NPC: the SNN network of NPC\n    @param episodes: train times\n    @return: None\n    \"\"\"\n    scores_agent = []\n    scores_NPC2  = []\n    for episode in tqdm(range(episodes)):\n        timer0 = 0\n        timer1 = 0\n        env0.reset()\n        env1.reset()\n        env0.actu_obs()\n        env1.actu_obs()\n        scores = {\n            'agent_0': 0,\n            'NPC2_0' : 0,\n            'agent_1': 0,\n            'NPC2_1': 0,\n        }\n        Done_agent_0 = Done_agent_1 = False\n        Done_NPC2_0  = Done_NPC2_1  = False\n        action_agent  = 3\n        action_NPC2   = 2\n        action_NPC1   = 4\n        action_agent1 = 4\n        # the start position are the same in two envs\n        mapping_a = {'state': sum(env0.agent['axis'], []),\n                     'action': action_agent}\n        mapping_N = {'state': sum(env0.NPC_2['axis'], []),\n                     'action': action_NPC2}\n        if task == 'both' or task == 'zero':\n            while True and timer0 < trajectories:\n                timer0 = timer0 + 1\n                NPC_1_state, NPC_2_state, Agent_state \\\n                    = env0.interact(action_NPC1, action_NPC2, action_agent)\n                env0.SHOW()\n                # time.sleep(2)\n                #NPC_1 selects action by pp\n                if env0.NPC_1['Done'] == False:\n                    action_seq1 = Find_a_way(size=5, board=NPC_1_state,\\\n                                             start_x=env0.NPC_1['x']-1,\\\n                                             start_y=4-(env0.NPC_1['y']-1),\\\n                                             end_x=3, end_y=4-4)\n                    action_NPC1 = list(env0.action_move.keys())[\\\n                        list(env0.action_move.values()).index((action_seq1[1][0]-(action_seq1[0][0]), -action_seq1[1][1]+(action_seq1[0][1])))]\n                #agent selects action by E-STDP\n                Agent_obs = sum(env0.agent['axis'], [])\n                if Done_agent_0 == False:\n                    action_agent = 3\n                    # net_agent.update_s(R = env0.agent['reward'],\\\n                    #                    mapping=mapping_a)\n                    # action_agent = net_agent(inputs = Agent_obs,\\\n                    #                          num_action = args.num_action,\\\n                    #                          episode = episode)\n                    # state_agent = copy.deepcopy(Agent_obs)\n                    # Done_agent_0 = copy.deepcopy(env0.agent['Done'])\n                    # mapping_a = {'state': state_agent, # at time t\n                    #            'action': action_agent}\n                #NPC_2 selects action by E-STDP\n                NPC2_obs = sum(env0.NPC_2['axis'], [])\n                if Done_NPC2_0 == False:\n                    net_NPC.update_s(R = env0.NPC_2['reward'], \\\n                                       mapping=mapping_N)\n                    action_NPC2 = net_NPC(inputs = NPC2_obs,\\\n                                          num_action = args.num_action,\\\n                                             episode = episode)\n                    state_NPC2 = copy.deepcopy(NPC2_obs)\n                    Done_NPC2_0 = copy.deepcopy(env0.NPC_2['Done'])\n                    mapping_N = {'state': state_NPC2, # at time t\n                               'action': action_NPC2}\n                    # continue\n                scores['agent_0'] += env0.agent['reward']\n                scores['NPC2_0'] += env0.NPC_2['reward']\n                if env0.NPC_1['Done'] == env0.NPC_2['Done'] == env0.agent['Done'] == True:\n                    break\n            scores_agent.append(scores['agent_0'])\n            scores_NPC2.append(scores['NPC2_0'])\n######################\n        if task == 'both' or task == 'one':\n            while True and timer1 < trajectories:\n                timer1 = timer1 + 1\n                NPC_2_state, Agent_state \\\n                    = env1.interact(action_NPC2, action_agent)\n                env1.SHOW()\n                # time.sleep(2)\n                # agent selects action by E-STDP\n                Agent_obs = sum(env1.agent['axis'], [])\n                if Done_agent_1 == False:\n                    action_agent = 3\n                    # net_agent.update_s(R=env1.agent['reward'], \\\n                    #                    mapping=mapping_a)\n                    # scores['agent_1'] += env1.agent['reward']\n                    # action_agent = net_agent(inputs=Agent_obs, \\\n                    #                          num_action=args.num_action,\\\n                    #                          episode = episode)\n                    # state_agent = copy.deepcopy(Agent_obs)\n                    # Done_agent_1 = copy.deepcopy(env1.agent['Done'])\n                    # mapping_a = {'state': state_agent,  # at time t\n                    #              'action': action_agent}\n                # NPC_2 selects action by E-STDP\n                NPC2_obs = sum(env1.NPC_2['axis'], [])\n                if Done_NPC2_1 == False:\n                    net_NPC.update_s(R=env1.NPC_2['reward'], \\\n                                     mapping=mapping_N)\n                    scores['NPC2_1'] += env1.NPC_2['reward']\n                    action_NPC2 = net_NPC(inputs=NPC2_obs, \\\n                                          num_action=args.num_action,\\\n                                             episode = episode)\n                    state_NPC2 = copy.deepcopy(NPC2_obs)\n                    Done_NPC2_1 = copy.deepcopy(env1.NPC_2['Done'])\n                    mapping_N = {'state': state_NPC2,  # at time t\n                                 'action': action_NPC2}\n                scores['agent_1'] += env1.agent['reward']\n                scores['NPC2_1'] += env1.NPC_2['reward']\n                if env1.NPC_2['Done'] == env1.agent['Done'] == True:\n                    break\n            scores_agent.append(scores['agent_1'])\n            scores_NPC2.append(scores['NPC2_1'])\n    total_scores = {\n        'the agent with the ToM': scores_agent,\n        'be observed agent without the ToM' : scores_NPC2\n    }\n    return total_scores\n\ndef train():\n    print('train mode loading ... ')\n    if not os.path.isdir(args.logdir):\n        os.mkdir(args.logdir)\n    #agent\n    abfs = pow(args.num_enpop, args.num_stateA)  # agent before synapstic\n    aafs = args.num_action * args.num_depop\n    net_agent = PFC_ToM(step=args.T, encode_type='rate', bias=True,\n                        in_features=abfs, out_features=aafs,\n                        node=node.LIFNode, num_state=args.num_stateA,\n                        greedy=args.greedy)    #out_features the kinds of policies\n    net_agent.to(args.device)\n    net_agent.fc.weight.data = torch.rand((aafs, abfs))\n    # net_agent.load_state_dict(torch.load('./checkpoint/net_agent_12.pth')['model'])\n    #NPC\n    bfs = pow(args.num_enpop, args.num_stateN)  # before synapstic\n    afs = args.num_action * args.num_depop\n    net_NPC = PFC_ToM(step=args.T, encode_type='rate', bias=True,\n                     in_features=bfs, out_features=afs,\n                     node=node.LIFNode, num_state=args.num_stateN,\n                        greedy=args.greedy)    #out_features the kinds of policies\n    net_NPC.to(args.device)\n    net_NPC.fc.weight.data = torch.rand((afs, bfs))\n    # net_NPC.load_state_dict(torch.load('./checkpoint/net_NPC_12.pth')['model'])\n    total_scores = update(env0, env1, net_agent, net_NPC, args.episodes,\\\n                          args.trajectories, args.task)\n    torch.save({'model': net_agent.state_dict()}, os.path.join(args.logdir, args.save_net_a))\n    torch.save({'model': net_NPC.state_dict()}, os.path.join(args.logdir, args.save_net_N))\n    time_end = time.time()\n    print('totally cost',time_end-time_start)\n    if args.task == 'zero' or args.task == 'one':\n        reward_plot(args.episodes, total_scores, 'Scores')\n    elif args.task == 'both':\n        reward_plot(args.episodes * 2, total_scores, 'Scores')\n    plt.show()\n\ndef test():\n    args.greedy = 1\n    print('test mode loading ... ')\n    print('greedy :', args.greedy)\n    #agent\n    abfs = pow(args.num_enpop, args.num_stateA)  # agent before synapstic\n    aafs = args.num_action * args.num_depop\n    net_agent = PFC_ToM(step=args.T, encode_type='rate', bias=True,\n                     in_features=abfs, out_features=aafs,\n                     node=node.LIFNode, num_state=args.num_stateA,\n                          greedy=args.greedy)\n    net_agent.to(args.device)\n    # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html\n    net_agent.load_state_dict(torch.load(os.path.join(args.logdir, args.save_net_a))['model'])      #out_features the kinds of policies\n    #NPC\n    bfs = pow(args.num_enpop, args.num_stateN)  # before synapstic\n    afs = args.num_action * args.num_depop\n    net_NPC = PFC_ToM(step=args.T, encode_type='rate', bias=True,\n                     in_features=bfs, out_features=afs,\n                     node=node.LIFNode, num_state=args.num_stateN,\n                        greedy=args.greedy)\n    net_NPC.to(args.device)\n    net_NPC.load_state_dict(torch.load(os.path.join(args.logdir, args.save_net_N))['model'])   #out_features the kinds of policies\n    total_scores = update(env0, env1, net_agent, net_NPC, args.episodes,\n                          args.trajectories, args.task)\n    time_end = time.time()\n    print('totally cost',time_end-time_start)\n    if args.task == 'zero' or args.task == 'one':\n        reward_plot(args.episodes, total_scores, 'Scores')\n    elif args.task == 'both':\n        reward_plot(args.episodes * 2, total_scores, 'Scores')\n    plt.show()\nif __name__==\"__main__\":\n    time_start = time.time()\n    env0 = FalseBelief_env0(args.reward)\n    env1 = FalseBelief_env1(args.reward)\n    # args.task = 'both'#'zero'\n    # args.mode = 'test'#'train'\n    # args.save_net_N = 'net_NPC_3.pth'\n    # args.save_net_a = 'net_agent_3.pth'\n    if args.mode == 'train':\n        train()\n    elif args.mode == 'test':\n        test()\n"
  },
  {
    "path": "examples/Social_Cognition/ToM/rulebasedpolicy/Find_a_way.py",
    "content": "# main.py\n\nimport numpy as np\nimport matplotlib.pyplot as plt\n\nfrom matplotlib.patches import Rectangle\n\nfrom rulebasedpolicy.random_map import *\nfrom rulebasedpolicy.a_star import *\nfrom env.env3_train_env01 import FalseBelief_env1\n\n\ndef Find_a_way(size, board, start_x, start_y, end_x, end_y):\n    map = RandomMap(size=size, board=board)\n\n    for i in range(map.size):\n        for j in range(map.size):\n            if map.IsObstacle(i,j):\n                rec = Rectangle((i, j), width=1, height=1, color='gray')\n            else:\n                rec = Rectangle((i, j), width=1, height=1, edgecolor='gray', facecolor='w')\n\n    rec = Rectangle((start_x, start_y), width = 1, height = 1, facecolor='b')\n\n    rec = Rectangle((end_x, end_y), width = 1, height = 1, facecolor='r')\n\n    A_star = AStar(map)\n    action_seq = A_star.RunAndSaveImage( start_x, start_y, end_x, end_y)#ax, plt,\n    return action_seq\n\n"
  },
  {
    "path": "examples/Social_Cognition/ToM/rulebasedpolicy/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Social_Cognition/ToM/rulebasedpolicy/a_star.py",
    "content": "import sys\nimport time\n\nimport numpy as np\n\nfrom matplotlib.patches import Rectangle\n\nfrom rulebasedpolicy.point import *\nfrom rulebasedpolicy.random_map import *\n\nclass AStar:\n    def __init__(self, map):\n        self.map=map\n        self.open_set = []\n        self.close_set = []\n\n    def BaseCost(self, p):\n        x_dis = p.x\n        y_dis = p.y\n        # Distance to start point\n        return x_dis + y_dis + (np.sqrt(2) - 2) * min(x_dis, y_dis)\n\n    def HeuristicCost(self, p):\n        x_dis = self.map.size - 1 - p.x\n        y_dis = self.map.size - 1 - p.y\n        # Distance to end point\n        return x_dis + y_dis + (np.sqrt(2) - 2) * min(x_dis, y_dis)\n\n    def TotalCost(self, p):\n        return self.BaseCost(p) + self.HeuristicCost(p)\n\n    def IsValidPoint(self, x, y):\n        if x < 0 or y < 0:\n            return False\n        if x >= self.map.size or y >= self.map.size:\n            return False\n        return not self.map.IsObstacle(x, y)\n\n    def IsInPointList(self, p, point_list):\n        for point in point_list:\n            if point.x == p.x and point.y == p.y:\n                return True\n        return False\n\n    def IsInOpenList(self, p):\n        return self.IsInPointList(p, self.open_set)\n\n    def IsInCloseList(self, p):\n        return self.IsInPointList(p, self.close_set)\n\n    def IsStartPoint(self, p, start_x, start_y):\n        return p.x == start_x and p.y ==start_y\n\n    def IsEndPoint(self, p, end_x, end_y):\n        return p.x == end_x and p.y == end_y###############\n\n    def SaveImage(self, plt):\n        millis = int(round(time.time() * 1000))\n        filename = './' + str(millis) + '.png'\n        plt.savefig(filename)\n\n    def ProcessPoint(self, x, y, parent):\n        if not self.IsValidPoint(x, y):\n            return # Do nothing for invalid point\n        p = Point(x, y)\n        if self.IsInCloseList(p):\n            return # Do nothing for visited point\n        # print('Process Point [', p.x, ',', p.y, ']', ', cost: ', p.cost)\n        if not self.IsInOpenList(p):\n            p.parent = parent\n            p.cost = self.TotalCost(p)\n            self.open_set.append(p)\n\n    def SelectPointInOpenList(self):\n        index = 0\n        selected_index = -1\n        min_cost = sys.maxsize\n        for p in self.open_set:\n            cost = self.TotalCost(p)\n            if cost < min_cost:\n                min_cost = cost\n                selected_index = index\n            index += 1\n        return selected_index\n\n    def BuildPath(self, p,  start_time, start_x, start_y, end_x, end_y):#ax, plt,\n        path = []\n        record = []\n        while True:\n            path.insert(0, p) # Insert first\n            if self.IsStartPoint(p, start_x, start_y):\n                break\n            else:\n                p = p.parent\n        p_x=start_x\n        p_y=start_y\n        for p in path:\n            if abs(p.x-p_x) == abs(p.y-p_y) == 1:\n                # rec = Rectangle((p_x, p.y), 1, 1, color='g')\n                # rec = Rectangle((p.x, p.y), 1, 1, color='g')\n                # ax.add_patch(rec)\n                # plt.draw()\n                # self.SaveImage(plt)\n                if abs(end_x - start_x) >= abs(end_y - start_y):\n                    record.append((p.x, p_y))\n                    record.append((p.x, p.y))\n                else:\n                    record.append((p_x, p.y))\n                    record.append((p.x, p.y))\n            else:\n                rec = Rectangle((p.x, p.y), 1, 1, color='g')\n                # ax.add_patch(rec)\n                # plt.draw()\n                # self.SaveImage(plt)\n                record.append((p.x, p.y))\n            p_x = p.x\n            p_y = p.y\n\n        end_time = time.time()\n        # print('===== Algorithm finish in', int(end_time-start_time), ' seconds')\n        return record\n\n    def RunAndSaveImage(self,  start_x, start_y, end_x, end_y):#ax, plt,\n        start_time = time.time()\n\n        start_point = Point(start_x, start_y)############################\n        start_point.cost = 0\n        self.open_set.append(start_point)\n\n        while True:\n            index = self.SelectPointInOpenList()\n            if index < 0:\n                print('No path found, algorithm failed!!!')\n                # self.SaveImage(plt)\n                return\n            p = self.open_set[index]\n            # rec = Rectangle((p.x, p.y), 1, 1, color='c')\n            # ax.add_patch(rec)\n            # self.SaveImage(plt)\n\n            if self.IsEndPoint(p, end_x, end_y):\n                return self.BuildPath(p,  start_time, start_x, start_y, end_x, end_y)#ax, plt,\n\n            del self.open_set[index]\n            self.close_set.append(p)\n\n            # Process all neighbors\n            x = p.x\n            y = p.y\n            self.ProcessPoint(x-1, y+1, p)\n            self.ProcessPoint(x-1, y, p)\n            self.ProcessPoint(x-1, y-1, p)\n            self.ProcessPoint(x, y-1, p)\n            self.ProcessPoint(x+1, y-1, p)\n            self.ProcessPoint(x+1, y, p)\n            self.ProcessPoint(x+1, y+1, p)\n            self.ProcessPoint(x, y+1, p)\n\n\n\n\n"
  },
  {
    "path": "examples/Social_Cognition/ToM/rulebasedpolicy/load_statedata.py",
    "content": "import random\nimport os\nimport numpy as np\n\nfrom torch.utils.data import Dataset, DataLoader\nimport torchvision.transforms as transforms\n\n# from torch.autograd import  Variable\n\n\nclass StateDataset:\n    # initial\n    def __init__(self, mode, num):\n        self.state = np.loadtxt(mode, dtype=np.int)\n        self.num = 1\n        self.state = self.state.reshape(num, 5*self.num , -1)\n\n    #data:A label:B\n    def __getitem__(self, item):\n        state = self.state[item]\n\n        state_A = state[:,0:5*self.num]\n        state_A = np.expand_dims(state_A, axis=0)\n        state_B = state[:, 5*self.num:10*self.num]\n        state_B = np.expand_dims(state_B, axis=0)\n\n        return {\"A\":state_A, \"B\":state_B}\n\n    #the number of data\n    def __len__(self):\n        return len(self.state)\n\n# def get_dataloader(self):\n\ndef get_dataloader(mode, num, batch):\n    train_dataset = StateDataset(mode, num)\n    train_loader = DataLoader(train_dataset, batch, shuffle=True)\n    return train_loader\n"
  },
  {
    "path": "examples/Social_Cognition/ToM/rulebasedpolicy/point.py",
    "content": "import sys\n\nclass Point:\n    def __init__(self, x, y):\n        self.x = x\n        self.y = y\n        self.cost = sys.maxsize"
  },
  {
    "path": "examples/Social_Cognition/ToM/rulebasedpolicy/random_map.py",
    "content": "import numpy as np\nfrom rulebasedpolicy.point import *\n\n\nclass RandomMap:\n    def __init__(self, size, board):\n        self.size = size\n        self.board = board\n        self.obstacle = size//8\n        self.GenerateObstacle()\n\n    def GenerateObstacle(self):\n        self.obstacle_point = []\n        # Generate an obstacle in the middle\n        for i in range(self.size):\n            for j in range(self.size):\n                if self.board[i,j] == 5:\n                    self.obstacle_point.append(Point(j, 4-i))\n\n\n\n    def IsObstacle(self, i ,j):\n        for p in self.obstacle_point:\n            if i==p.x and j==p.y:\n                return True\n        return False"
  },
  {
    "path": "examples/Social_Cognition/ToM/rulebasedpolicy/statedata_pre.py",
    "content": "import numpy as np\n\ndef data_transfer(B_txt, A_txt):\n    \"\"\"\n    Aim:读取训练数据，并将其转换为可以处理的形式\n    @param B_txt:Before processing -txt\n    @param A_txt:After processing -txt\n    @return:After processing -data\n    \"\"\"\n    with open(B_txt, 'r') as f:  #'dataA_B.txt'\n        data_all =[]\n        data_1 = []\n        data_2 = []\n        data_3 = []\n        data = f.read() #Read all the data in txt  ...str\n        data_split = data.split('\\n\\n') #Divide the data with '\\n\\n'\n        for i in range(len(data_split)-1):  #There are (len(data_split)-1) sets of valid data\n            data_split[i] = data_split[i].split('\\n')   #Remove '\\n' from each set of data\n            for j in range(len(data_split[i])):\n                # Split number\n                data_split[i][j] = \" \".join(data_split[i][j])\n                data_split[i][j] = data_split[i][j].split(' ')\n                data_split[i][j] = list(map(int, data_split[i][j])) #str-int\n\n            data_split[i] = np.array(data_split[i])    #list-np.array\n            # Data expansion\n            data_all.append(data_split[i])\n            data_1.append(np.flipud(data_split[i])) #上下对称\n            # data_2_split = data_split[i][:, [5, 6, 7, 8, 9, 0, 1, 2, 3, 4]]\n            # data_2.append(np.fliplr(data_2_split))    #左右对称\n            # data_3_split = data_split[i][:, [5, 6, 7, 8, 9, 0, 1, 2, 3, 4]]\n            # data_3.append(np.fliplr(data_3_split))\n\n        data_all.extend(data_1)\n        # data_all.extend(data_2)\n        # data_all.extend(data_3)\n\n    data_all = np.array(data_all)\n    data_all = data_all.reshape(data_all.shape[0]*data_all.shape[1], data_all.shape[2])\n    data_all = data_all.astype(int)\n\n    # new_data = np.repeat(data_all, repeats=num, axis=0)\n    # new_data = np.repeat(new_data, repeats=num, axis=1)\n\n    np.savetxt(A_txt, data_all, fmt='%i')  #'train.txt'\n\n    # Read TXT data into numpy\n    state = np.loadtxt(A_txt, dtype = np.int)\n    print(data_all.shape)\n    return state\n"
  },
  {
    "path": "examples/Social_Cognition/ToM/rulebasedpolicy/train.txt",
    "content": "8 5 1 1 1 8 5 0 0 0\n1 1 1 1 1 1 0 0 0 0\n1 1 1 1 1 1 1 0 0 0\n1 1 1 1 1 1 1 1 0 0\n1 1 1 1 1 1 1 1 1 0\n1 8 5 1 1 1 8 5 0 0\n1 1 1 1 1 1 1 0 0 0\n1 1 1 1 1 1 1 1 0 0\n1 1 1 1 1 1 1 1 1 0\n1 1 1 1 1 1 1 1 1 1\n1 1 8 5 1 1 1 8 5 0\n1 1 1 1 1 1 1 1 0 0\n1 1 1 1 1 1 1 1 1 0\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 8 5 1 1 1 8 5\n1 1 1 1 1 1 1 1 1 0\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 0\n1 1 1 8 5 1 1 1 8 5\n1 1 1 1 1 1 1 1 1 0\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 0 0\n1 1 8 5 1 1 1 8 5 0\n1 1 1 1 1 1 1 1 0 0\n1 1 1 1 1 1 1 1 1 0\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 0 0 0\n1 8 5 1 1 1 8 5 0 0\n1 1 1 1 1 1 1 0 0 0\n1 1 1 1 1 1 1 1 0 0\n1 1 1 1 1 1 1 1 1 0\n1 1 1 1 1 1 0 0 0 0\n8 5 1 1 1 8 5 0 0 0\n1 1 1 1 1 1 0 0 0 0\n1 1 1 1 1 1 1 0 0 0\n1 1 1 1 1 1 1 1 0 0\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 0\n1 1 1 8 5 1 1 1 8 5\n1 1 1 1 1 1 1 1 1 0\n8 1 5 1 1 8 1 5 0 0\n1 1 1 1 1 1 1 1 0 0\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 8 1 5 1 1 8 1 5 0\n1 1 1 1 1 1 1 1 1 0\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 0 0\n8 1 5 1 1 8 1 5 0 0\n1 1 1 1 1 1 1 1 0 0\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 0\n1 8 1 5 1 1 8 1 5 0\n1 1 1 1 1 1 1 1 1 0\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n8 1 1 5 1 8 1 1 5 0\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n8 1 1 5 1 8 1 1 5 0\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n8 1 1 1 1 8 1 1 1 1\n1 5 1 1 1 1 5 0 0 1\n1 1 1 1 1 1 0 0 0 0\n1 1 1 1 1 1 0 0 0 0\n1 1 1 1 1 1 1 0 0 0\n1 8 1 1 1 1 8 1 1 1\n1 1 5 1 1 1 1 5 0 0\n1 1 1 1 1 1 1 0 0 0\n1 1 1 1 1 1 1 0 0 0\n1 1 1 1 1 1 1 1 0 0\n1 1 8 1 1 1 1 8 1 1\n1 1 1 5 1 1 1 1 5 0\n1 1 1 1 1 1 1 1 0 0\n1 1 1 1 1 1 1 1 0 0\n1 1 1 1 1 1 1 1 1 0\n1 1 1 8 1 1 1 1 8 1\n1 1 1 1 5 1 1 1 1 5\n1 1 1 1 1 1 1 1 1 0\n1 1 1 1 1 1 1 1 1 0\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n8 1 1 1 1 8 1 1 1 1\n1 5 1 1 1 1 5 0 0 1\n1 1 1 1 1 1 0 0 0 0\n1 1 1 1 1 1 0 0 0 0\n1 1 1 1 1 1 1 1 1 1\n1 8 1 1 1 1 8 1 1 1\n1 1 5 1 1 1 1 5 0 0\n1 1 1 1 1 1 1 0 0 0\n1 1 1 1 1 1 1 0 0 0\n1 1 1 1 1 1 1 1 1 1\n1 1 8 1 1 1 1 8 1 1\n1 1 1 5 1 1 1 1 5 0\n1 1 1 1 1 1 1 1 0 0\n1 1 1 1 1 1 1 1 0 0\n1 1 1 1 1 1 1 1 1 1\n1 1 1 8 1 1 1 1 8 1\n1 1 1 1 5 1 1 1 1 5\n1 1 1 1 1 1 1 1 1 0\n1 1 1 1 1 1 1 1 1 0\n8 1 1 1 1 8 1 1 1 1\n1 1 5 1 1 1 1 5 0 0\n1 1 1 1 1 1 1 0 0 0\n1 1 1 1 1 1 1 1 0 0\n1 1 1 1 1 1 1 1 1 0\n1 8 1 1 1 1 8 1 1 1\n1 1 1 5 1 1 1 1 5 0\n1 1 1 1 1 1 1 1 0 0\n1 1 1 1 1 1 1 1 1 0\n1 1 1 1 1 1 1 1 1 1\n1 1 8 1 1 1 1 8 1 1\n1 1 1 1 5 1 1 1 1 5\n1 1 1 1 1 1 1 1 1 0\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n8 1 1 1 1 8 1 1 1 1\n1 1 5 1 1 1 1 5 0 0\n1 1 1 1 1 1 1 0 0 0\n1 1 1 1 1 1 1 1 0 0\n1 1 1 1 1 1 1 1 1 1\n1 8 1 1 1 1 8 1 1 1\n1 1 1 5 1 1 1 1 5 0\n1 1 1 1 1 1 1 1 0 0\n1 1 1 1 1 1 1 1 1 0\n1 1 1 1 1 1 1 1 1 1\n1 1 8 1 1 1 1 8 1 1\n1 1 1 1 5 1 1 1 1 5\n1 1 1 1 1 1 1 1 1 0\n1 1 1 1 1 1 1 1 1 1\n8 1 1 1 1 8 1 1 1 1\n1 1 1 5 1 1 1 1 5 0\n1 1 1 1 1 1 1 1 1 0\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 0\n1 1 1 1 1 1 1 1 0 0\n1 1 1 1 1 1 1 0 0 0\n1 1 1 1 1 1 0 0 0 0\n8 5 1 1 1 8 5 0 0 0\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 0\n1 1 1 1 1 1 1 1 0 0\n1 1 1 1 1 1 1 0 0 0\n1 8 5 1 1 1 8 5 0 0\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 0\n1 1 1 1 1 1 1 1 0 0\n1 1 8 5 1 1 1 8 5 0\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 0\n1 1 1 8 5 1 1 1 8 5\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 0\n1 1 1 8 5 1 1 1 8 5\n1 1 1 1 1 1 1 1 1 0\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 0\n1 1 1 1 1 1 1 1 0 0\n1 1 8 5 1 1 1 8 5 0\n1 1 1 1 1 1 1 1 0 0\n1 1 1 1 1 1 1 1 1 0\n1 1 1 1 1 1 1 1 0 0\n1 1 1 1 1 1 1 0 0 0\n1 8 5 1 1 1 8 5 0 0\n1 1 1 1 1 1 1 0 0 0\n1 1 1 1 1 1 1 1 0 0\n1 1 1 1 1 1 1 0 0 0\n1 1 1 1 1 1 0 0 0 0\n8 5 1 1 1 8 5 0 0 0\n1 1 1 1 1 1 0 0 0 0\n1 1 1 1 1 1 1 1 1 0\n1 1 1 8 5 1 1 1 8 5\n1 1 1 1 1 1 1 1 1 0\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 0 0\n8 1 5 1 1 8 1 5 0 0\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 0\n1 8 1 5 1 1 8 1 5 0\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 0 0\n8 1 5 1 1 8 1 5 0 0\n1 1 1 1 1 1 1 1 0 0\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 0\n1 8 1 5 1 1 8 1 5 0\n1 1 1 1 1 1 1 1 1 0\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n8 1 1 5 1 8 1 1 5 0\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n1 1 1 1 1 1 1 1 1 1\n8 1 1 5 1 8 1 1 5 0\n1 1 1 1 1 1 1 1 1 1\n"
  },
  {
    "path": "examples/Social_Cognition/ToM/rulebasedpolicy/world_model.py",
    "content": "import os\nimport sys\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom rulebasedpolicy.load_statedata import *\nimport math\nnp.set_printoptions(threshold = np.inf)\n\ndef data():\n    batch_size = 45\n    # read data\n    txt = os.path.join(sys.path[0],'rulebasedpolicy', 'train.txt')\n    train_loader=get_dataloader(mode=txt, num=batch_size ,batch=batch_size)\n\n    for data in train_loader:\n        A = data[\"A\"].numpy()\n        B = data[\"B\"].numpy()\n        B = B.reshape(batch_size, -1, 5, 5)\n        A = A.reshape(batch_size, 5, 5)\n    # distant between A and Wall(B)计算智能体与墙之间的距离\n    A_train = np.sum(np.square(np.argwhere(A==8)-np.argwhere(A==5)), axis = 1)\n    dist_AW = 1               #指定一个特定的距离1,2,4,5,9,10 distant between agent and wall############\n    o_idx = np.argwhere(A_train == dist_AW) #找到固定距离对应的所有矩阵Find all matrices corresponding to a fixed distance\n    return B, A_train\n\n\ndef flip180(arr):\n    \"\"\"\n    翻转180度\n    @param arr:\n    @return:\n    \"\"\"\n    new_arr = arr.reshape(arr.size)\n    new_arr = new_arr[::-1]\n    new_arr = new_arr.reshape(arr.shape)\n    return new_arr\n\ndef flip90_left(arr):\n    \"\"\"\n    向左翻转90度逆时针\n    @param arr:\n    @return:\n    \"\"\"\n    new_arr = np.transpose(arr)\n    new_arr = new_arr[::-1]\n    return new_arr\n\ndef flip90_right(arr):\n    \"\"\"\n    向右翻转90度顺时针\n    @param arr:\n    @return:\n    \"\"\"\n    new_arr = arr.reshape(arr.size)\n    new_arr = new_arr[::-1]\n    new_arr = new_arr.reshape(arr.shape)\n    new_arr = np.transpose(new_arr)[::-1]\n    return new_arr\n\ndef gain_env(obs, agent, wall):\n    \"\"\"\n    Aim:可以根据多个部分观察在一起拼成一个大的环境\n    @param obs:多组观测\n    @param agent:代表智能体的参数，这里默认用8来表示\n    @param wall:代表墙的参数，这里默认用5来表示\n    @return:拼成的环境\n    obs_random:随便选择一个矩阵作为based环境\n    obs[i]:其他用来补全的矩阵\n    x_a, y_a:based环境-智能体坐标\n    x_w, y_w:based环境-wall坐标\n    x_t, y_t:其他环境-智能体坐标\n    x_tt, y_tt:其他环境-wall坐标\n    \"\"\"\n\n    obs_random = obs[0]\n    for i in range(1, obs.shape[0]):\n        #based env\n        x_a, y_a = np.argwhere(obs_random == agent)[0]\n        x_w, y_w = np.argwhere(obs_random == wall)[0]\n        #external env\n        x_t, y_t = np.argwhere(obs[i] == agent)[0]\n        x_tt, y_tt = np.argwhere(obs[i] == wall)[0]\n\n        h, l = obs_random.shape\n        delta_up = max(x_t - x_a,0)\n        delta_down = max(5-x_tt - (h-x_w),0)\n        delta_left = max(y_t - y_a,0)\n        delta_right = max(5-y_tt - (l-y_w),0)\n\n        obs_random = np.r_[np.ones((delta_up, l)), obs_random] if delta_up != 0 else obs_random\n        h, l = obs_random.shape\n        obs_random = np.r_[obs_random, np.ones((delta_down, l))] if delta_down != 0 else obs_random\n        h, l = obs_random.shape\n        obs_random = np.c_[np.ones((h, delta_left)), obs_random] if delta_left != 0 else obs_random\n        h, l = obs_random.shape\n        obs_random = np.c_[obs_random, np.ones((h, delta_right))] if delta_right != 0 else obs_random\n        obs_random = obs_random.astype(np.int)\n        #based env\n        x_a, y_a = np.argwhere(obs_random == agent)[0]\n        up = x_a - x_t\n        left = y_a - y_t\n        obs_random[up:up+5, left:left+5] = obs_random[up:up+5, left:left+5] & obs[i]\n\n    return obs_random\n\ndef shelter_env(obs):\n    \"\"\"\n    Aim:用gain_env环境中的图，来描述更复杂的环境的遮挡关系\n    @param obs:复杂的环境\n    @return:环境的遮挡关系\n    \"\"\"\n    # print(obs,'----------------')\n    position_A = np.argwhere(obs==8)\n    position_W = np.argwhere(obs==5)\n    # print(position_W)\n    position = np.sum(np.square(np.argwhere(obs == 8) - np.argwhere(obs == 5)), axis=1) #numpy  (walls,)\n    # print(position_W, position_A,position)\n    shelter_env_i = np.ones((5,5)).astype(np.int)\n    B, A_train = data()\n    for i in range(position.size):\n        o_idx = np.argwhere(A_train == position[i])\n        if o_idx.size == 0:\n            break\n        else:\n            model = gain_env(B[o_idx].reshape(-1, 5, 5), 8 ,5).astype(np.int)\n\n            # print(model,'=============')\n\n            if (position_A[0,0] > position_W[i,0] and position_A[0,1] < position_W[i,1] and \\\n                    position_A[0, 0] - position_W[i, 0] < -position_A[0, 1] + position_W[i, 1])\\\n                or\\\n                (position_A[0, 0] < position_W[i, 0] and position_A[0, 1] < position_W[i, 1] and \\\n                 -position_A[0, 0] + position_W[i, 0] > -position_A[0, 1] + position_W[i, 1])\\\n                or\\\n                (position_A[0, 0] > position_W[i, 0] and position_A[0, 1] > position_W[i, 1] and \\\n                 position_A[0, 0] - position_W[i, 0] > position_A[0, 1] - position_W[i, 1])\\\n                or\\\n                (position_A[0, 0] < position_W[i, 0] and position_A[0, 1] > position_W[i, 1] and \\\n                 -position_A[0, 0] + position_W[i, 0] < position_A[0, 1] - position_W[i, 1]):\n\n                model = np.flip(model, 0)\n                # print(model, '-=-=-=-=-=-=-====')\n                model = flip90_right(model)\n                # print(model,'-=-=-=-=-=-=-====')\n            if position_A[0, 0] >= position_W[i, 0] and position_A[0, 1] > position_W[i, 1]:\n                model = flip180(model)\n            elif position_A[0, 0] > position_W[i, 0] and position_A[0, 1] <= position_W[i, 1]:\n                model = flip90_left(model)\n            elif position_A[0, 0] < position_W[i, 0] and position_A[0, 1] >= position_W[i, 1]:\n                model = flip90_right(model)\n            else:\n                model = model\n\n            x_t, y_t = np.argwhere(model == 8)[0]\n            if y_t<position_A[0, 1]:\n                model = np.c_[np.ones((model.shape[0], position_A[0, 1]-y_t)).astype(np.int), model]\n\n            if x_t < position_A[0, 0]:\n                model = np.r_[np.ones((position_A[0, 0] - x_t, model.shape[1])).astype(np.int), model]\n\n            if model.shape[0] - x_t < 5 - position_A[0, 0]:\n                model = np.r_[model, np.ones((5 - position_A[0, 0] - model.shape[0] + x_t, \\\n                                              model.shape[1])).astype(np.int)]\n\n            if model.shape[1] - y_t < 5 - position_A[0, 1]:\n                model = np.c_[model, np.ones((model.shape[0], \\\n                                              5 - position_A[0, 1]-model.shape[1] + y_t)).astype(np.int)]\n            model = model.astype(np.int)\n            # print(model,'...........')\n            x_t, y_t = np.argwhere(model == 8)[0]\n            shelter_env_i = model[x_t-position_A[0, 0]:x_t-position_A[0, 0]+5, \\\n                           y_t-position_A[0, 1]:y_t-position_A[0, 1]+5] & shelter_env_i\n            # print(shelter_env_i,'66666666666666')\n\n    shelter_env_i[np.argwhere(obs==5)[:,0], np.argwhere(obs==5)[:,1]] = 5\n    shelter_env_i[position_A[0,0], position_A[0, 1]] = 8\n    shelter_env_i = shelter_env_i.astype(np.int)\n\n    return shelter_env_i\n\ndef big_env(env):\n    h, l = np.shape(env)\n    res = np.ones((h+2, l+2))*5\n    res[1:6, 1:6] = env\n    return res.astype(np.int)\n\n# model = gain_env(B[o_idx].reshape(-1, 5, 5), 8 ,5)\n\n# obs_esti = shelter_env(np.array([[1,1,1,5,5],[1,1,1,5,5],[1,1,1,1,5],[1,8,1,1,1],[1,1,1,1,1]]))\n# env = big_env(obs_esti)\n# print(env)\n\n"
  },
  {
    "path": "examples/Social_Cognition/ToM/utils/Encoder.py",
    "content": "import numpy as np\nimport torch.nn as nn\n\n\n#exploit or explore\nnum_enpop = 6\nnum_depop = 10\n\nclass PopEncoder(nn.Module):\n    \"\"\"\n    One kind of population coding\n    \"\"\"\n    def __init__(self, step, encode_type):\n        super(PopEncoder, self).__init__()\n        self.step = step\n        self.fun = getattr(self, encode_type)\n        self.encode_type = encode_type\n\n    def forward(self, inputs, *args, **kwargs):\n        outputs = self.fun(inputs, *args, **kwargs)\n        return outputs\n\n    def rate(self, inputs, pop , num_state):\n        I = np.zeros((pow(num_enpop, num_state), self.step)) #将每一个状态都用一个神经元表示\n        #obs /in [1,2,3,4,5] ; obs_py /in [0,1,2,3,4]\n        obs_py = []\n        for i in range(len(inputs)):\n            obs_py.append(inputs[i]-1)\n        # six进制\n        ind = 0\n        for j in range(num_state): #cell_num\n            ind += pow(num_enpop, num_state - j - 1) * obs_py[j]\n\n        I[ind, 0: self.step] = 2\n\n        return I"
  },
  {
    "path": "examples/Social_Cognition/ToM/utils/one_hot.py",
    "content": "import numpy as np\n\ndef one_hot(value):\n    num = '12345'\n    letter = [0 for _ in range(len(num))]\n    letter[value-1] = 1\n    letter = np.array([letter])\n    # print(letter)\n    return letter\n"
  },
  {
    "path": "examples/Social_Cognition/affective_empathy/BAE-SNN/BAESNN.py",
    "content": "import imageio\nfrom env_poly import Maze\nfrom env_two_poly import Maze2\nimport numpy as np\nimport pandas as pd\nimport matplotlib\nimport matplotlib.pyplot as plt\n\n\nimport torch, os, sys\nfrom torch import nn\nfrom torch.nn import Parameter\nimport abc\nimport math\nfrom abc import ABC\nimport torch.nn.functional as F\nfrom braincog.base.node.node import *\nfrom braincog.base.learningrule.STDP import *\nfrom braincog.base.connection.CustomLinear import *\n\n\nclass BrainArea(nn.Module, abc.ABC):\n    \"\"\"\n    脑区基类\n    \"\"\"\n\n    @abc.abstractmethod\n    def __init__(self):\n        \"\"\"\n        \"\"\"\n        super().__init__()\n\n    @abc.abstractmethod\n    def forward(self, x):\n        \"\"\"\n        计算前向传播过程\n        :return:x是脉冲\n        \"\"\"\n\n        return x\n\n    def reset(self):\n        \"\"\"\n        计算前向传播过程\n        :return:x是脉冲\n        \"\"\"\n\n        pass\n\n\nclass BAESNN(BrainArea):\n    \"\"\"\n    情感共情网络\n    \"\"\"\n\n    def __init__(self,):\n        \"\"\"\n        \"\"\"\n        super().__init__()\n\n\n        self.node = [IFNode() for i in range(5)]\n       \n        \n        self.connection = []\n        \n        con_matrix0 = torch.eye(40, 40)*6\n        self.connection.append(CustomLinear(con_matrix0))#input-emotion\n        \n        con_matrix1 = torch.zeros((40, 50), dtype=torch.float)\n        for j in range(50):\n            if j in np.arange(0,20,1):\n                for i in np.arange(0, 20, 1):\n                    con_matrix1[i,j] =2\n            if j in np.arange(30,50,1):\n                for i in np.arange(20, 40, 1):\n                    con_matrix1[i,j] =2\n            if j in np.arange(20,30,1):\n                for i in np.arange(0, 40, 1):\n                    con_matrix1[i,j] = 2     \n        self.connection.append(CustomLinear(con_matrix1))#emotion-ifg\n        \n        con_matrix2 = torch.zeros((40, 50), dtype=torch.float)  \n        self.connection.append(CustomLinear(con_matrix2))#perception-ifg\n        \n        con_matrix3 = torch.eye(40, 40)*6\n        self.connection.append(CustomLinear(con_matrix3))#input-perception\n        \n        con_matrix4=torch.zeros((40,10), dtype=torch.float)\n        for j in range(10):\n            if j in np.arange(0,5,1):\n                for i in np.arange(0, 20, 1):\n                    con_matrix4[i,j] =2\n            if j in np.arange(5,10,1):\n                for i in np.arange(20, 40, 1):\n                    con_matrix4[i,j] =2\n        self.connection.append(CustomLinear(con_matrix4))#emotion-sma\n        \n        con_matrix5=torch.zeros((40,10), dtype=torch.float)\n        self.connection.append(CustomLinear(con_matrix5))#perception-m1\n        \n        con_matrix6 = torch.eye(10, 10)*6\n        self.connection.append(CustomLinear(con_matrix6))#sma-m1\n        \n        self.stdp = []\n        self.stdp.append(STDP(self.node[0], self.connection[0]))#0\n        self.stdp.append(STDP(self.node[2], self.connection[3]))#1\n        self.stdp.append(MutliInputSTDP(self.node[1], [self.connection[1], self.connection[2]]))#2\n        self.stdp.append(MutliInputSTDP(self.node[3], [self.connection[4], self.connection[5]]))#3\n        self.stdp.append(STDP(self.node[4], self.connection[6]))#4\n        self.stdp.append(STDP(self.node[1],self.connection[2]))#5\n        self.stdp.append(STDP(self.node[3],self.connection[5]))#6\n    def forward(self, x1,x2):\n        \"\"\"\n        计算前向传播过程\n        :return:x是脉冲\n        \"\"\"\n        out__m, dw0 = self.stdp[0](x1)#node0\n        out__p, dw3 = self.stdp[1](x2)#node2\n        out__ifg,dw_p_i=self.stdp[2](out__m,out__p)#node1\n        out__sma,dw_p_s=self.stdp[3](out__m,out__p)#node3\n        out__m1,dw1=self.stdp[4](out__sma)#node4\n    \n        return dw_p_i,dw_p_s,out__ifg,out__sma,out__m1\n    \n    def empathy(self,x3):\n        out_p,dw2=self.stdp[1](x3)#node2\n        out_ifg,dw4=self.stdp[5](out_p)#node1\n        out_sma,dw5=self.stdp[6](out_p)#node3\n        out_m1,dw6=self.stdp[4](out_sma)#node4\n        return out_ifg,out_sma,out_m1\n        \n    def UpdateWeight(self, i, dw, delta):\n        \"\"\"\n        更新第i组连接的权重 根据传入的dw值\n        :param i: 要更新的连接的索引\n        :param dw: 更新的量\n        :return: None\n        \"\"\"\n        self.connection[i].update(dw*delta)\n        self.connection[i].weight.data= torch.clamp(self.connection[i].weight.data,-1,4)\n        \n    def reset(self):\n        \"\"\"\n        reset神经元或学习法则的中间量\n        :return: None\n        \"\"\"\n        for i in range(5):\n            self.node[i].n_reset()\n        for i in range(len(self.stdp)):\n            self.stdp[i].reset()\n\ndef BAESNN_train():  \n    s = env.reset()\n    env._set_danger()\n    env._set_wall()\n    pain=0\n    i=0\n    set_pain=0\n    env._set_switch()\n    for i in range(100):\n        a.reset()\n        T=20\n        print('step:',i)\n        env.render()\n        \n        action = np.random.choice(list(range(env.n_actions)))\n        s_,s_pre,s_color = env.step(s, action, pain)\n        env.render()\n\n        if env.open_door == 1:\n            env.render()\n        \n        true_s_1 = np.array(s_)\n        predict_s_1=np.array(s_pre)\n        error = true_s_1 - predict_s_1\n        error = sum([c * c for c in error])\n        if error>=3200:\n            error=3200\n        if error>0:\n            pain=1\n        if error==0:\n            pain=0\n\n        if pain==0:\n            X1=torch.tensor([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,\n                                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,\n                                0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,\n                                0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]])\n            X2=torch.tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n                                0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n                                0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n                                0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])\n            env.render()\n            \n        if pain==1:\n            set_pain = 1\n            X1=torch.tensor([[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,\n                                0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,\n                                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,\n                                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])\n            X2=torch.tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n                                0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n                                0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n                                0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])\n            env.render()\n            \n        for i in range(T):\n            if i>=2:\n                X2=X1\n            OUTPUT = a(X1,X2)\n            a.UpdateWeight(2,OUTPUT[0][1],0.01)\n            a.UpdateWeight(5,OUTPUT[1][1],-0.1)\n        if OUTPUT[2][0][0]==1:\n            env.canvas.itemconfig(env.rect, fill=\"red\", outline='red')\n        if OUTPUT[2][0][40]==1:\n            env.canvas.itemconfig(env.rect, fill=\"green\", outline='green')\n        env.render()\n        \n        print('out_ifg:',OUTPUT[2])\n        print('out_sma:',OUTPUT[3])\n        print('out_m1:',OUTPUT[4])\n        # print('con2:',a.connection[2].weight.data)\n        # print('con5:',a.connection[5].weight.data)\n        \n        s = s_\n        if set_pain==1 and pain==0:\n            env.render()\n    env.destroy()\n                \n\ndef BAESNN_test():\n    a.reset()\n    s1,s=env2.reset()\n    pain=0\n    pain1 = 0\n    i=0\n    set_pain=0\n    \n    for i in range(1000):\n        env2.render()\n        \n        s_now = env2.canvas.coords(env2.agent1)\n\n        action1 = np.random.choice([0,1,2,3], p=[0.2, 0.3, 0.3, 0.2])\n        if env2.open_door==1 and s_now[0] <(9 / 2) * 40:\n            action1 = np.random.choice([0,1,2,3], p=[0.5, 0.0, 0.0, 0.5])\n\n        s1_, s1_pre,s1_color = env2.step1(action1,pain)\n        print('s1_color:',s1_color)\n\n        if env2.open_door == 1 :\n            env2.render()\n\n        true_s1_1 = np.array(s1_)\n        predict_s1_1=np.array(s1_pre)\n        error1 = true_s1_1 - predict_s1_1\n        error1 = sum([c * c for c in error1])\n        if error1>=3200:\n            error1=3200\n\n        if error1>0:\n            pain=1\n            set_pain=1\n            \n        if error1==0:\n            pain=0\n        \n        env2.generate_expression1(pain)\n        \n        if s1_color==\"red\":\n            X3=torch.tensor([[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,\n                                0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,\n                                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,\n                                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])\n        if s1_color==\"blue\":\n            X3=torch.tensor([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,\n                                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,\n                                0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,\n                                0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]])\n            \n        a.reset()\n        for i in range(20):\n            OUT=a.empathy(X3)\n            print(OUT)\n        if pain==1:\n            env2.agent_help()\n            \n        s1 = s1_\n        env2.render()\n\n        if pain==0 and set_pain==1:\n            env2.render()\n            break\n  \n    # env2.destroy()\n  \n\n\n\n\n\nif __name__ == \"__main__\":\n    env = Maze() \n    a = BAESNN() \n    BAESNN_train()\n    env.mainloop()\n    \n    env2 = Maze2()\n    BAESNN_test()\n    env2.mainloop()"
  },
  {
    "path": "examples/Social_Cognition/affective_empathy/BAE-SNN/README.md",
    "content": "# Requirments\n* numpy\n* scipy\n* pytorch >= 1.7.0\n* torchvision\n\n\n# Run\n## Train \n* the file to be run: BAESNN.py \n\n\n# Citation\n```\n@ARTICLE{Hui2022,\n        AUTHOR={Feng, Hui and Zeng, Yi and Lu, Enmeng},   \n        TITLE={Brain-Inspired Affective Empathy Computational Model and Its Application on Altruistic Rescue Task},      \t\n        JOURNAL={Frontiers in Computational Neuroscience},      \t\n        VOLUME={16},           \t\n        YEAR={2022},      \t  \n        URL={https://www.frontiersin.org/articles/10.3389/fncom.2022.784967},    \n        DOI={10.3389/fncom.2022.784967},      \n        ISSN={1662-5188}     \n         }\n   \n```\n"
  },
  {
    "path": "examples/Social_Cognition/affective_empathy/BAE-SNN/env_poly.py",
    "content": "import numpy as np\nnp.random.seed(1)\nimport tkinter as tk\nimport time\nfrom PIL import ImageGrab\n\nUNIT = 40   # pixels\nMAZE_H = 9  # grid height\nMAZE_W = 4 # grid width\n\n\nclass Maze(tk.Tk, object):\n    def __init__(self):\n        super(Maze, self).__init__()\n        self.action_space = ['u', 'd', 'l', 'r']\n        self.n_actions = len(self.action_space)\n        self.title('self-pain')\n        self.geometry('{0}x{1}'.format(MAZE_H * UNIT, MAZE_W * UNIT))\n        self._build_maze()\n        self.danger=0\n        self.action_hurt=0\n        self.sensory_hurt = 0\n        self.open_door = 0\n        self.pain_state=0\n\n    # create environment\n    def _build_maze(self):\n        self.canvas = tk.Canvas(self, bg='white',\n                           height=MAZE_W * UNIT,\n                           width=MAZE_H * UNIT)\n\n        # create grids\n        for c in range(0, MAZE_H * UNIT, UNIT):\n            x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT\n            self.canvas.create_line(x0, y0, x1, y1)\n        for r in range(0, MAZE_H * UNIT, UNIT):\n            x0, y0, x1, y1 = 0, r, MAZE_H * UNIT, r\n            self.canvas.create_line(x0, y0, x1, y1)\n\n\n        # create agent\n        self.orgin=[20,20]\n        # 上\n        self.points0 = [\n            # 右下\n            self.orgin[0]+15,#35\n            self.orgin[1]+15,#35\n            # 左下\n            self.orgin[0]-15,#5\n            self.orgin[1]+15,#35\n            # 左上+\n            self.orgin[0]-15,#5\n            self.orgin[1],#20\n            # 顶点\n            self.orgin[0],#20\n            self.orgin[1]-15,#5\n            # 右上+\n            self.orgin[0]+15,#35\n            self.orgin[1],#20\n        ]\n        # self.rect0 = self.canvas.create_polygon(self.points0, fill=\"green\")\n        # self.agent_action0 = self.canvas.coords(self.rect0)\n\n        # 下\n        self.points1 = [\n            # 左上\n            self.orgin[0]-15,#5\n            self.orgin[1]-15,#5\n            # 右上\n            self.orgin[0]+15,#35\n            self.orgin[1]-15,#5\n            # 右下+\n            self.orgin[0]+15,#35\n            self.orgin[1],#20\n            # 顶点\n            self.orgin[0],#20\n            self.orgin[1]+15,#35\n            # 左下+\n            self.orgin[0]-15,#5\n            self.orgin[1],#20\n        ]\n        self.rect = self.canvas.create_polygon(self.points1, fill=\"green\")\n        # self.agent_action1 = self.canvas.coords(self.rect1)\n\n        # 右\n        self.points2 = [\n            # 左下\n            self.orgin[0]-15,#5\n            self.orgin[1]+15,#35\n            # 左上\n            self.orgin[0]-15,#5\n            self.orgin[1]-15,#5\n            # 右上+\n            self.orgin[0],#20\n            self.orgin[1]-15,#5\n            # 顶点\n            self.orgin[0]+15,#35\n            self.orgin[1],#20\n            # 右下+\n            self.orgin[0],#20\n            self.orgin[1]+15,#35\n        ]\n        # self.rect2 = self.canvas.create_polygon(self.points2, fill=\"green\")\n        # self.agent_action2 = self.canvas.coords(self.rect2)\n\n        # 左\n        self.points3 = [\n            # 右上\n            self.orgin[0]+15,#20+15\n            self.orgin[1]-15,#20-15\n            # 右下\n            self.orgin[0]+15,#20+15\n            self.orgin[1]+15,#20+15\n            # 左下+\n            self.orgin[0],#20\n            self.orgin[1]+15,#20+15\n            # 顶点\n            self.orgin[0]-15,#20-15\n            self.orgin[1],#20\n            # 左上+\n            self.orgin[0],#20\n            self.orgin[1]-15,#20-15\n\n\n        ]\n        # self.rect3 = self.canvas.create_polygon(self.points3, fill=\"green\")\n        # self.agent_action3 = self.canvas.coords(self.rect3)\n        self.canvas.pack()\n\n\n    #reset agent location\n    def reset(self):\n        self.open_door = 0\n        self.update()\n        time.sleep(0.5)\n        self.canvas.delete(self.rect)\n        self.orgin = [20, 20]\n        # 下\n        self.points1 = [\n            # 左上\n            self.orgin[0] - 15,  # 5\n            self.orgin[1] - 15,  # 5\n            # 右上\n            self.orgin[0] + 15,  # 35\n            self.orgin[1] - 15,  # 5\n            # 右下+\n            self.orgin[0] + 15,  # 35\n            self.orgin[1],  # 20\n            # 顶点\n            self.orgin[0],  # 20\n            self.orgin[1] + 15,  # 35\n            # 左下+\n            self.orgin[0] - 15,  # 5\n            self.orgin[1],  # 20\n        ]\n        self.rect = self.canvas.create_polygon(self.points1, fill=\"green\")\n        # self.agent_action1 = self.canvas.coords(self.rect1)\n        return self.canvas.coords(self.rect)\n\n    def step(self, s, action, pain):\n        s = self.canvas.coords(self.rect)\n        self.centre = [(s[4] + s[8]) / 2, (s[5] + s[9]) / 2]\n\n        # danger or switch\n        if self.danger==1:\n            if all(self.centre == self.oval_center):\n                s_color = 'yellow'\n                self.canvas.delete(self.wall[3])\n                self.render()\n                # self.getter(self.canvas)\n                self.render()\n                # self.getter(self.canvas)#figure8 ,figure3.1 all red changed to green\n                self.open_door = 1\n\n                move = np.array([80, 0])\n                self.canvas.move(self.rect, move[0], move[1])\n\n                s = self.canvas.coords(self.rect)\n                self.render()\n                # self.getter(self.canvas)\n            elif all(self.centre == self.hell1_center):\n                s_color = 'black'\n                self.action_hurt = 1\n                self.render()\n                # self.getter(self.canvas)#figure4\n                self.render()\n            else:\n                s_color = 'white'\n\n\n\n\n        # modify current state\n        self.canvas.delete(self.rect)# 主要为开关那几步考虑，所以重复写了\n        self.centre = [(s[4] + s[8]) / 2, (s[5] + s[9]) / 2]\n\n        if action==0:\n            self.points0 = [\n                # 右下\n                self.centre[0] + 15,  # 35\n                self.centre[1] + 15,  # 35\n                # 左下\n                self.centre[0] - 15,  # 5\n                self.centre[1] + 15,  # 35\n                # 左上+\n                self.centre[0] - 15,  # 5\n                self.centre[1],  # 20\n                # 顶点\n                self.centre[0],  # 20\n                self.centre[1] - 15,  # 5\n                # 右上+\n                self.centre[0] + 15,  # 35\n                self.centre[1],  # 20\n            ]\n            if pain==0:\n                color=\"green\"\n            if pain == 1:\n                color = \"red\"\n            self.rect = self.canvas.create_polygon(self.points0, fill=color)\n        if action==1:\n            self.points1 = [\n                # 左上\n                self.centre[0] - 15,  # 5\n                self.centre[1] - 15,  # 5\n                # 右上\n                self.centre[0] + 15,  # 35\n                self.centre[1] - 15,  # 5\n                # 右下+\n                self.centre[0] + 15,  # 35\n                self.centre[1],  # 20\n                # 顶点\n                self.centre[0],  # 20\n                self.centre[1] + 15,  # 35\n                # 左下+\n                self.centre[0] - 15,  # 5\n                self.centre[1],  # 20\n            ]\n            if pain==0:\n                color=\"green\"\n            if pain == 1:\n                color = \"red\"\n            self.rect = self.canvas.create_polygon(self.points1, fill=color)\n        if action==2:\n            self.points2 = [\n                # 左下\n                self.centre[0] - 15,  # 5\n                self.centre[1] + 15,  # 35\n                # 左上\n                self.centre[0] - 15,  # 5\n                self.centre[1] - 15,  # 5\n                # 右上+\n                self.centre[0],  # 20\n                self.centre[1] - 15,  # 5\n                # 顶点\n                self.centre[0] + 15,  # 35\n                self.centre[1],  # 20\n                # 右下+\n                self.centre[0],  # 20\n                self.centre[1] + 15,  # 35\n            ]\n            if pain==0:\n                color=\"green\"\n            if pain == 1:\n                color = \"red\"\n            self.rect = self.canvas.create_polygon(self.points2, fill=color)\n        if action==3:\n            self.points3 = [\n                # 右上\n                self.centre[0] + 15,  # 20+15\n                self.centre[1] - 15,  # 20-15\n                # 右下\n                self.centre[0] + 15,  # 20+15\n                self.centre[1] + 15,  # 20+15\n                # 左下+\n                self.centre[0],  # 20\n                self.centre[1] + 15,  # 20+15\n                # 顶点\n                self.centre[0] - 15,  # 20-15\n                self.centre[1],  # 20\n                # 左上+\n                self.centre[0],  # 20\n                self.centre[1] - 15,  # 20-15\n\n            ]\n            if pain==0:\n                color=\"green\"\n            if pain == 1:\n                color = \"red\"\n            self.rect = self.canvas.create_polygon(self.points3, fill=color)\n        s = self.canvas.coords(self.rect)\n        self.render()#显示当前的动作指令是什么\n        # self.getter(self.canvas)#figure5 after figure4\n\n\n        if s[0] > (9 / 2) * 40:\n            self.action_hurt = 0\n        # ensure ture action\n        base_action = np.array([0, 0])\n        if self.action_hurt == 0:\n            true_action = action\n        else:\n            if action == 0:\n                true_action = 1\n            if action == 1:\n                true_action = 0\n            if action == 2:\n                true_action = 3\n            if action == 3:\n                true_action = 2\n\n        # predict next state\n        b = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]\n        if self.centre[0] <= ((MAZE_H - 1) / 2 +1) * UNIT:#120\n            if action == 0:  # up\n                if self.centre[1] > UNIT:\n                    b = [0, -40, 0, -40,0, -40, 0, -40,0, -40]\n            elif action == 1:  # down\n                if self.centre[1] < (MAZE_W - 1) * UNIT:\n                    b = [0, 40, 0, 40,0, 40, 0, 40, 0, 40]\n            elif action == 2:  # right\n                if self.centre[0] < ((MAZE_H - 1) / 2 - 1) * UNIT:\n                    b = [40, 0, 40, 0,40, 0, 40, 0,40, 0]\n            elif action == 3:  # left\n                if self.centre[0] > UNIT:\n                    b = [-40, 0, -40, 0,-40, 0, -40, 0,-40, 0]\n        else:\n            if action == 0:  # up\n                if self.centre[1] > UNIT:\n                    b = [0, -40, 0, -40,0, -40, 0, -40,0, -40]\n            elif action == 1:  # down\n                if self.centre[1] < (MAZE_W - 1) * UNIT:\n                    b = [0, 40, 0, 40,0, 40, 0, 40, 0, 40]\n            elif action == 2:  # right\n                if self.centre[0] < (MAZE_H - 1) * UNIT:\n                    b = [40, 0, 40, 0,40, 0, 40, 0,40, 0]\n            elif action == 3:  # left\n                if self.centre[0] > ((MAZE_H - 1) / 2 + 2) * UNIT:\n                    b = [-40, 0, -40, 0,-40, 0, -40, 0,-40, 0]\n        s_predict = []\n        for i in range(len(b)):\n            s_predict1 = s[i] + b[i]\n            s_predict.append(s_predict1)\n\n\n        # true next state\n        if self.centre[0]<=((MAZE_H - 1) / 2 +1) * UNIT:\n            if true_action == 0:  # up\n                if self.centre[1] > UNIT:\n                    base_action[1] -= UNIT\n            elif true_action == 1:  # down\n                if self.centre[1] < (MAZE_W - 1) * UNIT:\n                    base_action[1] += UNIT\n            elif true_action == 2:  # right\n                if self.centre[0] < ((MAZE_H - 1) / 2 - 1) * UNIT:\n                    base_action[0] += UNIT\n            elif true_action == 3:  # left\n                if self.centre[0] > UNIT:\n                    base_action[0] -= UNIT\n        else:\n            if true_action == 0:  # up\n                if self.centre[1] > UNIT:\n                    base_action[1] -= UNIT\n            elif true_action == 1:  # down\n                if self.centre[1] < (MAZE_W - 1) * UNIT:\n                    base_action[1] += UNIT\n            elif true_action == 2:  # right\n                if self.centre[0] < (MAZE_H - 1) * UNIT:\n                    base_action[0] += UNIT\n            elif true_action == 3:  # left\n                if self.centre[0] > ((MAZE_H - 1) / 2 + 2) * UNIT:\n                    base_action[0] -= UNIT\n        self.canvas.move(self.rect, base_action[0], base_action[1])\n        s_ = self.canvas.coords(self.rect)\n\n        return s_, s_predict, s_color\n\n    def step_RL1(self, action):\n        s = self.canvas.coords(self.rect)\n        base_action = np.array([0, 0])\n\n        if s[0] <= ((MAZE_H - 1) / 2 + 1) * UNIT:\n            if action == 0:  # up\n                if s[1] > UNIT:\n                    base_action[1] -= UNIT\n            elif action == 1:  # down\n                if s[1] < (MAZE_W - 1) * UNIT:\n                    base_action[1] += UNIT\n            elif action == 2:  # right\n                if s[0] < ((MAZE_H - 1) / 2 - 1) * UNIT:\n                    base_action[0] += UNIT\n            elif action == 3:  # left\n                if s[0] > UNIT:\n                    base_action[0] -= UNIT\n        else:\n            if action == 0:  # up\n                if s[1] > UNIT:\n                    base_action[1] -= UNIT\n            elif action == 1:  # down\n                if s[1] < (MAZE_W - 1) * UNIT:\n                    base_action[1] += UNIT\n            elif action == 2:  # right\n                if s[0] < (MAZE_H - 1) * UNIT:\n                    base_action[0] += UNIT\n            elif action == 3:  # left\n                if s[0] > ((MAZE_H - 1) / 2 + 2) * UNIT:\n                    base_action[0] -= UNIT\n\n        self.canvas.move(self.rect, base_action[0], base_action[1])  # move agent\n        s_ = self.canvas.coords(self.rect)  # next state\n\n        if s_==self.canvas.coords(self.hell1):\n            self.canvas.itemconfig(self.rect, fill=\"red\", outline='red')\n            reward = -1\n            self.pain_state=1\n        else:\n            reward = 0\n        return s_, reward,self.pain_state\n\n    def step_RL2(self, action):\n        s = self.canvas.coords(self.rect)\n        if s == self.canvas.coords(self.oval):\n            self.canvas.delete(self.wall[3])\n            self.open_door = 1\n            move = np.array([40, 0])\n            self.canvas.move(self.rect, move[0], move[1])\n            move = np.array([40, 0])\n            self.canvas.move(self.rect, move[0], move[1])\n            self.render()\n            self.canvas.itemconfig(self.rect, fill=\"green\", outline='green')\n            self.render()\n\n        base_action = np.array([0, 0])\n\n        if s[0] <= ((MAZE_H - 1) / 2 + 1) * UNIT:\n            if action == 0:  # up\n                if s[1] > UNIT:\n                    base_action[1] -= UNIT\n            elif action == 1:  # down\n                if s[1] < (MAZE_W - 1) * UNIT:\n                    base_action[1] += UNIT\n            elif action == 2:  # right\n                if s[0] < ((MAZE_H - 1) / 2 - 1) * UNIT:\n                    base_action[0] += UNIT\n            elif action == 3:  # left\n                if s[0] > UNIT:\n                    base_action[0] -= UNIT\n        else:\n            if action == 0:  # up\n                if s[1] > UNIT:\n                    base_action[1] -= UNIT\n            elif action == 1:  # down\n                if s[1] < (MAZE_W - 1) * UNIT:\n                    base_action[1] += UNIT\n            elif action == 2:  # right\n                if s[0] < (MAZE_H - 1) * UNIT:\n                    base_action[0] += UNIT\n            elif action == 3:  # left\n                if s[0] > ((MAZE_H - 1) / 2 + 2) * UNIT:\n                    base_action[0] -= UNIT\n\n        self.canvas.move(self.rect, base_action[0], base_action[1])  # move agent\n        s_ = self.canvas.coords(self.rect)  # next state\n\n        if s_ == self.canvas.coords(self.oval):\n\n\n            self.open_door = 1\n\n            if self.pain_state == 0:\n                reward = 0\n            if self.pain_state == 1:\n                reward = 1\n                self.pain_state = 0\n\n        elif s_ == self.canvas.coords(self.hell1):\n            self.canvas.itemconfig(self.rect, fill=\"red\", outline='red')\n            reward = -1\n            self.pain_state = 1\n            self.render()\n\n        else:\n            reward = 0\n\n        return s_, reward, self.pain_state\n\n    def _set_danger(self):\n        self.hell1_center = np.array([60, 60])\n        self.hell1 = self.canvas.create_oval(\n            self.hell1_center[0] - 15, self.hell1_center[1] - 15,\n            self.hell1_center[0] + 15, self.hell1_center[1] + 15,\n            fill='black')\n        # self.canvas.create_bitmap((40 , 40), bitmap='error')\n        self.hell = self.canvas.coords(self.hell1)\n        self.canvas.pack()\n        self.danger=1\n\n    def _set_switch(self):\n        self.oval_center = np.array([(MAZE_H * UNIT) / 2 - UNIT, ((MAZE_W + 4) * UNIT) / 2 - UNIT / 2])\n        self.oval = self.canvas.create_oval(\n            self.oval_center[0] - 15, self.oval_center[1] - 15,\n            self.oval_center[0] + 15, self.oval_center[1] + 15,\n            fill='yellow')\n        self.switch = self.canvas.coords(self.oval)\n        self.canvas.pack()\n\n\n    def _set_wall(self):\n        wall_center=[]\n        self.wall=[]\n        for a in range(MAZE_W):\n            wall_center.append([0,0])\n            self.wall.append([])\n        for b in range(MAZE_W):\n            wall_center[b]=np.array([(MAZE_H*UNIT)/2,((b)*UNIT)+UNIT/2])\n            self.wall[b] = self.canvas.create_rectangle(\n                wall_center[b][0] - 20, wall_center[b][1] - 20,\n                wall_center[b][0] + 20, wall_center[b][1] + 20,\n                fill='grey')\n        self.wall0 = self.canvas.coords(self.wall[0])\n        self.wall1 = self.canvas.coords(self.wall[1])\n        self.wall2 = self.canvas.coords(self.wall[2])\n        self.wall3 = self.canvas.coords(self.wall[3])\n\n        # self.canvas.pack()\n\n    def generate_expression(self,pain):\n        if pain==1:\n            self.canvas.itemconfig(self.rect, fill=\"red\", outline='red')\n            # self.canvas.pack()\n        if pain == 0:\n            self.canvas.itemconfig(self.rect, fill=\"green\", outline='green')\n            # self.canvas.pack()\n\n    def render(self):\n        time.sleep(0.01)\n        self.update()\n\n    # def getter(self, widget):\n    #     widget.update()\n    #     x = tk.Tk.winfo_rootx(self) + widget.winfo_x()\n    #     y = tk.Tk.winfo_rooty(self) + widget.winfo_y()\n    #     x1 = x + widget.winfo_width()\n    #     y1 = y + widget.winfo_height()\n    #     ImageGrab.grab().crop((x, y, x1, y1)).save(\"first.jpg\")\n    #     return ImageGrab.grab().crop((x, y, x1, y1))\n\n\n\n"
  },
  {
    "path": "examples/Social_Cognition/affective_empathy/BAE-SNN/env_two_poly.py",
    "content": "import numpy as np\nnp.random.seed(1)\nimport tkinter as tk\nimport time\nfrom PIL import ImageGrab\n\n\nUNIT = 40   # pixels\nMAZE_H = 9  # grid height\nMAZE_W = 4 # grid width\n\n\nclass Maze2(tk.Tk, object):\n    def __init__(self):\n        super(Maze2, self).__init__()\n        self.action_space = ['u', 'd', 'l', 'r']\n        self.action_space1 = ['u', 'd', 'l', 'r']\n        self.n_actions = len(self.action_space)\n        self.n_actions1 = len(self.action_space1)\n        self.title('two_agent_empathy')\n        self.geometry('{0}x{1}'.format(MAZE_H * UNIT, MAZE_W * UNIT))\n        self._build_maze()\n        self.danger=0\n        self.action_hurt=0\n        self.sensory_hurt = 0\n        self.action_hurt1 = 0\n        self.sensory_hurt1 = 0\n        self.open_door=0\n\n    # create environment\n    def _build_maze(self):\n        self.canvas = tk.Canvas(self, bg='white',\n                           height=MAZE_W * UNIT,\n                           width=MAZE_H * UNIT)\n\n        # create grids\n        for c in range(0, MAZE_H * UNIT, UNIT):\n            x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT\n            self.canvas.create_line(x0, y0, x1, y1)\n        for r in range(0, MAZE_H * UNIT, UNIT):\n            x0, y0, x1, y1 = 0, r, MAZE_H * UNIT, r\n            self.canvas.create_line(x0, y0, x1, y1)\n\n        # create switch\n        self.oval_center = np.array([(MAZE_H * UNIT)/2-UNIT+80, ((MAZE_W+4) * UNIT)/2-UNIT/2-80])\n        self.oval = self.canvas.create_oval(\n            self.oval_center[0] - 15, self.oval_center[1] - 15,\n            self.oval_center[0] + 15, self.oval_center[1] + 15,\n            fill='yellow')\n        self.switch = self.canvas.coords(self.oval)\n\n        self.orgin1 = np.array([20, 20])\n        # 下\n        self.points1 = [\n            # 左上\n            self.orgin1[0] - 15,  # 5\n            self.orgin1[1] - 15,  # 5\n            # 右上\n            self.orgin1[0] + 15,  # 35\n            self.orgin1[1] - 15,  # 5\n            # 右下+\n            self.orgin1[0] + 15,  # 35\n            self.orgin1[1],  # 20\n            # 顶点\n            self.orgin1[0],  # 20\n            self.orgin1[1] + 15,  # 35\n            # 左下+\n            self.orgin1[0] - 15,  # 5\n            self.orgin1[1],  # 20\n        ]\n        self.agent1 = self.canvas.create_polygon(self.points1, outline='black',fill=\"blue\")\n\n        self.orgin = np.array([MAZE_H * UNIT - UNIT / 2, 20])\n        # 下\n        self.points = [\n            # 左上\n            self.orgin[0] - 15,  # 5\n            self.orgin[1] - 15,  # 5\n            # 右上\n            self.orgin[0] + 15,  # 35\n            self.orgin[1] - 15,  # 5\n            # 右下+\n            self.orgin[0] + 15,  # 35\n            self.orgin[1],  # 20\n            # 顶点\n            self.orgin[0],  # 20\n            self.orgin[1] + 15,  # 35\n            # 左下+\n            self.orgin[0] - 15,  # 5\n            self.orgin[1],  # 20\n        ]\n        self.agent = self.canvas.create_polygon(self.points, fill=\"green\")\n\n        wall_center = []\n        self.wall = []\n        for i in range(MAZE_W):\n            wall_center.append([])\n            self.wall.append([])\n        for i in range(MAZE_W):\n            wall_center[i] = np.array([(MAZE_H * UNIT) / 2, ((i) * UNIT) + UNIT / 2])\n            self.wall[i] = self.canvas.create_rectangle(\n                wall_center[i][0] - 20, wall_center[i][1] - 20,\n                wall_center[i][0] + 20, wall_center[i][1] + 20,\n                fill='grey')\n\n        self.hell1_center = np.array([100, 20])\n        self.hell1 = self.canvas.create_oval(\n            self.hell1_center[0] - 15, self.hell1_center[1] - 15,\n            self.hell1_center[0] + 15, self.hell1_center[1] + 15,\n            fill='black')\n        self.hell2_center = np.array([60, 100])\n        self.hell2 = self.canvas.create_oval(\n            self.hell2_center[0] - 15, self.hell2_center[1] - 15,\n            self.hell2_center[0] + 15, self.hell2_center[1] + 15,\n            fill='black')\n        # self.canvas.create_bitmap((40 , 40), bitmap='error')\n\n        self.danger = 1\n\n        self.canvas.pack()\n\n    #reset agent location\n    def reset(self):\n        self.update()\n        time.sleep(0.5)\n        self.canvas.delete(self.agent1)\n        self.canvas.delete(self.agent)\n        self.orgin1 = np.array([20, 20])\n        # 下\n        self.points1 = [\n            # 左上\n            self.orgin1[0] - 15,  # 5\n            self.orgin1[1] - 15,  # 5\n            # 右上\n            self.orgin1[0] + 15,  # 35\n            self.orgin1[1] - 15,  # 5\n            # 右下+\n            self.orgin1[0] + 15,  # 35\n            self.orgin1[1],  # 20\n            # 顶点\n            self.orgin1[0],  # 20\n            self.orgin1[1] + 15,  # 35\n            # 左下+\n            self.orgin1[0] - 15,  # 5\n            self.orgin1[1],  # 20\n        ]\n        self.agent1 = self.canvas.create_polygon(self.points1, outline='black',fill=\"blue\")\n\n        self.orgin = np.array([MAZE_H * UNIT - UNIT / 2, 20])\n        # 下\n        self.points = [\n            # 左上\n            self.orgin[0] - 15,  # 5\n            self.orgin[1] - 15,  # 5\n            # 右上\n            self.orgin[0] + 15,  # 35\n            self.orgin[1] - 15,  # 5\n            # 右下+\n            self.orgin[0] + 15,  # 35\n            self.orgin[1],  # 20\n            # 顶点\n            self.orgin[0],  # 20\n            self.orgin[1] + 15,  # 35\n            # 左下+\n            self.orgin[0] - 15,  # 5\n            self.orgin[1],  # 20\n        ]\n        self.agent = self.canvas.create_polygon(self.points, fill=\"green\")\n\n        return self.canvas.coords(self.agent1),self.canvas.coords(self.agent)\n\n\n    # move agent1\n    def step1(self, action1,pain):\n        s1 = self.canvas.coords(self.agent1)\n        self.centre1 = [(s1[4] + s1[8]) / 2, (s1[5] + s1[9]) / 2]\n        if all(self.centre1 == self.hell1_center):\n            self.action_hurt1 = 1\n        if all(self.centre1 == self.hell2_center):\n            self.action_hurt1 = 1\n        \n        self.oval_center111 = np.array([(MAZE_H * UNIT) / 2, ((MAZE_W + 4) * UNIT) / 2 - UNIT / 2])\n        if all(self.centre1 ==self.oval_center111):\n            move = np.array([80, 0])\n            self.canvas.move(self.agent1, move[0], move[1])\n            s1 = self.canvas.coords(self.agent1)\n            self.render()\n        self.oval_center111 = np.array([(MAZE_H * UNIT) / 2 - UNIT, ((MAZE_W + 4) * UNIT) / 2 - UNIT / 2])\n        if all(self.centre1 ==self.oval_center111):\n            move = np.array([80, 0])\n            self.canvas.move(self.agent1, move[0], move[1])\n            s1 = self.canvas.coords(self.agent1)\n            self.render()\n        self.oval_center111 = np.array([(MAZE_H * UNIT) / 2 - UNIT*2, ((MAZE_W + 4) * UNIT) / 2 - UNIT / 2])\n        if all(self.centre1 == self.oval_center111):\n            move = np.array([80, 0])\n            self.canvas.move(self.agent1, move[0], move[1])\n            s1 = self.canvas.coords(self.agent1)\n            self.render()\n        self.oval_center111 = np.array([(MAZE_H * UNIT) / 2 - UNIT*3, ((MAZE_W + 4) * UNIT) / 2 - UNIT / 2])\n        if all(self.centre1 == self.oval_center111):\n            move = np.array([80, 0])\n            self.canvas.move(self.agent1, move[0], move[1])\n            s1 = self.canvas.coords(self.agent1)\n            self.render()\n        self.oval_center111 = np.array([(MAZE_H * UNIT) / 2 - UNIT*4, ((MAZE_W + 4) * UNIT) / 2 - UNIT / 2])\n        if all(self.centre1 == self.oval_center111):\n            move = np.array([80, 0])\n            self.canvas.move(self.agent1, move[0], move[1])\n            s1 = self.canvas.coords(self.agent1)\n            self.render()\n        \n        \n        self.canvas.delete(self.agent1)  # 主要为开关那几步考虑，所以重复写了\n        self.centre1 = [(s1[4] + s1[8]) / 2, (s1[5] + s1[9]) / 2]\n        if action1==0:\n            self.points0 = [\n                # 右下\n                self.centre1[0] + 15,  # 35\n                self.centre1[1] + 15,  # 35\n                # 左下\n                self.centre1[0] - 15,  # 5\n                self.centre1[1] + 15,  # 35\n                # 左上+\n                self.centre1[0] - 15,  # 5\n                self.centre1[1],  # 20\n                # 顶点\n                self.centre1[0],  # 20\n                self.centre1[1] - 15,  # 5\n                # 右上+\n                self.centre1[0] + 15,  # 35\n                self.centre1[1],  # 20\n            ]\n            if pain==0:\n                color=\"blue\"\n            if pain == 1:\n                color = \"red\"\n            self.agent1 = self.canvas.create_polygon(self.points0, fill=color,outline='black')\n        if action1==1:\n            self.points1 = [\n                # 左上\n                self.centre1[0] - 15,  # 5\n                self.centre1[1] - 15,  # 5\n                # 右上\n                self.centre1[0] + 15,  # 35\n                self.centre1[1] - 15,  # 5\n                # 右下+\n                self.centre1[0] + 15,  # 35\n                self.centre1[1],  # 20\n                # 顶点\n                self.centre1[0],  # 20\n                self.centre1[1] + 15,  # 35\n                # 左下+\n                self.centre1[0] - 15,  # 5\n                self.centre1[1],  # 20\n            ]\n            if pain==0:\n                color=\"blue\"\n            if pain == 1:\n                color = \"red\"\n            self.agent1 = self.canvas.create_polygon(self.points1, fill=color,outline='black')\n        if action1==2:\n            self.points2 = [\n                # 左下\n                self.centre1[0] - 15,  # 5\n                self.centre1[1] + 15,  # 35\n                # 左上\n                self.centre1[0] - 15,  # 5\n                self.centre1[1] - 15,  # 5\n                # 右上+\n                self.centre1[0],  # 20\n                self.centre1[1] - 15,  # 5\n                # 顶点\n                self.centre1[0] + 15,  # 35\n                self.centre1[1],  # 20\n                # 右下+\n                self.centre1[0],  # 20\n                self.centre1[1] + 15,  # 35\n            ]\n            if pain==0:\n                color=\"blue\"\n            if pain == 1:\n                color = \"red\"\n            self.agent1 = self.canvas.create_polygon(self.points2, fill=color,outline='black')\n        if action1==3:\n            self.points3 = [\n                # 右上\n                self.centre1[0] + 15,  # 20+15\n                self.centre1[1] - 15,  # 20-15\n                # 右下\n                self.centre1[0] + 15,  # 20+15\n                self.centre1[1] + 15,  # 20+15\n                # 左下+\n                self.centre1[0],  # 20\n                self.centre1[1] + 15,  # 20+15\n                # 顶点\n                self.centre1[0] - 15,  # 20-15\n                self.centre1[1],  # 20\n                # 左上+\n                self.centre1[0],  # 20\n                self.centre1[1] - 15,  # 20-15\n\n            ]\n            if pain==0:\n                color=\"blue\"\n            if pain == 1:\n                color = \"red\"\n            self.agent1 = self.canvas.create_polygon(self.points3, fill=color,outline='black')\n        s1 = self.canvas.coords(self.agent1)\n        self.render()#显示当前的动作指令是什么\n\n        self.centre1 = [(s1[4] + s1[8]) / 2, (s1[5] + s1[9]) / 2]\n        if self.centre1[0] > (9 / 2) * 40:\n            self.action_hurt1 = 0\n\n        # whether hurt\n        if self.action_hurt1 == 0:\n            true_action1 = action1\n        else:\n            if action1 == 0:\n                true_action1 = 1\n            if action1 == 1:\n                true_action1 = 0\n            if action1 == 2:\n                true_action1 = 3\n            if action1 == 3:\n                true_action1 = 2\n\n        # predict next state\n        b = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n        if self.centre1[0] <= ((MAZE_H - 1) / 2 + 1) * UNIT:  # 120\n            if action1 == 0:  # up\n                if self.centre1[1] > UNIT:\n                    b = [0, -40, 0, -40, 0, -40, 0, -40, 0, -40]\n            elif action1 == 1:  # down\n                if self.centre1[1] < (MAZE_W - 1) * UNIT:\n                    b = [0, 40, 0, 40, 0, 40, 0, 40, 0, 40]\n            elif action1 == 2:  # right\n                if self.centre1[0] < ((MAZE_H - 1) / 2 - 1) * UNIT:\n                    b = [40, 0, 40, 0, 40, 0, 40, 0, 40, 0]\n            elif action1 == 3:  # left\n                if self.centre1[0] > UNIT:\n                    b = [-40, 0, -40, 0, -40, 0, -40, 0, -40, 0]\n        else:\n            if action1 == 0:  # up\n                if self.centre1[1] > UNIT:\n                    b = [0, -40, 0, -40, 0, -40, 0, -40, 0, -40]\n            elif action1 == 1:  # down\n                if self.centre1[1] < (MAZE_W - 1) * UNIT:\n                    b = [0, 40, 0, 40, 0, 40, 0, 40, 0, 40]\n            elif action1 == 2:  # right\n                if self.centre1[0] < (MAZE_H - 1) * UNIT:\n                    b = [40, 0, 40, 0, 40, 0, 40, 0, 40, 0]\n            elif action1 == 3:  # left\n                if self.centre1[0] > ((MAZE_H - 1) / 2 + 2) * UNIT:\n                    b = [-40, 0, -40, 0, -40, 0, -40, 0, -40, 0]\n        s_predict = []\n        for i in range(len(b)):\n            s_predict1 = s1[i] + b[i]\n            s_predict.append(s_predict1)\n\n        base_action1 = np.array([0, 0])\n        \n        \n        # true next state\n        if self.centre1[0] <= ((MAZE_H - 1) / 2 + 1) * UNIT:\n            if true_action1 == 0:  # up\n                if self.centre1[1] > UNIT:\n                    base_action1[1] -= UNIT\n            elif true_action1 == 1:  # down\n                if self.centre1[1] < (MAZE_W - 1) * UNIT:\n                    base_action1[1] += UNIT\n            elif true_action1 == 2:  # right\n                if self.centre1[0] < ((MAZE_H - 1) / 2 - 1) * UNIT:\n                    base_action1[0] += UNIT\n            elif true_action1 == 3:  # left\n                if self.centre1[0] > UNIT:\n                    base_action1[0] -= UNIT\n        else:\n            if true_action1 == 0:  # up\n                if self.centre1[1] > UNIT:\n                    base_action1[1] -= UNIT\n            elif true_action1 == 1:  # down\n                if self.centre1[1] < (MAZE_W - 1) * UNIT:\n                    base_action1[1] += UNIT\n            elif true_action1 == 2:  # right\n                if self.centre1[0] < (MAZE_H - 1) * UNIT:\n                    base_action1[0] += UNIT\n            elif true_action1 == 3:  # left\n                if self.centre1[0] > ((MAZE_H - 1) / 2 + 2) * UNIT:\n                    base_action1[0] -= UNIT\n        self.canvas.move(self.agent1, base_action1[0], base_action1[1])\n        s1_ = self.canvas.coords(self.agent1)\n\n        return s1_, s_predict,color\n\n\n    def agent_help(self):\n        s = self.canvas.coords(self.agent)\n        self.centre2= [(s[4] + s[8]) / 2, (s[5] + s[9]) / 2]\n        if all(self.centre2 == self.oval_center):\n            self.canvas.delete(self.wall[3])\n            self.render()\n            self.open_door=1\n        else:    \n            self.canvas.move(self.agent, -40, 0)  # move agent\n            self.render()\n            self.canvas.move(self.agent, -40, 0)\n            self.render()\n            self.canvas.move(self.agent, -40, 0)\n            self.render()\n            self.canvas.move(self.agent, 0, 40)\n            self.render()\n        s_ = self.canvas.coords(self.agent)  # next state\n        \n        return s_\n\n    def _set_danger(self):\n        hell1_center = np.array([140, 60])\n        self.hell1 = self.canvas.create_rectangle(\n            hell1_center[0] - 15, hell1_center[1] - 15,\n            hell1_center[0] + 15, hell1_center[1] + 15,\n            fill='black')\n        hell2_center = np.array([100, 140])\n        self.hell2 = self.canvas.create_rectangle(\n            hell2_center[0] - 15, hell2_center[1] - 15,\n            hell2_center[0] + 15, hell2_center[1] + 15,\n            fill='black')\n        # self.canvas.create_bitmap((40 , 40), bitmap='error')\n        self.canvas.pack()\n        self.danger=1\n\n\n    def _set_wall(self):\n        wall_center=[]\n        self.wall=[]\n        for i in range(MAZE_W):\n            wall_center.append([])\n            self.wall.append([])\n        for i in range(MAZE_W):\n            wall_center[i]=np.array([(MAZE_H*UNIT)/2,((i)*UNIT)+UNIT/2])\n            self.wall[i] = self.canvas.create_rectangle(\n                wall_center[i][0] - 20, wall_center[i][1] - 20,\n                wall_center[i][0] + 20, wall_center[i][1] + 20,\n                fill='grey')\n        self.canvas.pack()\n\n    \n    def generate_expression1(self,pain1):\n        if pain1==1:\n            self.canvas.itemconfig(self.agent1, fill=\"red\", outline='black')\n            self.canvas.pack()\n        if pain1 ==0:\n            self.canvas.itemconfig(self.agent1, fill=\"blue\", outline='black')\n            self.canvas.pack()\n    def render(self):\n        time.sleep(0.2)\n        self.update()\n\n\n\n\n"
  },
  {
    "path": "examples/Social_Cognition/affective_empathy/BEEAD-SNN/BEEAD-SNN.py",
    "content": "import os\nimport sys\nimport imageio\nfrom env_poly_SNN import Maze\nfrom env import Maze2\nfrom RL_brain import QLearningTable\nimport numpy as np\nimport pandas as pd\nimport matplotlib\nimport matplotlib.pyplot as plt\nnp.random.seed(1)\nfrom torch.utils.tensorboard import SummaryWriter\nfrom sklearn.preprocessing import MinMaxScaler\nimport torch, os, sys\nfrom torch import nn\nfrom torch.nn import Parameter\nimport abc\nimport math\nfrom abc import ABC\nimport torch.nn.functional as F\nfrom braincog.base.node.node import *\nfrom braincog.base.learningrule.STDP import *\nfrom braincog.base.connection.CustomLinear import *\nfrom braincog.base.utils.visualization import spike_rate_vis, spike_rate_vis_1d#, spike_vis_2, spike_vis_5\n\nclass BrainArea(nn.Module, abc.ABC):\n    @abc.abstractmethod\n    def __init__(self):\n        super().__init__()\n      \n    @abc.abstractmethod\n    def forward(self, x):\n        \"\"\"\n        Calculate the forward propagation process\n        :return:x is spike\n        \"\"\"\n        return x\n\n    def reset(self):\n        \"\"\"\n        Calculate the forward propagation process\n        :return:x is spike\n        \"\"\"\n        pass\n\nclass BAESNN(BrainArea):\n    \"\"\"\n    Affactive Empathy Network\n    \"\"\"\n\n    def __init__(self,):\n        super().__init__()\n\n        self.node = [IFNode() for i in range(5)]\n        \n        self.connection = []\n        \n        con_matrix0 = torch.eye(24, 24)*6\n        self.connection.append(CustomLinear(con_matrix0))#input-emotion\n        \n        con_matrix1 = torch.eye(24, 24)\n       \n        self.connection.append(CustomLinear(con_matrix1))#emotion-ifg\n        \n        con_matrix2 = torch.zeros((24, 24), dtype=torch.float)   \n        self.connection.append(CustomLinear(con_matrix2))#perception-ifg\n        \n        con_matrix3 = torch.eye(24, 24)*6\n        self.connection.append(CustomLinear(con_matrix3))#input-perception\n        \n        con_matrix4=torch.zeros((24,10), dtype=torch.float)\n        for j in range(10):\n            if j in np.arange(0,5,1):\n                for i in np.arange(0, 12, 1):\n                    con_matrix4[i,j] =2\n            if j in np.arange(5,10,1):\n                for i in np.arange(12, 24, 1):\n                    con_matrix4[i,j] =2\n        self.connection.append(CustomLinear(con_matrix4))#emotion-sma\n        \n        con_matrix5=torch.zeros((24,10), dtype=torch.float)\n        self.connection.append(CustomLinear(con_matrix5))#perception-m1\n        \n        con_matrix6 = torch.eye(10, 10)*6\n        self.connection.append(CustomLinear(con_matrix6))#sma-m1\n        \n        self.stdp = []\n        self.stdp.append(STDP(self.node[0], self.connection[0]))#0 node0 emotion\n        self.stdp.append(STDP(self.node[2], self.connection[3]))#1 node2 perception\n        self.stdp.append(MutliInputSTDP(self.node[1], [self.connection[1], self.connection[2]]))#2 node1 ifg\n        self.stdp.append(MutliInputSTDP(self.node[3], [self.connection[4], self.connection[5]]))#3 node3 sma\n        self.stdp.append(STDP(self.node[4], self.connection[6]))#4 node4 m1\n        self.stdp.append(STDP(self.node[1],self.connection[2]))#5 node1 ifg\n        self.stdp.append(STDP(self.node[3],self.connection[5]))#6 node3 sma\n    def forward(self, x1,x2):\n        out__m, dw0 = self.stdp[0](x1)#node0 emotion\n        out__p, dw3 = self.stdp[1](x2)#node2 perception\n        out__ifg,dw_p_i=self.stdp[2](out__m,out__p)#node1 ifg   \n        out__sma,dw_p_s=self.stdp[3](out__m,out__p)#node3  sma\n        out__m1,dw1=self.stdp[4](out__sma)#node4 m1\n    \n        return dw_p_i,dw_p_s,out__ifg,out__sma,out__m1,out__m,out__p\n    \n    def empathy(self,x3):\n        out_p,dw2=self.stdp[1](x3)#node2 perception\n        out_ifg,dw4=self.stdp[5](out_p)#node1 ifg\n        out_sma,dw5=self.stdp[6](out_p)#node3 sma\n        out_m1,dw6=self.stdp[4](out_sma)#node4 m1\n        return out_ifg,out_sma,out_m1,out_p\n        \n    def UpdateWeight(self, i, dw, delta):\n        self.connection[i].update(dw*delta)\n        self.connection[i].weight.data= torch.clamp(self.connection[i].weight.data,-1,4)\n        \n    def reset(self):\n        for i in range(5):\n            self.node[i].n_reset()\n        for i in range(len(self.stdp)):\n            self.stdp[i].reset()\n\nclass DopamineArea(BrainArea):\n    \"\"\"\n    Dopamine brain area with a group of spiking neurons, computes reward prediction error.\n    \"\"\"\n    def __init__(self, n_neurons, beta=0.2):\n        super().__init__()\n        self.n_neurons = n_neurons\n        self.beta = beta\n        self.node = [IFNode() for _ in range(n_neurons)]\n        self.P = np.zeros(n_neurons)  # prediction for each neuron\n    def forward(self, spikes):\n        out_spikes = []\n        for i in range(self.n_neurons):\n            spike = self.node[i](torch.tensor([spikes[i]], dtype=torch.float32))\n            out_spikes.append(spike)\n        S = torch.stack(out_spikes).mean().item()\n        delta = S - self.P\n        self.P = self.P + self.beta * delta\n        return delta, out_spikes\n    def reset(self):\n        self.P = np.zeros(self.n_neurons)\n        for n in self.node:\n            n.n_reset()\n            \ndef BAESNN_train():  \n    s = env.reset()\n    env._set_danger()\n    env._set_wall()\n    pain = 0\n    i = 0\n    set_pain = 0\n    env._set_switch()\n    for i in range(100):\n        snn2.reset()\n        T = 100\n        pain = 0\n        print('**************step:', i)\n        env.render()\n        \n        action = np.random.choice(list(range(env.n_actions)))\n        print('action:', action)\n        d, d_pre, s_, sss = env.step(s, action, pain)\n        print('d:', d, 'd_pre:', d_pre, 'sss:', sss)\n        env.render()\n        \n        while (d == np.array([0, 0])).all():\n            action = np.random.choice(list(range(env.n_actions)))\n            print('action:', action)\n            d, d_pre, s_, sss = env.step(s, action, pain)\n            print('d:', d, 'd_pre:', d_pre, 'sss:', sss)\n            env.render()\n        \n        # Use env.is_agent1_in_danger to set OUT_PAIN, pain, emotion\n        if env.is_agent_in_danger():\n            OUT_PAIN = torch.ones(24)\n            pain = 1\n            set_pain = 1\n            emotion = -1\n        else:\n            OUT_PAIN = torch.zeros(24)\n            pain = 0\n            emotion = 0\n\n        print(\"OUT_PAIN:\", OUT_PAIN)\n        print(\"pain:\", pain)\n        print(\"emotion:\", emotion)\n        \n        T2 = 20\n        X1 = OUT_PAIN\n        X2 = torch.zeros(24)\n        X3 = torch.cat([torch.ones(12) * 0.1, torch.zeros(12)])\n        print('X1,X2:', X1, X2)\n        spike_emotion = []\n        spike_ifg = []\n        spike_sma = []\n        spike_m1 = []\n        spike_per = []\n        for i in range(T2):\n            if i >= 2:\n                X2 = X3\n            OUTPUT = snn2(X1, X2)\n            snn2.UpdateWeight(2, OUTPUT[0][1], 0.01)\n            snn2.UpdateWeight(5, OUTPUT[1][1], -0.1)\n            if OUTPUT[2][0] == 1:\n                env.canvas.itemconfig(env.rect, fill=\"red\", outline='red')\n            if OUTPUT[2][0] == 0:\n                env.canvas.itemconfig(env.rect, fill=\"green\", outline='green')\n            spike_emotion.append(OUTPUT[5])\n            spike_per.append(OUTPUT[6])\n            spike_ifg.append(OUTPUT[2])\n            spike_sma.append(OUTPUT[3])\n            spike_m1.append(OUTPUT[4])\n        \n        print('out_ifg:', OUTPUT[2])\n        print('out_sma:', OUTPUT[3])\n        print('out_m1:', OUTPUT[4])\n        print('con2:', snn2.connection[2].weight.data)\n        print('con5:', snn2.connection[5].weight.data)\n        \n        spike_emotion = torch.stack(spike_emotion)\n        spike_per = torch.stack(spike_per)\n        spike_ifg = torch.stack(spike_ifg)\n        spike_sma = torch.stack(spike_sma)\n        spike_m1 = torch.stack(spike_m1)\n        print(spike_emotion.shape)\n        env.render()\n        \n        s = s_\n        if set_pain == 1 and pain == 0:\n            env.render()\n            break\n    env.destroy()\n\ndef BAESNN_train_alstruism(lamda, E):\n    global writer\n    for episode in range(E):\n        print('*******************episode:', episode, ',factor:', lamda, '*********************************')\n        s1,s2=env2.reset()\n        env2._set_wall()\n        pain1 = 0\n        pain2 = 0\n        i = 0\n        set_pain = 0\n        env2.emotion = 0\n        env2.empathy_emotion = 0\n        env2.empathy_emotion_t_1 = 0\n        \n        rr = 0\n        hh = 0\n        g = 0\n        a = []\n        if episode < 200:\n            e_greedy = 0.5\n        elif episode < 500:\n            e_greedy = 0.7\n        elif episode < 700:\n            e_greedy = 0.9\n        else:\n            e_greedy = 1\n            \n        r = np.random.uniform()\n        \n        for i in range(100):\n            env2.render()\n            done = False\n            env2.empathy_emotion_t_1=env2.empathy_emotion\n            action2 = RL.choose_action(str([(s2[4] + s2[8]) / 2, (s2[5] + s2[9]) / 2,env2.empathy_emotion]),e_greedy=e_greedy)\n            s2_, done, done_oval = env2.step2(action2)\n            env2.render()\n            T=100\n            print('**************step:',i)\n            env2.render()\n            \n            action1 = np.random.choice(list(range(env.n_actions)))\n            print('action1:',action1)\n            if r<= 0.25:\n                if i==0:\n                    action1=2\n            if 0.25<r<=0.5:\n                if i==0:\n                    action1=1\n                if i==1:\n                    action1=1\n            if 0.5<r<=0.75:\n                if i==0:\n                    action1=0\n                if i==1:\n                    action1=1\n                if i==2:\n                    action1=1\n            if 0.75<r<=1.0:\n                if i==0:\n                    action1=0\n                if i==1:\n                    action1=3\n                if i==2:\n                    action1=1\n                if i==3:\n                    action1=1\n            if env2.emotion==0 and set_pain==1:\n                pass\n            else:     \n                d,d_pre,s1_,sss = env2.step1(s1, action1,env2.emotion)\n            print('d:',d,'d_pre:',d_pre,'sss:',sss)\n            env2.render()\n            if env2.is_agent1_in_danger():\n                OUT_PAIN = torch.ones(24)\n                pain1 = 1\n                env2.emotion = -1\n                set_pain = 1\n            else:\n                OUT_PAIN = torch.zeros(24)\n                pain1 = 0\n                env2.emotion = 0\n\n            print(\"OUT_PAIN:\", OUT_PAIN)\n            print(\"pain1:\", pain1)\n            print(\"emotion:\", env2.emotion)\n            env2.generate_expression1(env2.emotion)\n            \n            snn2.reset()\n            T2 = 20\n            X3 = OUT_PAIN.view(1, -1)\n            for i in range(T2):\n                OUT = snn2.empathy(X3)\n                print('out_ifg:', OUT[0])\n            \n            if OUT[0][0][0] == 1:\n                env2.empathy_emotion = -1            \n            if OUT[0][0][0] == 0:\n                env2.empathy_emotion = 0    \n                \n            env2.generate_expression2(env2.empathy_emotion)\n            env2.render()\n\n            delta,_ = snn1(OUT[0][0])\n            reward1 = delta\n            _, reward2 = env2.reward2()   \n            rr += reward2\n                \n            RL.learn_A(str([(s2[4] + s2[8]) / 2, (s2[5] + s2[9]) / 2, env2.empathy_emotion_t_1]), action2, env2.lamda * reward1 + reward2, str([(s2_[4] + s2_[8]) / 2, (s2_[5] + s2_[9]) / 2, env2.empathy_emotion]), done_oval)\n            \n            s1 = s1_\n            env2.render()\n            if env2.empathy_emotion == 0 and set_pain == 1:\n                a.append(i)\n            s2 = s2_  \n            if done:\n                g = 1\n                break\n        if a != []:\n            hh = 1\n        helpnumber.append(hh)\n        totalreward.append(rr)\n        goalnumber.append(g)\n        print('totalreward:\\n', totalreward)\n        print('goalnumber:\\n', goalnumber)\n        print('helpnumber:\\n', helpnumber)\n        writer.add_scalar('totalreward', rr, episode)\n        writer.add_scalar('helpnumber', hh, episode)\n\nif __name__ == \"__main__\":\n    env = Maze() \n    snn1 = DopamineArea(n_neurons=24, beta=0.2)  \n    snn2 = BAESNN() \n    writer = SummaryWriter(log_dir='runs/BAESNN')\n    BAESNN_train()\n    env.mainloop()\n    K=3\n    factor=3.0\n    Episode=1000\n    REWARD=[]\n    totalreward=[]\n    helpnumber=[]  \n    helpave=[]  \n    goalnumber=[]\n    env2 = Maze2(lamda=factor)\n    RL = QLearningTable(actions=list(range(env.n_actions)))\n    BAESNN_train_alstruism(factor,Episode)\n    helpnumber=np.array(helpnumber)\n    for jj in range(Episode//10):\n        helpave.append(np.mean(helpnumber[jj*10:(jj+1)*10])*10)\n    plt.figure(1,figsize=[18,9])\n    axes = plt.gca()\n    axes.set_ylim([-30,10])\n    plt.plot(totalreward,label=factor)\n    plt.legend(loc='lower right')\n    plt.title('totalreward')\n    plt.savefig('{}-totalreward.jpg'.format(K))\n    plt.figure(2,figsize=[18,9])\n    axes = plt.gca()\n    axes.set_ylim([0,12])\n    plt.plot(helpave,label=factor)\n    plt.legend(loc='lower right')\n    plt.title('helpave')\n    plt.savefig('{}-helpave.jpg'.format(K))\n    env.mainloop()\n    writer.close()\n"
  },
  {
    "path": "examples/Social_Cognition/affective_empathy/BEEAD-SNN/README.md",
    "content": "# Citation\n\n```bibtex\n@article{FeifeiZhao2025BuildingAA,\n  title={Building Altruistic and Moral AI Agent with Brain-inspired Emotional Empathy Mechanisms},\n  author={Feifei Zhao and Hui Feng and Haibo Tong and Zhengqiang Han and Erliang Lin and Enmeng Lu and Yinqian Sun and Yi Zeng},\n  journal={IEEE Transactions on Affective Computing},\n  year={2025},\n  volume={19 1},\n  doi={10.1109/taffc.2025.3627936}\n}\n```\n"
  },
  {
    "path": "examples/Social_Cognition/affective_empathy/BEEAD-SNN/RL_Brain.py",
    "content": "import numpy as np\nimport pandas as pd\n\nclass QLearningTable:\n    def __init__(self, actions, learning_rate=0.1, reward_decay=0.9):\n        self.actions = actions  # a list\n        self.lr = learning_rate\n        self.gamma = reward_decay\n        self.q_tableA = pd.DataFrame(columns=self.actions)\n       \n        self.colorblack=0\n        self.coloryellow=0\n      \n    def choose_action(self, observation,e_greedy):\n        self.A_chack_q_table_A(observation)\n       \n        # action selection\n        if np.random.uniform() < e_greedy:\n            # choose best action\n            state_actionA = self.q_tableA.loc[observation, :]\n            \n            state_action=state_actionA\n         \n            state_action= state_action.astype(float)\n            # print(state_action)\n            \n            action = state_action.argmax()\n            #print('best action',action)\n            \n        else:\n            action = np.random.choice(self.actions)\n            #print('random action',action)\n        return action\n\n    def learn_A(self, s, a, r, s_,done_oval):\n        self.A_chack_q_table_A(s_)\n       \n        if done_oval==0:\n            q_predict = self.q_tableA.loc[s, a]\n            q_target = r + self.gamma * self.q_tableA.loc[s_, :].max()  \n            self.q_tableA.loc[s, a] += self.lr * (q_target - q_predict)  # update\n            #print('self.q_tableA:\\n',self.q_tableA)\n        else:\n            pass\n        \n    def A_chack_q_table_A(self, state):\n        if state not in self.q_tableA.index:\n            # append new state to q table\n            # self.q_tableA = self.q_tableA.append( # append方法被新版本的pandas弃用\n            #     pd.Series(\n            #         [0]*len(self.actions),\n            #         index=self.q_tableA.columns,\n            #         name=state,\n            #     )\n            # )\n            self.q_tableA = pd.concat([\n                self.q_tableA,\n                pd.Series(\n                    [0] * len(self.actions),\n                    index=self.q_tableA.columns,\n                    name=state,\n                ).to_frame().T\n            ])\n            \nclass EnvModel:\n    \"\"\"Similar to the memory buffer in DQN, you can store past experiences in here.\n    Alternatively, the model can generate next state and reward signal accurately.\"\"\"\n    def __init__(self, actions):\n        # the simplest case is to think about the model is a memory which has all past transition information\n        self.actions = actions\n        self.database = pd.DataFrame(columns=actions, dtype=np.object)\n\n    def store_transition(self, s, a, r, s_):\n        if s not in self.database.index:\n            # self.database = self.database.append( append方法被新版本的pandas弃用\n            #     pd.Series(\n            #         [None] * len(self.actions),\n            #         index=self.database.columns,\n            #         name=s,\n            #     ))\n            self.database = pd.concat([\n                self.database,\n                pd.Series(\n                    [None] * len(self.actions),\n                    index=self.database.columns,\n                    name=s,\n                ).to_frame().T\n            ])\n        self.database.set_value(s, a, (r, s_))\n\n    def sample_s_a(self):\n        s = np.random.choice(self.database.index)\n        a = np.random.choice(self.database.loc[s].dropna().index)    # filter out the None value\n        return s, a\n\n    def get_r_s_(self, s, a):\n        r, s_ = self.database.loc[s, a]\n        return r, s_\n"
  },
  {
    "path": "examples/Social_Cognition/affective_empathy/BEEAD-SNN/env.py",
    "content": "import numpy as np\nnp.random.seed(1)\nimport tkinter as tk\nimport time\nfrom PIL import ImageGrab\n\nUNIT = 40   # pixels\nMAZE_H = 11  # grid height\nMAZE_W = 5 # grid width\n\nclass Maze2(tk.Tk, object):\n    def __init__(self,lamda=0):\n        super(Maze2, self).__init__()\n        self.action_space = ['u', 'd', 'l', 'r']\n        self.action_space1 = ['u', 'd', 'l', 'r']\n        self.n_actions = len(self.action_space)\n        self.n_actions1 = len(self.action_space1)\n        self.title('pain_empathy')\n        self.geometry('{0}x{1}'.format(MAZE_H * UNIT, MAZE_W * UNIT))\n        self._build_maze()\n        self.action_hurt1 = 0\n        self.emotion=0\n        self.empathy_emotion=0 \n        self.delta=0.5\n        self.set_pain=0\n        self.help_signal=0\n        self.lamda=lamda\n        self.empathy_emotion_t_1=0\n            \n    def _build_maze(self):\n        self.canvas = tk.Canvas(self, bg='white',\n                           height=MAZE_W * UNIT,\n                           width=MAZE_H * UNIT)\n        \n        for c in range(0, MAZE_H * UNIT, UNIT):# create grids\n            x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT\n            self.canvas.create_line(x0, y0, x1, y1)\n        for r in range(0, MAZE_H * UNIT, UNIT):\n            x0, y0, x1, y1 = 0, r, MAZE_H * UNIT, r\n            self.canvas.create_line(x0, y0, x1, y1)\n        \n        # self.oval_center = np.array([(MAZE_H * UNIT)/2+80, UNIT/2+UNIT])# create switch\n        # self.oval = self.canvas.create_oval(\n        #     self.oval_center[0] - 15, self.oval_center[1] - 15,\n        #     self.oval_center[0] + 15, self.oval_center[1] + 15,\n        #     fill='yellow')\n        # self.help = self.canvas.coords(self.oval)\n\n\n\n\n        self.orgin1 = np.array([20, 20])\n        # 下\n        self.points1 = [\n            # 左上\n            self.orgin1[0] - 15,  # 5\n            self.orgin1[1] - 15,  # 5\n            # 右上\n            self.orgin1[0] + 15,  # 35\n            self.orgin1[1] - 15,  # 5\n            # 右下+\n            self.orgin1[0] + 15,  # 35\n            self.orgin1[1],  # 20\n            # 顶点\n            self.orgin1[0],  # 20\n            self.orgin1[1] + 15,  # 35\n            # 左下+\n            self.orgin1[0] - 15,  # 5\n            self.orgin1[1],  # 20\n        ]\n        self.agent1 = self.canvas.create_polygon(self.points1, outline='black',fill=\"blue\")# left agent\n\n\n\n\n        self.orgin2 = np.array([MAZE_H * UNIT - UNIT / 2, 20])\n        # 下\n        self.points2 = [\n            # 左上\n            self.orgin2[0] - 15,  # 5\n            self.orgin2[1] - 15,  # 5\n            # 右上\n            self.orgin2[0] + 15,  # 35\n            self.orgin2[1] - 15,  # 5\n            # 右下+\n            self.orgin2[0] + 15,  # 35\n            self.orgin2[1],  # 20\n            # 顶点\n            self.orgin2[0],  # 20\n            self.orgin2[1] + 15,  # 35\n            # 左下+\n            self.orgin2[0] - 15,  # 5\n            self.orgin2[1],  # 20\n        ]\n        self.agent2 = self.canvas.create_polygon(self.points2, fill=\"green\")# right agent\n        \n        \n        \n        \n        self.goal_centre = np.array([(MAZE_H/2) * UNIT + UNIT, (MAZE_W/2) * UNIT])\n        # 下\n        self.points3 = [\n            # 左上\n            self.goal_centre[0] - 15,  # 5\n            self.goal_centre[1] ,  # 5\n            # 右上\n            self.goal_centre[0] ,  # 35\n            self.goal_centre[1] - 15,  # 5\n            # 右下+\n            self.goal_centre[0] + 15,  # 35\n            self.goal_centre[1],  # 20\n            # 顶点\n            self.goal_centre[0],  # 20\n            self.goal_centre[1] + 15,  # 35 \n        ]\n        self.target = self.canvas.create_polygon(self.points3, fill=\"purple\") # goal\n       \n        \n       \n        \n        # self.food=self.canvas.create_arc(((MAZE_H-2) * UNIT +10 ,160,(MAZE_H) * UNIT-5 ,220), \n        #                                  start=0, extent=60, fill='red', outline='orange', width=2)#food\n       \n        \n        \n        self.hell1_center = np.array([60, 20])\n        self.hell1 = self.canvas.create_oval(\n            self.hell1_center[0] - 15, self.hell1_center[1] - 15,\n            self.hell1_center[0] + 15, self.hell1_center[1] + 15,\n            fill='black')\n        self.hell2_center = np.array([20, 100])\n        self.hell2 = self.canvas.create_oval(\n            self.hell2_center[0] - 15, self.hell2_center[1] - 15,\n            self.hell2_center[0] + 15, self.hell2_center[1] + 15,\n            fill='black')\n        self.hell3_center = np.array([140, 140])\n        self.hell3 = self.canvas.create_oval(\n            self.hell3_center[0] - 15, self.hell3_center[1] - 15,\n            self.hell3_center[0] + 15, self.hell3_center[1] + 15,\n            fill='black')\n        self.hell4_center = np.array([140, 60])\n        self.hell4 = self.canvas.create_oval(\n            self.hell4_center[0] - 15, self.hell4_center[1] - 15,\n            self.hell4_center[0] + 15, self.hell4_center[1] + 15,\n            fill='black')\n     \n\n        self.canvas.pack()\n\n    #reset agent location\n    def reset(self):\n        self.update()\n        time.sleep(0.5)\n        self.help_signal=0\n        self.action_hurt1 = 0\n        self.empathy_emotion = 0  # 修正拼写\n        \n        \n        self.canvas.delete(self.agent1)\n        self.canvas.delete(self.agent2)\n        self.orgin1 = np.array([20, 20])\n        # 下\n        self.points1 = [\n            # 左上\n            self.orgin1[0] - 15,  # 5\n            self.orgin1[1] - 15,  # 5\n            # 右上\n            self.orgin1[0] + 15,  # 35\n            self.orgin1[1] - 15,  # 5\n            # 右下+\n            self.orgin1[0] + 15,  # 35\n            self.orgin1[1],  # 20\n            # 顶点\n            self.orgin1[0],  # 20\n            self.orgin1[1] + 15,  # 35\n            # 左下+\n            self.orgin1[0] - 15,  # 5\n            self.orgin1[1],  # 20\n        ]\n        self.agent1 = self.canvas.create_polygon(self.points1, outline='black',fill=\"blue\")\n\n        self.orgin2 = np.array([MAZE_H * UNIT - UNIT / 2, 20])\n        # 下\n        self.points2 = [\n            # 左上\n            self.orgin2[0] - 15,  # 5\n            self.orgin2[1] - 15,  # 5\n            # 右上\n            self.orgin2[0] + 15,  # 35\n            self.orgin2[1] - 15,  # 5\n            # 右下+\n            self.orgin2[0] + 15,  # 35\n            self.orgin2[1],  # 20\n            # 顶点\n            self.orgin2[0],  # 20\n            self.orgin2[1] + 15,  # 35\n            # 左下+\n            self.orgin2[0] - 15,  # 5\n            self.orgin2[1],  # 20\n        ]\n        self.agent2 = self.canvas.create_polygon(self.points2, fill=\"green\")\n        return self.canvas.coords(self.agent1),self.canvas.coords(self.agent2)\n\n\n    # move agent1\n    def step1(self, s1, action1,emotion):\n        s1 = self.canvas.coords(self.agent1)\n        self.help_signal = 0 \n        self.centre1 = [(s1[4] + s1[8]) / 2, (s1[5] + s1[9]) / 2]\n        if all(self.centre1 == self.hell1_center):\n            self.action_hurt1 = 1\n        if all(self.centre1 == self.hell2_center):\n            self.action_hurt1 = 1\n        if all(self.centre1 == self.hell3_center):\n            self.action_hurt1 = 1   \n        if all(self.centre1 == self.hell4_center):\n            self.action_hurt1 = 1  \n        # if self.help_signal:\n        #     self.action_hurt1 = 0\n        \n            \n        self.canvas.delete(self.agent1)  \n        self.centre1 = [(s1[4] + s1[8]) / 2, (s1[5] + s1[9]) / 2]\n        if action1==0:\n            self.points0 = [\n                # 右下\n                self.centre1[0] + 15,  # 35\n                self.centre1[1] + 15,  # 35\n                # 左下\n                self.centre1[0] - 15,  # 5\n                self.centre1[1] + 15,  # 35\n                # 左上+\n                self.centre1[0] - 15,  # 5\n                self.centre1[1],  # 20\n                # 顶点\n                self.centre1[0],  # 20\n                self.centre1[1] - 15,  # 5\n                # 右上+\n                self.centre1[0] + 15,  # 35\n                self.centre1[1],  # 20\n            ]\n            if emotion==0:\n                color=\"blue\"\n            if emotion == -1:\n                color = \"red\"\n            self.agent1 = self.canvas.create_polygon(self.points0, fill=color,outline='black')\n        if action1==1:\n            self.points1 = [\n                # 左上\n                self.centre1[0] - 15,  # 5\n                self.centre1[1] - 15,  # 5\n                # 右上\n                self.centre1[0] + 15,  # 35\n                self.centre1[1] - 15,  # 5\n                # 右下+\n                self.centre1[0] + 15,  # 35\n                self.centre1[1],  # 20\n                # 顶点\n                self.centre1[0],  # 20\n                self.centre1[1] + 15,  # 35\n                # 左下+\n                self.centre1[0] - 15,  # 5\n                self.centre1[1],  # 20\n            ]\n            if emotion==0:\n                color=\"blue\"\n            if emotion == -1:\n                color = \"red\"\n            self.agent1 = self.canvas.create_polygon(self.points1, fill=color,outline='black')\n        if action1==2:\n            self.points2 = [\n                # 左下\n                self.centre1[0] - 15,  # 5\n                self.centre1[1] + 15,  # 35\n                # 左上\n                self.centre1[0] - 15,  # 5\n                self.centre1[1] - 15,  # 5\n                # 右上+\n                self.centre1[0],  # 20\n                self.centre1[1] - 15,  # 5\n                # 顶点\n                self.centre1[0] + 15,  # 35\n                self.centre1[1],  # 20\n                # 右下+\n                self.centre1[0],  # 20\n                self.centre1[1] + 15,  # 35\n            ]\n            if emotion==0:\n                color=\"blue\"\n            if emotion == -1:\n                color = \"red\"\n            self.agent1 = self.canvas.create_polygon(self.points2, fill=color,outline='black')\n        if action1==3:\n            self.points3 = [\n                # 右上\n                self.centre1[0] + 15,  # 20+15\n                self.centre1[1] - 15,  # 20-15\n                # 右下\n                self.centre1[0] + 15,  # 20+15\n                self.centre1[1] + 15,  # 20+15\n                # 左下+\n                self.centre1[0],  # 20\n                self.centre1[1] + 15,  # 20+15\n                # 顶点\n                self.centre1[0] - 15,  # 20-15\n                self.centre1[1],  # 20\n                # 左上+\n                self.centre1[0],  # 20\n                self.centre1[1] - 15,  # 20-15\n            ]\n            if emotion==0:\n                color=\"blue\"\n            if emotion == -1:\n                color = \"red\"\n            self.agent1 = self.canvas.create_polygon(self.points3, fill=color,outline='black')\n        s1 = self.canvas.coords(self.agent1)\n        self.render()#显示当前的动作指令是什么\n\n        # whether hurt\n        if self.action_hurt1 == 0:\n            true_action1 = action1\n        else:\n            if action1 == 0:\n                true_action1 = 1\n            if action1 == 1:\n                true_action1 = 0\n            if action1 == 2:\n                true_action1 = 3\n            if action1 == 3:\n                true_action1 = 2\n                \n                \n        self.centre1 = [(s1[4] + s1[8]) / 2, (s1[5] + s1[9]) / 2] \n        \n        # predict next state\n        pre_displacement1 = np.array([0, 0])\n        if self.centre1[0] <= ((MAZE_H - 1) / 2 + 1) * UNIT:  # 120\n            if action1 == 0:  # up\n                if self.centre1[1] > UNIT:\n                    pre_displacement1 = np.array([0, -40])\n            elif action1 == 1:  # down\n                if self.centre1[1] < (MAZE_W - 1) * UNIT:\n                    pre_displacement1 = np.array([0, 40])\n            elif action1 == 2:  # right\n                if self.centre1[0] < ((MAZE_H - 1) / 2 - 1) * UNIT:\n                    pre_displacement1 = np.array([40, 0])\n            elif action1 == 3:  # left\n                if self.centre1[0] > UNIT:\n                    pre_displacement1 = np.array([-40, 0])\n        else:\n            if action1 == 0:  # up\n                if self.centre1[1] > UNIT:\n                    pre_displacement1 = np.array([0, -40])\n            elif action1 == 1:  # down\n                if self.centre1[1] < (MAZE_W - 1) * UNIT:\n                    pre_displacement1 = np.array([0, 40])\n            elif action1 == 2:  # right\n                if self.centre1[0] < (MAZE_H - 1) * UNIT:\n                    pre_displacement1 = np.array([40, 0])\n            elif action1 == 3:  # left\n                if self.centre1[0] > ((MAZE_H - 1) / 2 + 2) * UNIT:\n                    pre_displacement1 = np.array([-40, 0])\n       \n        \n        # true next state\n        displacement1 = np.array([0, 0])\n        \n        if self.centre1[0] <= ((MAZE_H - 1) / 2 + 1) * UNIT:\n            if true_action1 == 0:  # up\n                if self.centre1[1] > UNIT:\n                    displacement1= np.array([0, -40])\n            elif true_action1 == 1:  # down\n                if self.centre1[1] < (MAZE_W - 1) * UNIT:\n                    displacement1= np.array([0,40])\n            elif true_action1 == 2:  # right\n                if self.centre1[0] < ((MAZE_H - 1) / 2 - 1) * UNIT:\n                    displacement1= np.array([40,0])\n            elif true_action1 == 3:  # left\n                if self.centre1[0] > UNIT:\n                    displacement1= np.array([-40,0])\n        else:\n            if true_action1 == 0:  # up\n                if self.centre1[1] > UNIT:\n                    displacement1= np.array([0,-40])\n            elif true_action1 == 1:  # down\n                if self.centre1[1] < (MAZE_W - 1) * UNIT:\n                    displacement1= np.array([0,40])\n            elif true_action1 == 2:  # right\n                if self.centre1[0] < (MAZE_H - 1) * UNIT:\n                    displacement1= np.array([40,0])\n            elif true_action1 == 3:  # left\n                if self.centre1[0] > ((MAZE_H - 1) / 2 + 2) * UNIT:\n                    displacement1= np.array([-40,0])\n        self.canvas.move(self.agent1, displacement1[0], displacement1[1])\n        s1_ = self.canvas.coords(self.agent1)\n        sss = [(s1_[4] + s1_[8]) / 2, (s1_[5] + s1_[9]) / 2]\n\n        return displacement1, pre_displacement1,s1_,sss\n    \n    def is_agent1_in_danger(self):\n        \"\"\"\n        whether in danger(hell1~hell4)\n        :return: True/False\n        \"\"\"\n        s1 = self.canvas.coords(self.agent1)\n        print([(s1[4] + s1[8]) / 2, (s1[5] + s1[9]) / 2])\n        print(f'hell1_center: {self.hell1_center}, hell2_center: {self.hell2_center}, hell3_center: {self.hell3_center}, hell4_center: {self.hell4_center}')\n        agent1_pos = [(s1[4] + s1[8]) / 2, (s1[5] + s1[9]) / 2]\n        danger_centers = [self.hell1_center, self.hell2_center, self.hell3_center, self.hell4_center]\n        for center in danger_centers:\n            if all(np.isclose(agent1_pos, center)):\n                return True\n        return False\n\n    def step2(self, action):\n        s = self.canvas.coords(self.agent2)\n        \n        s=[(s[4] + s[8]) / 2, (s[5] + s[9]) / 2]\n        if all(s == self.oval_center)and self.empathy_emotion==-1:\n            done_oval=1\n        else:\n            done_oval=0\n       \n        base_action = np.array([0, 0])\n        if action == 0:  # up\n            if s[1] > UNIT:\n                base_action[1] -= UNIT\n        elif action == 1:  # down\n            if s[1] < (MAZE_W - 1) * UNIT:\n                base_action[1] += UNIT\n        elif action == 2:  # right\n            if s[0] < (MAZE_H - 1) * UNIT:\n                base_action[0] += UNIT\n        elif action == 3:  # left\n            if s[0] > ( (MAZE_H - 1)/2+2) * UNIT:\n                base_action[0] -= UNIT\n\n        self.canvas.move(self.agent2, base_action[0], base_action[1])  # move agent\n        s_ = self.canvas.coords(self.agent2)  # next state\n        self.centre2= [(s_[4] + s_[8]) / 2, (s_[5] + s_[9]) / 2]\n        \n        \n        \n        if all(self.centre2 == self.oval_center) and self.empathy_emotion==-1:\n           \n            self.help_signal=1\n      \n           \n        if all(self.centre2 == self.goal_centre):\n        \n            done = True\n        else:\n           \n            done = False\n            \n            \n            \n            \n        return s_, done,done_oval\n\n\n    def reward2(self):\n        \n        s_ = self.canvas.coords(self.agent2) \n        self.centre2= [(s_[4] + s_[8]) / 2, (s_[5] + s_[9]) / 2]\n        \n        if (self.empathy_emotion - self.empathy_emotion_t_1)==-1:\n            reward1=0\n        \n        elif (self.empathy_emotion - self.empathy_emotion_t_1)==1:\n            reward1=10\n        else:\n            reward1=0\n       \n            \n        if all(self.centre2 == self.goal_centre):\n            reward2 = 10    \n        else:\n            reward2 = -1\n            \n               \n            \n        return  reward1,reward2\n\n\n    def _set_wall(self):\n        \n        self.oval_center = np.array([(MAZE_H * UNIT)-20, ((MAZE_W)*UNIT-20)])# [(MAZE_H * UNIT)/2+80, UNIT/2+UNIT]\n        self.oval = self.canvas.create_oval(\n            self.oval_center[0] - 15, self.oval_center[1] - 15,\n            self.oval_center[0] + 15, self.oval_center[1] + 15,\n            fill='yellow')\n        self.help = self.canvas.coords(self.oval)\n        wall_center = []\n        self.wall = []\n        for i in range(MAZE_W):\n            wall_center.append([])\n            self.wall.append([])\n        for i in range(MAZE_W):\n            wall_center[i] = np.array([(MAZE_H * UNIT) / 2, ((i) * UNIT) + UNIT / 2])# wall\n            self.wall[i] = self.canvas.create_rectangle(\n                    wall_center[i][0] - 20, wall_center[i][1] - 20,\n                    wall_center[i][0] + 20, wall_center[i][1] + 20,\n                    fill='grey')\n        self.canvas.pack()\n\n\n\n\n\n    def generate_expression1(self,emotion):\n        if emotion==-1:\n            self.canvas.itemconfig(self.agent1, fill=\"red\", outline='black')\n            self.canvas.pack()\n        if emotion==0:\n            self.canvas.itemconfig(self.agent1, fill=\"blue\", outline='black')\n            self.canvas.pack()\n \n    \n    def generate_expression2(self,emotion):\n        if emotion==-1:\n            self.canvas.itemconfig(self.agent2, fill=\"red\")\n            self.canvas.pack()\n        if emotion==0:\n            self.canvas.itemconfig(self.agent2, fill=\"green\")\n            self.canvas.pack()\n \n \n    def render(self):\n        time.sleep(0.000001)\n        self.update()\n\n    # def getter(self,widget):\n    #     widget.update()\n    #     x = tk.Tk.winfo_rootx(self) + widget.winfo_x()\n    #     y = tk.Tk.winfo_rooty(self) + widget.winfo_y()\n    #     x1 = x + widget.winfo_width()\n    #     y1 = y + widget.winfo_height()\n    #     ImageGrab.grab().crop((x, y, x1, y1)).save(\"first.jpg\")\n    #     return ImageGrab.grab().crop((x, y, x1, y1))\n\n"
  },
  {
    "path": "examples/Social_Cognition/affective_empathy/BEEAD-SNN/env_poly_SNN.py",
    "content": "import numpy as np\nnp.random.seed(1)\nimport tkinter as tk\nimport time\nfrom PIL import ImageGrab\n\nUNIT = 40   # pixels\nMAZE_H = 9  # grid height\nMAZE_W = 4 # grid width\n\nclass Maze(tk.Tk, object):\n    def __init__(self):\n        super(Maze, self).__init__()\n        self.action_space = ['u', 'd', 'l', 'r']\n        self.n_actions = len(self.action_space)\n        self.title('self-pain')\n        self.geometry('{0}x{1}'.format(MAZE_H * UNIT, MAZE_W * UNIT))\n        self._build_maze()\n        self.danger=0\n        self.action_hurt=0\n        self.sensory_hurt = 0\n        self.open_door = 0\n        self.pain_state=0\n\n    # create environment\n    def _build_maze(self):\n        self.canvas = tk.Canvas(self, bg='white',\n                           height=MAZE_W * UNIT,\n                           width=MAZE_H * UNIT)\n\n        # create grids\n        for c in range(0, MAZE_H * UNIT, UNIT):\n            x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT\n            self.canvas.create_line(x0, y0, x1, y1)\n        for r in range(0, MAZE_H * UNIT, UNIT):\n            x0, y0, x1, y1 = 0, r, MAZE_H * UNIT, r\n            self.canvas.create_line(x0, y0, x1, y1)\n\n        self.orgin=[20,20]\n        # create agent\n        self.points1 = [\n            self.orgin[0]-15,self.orgin[1]-15,\n            self.orgin[0]+15,self.orgin[1]-15,\n            self.orgin[0]+15,self.orgin[1],\n            self.orgin[0],self.orgin[1]+15,\n            self.orgin[0]-15,self.orgin[1],\n        ]\n        self.rect = self.canvas.create_polygon(self.points1, fill=\"green\")\n        self.canvas.pack()\n\n    #reset agent location\n    def reset(self):\n        self.open_door = 0\n        self.update()\n        time.sleep(0.5)\n        self.canvas.delete(self.rect)\n        self.orgin = [20, 20]\n        # 下\n        self.points1 = [\n            self.orgin[0] - 15,self.orgin[1] - 15,\n            self.orgin[0] + 15,self.orgin[1] - 15,\n            self.orgin[0] + 15,self.orgin[1],\n            self.orgin[0],self.orgin[1]+15,\n            self.orgin[0] - 15,self.orgin[1],\n        ]\n        self.rect = self.canvas.create_polygon(self.points1, fill=\"green\")\n        return self.canvas.coords(self.rect)\n\n    def step(self, s, action, pain):\n        s = self.canvas.coords(self.rect)\n        self.centre = [(s[4] + s[8]) / 2, (s[5] + s[9]) / 2]\n\n        # danger or switch\n        if self.danger==1:\n            if all(self.centre == self.oval_center):\n                s_color = 'yellow'\n                self.canvas.delete(self.wall[3])\n                self.render()\n               \n                self.open_door = 1\n\n                move = np.array([80, 0])\n                self.canvas.move(self.rect, move[0], move[1])\n\n                s = self.canvas.coords(self.rect)\n                self.render()\n                \n            elif all(self.centre == self.hell1_center):\n                s_color = 'black'\n                self.action_hurt = 1\n                self.render()\n            else:\n                s_color = 'white'\n\n        # modify current state\n        self.canvas.delete(self.rect)\n        self.centre = [(s[4] + s[8]) / 2, (s[5] + s[9]) / 2]\n\n        if action==0:\n            self.points0 = [\n                self.centre[0] + 15,self.centre[1] + 15,\n                self.centre[0] - 15,self.centre[1] + 15,\n                self.centre[0] - 15,self.centre[1],\n                self.centre[0],self.centre[1] - 15,\n                self.centre[0] + 15,self.centre[1],\n            ]\n            if pain==0:\n                color=\"green\"\n            if pain == 1:\n                color = \"red\"\n            self.rect = self.canvas.create_polygon(self.points0, fill=color)\n        if action==1:\n            self.points1 = [\n                self.centre[0] - 15,self.centre[1] - 15,\n                self.centre[0] + 15,self.centre[1] - 15,\n                self.centre[0] + 15,self.centre[1],\n                self.centre[0],self.centre[1] + 15,\n                self.centre[0] - 15,self.centre[1],\n            ]\n            if pain==0:\n                color=\"green\"\n            if pain == 1:\n                color = \"red\"\n            self.rect = self.canvas.create_polygon(self.points1, fill=color)\n        if action==2:\n            self.points2 = [\n                self.centre[0] - 15,self.centre[1] + 15,\n                self.centre[0] - 15,self.centre[1] - 15,\n                self.centre[0],self.centre[1] - 15,\n                self.centre[0] + 15,self.centre[1],\n                self.centre[0],self.centre[1] + 15,\n            ]\n            if pain==0:\n                color=\"green\"\n            if pain == 1:\n                color = \"red\"\n            self.rect = self.canvas.create_polygon(self.points2, fill=color)\n        if action==3:\n            self.points3 = [\n                self.centre[0] + 15,\n                self.centre[1] - 15,\n                self.centre[0] + 15,self.centre[1] + 15,\n                self.centre[0],self.centre[1] + 15,\n                self.centre[0] - 15,self.centre[1],\n                self.centre[0],self.centre[1] - 15,\n            ]\n            if pain==0:\n                color=\"green\"\n            if pain == 1:\n                color = \"red\"\n            self.rect = self.canvas.create_polygon(self.points3, fill=color)\n        s = self.canvas.coords(self.rect)\n        self.render()\n\n        if s[0] > (9 / 2) * 40:\n            self.action_hurt = 0\n            \n        base_action = np.array([0, 0])\n        if self.action_hurt == 0:\n            true_action = action\n        else:\n            if action == 0:\n                true_action = 1\n            if action == 1:\n                true_action = 0\n            if action == 2:\n                true_action = 3\n            if action == 3:\n                true_action = 2\n        # predict next state\n        self.centre1 = [(s[4] + s[8]) / 2, (s[5] + s[9]) / 2]\n        pre_displacement1 = np.array([0, 0])\n        if self.centre1[0] <= ((MAZE_H - 1) / 2 + 1) * UNIT:  # 120\n            if action == 0:  # up\n                if self.centre1[1] > UNIT:\n                    pre_displacement1 = np.array([0, -40])\n            elif action == 1:  # down\n                if self.centre1[1] < (MAZE_W - 1) * UNIT:\n                    pre_displacement1 = np.array([0, 40])\n            elif action == 2:  # right\n                if self.centre1[0] < ((MAZE_H - 1) / 2 - 1) * UNIT:\n                    pre_displacement1 = np.array([40, 0])\n            elif action == 3:  # left\n                if self.centre1[0] > UNIT:\n                    pre_displacement1 = np.array([-40, 0])\n        else:\n            if action == 0:  # up\n                if self.centre1[1] > UNIT:\n                    pre_displacement1 = np.array([0, -40])\n            elif action == 1:  # down\n                if self.centre1[1] < (MAZE_W - 1) * UNIT:\n                    pre_displacement1 = np.array([0, 40])\n            elif action == 2:  # right\n                if self.centre1[0] < (MAZE_H - 1) * UNIT:\n                    pre_displacement1 = np.array([40, 0])\n            elif action == 3:  # left\n                if self.centre1[0] > ((MAZE_H - 1) / 2 + 2) * UNIT:\n                    pre_displacement1 = np.array([-40, 0])\n        \n        # true next state\n        displacement1 = np.array([0, 0])\n        \n        if self.centre1[0] <= ((MAZE_H - 1) / 2 + 1) * UNIT:\n            if true_action == 0:  # up\n                if self.centre1[1] > UNIT:\n                    displacement1=np.array([0,-40])\n            elif true_action == 1:  # down\n                if self.centre1[1] < (MAZE_W - 1) * UNIT:\n                    displacement1=np.array([0,40])\n            elif true_action == 2:  # right\n                if self.centre1[0] < ((MAZE_H - 1) / 2 - 1) * UNIT:\n                    displacement1=np.array([40,0])\n            elif true_action == 3:  # left\n                if self.centre1[0] > UNIT:\n                    displacement1=np.array([-40,0])\n        else:\n            if true_action == 0:  # up\n                if self.centre1[1] > UNIT:\n                    displacement1=np.array([0,-40])\n            elif true_action == 1:  # down\n                if self.centre1[1] < (MAZE_W - 1) * UNIT:\n                    displacement1=np.array([0,40])\n            elif true_action == 2:  # right\n                if self.centre1[0] < (MAZE_H - 1) * UNIT:\n                    displacement1=np.array([40,0])\n            elif true_action == 3:  # left\n                if self.centre1[0] > ((MAZE_H - 1) / 2 + 2) * UNIT:\n                    displacement1=np.array([-40,0])\n        self.canvas.move(self.rect, displacement1[0], displacement1[1])\n        s1_ = self.canvas.coords(self.rect)\n        sss = [(s1_[4] + s1_[8]) / 2, (s1_[5] + s1_[9]) / 2]\n\n\n        return displacement1, pre_displacement1,s1_,sss\n\n    def _set_danger(self):\n        self.hell1_center = np.array([60, 60])\n        self.hell1 = self.canvas.create_oval(\n            self.hell1_center[0] - 15, self.hell1_center[1] - 15,\n            self.hell1_center[0] + 15, self.hell1_center[1] + 15,\n            fill='black')\n        # self.canvas.create_bitmap((40 , 40), bitmap='error')\n        self.hell = self.canvas.coords(self.hell1)\n        self.canvas.pack()\n        self.danger=1\n\n    def _set_switch(self):\n        self.oval_center = np.array([(MAZE_H * UNIT) / 2 - UNIT, ((MAZE_W + 4) * UNIT) / 2 - UNIT / 2])\n        self.oval = self.canvas.create_oval(\n            self.oval_center[0] - 15, self.oval_center[1] - 15,\n            self.oval_center[0] + 15, self.oval_center[1] + 15,\n            fill='yellow')\n        self.switch = self.canvas.coords(self.oval)\n        self.canvas.pack()\n\n    def _set_wall(self):\n        wall_center=[]\n        self.wall=[]\n        for a in range(MAZE_W):\n            wall_center.append([0,0])\n            self.wall.append([])\n        for b in range(MAZE_W):\n            wall_center[b]=np.array([(MAZE_H*UNIT)/2,((b)*UNIT)+UNIT/2])\n            self.wall[b] = self.canvas.create_rectangle(\n                wall_center[b][0] - 20, wall_center[b][1] - 20,\n                wall_center[b][0] + 20, wall_center[b][1] + 20,\n                fill='grey')\n        self.wall0 = self.canvas.coords(self.wall[0])\n        self.wall1 = self.canvas.coords(self.wall[1])\n        self.wall2 = self.canvas.coords(self.wall[2])\n        self.wall3 = self.canvas.coords(self.wall[3])\n\n    def generate_expression(self,pain):\n        if pain==1:\n            self.canvas.itemconfig(self.rect, fill=\"red\", outline='red')\n        if pain == 0:\n            self.canvas.itemconfig(self.rect, fill=\"green\", outline='green')\n\n    def render(self):\n        time.sleep(0.1)\n        self.update()\n\n    def is_agent_in_danger(self):\n        \"\"\"\n        Check if the agent is in a danger zone (hell1).\n        Returns: True/False\n        \"\"\"\n        s = self.canvas.coords(self.rect)\n        agent_pos = [(s[4] + s[8]) / 2, (s[5] + s[9]) / 2]\n        if hasattr(self, 'hell1_center'):\n            if all(np.isclose(agent_pos, self.hell1_center)):\n                return True\n        return False\n\n    # def getter(self, widget):\n    #     widget.update()\n    #     x = tk.Tk.winfo_rootx(self) + widget.winfo_x()\n    #     y = tk.Tk.winfo_rooty(self) + widget.winfo_y()\n    #     x1 = x + widget.winfo_width()\n    #     y1 = y + widget.winfo_height()\n    #     ImageGrab.grab().crop((x, y, x1, y1)).save(\"first.jpg\")\n    #     return ImageGrab.grab().crop((x, y, x1, y1))\n"
  },
  {
    "path": "examples/Social_Cognition/affective_empathy/BEEAD-SNN/rsnn.py",
    "content": "\nimport torch\nfrom torch import nn\n\nfrom braincog.base.node.node import IFNode\nfrom braincog.base.learningrule.STDP import STDP,MutliInputSTDP\nfrom braincog.base.connection.CustomLinear import CustomLinear\n\n\nfrom collections import deque\nfrom random import randint\n\nclass RSNN(nn.Module):\n    def __init__(self,num_state,num_action):\n        super().__init__()\n        # parameters\n        rsnn_mask=[]\n        rsnn_con=[]\n        con_matrix1 = torch.ones((num_state,num_action), dtype=torch.float)\n        rsnn_mask.append(con_matrix1)\n        # rsnn_con.append(CustomLinear(torch.randint(2,size=(num_state,num_action))*0.1))\n        rsnn_con.append(CustomLinear(torch.ones(num_state,num_action)*0.1))\n        self.num_subR=2\n        self.connection = rsnn_con\n        self.mask=rsnn_mask\n        self.node = [IFNode() for i in range(self.num_subR)]\n        self.learning_rule = []\n        self.learning_rule.append(MutliInputSTDP(self.node[1], [self.connection[0]]))\n\n        self.weight_trace = torch.zeros(con_matrix1.shape, dtype=torch.float)\n        \n        self.out_in = torch.zeros((num_state), dtype=torch.float)\n        self.out = torch.zeros((self.connection[0].weight.size()[1]), dtype=torch.float)\n        self.dw = torch.zeros((self.connection[0].weight.size()), dtype=torch.float)\n\n    def forward(self, input):\n        input=torch.tensor(input, dtype=torch.float)\n        self.out_in=self.node[0](input)\n        self.out,self.dw = self.learning_rule[0](self.out_in)\n        return self.out,self.dw\n\n    def UpdateWeight(self,reward,a,C,n):\n        self.weight_trace[0:n,:]=0\n        self.weight_trace[n+1:, :] = 0\n        self.weight_trace[:, :a * C] = 0\n        self.weight_trace[:, (a + 1) * C:] = 0\n        self.weight_trace[self.weight_trace>0]=self.weight_trace[self.weight_trace>0]*reward\n        self.weight_trace[self.weight_trace < 0] = -1*self.weight_trace[self.weight_trace < 0] * reward\n        self.connection[0].update(self.weight_trace)\n        # self.connection[0].weight.data = torch.clamp(self.connection[0].weight.data, -1, 10)\n        # for i in range(self.connection[0].weight.size()[1]):\n        #     self.connection[0].weight.data[:, i] = (self.connection[0].weight.data[:, i] - torch.min(self.connection[0].weight.data[:, i])) / (torch.max(self.connection[0].weight.data[:, i]) - torch.min(self.connection[0].weight.data[:, i]))\n        # self.connection[0].weight.data= self.connection[0].weight.data * 0.5\n       \n        self.weight_trace = torch.zeros((64,5*C), dtype=torch.float)\n    def reset(self):\n        for i in range(self.num_subR):\n            self.node[i].n_reset()\n        for i in range(len(self.learning_rule)):\n            self.learning_rule[i].reset()\n    def getweight(self):\n        return self.connection\n"
  },
  {
    "path": "examples/Social_Cognition/affective_empathy/BEEAD-SNN/sd_env.py",
    "content": "import numpy as np\nnp.random.seed(1)\nimport tkinter as tk\nimport time\nfrom PIL import ImageGrab\n\nUNIT = 40   # pixels\nMAZE_H = 11  # grid horizontal\nMAZE_W = 5 # grid vertical\n\nclass Snowdrift(tk.Tk, object):\n    def __init__(self, n_agents=3, n_snowdrifts=4):\n        super(Snowdrift, self).__init__()\n        self.action_space = ['up', 'down', 'left', 'right', 'clean'] \n        self.n_actions = len(self.action_space)\n        self.n_agents = n_agents\n        self.n_snowdrifts = n_snowdrifts\n        self.UNIT = 40\n        self.MAZE_H = 8\n        self.MAZE_W = 8\n        self.title('Snowdrift Game')\n        self.geometry('{0}x{1}'.format(self.MAZE_H * self.UNIT, self.MAZE_W * self.UNIT)) # canvas\n        self.agents = []\n        self.agents_pos = []\n        self.agents_emotion = [-1] * n_agents \n        self.snowdrifts = []\n        self.snowdrifts_pos = []\n        self.cleaned = []  \n        self.empathy_emotion = 0\n        self.empathy_emotion_t_1 = 0\n        self.help_signal = 0\n        self.lamda = 1.0  \n        \n        self._build_maze()\n\n    def _build_maze(self):\n        self.canvas = tk.Canvas(self, bg='white',\n                              height=self.MAZE_H * self.UNIT,\n                              width=self.MAZE_W * self.UNIT)\n\n        # Create agents\n        colors = ['red', 'blue', 'green']\n        for i in range(self.n_agents):\n            pos = np.array([np.random.randint(0, self.MAZE_W) * self.UNIT + self.UNIT/2,\n                          np.random.randint(0, self.MAZE_H) * self.UNIT + self.UNIT/2])\n            agent = self.canvas.create_oval(\n                pos[0] - 15, pos[1] - 15,\n                pos[0] + 15, pos[1] + 15,\n                fill=colors[i])\n            self.agents.append(agent)\n            self.agents_pos.append(pos)\n            if self.agents_emotion[i] == -1:\n                self.canvas.itemconfig(agent, fill='gray')\n\n        # Create snowdrifts\n        for _ in range(self.n_snowdrifts):\n            pos = np.array([np.random.randint(0, self.MAZE_W) * self.UNIT + self.UNIT/2,\n                          np.random.randint(0, self.MAZE_H) * self.UNIT + self.UNIT/2])\n            points = [\n                pos[0], pos[1] - 15,  \n                pos[0] - 15, pos[1] + 15,  \n                pos[0] + 15, pos[1] + 15   \n            ]\n            snowdrift = self.canvas.create_polygon(points, fill='black')\n            self.snowdrifts.append(snowdrift)\n            self.snowdrifts_pos.append(pos)\n            self.cleaned.append(False)\n\n        self.canvas.pack()\n\n    def reset(self, agent_id):\n        \"\"\"Reset environment and return initial state index\"\"\"\n        self.update()\n        time.sleep(0.001)\n        \n        # Reset all states\n        self.empathy_emotion = 0\n        self.empathy_emotion_t_1 = 0\n        self.help_signal = 0\n        \n        # Reset agents\n        for i in range(self.n_agents):\n            self.canvas.delete(self.agents[i])\n            pos = np.array([np.random.randint(0, self.MAZE_W) * self.UNIT + self.UNIT/2,\n                          np.random.randint(0, self.MAZE_H) * self.UNIT + self.UNIT/2])\n            self.agents_pos[i] = pos\n            self.agents[i] = self.canvas.create_oval(\n                pos[0] - 15, pos[1] - 15,\n                pos[0] + 15, pos[1] + 15,\n                fill='gray') # ['red', 'blue', 'green'][i]\n            self.agents_emotion[i] = -1\n\n        # Reset snowdrifts\n        for i in range(self.n_snowdrifts):\n            if hasattr(self, 'snowdrifts') and len(self.snowdrifts) > i:\n                self.canvas.delete(self.snowdrifts[i])\n            pos = self.snowdrifts_pos[i]\n            points = [\n                pos[0], pos[1] - 15,  \n                pos[0] - 15, pos[1] + 15,  \n                pos[0] + 15, pos[1] + 15   \n            ]\n            snowdrift = self.canvas.create_polygon(points, fill='black')\n            if not hasattr(self, 'snowdrifts') or len(self.snowdrifts) <= i:\n                self.snowdrifts.append(snowdrift)\n            else:\n                self.snowdrifts[i] = snowdrift\n        self.cleaned = [False] * self.n_snowdrifts\n        \n        # Calculate initial state index\n        init_state = self._get_state_index(agent_id)\n        return init_state\n\n    def step_all(self, actions):\n        \"\"\"Multi-agent environment step\n        \n        Args:\n            actions: List[int] - List of actions for each agent\n        Returns:\n            next_states: List[int] - Next state index for each agent\n            rewards: List[float] - Rewards obtained by each agent\n            done: bool - Whether the episode is finished\n            info: dict - Additional information\n        \"\"\"\n        rewards = [0] * self.n_agents\n        empathtrewards_t = [0] * self.n_agents\n        next_states = []\n        cleaned_this_step = []  # Record snowdrifts cleaned in this step\n\n        # 1. Move phase - all agents move simultaneously\n        for agent_id, action in enumerate(actions):\n            s = self.agents_pos[agent_id]\n            base_action = np.array([0, 0])\n            \n            if action < 4:  # Move actions\n                if action == 0:   # up\n                    if s[1] > self.UNIT:\n                        base_action[1] -= self.UNIT\n                elif action == 1:   # down\n                    if s[1] < (self.MAZE_H - 1) * self.UNIT:\n                        base_action[1] += self.UNIT\n                elif action == 2:   # right\n                    if s[0] < (self.MAZE_W - 1) * self.UNIT:\n                        base_action[0] += self.UNIT\n                elif action == 3:   # left\n                    if s[0] > self.UNIT:\n                        base_action[0] -= self.UNIT\n                        \n                self.canvas.move(self.agents[agent_id], base_action[0], base_action[1])\n                self.agents_pos[agent_id] = self.agents_pos[agent_id] + base_action\n\n        # 2. Cleaning phase - handle all cleaning actions\n        for agent_id, action in enumerate(actions):\n            if action == 4:  \n                s = self.agents_pos[agent_id]\n                for i, pos in enumerate(self.snowdrifts_pos):\n                    if all(s == pos) and not self.cleaned[i] and i not in cleaned_this_step:\n                        self.canvas.itemconfig(self.snowdrifts[i], fill='') \n                        self.cleaned[i] = True\n                        cleaned_this_step.append(agent_id)\n                        rewards[agent_id] += 2\n                        self.agents_emotion[agent_id] = -1\n                        self.canvas.itemconfig(self.agents[agent_id], fill='gray')\n                        for j in range(self.n_agents):\n                            if j != agent_id:\n                                rewards[j] += 6 \n                                if self.agents_emotion[j] == -1:\n                                    self.agents_emotion[j] = 0\n                                    self.canvas.itemconfig(self.agents[j], fill=['red', 'blue', 'green'][j])\n                                    empathtrewards_t[agent_id] += 6 \n\n        # 3. Calculate next state for each agent\n        for agent_id in range(self.n_agents):\n            next_state = self._get_state_index(agent_id)\n            next_states.append(next_state)\n            empathtrewards_t[agent_id]\n\n        # 4. Check if finished\n        done = all(self.cleaned)\n        \n        info = {\n            'cleaned_positions': cleaned_this_step,\n            'agent_emotions': self.agents_emotion.copy()\n        }\n        \n        return next_states, rewards, empathtrewards_t, done, info\n\n    def _get_state_index(self, agent_id):\n        \"\"\"Convert state to index value\"\"\"\n        state_index = 0\n        pos = self.agents_pos[agent_id]\n        x = int(pos[0] / (self.MAZE_W * self.UNIT) * 8)\n        y = int(pos[1] / (self.MAZE_H * self.UNIT) * 8) \n        state_index += x + y * 8\n        return state_index\n\n    def render(self):\n        \"\"\"Render environment\"\"\"\n        time.sleep(0.00001)\n        self.update()\n\n"
  },
  {
    "path": "examples/Social_Cognition/affective_empathy/BEEAD-SNN/snowdrift_main.py",
    "content": "import time\nimport datetime\nimport os\nimport random\nimport numpy as np\nimport torch\nfrom sd_env import Snowdrift\nfrom rsnn import RSNN\nfrom matplotlib import pyplot as plt\nfrom torch.utils.tensorboard import SummaryWriter\n\n# Global parameters\nN_action = 5  # up, down, left, right, clean\nN_state = 64  # 8*8 grid\nC = 50\nruntime = 100\ntrace_decay = 0.8\n\ntorch.manual_seed(42)\nnp.random.seed(42)\n\ndef encode(n, e):\n    z = torch.zeros(N_state, 100) \n    z[n, :] = 1\n    z = z * 0.51\n    return z\n\ndef aoencode(n, e, env, agent_id):\n    z = torch.zeros(N_state, 100)\n    z[n, :] = 1\n    for i in range(len(env.agents_pos)):\n        if i != agent_id:\n            other_pos = env.agents_pos[i]\n            x = int(other_pos[0] / (env.MAZE_W * env.UNIT) * 8)\n            y = int(other_pos[1] / (env.MAZE_H * env.UNIT) * 8)\n            other_state_idx = x + y * 8\n            z[other_state_idx, :] += 0.3\n    for i, snow_pos in enumerate(env.snowdrifts_pos):\n        if not env.cleaned[i]:\n            x = int(snow_pos[0] / (env.MAZE_W * env.UNIT) * 8)\n            y = int(snow_pos[1] / (env.MAZE_H * env.UNIT) * 8)\n            snow_state_idx = x + y * 8\n            z[snow_state_idx, :] += 0.6\n    z = z * 0.51\n    return z\n\ndef poencode(n, e, env, agent_id):\n    \"\"\"\n    Encode state as partially observable representation.\n    Args:\n        n: state index of current agent position\n        e: emotion state\n        env: environment object\n        agent_id: agent ID\n    \"\"\"\n    z = torch.zeros(N_state, 100)\n    agent_pos = env.agents_pos[agent_id]\n    cur_x = int(agent_pos[0] / (env.MAZE_W * env.UNIT) * 8)\n    cur_y = int(agent_pos[1] / (env.MAZE_H * env.UNIT) * 8)\n    obs_range = 1  # observable grid range\n    z[n, :] = 1\n    for i in range(len(env.agents_pos)):\n        if i != agent_id:\n            other_pos = env.agents_pos[i]\n            x = int(other_pos[0] / (env.MAZE_W * env.UNIT) * 8)\n            y = int(other_pos[1] / (env.MAZE_H * env.UNIT) * 8)\n            if abs(x - cur_x) <= obs_range and abs(y - cur_y) <= obs_range:\n                other_state_idx = x + y * 8\n                z[other_state_idx, :] += 0.3\n    for i, snow_pos in enumerate(env.snowdrifts_pos):\n        if not env.cleaned[i]:\n            x = int(snow_pos[0] / (env.MAZE_W * env.UNIT) * 8)\n            y = int(snow_pos[1] / (env.MAZE_H * env.UNIT) * 8)\n            if abs(x - cur_x) <= obs_range and abs(y - cur_y) <= obs_range:\n                snow_state_idx = x + y * 8\n                z[snow_state_idx, :] += 0.6\n    z = z * 0.51\n    return z\n\ndef chooseAct(Net, input, explore, n, env, agent_id):\n    count_group = np.zeros(N_action)\n    count_output = np.zeros(N_action * C)\n    for i_train in range(runtime):\n        out, dw = Net(input[:, i_train])\n        Net.weight_trace *= trace_decay\n        Net.weight_trace += dw[0]\n        count_output = count_output + np.array(out)\n        for i in range(N_action):\n            count_group[i] = count_output[i*C:(i+1)*C].sum()\n    agent_pos = env.agents_pos[agent_id]\n    at_snowdrift = False\n    for i, snow_pos in enumerate(env.snowdrifts_pos):\n        if not env.cleaned[i] and all(agent_pos == snow_pos):\n            at_snowdrift = True\n            break\n    if not at_snowdrift:\n        count_group[4] = float('-inf')\n    if np.random.uniform() < explore:\n        if not at_snowdrift:\n            action = np.random.randint(0, 4)\n        else:\n            if count_group.max() > float('-inf'):\n                action = count_group.argmax()\n            else:\n                action = np.random.randint(0, N_action)\n    else:\n        if not at_snowdrift:\n            action = np.random.randint(0, 4)\n        else:\n            action = np.random.randint(0, N_action)\n    return action, Net, dw[0], 0\n\ndef train_model(n_agents, lamdas, episodes):\n    # TensorBoard writer\n    current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')\n    log_dir = os.path.join('run33obs', f'sd_partobs_a{n_agents}_l{lamdas[0]}{lamdas[1]}{lamdas[2]}_e{episodes}_{current_time}', f'')\n    writer = SummaryWriter(log_dir)\n    nets = [RSNN(N_state, N_action*C) for _ in range(n_agents)]\n    learn_steps = [[] for _ in range(n_agents)]\n    weight_marks = [np.zeros((N_state, N_action)) for _ in range(n_agents)]\n    update_stops = [0 for _ in range(n_agents)]\n    empathy_rewards_t = [0] * n_agents\n    total_rewards = [0] * n_agents\n    env = Snowdrift(n_agents=n_agents, n_snowdrifts=10)\n    episode_cleaned_counts = []\n    agent_cleaned_counts = [0] * n_agents\n\n    for episode in range(episodes):\n        print(f'Episode: {episode}, Lambda: {lamdas}')\n        cleaned_count = 0 \n        episode_agent_cleaned_count = [0] * n_agents\n        states = []\n        emotion_t = [-1] * n_agents\n        for i in range(n_agents):\n            state = env.reset(i)\n            states.append(state)\n        episode_rewards = [0 for _ in range(n_agents)]\n        episode_total_rewards = [0 for _ in range(n_agents)]\n        if episode < 100:\n            e_greedy = 0.2\n        elif episode < 300:\n            e_greedy = 0.5\n        elif episode < 900:\n            e_greedy = 0.9\n        else:\n            for i in range(n_agents):\n                if update_stops[i] == 0:\n                    update_stops[i] = 1\n            e_greedy = 1\n        for t in range(100):\n            emotion_tt = emotion_t.copy()\n            emotion_t = env.agents_emotion.copy()\n            actions = []\n            for i in range(n_agents):\n                input_state = poencode(states[i], env.agents_emotion[i], env, i)\n                action, nets[i], dw, _ = chooseAct(nets[i], input_state, e_greedy, states[i], env, i)\n                actions.append(action)\n            next_states, rewards, empathy_rewards, done, info = env.step_all(actions)\n            cleaned_count += len(info['cleaned_positions'])\n            if 'cleaned_by_agent' in info:\n                for snow_idx, agent_idx in info['cleaned_by_agent'].items():\n                    episode_agent_cleaned_count[agent_idx] += 1\n                    agent_cleaned_counts[agent_idx] += 1\n            print(f'intereaction {t} :')\n            for i in range(n_agents):\n                if env.agents_emotion[i] == emotion_t[i] and env.agents_emotion[i] == emotion_tt[i] and (emotion_tt[0]==0 or emotion_tt[1]==0 or emotion_tt[2]==0):\n                    total_rewards[i] = lamdas[i] * (empathy_rewards[i] - empathy_rewards_t[i]) + rewards[i]\n                elif env.agents_emotion[0]==-1 and env.agents_emotion[1]==-1 and env.agents_emotion[2]==-1:\n                    total_rewards[i] = lamdas[i] * (empathy_rewards[i] - empathy_rewards_t[i]) + rewards[i]\n                else:\n                    total_rewards[i] = lamdas[i] * empathy_rewards[i] + rewards[i]\n            print(f'Actions: {actions}, Rewards: {rewards}, Empathy Rewards: {empathy_rewards} Total Rewards: {total_rewards} emotion: {env.agents_emotion}')\n            empathy_rewards_t = empathy_rewards\n            env.render()\n            for i in range(n_agents):\n                if update_stops[i] == 0:\n                    nets[i].UpdateWeight(total_rewards[i], actions[i], C, states[i])\n            states = next_states\n            for i in range(n_agents):\n                episode_rewards[i] += rewards[i]\n                episode_total_rewards[i] += total_rewards[i]\n            if done:\n                break\n        for i in range(n_agents):\n            writer.add_scalar(f'ERewards/Agent_{i+1}', episode_rewards[i], episode)\n            writer.add_scalar(f'totalRewards/Agent_{i+1}', episode_total_rewards[i], episode)\n            writer.add_scalar(f'Cleaned/Agent_{i+1}', episode_agent_cleaned_count[i], episode)\n        for i in range(n_agents):\n            learn_steps[i].append(episode_rewards[i])\n        # Save weights\n        if episode == episodes-1:\n            for i in range(n_agents):\n                torch.save(nets[i].connection[0].weight.data, f'weight_agent{i}_lambda{lamdas}_episode{episode}.pth')\n        episode_cleaned_counts.append(cleaned_count)\n        writer.add_scalar('Performance/Cleaned_Snowdrifts', cleaned_count, episode)\n        print(f'Cleaned Count: {cleaned_count}')\n    writer.close()  \n    return learn_steps, episode_cleaned_counts\n\nif __name__ == \"__main__\":\n    n_agents = 3\n    n_snowdrifts = 10\n    self_factors = [[1.51, 1.51, 1.51]]\n    all_learn_steps = []\n    all_cleaned_counts = []\n    for ii in range(len(self_factors)):\n        all_learn_steps.append([[] for _ in range(n_agents)])\n        all_cleaned_counts.append([])\n    for iii, lamdas in enumerate(self_factors):\n        steps, cleaned_counts = train_model(n_agents, lamdas, 1000)\n        for i in range(n_agents):\n            all_learn_steps[iii][i].extend(steps[i])\n        all_cleaned_counts[iii] = cleaned_counts\n    # Save plot images\n    save_dir = 'results'\n    if not os.path.exists(save_dir):\n        os.makedirs(save_dir)\n    timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')\n    plt.figure(figsize=[24, 8])\n    # Plot reward curves for each agent\n    plt.subplot(1, 3, 1)\n    colors = ['red', 'blue', 'green']\n    labels = ['Agent 1', 'Agent 2', 'Agent 3']\n    for i in range(n_agents):\n        plt.plot(all_learn_steps[0][i], label=labels[i], color=colors[i])\n    plt.legend(loc='lower right')\n    plt.title('Rewards per Agent')\n    plt.xlabel('Episode')\n    plt.ylabel('Reward')\n    \n    # Plot number of cleaned snowdrifts\n    plt.subplot(1, 3, 2)\n    plt.plot(all_cleaned_counts[0], label='Cleaned Snowdrifts', color='black')\n    plt.axhline(y=n_snowdrifts, color='r', linestyle='--', label='Total Snowdrifts')\n    plt.legend(loc='lower right')\n    plt.title('Number of Cleaned Snowdrifts per Episode')\n    plt.xlabel('Episode')\n    plt.ylabel('Count')\n    plt.ylim([0, n_snowdrifts + 1])\n\n    # Add total reward curve\n    plt.subplot(1, 3, 3)\n    total_rewards_per_episode = np.sum(all_learn_steps[0], axis=0)  # Calculate total reward for each episode\n    plt.plot(total_rewards_per_episode, label='Total Rewards', color='purple')\n    plt.legend(loc='lower right')\n    plt.title('Total Rewards of All Agents')\n    plt.xlabel('Episode')\n    plt.ylabel('Total Reward')\n    \n    plt.tight_layout()\n    # Save image\n    save_path = os.path.join(save_dir, f'training_results_{timestamp}.png')\n    plt.savefig(save_path, dpi=300, bbox_inches='tight')\n    plt.close()  # Close the figure to free memory\n    \n    print(f'Image saved to: {save_path}')\n"
  },
  {
    "path": "examples/Social_Cognition/affective_empathy/BRP-SNN/BRP-SNN.py",
    "content": "import os\nimport sys\n# 把当前文件所在文件夹的父文件夹路径加入到PYTHONPATH\nsys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))\n\n\nimport imageio\nfrom env_poly_SNN import Maze\nfrom env_two_poly_SNN import Maze2\nimport numpy as np\nimport pandas as pd\nimport matplotlib\nimport matplotlib.pyplot as plt\n\n\nfrom sklearn.preprocessing import MinMaxScaler\nimport torch, os, sys\nfrom torch import nn\nfrom torch.nn import Parameter\nimport abc\nimport math\nfrom abc import ABC\nimport torch.nn.functional as F\nfrom braincog.base.node.node import *\nfrom braincog.base.learningrule.STDP import *\nfrom braincog.base.connection.CustomLinear import *\n\n\n\nX=np.array([[0],\n            [1],\n            [2],\n            [3]])\nY=np.array([[0,-40],\n            [0,40],\n            [40,0],\n            [-40,0]\n            ])\n\nclass BrainArea(nn.Module, abc.ABC):\n    \"\"\"\n    脑区基类\n    \"\"\"\n\n    @abc.abstractmethod\n    def __init__(self):\n        \"\"\"\n        \"\"\"\n        super().__init__()\n\n    @abc.abstractmethod\n    def forward(self, x):\n        \"\"\"\n        计算前向传播过程\n        :return:x是脉冲\n        \"\"\"\n\n        return x\n\n    def reset(self):\n        \"\"\"\n        计算前向传播过程\n        :return:x是脉冲\n        \"\"\"\n\n        pass\n\n\nclass BNESNN(BrainArea):\n    \"\"\"\n    负面情绪网络\n    \"\"\"\n    def __init__(self,):\n        super().__init__()\n        self.node = [IFNode() for i in range(5)]\n        self.connection = []\n        \n        con_matrix0 = torch.eye(12, 12)*6\n        self.connection.append(CustomLinear(con_matrix0))#input-state\n        \n        con_matrix1 = torch.zeros((12, 24), dtype=torch.float)   \n        self.connection.append(CustomLinear(con_matrix1))#state-prediction\n        \n        con_matrix2 = torch.eye(24, 24)*6\n        self.connection.append(CustomLinear(con_matrix2))#input-prediction\n        \n        con_matrix3 = torch.eye(24, 24)*6\n        self.connection.append(CustomLinear(con_matrix3))#input-sensory\n        \n        con_matrix4 = torch.eye(24, 24)*6\n        self.connection.append(CustomLinear(con_matrix4))#sensory-error\n        \n        con_matrix5 = torch.eye(24, 24)*(-6)\n        self.connection.append(CustomLinear(con_matrix5))#prediction-error\n        \n        con_matrix6 = torch.zeros((24, 24), dtype=torch.float)   \n        p=0.5        \n        if p==0.25:\n            con_matrix6[:,0:3]=1\n        if p==0.5:\n            con_matrix6[:,0:6]=1\n        if p==0.75:\n            con_matrix6[:,0:9]=1\n        if p==1:\n            con_matrix6[:,0:12]=1\n        self.connection.append(CustomLinear(con_matrix6))#error-pain\n        \n        self.stdp = []\n        self.stdp.append(STDP(self.node[0], self.connection[0]))#node0-state,stdp0\n        self.stdp.append(MutliInputSTDP(self.node[1], [self.connection[1], self.connection[2]]))#node1-prediction,stdp1\n        self.stdp.append(STDP(self.node[3], self.connection[3]))#node3-sensory,stdp2\n        self.stdp.append(MutliInputSTDP(self.node[2], [self.connection[4], self.connection[5]]))#node2-error,stdp3\n        self.stdp.append(STDP(self.node[1], self.connection[1]))#node1-prediction,stdp4\n        self.stdp.append(STDP(self.node[4], self.connection[6]))#node4-pain,stdp5\n    def forward(self, x1,x2):\n        \"\"\"\n        计算前向传播过程,训练过程\n        \"\"\"\n        out__s, dw0 = self.stdp[0](x1)#node0\n        out__p,dw = self.stdp[1](out__s,x2)#node1\n    \n        return dw,out__s,out__p\n    \n    def calculate_error(self, x1,x2):\n        \"\"\"\n        测试过程\n        \"\"\"\n        out__s,dw = self.stdp[0](x1)#node0-state,stdp0\n        out__pre,dw= self.stdp[4](out__s)#node1-prediction,stdp1\n        out__sensory,dw = self.stdp[2](x2)#node3-sensory,stdp2\n        out__error,dw = self.stdp[3](out__sensory,out__pre)#node2-error,stdp3\n        out__pain,dw = self.stdp[5](out__error)#node4-pain,stdp5\n        \n    \n        return out__s,out__pre,out__sensory,out__error,out__pain\n        \n    def UpdateWeight(self, i, dw, delta):\n        \"\"\"\n        更新第i组连接的权重 根据传入的dw值\n        :param i: 要更新的连接的索引\n        :param dw: 更新的量\n        :return: None\n        \"\"\"\n        self.connection[i].update(dw*delta)\n        self.connection[i].weight.data= torch.clamp(self.connection[i].weight.data,0,6)\n        \n    def reset(self):\n        \"\"\"\n        reset神经元或学习法则的中间量\n        :return: None\n        \"\"\"\n        for i in range(5):\n            self.node[i].n_reset()\n        for i in range(len(self.stdp)):\n            self.stdp[i].reset()\n\n         \ndef GRF(X,N):\n    gauss_neuron = 12  \n    center = np.ones((gauss_neuron, 1))\n    width = 1 / 15\n\n    for i in range(len(center)):  \n        center[i] = (2 * i - 3) / 20  \n    x = np.arange(0, 1, 0.0001)  \n\n    num_features = N\n\n    gauss_recpt_field = np.zeros((gauss_neuron, len(x)))  \n    for i in range(gauss_neuron):\n        gauss_recpt_field[i, :] = np.exp(-(x - center[i]) ** 2 / (2 * width * width))  \n\n    def gauss_response(inputs,num_features):\n        spike_time = np.zeros((gauss_neuron, num_features))\n        # input: shape [1, features]\n        # output: shape [gaussian neurons*features] spiking time\n        for i in range(num_features):\n            for j in range(gauss_neuron):\n                spike_time[j, i] = gauss_recpt_field[j, inputs[i]]  #entry gauss function\n        spikes = []\n        for i in range(spike_time.shape[1]):\n            spikes.extend(spike_time[:, i])\n        return np.array(spikes)\n\n\n    gauss_neurons = gauss_neuron * N\n   \n    scaler = MinMaxScaler()\n    X = scaler.fit_transform(X)\n    X = (X * 10000).astype(int)  #10000\n    X[X == 10000] = 9999 \n    input_spike = np.zeros((X.shape[0], gauss_neurons))  \n    for i in range(X.shape[0]):\n        input_spike[i, :] = gauss_response(X[i, :],num_features)\n    input_spike[input_spike < 0.1] = 0  \n    input_spike = np.around(100 * (1 - input_spike))  \n    input_spike[input_spike == 0] = 1\n    input_spike[input_spike == 100] = 0\n    state=[]\n    for i in range(len(X)):\n        aa=[]\n        for j in range(gauss_neurons):\n            if input_spike[i][j] != 0:\n                number=input_spike[i][j]\n                aa.append((int(number),j))\n        state.append(aa)\n        \n    return state\n             \n                \ndef encode(input,n_neuron):\n    a=len(input)\n    input_encode = []\n    for i in range(n_neuron):\n        temp = np.zeros([100, ])\n        input_encode.append(temp)\n    for j in range(a):\n        s=input[j][0]\n        n=input[j][1]\n        input_encode[n][s]=1\n\n    return input_encode\n\n\nclass BAESNN(BrainArea):\n    \"\"\"\n    情感共情网络\n    \"\"\"\n\n    def __init__(self,):\n        \"\"\"\n        \"\"\"\n        super().__init__()\n\n\n        self.node = [IFNode() for i in range(5)]\n       \n        \n        self.connection = []\n        \n        con_matrix0 = torch.eye(24, 24)*6\n        self.connection.append(CustomLinear(con_matrix0))#input-emotion\n        \n        con_matrix1 = torch.zeros((24, 50), dtype=torch.float)\n        for j in range(50):\n            if j in np.arange(0,25,1):\n                for i in np.arange(0, 12, 1):\n                    con_matrix1[i,j] =2\n            if j in np.arange(25,50,1):\n                for i in np.arange(12, 24, 1):\n                    con_matrix1[i,j] =2    \n        self.connection.append(CustomLinear(con_matrix1))#emotion-ifg\n        \n        con_matrix2 = torch.zeros((24, 50), dtype=torch.float)  \n        self.connection.append(CustomLinear(con_matrix2))#perception-ifg\n        \n        con_matrix3 = torch.eye(24, 24)*6\n        self.connection.append(CustomLinear(con_matrix3))#input-perception\n        \n        con_matrix4=torch.zeros((24,10), dtype=torch.float)\n        for j in range(10):\n            if j in np.arange(0,5,1):\n                for i in np.arange(0, 12, 1):\n                    con_matrix4[i,j] =2\n            if j in np.arange(5,10,1):\n                for i in np.arange(12, 24, 1):\n                    con_matrix4[i,j] =2\n        self.connection.append(CustomLinear(con_matrix4))#emotion-sma\n        \n        con_matrix5=torch.zeros((24,10), dtype=torch.float)\n        self.connection.append(CustomLinear(con_matrix5))#perception-m1\n        \n        con_matrix6 = torch.eye(10, 10)*6\n        self.connection.append(CustomLinear(con_matrix6))#sma-m1\n        \n        self.stdp = []\n        self.stdp.append(STDP(self.node[0], self.connection[0]))#0\n        self.stdp.append(STDP(self.node[2], self.connection[3]))#1\n        self.stdp.append(MutliInputSTDP(self.node[1], [self.connection[1], self.connection[2]]))#2\n        self.stdp.append(MutliInputSTDP(self.node[3], [self.connection[4], self.connection[5]]))#3\n        self.stdp.append(STDP(self.node[4], self.connection[6]))#4\n        self.stdp.append(STDP(self.node[1],self.connection[2]))#5\n        self.stdp.append(STDP(self.node[3],self.connection[5]))#6\n    def forward(self, x1,x2):\n        \"\"\"\n        计算前向传播过程\n        :return:x是脉冲\n        \"\"\"\n        out__m, dw0 = self.stdp[0](x1)#node0\n        out__p, dw3 = self.stdp[1](x2)#node2\n        out__ifg,dw_p_i=self.stdp[2](out__m,out__p)#node1\n        out__sma,dw_p_s=self.stdp[3](out__m,out__p)#node3\n        out__m1,dw1=self.stdp[4](out__sma)#node4\n    \n        return dw_p_i,dw_p_s,out__ifg,out__sma,out__m1\n    \n    def empathy(self,x3):\n        out_p,dw2=self.stdp[1](x3)#node2\n        out_ifg,dw4=self.stdp[5](out_p)#node1\n        out_sma,dw5=self.stdp[6](out_p)#node3\n        out_m1,dw6=self.stdp[4](out_sma)#node4\n        return out_ifg,out_sma,out_m1\n        \n    def UpdateWeight(self, i, dw, delta):\n        \"\"\"\n        更新第i组连接的权重 根据传入的dw值\n        :param i: 要更新的连接的索引\n        :param dw: 更新的量\n        :return: None\n        \"\"\"\n        self.connection[i].update(dw*delta)\n        self.connection[i].weight.data= torch.clamp(self.connection[i].weight.data,-1,4)\n        \n    def reset(self):\n        \"\"\"\n        reset神经元或学习法则的中间量\n        :return: None\n        \"\"\"\n        for i in range(5):\n            self.node[i].n_reset()\n        for i in range(len(self.stdp)):\n            self.stdp[i].reset()\n\ndef BNESNN_train():  \n    \n    state=GRF(X,1)\n    prediction=GRF(Y,2)\n\n    T=100 \n    epoch=10\n    for k in range(epoch):\n        print('epoch:',k)\n        for n in range(4):\n            snn1.reset()\n            train_state = np.array(encode(state[n], 12))\n            train_state=torch.tensor(train_state,dtype=torch.float32)\n            train_prediction = np.array(encode(prediction[n], 24))\n            train_prediction=torch.tensor(train_prediction,dtype=torch.float32)\n            for i in range(T):\n                OUTPUT = snn1(train_state[:,i],train_prediction[:,i])\n                snn1.UpdateWeight(1,OUTPUT[0][0],1)\n\n\n\ndef BAESNN_train():  \n    s = env.reset()\n    env._set_danger()\n    env._set_wall()\n    pain=0\n    i=0\n    set_pain=0\n    env._set_switch()\n    for i in range(100):\n        snn1.reset()\n        T=100\n        pain=0\n        print('**************step:',i)\n        env.render()\n        \n        action = np.random.choice(list(range(env.n_actions)))\n        print('action:',action)\n        d,d_pre,s_,sss = env.step(s, action, pain)\n        print('d:',d,'d_pre:',d_pre,'sss:',sss)\n        env.render()\n            \n        while (d==np.array([0,0])).all():\n            action = np.random.choice(list(range(env.n_actions)))\n            print('action:',action)\n            d,d_pre,s_,sss = env.step(s, action, pain)\n            print('d:',d,'d_pre:',d_pre,'sss:',sss)\n            env.render()\n        \n        \n        \n        aa=np.argwhere(X==action)[0][0]\n        for i in range(4):\n            if (Y[i]==d).all():\n                b=i\n        print('aa:',aa,'b:',b)       \n        state=GRF(X,1)\n        prediction=GRF(Y,2)\n        x=encode(state[aa],12)\n        y=encode(prediction[b],24)\n        train_state = np.array(x)\n        train_state=torch.tensor(train_state,dtype=torch.float32)\n        train_prediction = np.array(y)\n        train_prediction=torch.tensor(train_prediction,dtype=torch.float32)\n        OUT_PAIN=torch.tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n                                0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n                                0., 0., 0., 0.]])  \n        spike_pain=[]\n        spike_error=[]\n        for i in range(T):\n            OUTPUT_TEST = snn1.calculate_error(train_state[:,i],train_prediction[:,i])\n            spike_pain.append(OUTPUT_TEST[4])\n            spike_error.append(OUTPUT_TEST[3])\n            if OUTPUT_TEST[3].sum() != 0:\n                print('OUTPUT_TEST3:',i,OUTPUT_TEST[3])\n            if OUTPUT_TEST[4].sum() != 0:#pain brain area\n                print('OUTPUT_TEST4:',i,OUTPUT_TEST[4])\n                OUT_PAIN=OUTPUT_TEST[4]\n                pain=1\n                set_pain=1\n        spike_pain = torch.stack(spike_pain)\n        spike_error=torch.stack(spike_error)\n        if pain==1:\n            spike_rate_vis_1d(spike_error)\n            spike_rate_vis_1d(spike_pain)\n        print('pain:',pain)\n        \n        \n        \n        \n        \n        snn2.reset()\n        T2=20\n        X1= OUT_PAIN.view(1, -1) \n        X2=torch.tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n                                0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n                                0., 0., 0., 0.]])  \n        print('X1,X2:',X1,X2)\n        for i in range(T2):\n            if i>=2:\n                X2=X1\n            OUTPUT = snn2(X1,X2)\n            snn2.UpdateWeight(2,OUTPUT[0][1],0.01)\n            snn2.UpdateWeight(5,OUTPUT[1][1],-0.1)\n            if OUTPUT[2][0][0]==1:\n                env.canvas.itemconfig(env.rect, fill=\"red\", outline='red')\n            if OUTPUT[2][0][0]==0:\n                env.canvas.itemconfig(env.rect, fill=\"green\", outline='green')\n        env.render()\n        \n        print('out_ifg:',OUTPUT[2])\n        print('out_sma:',OUTPUT[3])\n        print('out_m1:',OUTPUT[4])\n        print('con2:',snn2.connection[2].weight.data)\n        print('con5:',snn2.connection[5].weight.data)\n        \n        s = s_\n        if set_pain==1 and pain==0:\n            env.render()\n            break\n    env.destroy()\n                \n\ndef BAESNN_test():\n    s1,s=env2.reset()\n    pain=0\n    pain1 = 0\n    i=0\n    set_pain=0\n    \n    for i in range(100):\n        \n        snn1.reset()\n        T=100\n        pain=0\n        print('**************test_step:',i)\n        env2.render()\n        \n        action1 = np.random.choice(list(range(env.n_actions)))\n        print('action1:',action1)\n        d,d_pre,s1_,sss = env2.step(s1, action1, pain1)\n        print('d:',d,'d_pre:',d_pre,'sss:',sss)\n        env2.render()\n            \n        while (d==np.array([0,0])).all():\n            action1 = np.random.choice(list(range(env.n_actions)))\n            print('action1:',action1)\n            d,d_pre,s1_,sss = env2.step(s, action1, pain1)\n            print('d:',d,'d_pre:',d_pre,'sss:',sss)\n            env2.render()\n        \n        \n        \n        aa=np.argwhere(X==action1)[0][0]\n        for i in range(4):\n            if (Y[i]==d).all():\n                b=i\n        # print('aa:',aa,'b:',b)       \n        state=GRF(X,1)\n        prediction=GRF(Y,2)\n        x=encode(state[aa],12)\n        y=encode(prediction[b],24)\n        train_state = np.array(x)\n        train_state=torch.tensor(train_state,dtype=torch.float32)\n        train_prediction = np.array(y)\n        train_prediction=torch.tensor(train_prediction,dtype=torch.float32)\n        OUT_PAIN=torch.tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n                                0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n                                0., 0., 0., 0.]])  \n        for i in range(T):\n            OUTPUT_TEST = snn1.calculate_error(train_state[:,i],train_prediction[:,i])\n            if OUTPUT_TEST[3].sum() != 0:\n                print('OUTPUT_TEST3:',i,OUTPUT_TEST[3])\n            if OUTPUT_TEST[4].sum() != 0:#pain brain area\n                print('OUTPUT_TEST4:',i,OUTPUT_TEST[4])\n                OUT_PAIN=OUTPUT_TEST[4]\n                pain1=1\n                set_pain=1\n        print('pain1:',pain1)\n        \n\n        \n        env2.generate_expression1(pain1)\n        \n        \n        snn2.reset()\n        T2=20\n        X3= OUT_PAIN.view(1, -1) \n       \n        for i in range(T2):\n            OUT=snn2.empathy(X3)\n            print('out_ifg:',OUT[0])\n        \n        if OUT[0][0][0]==1:\n            pain=1        \n        \n        if OUT[0][0][0]==0:\n            pain=0  \n                \n        if pain==1:\n            env2.agent_help()\n            \n        s1 = s1_\n        env2.render()\n\n        if pain==0 and set_pain==1:\n            env2.render()\n            break\n  \n    env2.destroy()\n  \n\n\n\n\n\nif __name__ == \"__main__\":\n    env = Maze() \n    snn1 = BNESNN()\n    snn2 = BAESNN() \n    BNESNN_train()\n    BAESNN_train()\n    env.mainloop()\n    \n    env2 = Maze2()\n    BAESNN_test()\n    env2.mainloop()"
  },
  {
    "path": "examples/Social_Cognition/affective_empathy/BRP-SNN/README.md",
    "content": ""
  },
  {
    "path": "examples/Social_Cognition/affective_empathy/BRP-SNN/env_poly_SNN.py",
    "content": "import numpy as np\nnp.random.seed(1)\nimport tkinter as tk\nimport time\nfrom PIL import ImageGrab\n\nUNIT = 40   # pixels\nMAZE_H = 9  # grid height\nMAZE_W = 4 # grid width\n\n\nclass Maze(tk.Tk, object):\n    def __init__(self):\n        super(Maze, self).__init__()\n        self.action_space = ['u', 'd', 'l', 'r']\n        self.n_actions = len(self.action_space)\n        self.title('self-pain')\n        self.geometry('{0}x{1}'.format(MAZE_H * UNIT, MAZE_W * UNIT))\n        self._build_maze()\n        self.danger=0\n        self.action_hurt=0\n        self.sensory_hurt = 0\n        self.open_door = 0\n        self.pain_state=0\n\n    # create environment\n    def _build_maze(self):\n        self.canvas = tk.Canvas(self, bg='white',\n                           height=MAZE_W * UNIT,\n                           width=MAZE_H * UNIT)\n\n        # create grids\n        for c in range(0, MAZE_H * UNIT, UNIT):\n            x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT\n            self.canvas.create_line(x0, y0, x1, y1)\n        for r in range(0, MAZE_H * UNIT, UNIT):\n            x0, y0, x1, y1 = 0, r, MAZE_H * UNIT, r\n            self.canvas.create_line(x0, y0, x1, y1)\n\n        self.orgin=[20,20]\n        # create agent\n        # 下\n        self.points1 = [\n            # 左上\n            self.orgin[0]-15,#5\n            self.orgin[1]-15,#5\n            # 右上\n            self.orgin[0]+15,#35\n            self.orgin[1]-15,#5\n            # 右下+\n            self.orgin[0]+15,#35\n            self.orgin[1],#20\n            # 顶点\n            self.orgin[0],#20\n            self.orgin[1]+15,#35\n            # 左下+\n            self.orgin[0]-15,#5\n            self.orgin[1],#20\n        ]\n        self.rect = self.canvas.create_polygon(self.points1, fill=\"green\")\n        self.canvas.pack()\n\n\n    #reset agent location\n    def reset(self):\n        self.open_door = 0\n        self.update()\n        time.sleep(0.5)\n        self.canvas.delete(self.rect)\n        self.orgin = [20, 20]\n        # 下\n        self.points1 = [\n            # 左上\n            self.orgin[0] - 15,  # 5\n            self.orgin[1] - 15,  # 5\n            # 右上\n            self.orgin[0] + 15,  # 35\n            self.orgin[1] - 15,  # 5\n            # 右下+\n            self.orgin[0] + 15,  # 35\n            self.orgin[1],  # 20\n            # 顶点\n            self.orgin[0],  # 20\n            self.orgin[1] + 15,  # 35\n            # 左下+\n            self.orgin[0] - 15,  # 5\n            self.orgin[1],  # 20\n        ]\n        self.rect = self.canvas.create_polygon(self.points1, fill=\"green\")\n        return self.canvas.coords(self.rect)\n\n\n    def step(self, s, action, pain):\n        s = self.canvas.coords(self.rect)\n        self.centre = [(s[4] + s[8]) / 2, (s[5] + s[9]) / 2]\n\n        # danger or switch\n        if self.danger==1:\n            if all(self.centre == self.oval_center):\n                s_color = 'yellow'\n                self.canvas.delete(self.wall[3])\n                self.render()\n               \n                self.open_door = 1\n\n                move = np.array([80, 0])\n                self.canvas.move(self.rect, move[0], move[1])\n\n                s = self.canvas.coords(self.rect)\n                self.render()\n                \n            elif all(self.centre == self.hell1_center):\n                s_color = 'black'\n                self.action_hurt = 1\n                self.render()\n            else:\n                s_color = 'white'\n\n\n\n\n        # modify current state\n        self.canvas.delete(self.rect)# 主要为开关那几步考虑，所以重复写了\n        self.centre = [(s[4] + s[8]) / 2, (s[5] + s[9]) / 2]\n\n        if action==0:\n            self.points0 = [\n                # 右下\n                self.centre[0] + 15,  # 35\n                self.centre[1] + 15,  # 35\n                # 左下\n                self.centre[0] - 15,  # 5\n                self.centre[1] + 15,  # 35\n                # 左上+\n                self.centre[0] - 15,  # 5\n                self.centre[1],  # 20\n                # 顶点\n                self.centre[0],  # 20\n                self.centre[1] - 15,  # 5\n                # 右上+\n                self.centre[0] + 15,  # 35\n                self.centre[1],  # 20\n            ]\n            if pain==0:\n                color=\"green\"\n            if pain == 1:\n                color = \"red\"\n            self.rect = self.canvas.create_polygon(self.points0, fill=color)\n        if action==1:\n            self.points1 = [\n                # 左上\n                self.centre[0] - 15,  # 5\n                self.centre[1] - 15,  # 5\n                # 右上\n                self.centre[0] + 15,  # 35\n                self.centre[1] - 15,  # 5\n                # 右下+\n                self.centre[0] + 15,  # 35\n                self.centre[1],  # 20\n                # 顶点\n                self.centre[0],  # 20\n                self.centre[1] + 15,  # 35\n                # 左下+\n                self.centre[0] - 15,  # 5\n                self.centre[1],  # 20\n            ]\n            if pain==0:\n                color=\"green\"\n            if pain == 1:\n                color = \"red\"\n            self.rect = self.canvas.create_polygon(self.points1, fill=color)\n        if action==2:\n            self.points2 = [\n                # 左下\n                self.centre[0] - 15,  # 5\n                self.centre[1] + 15,  # 35\n                # 左上\n                self.centre[0] - 15,  # 5\n                self.centre[1] - 15,  # 5\n                # 右上+\n                self.centre[0],  # 20\n                self.centre[1] - 15,  # 5\n                # 顶点\n                self.centre[0] + 15,  # 35\n                self.centre[1],  # 20\n                # 右下+\n                self.centre[0],  # 20\n                self.centre[1] + 15,  # 35\n            ]\n            if pain==0:\n                color=\"green\"\n            if pain == 1:\n                color = \"red\"\n            self.rect = self.canvas.create_polygon(self.points2, fill=color)\n        if action==3:\n            self.points3 = [\n                # 右上\n                self.centre[0] + 15,  # 20+15\n                self.centre[1] - 15,  # 20-15\n                # 右下\n                self.centre[0] + 15,  # 20+15\n                self.centre[1] + 15,  # 20+15\n                # 左下+\n                self.centre[0],  # 20\n                self.centre[1] + 15,  # 20+15\n                # 顶点\n                self.centre[0] - 15,  # 20-15\n                self.centre[1],  # 20\n                # 左上+\n                self.centre[0],  # 20\n                self.centre[1] - 15,  # 20-15\n\n            ]\n            if pain==0:\n                color=\"green\"\n            if pain == 1:\n                color = \"red\"\n            self.rect = self.canvas.create_polygon(self.points3, fill=color)\n        s = self.canvas.coords(self.rect)\n        self.render()#显示当前的动作指令是什么\n       \n\n\n        if s[0] > (9 / 2) * 40:\n            self.action_hurt = 0\n            \n        base_action = np.array([0, 0])\n        if self.action_hurt == 0:\n            true_action = action\n        else:\n            if action == 0:\n                true_action = 1\n            if action == 1:\n                true_action = 0\n            if action == 2:\n                true_action = 3\n            if action == 3:\n                true_action = 2\n\n        # predict next state\n        # predict next state\n        self.centre1 = [(s[4] + s[8]) / 2, (s[5] + s[9]) / 2]\n        pre_displacement1 = np.array([0, 0])\n        if self.centre1[0] <= ((MAZE_H - 1) / 2 + 1) * UNIT:  # 120\n            if action == 0:  # up\n                if self.centre1[1] > UNIT:\n                    pre_displacement1 = np.array([0, -40])\n            elif action == 1:  # down\n                if self.centre1[1] < (MAZE_W - 1) * UNIT:\n                    pre_displacement1 = np.array([0, 40])\n            elif action == 2:  # right\n                if self.centre1[0] < ((MAZE_H - 1) / 2 - 1) * UNIT:\n                    pre_displacement1 = np.array([40, 0])\n            elif action == 3:  # left\n                if self.centre1[0] > UNIT:\n                    pre_displacement1 = np.array([-40, 0])\n        else:\n            if action == 0:  # up\n                if self.centre1[1] > UNIT:\n                    pre_displacement1 = np.array([0, -40])\n            elif action == 1:  # down\n                if self.centre1[1] < (MAZE_W - 1) * UNIT:\n                    pre_displacement1 = np.array([0, 40])\n            elif action == 2:  # right\n                if self.centre1[0] < (MAZE_H - 1) * UNIT:\n                    pre_displacement1 = np.array([40, 0])\n            elif action == 3:  # left\n                if self.centre1[0] > ((MAZE_H - 1) / 2 + 2) * UNIT:\n                    pre_displacement1 = np.array([-40, 0])\n        \n        \n        \n        \n        \n        # true next state\n        displacement1 = np.array([0, 0])\n        \n        if self.centre1[0] <= ((MAZE_H - 1) / 2 + 1) * UNIT:\n            if true_action == 0:  # up\n                if self.centre1[1] > UNIT:\n                    displacement1=np.array([0,-40])\n            elif true_action == 1:  # down\n                if self.centre1[1] < (MAZE_W - 1) * UNIT:\n                    displacement1=np.array([0,40])\n            elif true_action == 2:  # right\n                if self.centre1[0] < ((MAZE_H - 1) / 2 - 1) * UNIT:\n                    displacement1=np.array([40,0])\n            elif true_action == 3:  # left\n                if self.centre1[0] > UNIT:\n                    displacement1=np.array([-40,0])\n        else:\n            if true_action == 0:  # up\n                if self.centre1[1] > UNIT:\n                    displacement1=np.array([0,-40])\n            elif true_action == 1:  # down\n                if self.centre1[1] < (MAZE_W - 1) * UNIT:\n                    displacement1=np.array([0,40])\n            elif true_action == 2:  # right\n                if self.centre1[0] < (MAZE_H - 1) * UNIT:\n                    displacement1=np.array([40,0])\n            elif true_action == 3:  # left\n                if self.centre1[0] > ((MAZE_H - 1) / 2 + 2) * UNIT:\n                    displacement1=np.array([-40,0])\n        self.canvas.move(self.rect, displacement1[0], displacement1[1])\n        s1_ = self.canvas.coords(self.rect)\n        sss = [(s1_[4] + s1_[8]) / 2, (s1_[5] + s1_[9]) / 2]\n\n\n        return displacement1, pre_displacement1,s1_,sss\n\n    \n\n    def _set_danger(self):\n        self.hell1_center = np.array([60, 60])\n        self.hell1 = self.canvas.create_oval(\n            self.hell1_center[0] - 15, self.hell1_center[1] - 15,\n            self.hell1_center[0] + 15, self.hell1_center[1] + 15,\n            fill='black')\n        # self.canvas.create_bitmap((40 , 40), bitmap='error')\n        self.hell = self.canvas.coords(self.hell1)\n        self.canvas.pack()\n        self.danger=1\n\n    def _set_switch(self):\n        self.oval_center = np.array([(MAZE_H * UNIT) / 2 - UNIT, ((MAZE_W + 4) * UNIT) / 2 - UNIT / 2])\n        self.oval = self.canvas.create_oval(\n            self.oval_center[0] - 15, self.oval_center[1] - 15,\n            self.oval_center[0] + 15, self.oval_center[1] + 15,\n            fill='yellow')\n        self.switch = self.canvas.coords(self.oval)\n        self.canvas.pack()\n\n\n    def _set_wall(self):\n        wall_center=[]\n        self.wall=[]\n        for a in range(MAZE_W):\n            wall_center.append([0,0])\n            self.wall.append([])\n        for b in range(MAZE_W):\n            wall_center[b]=np.array([(MAZE_H*UNIT)/2,((b)*UNIT)+UNIT/2])\n            self.wall[b] = self.canvas.create_rectangle(\n                wall_center[b][0] - 20, wall_center[b][1] - 20,\n                wall_center[b][0] + 20, wall_center[b][1] + 20,\n                fill='grey')\n        self.wall0 = self.canvas.coords(self.wall[0])\n        self.wall1 = self.canvas.coords(self.wall[1])\n        self.wall2 = self.canvas.coords(self.wall[2])\n        self.wall3 = self.canvas.coords(self.wall[3])\n\n        # self.canvas.pack()\n\n    def generate_expression(self,pain):\n        if pain==1:\n            self.canvas.itemconfig(self.rect, fill=\"red\", outline='red')\n            # self.canvas.pack()\n        if pain == 0:\n            self.canvas.itemconfig(self.rect, fill=\"green\", outline='green')\n            # self.canvas.pack()\n\n    def render(self):\n        time.sleep(0.2)\n        self.update()\n\n    # def getter(self, widget):\n    #     widget.update()\n    #     x = tk.Tk.winfo_rootx(self) + widget.winfo_x()\n    #     y = tk.Tk.winfo_rooty(self) + widget.winfo_y()\n    #     x1 = x + widget.winfo_width()\n    #     y1 = y + widget.winfo_height()\n    #     ImageGrab.grab().crop((x, y, x1, y1)).save(\"first.jpg\")\n    #     return ImageGrab.grab().crop((x, y, x1, y1))\n\n\n\n"
  },
  {
    "path": "examples/Social_Cognition/affective_empathy/BRP-SNN/env_two_poly_SNN.py",
    "content": "import numpy as np\nnp.random.seed(1)\nimport tkinter as tk\nimport time\nfrom PIL import ImageGrab\n\n\nUNIT = 40   # pixels\nMAZE_H = 9  # grid height\nMAZE_W = 4 # grid width\n\n\nclass Maze2(tk.Tk, object):\n    def __init__(self):\n        super(Maze2, self).__init__()\n        self.action_space = ['u', 'd', 'l', 'r']\n        self.action_space1 = ['u', 'd', 'l', 'r']\n        self.n_actions = len(self.action_space)\n        self.n_actions1 = len(self.action_space1)\n        self.title('two_agent_empathy')\n        self.geometry('{0}x{1}'.format(MAZE_H * UNIT, MAZE_W * UNIT))\n        self._build_maze()\n        self.danger=0\n        self.action_hurt=0\n        self.sensory_hurt = 0\n        self.action_hurt1 = 0\n        self.sensory_hurt1 = 0\n        self.open_door=0\n\n    # create environment\n    def _build_maze(self):\n        self.canvas = tk.Canvas(self, bg='white',\n                           height=MAZE_W * UNIT,\n                           width=MAZE_H * UNIT)\n\n        # create grids\n        for c in range(0, MAZE_H * UNIT, UNIT):\n            x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT\n            self.canvas.create_line(x0, y0, x1, y1)\n        for r in range(0, MAZE_H * UNIT, UNIT):\n            x0, y0, x1, y1 = 0, r, MAZE_H * UNIT, r\n            self.canvas.create_line(x0, y0, x1, y1)\n\n        # create switch\n        self.oval_center = np.array([(MAZE_H * UNIT)/2-UNIT+80, ((MAZE_W+4) * UNIT)/2-UNIT/2-80])\n        self.oval = self.canvas.create_oval(\n            self.oval_center[0] - 15, self.oval_center[1] - 15,\n            self.oval_center[0] + 15, self.oval_center[1] + 15,\n            fill='yellow')\n        self.switch = self.canvas.coords(self.oval)\n\n        self.orgin1 = np.array([20, 20])\n        # 下\n        self.points1 = [\n            # 左上\n            self.orgin1[0] - 15,  # 5\n            self.orgin1[1] - 15,  # 5\n            # 右上\n            self.orgin1[0] + 15,  # 35\n            self.orgin1[1] - 15,  # 5\n            # 右下+\n            self.orgin1[0] + 15,  # 35\n            self.orgin1[1],  # 20\n            # 顶点\n            self.orgin1[0],  # 20\n            self.orgin1[1] + 15,  # 35\n            # 左下+\n            self.orgin1[0] - 15,  # 5\n            self.orgin1[1],  # 20\n        ]\n        self.agent1 = self.canvas.create_polygon(self.points1, outline='black',fill=\"blue\")\n\n        self.orgin = np.array([MAZE_H * UNIT - UNIT / 2, 20])\n        # 下\n        self.points = [\n            # 左上\n            self.orgin[0] - 15,  # 5\n            self.orgin[1] - 15,  # 5\n            # 右上\n            self.orgin[0] + 15,  # 35\n            self.orgin[1] - 15,  # 5\n            # 右下+\n            self.orgin[0] + 15,  # 35\n            self.orgin[1],  # 20\n            # 顶点\n            self.orgin[0],  # 20\n            self.orgin[1] + 15,  # 35\n            # 左下+\n            self.orgin[0] - 15,  # 5\n            self.orgin[1],  # 20\n        ]\n        self.agent = self.canvas.create_polygon(self.points, fill=\"green\")\n\n        wall_center = []\n        self.wall = []\n        for i in range(MAZE_W):\n            wall_center.append([])\n            self.wall.append([])\n        for i in range(MAZE_W):\n            wall_center[i] = np.array([(MAZE_H * UNIT) / 2, ((i) * UNIT) + UNIT / 2])\n            self.wall[i] = self.canvas.create_rectangle(\n                wall_center[i][0] - 20, wall_center[i][1] - 20,\n                wall_center[i][0] + 20, wall_center[i][1] + 20,\n                fill='grey')\n\n        self.hell1_center = np.array([100, 20])\n        self.hell1 = self.canvas.create_oval(\n            self.hell1_center[0] - 15, self.hell1_center[1] - 15,\n            self.hell1_center[0] + 15, self.hell1_center[1] + 15,\n            fill='black')\n        self.hell2_center = np.array([60, 100])\n        self.hell2 = self.canvas.create_oval(\n            self.hell2_center[0] - 15, self.hell2_center[1] - 15,\n            self.hell2_center[0] + 15, self.hell2_center[1] + 15,\n            fill='black')\n        # self.canvas.create_bitmap((40 , 40), bitmap='error')\n\n        self.danger = 1\n\n        self.canvas.pack()\n\n    #reset agent location\n    def reset(self):\n        self.update()\n        time.sleep(0.5)\n        self.canvas.delete(self.agent1)\n        self.canvas.delete(self.agent)\n        self.orgin1 = np.array([20, 20])\n        # 下\n        self.points1 = [\n            # 左上\n            self.orgin1[0] - 15,  # 5\n            self.orgin1[1] - 15,  # 5\n            # 右上\n            self.orgin1[0] + 15,  # 35\n            self.orgin1[1] - 15,  # 5\n            # 右下+\n            self.orgin1[0] + 15,  # 35\n            self.orgin1[1],  # 20\n            # 顶点\n            self.orgin1[0],  # 20\n            self.orgin1[1] + 15,  # 35\n            # 左下+\n            self.orgin1[0] - 15,  # 5\n            self.orgin1[1],  # 20\n        ]\n        self.agent1 = self.canvas.create_polygon(self.points1, outline='black',fill=\"blue\")\n\n        self.orgin = np.array([MAZE_H * UNIT - UNIT / 2, 20])\n        # 下\n        self.points = [\n            # 左上\n            self.orgin[0] - 15,  # 5\n            self.orgin[1] - 15,  # 5\n            # 右上\n            self.orgin[0] + 15,  # 35\n            self.orgin[1] - 15,  # 5\n            # 右下+\n            self.orgin[0] + 15,  # 35\n            self.orgin[1],  # 20\n            # 顶点\n            self.orgin[0],  # 20\n            self.orgin[1] + 15,  # 35\n            # 左下+\n            self.orgin[0] - 15,  # 5\n            self.orgin[1],  # 20\n        ]\n        self.agent = self.canvas.create_polygon(self.points, fill=\"green\")\n\n        return self.canvas.coords(self.agent1),self.canvas.coords(self.agent)\n\n\n\n    def step(self, s, action, pain):\n        s1 = self.canvas.coords(self.agent1)\n        self.centre1 = [(s1[4] + s1[8]) / 2, (s1[5] + s1[9]) / 2]\n        if all(self.centre1 == self.hell1_center):\n            self.action_hurt1 = 1\n        if all(self.centre1 == self.hell2_center):\n            self.action_hurt1 = 1\n        \n        self.oval_center111 = np.array([(MAZE_H * UNIT) / 2, ((MAZE_W + 4) * UNIT) / 2 - UNIT / 2])\n        if all(self.centre1 ==self.oval_center111):\n            move = np.array([80, 0])\n            self.canvas.move(self.agent1, move[0], move[1])\n            s1 = self.canvas.coords(self.agent1)\n            self.render()\n        self.oval_center111 = np.array([(MAZE_H * UNIT) / 2 - UNIT, ((MAZE_W + 4) * UNIT) / 2 - UNIT / 2])\n        if all(self.centre1 ==self.oval_center111):\n            move = np.array([80, 0])\n            self.canvas.move(self.agent1, move[0], move[1])\n            s1 = self.canvas.coords(self.agent1)\n            self.render()\n        self.oval_center111 = np.array([(MAZE_H * UNIT) / 2 - UNIT*2, ((MAZE_W + 4) * UNIT) / 2 - UNIT / 2])\n        if all(self.centre1 == self.oval_center111):\n            move = np.array([80, 0])\n            self.canvas.move(self.agent1, move[0], move[1])\n            s1 = self.canvas.coords(self.agent1)\n            self.render()\n        self.oval_center111 = np.array([(MAZE_H * UNIT) / 2 - UNIT*3, ((MAZE_W + 4) * UNIT) / 2 - UNIT / 2])\n        if all(self.centre1 == self.oval_center111):\n            move = np.array([80, 0])\n            self.canvas.move(self.agent1, move[0], move[1])\n            s1 = self.canvas.coords(self.agent1)\n            self.render()\n        self.oval_center111 = np.array([(MAZE_H * UNIT) / 2 - UNIT*4, ((MAZE_W + 4) * UNIT) / 2 - UNIT / 2])\n        if all(self.centre1 == self.oval_center111):\n            move = np.array([80, 0])\n            self.canvas.move(self.agent1, move[0], move[1])\n            s1 = self.canvas.coords(self.agent1)\n            self.render()\n\n        \n\n\n\n        #显示当前的动作指令是什么\n        self.canvas.delete(self.agent1)\n        self.centre = [(s1[4] + s1[8]) / 2, (s1[5] + s1[9]) / 2]\n\n        if action==0:\n            self.points0 = [\n                # 右下\n                self.centre[0] + 15,  # 35\n                self.centre[1] + 15,  # 35\n                # 左下\n                self.centre[0] - 15,  # 5\n                self.centre[1] + 15,  # 35\n                # 左上+\n                self.centre[0] - 15,  # 5\n                self.centre[1],  # 20\n                # 顶点\n                self.centre[0],  # 20\n                self.centre[1] - 15,  # 5\n                # 右上+\n                self.centre[0] + 15,  # 35\n                self.centre[1],  # 20\n            ]\n            if pain==0:\n                color=\"blue\"\n            if pain == 1:\n                color = \"red\"\n            self.agent1 = self.canvas.create_polygon(self.points0, fill=color)\n        if action==1:\n            self.points1 = [\n                # 左上\n                self.centre[0] - 15,  # 5\n                self.centre[1] - 15,  # 5\n                # 右上\n                self.centre[0] + 15,  # 35\n                self.centre[1] - 15,  # 5\n                # 右下+\n                self.centre[0] + 15,  # 35\n                self.centre[1],  # 20\n                # 顶点\n                self.centre[0],  # 20\n                self.centre[1] + 15,  # 35\n                # 左下+\n                self.centre[0] - 15,  # 5\n                self.centre[1],  # 20\n            ]\n            if pain==0:\n                color=\"blue\"\n            if pain == 1:\n                color = \"red\"\n            self.agent1 = self.canvas.create_polygon(self.points1, fill=color)\n        if action==2:\n            self.points2 = [\n                # 左下\n                self.centre[0] - 15,  # 5\n                self.centre[1] + 15,  # 35\n                # 左上\n                self.centre[0] - 15,  # 5\n                self.centre[1] - 15,  # 5\n                # 右上+\n                self.centre[0],  # 20\n                self.centre[1] - 15,  # 5\n                # 顶点\n                self.centre[0] + 15,  # 35\n                self.centre[1],  # 20\n                # 右下+\n                self.centre[0],  # 20\n                self.centre[1] + 15,  # 35\n            ]\n            if pain==0:\n                color=\"blue\"\n            if pain == 1:\n                color = \"red\"\n            self.agent1 = self.canvas.create_polygon(self.points2, fill=color)\n        if action==3:\n            self.points3 = [\n                # 右上\n                self.centre[0] + 15,  # 20+15\n                self.centre[1] - 15,  # 20-15\n                # 右下\n                self.centre[0] + 15,  # 20+15\n                self.centre[1] + 15,  # 20+15\n                # 左下+\n                self.centre[0],  # 20\n                self.centre[1] + 15,  # 20+15\n                # 顶点\n                self.centre[0] - 15,  # 20-15\n                self.centre[1],  # 20\n                # 左上+\n                self.centre[0],  # 20\n                self.centre[1] - 15,  # 20-15\n\n            ]\n            if pain==0:\n                color=\"blue\"\n            if pain == 1:\n                color = \"red\"\n            self.agent1 = self.canvas.create_polygon(self.points3, fill=color)\n        \n        s1 = self.canvas.coords(self.agent1)\n        self.render()#显示当前的动作指令是什么\n\n        self.centre1 = [(s1[4] + s1[8]) / 2, (s1[5] + s1[9]) / 2]\n        if self.centre1[0] > (9 / 2) * 40:\n            self.action_hurt1 = 0\n\n        # whether hurt\n        if self.action_hurt1 == 0:\n            true_action = action\n        else:\n            if action == 0:\n                true_action = 1\n            if action == 1:\n                true_action = 0\n            if action == 2:\n                true_action = 3\n            if action == 3:\n                true_action = 2\n        \n\n            \n        base_action = np.array([0, 0])\n        \n\n       \n        # predict next state\n        self.centre1 = [(s1[4] + s1[8]) / 2, (s1[5] + s1[9]) / 2]\n        pre_displacement1 = np.array([0, 0])\n        if self.centre1[0] <= ((MAZE_H - 1) / 2 + 1) * UNIT:  # 120\n            if action == 0:  # up\n                if self.centre1[1] > UNIT:\n                    pre_displacement1 = np.array([0, -40])\n            elif action == 1:  # down\n                if self.centre1[1] < (MAZE_W - 1) * UNIT:\n                    pre_displacement1 = np.array([0, 40])\n            elif action == 2:  # right\n                if self.centre1[0] < ((MAZE_H - 1) / 2 - 1) * UNIT:\n                    pre_displacement1 = np.array([40, 0])\n            elif action == 3:  # left\n                if self.centre1[0] > UNIT:\n                    pre_displacement1 = np.array([-40, 0])\n        else:\n            if action == 0:  # up\n                if self.centre1[1] > UNIT:\n                    pre_displacement1 = np.array([0, -40])\n            elif action == 1:  # down\n                if self.centre1[1] < (MAZE_W - 1) * UNIT:\n                    pre_displacement1 = np.array([0, 40])\n            elif action == 2:  # right\n                if self.centre1[0] < (MAZE_H - 1) * UNIT:\n                    pre_displacement1 = np.array([40, 0])\n            elif action == 3:  # left\n                if self.centre1[0] > ((MAZE_H - 1) / 2 + 2) * UNIT:\n                    pre_displacement1 = np.array([-40, 0])\n        \n        \n        \n        \n        \n        # true next state\n        displacement1 = np.array([0, 0])\n        \n        if self.centre1[0] <= ((MAZE_H - 1) / 2 + 1) * UNIT:\n            if true_action == 0:  # up\n                if self.centre1[1] > UNIT:\n                    displacement1=np.array([0,-40])\n            elif true_action == 1:  # down\n                if self.centre1[1] < (MAZE_W - 1) * UNIT:\n                    displacement1=np.array([0,40])\n            elif true_action == 2:  # right\n                if self.centre1[0] < ((MAZE_H - 1) / 2 - 1) * UNIT:\n                    displacement1=np.array([40,0])\n            elif true_action == 3:  # left\n                if self.centre1[0] > UNIT:\n                    displacement1=np.array([-40,0])\n        else:\n            if true_action == 0:  # up\n                if self.centre1[1] > UNIT:\n                    displacement1=np.array([0,-40])\n            elif true_action == 1:  # down\n                if self.centre1[1] < (MAZE_W - 1) * UNIT:\n                    displacement1=np.array([0,40])\n            elif true_action == 2:  # right\n                if self.centre1[0] < (MAZE_H - 1) * UNIT:\n                    displacement1=np.array([40,0])\n            elif true_action == 3:  # left\n                if self.centre1[0] > ((MAZE_H - 1) / 2 + 2) * UNIT:\n                    displacement1=np.array([-40,0])\n        self.canvas.move(self.agent1, displacement1[0], displacement1[1])\n        s1_ = self.canvas.coords(self.agent1)\n        sss = [(s1_[4] + s1_[8]) / 2, (s1_[5] + s1_[9]) / 2]\n\n\n        return displacement1, pre_displacement1,s1_,sss\n\n\n    def agent_help(self):\n        s = self.canvas.coords(self.agent)\n        self.centre2= [(s[4] + s[8]) / 2, (s[5] + s[9]) / 2]\n        if all(self.centre2 == self.oval_center):\n            self.canvas.delete(self.wall[3])\n            self.render()\n            self.open_door=1\n        else:    \n            self.canvas.move(self.agent, -40, 0)  # move agent\n            self.render()\n            self.canvas.move(self.agent, -40, 0)\n            self.render()\n            self.canvas.move(self.agent, -40, 0)\n            self.render()\n            self.canvas.move(self.agent, 0, 40)\n            self.render()\n        s_ = self.canvas.coords(self.agent)  # next state\n        \n        return s_\n\n    def _set_danger(self):\n        hell1_center = np.array([140, 60])\n        self.hell1 = self.canvas.create_agent1angle(\n            hell1_center[0] - 15, hell1_center[1] - 15,\n            hell1_center[0] + 15, hell1_center[1] + 15,\n            fill='black')\n        hell2_center = np.array([100, 140])\n        self.hell2 = self.canvas.create_agent1angle(\n            hell2_center[0] - 15, hell2_center[1] - 15,\n            hell2_center[0] + 15, hell2_center[1] + 15,\n            fill='black')\n        # self.canvas.create_bitmap((40 , 40), bitmap='error')\n        self.canvas.pack()\n        self.danger=1\n\n\n    def _set_wall(self):\n        wall_center=[]\n        self.wall=[]\n        for i in range(MAZE_W):\n            wall_center.append([])\n            self.wall.append([])\n        for i in range(MAZE_W):\n            wall_center[i]=np.array([(MAZE_H*UNIT)/2,((i)*UNIT)+UNIT/2])\n            self.wall[i] = self.canvas.create_agent1angle(\n                wall_center[i][0] - 20, wall_center[i][1] - 20,\n                wall_center[i][0] + 20, wall_center[i][1] + 20,\n                fill='grey')\n        self.canvas.pack()\n\n    \n    def generate_expression1(self,pain1):\n        if pain1==1:\n            self.canvas.itemconfig(self.agent1, fill=\"red\", outline='black')\n            self.canvas.pack()\n        if pain1 ==0:\n            self.canvas.itemconfig(self.agent1, fill=\"blue\", outline='black')\n            self.canvas.pack()\n    def render(self):\n        time.sleep(0.2)\n        self.update()\n\n\n\n\n"
  },
  {
    "path": "examples/Social_Cognition/mirror_test/README.md",
    "content": "# Mirror Test \n\nThe mirror_test.py implements the core code of the Multi-Robots Mirror Self-Recognition Test in \"Toward Robot Self-Consciousness (II): Brain-Inspired Robot Bodily Self Model for Self-Recognition\".\n\nThe experiment is: three robots with identical appearance move their arms randomly in front of the mirror at the same time. \n\nIn the training stage, according to the spiking time difference of neurons in IPLM and IPLV, the robot learns the correlations between self-generated actions and visual feedbacks in motion by learning with spike timing dependent plasticity (STDP) mechanism. \n\nIn the test stage, the robot can predicts the visual feedback generated by its arm movement according to the training results. With the InsulaNet, the robot can identify which mirror image belongs to it.\n\nIn the result, Motion Detection shows the results of visual detection, and Motion Prediction shows the visual feedback generated by itself. The red line in the figure indicates that the robot determines that the corresponding mirror belongs to itself.\n\nDifferences from the original article:\nSince there is no motion error under the simulation conditions, the theta_threshold is set to zero.\n\n\n### Citation \nIf you find this package helpful, please consider citing the following papers:\n\n```BibTex\n@article{zeng2018toward,\n  title={Toward robot self-consciousness (ii): brain-inspired robot bodily self model for self-recognition},\n  author={Zeng, Yi and Zhao, Yuxuan and Bai, Jun and Xu, Bo},\n  journal={Cognitive Computation},\n  volume={10},\n  number={2},\n  pages={307--320},\n  year={2018},\n  publisher={Springer}\n}\n\n@misc{https://doi.org/10.48550/arxiv.2207.08533,\n  doi = {10.48550/ARXIV.2207.08533},\n  url = {https://arxiv.org/abs/2207.08533},\n  author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},\n  title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},\n  publisher = {arXiv},\n  year = {2022},\n}\n\n```\n"
  },
  {
    "path": "examples/Social_Cognition/mirror_test/mirror_test.py",
    "content": "from braincog.base.brainarea.Insula import *\nfrom braincog.base.brainarea.IPL import *\nfrom braincog.base.learningrule.STDP import *\nfrom braincog.base.node.node import *\nfrom braincog.base.connection.CustomLinear import *\nimport random\nimport numpy as np\nimport torch\nimport os\nimport sys\nfrom torch import nn\nfrom torch.nn import Parameter\n\nimport abc\nimport math\nfrom abc import ABC\n\nimport numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import Parameter\nimport torch.nn.functional as F\nimport matplotlib.pyplot as plt\nfrom braincog.base.strategy.surrogate import *\n\nimport os\nos.environ[\"KMP_DUPLICATE_LIB_OK\"] = \"TRUE\"\n\n\nif __name__ == \"__main__\":\n    \"\"\"\n    Set the number of neurons, and each neuron represents unique motion information (such as angle)\n    \"\"\"\n    # number of neurons\n    num_neuron = 5\n    num_vPMC = num_neuron\n    num_STS = num_neuron\n    num_IPLM = num_neuron\n    num_IPLV = num_neuron\n    num_Insula = num_neuron\n\n    \"\"\"\n    Setting the network structure and the initial weight of IPL\n    \"\"\"\n    # IPLNet\n    # connection\n    connection = []\n    # vPMC-IPLM\n    con_matrix0 = torch.eye(num_IPLM, dtype=torch.float) * 2.5\n    connection.append(CustomLinear(con_matrix0))\n    # STS-IPLV\n    con_matrix1 = torch.eye(num_IPLV, dtype=torch.float) * 2.5\n    connection.append(CustomLinear(con_matrix1))\n    # IPLM-IPLV\n    con_matrix2 = torch.zeros([num_IPLM, num_IPLV], dtype=torch.float)\n    connection.append(CustomLinear(con_matrix2))\n\n    IPL = IPLNet(connection)\n\n    print(\"IPL Connection (Before training):\", connection[2].weight)\n\n    \"\"\"\n    Setting the network structure and the initial weight of Insula\n    \"\"\"\n    # InsulaNet\n    # connection\n    Insula_connection = []\n    # IPLV-Insula\n    con_matrix0 = torch.eye(num_IPLM, dtype=torch.float) * 2\n    Insula_connection.append(CustomLinear(con_matrix0))\n    # STS-Insula\n    con_matrix1 = torch.eye(num_IPLV, dtype=torch.float) * 2\n    Insula_connection.append(CustomLinear(con_matrix1))\n\n    Insula = InsulaNet(Insula_connection)\n\n    \"\"\"\n    Training process\n    :param train_num: number of movements during training\n    \"\"\"\n    # Train\n    for vPMC_Angel in range(1, num_vPMC + 1):\n        # vPMC Angle\n        vPMC_Angel_v = torch.zeros([1, num_vPMC], dtype=torch.float)\n        vPMC_Angel_v[0, vPMC_Angel - 1] = 20\n\n        dwIPL_temp = torch.zeros([num_IPLM, num_IPLV], dtype=torch.float)\n\n        train_num = 10\n        for i_train in range(train_num):\n\n            # STS 1\n            STS_Angel_1 = vPMC_Angel\n            for t in range(2):\n                vPMC_input = vPMC_Angel_v\n                STS_Angel_v = torch.zeros([1, num_STS], dtype=torch.float)\n                STS_Angel_v[0, STS_Angel_1 - 1] = 20\n                STS_input = STS_Angel_v\n                IPLV_out, dwIPL = IPL(vPMC_input, STS_input)\n                dwIPL_temp = dwIPL_temp + dwIPL\n            IPL.reset()\n\n            # STS 2\n            STS_Angel_2 = random.randint(1, num_neuron)\n            for t in range(2):\n                vPMC_input = vPMC_Angel_v\n                STS_Angel_v = torch.zeros([1, num_STS], dtype=torch.float)\n                STS_Angel_v[0, STS_Angel_2 - 1] = 20\n                STS_input = STS_Angel_v\n                IPLV_out, dwIPL = IPL(vPMC_input, STS_input)\n                dwIPL_temp = dwIPL_temp + dwIPL\n            IPL.reset()\n\n            # STS 3\n            STS_Angel_3 = random.randint(1, num_neuron)\n            for t in range(2):\n                vPMC_input = vPMC_Angel_v\n                STS_Angel_v = torch.zeros([1, num_STS], dtype=torch.float)\n                STS_Angel_v[0, STS_Angel_3 - 1] = 20\n                STS_input = STS_Angel_v\n                IPLV_out, dwIPL = IPL(vPMC_input, STS_input)\n                dwIPL_temp = dwIPL_temp + dwIPL\n            IPL.reset()\n\n        IPL.UpdateWeight(2, dwIPL_temp)\n\n    print(\"IPL Connection (After training):\", connection[2].weight)\n\n    \"\"\"\n    Test process\n    :param move_count: number of movements during test\n    \"\"\"\n    # Test\n    move_count = 10\n    TestList_vPMC_Angel = np.random.randint(1, num_vPMC, move_count)\n    TestList_STS_Angel_1 = TestList_vPMC_Angel\n    TestList_STS_Angel_2 = np.random.randint(1, num_STS, move_count)\n    TestList_STS_Angel_3 = np.random.randint(1, num_STS, move_count)\n    TestMat_STS_Angle = np.vstack((TestList_STS_Angel_1, TestList_STS_Angel_2, TestList_STS_Angel_3))\n    np.random.shuffle(TestMat_STS_Angle)\n\n    TestList_IPLV_out = []\n    for i_test in range(move_count):\n        Test_vPMC_Angel = TestList_vPMC_Angel[i_test]\n        Test_vPMC_Angel_v = torch.zeros([1, num_vPMC], dtype=torch.float)\n        Test_vPMC_Angel_v[0, Test_vPMC_Angel - 1] = 20\n        Test_STS_Angel_v = torch.zeros([1, num_STS], dtype=torch.float)\n        for t in range(2):\n            IPL(Test_vPMC_Angel_v, Test_STS_Angel_v)\n            IPLV_out_f = torch.argmax(IPL.node[1].u) + 1\n        IPL.reset()\n        TestList_IPLV_out.append(IPLV_out_f.numpy().item())\n\n    confidence = [0, 0, 0]\n    for i in range(move_count):\n        theta_predict = TestList_IPLV_out[i]\n        theta_visual_1 = TestMat_STS_Angle[0][i]\n        theta_visual_2 = TestMat_STS_Angle[1][i]\n        theta_visual_3 = TestMat_STS_Angle[2][i]\n\n        Test_IPL_v = torch.zeros([1, num_IPLV], dtype=torch.float)\n        Test_IPL_v[0, theta_predict - 1] = 20\n\n        Test_STS1_v = torch.zeros([1, num_STS], dtype=torch.float)\n        Test_STS1_v[0, theta_visual_1 - 1] = 20\n        for t in range(2):\n            Insula(Test_IPL_v, Test_STS1_v)\n        if sum(sum(Insula.out_Insula)) > 0:\n            confidence[0] = confidence[0] + 1\n        Insula.reset()\n\n        Test_STS2_v = torch.zeros([1, num_STS], dtype=torch.float)\n        Test_STS2_v[0, theta_visual_2 - 1] = 20\n        for t in range(2):\n            Insula(Test_IPL_v, Test_STS2_v)\n        if sum(sum(Insula.out_Insula)) > 0:\n            confidence[1] = confidence[1] + 1\n        Insula.reset()\n\n        Test_STS3_v = torch.zeros([1, num_STS], dtype=torch.float)\n        Test_STS3_v[0, theta_visual_3 - 1] = 20\n        for t in range(2):\n            Insula(Test_IPL_v, Test_STS3_v)\n        if sum(sum(Insula.out_Insula)) > 0:\n            confidence[2] = confidence[2] + 1\n        Insula.reset()\n\n    x_0 = torch.arange(0, move_count)\n    x_1 = torch.arange(move_count * 1, move_count * 2)\n    x_2 = torch.arange(move_count * 2, move_count * 3)\n\n    color_list = ['k', 'k', 'k']\n    color_list[confidence.index(max(confidence))] = 'r'\n\n    plt.subplot(211)\n    plt.figure(1)\n    plt.plot(x_0, TestMat_STS_Angle[0], color=color_list[0])\n    plt.plot(x_1, TestMat_STS_Angle[1], color=color_list[1])\n    plt.plot(x_2, TestMat_STS_Angle[2], color=color_list[2])\n    plt.title(\"Motion Detection\")\n    plt.subplot(212)\n    plt.plot(x_0, TestList_IPLV_out, color='r')\n    plt.title(\"Motion Prediction\")\n    plt.tight_layout()\n    plt.show()\n"
  },
  {
    "path": "examples/Spiking-Transformers/LIFNode.py",
    "content": "from timm.models.layers import to_2tuple, trunc_normal_, DropPath\nfrom timm.models.registry import register_model\nfrom timm.models.vision_transformer import _cfg\nimport torch.nn.functional as F\nfrom braincog.model_zoo.base_module import BaseModule\nfrom braincog.base.node.node import *\nfrom braincog.base.connection.layer import *\nfrom braincog.base.strategy.surrogate import *\n\nclass MyBaseNode(BaseNode):\n    def __init__(self, threshold=0.5, step=4, layer_by_layer=False, mem_detach=False):\n        super().__init__(threshold=threshold, step=step, layer_by_layer=layer_by_layer, mem_detach=mem_detach)\n\n    def rearrange2node(self, inputs):\n        if self.groups != 1:\n            if len(inputs.shape) == 4:\n                outputs = rearrange(inputs, 'b (c t) w h -> t b c w h', t=self.step)\n            elif len(inputs.shape) == 2:\n                outputs = rearrange(inputs, 'b (c t) -> t b c', t=self.step)\n            else:\n                raise NotImplementedError\n\n        elif self.layer_by_layer:\n            if len(inputs.shape) == 4:\n                outputs = rearrange(inputs, '(t b) c w h -> t b c w h', t=self.step)\n\n            # 加入适配Transformer T B N C的rearange2node分支\n            elif len(inputs.shape) == 3:\n                outputs = rearrange(inputs, '(t b) n c -> t b n c', t=self.step)\n            elif len(inputs.shape) == 2:\n                outputs = rearrange(inputs, '(t b) c -> t b c', t=self.step)\n            else:\n                raise NotImplementedError\n\n\n        else:\n            outputs = inputs\n\n        return outputs\n\n    def rearrange2op(self, inputs):\n        if self.groups != 1:\n            if len(inputs.shape) == 5:\n                outputs = rearrange(inputs, 't b c w h -> b (c t) w h')\n            elif len(inputs.shape) == 3:\n                outputs = rearrange(inputs, ' t b c -> b (c t)')\n            else:\n                raise NotImplementedError\n        elif self.layer_by_layer:\n            if len(inputs.shape) == 5:\n                outputs = rearrange(inputs, 't b c w h -> (t b) c w h')\n\n            # 加入适配Transformer T B N C的rearange2op分支\n            elif len(inputs.shape) == 4:\n                outputs = rearrange(inputs, ' t b n c -> (t b) n c')\n            elif len(inputs.shape) == 3:\n                outputs = rearrange(inputs, ' t b c -> (t b) c')\n            else:\n                raise NotImplementedError\n\n        else:\n            outputs = inputs\n\n        return outputs\n\n\nclass MyGrad(SurrogateFunctionBase):\n    def __init__(self, alpha=4., requires_grad=False):\n        super().__init__(alpha, requires_grad)\n\n    @staticmethod\n    def act_fun(x, alpha):\n        return sigmoid.apply(x, alpha)\n\n\nclass MyNode(MyBaseNode):\n    def __init__(self, threshold=1., step=4, layer_by_layer=True, tau=2., act_fun=MyGrad, mem_detach=True, *args,\n                 **kwargs):\n        super().__init__(threshold=threshold, step=step, layer_by_layer=layer_by_layer, mem_detach=mem_detach)\n        self.tau = tau\n        if isinstance(act_fun, str):\n            act_fun = eval(act_fun)\n        self.act_fun = act_fun(alpha=4., requires_grad=False)\n\n    def integral(self, inputs):\n        self.mem = self.mem + (inputs - self.mem) / self.tau\n\n    def calc_spike(self):\n        self.spike = self.act_fun(self.mem - self.threshold)\n        self.mem = self.mem * (1 - self.spike.detach())"
  },
  {
    "path": "examples/Spiking-Transformers/README.md",
    "content": "# Spiking Transformers Reproduced With Braincog\nHere is the current Spiking Transformer code reproduced using [BrainCog](http://www.brain-cog.network/). Welcome to follow the work of BrainCog and utilize the [BrainCog framework](https://github.com/BrainCog-X/Brain-Cog) to create relevant brain-inspired AI endeavors. The works implemented here will also be merged into BrainCog Repo.\n\n### Models\n**Spikformer(ICLR 2023)**\n[Zhou, Z., Zhu, Y., He, C., Wang, Y., Yan, S., Tian, Y., & Yuan, L. (2022). Spikformer: When spiking neural network meets transformer. arXiv preprint arXiv:2209.15425.](https://openreview.net/forum?id=frE4fUwz_h)\n![alt text](/img/spikformer.png)\n\n**Spike-driven Transformer(Nips 2023)**\n[Yao, M., Hu, J., Zhou, Z., Yuan, L., Tian, Y., Xu, B., & Li, G. (2024). Spike-driven transformer. Advances in Neural Information Processing Systems, 36.](https://proceedings.neurips.cc/paper_files/paper/2023/hash/ca0f5358dbadda74b3049711887e9ead-Abstract-Conference.html)\n![alt text](/img/sdv1.png)\n\n\n**Spike-driven Transformer V2(ICLR 2024)**\n[Yao, M., Hu, J., Hu, T., Xu, Y., Zhou, Z., Tian, Y., ... & Li, G. (2023, October). Spike-driven Transformer V2: Meta Spiking Neural Network Architecture Inspiring the Design of Next-generation Neuromorphic Chips. In The Twelfth International Conference on Learning Representations.](https://openreview.net/forum?id=1SIBN5Xyw7)\n![alt text](/img/sdv2.png)\n\n## Models in comming soon\n**SpikingResFormer**\n[Shi, X., Hao, Z., & Yu, Z. (2024). SpikingResformer: Bridging ResNet and Vision Transformer in Spiking Neural Networks. arXiv preprint arXiv:2403.14302.](https://arxiv.org/abs/2403.14302)\n\n**TIM**\n[Shen, S., Zhao, D., Shen, G., & Zeng, Y. (2024). TIM: An Efficient Temporal Interaction Module for Spiking Transformer. arXiv preprint arXiv:2401.11687.](https://arxiv.org/abs/2401.11687)\n\n**SGLFormer(Frontiers in Neuroscience)**\n[Zhang, H., Zhou, C., Yu, L., Huang, L., Ma, Z., Fan, X., ... & Tian, Y. (2024). SGLFormer: Spiking Global-Local-Fusion Transformer with High Performance. Frontiers in Neuroscience, 18, 1371290.](https://www.frontiersin.org/journals/neuroscience/articles/10.3389/fnins.2024.1371290/full)\n\n**QKFormer(CVPR2024)**\n[Zhou, C., Zhang, H., Zhou, Z., Yu, L., Huang, L., Fan, X., ... & Tian, Y. (2024). QKFormer: Hierarchical Spiking Transformer using QK Attention. arXiv preprint arXiv:2403.16552.](https://arxiv.org/abs/2403.16552)\n\n\n## Requirments\n- Braincog\n- einops >= 0.4.1\n- timm >= 0.5.4\n\n## Training Examples\n### Training on CIFAR10-DVS\npython main.py --dataset dvsc10 --epochs 500 --batch-size 16 --seed 42 --event-size 64 --model spikformer_dvs\n### Training on ImageNet\npython main.py --dataset imnet --epochs 500 --batch-size 16 --seed 42 --model spikformer"
  },
  {
    "path": "examples/Spiking-Transformers/datasets.py",
    "content": "import os\nimport warnings\nimport random\nimport torchvision.datasets\n\nimport braincog.datasets.ucf101_dvs\n\ntry:\n    import tonic\n    from tonic import DiskCachedDataset\nexcept:\n    warnings.warn(\"tonic should be installed, 'pip install git+https://github.com/FloyedShen/tonic.git'\")\n\nimport torch\nimport torch.nn.functional as F\nimport torch.utils\nimport torchvision.datasets as datasets\nfrom timm.data import ImageDataset, create_loader, Mixup, FastCollateMixup, AugMixDataset\nfrom timm.data import create_transform, distributed_sampler\nfrom timm.data.loader import PrefetchLoader\nfrom tonic import DiskCachedDataset\nfrom torchvision import transforms\nfrom typing import Any, Dict, Optional, Sequence, Tuple, Union\nfrom braincog.datasets.NOmniglot.nomniglot_full import NOmniglotfull\nfrom braincog.datasets.NOmniglot.nomniglot_nw_ks import NOmniglotNWayKShot\nfrom braincog.datasets.NOmniglot.nomniglot_pair import NOmniglotTrainSet, NOmniglotTestSet\n# from braincog.base.conversion.conversion import CIFAR10Policy, Cutout\n# from .cut_mix import CutMix, EventMix, MixUp\n# from .rand_aug import *\n# from .event_drop import event_drop\n# from .utils import dvs_channel_check_expend, rescale\n\nDVSCIFAR10_MEAN_16 = [0.3290, 0.4507]\nDVSCIFAR10_STD_16 = [1.8398, 1.6549]\n\nDATA_DIR = '/data/datasets'\n\nDEFAULT_CROP_PCT = 0.875\nIMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)\nIMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)\nIMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)\nIMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)\nIMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)\nIMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)\n\nCIFAR10_DEFAULT_MEAN = (0.4914, 0.4822, 0.4465)\nCIFAR10_DEFAULT_STD = (0.2023, 0.1994, 0.2010)\nCIFAR100_DEFAULT_MEAN = (0.5071, 0.4867, 0.4408)\nCIFAR100_DEFAULT_STD = (0.2675, 0.2565, 0.2761)\n\n\ndef unpack_mix_param(args):\n    mix_up = args['mix_up'] if 'mix_up' in args else False\n    cut_mix = args['cut_mix'] if 'cut_mix' in args else False\n    event_mix = args['event_mix'] if 'event_mix' in args else False\n    beta = args['beta'] if 'beta' in args else 1.\n    prob = args['prob'] if 'prob' in args else .5\n    num = args['num'] if 'num' in args else 1\n    num_classes = args['num_classes'] if 'num_classes' in args else 10\n    noise = args['noise'] if 'noise' in args else 0.\n    gaussian_n = args['gaussian_n'] if 'gaussian_n' in args else None\n    return mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n\n\n\ndef build_transform(is_train, img_size):\n    \"\"\"\n    构建数据增强, 适用于static data\n    :param is_train: 是否训练集\n    :param img_size: 输出的图像尺寸\n    :return: 数据增强策略\n    \"\"\"\n    resize_im = img_size > 32\n    if is_train:\n        # this should always dispatch to transforms_imagenet_train\n        transform = create_transform(\n            input_size=img_size,\n            is_training=True,\n            color_jitter=0.4,\n            # auto_augment='rand-m9-mstd0.5-inc1',\n            interpolation='bicubic',\n            # re_prob=0.25,\n            # re_mode='pixel',\n            # re_count=1,\n        )\n        if not resize_im:\n            # replace RandomResizedCropAndInterpolation with\n            # RandomCrop\n            transform.transforms[0] = transforms.RandomCrop(\n                img_size, padding=4)\n        return transform\n\n    t = []\n    if resize_im:\n        size = int((256 / 224) * img_size)\n        t.append(\n            # to maintain same ratio w.r.t. 224 images\n            transforms.Resize(size, interpolation=3),\n        )\n        t.append(transforms.CenterCrop(img_size))\n\n    t.append(transforms.ToTensor())\n    if img_size > 32:\n        t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))\n    else:\n        t.append(transforms.Normalize(CIFAR10_DEFAULT_MEAN, CIFAR10_DEFAULT_STD))\n    return transforms.Compose(t)\n\n\ndef build_dataset(is_train, img_size, dataset, path, same_da=False):\n    \"\"\"\n    构建带有增强策略的数据集\n    :param is_train: 是否训练集\n    :param img_size: 输出图像尺寸\n    :param dataset: 数据集名称\n    :param path: 数据集路径\n    :param same_da: 为训练集使用测试集的增广方法\n    :return: 增强后的数据集\n    \"\"\"\n    transform = build_transform(False, img_size) if same_da else build_transform(is_train, img_size)\n\n    if dataset == 'CIFAR10':\n        dataset = datasets.CIFAR10(\n            path, train=is_train, transform=transform, download=True)\n        nb_classes = 10\n    elif dataset == 'CIFAR100':\n        dataset = datasets.CIFAR100(\n            path, train=is_train, transform=transform, download=True)\n        nb_classes = 100\n    else:\n        raise NotImplementedError\n\n    return dataset, nb_classes\n\n\nclass MNISTData(object):\n    \"\"\"\n    Load MNIST datesets.\n    \"\"\"\n\n    def __init__(self,\n                 data_path: str,\n                 batch_size: int,\n                 train_trans: Sequence[torch.nn.Module] = None,\n                 test_trans: Sequence[torch.nn.Module] = None,\n                 pin_memory: bool = True,\n                 drop_last: bool = True,\n                 shuffle: bool = True,\n                 ) -> None:\n        self._data_path = data_path\n        self._batch_size = batch_size\n        self._pin_memory = pin_memory\n        self._drop_last = drop_last\n        self._shuffle = shuffle\n        self._train_transform = transforms.Compose(train_trans) if train_trans else None\n        self._test_transform = transforms.Compose(test_trans) if test_trans else None\n\n    def get_data_loaders(self):\n        print('Batch size: ', self._batch_size)\n        train_datasets = datasets.MNIST(root=self._data_path, train=True, transform=self._train_transform, download=True)\n        test_datasets = datasets.MNIST(root=self._data_path, train=False, transform=self._test_transform, download=True)\n        train_loader = torch.utils.data.DataLoader(\n            train_datasets, batch_size=self._batch_size,\n            pin_memory=self._pin_memory, drop_last=self._drop_last, shuffle=self._shuffle\n        )\n        test_loader = torch.utils.data.DataLoader(\n            test_datasets, batch_size=self._batch_size,\n            pin_memory=self._pin_memory, drop_last=False\n        )\n        return train_loader, test_loader\n\n    def get_standard_data(self):\n        MNIST_MEAN = 0.1307\n        MNIST_STD = 0.3081\n        self._train_transform = transforms.Compose([transforms.RandomCrop(28, padding=4),\n                                                    transforms.ToTensor(),\n                                                    transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])\n        self._test_transform = transforms.Compose([transforms.ToTensor(),\n                                                   transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])\n        return self.get_data_loaders()\n\n\ndef get_mnist_data(batch_size, num_workers=8, same_da=False, **kwargs):\n    \"\"\"\n    获取MNIST数据\n    http://data.pymvpa.org/datasets/mnist/\n    :param batch_size: batch size\n    :param same_da: 为训练集使用测试集的增广方法\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    MNIST_MEAN = 0.1307\n    MNIST_STD = 0.3081\n    if 'skip_norm' in kwargs and kwargs['skip_norm'] is True:\n        train_transform = transforms.Compose([\n            transforms.ToTensor(),\n            transforms.Lambda(rescale)\n        ])\n        test_transform = transforms.Compose([\n            transforms.ToTensor(),\n            transforms.Lambda(rescale)\n        ])\n    else:\n        train_transform = transforms.Compose([transforms.RandomCrop(28, padding=4),\n                                              # transforms.RandomRotation(10),\n                                              transforms.ToTensor(),\n                                              transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])\n        test_transform = transforms.Compose([transforms.ToTensor(),\n                                             transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])\n\n    train_datasets = datasets.MNIST(\n        root=DATA_DIR, train=True, transform=test_transform if same_da else train_transform, download=True)\n    test_datasets = datasets.MNIST(\n        root=DATA_DIR, train=False, transform=test_transform, download=True)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=num_workers\n    )\n\n    return train_loader, test_loader, False, None\n\n\ndef get_fashion_data(batch_size, num_workers=8, same_da=False, **kwargs):\n    \"\"\"\n    获取fashion MNIST数据\n    http://arxiv.org/abs/1708.07747\n    :param batch_size: batch size\n    :param same_da: 为训练集使用测试集的增广方法\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    train_transform = transforms.Compose([transforms.RandomCrop(28, padding=4),\n                                          transforms.RandomHorizontalFlip(),\n                                          transforms.RandomRotation(10),\n                                          transforms.ToTensor()])\n    test_transform = transforms.Compose([transforms.ToTensor()])\n\n    train_datasets = datasets.FashionMNIST(\n        root=DATA_DIR, train=True, transform=test_transform if same_da else train_transform, download=True)\n    test_datasets = datasets.FashionMNIST(\n        root=DATA_DIR, train=False, transform=test_transform, download=True)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=num_workers\n    )\n\n    return train_loader, test_loader, False, None\n\n\ndef get_cifar10_data(batch_size, num_workers=8, same_da=False, **kwargs):\n    # \"\"\"\n    # 获取CIFAR10数据\n    #  https://www.cs.toronto.edu/~kriz/cifar.html\n    # :param batch_size: batch size\n    # :param kwargs:\n    # :return: (train loader, test loader, mixup_active, mixup_fn)\n    # \"\"\"\n    # train_datasets, _ = build_dataset(True, 32, 'CIFAR10', DATA_DIR, same_da)\n    # test_datasets, _ = build_dataset(False, 32, 'CIFAR10', DATA_DIR, same_da)\n    #\n    # train_loader = torch.utils.data.DataLoader(\n    #     train_datasets, batch_size=batch_size,\n    #     pin_memory=True, drop_last=True, shuffle=True,\n    #     num_workers=num_workers\n    # )\n    #\n    # test_loader = torch.utils.data.DataLoader(\n    #     test_datasets, batch_size=batch_size,\n    #     pin_memory=True, drop_last=False,\n    #     num_workers=num_workers\n    # )\n    normalize = transforms.Normalize(CIFAR10_DEFAULT_MEAN, CIFAR10_DEFAULT_STD)\n    transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),\n                                          CIFAR10Policy(),\n                                          transforms.ToTensor(),\n                                          Cutout(n_holes=1, length=16),\n                                          normalize])\n    transform_test = transforms.Compose([transforms.ToTensor(), normalize])\n    train_dataset = datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=transform_train)\n    test_dataset = datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=transform_test)\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset,  batch_size=batch_size,\n        shuffle=True, num_workers=num_workers,\n        pin_memory=True\n    )\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        shuffle=False, num_workers=num_workers,\n        pin_memory=True\n    )\n    return train_loader, test_loader, None, None\n\n\ndef get_cifar100_data(batch_size, num_workers=8, same_data=False, *args, **kwargs):\n    # \"\"\"\n    # 获取CIFAR100数据\n    # https://www.cs.toronto.edu/~kriz/cifar.html\n    # :param batch_size: batch size\n    # :param kwargs:\n    # :return: (train loader, test loader, mixup_active, mixup_fn)\n    # \"\"\"\n    # train_datasets, _ = build_dataset(True, 32, 'CIFAR100', DATA_DIR, same_data)\n    # test_datasets, _ = build_dataset(False, 32, 'CIFAR100', DATA_DIR, same_data)\n    #\n    # train_loader = torch.utils.data.DataLoader(\n    #     train_datasets, batch_size=batch_size,\n    #     pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers\n    # )\n    #\n    # test_loader = torch.utils.data.DataLoader(\n    #     test_datasets, batch_size=batch_size,\n    #     pin_memory=True, drop_last=False, num_workers=num_workers\n    # )\n    # return train_loader, test_loader, False, None\n    normalize = transforms.Normalize(CIFAR100_DEFAULT_MEAN, CIFAR100_DEFAULT_STD)\n    transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),\n                                          CIFAR10Policy(),\n                                          transforms.ToTensor(),\n                                          Cutout(n_holes=1, length=16),\n                                          normalize])\n    transform_test = transforms.Compose([transforms.ToTensor(), normalize])\n    train_dataset = datasets.CIFAR100(root=DATA_DIR, train=True, download=True, transform=transform_train)\n    test_dataset = datasets.CIFAR100(root=DATA_DIR, train=False, download=True, transform=transform_test)\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset,  batch_size=batch_size,\n        shuffle=True, num_workers=num_workers,\n        pin_memory=True\n    )\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        shuffle=False, num_workers=num_workers,\n        pin_memory=True\n    )\n    return train_loader, test_loader, None, None\n\n\ndef get_imnet_data(args, _logger, data_config, num_aug_splits, **kwargs):\n    \"\"\"\n    获取ImageNet数据集\n    http://arxiv.org/abs/1409.0575\n    :param args: 其他的参数\n    :param _logger: 日志路径\n    :param data_config: 增强策略\n    :param num_aug_splits: 不同增强策略的数量\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    train_dir = os.path.join(DATA_DIR, 'ILSVRC2012/train')\n    if not os.path.exists(train_dir):\n        _logger.error(\n            'Training folder does not exist at: {}'.format(train_dir))\n        exit(1)\n    dataset_train = ImageDataset(train_dir)\n    # collate_fn = None\n    # mixup_fn = None\n    # mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None\n    # if mixup_active:\n    #     mixup_args = dict(\n    #         mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,\n    #         prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,\n    #         label_smoothing=args.smoothing, num_classes=args.num_classes)\n    #     if args.prefetcher:\n    #         # collate conflict (need to support deinterleaving in collate mixup)\n    #         assert not num_aug_splits\n    #         collate_fn = FastCollateMixup(**mixup_args)\n    #     else:\n    #         mixup_fn = Mixup(**mixup_args)\n\n    # if num_aug_splits > 1:\n    #     dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)\n\n    train_interpolation = args.train_interpolation\n    if args.no_aug or not train_interpolation:\n        train_interpolation = data_config['interpolation']\n    loader_train = create_loader(\n        dataset_train,\n        input_size=data_config['input_size'],\n        batch_size=args.batch_size,\n        is_training=True,\n        use_prefetcher=args.prefetcher,\n        no_aug=args.no_aug,\n        # re_prob=args.reprob,\n        # re_mode=args.remode,\n        # re_count=args.recount,\n        # re_split=args.resplit,\n        scale=args.scale,\n        ratio=args.ratio,\n        hflip=args.hflip,\n        # vflip=args.vflip,\n        # color_jitter=args.color_jitter,\n        # auto_augment=args.aa,\n        # num_aug_splits=num_aug_splits,\n        interpolation=train_interpolation,\n        mean=data_config['mean'],\n        std=data_config['std'],\n        num_workers=args.workers,\n        distributed=args.distributed,\n        # collate_fn=collate_fn,\n        pin_memory=args.pin_mem,\n        # use_multi_epochs_loader=args.use_multi_epochs_loader\n    )\n\n    eval_dir = os.path.join(DATA_DIR, 'ILSVRC2012/val')\n    if not os.path.isdir(eval_dir):\n        eval_dir = os.path.join(DATA_DIR, 'ILSVRC2012/validation')\n        if not os.path.isdir(eval_dir):\n            _logger.error(\n                'Validation folder does not exist at: {}'.format(eval_dir))\n            exit(1)\n    dataset_eval = ImageDataset(eval_dir)\n\n    loader_eval = create_loader(\n        dataset_eval,\n        input_size=data_config['input_size'],\n        batch_size=args.validation_batch_size_multiplier * args.batch_size,\n        is_training=False,\n        use_prefetcher=args.prefetcher,\n        interpolation=data_config['interpolation'],\n        mean=data_config['mean'],\n        std=data_config['std'],\n        num_workers=args.workers,\n        distributed=args.distributed,\n        crop_pct=data_config['crop_pct'],\n        pin_memory=args.pin_mem,\n    )\n    return loader_train, loader_eval, None, None\n\n\ndef get_dvsg_data(batch_size, step, **kwargs):\n    \"\"\"\n    获取DVS Gesture数据\n    DOI: 10.1109/CVPR.2017.781\n    :param batch_size: batch size\n    :param step: 仿真步长\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    sensor_size = tonic.datasets.DVSGesture.sensor_size\n    size = kwargs['size'] if 'size' in kwargs else 48\n\n    train_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        # tonic.transforms.DropEvent(p=0.1),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),\n    ])\n    test_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),\n    ])\n\n    train_dataset = tonic.datasets.DVSGesture(os.path.join(DATA_DIR, 'DVS/DVSGesture'),\n                                              transform=train_transform, train=True)\n    test_dataset = tonic.datasets.DVSGesture(os.path.join(DATA_DIR, 'DVS/DVSGesture'),\n                                             transform=test_transform, train=False)\n\n    train_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        lambda x: dvs_channel_check_expend(x),\n        transforms.RandomCrop(size, padding=size // 12),\n        # lambda x: event_drop(x),\n        # transforms.RandomHorizontalFlip(),\n        # transforms.RandomRotation(15)\n    ])\n    test_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        lambda x: dvs_channel_check_expend(x),\n    ])\n    if 'rand_aug' in kwargs.keys():\n        if kwargs['rand_aug'] is True:\n            n = kwargs['randaug_n']\n            m = kwargs['randaug_m']\n            train_transform.transforms.insert(2, RandAugment(m=m, n=n))\n\n    # if 'temporal_flatten' in kwargs.keys():\n    #     if kwargs['temporal_flatten'] is True:\n    #         train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n    #         test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n\n    train_dataset = DiskCachedDataset(train_dataset,\n                                      cache_path=os.path.join(DATA_DIR, 'DVS/DVSGesture/train_cache_{}'.format(step)),\n                                      transform=train_transform, num_copies=3)\n    test_dataset = DiskCachedDataset(test_dataset,\n                                     cache_path=os.path.join(DATA_DIR, 'DVS/DVSGesture/test_cache_{}'.format(step)),\n                                     transform=test_transform, num_copies=3)\n\n    mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n    mixup_active = cut_mix | event_mix | mix_up\n\n    if cut_mix:\n        train_dataset = CutMix(train_dataset,\n                               beta=beta,\n                               prob=prob,\n                               num_mix=num,\n                               num_class=num_classes,\n                               noise=noise)\n\n    if event_mix:\n        train_dataset = EventMix(train_dataset,\n                                 beta=beta,\n                                 prob=prob,\n                                 num_mix=num,\n                                 num_class=num_classes,\n                                 noise=noise,\n                                 gaussian_n=gaussian_n)\n    if mix_up:\n        train_dataset = MixUp(train_dataset,\n                              beta=beta,\n                              prob=prob,\n                              num_mix=num,\n                              num_class=num_classes,\n                              noise=noise)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size,\n        pin_memory=True, drop_last=True, num_workers=8,\n        shuffle=True,\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=2,\n        shuffle=False,\n    )\n\n    return train_loader, test_loader, mixup_active, None\n\n\ndef get_dvsc10_data(batch_size, step, **kwargs):\n    \"\"\"\n    获取DVS CIFAR10数据\n    http://journal.frontiersin.org/article/10.3389/fnins.2017.00309/full\n    :param batch_size: batch size\n    :param step: 仿真步长\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    size = kwargs['size'] if 'size' in kwargs else 48\n    sensor_size = tonic.datasets.CIFAR10DVS.sensor_size\n    train_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        # tonic.transforms.DropEvent(p=0.1),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    test_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    train_dataset = tonic.datasets.CIFAR10DVS(os.path.join(DATA_DIR, 'DVS/DVS_Cifar10'), transform=train_transform)\n    test_dataset = tonic.datasets.CIFAR10DVS(os.path.join(DATA_DIR, 'DVS/DVS_Cifar10'), transform=test_transform)\n\n    train_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        # lambda x: TemporalShift(x, .01),\n        # lambda x: drop(x, 0.15),\n        # lambda x: ShearX(x, 15),\n        # lambda x: ShearY(x, 15),\n        # lambda x: TranslateX(x, 0.225),\n        # lambda x: TranslateY(x, 0.225),\n        # lambda x: Rotate(x, 15),\n        # lambda x: CutoutAbs(x, 0.25),\n        # lambda x: CutoutTemporal(x, 0.25),\n        # lambda x: GaussianBlur(x, 0.5),\n        # lambda x: SaltAndPepperNoise(x, 0.1),\n        # transforms.Normalize(DVSCIFAR10_MEAN_16, DVSCIFAR10_STD_16),\n        transforms.RandomCrop(size, padding=size // 12),\n        transforms.RandomHorizontalFlip(),\n        # lambda x: event_drop(x),\n        # transforms.RandomRotation(15)\n    ])\n    test_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n    ])\n\n    if 'rand_aug' in kwargs.keys():\n        if kwargs['rand_aug'] is True:\n            n = kwargs['randaug_n']\n            m = kwargs['randaug_m']\n            # print('randaug', m, n)\n            train_transform.transforms.insert(2, RandAugment(m=m, n=n))\n\n    # if 'temporal_flatten' in kwargs.keys():\n    #     if kwargs['temporal_flatten'] is True:\n    #         train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n    #         test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n\n    train_dataset = DiskCachedDataset(train_dataset,\n                                      cache_path=os.path.join(DATA_DIR, 'DVS/DVS_Cifar10/train_cache_{}'.format(step)),\n                                      transform=train_transform)\n    test_dataset = DiskCachedDataset(test_dataset,\n                                     cache_path=os.path.join(DATA_DIR, 'DVS/DVS_Cifar10/test_cache_{}'.format(step)),\n                                     transform=test_transform)\n\n    num_train = len(train_dataset)\n    num_per_cls = num_train // 10\n    indices_train, indices_test = [], []\n    portion = kwargs['portion'] if 'portion' in kwargs else .9\n    for i in range(10):\n        indices_train.extend(\n            list(range(i * num_per_cls, round(i * num_per_cls + num_per_cls * portion))))\n        indices_test.extend(\n            list(range(round(i * num_per_cls + num_per_cls * portion), (i + 1) * num_per_cls)))\n\n    mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n    mixup_active = cut_mix | event_mix | mix_up\n\n    if cut_mix:\n        # print('cut_mix', beta, prob, num, num_classes)\n        train_dataset = CutMix(train_dataset,\n                               beta=beta,\n                               prob=prob,\n                               num_mix=num,\n                               num_class=num_classes,\n                               indices=indices_train,\n                               noise=noise)\n\n    if event_mix:\n        train_dataset = EventMix(train_dataset,\n                                 beta=beta,\n                                 prob=prob,\n                                 num_mix=num,\n                                 num_class=num_classes,\n                                 indices=indices_train,\n                                 noise=noise,\n                                 gaussian_n=gaussian_n)\n\n    if mix_up:\n        train_dataset = MixUp(train_dataset,\n                              beta=beta,\n                              prob=prob,\n                              num_mix=num,\n                              num_class=num_classes,\n                              indices=indices_train,\n                              noise=noise)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size,\n        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_train),\n        pin_memory=True, drop_last=True, num_workers=8\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_test),\n        pin_memory=True, drop_last=False, num_workers=2\n    )\n\n    return train_loader, test_loader, mixup_active, None\n\n\ndef get_nmnist_data(batch_size, step, **kwargs):\n    \"\"\"\n    获取DVS CIFAR10数据\n    http://journal.frontiersin.org/article/10.3389/fnins.2017.00309/full\n    :param batch_size: batch size\n    :param step: 仿真步长\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    size = kwargs['size'] if 'size' in kwargs else 48\n    sensor_size = tonic.datasets.NMNIST.sensor_size\n    train_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        # tonic.transforms.DropEvent(p=0.1),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    test_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    train_dataset = tonic.datasets.NMNIST(os.path.join(DATA_DIR, 'DVS/NMNIST'), transform=train_transform)\n    test_dataset = tonic.datasets.NMNIST(os.path.join(DATA_DIR, 'DVS/NMNIST'), transform=test_transform)\n\n    train_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        # lambda x: TemporalShift(x, .01),\n        # lambda x: drop(x, 0.15),\n        # lambda x: ShearX(x, 15),\n        # lambda x: ShearY(x, 15),\n        # lambda x: TranslateX(x, 0.225),\n        # lambda x: TranslateY(x, 0.225),\n        # lambda x: Rotate(x, 15),\n        # lambda x: CutoutAbs(x, 0.25),\n        # lambda x: CutoutTemporal(x, 0.25),\n        # lambda x: GaussianBlur(x, 0.5),\n        # lambda x: SaltAndPepperNoise(x, 0.1),\n        # transforms.Normalize(DVSCIFAR10_MEAN_16, DVSCIFAR10_STD_16),\n        transforms.RandomCrop(size, padding=size // 12),\n        transforms.RandomHorizontalFlip(),\n        # lambda x: event_drop(x),\n        # transforms.RandomRotation(15)\n    ])\n    test_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n    ])\n\n    if 'rand_aug' in kwargs.keys():\n        if kwargs['rand_aug'] is True:\n            n = kwargs['randaug_n']\n            m = kwargs['randaug_m']\n            # print('randaug', m, n)\n            train_transform.transforms.insert(2, RandAugment(m=m, n=n))\n\n    # if 'temporal_flatten' in kwargs.keys():\n    #     if kwargs['temporal_flatten'] is True:\n    #         train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n    #         test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n\n    train_dataset = DiskCachedDataset(train_dataset,\n                                      cache_path=os.path.join(DATA_DIR, 'DVS/NMNIST/train_cache_{}'.format(step)),\n                                      transform=train_transform)\n    test_dataset = DiskCachedDataset(test_dataset,\n                                     cache_path=os.path.join(DATA_DIR, 'DVS/NMNIST/test_cache_{}'.format(step)),\n                                     transform=test_transform)\n\n    num_train = len(train_dataset)\n    num_per_cls = num_train // 10\n    indices_train, indices_test = [], []\n    portion = kwargs['portion'] if 'portion' in kwargs else .9\n    for i in range(10):\n        indices_train.extend(\n            list(range(i * num_per_cls, round(i * num_per_cls + num_per_cls * portion))))\n        indices_test.extend(\n            list(range(round(i * num_per_cls + num_per_cls * portion), (i + 1) * num_per_cls)))\n\n    mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n    mixup_active = cut_mix | event_mix | mix_up\n\n    if cut_mix:\n        # print('cut_mix', beta, prob, num, num_classes)\n        train_dataset = CutMix(train_dataset,\n                               beta=beta,\n                               prob=prob,\n                               num_mix=num,\n                               num_class=num_classes,\n                               indices=indices_train,\n                               noise=noise)\n\n    if event_mix:\n        train_dataset = EventMix(train_dataset,\n                                 beta=beta,\n                                 prob=prob,\n                                 num_mix=num,\n                                 num_class=num_classes,\n                                 indices=indices_train,\n                                 noise=noise,\n                                 gaussian_n=gaussian_n)\n\n    if mix_up:\n        train_dataset = MixUp(train_dataset,\n                              beta=beta,\n                              prob=prob,\n                              num_mix=num,\n                              num_class=num_classes,\n                              indices=indices_train,\n                              noise=noise)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size,\n        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_train),\n        pin_memory=True, drop_last=True, num_workers=8\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_test),\n        pin_memory=True, drop_last=False, num_workers=2\n    )\n\n    return train_loader, test_loader, mixup_active, None\n\n\ndef get_NCALTECH101_data(batch_size, step, **kwargs):\n    \"\"\"\n    获取NCaltech101数据\n    http://journal.frontiersin.org/Article/10.3389/fnins.2015.00437/abstract\n    :param batch_size: batch size\n    :param step: 仿真步长\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    sensor_size = braincog.datasets.ncaltech101.NCALTECH101.sensor_size\n    cls_count = braincog.datasets.ncaltech101.NCALTECH101.cls_count\n    dataset_length = braincog.datasets.ncaltech101.NCALTECH101.length\n    portion = kwargs['portion'] if 'portion' in kwargs else .9\n    size = kwargs['size'] if 'size' in kwargs else 48\n    # print('portion', portion)\n    train_sample_weight = []\n    train_sample_index = []\n    train_count = 0\n    test_sample_index = []\n    idx_begin = 0\n    for count in cls_count:\n        sample_weight = dataset_length / count\n        train_sample = round(portion * count)\n        test_sample = count - train_sample\n        train_count += train_sample\n        train_sample_weight.extend(\n            [sample_weight] * train_sample\n        )\n        train_sample_weight.extend(\n            [0.] * test_sample\n        )\n        train_sample_index.extend(\n            list((range(idx_begin, idx_begin + train_sample)))\n        )\n        test_sample_index.extend(\n            list(range(idx_begin + train_sample, idx_begin + train_sample + test_sample))\n        )\n        idx_begin += count\n\n    train_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_sample_weight, train_count)\n    test_sampler = torch.utils.data.sampler.SubsetRandomSampler(test_sample_index)\n\n    train_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        # tonic.transforms.DropEvent(p=0.1),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    test_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n\n    train_dataset = braincog.datasets.ncaltech101.NCALTECH101(os.path.join(DATA_DIR, 'DVS/NCALTECH101'), transform=train_transform)\n    test_dataset = braincog.datasets.ncaltech101.NCALTECH101(os.path.join(DATA_DIR, 'DVS/NCALTECH101'), transform=test_transform)\n\n    train_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        # lambda x: print(x.shape),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        transforms.RandomCrop(size, padding=size // 12),\n        # transforms.RandomHorizontalFlip(),\n        # transforms.RandomRotation(15)\n    ])\n    test_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        # lambda x: temporal_flatten(x),\n    ])\n    if 'rand_aug' in kwargs.keys():\n        if kwargs['rand_aug'] is True:\n            n = kwargs['randaug_n']\n            m = kwargs['randaug_m']\n            train_transform.transforms.insert(2, RandAugment(m=m, n=n))\n\n    # if 'temporal_flatten' in kwargs.keys():\n    #     if kwargs['temporal_flatten'] is True:\n    #         train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n    #         test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n\n    train_dataset = DiskCachedDataset(train_dataset,\n                                      cache_path=os.path.join(DATA_DIR, 'DVS/NCALTECH101/train_cache_{}'.format(step)),\n                                      transform=train_transform, num_copies=3)\n    test_dataset = DiskCachedDataset(test_dataset,\n                                     cache_path=os.path.join(DATA_DIR, 'DVS/NCALTECH101/test_cache_{}'.format(step)),\n                                     transform=test_transform, num_copies=3)\n\n    mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n    mixup_active = cut_mix | event_mix | mix_up\n\n    if cut_mix:\n        train_dataset = CutMix(train_dataset,\n                               beta=beta,\n                               prob=prob,\n                               num_mix=num,\n                               num_class=num_classes,\n                               indices=train_sample_index,\n                               noise=noise)\n\n    if event_mix:\n        train_dataset = EventMix(train_dataset,\n                                 beta=beta,\n                                 prob=prob,\n                                 num_mix=num,\n                                 num_class=num_classes,\n                                 indices=train_sample_index,\n                                 noise=noise,\n                                 gaussian_n=gaussian_n)\n    if mix_up:\n        train_dataset = MixUp(train_dataset,\n                              beta=beta,\n                              prob=prob,\n                              num_mix=num,\n                              num_class=num_classes,\n                              indices=train_sample_index,\n                              noise=noise)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size,\n        sampler=train_sampler,\n        pin_memory=True, drop_last=True, num_workers=8\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        sampler=test_sampler,\n        pin_memory=True, drop_last=False, num_workers=2\n    )\n\n    return train_loader, test_loader, mixup_active, None\n\n\ndef get_UCF101DVS_data(batch_size, step, **kwargs):\n    \"\"\"\n    获取DVS CIFAR10数据\n    http://journal.frontiersin.org/article/10.3389/fnins.2017.00309/full\n    :param batch_size: batch size\n    :param step: 仿真步长\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    size = kwargs['size'] if 'size' in kwargs else 48\n    sensor_size = braincog.datasets.ucf101_dvs.UCF101DVS.sensor_size\n    train_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        # tonic.transforms.DropEvent(p=0.1),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    test_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    train_dataset = braincog.datasets.ucf101_dvs.UCF101DVS(os.path.join(DATA_DIR, 'DVS/UCF101DVS'), train=True, transform=train_transform)\n    test_dataset = braincog.datasets.ucf101_dvs.UCF101DVS(os.path.join(DATA_DIR, 'DVS/UCF101DVS'), train=False, transform=test_transform)\n\n    train_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        # lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        # lambda x: TemporalShift(x, .01),\n        # lambda x: drop(x, 0.15),\n        # lambda x: ShearX(x, 15),\n        # lambda x: ShearY(x, 15),\n        # lambda x: TranslateX(x, 0.225),\n        # lambda x: TranslateY(x, 0.225),\n        # lambda x: Rotate(x, 15),\n        # lambda x: CutoutAbs(x, 0.25),\n        # lambda x: CutoutTemporal(x, 0.25),\n        # lambda x: GaussianBlur(x, 0.5),\n        # lambda x: SaltAndPepperNoise(x, 0.1),\n        # transforms.Normalize(DVSCIFAR10_MEAN_16, DVSCIFAR10_STD_16),\n        # transforms.RandomCrop(size, padding=size // 12),\n        transforms.RandomHorizontalFlip(),\n        # transforms.RandomRotation(15)\n    ])\n    test_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        # lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n    ])\n\n    if 'rand_aug' in kwargs.keys():\n        if kwargs['rand_aug'] is True:\n            n = kwargs['randaug_n']\n            m = kwargs['randaug_m']\n            # print('randaug', m, n)\n            train_transform.transforms.insert(2, RandAugment(m=m, n=n))\n\n    # if 'temporal_flatten' in kwargs.keys():\n    #     if kwargs['temporal_flatten'] is True:\n    #         train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n    #         test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n\n    train_dataset = DiskCachedDataset(train_dataset,\n                                      cache_path=os.path.join(DATA_DIR, 'UCF101DVS/train_cache_{}'.format(step)),\n                                      transform=train_transform)\n    test_dataset = DiskCachedDataset(test_dataset,\n                                     cache_path=os.path.join(DATA_DIR, 'UCF101DVS/test_cache_{}'.format(step)),\n                                     transform=test_transform)\n\n    mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n    mixup_active = cut_mix | event_mix | mix_up\n\n    if cut_mix:\n        # print('cut_mix', beta, prob, num, num_classes)\n        train_dataset = CutMix(train_dataset,\n                               beta=beta,\n                               prob=prob,\n                               num_mix=num,\n                               num_class=num_classes,\n                               noise=noise)\n\n    if event_mix:\n        train_dataset = EventMix(train_dataset,\n                                 beta=beta,\n                                 prob=prob,\n                                 num_mix=num,\n                                 num_class=num_classes,\n                                 noise=noise,\n                                 gaussian_n=gaussian_n)\n\n    if mix_up:\n        train_dataset = MixUp(train_dataset,\n                              beta=beta,\n                              prob=prob,\n                              num_mix=num,\n                              num_class=num_classes,\n                              noise=noise)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size, shuffle=True,\n        pin_memory=True, drop_last=True, num_workers=8\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size, shuffle=False,\n        pin_memory=True, drop_last=False, num_workers=2\n    )\n\n    return train_loader, test_loader, mixup_active, None\n\n\ndef get_HMDBDVS_data(batch_size, step, **kwargs):\n    sensor_size = braincog.datasets.hmdb_dvs.HMDBDVS.sensor_size\n\n    train_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        # tonic.transforms.DropEvent(p=0.1),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    test_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n\n    train_dataset = braincog.datasets.hmdb_dvs.HMDBDVS(os.path.join(DATA_DIR, 'HMDBDVS'), transform=train_transform)\n    test_dataset = braincog.datasets.hmdb_dvs.HMDBDVS(os.path.join(DATA_DIR, 'HMDBDVS'), transform=test_transform)\n\n    cls_count = train_dataset.cls_count\n    dataset_length = train_dataset.length\n\n    portion = .5\n    # portion = kwargs['portion'] if 'portion' in kwargs else .9\n    size = kwargs['size'] if 'size' in kwargs else 48\n    # print('portion', portion)\n    train_sample_weight = []\n    train_sample_index = []\n    train_count = 0\n    test_sample_index = []\n    idx_begin = 0\n    for count in cls_count:\n        sample_weight = dataset_length / count\n        train_sample = round(portion * count)\n        test_sample = count - train_sample\n        train_count += train_sample\n        train_sample_weight.extend(\n            [sample_weight] * train_sample\n        )\n        train_sample_weight.extend(\n            [0.] * test_sample\n        )\n        lst = list(range(idx_begin, idx_begin + train_sample + test_sample))\n        random.seed(0)\n        random.shuffle(lst)\n        train_sample_index.extend(\n            lst[:train_sample]\n            # list((range(idx_begin, idx_begin + train_sample)))\n        )\n        test_sample_index.extend(\n            lst[train_sample:train_sample + test_sample]\n            # list(range(idx_begin + train_sample, idx_begin + train_sample + test_sample))\n        )\n        idx_begin += count\n\n    train_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_sample_weight, train_count)\n    test_sampler = torch.utils.data.sampler.SubsetRandomSampler(test_sample_index)\n\n    train_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        # lambda x: print(x.shape),\n        # lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        # transforms.RandomCrop(size, padding=size // 12),\n        # transforms.RandomHorizontalFlip(),\n        # transforms.RandomRotation(15)\n    ])\n    test_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        # lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        # lambda x: temporal_flatten(x),\n    ])\n    if 'rand_aug' in kwargs.keys():\n        if kwargs['rand_aug'] is True:\n            n = kwargs['randaug_n']\n            m = kwargs['randaug_m']\n            train_transform.transforms.insert(2, RandAugment(m=m, n=n))\n\n    # if 'temporal_flatten' in kwargs.keys():\n    #     if kwargs['temporal_flatten'] is True:\n    #         train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n    #         test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n\n    train_dataset = DiskCachedDataset(train_dataset,\n                                      cache_path=os.path.join(DATA_DIR, 'HMDBDVS/train_cache_{}'.format(step)),\n                                      transform=train_transform, num_copies=3)\n    test_dataset = DiskCachedDataset(test_dataset,\n                                     cache_path=os.path.join(DATA_DIR, 'HMDBDVS/test_cache_{}'.format(step)),\n                                     transform=test_transform, num_copies=3)\n\n    mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n    mixup_active = cut_mix | event_mix | mix_up\n\n    if cut_mix:\n        train_dataset = CutMix(train_dataset,\n                               beta=beta,\n                               prob=prob,\n                               num_mix=num,\n                               num_class=num_classes,\n                               indices=train_sample_index,\n                               noise=noise)\n\n    if event_mix:\n        train_dataset = EventMix(train_dataset,\n                                 beta=beta,\n                                 prob=prob,\n                                 num_mix=num,\n                                 num_class=num_classes,\n                                 indices=train_sample_index,\n                                 noise=noise,\n                                 gaussian_n=gaussian_n)\n    if mix_up:\n        train_dataset = MixUp(train_dataset,\n                              beta=beta,\n                              prob=prob,\n                              num_mix=num,\n                              num_class=num_classes,\n                              indices=train_sample_index,\n                              noise=noise)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size,\n        sampler=train_sampler,\n        pin_memory=True, drop_last=True, num_workers=8\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        sampler=test_sampler,\n        pin_memory=True, drop_last=False, num_workers=2\n    )\n\n    return train_loader, test_loader, mixup_active, None\n\n\n# def get_NCARS_data(batch_size, step, **kwargs):\n#     \"\"\"\n#     获取N-Cars数据\n#     https://ieeexplore.ieee.org/document/8578284/\n#     :param batch_size: batch size\n#     :param step: 仿真步长\n#     :param kwargs:\n#     :return: (train loader, test loader, mixup_active, mixup_fn)\n#     \"\"\"\n#     sensor_size = tonic.datasets.NCARS.sensor_size\n#     size = kwargs['size'] if 'size' in kwargs else 48\n#\n#     train_transform = transforms.Compose([\n#         # tonic.transforms.Denoise(filter_time=10000),\n#         # tonic.transforms.DropEvent(p=0.1),\n#         tonic.transforms.ToFrame(sensor_size=None, n_time_bins=step),\n#     ])\n#     test_transform = transforms.Compose([\n#         # tonic.transforms.Denoise(filter_time=10000),\n#         tonic.transforms.ToFrame(sensor_size=None, n_time_bins=step),\n#     ])\n#\n#     train_dataset = tonic.datasets.NCARS(os.path.join(DATA_DIR, 'DVS/NCARS'), transform=train_transform, train=True)\n#     test_dataset = tonic.datasets.NCARS(os.path.join(DATA_DIR, 'DVS/NCARS'), transform=test_transform, train=False)\n#\n#     train_transform = transforms.Compose([\n#         lambda x: torch.tensor(x, dtype=torch.float),\n#         lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n#         lambda x: dvs_channel_check_expend(x),\n#         transforms.RandomCrop(size, padding=size // 12),\n#         transforms.RandomHorizontalFlip(),\n#         transforms.RandomRotation(15)\n#     ])\n#     test_transform = transforms.Compose([\n#         lambda x: torch.tensor(x, dtype=torch.float),\n#         lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n#         lambda x: dvs_channel_check_expend(x),\n#     ])\n#     if 'rand_aug' in kwargs.keys():\n#         if kwargs['rand_aug'] is True:\n#             n = kwargs['randaug_n']\n#             m = kwargs['randaug_m']\n#             train_transform.transforms.insert(2, RandAugment(m=m, n=n))\n#\n#     # if 'temporal_flatten' in kwargs.keys():\n#     #     if kwargs['temporal_flatten'] is True:\n#     #         train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n#     #         test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n#\n#     train_dataset = DiskCachedDataset(train_dataset,\n#                                       cache_path=os.path.join(DATA_DIR, 'DVS/NCARS/train_cache_{}'.format(step)),\n#                                       transform=train_transform, num_copies=3)\n#     test_dataset = DiskCachedDataset(test_dataset,\n#                                      cache_path=os.path.join(DATA_DIR, 'DVS/NCARS/test_cache_{}'.format(step)),\n#                                      transform=test_transform, num_copies=3)\n#\n#     mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n#     mixup_active = cut_mix | event_mix | mix_up\n#\n#     if cut_mix:\n#         train_dataset = CutMix(train_dataset,\n#                                beta=beta,\n#                                prob=prob,\n#                                num_mix=num,\n#                                num_class=num_classes,\n#                                noise=noise)\n#\n#     if event_mix:\n#         train_dataset = EventMix(train_dataset,\n#                                  beta=beta,\n#                                  prob=prob,\n#                                  num_mix=num,\n#                                  num_class=num_classes,\n#                                  noise=noise,\n#                                  gaussian_n=gaussian_n)\n#     if mix_up:\n#         train_dataset = MixUp(train_dataset,\n#                               beta=beta,\n#                               prob=prob,\n#                               num_mix=num,\n#                               num_class=num_classes,\n#                               noise=noise)\n#\n#     train_loader = torch.utils.data.DataLoader(\n#         train_dataset, batch_size=batch_size,\n#         pin_memory=True, drop_last=True, num_workers=8,\n#         shuffle=True,\n#     )\n#\n#     test_loader = torch.utils.data.DataLoader(\n#         test_dataset, batch_size=batch_size,\n#         pin_memory=True, drop_last=False, num_workers=2,\n#         shuffle=False,\n#     )\n#\n#     return train_loader, test_loader, mixup_active, None\n\n\ndef get_nomni_data(batch_size, train_portion=1., **kwargs):\n    \"\"\"\n    获取N-Omniglot数据\n    :param batch_size:batch的大小\n    :param data_mode:一共full nkks pair三种模式\n    :param frames_num:一个样本帧的个数\n    :param data_type:event frequency两种模式\n    \"\"\"\n    data_mode = kwargs[\"data_mode\"] if \"data_mode\" in kwargs else \"full\"\n    frames_num = kwargs[\"frames_num\"] if \"frames_num\" in kwargs else 10\n    data_type = kwargs[\"data_type\"] if \"data_type\" in kwargs else \"event\"\n\n    train_transform = transforms.Compose([\n        transforms.Resize((64, 64))])\n    test_transform = transforms.Compose([\n        transforms.Resize((64, 64))])\n    if data_mode == \"full\":\n        train_datasets = NOmniglotfull(root=os.path.join(DATA_DIR, 'DVS/NOmniglot'), train=True, frames_num=frames_num,\n                                       data_type=data_type,\n                                       transform=train_transform)\n        test_datasets = NOmniglotfull(root=os.path.join(DATA_DIR, 'DVS/NOmniglot'), train=False, frames_num=frames_num,\n                                      data_type=data_type,\n                                      transform=test_transform)\n\n    elif data_mode == \"nkks\":\n        train_datasets = NOmniglotNWayKShot(os.path.join(DATA_DIR, 'DVS/NOmniglot'),\n                                            n_way=kwargs[\"n_way\"],\n                                            k_shot=kwargs[\"k_shot\"],\n                                            k_query=kwargs[\"k_query\"],\n                                            train=True,\n                                            frames_num=frames_num,\n                                            data_type=data_type,\n                                            transform=train_transform)\n        test_datasets = NOmniglotNWayKShot(os.path.join(DATA_DIR, 'DVS/NOmniglot'),\n                                           n_way=kwargs[\"n_way\"],\n                                           k_shot=kwargs[\"k_shot\"],\n                                           k_query=kwargs[\"k_query\"],\n                                           train=False,\n                                           frames_num=frames_num,\n                                           data_type=data_type,\n                                           transform=test_transform)\n    elif data_mode == \"pair\":\n        train_datasets = NOmniglotTrainSet(root=os.path.join(DATA_DIR, 'DVS/NOmniglot'), use_frame=True,\n                                           frames_num=frames_num, data_type=data_type,\n                                           use_npz=False, resize=105)\n        test_datasets = NOmniglotTestSet(root=os.path.join(DATA_DIR, 'DVS/NOmniglot'), time=2000, way=kwargs[\"n_way\"],\n                                         shot=kwargs[\"k_shot\"], use_frame=True,\n                                         frames_num=frames_num, data_type=data_type, use_npz=False, resize=105)\n\n    else:\n        pass\n\n    train_loader = torch.utils.data.DataLoader(\n        train_datasets, batch_size=batch_size, num_workers=12,\n        pin_memory=True, drop_last=True, shuffle=True\n    )\n    test_loader = torch.utils.data.DataLoader(\n        test_datasets, batch_size=batch_size, num_workers=12,\n        pin_memory=True, drop_last=False\n    )\n    return train_loader, test_loader, None, None\n"
  },
  {
    "path": "examples/Spiking-Transformers/main.py",
    "content": "import argparse\nimport time\n\nimport timm.models\nimport yaml\nimport os\nimport random as buildin_random\nimport logging\nfrom collections import OrderedDict\nfrom contextlib import suppress\nfrom datetime import datetime\n\nfrom braincog.base.node.node import *\nfrom braincog.utils import *\nfrom braincog.base.utils.criterions import *\n# from braincog.datasets.datasets import *\nfrom datasets import *\nfrom braincog.model_zoo.resnet import *\nfrom braincog.model_zoo.convnet import *\nfrom braincog.model_zoo.vgg_snn import VGG_SNN, SNN5\n# from braincog.model_zoo.fc_snn import SHD_SNN\nfrom braincog.model_zoo.resnet19_snn import resnet19\n#from braincog.model_zoo.sew_resnet import sew_resnet18, sew_resnet34, sew_resnet50\nfrom braincog.utils import save_feature_map, setup_seed\nfrom braincog.base.utils.visualization import plot_tsne_3d, plot_tsne, plot_confusion_matrix, plot_mem_distribution\n\nimport torch\nimport torch.nn as nn\nimport torchvision.utils\nfrom torch.nn.parallel import DistributedDataParallel as NativeDDP\n\nfrom timm.data import ImageDataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset\nfrom timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model, register_model\nfrom timm.utils import *\nfrom timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy\nfrom timm.optim import create_optimizer\nfrom timm.scheduler import create_scheduler\nfrom timm.utils import ApexScaler, NativeScaler\n\nfrom torch.utils.tensorboard import SummaryWriter\n\n# load spiking transformer models\nfrom models.spikformer import spikformer\nfrom models.spikformer_dvs import spikformer_dvs\nfrom models.spike_driven_transformer import sd_transformer\nfrom models.spike_driven_transformer_dvs import sd_transformer_dvs\nfrom models.spike_driven_transformer_v2 import sd_transformer_v2\nfrom models.spike_driven_transformer_v2_dvs import sd_transformer_v2_dvs\n\n# choose ur device here\nos.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n\n\ntorch.backends.cudnn.benchmark = True\n_logger = logging.getLogger('train')\n\n# The first arg parser parses out only the --config argument, this argument is used to\n# load a yaml file containing key-values that override the defaults for the main parser below\nconfig_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)\nparser.add_argument('-c', '--config', default='', type=str, metavar='FILE',\n                    help='YAML config file specifying default arguments')\n\nparser = argparse.ArgumentParser(description='SNN Training and Evaluating')\n\n# Model parameters\nparser.add_argument('--dataset', default='dvsc10', type=str)\nparser.add_argument('--model', default='spikformer', type=str, metavar='MODEL',\n                    help='Name of model to train (default: \"countception\"')\nparser.add_argument('--pretrained', action='store_true', default=False,\n                    help='Start with pretrained version of specified network (if avail)')\nparser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',\n                    help='Initialize model from this checkpoint (default: none)')\nparser.add_argument('--resume', default='', type=str, metavar='PATH',\n                    help='Resume full model and optimizer state from checkpoint (default: none)')\nparser.add_argument('--eval_checkpoint', default='', type=str, metavar='PATH',\n                    help='path to eval checkpoint (default: none)')\nparser.add_argument('--no-resume-opt', action='store_true', default=False,\n                    help='prevent resume of optimizer state when resuming model')\nparser.add_argument('--num-classes', type=int, default=10, metavar='N',\n                    help='number of label classes (default: 1000)')\nparser.add_argument('--gp', default=None, type=str, metavar='POOL',\n                    help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')\n\n# Dataset parameters for static datasets\nparser.add_argument('--img-size', type=int, default=224, metavar='N',\n                    help='Image patch size (default: None => model default)')\nparser.add_argument('--crop-pct', default=None, type=float,\n                    metavar='N', help='inputs image center crop percent (for validation only)')\nparser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',\n                    help='Override mean pixel value of dataset')\nparser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',\n                    help='Override std deviation of of dataset')\nparser.add_argument('--interpolation', default='', type=str, metavar='NAME',\n                    help='Image resize interpolation type (overrides model)')\n\n# Dataloader parameters\nparser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',\n                    help='inputs batch size for training (default: 128)')\nparser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',\n                    help='ratio of validation batch size to training batch size (default: 1)')\n\n# Optimizer parameters\nparser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',\n                    help='Optimizer (default: \"adamw\"')\nparser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',\n                    help='Optimizer Epsilon (default: None, use opt default)')\nparser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',\n                    help='Optimizer Betas (default: None, use opt default)')\nparser.add_argument('--momentum', type=float, default=0.9, metavar='M',\n                    help='Optimizer momentum (default: 0.9)')\nparser.add_argument('--weight-decay', type=float, default=1e-4,\n                    help='weight decay (default: 0.01 for adamw)')\nparser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',\n                    help='Clip gradient norm (default: None, no clipping)')\nparser.add_argument('--adam-epoch', type=int, default=1000, help='lamb switch to adamw')\n\n# Learning rate schedule parameters\nparser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',\n                    help='LR scheduler (default: \"cosine\"')\nparser.add_argument('--lr', type=float, default=5e-3, metavar='LR',\n                    help='learning rate (default: 0.01)')\nparser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',\n                    help='learning rate noise on/off epoch percentages')\nparser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',\n                    help='learning rate noise limit percent (default: 0.67)')\nparser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',\n                    help='learning rate noise std-dev (default: 1.0)')\nparser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',\n                    help='learning rate cycle len multiplier (default: 1.0)')\nparser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',\n                    help='learning rate cycle limit')\nparser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',\n                    help='warmup learning rate (default: 0.0001)')\nparser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',\n                    help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')\nparser.add_argument('--epochs', type=int, default=400, metavar='N',\n                    help='number of epochs to train (default: 2)')\nparser.add_argument('--start-epoch', default=None, type=int, metavar='N',\n                    help='manual epoch number (useful on restarts)')\nparser.add_argument('--decay-epochs', type=float, default=30, metavar='N',\n                    help='epoch interval to decay LR')\nparser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',\n                    help='epochss to warmup LR, if scheduler supports')\nparser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',\n                    help='epochs to cooldown LR at min_lr, after cyclic schedule ends')\nparser.add_argument('--patience-epochs', type=int, default=10, metavar='N',\n                    help='patience epochs for Plateau LR scheduler (default: 10')\nparser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',\n                    help='LR decay rate (default: 0.1)')\nparser.add_argument('--power', type=int, default=1, help='power')\n\n# Augmentation & regularization parameters ONLY FOR IMAGE NET\nparser.add_argument('--no-aug', action='store_true', default=False,\n                    help='Disable all training augmentation, override other train aug args')\nparser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',\n                    help='Random resize scale (default: 0.08 1.0)')\nparser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                    help='Random resize aspect ratio (default: 0.75 1.33)')\nparser.add_argument('--hflip', type=float, default=0.5,\n                    help='Horizontal flip training aug probability')\nparser.add_argument('--vflip', type=float, default=0.,\n                    help='Vertical flip training aug probability')\nparser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',\n                    help='Color jitter factor (default: 0.4)')\nparser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',\n                    help='Use AutoAugment policy. \"v0\" or \"original\". (default: None)'),\nparser.add_argument('--aug-splits', type=int, default=0,\n                    help='Number of augmentation splits (default: 0, valid: 0 or >=2)')\nparser.add_argument('--jsd', action='store_true', default=False,\n                    help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')\nparser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',\n                    help='Random erase prob (default: 0.25)')\nparser.add_argument('--remode', type=str, default='pixel',\n                    help='Random erase mode (default: \"const\")')\nparser.add_argument('--recount', type=int, default=1,\n                    help='Random erase count (default: 1)')\nparser.add_argument('--resplit', action='store_true', default=False,\n                    help='Do not random erase first (clean) augmentation split')\nparser.add_argument('--mixup', type=float, default=0.,\n                    help='mixup alpha, mixup enabled if > 0. (default: 0.)')\nparser.add_argument('--cutmix', type=float, default=0.,\n                    help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')\nparser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,\n                    help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')\nparser.add_argument('--mixup-prob', type=float, default=0.,\n                    help='Probability of performing mixup or cutmix when either/both is enabled')\nparser.add_argument('--mixup-switch-prob', type=float, default=0.5,\n                    help='Probability of switching to cutmix when both mixup and cutmix enabled')\nparser.add_argument('--mixup-mode', type=str, default='batch',\n                    help='How to apply mixup/cutmix params. Per \"batch\", \"pair\", or \"elem\"')\nparser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',\n                    help='Turn off mixup after this epoch, disabled if 0 (default: 0)')\nparser.add_argument('--smoothing', type=float, default=0.1,\n                    help='Label smoothing (default: 0.1)')\nparser.add_argument('--train-interpolation', type=str, default='random',\n                    help='Training interpolation (random, bilinear, bicubic default: \"random\")')\nparser.add_argument('--drop', type=float, default=0.0, metavar='PCT',\n                    help='Dropout rate (default: 0.0)')\nparser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',\n                    help='Drop connect rate, DEPRECATED, use drop-path (default: None)')\nparser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',\n                    help='Drop path rate (default: None)')\nparser.add_argument('--drop-block', type=float, default=None, metavar='PCT',\n                    help='Drop block rate (default: None)')\nparser.add_argument('--newton-maxiter', default=20, type=int,\n                    help='max iterration in newton method')\nparser.add_argument('--reset-drop', action='store_true', default=False,\n                    help='whether to reset drop')\nparser.add_argument('--kernel-method', type=str, default='cuda', choices=['torch', 'cuda'],\n                    help='The implementation way of gaussian kernel method, choose from \"cuda\" and \"torch\"')\n\n# Batch norm parameters (only works with gen_efficientnet based models currently)\nparser.add_argument('--bn-tf', action='store_true', default=False,\n                    help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')\nparser.add_argument('--bn-momentum', type=float, default=None,\n                    help='BatchNorm momentum override (if not None)')\nparser.add_argument('--bn-eps', type=float, default=None,\n                    help='BatchNorm epsilon override (if not None)')\nparser.add_argument('--sync-bn', action='store_true',\n                    help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')\nparser.add_argument('--dist-bn', type=str, default='',\n                    help='Distribute BatchNorm stats between node after each epoch (\"broadcast\", \"reduce\", or \"\")')\nparser.add_argument('--split-bn', action='store_true',\n                    help='Enable separate BN layers per augmentation split.')\n\n# Model Exponential Moving Average\nparser.add_argument('--model-ema', action='store_true', default=False,\n                    help='Enable tracking moving average of model weights')\nparser.add_argument('--model-ema-force-cpu', action='store_true', default=False,\n                    help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')\nparser.add_argument('--model-ema-decay', type=float, default=0.99996,\n                    help='decay factor for model weights moving average (default: 0.9998)')\n\n# Misc\nparser.add_argument('--seed', type=int, default=42, metavar='S',\n                    help='random seed (default: 42)')\nparser.add_argument('--log-interval', type=int, default=50, metavar='N',\n                    help='how many batches to wait before logging training status')\nparser.add_argument('--recovery-interval', type=int, default=0, metavar='N',\n                    help='how many batches to wait before writing recovery checkpoint')\nparser.add_argument('-j', '--workers', type=int, default=8, metavar='N',\n                    help='how many training processes to use (default: 1)')\nparser.add_argument('--num-gpu', type=int, default=1,\n                    help='Number of GPUS to use')\nparser.add_argument('--save-images', action='store_true', default=False,\n                    help='save images of inputs bathes every log interval for debugging')\nparser.add_argument('--amp', action='store_true', default=False,\n                    help='use NVIDIA Apex AMP or Native AMP for mixed precision training')\nparser.add_argument('--apex-amp', action='store_true', default=False,\n                    help='Use NVIDIA Apex AMP mixed precision')\nparser.add_argument('--native-amp', action='store_true', default=False,\n                    help='Use Native Torch AMP mixed precision')\nparser.add_argument('--channels-last', action='store_true', default=False,\n                    help='Use channels_last memory layout')\nparser.add_argument('--pin-mem', action='store_true', default=False,\n                    help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')\nparser.add_argument('--no-prefetcher', action='store_true', default=False,\n                    help='disable fast prefetcher')\nparser.add_argument('--output', default='/home/shensicheng/code/SpikingTransformers', type=str, metavar='PATH',\n                    help='path to output folder (default: none, current dir)')\nparser.add_argument('--tensorboard-dir', default='./runs', type=str)\nparser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',\n                    help='Best metric (default: \"top1\"')\nparser.add_argument('--tta', type=int, default=0, metavar='N',\n                    help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')\nparser.add_argument('--local_rank', default=0, type=int)\nparser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,\n                    help='use the multi-epochs-loader to save time at the beginning of every epoch')\nparser.add_argument('--eval', action='store_true', help='Perform evaluation only')\nparser.add_argument('--device', type=int, default=0)\n\n# Spike parameters\nparser.add_argument('--step', type=int, default=10, help='Simulation time step (default: 10)')\nparser.add_argument('--encode', type=str, default='direct', help='Input encode method (default: direct)')\nparser.add_argument('--temporal-flatten', action='store_true',\n                    help='Temporal flatten to channels. ONLY FOR EVENT DATA TRAINING BY ANN')\nparser.add_argument('--adaptive-node', action='store_true')\nparser.add_argument('--critical-loss', action='store_true')\nparser.add_argument('--conv-type', type=str, default='normal')\nparser.add_argument('--sew-cnf', type=str, default='ADD')\nparser.add_argument('--rand-step', action='store_true')\n\n# neuron type\nparser.add_argument('--node-type', type=str, default='LIFNode', help='Node type in network (default: PLIF)')\nparser.add_argument('--act-fun', type=str, default='QGateGrad',\n                    help='Surogate Function in node. Only for Surrogate nodes (default: AtanGrad)')\nparser.add_argument('--threshold', type=float, default=.5, help='Firing threshold (default: 0.5)')\nparser.add_argument('--tau', type=float, default=2., help='Attenuation coefficient (default: 2.)')\nparser.add_argument('--requires-thres-grad', action='store_true')\nparser.add_argument('--sigmoid-thres', action='store_true')\n\nparser.add_argument('--loss-fn', type=str, default='ce', help='loss function (default: ce)')\nparser.add_argument('--noisy-grad', type=float, default=0.,\n                    help='Add noise to backward, sometime will make higher accuracy (default: 0.)')\nparser.add_argument('--spike-output', action='store_true', default=False,\n                    help='Using mem output or spike output (default: False)')\nparser.add_argument('--n_groups', type=int, default=1)\nparser.add_argument('--n-encode-type', type=str, default='linear')\nparser.add_argument('--n-preact', action='store_true')\nparser.add_argument('--layer-by-layer', action='store_true',\n                    help='forward step-by-step or layer-by-layer. '\n                         'Larger Model with layer-by-layer will be faster (default: False)')\nparser.add_argument('--tet-loss', action='store_true')\n\n# EventData Augmentation\nparser.add_argument('--mix-up', action='store_true', help='Mix-up for event data (default: False)')\nparser.add_argument('--cut-mix', action='store_true', help='CutMix for event data (default: False)')\nparser.add_argument('--event-mix', action='store_true', help='EventMix for event data (default: False)')\nparser.add_argument('--cutmix_beta', type=float, default=2.0, help='cutmix_beta (default: 1.)')\nparser.add_argument('--cutmix_prob', type=float, default=0.5, help='cutmix_prib for event data (default: .5)')\nparser.add_argument('--cutmix_num', type=int, default=1, help='cutmix_num for event data (default: 1)')\nparser.add_argument('--cutmix_noise', type=float, default=0.,\n                    help='Add Pepper noise after mix, sometimes work (default: 0.)')\nparser.add_argument('--gaussian-n', type=int, default=3)\nparser.add_argument('--rand-aug', action='store_true',\n                    help='Rand Augment for Event data (default: False)')\nparser.add_argument('--randaug_n', type=int, default=3,\n                    help='Rand Augment times n (default: 3)')\nparser.add_argument('--randaug_m', type=int, default=15,\n                    help='Rand Augment times n (default: 15) (0-30)')\nparser.add_argument('--train-portion', type=float, default=0.9,\n                    help='Dataset portion, only for datasets which do not have validation set (default: 0.9)')\nparser.add_argument('--event-size', default=48, type=int,\n                    help='Event size. Resize event data before process (default: 48)')\nparser.add_argument('--node-resume', type=str, default='',\n                    help='resume weights in node for adaptive node. (default: False)')\n\n# visualize\nparser.add_argument('--visualize', action='store_true',\n                    help='Visualize spiking map for each layer, only for validate (default: False)')\nparser.add_argument('--spike-rate', action='store_true',\n                    help='Print spiking rate for each layer, only for validate(default: False)')\nparser.add_argument('--tsne', action='store_true')\nparser.add_argument('--conf-mat', action='store_true')\nparser.add_argument('--mem-dist', action='store_true')\nparser.add_argument('--adaptation-info', action='store_true')\n\nparser.add_argument('--suffix', type=str, default='',\n                    help='Add an additional suffix to the save path (default: \\'\\')')\n\ntry:\n    from apex import amp\n    from apex.parallel import DistributedDataParallel as ApexDDP\n    from apex.parallel import convert_syncbn_model\n\n    has_apex = True\nexcept ImportError:\n    has_apex = False\n\nhas_native_amp = False\ntry:\n    if getattr(torch.cuda.amp, 'autocast') is not None:\n        has_native_amp = True\nexcept AttributeError:\n    pass\n\n\ndef _parse_args():\n    # Do we have a config file to parse?\n    args_config, remaining = config_parser.parse_known_args()\n    if args_config.config:\n        with open(args_config.config, 'r') as f:\n            cfg = yaml.safe_load(f)\n            parser.set_defaults(**cfg)\n\n    # The main arg parser parses the rest of the args, the usual\n    # defaults will have been overridden if config file specified.\n    args = parser.parse_args(remaining)\n\n    # Cache the args as a text string to save them in the output dir later\n    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)\n    return args, args_text\n\n\n\ndef main():\n    args, args_text = _parse_args()\n    # args.no_spike_output = args.no_spike_output | args.cut_mix\n    args.no_spike_output = True\n    output_dir = ''\n    if args.local_rank == 0:\n        output_base = args.output if args.output else './output'\n        exp_name = '-'.join([\n            args.model,\n            args.dataset,\n            args.node_type,\n            str(args.step),\n            args.suffix,\n            datetime.now().strftime(\"%Y%m%d-%H%M%S\"),\n            # str(args.img_size)\n        ])\n        output_dir = get_outdir(output_base, 'train', exp_name)\n        args.output_dir = output_dir\n        setup_default_logging(log_path=os.path.join(output_dir, 'log.txt'))\n        summary_writer = SummaryWriter(log_dir=os.path.join(args.tensorboard_dir, exp_name))\n        args.tensorboard_prefix = os.path.join(args.dataset, args.model)\n    else:\n        summary_writer = None\n        setup_default_logging()\n\n    args.prefetcher = not args.no_prefetcher\n    args.distributed = False\n    if 'WORLD_SIZE' in os.environ:\n        args.distributed = int(os.environ['WORLD_SIZE']) > 1\n        if args.distributed and args.num_gpu > 1:\n            _logger.warning(\n                'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')\n            args.num_gpu = 1\n\n    # args.device = 'cuda:0'\n    args.world_size = 1\n    args.rank = 0  # global rank\n    if args.distributed:\n        args.num_gpu = 1\n        args.device = 'cuda:%d' % args.local_rank\n        torch.cuda.set_device(args.local_rank)\n        torch.distributed.init_process_group(backend='nccl', init_method='env://')\n        args.world_size = torch.distributed.get_world_size()\n        args.rank = torch.distributed.get_rank()\n    else:\n        torch.cuda.set_device('cuda:%d' % args.device)\n    assert args.rank >= 0\n\n    if args.distributed:\n        _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'\n                     % (args.rank, args.world_size))\n    else:\n        _logger.info('Training with a single process on %d GPUs.' % args.num_gpu)\n\n    # torch.manual_seed(args.seed + args.rank)\n    setup_seed(args.seed + args.rank)\n\n    model = create_model(\n        args.model,\n        # pretrained=args.pretrained,\n        # num_classes=args.num_classes,\n        # dataset=args.dataset,\n        # step=args.step,\n        # encode_type=args.encode,\n        # node_type=eval(args.node_type),\n        # threshold=args.threshold,\n        # tau=args.tau,\n        # sigmoid_thres=args.sigmoid_thres,\n        # requires_thres_grad=args.requires_thres_grad,\n        # spike_output=not args.no_spike_output,\n        # act_fun=args.act_fun,\n        # temporal_flatten=args.temporal_flatten,\n        # layer_by_layer=args.layer_by_layer,\n        # n_groups=args.n_groups,\n        # n_encode_type=args.n_encode_type,\n        # n_preact=args.n_preact,\n        # tet_loss=args.tet_loss,\n        # sew_cnf=args.sew_cnf,\n        # conv_type=args.conv_type,\n    )\n\n    _logger.info('[MODEL ARCH]\\n{}'.format(model))\n\n    if 'dvs' in args.dataset:\n        args.channels = 2\n    elif 'mnist' in args.dataset:\n        args.channels = 1\n    else:\n        args.channels = 3\n    # flops, params = profile(model, inputs=(torch.randn(1, args.channels, args.event_size, args.event_size),), verbose=False)\n    # _logger.info('flops = %fM', flops / 1e6)\n    # _logger.info('param size = %fM', params / 1e6)\n\n    linear_scaled_lr = args.lr * args.batch_size * args.world_size / 1024.0\n    args.lr = linear_scaled_lr\n    _logger.info(\"learning rate is %f\" % linear_scaled_lr)\n\n    if args.local_rank == 0:\n        _logger.info('Model %s created, param count: %d' %\n                     (args.model, sum([m.numel() for m in model.parameters()])))\n\n    num_aug_splits = 0\n    if args.aug_splits > 0:\n        assert args.aug_splits > 1, 'A split of 1 makes no sense'\n        num_aug_splits = args.aug_splits\n\n    if args.split_bn:\n        assert num_aug_splits > 1 or args.resplit\n        model = convert_splitbn_model(model, max(num_aug_splits, 2))\n\n    use_amp = None\n    if args.amp:\n        # for backwards compat, `--amp` arg tries apex before native amp\n        if has_apex:\n            args.apex_amp = True\n        elif has_native_amp:\n            args.native_amp = True\n    if args.apex_amp and has_apex:\n        use_amp = 'apex'\n    elif args.native_amp and has_native_amp:\n        use_amp = 'native'\n    elif args.apex_amp or args.native_amp:\n        _logger.warning(\"Neither APEX or native Torch AMP is available, using float32. \"\n                        \"Install NVIDA apex or upgrade to PyTorch 1.6\")\n\n    if args.num_gpu > 1:\n        if use_amp == 'apex':\n            _logger.warning(\n                'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.')\n            use_amp = None\n        model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()\n        assert not args.channels_last, \"Channels last not supported with DP, use DDP.\"\n    else:\n        model = model.cuda()\n        if args.channels_last:\n            model = model.to(memory_format=torch.channels_last)\n\n    optimizer = create_optimizer(args, model)\n\n    _logger.info('[OPTIMIZER]\\n{}'.format(optimizer))\n\n    amp_autocast = suppress  # do nothing\n    loss_scaler = None\n    if use_amp == 'apex':\n        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')\n        loss_scaler = ApexScaler()\n        if args.local_rank == 0:\n            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')\n    elif use_amp == 'native':\n        amp_autocast = torch.cuda.amp.autocast\n        loss_scaler = NativeScaler()\n        if args.local_rank == 0:\n            _logger.info('Using native Torch AMP. Training in mixed precision.')\n    else:\n        if args.local_rank == 0:\n            _logger.info('AMP not enabled. Training in float32.')\n\n    # optionally resume from a checkpoint\n    resume_epoch = None\n    if args.resume and args.eval_checkpoint == '':\n        args.eval_checkpoint = args.resume\n    if args.resume:\n        args.eval = True\n        # checkpoint = torch.load(args.resume, map_location='cpu')\n        # model.load_state_dict(checkpoint['state_dict'], False)\n        resume_epoch = resume_checkpoint(\n            model, args.resume,\n            optimizer=None if args.no_resume_opt else optimizer,\n            loss_scaler=None if args.no_resume_opt else loss_scaler,\n            log_info=args.local_rank == 0)\n        # print(model.get_attr('mu'))\n        # print(model.get_attr('sigma'))\n        if hasattr(model, 'set_threshold'):\n            model.set_threshold(args.threshold)\n\n    if args.critical_loss or args.spike_rate:\n        model.set_requires_fp(True)\n\n    model_ema = None\n    if args.model_ema:\n        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper\n        model_ema = ModelEma(\n            model,\n            decay=args.model_ema_decay,\n            device='cpu' if args.model_ema_force_cpu else '',\n            resume=args.resume)\n\n    if args.node_resume:\n        ckpt = torch.load(args.node_resume, map_location='cpu')\n        model.load_node_weight(ckpt, args.node_trainable)\n\n    model_without_ddp = model\n    if args.distributed:\n        if args.sync_bn:\n            assert not args.split_bn\n            try:\n                if has_apex and use_amp != 'native':\n                    # Apex SyncBN preferred unless native amp is activated\n                    model = convert_syncbn_model(model)\n                else:\n                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)\n                if args.local_rank == 0:\n                    _logger.info(\n                        'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '\n                        'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')\n            except Exception as e:\n                _logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')\n        if has_apex and use_amp != 'native':\n            # Apex DDP preferred unless native amp is activated\n            if args.local_rank == 0:\n                _logger.info(\"Using NVIDIA APEX DistributedDataParallel.\")\n            model = ApexDDP(model, delay_allreduce=True)\n        else:\n            if args.local_rank == 0:\n                _logger.info(\"Using native Torch DistributedDataParallel.\")\n            model = NativeDDP(model.cuda(), device_ids=[args.local_rank],\n                              find_unused_parameters=True)  # can use device str in Torch >= 1.1\n        model_without_ddp = model.module\n    # NOTE: EMA model does not need to be wrapped by DDP\n\n    lr_scheduler, num_epochs = create_scheduler(args, optimizer)\n    start_epoch = 0\n    if args.start_epoch is not None:\n        # a specified start_epoch will always override the resume epoch\n        start_epoch = args.start_epoch\n    elif resume_epoch is not None:\n        start_epoch = resume_epoch\n    if lr_scheduler is not None and start_epoch > 0:\n        lr_scheduler.step(start_epoch)\n\n    if args.local_rank == 0:\n        _logger.info('Scheduled epochs: {}'.format(num_epochs))\n\n    # now config only for imnet\n    data_config = resolve_data_config(vars(args), model=model, verbose=False)\n    loader_train, loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % args.dataset)(\n        batch_size=args.batch_size,\n        step=args.step,\n        args=args,\n        _logge=_logger,\n        data_config=data_config,\n        num_aug_splits=num_aug_splits,\n        size=args.event_size,\n        mix_up=args.mix_up,\n        cut_mix=args.cut_mix,\n        event_mix=args.event_mix,\n        beta=args.cutmix_beta,\n        prob=args.cutmix_prob,\n        gaussian_n=args.gaussian_n,\n        num=args.cutmix_num,\n        noise=args.cutmix_noise,\n        num_classes=args.num_classes,\n        rand_aug=args.rand_aug,\n        randaug_n=args.randaug_n,\n        randaug_m=args.randaug_m,\n        portion=args.train_portion,\n        _logger=_logger,\n    )\n    # _logger.info('train_loader:\\n{}\\nval_loader:\\n{}'.format(loader_train, loader_eval))\n    if args.loss_fn == 'mse':\n        train_loss_fn = UnilateralMse(1.)\n        validate_loss_fn = UnilateralMse(1.)\n    elif args.loss_fn == 'onehot-mse':\n        train_loss_fn = OnehotMse(args.num_classes)\n        validate_loss_fn = OnehotMse(args.num_classes)\n    else:\n        if args.jsd:\n            assert num_aug_splits > 1  # JSD only valid with aug splits set\n            train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()\n        elif mixup_active:\n            # smoothing is handled with mixup target transform\n            train_loss_fn = SoftTargetCrossEntropy().cuda()\n        elif args.smoothing:\n            train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()\n        else:\n            train_loss_fn = nn.CrossEntropyLoss().cuda()\n\n        validate_loss_fn = nn.CrossEntropyLoss().cuda()\n\n    if args.loss_fn == 'mix':\n        train_loss_fn = MixLoss(train_loss_fn)\n        validate_loss_fn = MixLoss(validate_loss_fn)\n\n    if args.tet_loss:\n        train_loss_fn = TetLoss(train_loss_fn)\n        validate_loss_fn = TetLoss(validate_loss_fn)\n\n    eval_metric = args.eval_metric\n    best_metric = None\n    best_epoch = None\n\n    if args.eval:  # evaluate the model\n        # if args.distributed:\n        #     raise NotImplementedError('eval not has not been verified for distributed')\n        # else:\n        #     load_checkpoint(model, args.eval_checkpoint, args.model_ema)\n        model.eval()\n        for t in range(1, args.step * 3):\n        # for t in range(args.step, args.step + 1):\n            model.set_attr('step', t)\n            val_metrics = validate(start_epoch, model, loader_eval, validate_loss_fn, args,\n                                   visualize=args.visualize, spike_rate=args.spike_rate,\n                                   tsne=args.tsne, conf_mat=args.conf_mat, summary_writer=summary_writer)\n            print(f\"[STEP:{t}], Top-1 accuracy of the model is: {val_metrics['top1']:.1f}%\")\n        return\n\n    saver = None\n    if args.local_rank == 0:\n        decreasing = True if eval_metric == 'loss' else False\n        saver = CheckpointSaver(\n            model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,\n            checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=3)\n        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:\n            f.write(args_text)\n\n    try:  # train the model\n        if args.reset_drop:\n            model_without_ddp.reset_drop_path(0.0)\n        for epoch in range(start_epoch, args.epochs):\n            if epoch == 0 and args.reset_drop:\n                model_without_ddp.reset_drop_path(args.drop_path)\n\n            if args.distributed:\n                loader_train.sampler.set_epoch(epoch)\n\n            train_metrics = train_epoch(\n                epoch, model, loader_train, optimizer, train_loss_fn, args,\n                lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,\n                amp_autocast=amp_autocast, loss_scaler=loss_scaler,\n                model_ema=model_ema, mixup_fn=mixup_fn, summary_writer=summary_writer\n            )\n\n            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):\n                if args.local_rank == 0:\n                    _logger.info(\"Distributing BatchNorm running means and vars\")\n                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')\n\n            eval_metrics = validate(epoch, model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast,\n                                    visualize=args.visualize, spike_rate=args.spike_rate,\n                                    tsne=args.tsne, conf_mat=args.conf_mat, summary_writer=summary_writer)\n\n            if model_ema is not None and not args.model_ema_force_cpu:\n                if args.distributed and args.dist_bn in ('broadcast', 'reduce'):\n                    distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')\n                ema_eval_metrics = validate(\n                    epoch, model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)',\n                    visualize=args.visualize, spike_rate=args.spike_rate,\n                    tsne=args.tsne, conf_mat=args.conf_mat, summary_writer=summary_writer\n                )\n                eval_metrics = ema_eval_metrics\n\n            if lr_scheduler is not None:\n                # step LR for next epoch\n                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])\n\n            update_summary(\n                epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),\n                write_header=best_metric is None)\n\n            # if saver is not None and epoch >= args.n_warm_up:\n            if saver is not None:\n                # save proper checkpoint with eval metric\n                save_metric = eval_metrics[eval_metric]\n                best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)\n\n    except KeyboardInterrupt:\n        pass\n    if best_metric is not None:\n        _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))\n\n\ndef train_epoch(\n        epoch, model, loader, optimizer, loss_fn, args,\n        lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,\n        loss_scaler=None, model_ema=None, mixup_fn=None, summary_writer=None):\n    if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:\n        if args.prefetcher and loader.mixup_enabled:\n            loader.mixup_enabled = False\n        elif mixup_fn is not None:\n            mixup_fn.mixup_enabled = False\n\n    second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order\n    batch_time_m = AverageMeter()\n    data_time_m = AverageMeter()\n    losses_m = AverageMeter()\n    # closses_m = AverageMeter()\n    top1_m = AverageMeter()\n    top5_m = AverageMeter()\n\n    model.train()\n\n    # t, k = adjust_surrogate_coeff(100, args.epochs)\n    # model.set_attr('t', t)\n    # model.set_attr('k', k)\n\n    end = time.time()\n    last_idx = len(loader) - 1\n    num_updates = epoch * len(loader)\n    iters_per_epoch = len(loader)\n    for batch_idx, (inputs, target) in enumerate(loader):\n        last_batch = batch_idx == last_idx\n        if args.rand_step:\n            step = buildin_random.randint(1, args.step + 2)\n            model.set_attr('step', step)\n\n        data_time_m.update(time.time() - end)\n        if not args.prefetcher or args.dataset != 'imnet':\n            inputs, target = inputs.type(torch.FloatTensor).cuda(), target.cuda()\n            if mixup_fn is not None:\n                inputs, target = mixup_fn(inputs, target)\n        if args.channels_last:\n            inputs = inputs.contiguous(memory_format=torch.channels_last)\n        with amp_autocast():\n            output = model(inputs)\n            loss = loss_fn(output, target)\n        if args.tet_loss:\n            output = output.mean(0)\n\n        if not (args.cut_mix | args.mix_up | args.event_mix | (args.cutmix != 0.) | (args.mixup != 0.)):\n            # print(output.shape, target.shape)\n            acc1, acc5 = accuracy(output, target, topk=(1, 5))\n            # acc1, = accuracy(output, target)\n        else:\n            acc1, acc5 = torch.tensor([0.]), torch.tensor([0.])\n\n        optimizer.zero_grad()\n        if loss_scaler is not None:\n            loss_scaler(\n                loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order)\n        else:\n            loss.backward(create_graph=second_order)\n            if args.noisy_grad != 0.:\n                random_gradient(model, args.noisy_grad)\n            if args.clip_grad is not None:\n                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)\n            # if args.opt == 'lamb':\n            #     optimizer.step(epoch=epoch)\n            # else:\n            optimizer.step()\n\n        torch.cuda.synchronize()\n        if model_ema is not None:\n            model_ema.update(model)\n        num_updates += 1\n\n        batch_time_m.update(time.time() - end)\n\n        if args.local_rank == 0:\n            summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/train/top1'), acc1.item(), epoch * iters_per_epoch + batch_idx)\n            summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/train/top5'), acc5.item(), epoch * iters_per_epoch + batch_idx)\n            summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/train/loss'), loss.item(), epoch * iters_per_epoch + batch_idx)\n\n        if last_batch or batch_idx % args.log_interval == 0:\n            lrl = [param_group['lr'] for param_group in optimizer.param_groups]\n            lr = sum(lrl) / len(lrl)\n\n            if args.distributed:\n                loss = reduce_tensor(loss.data, args.world_size)\n                acc1 = reduce_tensor(acc1, args.world_size)\n                acc5 = reduce_tensor(acc5, args.world_size)\n\n            losses_m.update(loss.item(), inputs.size(0))\n            top1_m.update(acc1.item(), output.size(0))\n            top5_m.update(acc5.item(), output.size(0))\n                # closses_m.update(reduced_loss.item(), inputs.size(0))\n\n            if args.local_rank == 0:\n                # if args.distributed:\n                _logger.info(\n                    'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '\n                    'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f})  '\n                    'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '\n                    'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})  '\n                    'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '\n                    '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '\n                    'LR: {lr:.3e}  '\n                    'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(\n                        epoch,\n                        batch_idx, len(loader),\n                        100. * batch_idx / last_idx,\n                        loss=losses_m,\n                        top1=top1_m,\n                        top5=top5_m,\n                        batch_time=batch_time_m,\n                        rate=inputs.size(0) * args.world_size / batch_time_m.val,\n                        rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,\n                        lr=lr,\n                        data_time=data_time_m\n                    ))\n\n                if args.save_images and output_dir:\n                    torchvision.utils.save_image(\n                        inputs,\n                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),\n                        padding=0,\n                        normalize=True)\n\n        if saver is not None and args.recovery_interval and (\n                last_batch or (batch_idx + 1) % args.recovery_interval == 0):\n            saver.save_recovery(epoch, batch_idx=batch_idx)\n\n        if lr_scheduler is not None:\n            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)\n\n        end = time.time()\n    # end for\n\n    if hasattr(optimizer, 'sync_lookahead'):\n        optimizer.sync_lookahead()\n\n    if args.local_rank == 0:\n        summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/train/top1'), top1_m.avg, epoch)\n        summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/train/top5'), top5_m.avg, epoch)\n        summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/train/loss'), losses_m.avg, epoch)\n\n    if args.rand_step:\n        model.set_attr('step', args.step)\n\n    return OrderedDict([('loss', losses_m.avg)])\n\n\ndef validate(epoch, model, loader, loss_fn, args, amp_autocast=suppress,\n             log_suffix='', visualize=False, spike_rate=False, tsne=False, conf_mat=False, summary_writer=None):\n    batch_time_m = AverageMeter()\n    losses_m = AverageMeter()\n    # closses_m = AverageMeter()\n    top1_m = AverageMeter()\n    top5_m = AverageMeter()\n    spike_m = AverageMeter()\n\n    model.eval()\n\n    feature_vec = []\n    feature_cls = []\n    logits_vec = []\n    labels_vec = []\n    mem_vec = []\n\n    end = time.time()\n    last_idx = len(loader) - 1\n    iters_per_epoch = len(loader)\n    with torch.no_grad():\n\n        for batch_idx, (inputs, target) in enumerate(loader):\n            # inputs = inputs.type(torch.float64)\n            last_batch = batch_idx == last_idx\n            if not args.prefetcher or args.dataset != 'imnet':\n                inputs = inputs.type(torch.FloatTensor).cuda()\n                target = target.cuda()\n            if args.channels_last:\n                inputs = inputs.contiguous(memory_format=torch.channels_last)\n\n            if not args.distributed:\n                if (visualize or spike_rate or tsne or conf_mat or args.mem_dist) and not args.critical_loss:\n                    model.set_requires_fp(True)\n\n            with amp_autocast():\n                output = model(inputs)\n\n            if isinstance(output, (tuple, list)):\n                output = output[0]\n\n            # augmentation reduction\n            reduce_factor = args.tta\n            if reduce_factor > 1:\n                output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)\n                target = target[0:target.size(0):reduce_factor]\n\n            # print(args.rank, output.shape, target.shape, max(target))\n            loss = loss_fn(output, target)\n            if args.tet_loss:\n                output = output.mean(0)\n            acc1, acc5 = accuracy(output, target, topk=(1, 5))\n\n            if args.distributed:\n                reduced_loss = reduce_tensor(loss.data, args.world_size)\n                acc1 = reduce_tensor(acc1, args.world_size)\n                acc5 = reduce_tensor(acc5, args.world_size)\n            else:\n                reduced_loss = loss.data\n\n            torch.cuda.synchronize()\n\n            losses_m.update(reduced_loss.item(), inputs.size(0))\n            top1_m.update(acc1.item(), output.size(0))\n            top5_m.update(acc5.item(), output.size(0))\n            # closses_m.update(closs, inputs.size(0))\n\n            batch_time_m.update(time.time() - end)\n            end = time.time()\n\n            if args.local_rank == 0:\n                summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/val/top1'), acc1.item(), epoch * iters_per_epoch + batch_idx)\n                summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/val/top5'), acc5.item(), epoch * iters_per_epoch + batch_idx)\n                summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/val/loss'), loss.item(), epoch * iters_per_epoch + batch_idx)\n\n            if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):\n                log_name = 'Test' + log_suffix\n\n            if not args.distributed and spike_rate:\n                spike_m.update(model.get_tot_spike() / output.size(0), output.size(0))\n\n                if not args.distributed and spike_rate:\n                    _logger.info(\n                        '[Spike Info]: {spike.val} ({spike.avg})'.format(\n                            spike=spike_m\n                        )\n                    )\n            if last_batch or batch_idx % args.log_interval == 0:\n                _logger.info(\n                    'Eval : {} '\n                    'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '\n                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '\n                    'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})'\n                    'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(\n                        epoch,\n                        batch_idx,\n                        last_idx,\n                        batch_time=batch_time_m,\n                        loss=losses_m,\n                        top1=top1_m,\n                        top5=top5_m,\n                        ))\n\n    # metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])\n    metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg)])\n\n    if args.local_rank == 0:\n        summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/val/top1'), top1_m.avg, epoch)\n        summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/val/top5'), top5_m.avg, epoch)\n        summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/val/loss'), losses_m.avg, epoch)\n    return metrics\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "examples/Spiking-Transformers/models/spike_driven_transformer.py",
    "content": "import torch\nimport torch.nn as nn\nfrom timm.models.layers import to_2tuple, trunc_normal_, DropPath\nfrom timm.models.registry import register_model\nfrom timm.models.vision_transformer import _cfg\nimport torch.nn.functional as F\nfrom braincog.model_zoo.base_module import BaseModule\nfrom braincog.base.node.node import *\nfrom braincog.base.connection.layer import *\nfrom braincog.base.strategy.surrogate import *\nfrom LIFNode import MyNode  # LIFNode setting for Spiking Tranformers\nfrom functools import partial\n\n__all__ = ['spikformer']\n\n'''The input shape of neuromorphic datasets in Spiking Transformer when using Braincog\nare used to set to 64*64 '''\n\n\n\nclass MLP(BaseModule):\n    #Linear here is subsituted by convs\n    def __init__(self, in_features, step=10, encode_type='direct', hidden_features=None, out_features=None, drop=0.):\n        super().__init__(step=10, encode_type='direct')\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1_conv = nn.Conv1d(in_features, hidden_features, kernel_size=1, stride=1)\n        self.fc1_bn = nn.BatchNorm1d(hidden_features)\n        self.fc1_lif = MyNode(step=step, tau=2.0)\n\n        self.fc2_conv = nn.Conv1d(hidden_features, out_features, kernel_size=1, stride=1)\n        self.fc2_bn = nn.BatchNorm1d(out_features)\n        self.fc2_lif = MyNode(step=step, tau=2.0)\n\n        self.c_hidden = hidden_features\n        self.c_output = out_features\n\n    def forward(self, x):\n        self.reset()\n\n        T, B, C, N = x.shape\n\n        x = self.fc1_lif(x.flatten(0, 1)).reshape(T, B, C, N).contiguous()\n        x = self.fc1_conv(x.flatten(0, 1)) \n        x = self.fc1_bn(x).reshape(T, B, self.c_hidden, N).contiguous()  # T B C N\n        \n        x = self.fc2_lif(x.flatten(0, 1)).reshape(T, B, self.c_hidden, N).contiguous()\n        x = self.fc2_conv(x.flatten(0, 1))\n        x = self.fc2_bn(x).reshape(T, B, C, N).contiguous()\n        \n        return x\n\nclass SSA(BaseModule):\n    def __init__(self, dim, step=10, encode_type='direct', num_heads=16, qkv_bias=False, qk_scale=None, attn_drop=0.,\n                 proj_drop=0., sr_ratio=1):\n        super().__init__(step=10, encode_type='direct')\n        assert dim % num_heads == 0, f\"dim {dim} should be divided by num_heads {num_heads}.\"\n        self.dim = dim\n\n        # for shortcut\n        self.head_lif = MyNode(step=step, tau=2.0)\n\n        self.num_heads = num_heads\n        # scale\n        self.scale = 0.25\n\n        self.q_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)\n        self.q_bn = nn.BatchNorm1d(dim)\n        self.q_lif = MyNode(step=step, tau=2.0)\n\n        self.k_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)\n        self.k_bn = nn.BatchNorm1d(dim)\n        self.k_lif = MyNode(step=step, tau=2.0)\n\n        self.v_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)\n        self.v_bn = nn.BatchNorm1d(dim)\n        self.v_lif = MyNode(step=step, tau=2.0)\n\n        self.attn_drop = nn.Dropout(0.2)\n        self.res_lif = MyNode(step=step, tau=2.0)\n        self.attn_lif = MyNode(step=step, tau=2.0, v_threshold=0.5, )\n\n        self.proj_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)\n        self.proj_bn = nn.BatchNorm1d(dim)\n        self.proj_lif = MyNode(step=step, tau=2.0, )\n\n        self.sd_lif = PLIFNode(step=step, threshold=0.5, tau=2)\n\n    def forward(self, x):\n        self.reset()\n\n        T, B, C, N = x.shape\n\n        x_for_qkv = x.flatten(0, 1)  # TB, C N\n\n        x_for_qkv = self.head_lif(x_for_qkv)\n\n        q_conv_out = self.q_conv(x_for_qkv)  # [TB] C N\n        q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, N).contiguous()  # T B C N\n        q_conv_out = self.q_lif(q_conv_out.flatten(0, 1)).reshape(T, B, C, N)  # TB C N\n        q = q_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()\n\n        k_conv_out = self.k_conv(x_for_qkv)\n        k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, N).contiguous()\n        k_conv_out = self.k_lif(k_conv_out.flatten(0, 1)).reshape(T, B, C, N)  # TB C N\n        k = k_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()\n\n        v_conv_out = self.v_conv(x_for_qkv)\n        v_conv_out = self.v_bn(v_conv_out).reshape(T, B, C, N).contiguous()\n        v_conv_out = self.v_lif(v_conv_out.flatten(0, 1)).reshape(T, B, C, N)  # TB C N\n        v = v_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()\n\n\n        # Spike-driven Transformer attention\n        kv = k.mul(v)\n        kv = kv.sum(dim=-2, keepdim=True)\n        kv = self.sd_lif(kv)\n\n        x = q.mul(kv)\n\n        x = x.transpose(3,4).reshape(T, B, C, N).contiguous() # T B C N\n        # ignore following lines for membrane shortcut\n        # x = self.attn_lif(x.flatten(0,1)) #[TB] C N\n        # x = self.proj_lif(self.proj_bn(self.proj_conv(x))).reshape(T, B, C, N) #T B C N\n        return x\n\n\nclass Block(nn.Module):\n    def __init__(self, dim, num_heads, step=10, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = SSA(dim, step=step, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                        attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = MLP(step=step, in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)\n\n    def forward(self, x):\n        # residual connection\n        x = x + self.attn(x)\n        x = x + self.mlp(x)\n        return x\n\n\n# embed_dims = 256\nclass SPS(BaseModule):\n    def __init__(self, step=10, encode_type='direct', img_size_h=128, img_size_w=128, patch_size=4, in_channels=2,\n                 embed_dims=256):\n        super().__init__(step=10, encode_type='direct')\n        self.image_size = [img_size_h, img_size_w]\n        patch_size = to_2tuple(patch_size)  # 4->(4,4)\n        self.patch_size = patch_size  # patch_size\n        self.C = in_channels  # image_channel\n        self.H, self.W = self.image_size[0] // patch_size[0], self.image_size[1] // patch_size[1]\n        self.num_patches = self.H * self.W\n\n        # DVS with 2 more Maxpooling\n\n        self.proj_conv = nn.Conv2d(in_channels, embed_dims // 8, kernel_size=3, stride=1, padding=1, bias=False)\n        self.proj_bn = nn.BatchNorm2d(embed_dims // 8)\n        self.proj_lif = MyNode(step=step, tau=2.0)\n        self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n\n        self.proj_conv1 = nn.Conv2d(embed_dims // 8, embed_dims // 4, kernel_size=3, stride=1, padding=1, bias=False)\n        self.proj_bn1 = nn.BatchNorm2d(embed_dims // 4)\n        self.proj_lif1 = MyNode(step=step, tau=2.0)\n        self.maxpool1 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n\n        self.proj_conv2 = nn.Conv2d(embed_dims // 4, embed_dims // 2, kernel_size=3, stride=1, padding=1, bias=False)\n        self.proj_bn2 = nn.BatchNorm2d(embed_dims // 2)\n        self.proj_lif2 = MyNode(step=step, tau=2.0)\n        self.maxpool2 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n\n        self.proj_conv3 = nn.Conv2d(embed_dims // 2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)\n        self.proj_bn3 = nn.BatchNorm2d(embed_dims)\n        # self.proj_lif3 = MyNode(step=step, tau=2.0)\n        self.maxpool3 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n\n        self.rpe_conv = nn.Conv2d(embed_dims, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)\n        self.rpe_bn = nn.BatchNorm2d(embed_dims)\n        self.rpe_lif = MyNode(step=step, tau=2.0)\n\n    def forward(self, x):\n        self.reset()\n\n        T, B, C, H, W = x.shape\n\n        x = self.proj_conv(x.flatten(0, 1))  # have some fire value\n        x = self.proj_bn(x).reshape(T, B, -1, H, W).contiguous()\n        x = self.proj_lif(x.flatten(0, 1)).contiguous()\n        x = self.maxpool(x)\n\n        x = self.proj_conv1(x)\n        x = self.proj_bn1(x).reshape(T, B, -1, H // 2, W // 2).contiguous()\n        x = self.proj_lif1(x.flatten(0, 1)).contiguous()\n        x = self.maxpool1(x)\n\n        x = self.proj_conv2(x)\n        x = self.proj_bn2(x).reshape(T, B, -1, H // 4, W // 4).contiguous()\n        x = self.proj_lif2(x.flatten(0, 1)).contiguous()\n        x = self.maxpool2(x)\n\n        x = self.proj_conv3(x)\n        x = self.proj_bn3(x).reshape(T, B, -1, H // 8, W // 8)\n        # abandon the LIF here to leverage membrane shortcut\n        # x = self.proj_lif3(x.flatten(0, 1)).contiguous()  \n        x = self.maxpool3(x.flatten(0,1)).reshape(T, B, -1, H // 16, W // 16)\n\n\n        # The order here is different from spikformer for using membrain shortcut\n        x_rpe = self.rpe_lif(x.flatten(0, 1)).contiguous() \n        x_rpe = self.rpe_bn(self.rpe_conv(x_rpe)).reshape(T, B, -1, H // 16, W // 16).contiguous()\n        \n        x = x + x_rpe # membrane shortcut\n\n        x = x.reshape(T, B, -1, (H // 16) * (H // 16)).contiguous()\n\n        return x  # T B C N\n\n\nclass Spikformer(BaseModule):\n    def __init__(self, step=10, encode_type='direct',\n                 img_size_h=224, img_size_w=224, patch_size=16, in_channels=3, num_classes=1000,\n                 embed_dims=512, num_heads=12, mlp_ratios=4, qkv_bias=False, qk_scale=None,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,\n                 depths=8, sr_ratios=4,\n                 ):\n        super().__init__(step=10, encode_type='direct')\n        self.step = step  # time step\n        self.num_classes = num_classes\n        self.depths = depths\n        \n\n        # for membrane shortcut\n        self.final_lif = MyNode(step=step,tau=2.0)\n        \n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths)]  # stochastic depth decay rule\n\n        patch_embed = SPS(step=step,\n                          img_size_h=img_size_h,\n                          img_size_w=img_size_w,\n                          patch_size=patch_size,\n                          in_channels=in_channels,\n                          embed_dims=embed_dims)\n\n        block = nn.ModuleList([Block(step=step,\n                                     dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratios, qkv_bias=qkv_bias,\n                                     qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j],\n                                     norm_layer=norm_layer, sr_ratio=sr_ratios)\n\n                               for j in range(depths)])\n\n        setattr(self, f\"patch_embed\", patch_embed)\n        setattr(self, f\"block\", block)\n\n        # classification head\n        self.head = nn.Linear(embed_dims, num_classes) if num_classes > 0 else nn.Identity()\n        self.apply(self._init_weights)\n\n    @torch.jit.ignore\n    def _get_pos_embed(self, pos_embed, patch_embed, H, W):\n        if H * W == self.patch_embed1.num_patches:\n            return pos_embed\n        else:\n            return F.interpolate(\n                pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),\n                size=(H, W), mode=\"bilinear\").reshape(1, -1, H * W).permute(0, 2, 1)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def forward_features(self, x):\n\n        block = getattr(self, f\"block\")\n        patch_embed = getattr(self, f\"patch_embed\")\n\n        x = patch_embed(x)\n        for blk in block:\n            x = blk(x)\n        # for membrane shortcut\n        T, B , C, N = x.shape\n        x = self.final_lif(x.flatten(0,1)).reshape(T, B, C, N).contiguous()\n        return x.mean(3)\n\n    def forward(self, x):\n        self.reset()\n        x = self.encoder(x)  \n        x = self.forward_features(x)\n        x = self.head(x.mean(0))\n        return x\n\n\n\n# Adjust ur hyperparams here\n@register_model\ndef sd_transformer(pretrained=False, **kwargs):\n    model = Spikformer(step = 4,\n        img_size_h=224, img_size_w=224,\n        patch_size=16, embed_dims=512, num_heads=16, mlp_ratios=4,\n        in_channels=3, num_classes=1000, qkv_bias=False,\n        depths=8, sr_ratios=1,\n        **kwargs\n    )\n    model.default_cfg = _cfg()\n    return model\n"
  },
  {
    "path": "examples/Spiking-Transformers/models/spike_driven_transformer_dvs.py",
    "content": "import torch\nimport torch.nn as nn\nfrom timm.models.layers import to_2tuple, trunc_normal_, DropPath\nfrom timm.models.registry import register_model\nfrom timm.models.vision_transformer import _cfg\nimport torch.nn.functional as F\nfrom braincog.model_zoo.base_module import BaseModule\nfrom braincog.base.node.node import *\nfrom braincog.base.connection.layer import *\nfrom braincog.base.strategy.surrogate import *\nfrom LIFNode import MyNode  # LIFNode setting for Spiking Tranformers\nfrom functools import partial\n\n__all__ = ['spikformer']\n\n'''The input shape of neuromorphic datasets in Spiking Transformer when using Braincog\nare used to set to 64*64 '''\n\n\n\nclass MLP(BaseModule):\n    #Linear here is subsituted by convs\n    def __init__(self, in_features, step=10, encode_type='direct', hidden_features=None, out_features=None, drop=0.):\n        super().__init__(step=10, encode_type='direct')\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1_conv = nn.Conv1d(in_features, hidden_features, kernel_size=1, stride=1)\n        self.fc1_bn = nn.BatchNorm1d(hidden_features)\n        self.fc1_lif = MyNode(step=step, tau=2.0)\n\n        self.fc2_conv = nn.Conv1d(hidden_features, out_features, kernel_size=1, stride=1)\n        self.fc2_bn = nn.BatchNorm1d(out_features)\n        self.fc2_lif = MyNode(step=step, tau=2.0)\n\n        self.c_hidden = hidden_features\n        self.c_output = out_features\n\n    def forward(self, x):\n        self.reset()\n\n        T, B, C, N = x.shape\n\n        x = self.fc1_lif(x.flatten(0, 1)).reshape(T, B, C, N).contiguous()\n        x = self.fc1_conv(x.flatten(0, 1)) \n        x = self.fc1_bn(x).reshape(T, B, self.c_hidden, N).contiguous()  # T B C N\n        \n        x = self.fc2_lif(x.flatten(0, 1)).reshape(T, B, self.c_hidden, N).contiguous()\n        x = self.fc2_conv(x.flatten(0, 1))\n        x = self.fc2_bn(x).reshape(T, B, C, N).contiguous()\n        \n        return x\n\nclass SSA(BaseModule):\n    def __init__(self, dim, step=10, encode_type='direct', num_heads=16, qkv_bias=False, qk_scale=None, attn_drop=0.,\n                 proj_drop=0., sr_ratio=1):\n        super().__init__(step=10, encode_type='direct')\n        assert dim % num_heads == 0, f\"dim {dim} should be divided by num_heads {num_heads}.\"\n        self.dim = dim\n\n        # for shortcut\n        self.head_lif = MyNode(step=step, tau=2.0)\n\n        self.num_heads = num_heads\n        # scale\n        self.scale = 0.25\n\n        self.q_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)\n        self.q_bn = nn.BatchNorm1d(dim)\n        self.q_lif = MyNode(step=step, tau=2.0)\n\n        self.k_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)\n        self.k_bn = nn.BatchNorm1d(dim)\n        self.k_lif = MyNode(step=step, tau=2.0)\n\n        self.v_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)\n        self.v_bn = nn.BatchNorm1d(dim)\n        self.v_lif = MyNode(step=step, tau=2.0)\n\n        self.attn_drop = nn.Dropout(0.2)\n        self.res_lif = MyNode(step=step, tau=2.0)\n        self.attn_lif = MyNode(step=step, tau=2.0, v_threshold=0.5, )\n\n        self.proj_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)\n        self.proj_bn = nn.BatchNorm1d(dim)\n        self.proj_lif = MyNode(step=step, tau=2.0, )\n\n        self.sd_lif = PLIFNode(step=step, threshold=0.5, tau=2)\n\n    def forward(self, x):\n        self.reset()\n\n        T, B, C, N = x.shape\n\n        x_for_qkv = x.flatten(0, 1)  # TB, C N\n\n        x_for_qkv = self.head_lif(x_for_qkv)\n\n        q_conv_out = self.q_conv(x_for_qkv)  # [TB] C N\n        q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, N).contiguous()  # T B C N\n        q_conv_out = self.q_lif(q_conv_out.flatten(0, 1)).reshape(T, B, C, N)  # TB C N\n        q = q_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()\n\n        k_conv_out = self.k_conv(x_for_qkv)\n        k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, N).contiguous()\n        k_conv_out = self.k_lif(k_conv_out.flatten(0, 1)).reshape(T, B, C, N)  # TB C N\n        k = k_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()\n\n        v_conv_out = self.v_conv(x_for_qkv)\n        v_conv_out = self.v_bn(v_conv_out).reshape(T, B, C, N).contiguous()\n        v_conv_out = self.v_lif(v_conv_out.flatten(0, 1)).reshape(T, B, C, N)  # TB C N\n        v = v_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()\n\n\n        # Spike-driven Transformer attention\n        kv = k.mul(v)\n        kv = kv.sum(dim=-2, keepdim=True)\n        kv = self.sd_lif(kv)\n\n        x = q.mul(kv)\n\n        x = x.transpose(3,4).reshape(T, B, C, N).contiguous() # T B C N\n        # ignore following lines for membrane shortcut\n        # x = self.attn_lif(x.flatten(0,1)) #[TB] C N\n        # x = self.proj_lif(self.proj_bn(self.proj_conv(x))).reshape(T, B, C, N) #T B C N\n        return x\n\n\nclass Block(nn.Module):\n    def __init__(self, dim, num_heads, step=10, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = SSA(dim, step=step, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                        attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = MLP(step=step, in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)\n\n    def forward(self, x):\n        # residual connection\n        x = x + self.attn(x)\n        x = x + self.mlp(x)\n        return x\n\n\n# embed_dims = 256\nclass SPS(BaseModule):\n    def __init__(self, step=10, encode_type='direct', img_size_h=128, img_size_w=128, patch_size=4, in_channels=2,\n                 embed_dims=256):\n        super().__init__(step=10, encode_type='direct')\n        self.image_size = [img_size_h, img_size_w]\n        patch_size = to_2tuple(patch_size)  # 4->(4,4)\n        self.patch_size = patch_size  # patch_size\n        self.C = in_channels  # image_channel\n        self.H, self.W = self.image_size[0] // patch_size[0], self.image_size[1] // patch_size[1]\n        self.num_patches = self.H * self.W\n\n        # DVS with 2 more Maxpooling\n\n        self.proj_conv = nn.Conv2d(in_channels, embed_dims // 8, kernel_size=3, stride=1, padding=1, bias=False)\n        self.proj_bn = nn.BatchNorm2d(embed_dims // 8)\n        self.proj_lif = MyNode(step=step, tau=2.0)\n        self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n\n        self.proj_conv1 = nn.Conv2d(embed_dims // 8, embed_dims // 4, kernel_size=3, stride=1, padding=1, bias=False)\n        self.proj_bn1 = nn.BatchNorm2d(embed_dims // 4)\n        self.proj_lif1 = MyNode(step=step, tau=2.0)\n        self.maxpool1 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n\n        self.proj_conv2 = nn.Conv2d(embed_dims // 4, embed_dims // 2, kernel_size=3, stride=1, padding=1, bias=False)\n        self.proj_bn2 = nn.BatchNorm2d(embed_dims // 2)\n        self.proj_lif2 = MyNode(step=step, tau=2.0)\n        self.maxpool2 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n\n        self.proj_conv3 = nn.Conv2d(embed_dims // 2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)\n        self.proj_bn3 = nn.BatchNorm2d(embed_dims)\n        # self.proj_lif3 = MyNode(step=step, tau=2.0)\n        self.maxpool3 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n\n        self.rpe_conv = nn.Conv2d(embed_dims, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)\n        self.rpe_bn = nn.BatchNorm2d(embed_dims)\n        self.rpe_lif = MyNode(step=step, tau=2.0)\n\n    def forward(self, x):\n        self.reset()\n\n        T, B, C, H, W = x.shape\n\n        x = self.proj_conv(x.flatten(0, 1))  # have some fire value\n        x = self.proj_bn(x).reshape(T, B, -1, H, W).contiguous()\n        x = self.proj_lif(x.flatten(0, 1)).contiguous()\n        x = self.maxpool(x)\n\n        x = self.proj_conv1(x)\n        x = self.proj_bn1(x).reshape(T, B, -1, H // 2, W // 2).contiguous()\n        x = self.proj_lif1(x.flatten(0, 1)).contiguous()\n        x = self.maxpool1(x)\n\n        x = self.proj_conv2(x)\n        x = self.proj_bn2(x).reshape(T, B, -1, H // 4, W // 4).contiguous()\n        x = self.proj_lif2(x.flatten(0, 1)).contiguous()\n        x = self.maxpool2(x)\n\n        x = self.proj_conv3(x)\n        x = self.proj_bn3(x).reshape(T, B, -1, H // 8, W // 8)\n        # abandon the LIF here to leverage membrane shortcut\n        # x = self.proj_lif3(x.flatten(0, 1)).contiguous()  \n        x = self.maxpool3(x.flatten(0,1)).reshape(T, B, -1, H // 16, W // 16)\n\n\n        # The order here is different from spikformer for using membrain shortcut\n        x_rpe = self.rpe_lif(x.flatten(0, 1)).contiguous() \n        x_rpe = self.rpe_bn(self.rpe_conv(x_rpe)).reshape(T, B, -1, H // 16, W // 16).contiguous()\n        \n        x = x + x_rpe # membrane shortcut\n\n        x = x.reshape(T, B, -1, (H // 16) * (H // 16)).contiguous()\n\n        return x  # T B C N\n\n\nclass Spikformer(BaseModule):\n    def __init__(self, step=10, encode_type='direct',\n                 img_size_h=64, img_size_w=64, patch_size=4, in_channels=2, num_classes=10,\n                 embed_dims=256, num_heads=16, mlp_ratios=4, qkv_bias=False, qk_scale=None,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,\n                 depths=2, sr_ratios=4,\n                 ):\n        super().__init__(step=10, encode_type='direct')\n        self.step = step  # time step\n        self.num_classes = num_classes\n        self.depths = depths\n        \n\n        # for membrane shortcut\n        self.final_lif = MyNode(step=step,tau=2.0)\n        \n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths)]  # stochastic depth decay rule\n\n        patch_embed = SPS(step=step,\n                          img_size_h=img_size_h,\n                          img_size_w=img_size_w,\n                          patch_size=patch_size,\n                          in_channels=in_channels,\n                          embed_dims=embed_dims)\n\n        block = nn.ModuleList([Block(step=step,\n                                     dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratios, qkv_bias=qkv_bias,\n                                     qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j],\n                                     norm_layer=norm_layer, sr_ratio=sr_ratios)\n\n                               for j in range(depths)])\n\n        setattr(self, f\"patch_embed\", patch_embed)\n        setattr(self, f\"block\", block)\n\n        # classification head\n        self.head = nn.Linear(embed_dims, num_classes) if num_classes > 0 else nn.Identity()\n        self.apply(self._init_weights)\n\n    @torch.jit.ignore\n    def _get_pos_embed(self, pos_embed, patch_embed, H, W):\n        if H * W == self.patch_embed1.num_patches:\n            return pos_embed\n        else:\n            return F.interpolate(\n                pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),\n                size=(H, W), mode=\"bilinear\").reshape(1, -1, H * W).permute(0, 2, 1)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def forward_features(self, x):\n\n        block = getattr(self, f\"block\")\n        patch_embed = getattr(self, f\"patch_embed\")\n\n        x = patch_embed(x)\n        for blk in block:\n            x = blk(x)\n        # for membrane shortcut\n        T, B , C, N = x.shape\n        x = self.final_lif(x.flatten(0,1)).reshape(T, B, C, N).contiguous()\n        return x.mean(3)\n\n    def forward(self, x):\n        self.reset()\n        x = x.permute(1, 0, 2, 3, 4)  # [T, N, 2, *, *]\n        x = self.forward_features(x)\n        x = self.head(x.mean(0))\n        return x\n\n\n\n# Adjust ur hyperparams here\n@register_model\ndef sd_transformer_dvs(pretrained=False, **kwargs):\n    model = Spikformer(step = 8,\n        img_size_h=64, img_size_w=64,\n        patch_size=4, embed_dims=256, num_heads=16, mlp_ratios=4,\n        in_channels=2, num_classes=10, qkv_bias=False,\n        depths=2, sr_ratios=1,\n        **kwargs\n    )\n    model.default_cfg = _cfg()\n    return model\n"
  },
  {
    "path": "examples/Spiking-Transformers/models/spike_driven_transformer_v2.py",
    "content": "import torch\nimport torch.nn as nn\nfrom timm.models.layers import to_2tuple, trunc_normal_, DropPath\nfrom timm.models.registry import register_model\nfrom timm.models.vision_transformer import _cfg\nimport torch.nn.functional as F\nfrom braincog.model_zoo.base_module import BaseModule\nfrom braincog.base.node.node import *\nfrom braincog.base.connection.layer import *\nfrom braincog.base.strategy.surrogate import *\nfrom LIFNode import MyNode  # LIFNode setting for Spiking Tranformers\nfrom functools import partial\n\n__all__ = ['spikformer']\n\n'''The input shape of neuromorphic datasets in Spiking Transformer when using Braincog\nare used to set to 64*64 '''\n\n'''Here the second version of Spike-driven Transformer only open sourced the\n code for img cla '''\n\n\n# Modified Operators\nclass BNAndPadLayer(nn.Module):\n    def __init__(\n        self,\n        pad_pixels,\n        num_features,\n        eps=1e-5,\n        momentum=0.1,\n        affine=True,\n        track_running_stats=True,\n    ):\n        super(BNAndPadLayer, self).__init__()\n        self.bn = nn.BatchNorm2d(\n            num_features, eps, momentum, affine, track_running_stats\n        )\n        self.pad_pixels = pad_pixels\n\n    def forward(self, input):\n        output = self.bn(input)\n        if self.pad_pixels > 0:\n            if self.bn.affine:\n                pad_values = (\n                    self.bn.bias.detach()\n                    - self.bn.running_mean\n                    * self.bn.weight.detach()\n                    / torch.sqrt(self.bn.running_var + self.bn.eps)\n                )\n            else:\n                pad_values = -self.bn.running_mean / torch.sqrt(\n                    self.bn.running_var + self.bn.eps\n                )\n            output = F.pad(output, [self.pad_pixels] * 4)\n            pad_values = pad_values.view(1, -1, 1, 1)\n            output[:, :, 0 : self.pad_pixels, :] = pad_values\n            output[:, :, -self.pad_pixels :, :] = pad_values\n            output[:, :, :, 0 : self.pad_pixels] = pad_values\n            output[:, :, :, -self.pad_pixels :] = pad_values\n        return output\n    @property\n    def weight(self):\n        return self.bn.weight\n\n    @property\n    def bias(self):\n        return self.bn.bias\n\n    @property\n    def running_mean(self):\n        return self.bn.running_mean\n\n    @property\n    def running_var(self):\n        return self.bn.running_var\n\n    @property\n    def eps(self):\n        return self.bn.eps\n\n\nclass RepConv(nn.Module):\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        bias=False,\n    ):\n        super().__init__()\n        # hidden_channel = in_channel\n        conv1x1 = nn.Conv2d(in_channels, in_channels, 1, 1, 0, bias=False, groups=1)\n        bn = BNAndPadLayer(pad_pixels=1, num_features=in_channels)\n        conv3x3 = nn.Sequential(\n            nn.Conv2d(in_channels, in_channels, 3, 1, 0, groups=in_channels, bias=False),\n            nn.Conv2d(in_channels, out_channels, 1, 1, 0, groups=1, bias=False),\n            nn.BatchNorm2d(out_channels),\n        )\n\n        self.body = nn.Sequential(conv1x1, bn, conv3x3)\n\n    def forward(self, x):\n        return self.body(x)\n\n\nclass SepConv(BaseModule):\n    r\"\"\"\n    Inverted separable convolution from MobileNetV2: https://arxiv.org/abs/1801.04381.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim,\n        step=8,\n        encode_type='direct',\n        expansion_ratio=2,\n        act2_layer=nn.Identity,\n        bias=False,\n        kernel_size=7,\n        padding=3,\n    ):\n        super().__init__(step=step,encode_type=encode_type,)\n        med_channels = int(expansion_ratio * dim)\n        self.lif1 = MyNode(step=step,tau=2.0)\n        self.pwconv1 = nn.Conv2d(dim, med_channels, kernel_size=1, stride=1, bias=bias)\n        self.bn1 = nn.BatchNorm2d(med_channels)\n        self.lif2 =MyNode(step=step,tau=2.0)\n        self.dwconv = nn.Conv2d(\n            med_channels,\n            med_channels,\n            kernel_size=kernel_size,\n            padding=padding,\n            groups=med_channels,\n            bias=bias,\n        )  # depthwise conv\n        self.pwconv2 = nn.Conv2d(med_channels, dim, kernel_size=1, stride=1, bias=bias)\n        self.bn2 = nn.BatchNorm2d(dim)\n\n    def forward(self, x):\n        self.reset()\n        T, B, C, H, W = x.shape\n        x = self.lif1(x.flatten(0,1)).reshape(T,B,C,H,W).contiguous()\n        x = self.bn1(self.pwconv1(x.flatten(0, 1))).reshape(T, B, -1, H, W)\n        x = self.lif2(x.flatten(0,1)).reshape(T,B,-1,H,W).contiguous()\n        x = self.dwconv(x.flatten(0, 1))\n        x = self.bn2(self.pwconv2(x)).reshape(T, B, -1, H, W)\n        return x # T B C H W\n\n\nclass MLP(BaseModule):\n    #Linear here is subsituted by convs\n    def __init__(self, in_features, step=10, encode_type='direct', hidden_features=None, out_features=None, drop=0.):\n        super().__init__(step=10, encode_type='direct')\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1_conv = nn.Conv1d(in_features, hidden_features, kernel_size=1, stride=1)\n        self.fc1_bn = nn.BatchNorm1d(hidden_features)\n        self.fc1_lif = MyNode(step=step, tau=2.0)\n\n        self.fc2_conv = nn.Conv1d(hidden_features, out_features, kernel_size=1, stride=1)\n        self.fc2_bn = nn.BatchNorm1d(out_features)\n        self.fc2_lif = MyNode(step=step, tau=2.0)\n\n        self.c_hidden = hidden_features\n        self.c_output = out_features\n\n    def forward(self, x):\n        self.reset()\n\n        T, B, C, H, W = x.shape\n\n        x = x.flatten(3)  # T B C N\n\n        _, _, _, N = x.shape \n        \n        x = self.fc1_lif(x.flatten(0, 1)).reshape(T, B, C, N).contiguous()\n        x = self.fc1_conv(x.flatten(0, 1)) \n        x = self.fc1_bn(x).reshape(T, B, self.c_hidden, N).contiguous()  # T B C N\n        \n        x = self.fc2_lif(x.flatten(0, 1)).reshape(T, B, self.c_hidden, N).contiguous()\n        x = self.fc2_conv(x.flatten(0, 1))\n        x = self.fc2_bn(x).reshape(T, B, C, H, W).contiguous()\n        \n        return x  # T B C H W\n\n\n# convs in SDSA V3/V4 should be substituted\nclass SDSA(BaseModule):\n    def __init__(self, dim, step=10, encode_type='direct', num_heads=16, qkv_bias=False, qk_scale=None, attn_drop=0.,\n                 proj_drop=0., sr_ratio=1):\n        super().__init__(step=10, encode_type='direct')\n        assert dim % num_heads == 0, f\"dim {dim} should be divided by num_heads {num_heads}.\"\n        self.dim = dim\n\n        self.num_heads = num_heads\n        # scale\n        self.scale = 0.125\n\n        self.head_lif = MyNode(step=step, tau=2.0) # for spike-drivens\n\n        self.q_conv = RepConv(dim, dim, bias=False)\n        self.q_bn = nn.BatchNorm2d(dim)\n        self.q_lif = MyNode(step=step, tau=2.0)\n\n        self.k_conv = RepConv(dim, dim, bias=False)\n        self.k_bn = nn.BatchNorm2d(dim)\n        self.k_lif = MyNode(step=step, tau=2.0)\n\n        self.v_conv = RepConv(dim, dim, bias=False)\n        self.v_bn = nn.BatchNorm2d(dim)\n        self.v_lif = MyNode(step=step, tau=2.0)\n\n        self.attn_drop = nn.Dropout(0.2)\n        self.res_lif = MyNode(step=step, tau=2.0)\n        self.attn_lif = MyNode(step=step, tau=2.0, v_threshold=0.5, )\n\n        self.proj_conv = RepConv(dim, dim, bias=False)\n        self.proj_bn =  nn.BatchNorm2d(dim)\n        \n\n\n    def forward(self, x):\n        self.reset()\n        \n        #different here\n        T, B, C, H, W = x.shape\n\n        N  = H * W\n\n        x = self.head_lif(x.flatten(0,1)).reshape(T, B, C, H, W).contiguous()\n\n        x_for_qkv = x.flatten(0, 1)  # TB C H W\n\n        q_conv_out = self.q_conv(x_for_qkv)  # [TB] C H W\n        q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, H, W).contiguous()  # T B C H W\n        q_conv_out = self.q_lif(q_conv_out.flatten(0, 1)).reshape(T, B, C, N).transpose(-1,-2)  # T B N C\n        q = q_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()\n\n        k_conv_out = self.k_conv(x_for_qkv)\n        k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, H, W).contiguous()\n        k_conv_out = self.k_lif(k_conv_out.flatten(0, 1)).reshape(T, B, C, N).transpose(-1,-2)  # T B N C\n        k = k_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()\n\n        v_conv_out = self.v_conv(x_for_qkv)\n        v_conv_out = self.v_bn(v_conv_out).reshape(T, B, C, H, W).contiguous()\n        v_conv_out = self.v_lif(v_conv_out.flatten(0, 1)).reshape(T, B, C, N).transpose(-1,-2)  # T B N C\n        v = v_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()\n\n        x = k.transpose(-2, -1) @ v\n        x = (q @ x) * self.scale\n\n        x = x.transpose(3, 4).reshape(T, B, C, N).contiguous()\n        x = self.attn_lif(x).reshape(T, B, C, H, W)\n        x = x.reshape(T, B, C, H, W)\n        x = x.flatten(0, 1)\n        x = self.proj_conv(x)\n        x = self.proj_bn(x).reshape(T, B, C, H, W)\n\n        return x # T B C H W\n\n\nclass Block(nn.Module):\n    def __init__(self, dim, num_heads, step=10, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = SDSA(dim, step=step, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                        attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = MLP(step=step, in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)\n\n    def forward(self, x):\n        # residual connection\n        x = x + self.attn(x)\n        x = x + self.mlp(x)\n        return x\n    \nclass DownSampling(BaseModule):\n    def __init__(\n        self,\n        step=10,\n        encode_type='direct',\n        in_channels=2,\n        embed_dims=512,\n        kernel_size=3,\n        stride=2,\n        padding=1,\n        first_layer=True,\n    ):\n        super().__init__(step=step,\n        encode_type=encode_type,)\n\n        self.encode_conv = nn.Conv2d(\n            in_channels,\n            embed_dims,\n            kernel_size=kernel_size,\n            stride=stride,\n            padding=padding,\n        )\n\n        self.encode_bn = nn.BatchNorm2d(embed_dims)\n        if not first_layer:\n            self.encode_lif = MyNode(\n                tau=2.0,step=step\n            )\n\n    def forward(self, x):\n        self.reset()\n        \n        T, B, C, H, W = x.shape\n\n        if hasattr(self, \"encode_lif\"):\n            x = self.encode_lif(x.flatten(0,1)).reshape(T,B,C,H,W).contiguous()\n        x = self.encode_conv(x.flatten(0, 1))\n        _, _, H, W = x.shape\n        x = self.encode_bn(x).reshape(T, B, -1, H, W).contiguous()\n\n        return x\n\nclass ConvBlock(BaseModule):\n    def __init__(\n        self,\n        dim,\n        step=10,\n        encode_type='direct',\n        mlp_ratio=4.0,\n    ):\n        super().__init__(step=step,\n        encode_type=encode_type,)\n\n        self.Conv = SepConv(step=step,dim=dim)\n        # self.Conv = MHMC(dim=dim)\n\n        self.lif1 = MyNode(step=step,tau=2.0)\n        self.conv1 = nn.Conv2d(\n            dim, dim * mlp_ratio, kernel_size=3, padding=1, groups=1, bias=False\n        )\n        # self.conv1 = RepConv(dim, dim*mlp_ratio)\n        self.bn1 = nn.BatchNorm2d(dim * mlp_ratio) \n        self.lif2 = MyNode(step=step,tau=2.0)\n        self.conv2 = nn.Conv2d(\n            dim * mlp_ratio, dim, kernel_size=3, padding=1, groups=1, bias=False\n        )\n        # self.conv2 = RepConv(dim*mlp_ratio, dim)\n        self.bn2 = nn.BatchNorm2d(dim)  \n\n    def forward(self, x):\n        self.reset()\n\n        T, B, C, H, W = x.shape\n\n        x = self.Conv(x) + x\n        x_feat = x\n        x = self.bn1(self.conv1(self.lif1(x.flatten(0,1)))).reshape(T, B, 4 * C, H, W)\n        x = self.bn2(self.conv2(self.lif2(x.flatten(0, 1)))).reshape(T, B, C, H, W)\n        x = x_feat + x\n\n        return x\nclass Spikformer(BaseModule):\n    def __init__(self, step=4, encode_type='direct',\n                 img_size_h=64, img_size_w=64, patch_size=4, in_channels=2, num_classes=1000,\n                 embed_dims=512, num_heads=16, mlp_ratios=4, qkv_bias=False, qk_scale=None,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,\n                 depths=8, sr_ratios=4,kd=False,\n                 ):\n        super().__init__(step=10, encode_type='direct')\n        self.step = step  # time step\n        self.num_classes = num_classes\n        self.depths = depths\n        \n        self.block3_depths = 6\n        # for membrane shortcut\n        self.final_lif = MyNode(step=step,tau=2.0)\n        # channel for dvs\n        # 16 32 64 128 256\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths)]  # stochastic depth decay rule\n        self.downsample1_1 = DownSampling(\n            step=step,\n            in_channels=in_channels,\n            embed_dims=embed_dims // 16,\n            kernel_size=7,\n            stride=2,\n            padding=3,\n            first_layer=True,\n        )\n\n        self.ConvBlock1_1 = nn.ModuleList(\n            [ConvBlock(step=step,dim= embed_dims // 16, mlp_ratio=mlp_ratios)]\n        )\n\n        self.downsample1_2 = DownSampling(\n            step=step,\n            in_channels =  embed_dims // 16,\n            embed_dims= embed_dims // 8,\n            kernel_size=3,\n            stride=2,\n            padding=1,\n            first_layer=False,\n        )\n        self.ConvBlock1_2 = nn.ModuleList(\n            [ConvBlock(step=step,dim=embed_dims // 8, mlp_ratio=mlp_ratios)]\n        )\n\n        self.downsample2 = DownSampling(\n            step=step,\n            in_channels=embed_dims // 8,\n            embed_dims=embed_dims // 4,\n            kernel_size=3,\n            stride=2,\n            padding=1,\n            first_layer=False,\n        )\n\n        self.ConvBlock2_1 = nn.ModuleList(\n            [ConvBlock(step=step,dim=embed_dims // 4, mlp_ratio=mlp_ratios)]\n        )\n\n        self.ConvBlock2_2 = nn.ModuleList(\n            [ConvBlock(step=step,dim=embed_dims // 4, mlp_ratio=mlp_ratios)]\n        )\n\n        self.downsample3 = DownSampling(\n            step=step,\n            in_channels=embed_dims // 4,\n            embed_dims=embed_dims // 2,\n            kernel_size=3,\n            stride=2,\n            padding=1,\n            first_layer=False,\n        )\n\n        self.block3 = nn.ModuleList(\n            [\n                Block(\n                    step=step,\n                    dim=embed_dims // 2,\n                    num_heads=num_heads,\n                    mlp_ratio=mlp_ratios,\n                    qkv_bias=qkv_bias,\n                    qk_scale=qk_scale,\n                    drop=drop_rate,\n                    attn_drop=attn_drop_rate,\n                    # drop_path=dpr[j],\n                    norm_layer=norm_layer,\n                    sr_ratio=sr_ratios,\n                )\n                for j in range(self.block3_depths)\n            ]\n        )\n\n        self.downsample4 = DownSampling(\n            step=step,\n            in_channels=embed_dims // 2,\n            embed_dims=embed_dims,\n            kernel_size=3,\n            stride=1,\n            padding=1,\n            first_layer=False,\n        )\n\n        self.block4 = nn.ModuleList(\n            [\n                Block(\n                    step=step,\n                    dim=embed_dims,\n                    num_heads=num_heads,\n                    mlp_ratio=mlp_ratios,\n                    qkv_bias=qkv_bias,\n                    qk_scale=qk_scale,\n                    drop=drop_rate,\n                    attn_drop=attn_drop_rate,\n                    drop_path=dpr[j],\n                    norm_layer=norm_layer,\n                    sr_ratio=sr_ratios,\n                )\n                for j in range(self.depths-self.block3_depths)\n            ]\n        )\n\n        # classification head\n        self.lif = MyNode(step=step,tau=2.0,)\n        self.head = (\n            nn.Linear(embed_dims, num_classes) if num_classes > 0 else nn.Identity()\n        )\n\n        self.kd = kd\n        if self.kd:\n            self.head_kd = (\n                nn.Linear(embed_dims, num_classes)\n                if num_classes > 0\n                else nn.Identity()\n            )\n        self.apply(self._init_weights)\n\n        # setattr(self, f\"patch_embed\", patch_embed)\n        # setattr(self, f\"block\", block)\n\n    @torch.jit.ignore\n    def _get_pos_embed(self, pos_embed, patch_embed, H, W):\n        if H * W == self.patch_embed1.num_patches:\n            return pos_embed\n        else:\n            return F.interpolate(\n                pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),\n                size=(H, W), mode=\"bilinear\").reshape(1, -1, H * W).permute(0, 2, 1)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def forward_features(self, x):\n        x = self.downsample1_1(x)\n        for blk in self.ConvBlock1_1:\n            x = blk(x)\n        x = self.downsample1_2(x)\n        for blk in self.ConvBlock1_2:\n            x = blk(x)\n\n        x = self.downsample2(x)\n        for blk in self.ConvBlock2_1:\n            x = blk(x)\n        for blk in self.ConvBlock2_2:\n            x = blk(x)\n\n        x = self.downsample3(x)\n        for blk in self.block3: # attention here\n            x = blk(x)\n\n        x = self.downsample4(x) # attention here\n        for blk in self.block4:\n            x = blk(x)\n        return x  # T,B,C,H,W\n    \n    def forward(self, x):\n        self.reset()\n        x = self.encoder(x)    # [T, N, 2, *, *]\n        x = self.forward_features(x)\n        x = x.flatten(3).mean(3)\n        T,B,_ = x.shape\n        x_lif = self.lif(x.flatten(0,1)).reshape(T,B,-1)\n        x = self.head(x_lif).mean(0)\n        if self.kd:\n            x_kd = self.head_kd(x_lif).mean(0)\n            if self.training:\n                return x, x_kd\n            else:\n                return (x + x_kd) / 2\n        return x\n\n\n\n# Adjust ur hyperparams here\n@register_model\ndef sd_transformer_v2(pretrained=False, **kwargs):\n    model = Spikformer(step = 4,\n        img_size_h=224, img_size_w=224,\n        patch_size=16, embed_dims=512, num_heads=12, mlp_ratios=4,\n        in_channels=3, num_classes=1000, qkv_bias=False,\n        depths=2, sr_ratios=1,\n        **kwargs\n    )\n    model.default_cfg = _cfg()\n    return model\n"
  },
  {
    "path": "examples/Spiking-Transformers/models/spike_driven_transformer_v2_dvs.py",
    "content": "import torch\nimport torch.nn as nn\nfrom timm.models.layers import to_2tuple, trunc_normal_, DropPath\nfrom timm.models.registry import register_model\nfrom timm.models.vision_transformer import _cfg\nimport torch.nn.functional as F\nfrom braincog.model_zoo.base_module import BaseModule\nfrom braincog.base.node.node import *\nfrom braincog.base.connection.layer import *\nfrom braincog.base.strategy.surrogate import *\nfrom LIFNode import MyNode  # LIFNode setting for Spiking Tranformers\nfrom functools import partial\n\n__all__ = ['spikformer']\n\n'''The input shape of neuromorphic datasets in Spiking Transformer when using Braincog\nare used to set to 64*64 '''\n\n'''Here the second version of Spike-driven Transformer only open sourced the\n code for img cla '''\n\n\n# Modified Operators\nclass BNAndPadLayer(nn.Module):\n    def __init__(\n        self,\n        pad_pixels,\n        num_features,\n        eps=1e-5,\n        momentum=0.1,\n        affine=True,\n        track_running_stats=True,\n    ):\n        super(BNAndPadLayer, self).__init__()\n        self.bn = nn.BatchNorm2d(\n            num_features, eps, momentum, affine, track_running_stats\n        )\n        self.pad_pixels = pad_pixels\n\n    def forward(self, input):\n        output = self.bn(input)\n        if self.pad_pixels > 0:\n            if self.bn.affine:\n                pad_values = (\n                    self.bn.bias.detach()\n                    - self.bn.running_mean\n                    * self.bn.weight.detach()\n                    / torch.sqrt(self.bn.running_var + self.bn.eps)\n                )\n            else:\n                pad_values = -self.bn.running_mean / torch.sqrt(\n                    self.bn.running_var + self.bn.eps\n                )\n            output = F.pad(output, [self.pad_pixels] * 4)\n            pad_values = pad_values.view(1, -1, 1, 1)\n            output[:, :, 0 : self.pad_pixels, :] = pad_values\n            output[:, :, -self.pad_pixels :, :] = pad_values\n            output[:, :, :, 0 : self.pad_pixels] = pad_values\n            output[:, :, :, -self.pad_pixels :] = pad_values\n        return output\n    @property\n    def weight(self):\n        return self.bn.weight\n\n    @property\n    def bias(self):\n        return self.bn.bias\n\n    @property\n    def running_mean(self):\n        return self.bn.running_mean\n\n    @property\n    def running_var(self):\n        return self.bn.running_var\n\n    @property\n    def eps(self):\n        return self.bn.eps\n\n\nclass RepConv(nn.Module):\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        bias=False,\n    ):\n        super().__init__()\n        # hidden_channel = in_channel\n        conv1x1 = nn.Conv2d(in_channels, in_channels, 1, 1, 0, bias=False, groups=1)\n        bn = BNAndPadLayer(pad_pixels=1, num_features=in_channels)\n        conv3x3 = nn.Sequential(\n            nn.Conv2d(in_channels, in_channels, 3, 1, 0, groups=in_channels, bias=False),\n            nn.Conv2d(in_channels, out_channels, 1, 1, 0, groups=1, bias=False),\n            nn.BatchNorm2d(out_channels),\n        )\n\n        self.body = nn.Sequential(conv1x1, bn, conv3x3)\n\n    def forward(self, x):\n        return self.body(x)\n\n\nclass SepConv(BaseModule):\n    r\"\"\"\n    Inverted separable convolution from MobileNetV2: https://arxiv.org/abs/1801.04381.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim,\n        step=8,\n        encode_type='direct',\n        expansion_ratio=2,\n        act2_layer=nn.Identity,\n        bias=False,\n        kernel_size=7,\n        padding=3,\n    ):\n        super().__init__(step=step,encode_type=encode_type,)\n        med_channels = int(expansion_ratio * dim)\n        self.lif1 = MyNode(step=step,tau=2.0)\n        self.pwconv1 = nn.Conv2d(dim, med_channels, kernel_size=1, stride=1, bias=bias)\n        self.bn1 = nn.BatchNorm2d(med_channels)\n        self.lif2 =MyNode(step=step,tau=2.0)\n        self.dwconv = nn.Conv2d(\n            med_channels,\n            med_channels,\n            kernel_size=kernel_size,\n            padding=padding,\n            groups=med_channels,\n            bias=bias,\n        )  # depthwise conv\n        self.pwconv2 = nn.Conv2d(med_channels, dim, kernel_size=1, stride=1, bias=bias)\n        self.bn2 = nn.BatchNorm2d(dim)\n\n    def forward(self, x):\n        self.reset()\n        T, B, C, H, W = x.shape\n        x = self.lif1(x.flatten(0,1)).reshape(T,B,C,H,W).contiguous()\n        x = self.bn1(self.pwconv1(x.flatten(0, 1))).reshape(T, B, -1, H, W)\n        x = self.lif2(x.flatten(0,1)).reshape(T,B,-1,H,W).contiguous()\n        x = self.dwconv(x.flatten(0, 1))\n        x = self.bn2(self.pwconv2(x)).reshape(T, B, -1, H, W)\n        return x # T B C H W\n\n\nclass MLP(BaseModule):\n    #Linear here is subsituted by convs\n    def __init__(self, in_features, step=10, encode_type='direct', hidden_features=None, out_features=None, drop=0.):\n        super().__init__(step=10, encode_type='direct')\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1_conv = nn.Conv1d(in_features, hidden_features, kernel_size=1, stride=1)\n        self.fc1_bn = nn.BatchNorm1d(hidden_features)\n        self.fc1_lif = MyNode(step=step, tau=2.0)\n\n        self.fc2_conv = nn.Conv1d(hidden_features, out_features, kernel_size=1, stride=1)\n        self.fc2_bn = nn.BatchNorm1d(out_features)\n        self.fc2_lif = MyNode(step=step, tau=2.0)\n\n        self.c_hidden = hidden_features\n        self.c_output = out_features\n\n    def forward(self, x):\n        self.reset()\n\n        T, B, C, H, W = x.shape\n\n        x = x.flatten(3)  # T B C N\n\n        _, _, _, N = x.shape \n        \n        x = self.fc1_lif(x.flatten(0, 1)).reshape(T, B, C, N).contiguous()\n        x = self.fc1_conv(x.flatten(0, 1)) \n        x = self.fc1_bn(x).reshape(T, B, self.c_hidden, N).contiguous()  # T B C N\n        \n        x = self.fc2_lif(x.flatten(0, 1)).reshape(T, B, self.c_hidden, N).contiguous()\n        x = self.fc2_conv(x.flatten(0, 1))\n        x = self.fc2_bn(x).reshape(T, B, C, H, W).contiguous()\n        \n        return x  # T B C H W\n\n\n# convs in SDSA V3/V4 should be substituted\nclass SDSA(BaseModule):\n    def __init__(self, dim, step=10, encode_type='direct', num_heads=16, qkv_bias=False, qk_scale=None, attn_drop=0.,\n                 proj_drop=0., sr_ratio=1):\n        super().__init__(step=10, encode_type='direct')\n        assert dim % num_heads == 0, f\"dim {dim} should be divided by num_heads {num_heads}.\"\n        self.dim = dim\n\n        self.num_heads = num_heads\n        # scale\n        self.scale = 0.125\n\n        self.head_lif = MyNode(step=step, tau=2.0) # for spike-drivens\n\n        self.q_conv = RepConv(dim, dim, bias=False)\n        self.q_bn = nn.BatchNorm2d(dim)\n        self.q_lif = MyNode(step=step, tau=2.0)\n\n        self.k_conv = RepConv(dim, dim, bias=False)\n        self.k_bn = nn.BatchNorm2d(dim)\n        self.k_lif = MyNode(step=step, tau=2.0)\n\n        self.v_conv = RepConv(dim, dim, bias=False)\n        self.v_bn = nn.BatchNorm2d(dim)\n        self.v_lif = MyNode(step=step, tau=2.0)\n\n        self.attn_drop = nn.Dropout(0.2)\n        self.res_lif = MyNode(step=step, tau=2.0)\n        self.attn_lif = MyNode(step=step, tau=2.0, v_threshold=0.5, )\n\n        self.proj_conv = RepConv(dim, dim, bias=False)\n        self.proj_bn =  nn.BatchNorm2d(dim)\n        \n\n\n    def forward(self, x):\n        self.reset()\n        \n        #different here\n        T, B, C, H, W = x.shape\n\n        N  = H * W\n\n        x = self.head_lif(x.flatten(0,1)).reshape(T, B, C, H, W).contiguous()\n\n        x_for_qkv = x.flatten(0, 1)  # TB C H W\n\n        q_conv_out = self.q_conv(x_for_qkv)  # [TB] C H W\n        q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, H, W).contiguous()  # T B C H W\n        q_conv_out = self.q_lif(q_conv_out.flatten(0, 1)).reshape(T, B, C, N).transpose(-1,-2)  # T B N C\n        q = q_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()\n\n        k_conv_out = self.k_conv(x_for_qkv)\n        k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, H, W).contiguous()\n        k_conv_out = self.k_lif(k_conv_out.flatten(0, 1)).reshape(T, B, C, N).transpose(-1,-2)  # T B N C\n        k = k_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()\n\n        v_conv_out = self.v_conv(x_for_qkv)\n        v_conv_out = self.v_bn(v_conv_out).reshape(T, B, C, H, W).contiguous()\n        v_conv_out = self.v_lif(v_conv_out.flatten(0, 1)).reshape(T, B, C, N).transpose(-1,-2)  # T B N C\n        v = v_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()\n\n        x = k.transpose(-2, -1) @ v\n        x = (q @ x) * self.scale\n\n        x = x.transpose(3, 4).reshape(T, B, C, N).contiguous()\n        x = self.attn_lif(x).reshape(T, B, C, H, W)\n        x = x.reshape(T, B, C, H, W)\n        x = x.flatten(0, 1)\n        x = self.proj_conv(x)\n        x = self.proj_bn(x).reshape(T, B, C, H, W)\n\n        return x # T B C H W\n\n\nclass Block(nn.Module):\n    def __init__(self, dim, num_heads, step=10, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = SDSA(dim, step=step, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                        attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = MLP(step=step, in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)\n\n    def forward(self, x):\n        # residual connection\n        x = x + self.attn(x)\n        x = x + self.mlp(x)\n        return x\n    \nclass DownSampling(BaseModule):\n    def __init__(\n        self,\n        step=10,\n        encode_type='direct',\n        in_channels=2,\n        embed_dims=256,\n        kernel_size=3,\n        stride=2,\n        padding=1,\n        first_layer=True,\n    ):\n        super().__init__(step=step,\n        encode_type=encode_type,)\n\n        self.encode_conv = nn.Conv2d(\n            in_channels,\n            embed_dims,\n            kernel_size=kernel_size,\n            stride=stride,\n            padding=padding,\n        )\n\n        self.encode_bn = nn.BatchNorm2d(embed_dims)\n        if not first_layer:\n            self.encode_lif = MyNode(\n                tau=2.0,step=step\n            )\n\n    def forward(self, x):\n        self.reset()\n        \n        T, B, C, H, W = x.shape\n\n        if hasattr(self, \"encode_lif\"):\n            x = self.encode_lif(x.flatten(0,1)).reshape(T,B,C,H,W).contiguous()\n        x = self.encode_conv(x.flatten(0, 1))\n        _, _, H, W = x.shape\n        x = self.encode_bn(x).reshape(T, B, -1, H, W).contiguous()\n\n        return x\n\nclass ConvBlock(BaseModule):\n    def __init__(\n        self,\n        dim,\n        step=10,\n        encode_type='direct',\n        mlp_ratio=4.0,\n    ):\n        super().__init__(step=step,\n        encode_type=encode_type,)\n\n        self.Conv = SepConv(step=step,dim=dim)\n        # self.Conv = MHMC(dim=dim)\n\n        self.lif1 = MyNode(step=step,tau=2.0)\n        self.conv1 = nn.Conv2d(\n            dim, dim * mlp_ratio, kernel_size=3, padding=1, groups=1, bias=False\n        )\n        # self.conv1 = RepConv(dim, dim*mlp_ratio)\n        self.bn1 = nn.BatchNorm2d(dim * mlp_ratio) \n        self.lif2 = MyNode(step=step,tau=2.0)\n        self.conv2 = nn.Conv2d(\n            dim * mlp_ratio, dim, kernel_size=3, padding=1, groups=1, bias=False\n        )\n        # self.conv2 = RepConv(dim*mlp_ratio, dim)\n        self.bn2 = nn.BatchNorm2d(dim)  \n\n    def forward(self, x):\n        self.reset()\n\n        T, B, C, H, W = x.shape\n\n        x = self.Conv(x) + x\n        x_feat = x\n        x = self.bn1(self.conv1(self.lif1(x.flatten(0,1)))).reshape(T, B, 4 * C, H, W)\n        x = self.bn2(self.conv2(self.lif2(x.flatten(0, 1)))).reshape(T, B, C, H, W)\n        x = x_feat + x\n\n        return x\nclass Spikformer(BaseModule):\n    def __init__(self, step=10, encode_type='direct',\n                 img_size_h=64, img_size_w=64, patch_size=4, in_channels=2, num_classes=10,\n                 embed_dims=256, num_heads=16, mlp_ratios=4, qkv_bias=False, qk_scale=None,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,\n                 depths=2, sr_ratios=4,kd=False,\n                 ):\n        super().__init__(step=10, encode_type='direct')\n        self.step = step  # time step\n        self.num_classes = num_classes\n        self.depths = depths\n        \n        self.block3_depths = 1\n        # for membrane shortcut\n        self.final_lif = MyNode(step=step,tau=2.0)\n        # channel for dvs\n        # 16 32 64 128 256\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths)]  # stochastic depth decay rule\n        self.downsample1_1 = DownSampling(\n            step=step,\n            in_channels=in_channels,\n            embed_dims=embed_dims // 16,\n            kernel_size=7,\n            stride=2,\n            padding=3,\n            first_layer=True,\n        )\n\n        self.ConvBlock1_1 = nn.ModuleList(\n            [ConvBlock(step=step,dim= embed_dims // 16, mlp_ratio=mlp_ratios)]\n        )\n\n        self.downsample1_2 = DownSampling(\n            step=step,\n            in_channels =  embed_dims // 16,\n            embed_dims= embed_dims // 8,\n            kernel_size=3,\n            stride=2,\n            padding=1,\n            first_layer=False,\n        )\n        self.ConvBlock1_2 = nn.ModuleList(\n            [ConvBlock(step=step,dim=embed_dims // 8, mlp_ratio=mlp_ratios)]\n        )\n\n        self.downsample2 = DownSampling(\n            step=step,\n            in_channels=embed_dims // 8,\n            embed_dims=embed_dims // 4,\n            kernel_size=3,\n            stride=2,\n            padding=1,\n            first_layer=False,\n        )\n\n        self.ConvBlock2_1 = nn.ModuleList(\n            [ConvBlock(step=step,dim=embed_dims // 4, mlp_ratio=mlp_ratios)]\n        )\n\n        self.ConvBlock2_2 = nn.ModuleList(\n            [ConvBlock(step=step,dim=embed_dims // 4, mlp_ratio=mlp_ratios)]\n        )\n\n        self.downsample3 = DownSampling(\n            step=step,\n            in_channels=embed_dims // 4,\n            embed_dims=embed_dims // 2,\n            kernel_size=3,\n            stride=2,\n            padding=1,\n            first_layer=False,\n        )\n\n        self.block3 = nn.ModuleList(\n            [\n                Block(\n                    step=step,\n                    dim=embed_dims // 2,\n                    num_heads=num_heads,\n                    mlp_ratio=mlp_ratios,\n                    qkv_bias=qkv_bias,\n                    qk_scale=qk_scale,\n                    drop=drop_rate,\n                    attn_drop=attn_drop_rate,\n                    # drop_path=dpr[j],\n                    norm_layer=norm_layer,\n                    sr_ratio=sr_ratios,\n                )\n                for j in range(self.block3_depths)\n            ]\n        )\n\n        self.downsample4 = DownSampling(\n            step=step,\n            in_channels=embed_dims // 2,\n            embed_dims=embed_dims,\n            kernel_size=3,\n            stride=1,\n            padding=1,\n            first_layer=False,\n        )\n\n        self.block4 = nn.ModuleList(\n            [\n                Block(\n                    step=step,\n                    dim=embed_dims,\n                    num_heads=num_heads,\n                    mlp_ratio=mlp_ratios,\n                    qkv_bias=qkv_bias,\n                    qk_scale=qk_scale,\n                    drop=drop_rate,\n                    attn_drop=attn_drop_rate,\n                    drop_path=dpr[j],\n                    norm_layer=norm_layer,\n                    sr_ratio=sr_ratios,\n                )\n                for j in range(self.depths-self.block3_depths)\n            ]\n        )\n\n        # classification head\n        self.lif = MyNode(step=step,tau=2.0,)\n        self.head = (\n            nn.Linear(embed_dims, num_classes) if num_classes > 0 else nn.Identity()\n        )\n\n        self.kd = kd\n        if self.kd:\n            self.head_kd = (\n                nn.Linear(embed_dims, num_classes)\n                if num_classes > 0\n                else nn.Identity()\n            )\n        self.apply(self._init_weights)\n\n        # setattr(self, f\"patch_embed\", patch_embed)\n        # setattr(self, f\"block\", block)\n\n    @torch.jit.ignore\n    def _get_pos_embed(self, pos_embed, patch_embed, H, W):\n        if H * W == self.patch_embed1.num_patches:\n            return pos_embed\n        else:\n            return F.interpolate(\n                pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),\n                size=(H, W), mode=\"bilinear\").reshape(1, -1, H * W).permute(0, 2, 1)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def forward_features(self, x):\n        x = self.downsample1_1(x)\n        for blk in self.ConvBlock1_1:\n            x = blk(x)\n        x = self.downsample1_2(x)\n        for blk in self.ConvBlock1_2:\n            x = blk(x)\n\n        x = self.downsample2(x)\n        for blk in self.ConvBlock2_1:\n            x = blk(x)\n        for blk in self.ConvBlock2_2:\n            x = blk(x)\n\n        x = self.downsample3(x)\n        for blk in self.block3: # attention here\n            x = blk(x)\n\n        x = self.downsample4(x) # attention here\n        for blk in self.block4:\n            x = blk(x)\n        return x  # T,B,C,H,W\n    \n    def forward(self, x):\n        self.reset()\n        x = x.permute(1, 0, 2, 3, 4)  # [T, N, 2, *, *]\n        x = self.forward_features(x)\n        x = x.flatten(3).mean(3)\n        T,B,_ = x.shape\n        x_lif = self.lif(x.flatten(0,1)).reshape(T,B,-1)\n        x = self.head(x_lif).mean(0)\n        if self.kd:\n            x_kd = self.head_kd(x_lif).mean(0)\n            if self.training:\n                return x, x_kd\n            else:\n                return (x + x_kd) / 2\n        return x\n\n\n\n# Adjust ur hyperparams here\n@register_model\ndef sd_transformer_v2_dvs(pretrained=False, **kwargs):\n    model = Spikformer(step = 8,\n        img_size_h=64, img_size_w=64,\n        patch_size=4, embed_dims=256, num_heads=16, mlp_ratios=4,\n        in_channels=2, num_classes=10, qkv_bias=False,\n        depths=2, sr_ratios=1,\n        **kwargs\n    )\n    model.default_cfg = _cfg()\n    return model\n"
  },
  {
    "path": "examples/Spiking-Transformers/models/spikformer.py",
    "content": "import torch\nimport torch.nn as nn\nfrom timm.models.layers import to_2tuple, trunc_normal_, DropPath\nfrom timm.models.registry import register_model\nfrom timm.models.vision_transformer import _cfg\nimport torch.nn.functional as F\nfrom braincog.model_zoo.base_module import BaseModule\nfrom braincog.base.node.node import *\nfrom braincog.base.connection.layer import *\nfrom braincog.base.strategy.surrogate import *\nfrom LIFNode import MyNode  # LIFNode setting for Spiking Tranformers\nfrom functools import partial\n\n__all__ = ['spikformer']\n\n\nclass MLP(BaseModule):\n    # Linear -> BN -> LIF -> Linear -> BN -> LIF\n    def __init__(self, in_features, step=4, encode_type='direct', hidden_features=None, out_features=None, drop=0.):\n        super().__init__(step=step, encode_type=encode_type)\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1_linear = nn.Linear(in_features, hidden_features)\n        self.fc1_bn = nn.BatchNorm1d(hidden_features)\n        self.fc1_lif = MyNode(step=step,tau=2.0)\n\n        self.fc2_linear = nn.Linear(hidden_features, out_features)\n        self.fc2_bn = nn.BatchNorm1d(out_features)\n        self.fc2_lif = MyNode(step=step,tau=2.0)\n\n        self.c_hidden = hidden_features\n        self.c_output = out_features\n\n    def forward(self, x):\n        self.reset()\n\n        T, B, N, C = x.shape\n\n        x_ = x.flatten(0, 1)  # TB N C\n\n        x = self.fc1_linear(x_)\n        x = self.fc1_bn(x.transpose(-1, -2)).transpose(-1, -2).reshape(T, B, N, self.c_hidden).contiguous()  # T B N C\n        x = self.fc1_lif(x.flatten(0, 1)).reshape(T, B, N, self.c_hidden)\n\n        x = self.fc2_linear(x.flatten(0, 1))\n        x = self.fc2_bn(x.transpose(-1, -2)).transpose(-1, -2).reshape(T, B, N, C).contiguous()\n        x = self.fc2_lif(x.flatten(0, 1)).reshape(T, B, N, self.c_output)\n        return x\n\n\nclass SSA(BaseModule):\n    def __init__(self, dim, step=4, encode_type='rate', num_heads=12, qkv_bias=False, qk_scale=None, attn_drop=0.,\n                 proj_drop=0., sr_ratio=1):\n        super().__init__(step=step, encode_type=encode_type)\n        assert dim % num_heads == 0, f\"dim {dim} should be divided by num_heads {num_heads}.\"\n        self.dim = dim\n        # 多头注意力 # of heads\n        self.num_heads = num_heads\n        # scale参数，用于防止KQ乘积结果过大\n        self.scale = 0.125\n\n        self.q_linear = nn.Linear(dim, dim)\n        self.q_bn = nn.BatchNorm1d(dim)\n        self.q_lif = MyNode(step=step,tau=2.0)\n\n        self.k_linear = nn.Linear(dim, dim)\n        self.k_bn = nn.BatchNorm1d(dim)\n        self.k_lif = MyNode(step=step,tau=2.0)\n\n        self.v_linear = nn.Linear(dim, dim)\n        self.v_bn = nn.BatchNorm1d(dim)\n        self.v_lif = MyNode(step=step,tau=2.0)\n\n        self.attn_lif = MyNode(step=step, tau=2.0, v_threshold=0.5, )\n\n        self.proj_linear = nn.Linear(dim, dim)\n        self.proj_bn = nn.BatchNorm1d(dim)\n        self.proj_lif = MyNode(step=step, tau=2.0, )\n\n    def forward(self, x):\n        self.reset()\n\n        T, B, N, C = x.shape\n\n        x_for_qkv = x.flatten(0, 1)  # TB, N, C\n\n        q_linear_out = self.q_linear(x_for_qkv)  # [TB, N, C]\n        q_linear_out = self.q_bn(q_linear_out.transpose(-1, -2)).transpose(-1, -2).reshape(T, B, N,\n                                                                                           C).contiguous()  # T B N C\n        q_linear_out = self.q_lif(q_linear_out.flatten(0, 1)).reshape(T, B, N, C)  # TB N C\n        q = q_linear_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()\n\n        k_linear_out = self.k_linear(x_for_qkv)\n        k_linear_out = self.k_bn(k_linear_out.transpose(-1, -2)).transpose(-1, -2).reshape(T, B, N, C).contiguous()\n        k_linear_out = self.k_lif(k_linear_out.flatten(0, 1)).reshape(T, B, N, C)  # TB N C\n        k = k_linear_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()\n\n        v_linear_out = self.v_linear(x_for_qkv)\n        v_linear_out = self.v_bn(v_linear_out.transpose(-1, -2)).transpose(-1, -2).reshape(T, B, N, C).contiguous()\n        v_linear_out = self.v_lif(v_linear_out.flatten(0, 1)).reshape(T, B, N, C)  # TB N C\n        v = v_linear_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()\n\n        # @表示矩阵乘法,与matmul等价\n        # K,QV -> attention -> scale -> LIF -> Linear -> BN\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        x = attn @ v\n        x = x.transpose(2, 3).reshape(T, B, N, C).contiguous()\n        x = self.attn_lif(x.flatten(0, 1)).reshape(T, B, N, C)  # T B N C\n        x = x.flatten(0, 1)  # TB N C\n        x = self.proj_lif(self.proj_bn(self.proj_linear(x).transpose(-1, -2)).transpose(-1, -2)).reshape(T, B, N, C)\n        return x\n\n\nclass Block(nn.Module):\n    def __init__(self, dim, num_heads, step =4,  mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):\n        super().__init__()\n        self.step = 4\n        self.norm1 = norm_layer(dim)\n        self.attn = SSA(dim, step=self.step, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                        attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = MLP(step=self.step, in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)\n\n    def forward(self, x):\n        # residual\n        x = x + self.attn(x)\n        x = x + self.mlp(x)\n        return x\n\n\n# SPS for dimension adjustment\n# embed_dims = 256\nclass SPS(BaseModule):\n    def __init__(self, step=4, encode_type='direct', img_size_h=32, img_size_w=32, patch_size=4, in_channels=3,\n                 embed_dims=384):\n        super().__init__(step=step, encode_type=encode_type)\n        self.image_size = [img_size_h, img_size_w]\n\n        # timm内置to_2tuple把整形转换成2元元组\n        patch_size = to_2tuple(patch_size)  # 4->(4,4)\n        self.patch_size = patch_size  # patch_size\n        self.C = in_channels  # image_channel\n        self.H, self.W = self.image_size[0] // patch_size[0], self.image_size[1] // patch_size[1]  # 重新计算patch之后的图片大小\n        self.num_patches = self.H * self.W  # patch数量\n\n        self.proj_conv = nn.Conv2d(in_channels, embed_dims // 8, kernel_size=3, stride=1, padding=1, bias=False)\n        self.proj_bn = nn.BatchNorm2d(embed_dims // 8)\n        self.proj_lif = MyNode(step=step,tau=2.0)\n\n        self.proj_conv1 = nn.Conv2d(embed_dims // 8, embed_dims // 4, kernel_size=3, stride=1, padding=1, bias=False)\n        self.proj_bn1 = nn.BatchNorm2d(embed_dims // 4)\n        self.proj_lif1 = MyNode(step=step,tau=2.0)\n\n        self.proj_conv2 = nn.Conv2d(embed_dims // 4, embed_dims // 2, kernel_size=3, stride=1, padding=1, bias=False)\n        self.proj_bn2 = nn.BatchNorm2d(embed_dims // 2)\n        self.proj_lif2 = MyNode(step=step,tau=2.0)\n        self.maxpool2 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n\n        self.proj_conv3 = nn.Conv2d(embed_dims // 2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)\n        self.proj_bn3 = nn.BatchNorm2d(embed_dims)\n        self.proj_lif3 = MyNode(step=step,tau=2.0)\n        self.maxpool3 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n\n        self.rpe_conv = nn.Conv2d(embed_dims, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)\n        self.rpe_bn = nn.BatchNorm2d(embed_dims)\n        self.rpe_lif = MyNode(step=step,tau=2.0)\n\n    def forward(self, x):\n        self.reset()\n\n        T, B, C, H, W = x.shape\n\n        x = self.proj_conv(x.flatten(0, 1))  # have some fire value\n        x = self.proj_bn(x).reshape(T, B, -1, H, W).contiguous()\n        x = self.proj_lif(x.flatten(0, 1)).contiguous()\n\n        x = self.proj_conv1(x)\n        x = self.proj_bn1(x).reshape(T, B, -1, H, W).contiguous()\n        x = self.proj_lif1(x.flatten(0, 1)).contiguous()\n\n        x = self.proj_conv2(x)\n        x = self.proj_bn2(x).reshape(T, B, -1, H, W).contiguous()\n        x = self.proj_lif2(x.flatten(0, 1)).contiguous()\n        x = self.maxpool2(x)\n\n        x = self.proj_conv3(x)\n        x = self.proj_bn3(x).reshape(T, B, -1, H // 2, W // 2).contiguous()\n        x = self.proj_lif3(x.flatten(0, 1)).contiguous()\n        x = self.maxpool3(x)\n\n        x_feat = x.reshape(T, B, -1, H // 4, W // 4).contiguous()\n        x = self.rpe_conv(x)\n        x = self.rpe_bn(x).reshape(T, B, -1, H // 4, W // 4).contiguous()\n        x = self.rpe_lif(x.flatten(0, 1)).reshape(T, B, -1, H // 4, W // 4)\n\n        x = x + x_feat\n\n        x = x.flatten(-2).transpose(-1, -2)  # T,B,N,C\n        return x\n\n\nclass Spikformer(BaseModule):\n    def __init__(self, step=4, encode_type='direct',\n                 img_size_h=224, img_size_w=224, patch_size=16, in_channels=3, num_classes=1000,\n                 embed_dims=384, num_heads=12, mlp_ratios=4, qkv_bias=False, qk_scale=None,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,\n                 depths=4, sr_ratios=4,\n                 ):\n        super().__init__(step=step, encode_type=encode_type)\n        self.step = step  # time step\n        self.num_classes = num_classes\n        self.depths = depths\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths)]  # stochastic depth decay rule\n\n        patch_embed = SPS(step = self.step,\n                          img_size_h=img_size_h,\n                          img_size_w=img_size_w,\n                          patch_size=patch_size,\n                          in_channels=in_channels,\n                          embed_dims=embed_dims)\n\n        block = nn.ModuleList([Block(step=self.step,\n            dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratios, qkv_bias=qkv_bias,\n            qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j],\n            norm_layer=norm_layer, sr_ratio=sr_ratios)\n\n            for j in range(depths)])\n\n        setattr(self, f\"patch_embed\", patch_embed)\n        setattr(self, f\"block\", block)\n\n        # classification head\n        self.head = nn.Linear(embed_dims, num_classes) if num_classes > 0 else nn.Identity()\n        self.apply(self._init_weights)\n\n    @torch.jit.ignore\n    def _get_pos_embed(self, pos_embed, patch_embed, H, W):\n        if H * W == self.patch_embed1.num_patches:\n            return pos_embed\n        else:\n            return F.interpolate(\n                pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),\n                size=(H, W), mode=\"bilinear\").reshape(1, -1, H * W).permute(0, 2, 1)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def forward_features(self, x):\n\n        block = getattr(self, f\"block\")\n        patch_embed = getattr(self, f\"patch_embed\")\n\n        x = patch_embed(x)\n        for blk in block:\n            x = blk(x)\n        return x.mean(2)\n\n    def forward(self, x):\n        x = self.encoder(x)\n        x = self.forward_features(x)\n        x = self.head(x.mean(0))\n        return x\n\n\n@register_model\ndef spikformer(pretrained=False, **kwargs):\n    model = Spikformer(\n        step=4,\n        img_size_h=224, img_size_w=224,\n        patch_size=16, embed_dims=512, num_heads=8, mlp_ratios=4,\n        in_channels=3, num_classes=1000, qkv_bias=False,\n        depths=8, sr_ratios=1,\n        **kwargs\n    )\n    model.default_cfg = _cfg()\n    return model\n\n\n"
  },
  {
    "path": "examples/Spiking-Transformers/models/spikformer_dvs.py",
    "content": "import torch\nimport torch.nn as nn\nfrom timm.models.layers import to_2tuple, trunc_normal_, DropPath\nfrom timm.models.registry import register_model\nfrom timm.models.vision_transformer import _cfg\nimport torch.nn.functional as F\nfrom braincog.model_zoo.base_module import BaseModule\nfrom braincog.base.node.node import *\nfrom braincog.base.connection.layer import *\nfrom braincog.base.strategy.surrogate import *\nfrom LIFNode import MyNode  # LIFNode setting for Spiking Tranformers\nfrom functools import partial\n\n__all__ = ['spikformer']\n\n'''The input shape of neuromorphic datasets in Spiking Transformer when using Braincog\nare used to set to 64*64 '''\n\n\nclass MLP(BaseModule):\n    #Linear here is subsituted by convs\n    def __init__(self, in_features, step=10, encode_type='direct', hidden_features=None, out_features=None, drop=0.):\n        super().__init__(step=step, encode_type=encode_type)\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1_conv = nn.Conv1d(in_features, hidden_features, kernel_size=1, stride=1)\n        self.fc1_bn = nn.BatchNorm1d(hidden_features)\n        self.fc1_lif = MyNode(step=step, tau=2.0)\n\n        self.fc2_conv = nn.Conv1d(hidden_features, out_features, kernel_size=1, stride=1)\n        self.fc2_bn = nn.BatchNorm1d(out_features)\n        self.fc2_lif = MyNode(step=step, tau=2.0)\n\n        self.c_hidden = hidden_features\n        self.c_output = out_features\n\n    def forward(self, x):\n        self.reset()\n\n        T, B, C, N = x.shape\n\n        x = self.fc1_conv(x.flatten(0, 1))\n        x = self.fc1_bn(x).reshape(T, B, self.c_hidden, N).contiguous()  # T B C N\n        x = self.fc1_lif(x.flatten(0, 1)).reshape(T, B, self.c_hidden, N).contiguous()\n\n        x = self.fc2_conv(x.flatten(0, 1))\n        x = self.fc2_bn(x).reshape(T, B, C, N).contiguous()\n        x = self.fc2_lif(x.flatten(0, 1)).reshape(T, B, C, N).contiguous()\n        return x\n\n\nclass SSA(BaseModule):\n    def __init__(self, dim, step=10, encode_type='direct', num_heads=16, qkv_bias=False, qk_scale=None, attn_drop=0.,\n                 proj_drop=0., sr_ratio=1):\n        super().__init__(step=step, encode_type=encode_type)\n        assert dim % num_heads == 0, f\"dim {dim} should be divided by num_heads {num_heads}.\"\n        self.dim = dim\n\n        self.num_heads = num_heads\n        # scale\n        self.scale = 0.25\n\n        self.q_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)\n        self.q_bn = nn.BatchNorm1d(dim)\n        self.q_lif = MyNode(step=step, tau=2.0)\n\n        self.k_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)\n        self.k_bn = nn.BatchNorm1d(dim)\n        self.k_lif = MyNode(step=step, tau=2.0)\n\n        self.v_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)\n        self.v_bn = nn.BatchNorm1d(dim)\n        self.v_lif = MyNode(step=step, tau=2.0)\n\n        self.attn_drop = nn.Dropout(0.2)\n        self.res_lif = MyNode(step=step, tau=2.0)\n        self.attn_lif = MyNode(step=step, tau=2.0, v_threshold=0.5, )\n\n        self.proj_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)\n        self.proj_bn = nn.BatchNorm1d(dim)\n        self.proj_lif = MyNode(step=step, tau=2.0, )\n\n    def forward(self, x):\n        self.reset()\n\n        T, B, C, N = x.shape\n\n        x_for_qkv = x.flatten(0, 1)  # TB, C N\n\n        q_conv_out = self.q_conv(x_for_qkv)  # [TB] C N\n        q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, N).contiguous()  # T B C N\n        q_conv_out = self.q_lif(q_conv_out.flatten(0, 1)).reshape(T, B, C, N)  # TB C N\n        q = q_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()\n\n        k_conv_out = self.k_conv(x_for_qkv)\n        k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, N).contiguous()\n        k_conv_out = self.k_lif(k_conv_out.flatten(0, 1)).reshape(T, B, C, N)  # TB C N\n        k = k_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()\n\n        v_conv_out = self.v_conv(x_for_qkv)\n        v_conv_out = self.v_bn(v_conv_out).reshape(T, B, C, N).contiguous()\n        v_conv_out = self.v_lif(v_conv_out.flatten(0, 1)).reshape(T, B, C, N)  # TB C N\n        v = v_conv_out.reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4).contiguous()\n\n        # @表示矩阵乘法,与matmul等价\n        # K,QV -> attention -> scale -> LIF -> Linear -> BN\n        attn = (q @ k.transpose(-2, -1))\n        x = (attn @ v) * self.scale\n\n        x = x.transpose(3, 4).reshape(T, B, C, N).contiguous()  # T B C N\n        x = self.attn_lif(x.flatten(0, 1))  # [TB] C N\n        x = self.proj_lif(self.proj_bn(self.proj_conv(x))).reshape(T, B, C, N)  # T B C N\n\n        return x\n\n\n# 整个encoder block,要在SSA和MLP的基础上加入残差\nclass Block(nn.Module):\n    def __init__(self, dim, num_heads, step=10, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = SSA(dim, step=step, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                        attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = MLP(step=step, in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)\n\n    def forward(self, x):\n        # residual connection\n        x = x + self.attn(x)\n        x = x + self.mlp(x)\n        return x\n\n\n# embed_dims = 256\nclass SPS(BaseModule):\n    def __init__(self, step=10, encode_type='direct', img_size_h=128, img_size_w=128, patch_size=4, in_channels=2,\n                 embed_dims=256):\n        super().__init__(step=step, encode_type=encode_type)\n        self.image_size = [img_size_h, img_size_w]\n\n        # timm内置to_2tuple把整形转换成2元元组\n        patch_size = to_2tuple(patch_size)  # 4->(4,4)\n        self.patch_size = patch_size  # patch_size\n        self.C = in_channels  # image_channel\n        self.H, self.W = self.image_size[0] // patch_size[0], self.image_size[1] // patch_size[1]\n        self.num_patches = self.H * self.W\n\n        # DVS with 2 more Maxpooling\n\n        self.proj_conv = nn.Conv2d(in_channels, embed_dims // 8, kernel_size=3, stride=1, padding=1, bias=False)\n        self.proj_bn = nn.BatchNorm2d(embed_dims // 8)\n        self.proj_lif = MyNode(step=step, tau=2.0)\n        self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n\n        self.proj_conv1 = nn.Conv2d(embed_dims // 8, embed_dims // 4, kernel_size=3, stride=1, padding=1, bias=False)\n        self.proj_bn1 = nn.BatchNorm2d(embed_dims // 4)\n        self.proj_lif1 = MyNode(step=step, tau=2.0)\n        self.maxpool1 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n\n        self.proj_conv2 = nn.Conv2d(embed_dims // 4, embed_dims // 2, kernel_size=3, stride=1, padding=1, bias=False)\n        self.proj_bn2 = nn.BatchNorm2d(embed_dims // 2)\n        self.proj_lif2 = MyNode(step=step, tau=2.0)\n        self.maxpool2 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n\n        self.proj_conv3 = nn.Conv2d(embed_dims // 2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)\n        self.proj_bn3 = nn.BatchNorm2d(embed_dims)\n        self.proj_lif3 = MyNode(step=step, tau=2.0)\n        self.maxpool3 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n\n        self.rpe_conv = nn.Conv2d(embed_dims, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)\n        self.rpe_bn = nn.BatchNorm2d(embed_dims)\n        self.rpe_lif = MyNode(step=step, tau=2.0)\n\n    def forward(self, x):\n        self.reset()\n\n        T, B, C, H, W = x.shape\n\n        x = self.proj_conv(x.flatten(0, 1))  # have some fire value\n        x = self.proj_bn(x).reshape(T, B, -1, H, W).contiguous()\n        x = self.proj_lif(x.flatten(0, 1)).contiguous()\n        x = self.maxpool(x)\n\n        x = self.proj_conv1(x)\n        x = self.proj_bn1(x).reshape(T, B, -1, H // 2, W // 2).contiguous()\n        x = self.proj_lif1(x.flatten(0, 1)).contiguous()\n        x = self.maxpool1(x)\n\n        x = self.proj_conv2(x)\n        x = self.proj_bn2(x).reshape(T, B, -1, H // 4, W // 4).contiguous()\n        x = self.proj_lif2(x.flatten(0, 1)).contiguous()\n        x = self.maxpool2(x)\n\n        x = self.proj_conv3(x)\n        x = self.proj_bn3(x).reshape(T, B, -1, H // 8, W // 8).contiguous()\n        x = self.proj_lif3(x.flatten(0, 1)).contiguous()\n        x = self.maxpool3(x)\n\n        x_rpe = self.rpe_bn(self.rpe_conv(x)).reshape(T, B, -1, H // 16, W // 16).contiguous()\n        x_rpe = self.rpe_lif(x_rpe.flatten(0, 1)).contiguous()\n        x = x + x_rpe\n        x = x.reshape(T, B, -1, (H // 16) * (H // 16)).contiguous()\n\n        return x  # T B C N\n\n\nclass Spikformer(nn.Module):\n    def __init__(self, step=10,\n                 img_size_h=64, img_size_w=64, patch_size=4, in_channels=2, num_classes=10,\n                 embed_dims=256, num_heads=16, mlp_ratios=4, qkv_bias=False, qk_scale=None,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,\n                 depths=2, sr_ratios=4,\n                 ):\n        super().__init__()\n        self.step = step  # time step\n        self.num_classes = num_classes\n        self.depths = depths\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths)]  # stochastic depth decay rule\n\n        patch_embed = SPS(step=step,\n                          img_size_h=img_size_h,\n                          img_size_w=img_size_w,\n                          patch_size=patch_size,\n                          in_channels=in_channels,\n                          embed_dims=embed_dims)\n\n        block = nn.ModuleList([Block(step=step,\n                                     dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratios, qkv_bias=qkv_bias,\n                                     qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j],\n                                     norm_layer=norm_layer, sr_ratio=sr_ratios)\n\n                               for j in range(depths)])\n\n        setattr(self, f\"patch_embed\", patch_embed)\n        setattr(self, f\"block\", block)\n\n        # classification head\n        self.head = nn.Linear(embed_dims, num_classes) if num_classes > 0 else nn.Identity()\n        self.apply(self._init_weights)\n\n    @torch.jit.ignore\n    def _get_pos_embed(self, pos_embed, patch_embed, H, W):\n        if H * W == self.patch_embed1.num_patches:\n            return pos_embed\n        else:\n            return F.interpolate(\n                pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),\n                size=(H, W), mode=\"bilinear\").reshape(1, -1, H * W).permute(0, 2, 1)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def forward_features(self, x):\n\n        block = getattr(self, f\"block\")\n        patch_embed = getattr(self, f\"patch_embed\")\n\n        x = patch_embed(x)\n        for blk in block:\n            x = blk(x)\n        return x.mean(3)\n\n    def forward(self, x):\n        x = x.permute(1, 0, 2, 3, 4)  # [T, N, 2, *, *]\n        x = self.forward_features(x)\n        x = self.head(x.mean(0))\n        return x\n\n\n\n# Adjust ur hyperparams here\n@register_model\ndef spikformer_dvs(pretrained=False, **kwargs):\n    model = Spikformer(step = 8,\n        img_size_h=64, img_size_w=64,\n        patch_size=4, embed_dims=256, num_heads=16, mlp_ratios=4,\n        in_channels=2, num_classes=10, qkv_bias=False,\n        depths=2, sr_ratios=1,\n        **kwargs\n    )\n    model.default_cfg = _cfg()\n    return model\n"
  },
  {
    "path": "examples/Structural_Development/DPAP/README.md",
    "content": "# Developmental Plasticity-inspired Adaptive Pruning for Deep Spiking and Artificial Neural Networks #\n\n## Requirments ##\n* matplotlib==3.5.1\n* numpy==1.22.4\n* Pillow==9.3.0\n* scipy==1.9.3\n* tensorboardX==2.5.1\n* torch==1.8.1+cu111\n* torchvision==0.9.1+cu111\n\n\n## Run ##\n\n``` CUDA_VISIBLE_DEVICES=0 python prun_ main.py```\n\n## Citation ##\nIf you find the code and dataset useful in your research, please consider citing:\n```\n@article{han2024similarity,\n  title={Developmental Plasticity-inspired Adaptive Pruning for Deep Spiking and Artificial Neural Networks},\n  author={Han, Bing and Zhao, Feifei and Zeng, Yi and Shen Guobin},\n  journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},\n  year={2024}\n  }\n  \n@article{zeng2023braincog,\n  title={Braincog: A spiking neural network based, brain-inspired cognitive intelligence engine for brain-inspired ai and brain simulation},\n  author={Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and others},\n  journal={Patterns},\n  volume={4},\n  number={8},\n  year={2023},\n  publisher={Elsevier},\n}\n```\n\nEnjoy!\n"
  },
  {
    "path": "examples/Structural_Development/DPAP/mask_model.py",
    "content": "import abc\nfrom functools import partial\nfrom torch.nn import functional as F\nimport torchvision\nfrom timm.models import register_model\n\nfrom braincog.base.node.node import *\nfrom braincog.base.encoder.encoder import *\nfrom braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule\nfrom utils import *\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\nconvlayer = [-1,0, 1, 3, 4, 6, 7]\nfclayer=[8,9]\nimgsize = [32,32, 32, 16,16, 16, 8,8, 8]\nsize = [3,128, 128, 256, 256, 512,512]\nsize_pool = [3,128, 128,128, 256, 256,256, 512,512]\nfcsize=[512*8*8,512]\n\nclass my_cifar_model(BaseModule):\n    def __init__(self,\n                 num_classes=10,\n                 step=8,\n                 node_type=LIFNode,\n                 encode_type='direct',\n                 *args,\n                 **kwargs):\n        super().__init__(step, encode_type, *args, **kwargs)\n\n        self.num_classes = num_classes\n\n        # self.node = node_type\n        # if issubclass(self.node, BaseNode):\n        #     self.node = partial(self.node, **kwargs, step=step)\n\n\n        self.feature = nn.Sequential(\n            BaseConvModule(size[0], size[1], kernel_size=(3, 3), padding=(1, 1)),\n            BaseConvModule(size[1],size[2], kernel_size=(3, 3), padding=(1, 1)),\n            nn.MaxPool2d(2),\n            BaseConvModule(size[2], size[3], kernel_size=(3, 3), padding=(1, 1)),\n            BaseConvModule(size[3], size[4], kernel_size=(3, 3), padding=(1, 1)),\n            nn.MaxPool2d(2),\n            BaseConvModule(size[4], size[5], kernel_size=(3, 3), padding=(1, 1)),\n            BaseConvModule(size[5], size[6], kernel_size=(3, 3), padding=(1, 1)),\n        )\n        self.cfla=self._cflatten()\n        self.fc_prun = self._create_fc_prun()\n        self.fc = self._create_fc()\n\n    def _cflatten(self):\n        fc = nn.Sequential(\n            nn.Flatten(),\n        )\n        return fc\n        \n    def _create_fc_prun(self):\n        fc = nn.Sequential(\n            BaseLinearModule(fcsize[0], fcsize[1])\n        )\n        return fc\n\n    def _create_fc(self):\n        fc = nn.Sequential(\n            BaseLinearModule(fcsize[1], self.num_classes)\n        )\n        return fc\n    \n    def forward(self, inputs):\n        inputs = self.encoder(inputs)\n\n        self.reset()\n        if not self.training:\n            self.fire_rate.clear()\n\n        outputs = []\n        spikes=[]\n                \n        for t in range(self.step):\n            spikest=[]\n            x = inputs[t]\n            if x.shape[-1] > 32:\n                x = F.interpolate(x, size=[64, 64])\n            spikest.append(x.detach())\n            for i in range(len(self.feature)):\n                spikei=self.feature[i](x)\n                x=spikei\n                spikest.append(spikei.detach())\n\n            x=self.cfla(x)\n            spikest.append(x.detach())\n            x=self.fc_prun(x)\n            spikest.append(x.detach())\n            x = self.fc(x)\n            spikes.append(spikest)\n\n            outputs.append(x)\n\n        return sum(outputs) / len(outputs),spikes\n\n\nclass Mask:\n    def __init__(self, model,batch,step):\n        self.model = model\n        self.fullbook={}\n        self.mat = {}\n        self.feature=model.feature\n        self.fc={}\n        self.fc[1]=model.fc_prun[0]\n        self.fc[2]=model.fc[0]\n        self.n_delta={}\n        self.ww_delta={}\n        self.reduce={}\n        self.reduceww={}\n        self.batch=batch\n        self.step=step\n\n    def init_length(self):\n        for i in range(1,len(convlayer)):\n            index=convlayer[i]\n            self.fullbook[index] =torch.ones((size[i],size[i-1],3,3),device=device)\n            self.n_delta[index]=torch.zeros(size[i],device=device)\n            self.reduce[index] = 10*torch.ones(size[i],device=device)\n        for i in range(1,len(fclayer)):\n            index=fclayer[i]\n            self.fullbook[index] = torch.ones((fcsize[i],fcsize[i-1]),device=device)\n            self.n_delta[index]=torch.zeros(fcsize[i],device=device)\n            self.ww_delta[index]=torch.zeros(fcsize[i]*fcsize[i-1],device=device)\n            self.reduce[index] = 10*torch.ones(fcsize[i],device=device)\n            self.reduceww[index] = 10*torch.ones(fcsize[i]*fcsize[i-1],device=device)\n            \n            \n    def get_filter_codebook(self,ww,dendrite,ii,index,epoch): \n        if ii == 4:\n            wconv= dendrite#.cpu().numpy()\n            self.n_delta[index]=(unit(wconv)*2-0.65)\n            pos=torch.nonzero(self.n_delta[index]>0)\n            self.n_delta[index][pos]=self.n_delta[index][pos]+5\n            print(wconv.mean(),wconv.max(), wconv.min())\n            self.reduce[index]=self.reduce[index]*0.999+self.n_delta[index]*math.exp(-int((epoch-5)/12))\n            filter_ind = torch.nonzero(self.reduce[index] <0)\n            print(self.reduce[index].mean(),self.reduce[index].max(),self.reduce[index].min(),len(filter_ind))\n             \n            for x in range(0, len(filter_ind)):\n                self.fullbook[index][filter_ind[x]] = 0\n      \n        if ii == 2:\n            length=ww.size()[0]*ww.size()[1]\n            book=torch.ones(length,device=device)\n            filter_ww = ww.view(-1)#.cpu().numpy()\n            self.ww_delta[index]=(unit(filter_ww)*2-1.5)\n            pos=torch.nonzero(self.ww_delta[index]>0)\n            self.ww_delta[index][pos]=self.ww_delta[index][pos]+2\n            self.reduceww[index]= self.reduceww[index]*0.999+self.ww_delta[index]*math.exp(-int((epoch-5)/13))\n            filter_indww =torch.nonzero(self.reduceww[index] < 0)\n            book[filter_indww]=0\n            book=book.reshape((ww.size()[0],-1))\n            self.fullbook[index]=self.fullbook[index]*book\n            print(self.reduceww[index].mean(),self.reduceww[index].max(),self.reduceww[index].min(),len(filter_indww))\n                \n            wconv= dendrite#.cpu().numpy()\n            self.n_delta[index]=(unit(wconv)*2-1.5)\n            pos=torch.nonzero(self.n_delta[index]>0)\n            self.n_delta[index][pos]=self.n_delta[index][pos]+2\n            self.reduce[index]=self.reduce[index]*0.999+self.n_delta[index]*math.exp(-int((epoch-5)/13))\n            filter_ind = torch.nonzero(self.reduce[index] <0)\n            print(self.reduce[index].mean(),self.reduce[index].max(),self.reduce[index].min(),len(filter_ind))\n             \n            for x in range(0, len(filter_ind)):\n                self.fullbook[index][filter_ind[x]] = 0\n\n        return self.fullbook[index]\n\n    def convert2tensor(self, x):\n        x = torch.FloatTensor(x)\n        return x\n\n    def init_mask(self, wwfc,convtra,epoch):\n        for i in range(1,len(convlayer)):\n            index=convlayer[i]\n            ww = wwfc[index]\n            dendrite=convtra[index]\n            self.mat[index]=self.get_filter_codebook(ww, dendrite,4,index,epoch)\n            #self.mat[index] = self.convert2tensor(self.mat[index]).cuda()\n        for i in range(1,len(fclayer)):\n            index=fclayer[i]\n            ww=wwfc[index]\n            dendrite=convtra[index]\n            self.mat[index]=self.get_filter_codebook(ww,dendrite,2,index,epoch)\n            #self.mat[index] = self.convert2tensor(self.mat[index]).cuda()\n\n    def do_mask(self):\n        for i in range(1,len(convlayer)):\n            index=convlayer[i]\n            ww = self.feature[index].conv.weight\n            maskww=ww*self.mat[index]\n            self.feature[index].conv.weight.data=maskww\n        for i in range(1,len(fclayer)):\n            ind=fclayer[i]\n            ww = self.fc[i].fc.weight\n            maskww=ww*self.mat[ind]\n            self.fc[i].fc.weight.data=maskww\n\n    def if_zero(self):\n        cc=[]\n        for i in range(1,len(convlayer)):\n            ww=self.feature[convlayer[i]].conv.weight\n            b = ww.data.view(-1).cpu().numpy()\n            print(\"number of weight is %d, zero is %.3f\" %(len(b),100*(len(b)- np.count_nonzero(b))/len(b)))\n            cc.append(100*(len(b)- np.count_nonzero(b))/len(b))\n        for i in range(1,len(fcsize)):\n            ww=self.fc[i].fc.weight\n            b = ww.data.view(-1).cpu().numpy()\n            print(\"number of weight is %d, zero is %.3f\" %(len(b),100*(len(b)- np.count_nonzero(b))/len(b)))\n            cc.append(100*(len(b)- np.count_nonzero(b))/len(b))\n        return cc\n\nclass Trace:\n    def __init__(self, model,batch,step):\n        self.model = model\n        self.feature=model.feature\n        self.ctrace={}\n        self.fctrace={}\n        self.csum={}\n        self.fcsum={}\n        self.delta = 0.5\n        self.batch=batch\n        self.step=step\n\n    def computing_trace(self,spikes):\n        for i in range(len(imgsize)):\n            index=i-1\n            self.ctrace[index]=torch.zeros((self.batch,size_pool[i],imgsize[i],imgsize[i]),device=device)\n        for i in range(len(fclayer)):\n            index=fclayer[i]\n            self.fctrace[index]=torch.zeros((self.batch,fcsize[i]),device=device)\n        for t in range(self.step):      \n            for i in range(len(imgsize)):\n                index=i-1\n                sp=spikes[t][index+1].detach()\n                #print(sp.size(),self.ctrace[index].size())\n                self.ctrace[index]=self.delta*self.ctrace[index].cuda()+sp.cuda()\n            for i in range(len(fclayer)):\n                index=fclayer[i]\n                sp=spikes[t][index+1].detach()\n                self.fctrace[index]=self.delta*self.fctrace[index].cuda()+sp.cuda()\n        for i in range(len(imgsize)):\n            index=i-1\n            self.csum[index]=self.ctrace[index]/(self.step)\n            self.csum[index]=torch.sum(torch.sum(self.csum[index],dim=2),dim=2)\n        for i in range(len(fclayer)):\n            index=fclayer[i]\n            self.fcsum[index]=self.fctrace[index]/(self.step)\n        return self.csum,self.fcsum\n"
  },
  {
    "path": "examples/Structural_Development/DPAP/prun_main.py",
    "content": "import argparse\nimport time\nimport os\nfrom collections import OrderedDict\nfrom contextlib import suppress\nfrom datetime import datetime\nimport sys\nsys.path.append('..')\nimport torch\nimport torch.nn as nn\nfrom torch.nn.parallel import DistributedDataParallel as NativeDDP\nimport logging\nfrom timm.utils import *\nfrom timm.optim import create_optimizer\nfrom timm.scheduler import create_scheduler\n\nfrom braincog.base.node.node import *\nfrom braincog.base.encoder.encoder import *\nfrom braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule\nfrom braincog.base.utils.criterions import *\nfrom braincog.datasets.datasets import *\n\nfrom mask_model import *\nfrom utils import *\n\n_logger = logging.getLogger('train')\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nexp_name = '-'.join([datetime.now().strftime(\"%Y%m%d-%H%M%S\"),'c10'])\noutput_dir = get_outdir('./', 'train', exp_name)\nsetup_default_logging(log_path=os.path.join(output_dir, 'log.txt'))\n_logger.info(exp_name)\n\nconfig_parser = cfg = argparse.ArgumentParser(description='Training Config', add_help=False)\ndataset='cifar10'\nnum_classes=10\nstep=8\nencode='direct'\nnode_type='PLIFNode'\nthresh=0.5\ntau=2.0\n\ntorch.backends.cudnn.benchmark = True\ndevicee=0\nseed=42\nchannels = 2\nbatch_size=50\nepochs=300\n\nlr=5e-3\nlinear_scaled_lr = lr * batch_size/ 1024.0\ncfg.opt='adamw'\ncfg.lr=linear_scaled_lr\ncfg.weight_decay=0.01\ncfg.momentum=0.9\ncfg.epochs=epochs\ncfg.sched='cosine'\ncfg.min_lr=1e-5\ncfg.warmup_lr=1e-6\ncfg.warmup_epochs=5\ncfg.cooldown_epochs=10\ncfg.decay_rate=0.1\n\neval_metric='top1'\nbest_test = 0\nbest_testepoch = 0\nbest_testprun = 0\nbest_testepochprun = 0\nepoch_prune = 1\nrate_decay_epoch=30\nNUM = 0\n\ntorch.cuda.set_device('cuda:%d' % devicee)\ntorch.manual_seed(seed)\n\nmodel = my_cifar_model(step=step,encode_type=encode,node_type=node_type,num_classes=num_classes)\nmodel = model.cuda()\nprint(model)\n\noptimizer = create_optimizer(cfg, model)\nlr_scheduler, num_epochs = create_scheduler(cfg, optimizer)\n\nloader_train, loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % dataset)(batch_size=batch_size, step=step)\ntrain_loss_fn = UnilateralMse(1.)\nvalidate_loss_fn = UnilateralMse(1.)\n\n\nm = Mask(model,batch_size,step)\nm.init_length()\ntrace=Trace(model,batch_size,step)\n\nneuron_th,spines,bcm,epoch_trace = init(batch_size,convlayer,fclayer,size,fcsize)\n\ndef BCM_and_trace(NUM,trace,spikes,neuron_th,bcm,epoch_trace):\n    NUM = NUM + 1\n    csum,fcsum= trace.computing_trace(spikes)\n    for i in range(1,len(convlayer)):\n        index=convlayer[i]\n        post1 = (csum[index] * (csum[index] - neuron_th[index]))\n        hebb1 = torch.mm(post1.T, csum[index-1]) \n        bcm[index] = bcm[index] + hebb1\n        neuron_th[index] = torch.div(neuron_th[index] * (NUM - 1) + csum[index], NUM)\n        cs=torch.sum(csum[index],dim=0)\n        epoch_trace[index] = epoch_trace[index] + cs\n\n    for i in range(1,len(fclayer)):\n        index = fclayer[i]\n        post1 = (fcsum[index] * (fcsum[index] - neuron_th[index]))\n        hebb1 = torch.mm(post1.T, fcsum[fclayer[i - 1]])\n        bcm[index] = bcm[index] + hebb1\n        neuron_th[index] = torch.div(neuron_th[index] * (NUM - 1) + fcsum[index], NUM)\n        cs=torch.sum(fcsum[index],dim=0)\n        epoch_trace[index] = epoch_trace[index] + cs\n    return epoch_trace,bcm,NUM\n\ndef train_epoch(\n        epoch, model, loader, optimizer, loss_fn,trace,NUM,bcm,neuron_th,epoch_trace,\n        lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,\n        loss_scaler=None, model_ema=None, mixup_fn=None):\n\n    batch_time_m = AverageMeter()\n    data_time_m = AverageMeter()\n    losses_m = AverageMeter()\n    top1_m = AverageMeter()\n\n    model.train()\n\n    end = time.time()\n    last_idx = len(loader) - 1\n    num_updates = epoch * len(loader)\n\n    for batch_idx, (inputs, target) in enumerate(loader):\n        last_batch = batch_idx == last_idx\n        data_time_m.update(time.time() - end)\n        inputs, target = inputs.type(torch.FloatTensor).cuda(), target.cuda()\n        output,spikes = model(inputs)\n\n        epoch_trace,bcm,NUM = BCM_and_trace(NUM,trace,spikes,neuron_th,bcm,epoch_trace)\n\n        loss = loss_fn(output, target)\n        acc1, acc5 = accuracy(output, target, topk=(1, 5))\n        losses_m.update(loss.item(), inputs.size(0))\n        top1_m.update(acc1.item(), inputs.size(0))\n\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        batch_time_m.update(time.time() - end)\n        if last_batch or batch_idx %100 == 0:\n            lrl = [param_group['lr'] for param_group in optimizer.param_groups]\n            lr = sum(lrl) / len(lrl)\n            print(\"Train: epoch:\",epoch,batch_idx,\"/\",len(loader),\"loss:\",losses_m.avg,\"acc1:\", top1_m.avg,\"lr:\",lr)\n\n        if lr_scheduler is not None:\n            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)\n\n        end = time.time()\n        # end for\n\n    if hasattr(optimizer, 'sync_lookahead'):\n        optimizer.sync_lookahead()\n\n    return OrderedDict([('loss', losses_m.avg)]),epoch_trace,bcm,NUM\n\ndef validate(model, loader, loss_fn, amp_autocast=suppress, log_suffix=''):\n    batch_time_m = AverageMeter()\n    losses_m = AverageMeter()\n    top1_m = AverageMeter()\n    top5_m = AverageMeter()\n\n    model.eval()\n\n    end = time.time()\n    last_idx = len(loader) - 1\n    with torch.no_grad():\n        for batch_idx, (inputs, target) in enumerate(loader):\n            last_batch = batch_idx == last_idx\n            inputs = inputs.type(torch.FloatTensor).cuda()\n            target = target.cuda()\n\n            output,spikes = model(inputs)\n            if isinstance(output, (tuple, list)):\n                output = output[0]\n\n            loss = loss_fn(output, target)\n            acc1, acc5 = accuracy(output, target, topk=(1, 5))\n\n            reduced_loss = loss.data\n\n            torch.cuda.synchronize()\n\n            losses_m.update(reduced_loss.item(), inputs.size(0))\n            top1_m.update(acc1.item(), output.size(0))\n            if last_batch or batch_idx %100 == 0:\n                print(\"Test: loss:\",losses_m.avg,\"acc1:\", top1_m.avg)\n\n            batch_time_m.update(time.time() - end)\n            end = time.time()\n\n\n    metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])\n\n    return metrics\n\n\nfor epoch in range(epochs):\n\n    train_metrics, epoch_trace, bcm, N = train_epoch(\n        epoch, model, loader_train, optimizer, train_loss_fn,trace,NUM,bcm,neuron_th,epoch_trace,\n        lr_scheduler=lr_scheduler)\n    NUM = N\n\n    for i in range(1,len(convlayer)):\n        index=convlayer[i]\n        bcmconv = torch.sum(bcm[index], dim=1)\n        bcmconv=unit_tensor(bcmconv)\n        traconv=unit_tensor(epoch_trace[index])\n        spines[index]=bcmconv*traconv\n    for i in range(1, len(fclayer)):\n        index=fclayer[i]\n        bcmfc = torch.sum(bcm[index], dim=1)\n        bcmfc=unit_tensor(bcmfc)\n        trafc=unit_tensor(epoch_trace[index])\n        spines[index]=bcmfc*trafc\n\n    if epoch>4:\n        m.model = model\n        m.init_mask(bcm,spines,epoch)\n        m.do_mask()\n        print(\"Done pruning\")\n        cc=m.if_zero()\n        model = m.model\n\n    eval_metrics = validate(model, loader_eval, validate_loss_fn)\n    top1=eval_metrics['top1']\n    if top1 > best_testprun:\n        best_testprun = top1\n        best_testepochprun =epoch\n    if epoch%40==0:\n        print('best acc:',best_testprun,'best epoch:',best_testepochprun)\n    if epoch>4:\n        _logger.info('*** epoch: {0} (pruning rate {1},acc:{2})'.format(epoch, cc,top1))\n    if lr_scheduler is not None:\n        lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])\n"
  },
  {
    "path": "examples/Structural_Development/DPAP/utils.py",
    "content": "import torch\nimport numpy as np\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\ndef print_log(print_string, log):\n    #print(\"{}\".format(print_string))\n    log.write('{}\\n'.format(print_string))\n    log.flush()\n\ndef unit(x):\n    if x.size()[0]>0:\n        xnp=x.cpu().numpy()\n        maxx=np.percentile(xnp, 75)\n        minx=torch.min(x)\n        marge=maxx-minx\n        if marge!=0:\n            xx=(x-minx)/marge\n            xx=torch.clip(xx, 0,1)\n        else:\n            xx=0.5*torch.ones_like(x)\n        return xx\n    else:\n        return x\n\ndef unit_tensor(x):\n    if x.size()[0]>0:\n        maxx=torch.max(x)\n        minx=torch.min(x)\n        marge=maxx-minx\n        if marge!=0:\n            xx=(x-minx)/marge\n        else:\n            xx=0.5*torch.ones_like(x)\n        return xx\n    else:\n        return x\n\ndef init(batch,convlayer,fclayer,size,fcsize):\n    neuron_th={}\n    convtra = {}\n    bcm={}\n    epoch_trace = {}\n    for i in range(1,len(convlayer)):\n        index=convlayer[i]\n        neuron_th[index]=torch.zeros((batch,size[i]),device=device)\n        convtra[index] = torch.zeros(size[i],device=device)\n        bcm[index]=torch.zeros(size[i],size[i-1],device=device)\n        epoch_trace[index] = torch.zeros((size[i]),device=device)\n    for i in range(1,len(fclayer)):\n        index=fclayer[i]\n        neuron_th[index]=torch.zeros((batch,fcsize[i]),device=device)\n        convtra[index]=torch.zeros(fcsize[i],device=device)\n        bcm[index]=torch.zeros(fcsize[i],fcsize[i-1],device=device)\n        epoch_trace[index] = torch.zeros(fcsize[i],device=device)\n    return neuron_th,convtra,bcm,epoch_trace\n"
  },
  {
    "path": "examples/Structural_Development/DSD-SNN/README.md",
    "content": "# Enhancing Efficient Continual Learning with Dynamic Structure Development of Spiking Neural Networks #\n\n## Requirments ##\n* numpy\n* timm\n* pytorch >= 1.7.0\n* collections\n* argparse\n\n## Introduction ##\nDynamic Structure Development of Spiking Neural Networks (DSD-SNN) for efficient and adaptive continual learning:   \n\ngrow new neurons and prune redundant neurons, increasing memory capacity and reducing computational overhead.  \n\nverlap shared structure to leverage acquired knowledge to new tasks, empowering a single network to support multiple incremental tasks.   \n\nWe validate the effectiveness of the DSD-SNN multiple TIL and CIL benchmarks.\n\n## Run ##\n \n\n```CUDA_VISIBLE_DEVICES=0 python main_simplified.py```   \n\n## Citation ##\nIf you find the code and dataset useful in your research, please consider citing:\n```\n@article{han2022developmental,\n  title={Enhancing Efficient Continual Learning with Dynamic Structure Development of Spiking Neural Networks},\n  author={Han, Bing and Zhao, Feifei and Zeng, Yi and Wenxuan, Pan and Shen, Guobin},\n  booktitle = {Proceedings of the Thirty-First International Joint Conference on\n               Artificial Intelligence, {IJCAI-23}},\n  publisher = {International Joint Conferences on Artificial Intelligence Organization},\n  year={2023}\n  }\n\n@article{zeng2023braincog,\n  title={Braincog: A spiking neural network based, brain-inspired cognitive intelligence engine for brain-inspired ai and brain simulation},\n  author={Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and others},\n  journal={Patterns},\n  volume={4},\n  number={8},\n  year={2023},\n  publisher={Elsevier},\n}\n```\n\nEnjoy!\n"
  },
  {
    "path": "examples/Structural_Development/DSD-SNN/cifar100/available.py",
    "content": "from torchvision import datasets, transforms\nfrom manipulate import UnNormalize\n\n\n# specify available data-sets.\nAVAILABLE_DATASETS = {\n    'MNIST': datasets.MNIST,\n    'CIFAR100': datasets.CIFAR100,\n    'CIFAR10': datasets.CIFAR10,\n}\n\n# specify available transforms.\nAVAILABLE_TRANSFORMS = {\n    'MNIST': [\n        transforms.ToTensor(),\n    ],\n    'MNIST32': [\n        transforms.Pad(2),\n        transforms.ToTensor(),\n    ],\n    'CIFAR10': [\n        transforms.ToTensor(),\n    ],\n    'CIFAR100': [\n        transforms.ToTensor(),\n    ],\n    'CIFAR10_norm': [\n        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])\n    ],\n    'CIFAR100_norm': [\n        transforms.Normalize(mean=[0.5071, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2761])\n    ],\n    'CIFAR10_denorm': UnNormalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]),\n    'CIFAR100_denorm': UnNormalize(mean=[0.5071, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2761]),\n    'augment_from_tensor': [\n        transforms.ToPILImage(),\n        transforms.RandomCrop(32, padding=4, padding_mode='symmetric'),\n        transforms.RandomHorizontalFlip(),\n        transforms.ToTensor(),\n    ],\n    'augment': [\n        transforms.RandomCrop(32, padding=4, padding_mode='symmetric'),\n        transforms.RandomHorizontalFlip(),\n    ],\n}\n\n# specify configurations of available data-sets.\nDATASET_CONFIGS = {\n    'MNIST': {'size': 28, 'channels': 1, 'classes': 10},\n    'MNIST32': {'size': 32, 'channels': 1, 'classes': 10},\n    'CIFAR10': {'size': 32, 'channels': 3, 'classes': 10},\n    'CIFAR100': {'size': 32, 'channels': 3, 'classes': 100},\n}\n"
  },
  {
    "path": "examples/Structural_Development/DSD-SNN/cifar100/main_simplified.py",
    "content": "\n\nimport argparse\nimport time\n\nimport timm.models\nimport yaml\nimport os\nimport logging\nfrom collections import OrderedDict\nfrom contextlib import suppress\nfrom datetime import datetime\n\nfrom braincog.base.node.node import *\nfrom braincog.utils import *\nfrom braincog.base.utils.criterions import *\nfrom braincog.datasets.datasets import *\nfrom braincog.model_zoo.resnet import *\nfrom braincog.model_zoo.convnet import *\nfrom braincog.utils import save_feature_map\n\nimport torch\nimport torch.nn as nn\nimport torchvision.utils\nfrom torchvision import transforms\nfrom timm.data import ImageDataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset\nfrom timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model\nfrom timm.utils import *\nfrom timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy\nfrom timm.optim import create_optimizer\nfrom timm.scheduler import create_scheduler\nfrom timm.utils import ApexScaler, NativeScaler\nfrom maskcl2 import *\n# from ptflops import get_model_complexity_info\nfrom thop import profile, clever_format\nfrom manipulate import SubDataset\ntorch.backends.cudnn.benchmark = True\n_logger = logging.getLogger('train')\nfrom available import AVAILABLE_DATASETS, AVAILABLE_TRANSFORMS, DATASET_CONFIGS\nfrom torch.utils.data.dataloader import DataLoader\nfrom torch.utils.data import ConcatDataset\nimport copy\nfrom vgg_snn import SNN\n\n# torch.cuda.set_device(9)\n\n# The first arg parser parses out only the --config argument, this argument is used to\n# load a yaml file containing key-values that override the defaults for the main parser below\nconfig_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)\nparser.add_argument('-c', '--config', default='', type=str, metavar='FILE',\n                    help='YAML config file specifying default arguments')\n\nparser = argparse.ArgumentParser(description='SNN Training and Evaluating')\n\n# Model parameters\nparser.add_argument('--dataset', default='cifar100', type=str)\nparser.add_argument('--model', default='cifar_convnet', type=str, metavar='MODEL',\n                    help='Name of model to train (default: \"countception\"')\nparser.add_argument('--resume', default='', type=str, metavar='PATH',\n                    help='Resume full model and optimizer state from checkpoint (default: none)')\nparser.add_argument('--num-classes', type=int, default=100, metavar='N',\n                    help='number of label classes (default: 100)')\nparser.add_argument('--task_num', type=int, default=10, metavar='N',\n                    help='number of label classes (default: 10)')\n\n# Dataloader parameters\nparser.add_argument('-b', '--batch-size', type=int, default=50, metavar='N',\n                    help='inputs batch size for training (default: 128)')\n\n# Optimizer parameters\nparser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',\n                    help='Optimizer (default: \"adamw\"')\nparser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',\n                    help='Optimizer Epsilon (default: None, use opt default)')\nparser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',\n                    help='Optimizer Betas (default: None, use opt default)')\nparser.add_argument('--momentum', type=float, default=0.9, metavar='M',\n                    help='Optimizer momentum (default: 0.9)')\nparser.add_argument('--weight-decay', type=float, default=0.01,\n                    help='weight decay (default: 0.01 for adamw)')\nparser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',\n                    help='Clip gradient norm (default: None, no clipping)')\nparser.add_argument('--adam-epoch', type=int, default=1000, help='lamb switch to adamw')\n\n# Learning rate schedule parameters\nparser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',\n                    help='LR scheduler (default: \"cosine\"')\nparser.add_argument('--lr', type=float, default=1e-2, metavar='LR',\n                    help='learning rate (default: 0.01)')\nparser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',\n                    help='learning rate noise on/off epoch percentages')\nparser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',\n                    help='learning rate noise limit percent (default: 0.67)')\nparser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',\n                    help='learning rate noise std-dev (default: 1.0)')\nparser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',\n                    help='learning rate cycle len multiplier (default: 1.0)')\nparser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',\n                    help='learning rate cycle limit')\nparser.add_argument('--warmup-lr', type=float, default=1e-4, metavar='LR',\n                    help='warmup learning rate (default: 0.0001)')\nparser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',\n                    help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')\nparser.add_argument('--epochs', type=int, default=600, metavar='N',\n                    help='number of epochs to train (default: 2)')\nparser.add_argument('--start-epoch', default=None, type=int, metavar='N',\n                    help='manual epoch number (useful on restarts)')\nparser.add_argument('--decay-epochs', type=float, default=30, metavar='N',\n                    help='epoch interval to decay LR')\nparser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',\n                    help='epochs to warmup LR, if scheduler supports')\nparser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',\n                    help='epochs to cooldown LR at min_lr, after cyclic schedule ends')\nparser.add_argument('--patience-epochs', type=int, default=10, metavar='N',\n                    help='patience epochs for Plateau LR scheduler (default: 10')\nparser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',\n                    help='LR decay rate (default: 0.1)')\nparser.add_argument('--power', type=int, default=1, help='power')\n\n# Misc\nparser.add_argument('--seed', type=int, default=0, metavar='S',\n                    help='random seed (default: 42)')\nparser.add_argument('--log-interval', type=int, default=25, metavar='N',\n                    help='how many batches to wait before logging training status')\nparser.add_argument('--recovery-interval', type=int, default=0, metavar='N',\n                    help='how many batches to wait before writing recovery checkpoint')\nparser.add_argument('-j', '--workers', type=int, default=4, metavar='N',\n                    help='how many training processes to use (default: 1)')\nparser.add_argument('--device', type=int, default=0)\nparser.add_argument('--output', default='/home/hanbing/brain/bp2/', type=str, metavar='PATH',\n                    help='path to output folder (default: none, current dir)')\nparser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',\n                    help='Best metric (default: \"top1\"')\n\n# Spike parameters\nparser.add_argument('--step', type=int, default=4, help='Simulation time step (default: 10)')\nparser.add_argument('--encode', type=str, default='direct', help='Input encode method (default: direct)')\n# neuron type\nparser.add_argument('--node-type', type=str, default='LIFNode', help='Node type in network (default: PLIF)')\nparser.add_argument('--act-fun', type=str, default='QGateGrad',\n                    help='Surogate Function in node. Only for Surrogate nodes (default: AtanGrad)')\nparser.add_argument('--thresh', type=float, default=.5, help='Firing threshold (default: 0.5)')\nparser.add_argument('--tau', type=float, default=2., help='Attenuation coefficient (default: 2.)')\n\nparser.add_argument('--loss-fn', type=str, default='ce', help='loss function (default: ce)')\nparser.add_argument('--noisy-grad', type=float, default=0.,\n                    help='Add noise to backward, sometime will make higher accuracy (default: 0.)')\nparser.add_argument('--n_warm_up', type=int, default=0,\n                    help='Warm up epoch, replace all node to ReLU to warm up weights in network before (default: 0)')\nparser.add_argument('--spike-output', action='store_true', default=False,\n                    help='Using mem output or spike output (default: False)')\n\n# EventData Augmentation\nparser.add_argument('--mix-up', action='store_true', help='Mix-up for event data (default: False)')\nparser.add_argument('--cut-mix', action='store_true', help='CutMix for event data (default: False)')\nparser.add_argument('--event-mix', action='store_true', help='EventMix for event data (default: False)')\nparser.add_argument('--cutmix_beta', type=float, default=1.0, help='cutmix_beta (default: 1.)')\nparser.add_argument('--cutmix_prob', type=float, default=0.5, help='cutmix_prib for event data (default: .5)')\nparser.add_argument('--cutmix_num', type=int, default=1, help='cutmix_num for event data (default: 1)')\nparser.add_argument('--cutmix_noise', type=float, default=0.,\n                    help='Add Pepper noise after mix, sometimes work (default: 0.)')\nparser.add_argument('--rand-aug', action='store_true',\n                    help='Rand Augment for Event data (default: False)')\nparser.add_argument('--randaug_n', type=int, default=3,\n                    help='Rand Augment times n (default: 3)')\nparser.add_argument('--randaug_m', type=int, default=15,\n                    help='Rand Augment times n (default: 15) (0-30)')\nparser.add_argument('--temporal-flatten', action='store_true',\n                    help='Temporal flatten to channels. ONLY FOR EVENT DATA TRAINING BY ANN')\nparser.add_argument('--train-portion', type=float, default=0.9,\n                    help='Dataset portion, only for datasets which do not have validation set (default: 0.9)')\nparser.add_argument('--event-size', default=48, type=int,\n                    help='Event size. Resize event data before process (default: 48)')\nparser.add_argument('--layer-by-layer', action='store_true',\n                    help='forward step-by-step or layer-by-layer. '\n                         'Larger Model with layer-by-layer will be faster (default: False)')\nparser.add_argument('--node-resume', type=str, default='',\n                    help='resume weights in node for adaptive node. (default: False)')\nparser.add_argument('--node-trainable', action='store_true')\n\n# visualize\nparser.add_argument('--visualize', action='store_true',\n                    help='Visualize spiking map for each layer, only for validate (default: False)')\nparser.add_argument('--spike-rate', action='store_true',\n                    help='Print spiking rate for each layer, only for validate(default: False)')\n\nparser.add_argument('--suffix', type=str, default='',\n                    help='Add an additional suffix to the save path (default: \\'\\')')\n\n\ndef _parse_args():\n    # Do we have a config file to parse?\n    args_config, remaining = config_parser.parse_known_args()\n    if args_config.config:\n        with open(args_config.config, 'r') as f:\n            cfg = yaml.safe_load(f)\n            parser.set_defaults(**cfg)\n\n    # The main arg parser parses the rest of the args, the usual\n    # defaults will have been overridden if config file specified.\n    args = parser.parse_args(remaining)\n\n    # Cache the args as a text string to save them in the output dir later\n    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)\n    return args, args_text\n\ndef get_dataset(name, type='train', download=True, capacity=None, permutation=None, dir='./store/datasets',\n                verbose=False, augment=False, normalize=False, target_transform=None):\n    '''Create [train|valid|test]-dataset.'''\n\n    data_name = 'MNIST' if name in ('MNIST28', 'MNIST32') else name\n    dataset_class = AVAILABLE_DATASETS[data_name]\n\n    # specify image-transformations to be applied\n    transforms_list = [*AVAILABLE_TRANSFORMS['augment']] if augment else []\n    transforms_list += [*AVAILABLE_TRANSFORMS[name]]\n    if normalize:\n        transforms_list += [*AVAILABLE_TRANSFORMS[name+\"_norm\"]]\n    # if permutation is not None:\n    #     transforms_list.append(transforms.Lambda(lambda x, p=permutation: permutate_image_pixels(x, p)))\n    dataset_transform = transforms.Compose(transforms_list)\n\n    # load data-set\n    dataset = dataset_class('{dir}/{name}'.format(dir=dir, name=data_name), train=False if type=='test' else True,\n                            download=download, transform=dataset_transform, target_transform=target_transform)\n\n    # print information about dataset on the screen\n    if verbose:\n        print(\" --> {}: '{}'-dataset consisting of {} samples\".format(name, type, len(dataset)))\n\n    # if dataset is (possibly) not large enough, create copies until it is.\n    if capacity is not None and len(dataset) < capacity:\n        dataset = ConcatDataset([copy.deepcopy(dataset) for _ in range(int(np.ceil(capacity / len(dataset))))])\n\n    return dataset\n\n\ndef main():\n    args, args_text = _parse_args()\n    # args.no_spike_output = args.no_spike_output | args.cut_mix\n    args.no_spike_output = True\n    output_dir = ''\n    output_base = args.output if args.output else './output'\n    exp_name = '-'.join([\n        datetime.now().strftime(\"%Y%m%d-%H%M%S\"),\n        'SNN',\n        args.dataset,\n        str(args.seed),\n        'gwu',\n        'xin-epochs'\n    ])\n    output_dir = get_outdir(output_base, 'train', exp_name)\n    print(output_dir)\n    args.output_dir = output_dir\n    setup_default_logging(log_path=os.path.join(output_dir, 'log.txt'))\n\n    torch.cuda.set_device('cuda:%d' % args.device)\n\n    torch.manual_seed(args.seed)\n\n    model = SNN(\n        num_classes=args.num_classes,\n        dataset=args.dataset,\n        step=args.step,\n        encode_type=args.encode,\n        node_type=eval(args.node_type),\n        threshold=args.thresh,\n        tau=args.tau,\n        spike_output=not args.no_spike_output,\n        act_fun=args.act_fun,\n        temporal_flatten=args.temporal_flatten,\n        layer_by_layer=args.layer_by_layer,\n        batch_size=args.batch_size,\n        task_num=args.task_num\n    )\n\n\n    print(model)\n    # for n,p in enumerate(model.parameters()):\n    #     print(n,p.size())\n    if 'dvs' in args.dataset:\n        args.channels = 2\n    elif 'mnist' in args.dataset:\n        args.channels = 1\n    else:\n        args.channels = 3\n    # flops, params = profile(model, inputs=(torch.randn(1, args.channels, args.img_size, args.img_size),), verbose=False)\n    # _logger.info('flops = %fM', flops / 1e6)\n    # _logger.info('param size = %fM', params / 1e6)\n\n    linear_scaled_lr = args.lr * args.batch_size / 1024.0\n    args.lr = linear_scaled_lr\n\n    model = model.cuda()\n    optimizer = create_optimizer(args, model)\n\n    # optionally resume from a checkpoint\n    resume_epoch = None\n    if args.resume:\n        # checkpoint = torch.load(args.resume, map_location='cpu')\n        # model.load_state_dict(checkpoint['state_dict'], False)\n        resume_epoch = resume_checkpoint(\n            model, args.resume,\n            optimizer=None if args.no_resume_opt else optimizer)\n\n    if args.node_resume:\n        ckpt = torch.load(args.node_resume, map_location='cpu')\n        model.load_node_weight(ckpt, args.node_trainable)\n\n    lr_scheduler, num_epochs = create_scheduler(args, optimizer)\n    m = Mask(model)\n    start_epoch = 0\n    if args.start_epoch is not None:\n        # a specified start_epoch will always override the resume epoch\n        start_epoch = args.start_epoch\n    elif resume_epoch is not None:\n        start_epoch = resume_epoch\n    if lr_scheduler is not None and start_epoch > 0:\n        lr_scheduler.step(start_epoch)\n\n    _logger.info('Scheduled epochs: {}'.format(num_epochs))\n    batch_size=args.batch_size\n    data_dir = '/data0/datasets/'\n    trainset = get_dataset('CIFAR100', type=\"train\", dir=data_dir)\n    testset = get_dataset('CIFAR100', type=\"test\", dir=data_dir)\n    out_num=int(args.num_classes/args.task_num)\n    labels_per_dataset_train = [list(np.array(range(out_num))+out_num*context_id) for context_id in range(args.task_num)]\n    labels_per_dataset_test = [list(np.array(range(out_num))+out_num*context_id) for context_id in range(args.task_num)]\n    train_datasets = []\n    for labels in labels_per_dataset_train:\n        target_transform = transforms.Lambda(lambda y, x=labels[0]: y-x)\n        train_datasets.append(SubDataset(trainset, labels, target_transform=target_transform))\n    test_datasets = []\n    for labels in labels_per_dataset_test:\n        target_transform = transforms.Lambda(lambda y, x=labels[0]: y-x) \n        test_datasets.append(SubDataset(testset, labels, target_transform=target_transform))\n   \n    train_data = []\n    test_data = []\n    t_data=[]\n    for task in range(len(train_datasets)):\n        train_data.append(DataLoader(train_datasets[task], batch_size=batch_size, shuffle=True, drop_last=True, **({'num_workers': 4, 'pin_memory': True})))\n        test_data.append(DataLoader(test_datasets[task], batch_size=batch_size, shuffle=True, drop_last=True, **({'num_workers': 4, 'pin_memory': True})))\n    if args.loss_fn == 'mse':\n        train_loss_fn = UnilateralMse(1.)\n        validate_loss_fn = UnilateralMse(1.)\n\n    else:\n\n        train_loss_fn = nn.CrossEntropyLoss().cuda()\n\n        validate_loss_fn = nn.CrossEntropyLoss().cuda()\n\n    if args.loss_fn == 'mix':\n        train_loss_fn = MixLoss(train_loss_fn)\n        validate_loss_fn = MixLoss(validate_loss_fn)\n\n    eval_metric = args.eval_metric\n    best_metric = None\n    best_epoch = None\n\n    saver = CheckpointSaver(\n        model=model, optimizer=optimizer, args=args,\n        checkpoint_dir=output_dir, recovery_dir=output_dir)\n    with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:\n        f.write(args_text)\n    loader_his=[]\n    task_ready={}\n    for index, item in enumerate(model.parameters()):\n        if len(item.size()) > 1 and index<=40:\n            task_ready[index]=torch.zeros(item.size(),device=device)\n    try:  # train the model\n        task_count=0\n        regularization_terms= {}\n        for task in range(len(train_datasets)):\n            print(\"Task:\",task)\n            if task==0:\n                m.model = model\n                mat=m.init_length()\n                model = m.model\n                epochs=50\n            else:\n                m.model = model\n                mat,task_ready,taskmaskk,taskww=m.init_grow(task)\n                model = m.model\n                epochs=30\n            ta_his=[i for i in range(task+1)]\n            for epoch in range(epochs):\n                loader_train = iter(train_data[task])\n                if task==0:\n                    train_epoch(epoch, task, model, loader_train, optimizer, train_loss_fn, args,mat,task_ready,taskww=None,\n                    lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,regularization_terms=regularization_terms)\n                else:\n                    train_epoch(epoch, task, model, loader_train, optimizer, train_loss_fn, args,mat,task_ready,taskww=taskww,\n                    lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,regularization_terms=regularization_terms)\n            \n                print(epoch)\n\n                if epoch>0:\n                    m.model = model\n                    m.init_mask(task,epoch)\n                    mat=m.do_mask(task)\n                    model = m.model\n\n                ta_his=[i for i in range(task+1)]\n                for t in ta_his:\n                    loader_his=iter(test_data[t])\n                    validate(t, model, loader_his, validate_loss_fn, args,mat)\n\n                cc=m.if_zero()\n                _logger.info('*** epoch: {0}, task: {1}, pruning: {2}'.format(epoch,task, cc))\n            p_index=m.record()\n                    \n        \n    except KeyboardInterrupt:\n        pass\n    # if best_metric is not None:\n    #     _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))\n\n\ndef train_epoch(\n        epoch, task,model, loader, optimizer, loss_fn, args,mat,task_ready,taskww=None,\n        lr_scheduler=None, saver=None, output_dir='',regularization_terms={}):\n\n    batch_time_m = AverageMeter()\n    data_time_m = AverageMeter()\n    losses_m = AverageMeter()\n    top1_m = AverageMeter()\n    top5_m = AverageMeter()\n\n    model.train()\n\n    end = time.time()\n    last_idx = len(loader) - 1\n    num_updates = epoch * len(loader)\n    for batch_idx, (inputs, target) in enumerate(loader):\n        last_batch = batch_idx == last_idx\n        data_time_m.update(time.time() - end)\n        inputs, target = inputs.type(torch.FloatTensor).cuda(), target.cuda()\n\n        output = model(inputs, mat)\n        t_preds = output[task].cuda()\n        loss = loss_fn(t_preds, target)\n\n        # if len(regularization_terms)>0:\n        #     reg_loss = 0\n        #     for i,reg_term in regularization_terms.items():\n        #         task_reg_loss = 0\n        #         importance = reg_term['importance']\n        #         task_param = reg_term['task_param']\n        #         for n, p in enumerate(model.parameters()):\n        #             if len(p.size())>=1:\n        #                 task_reg_loss += (importance[n] * (p - task_param[n]) ** 2).sum()\n        #         reg_loss += task_reg_loss\n        #     loss += 10000 * reg_loss\n\n        acc1, acc5 = accuracy(t_preds, target, topk=(1, 5))\n\n        losses_m.update(loss.item(), inputs.size(0))\n        top1_m.update(acc1.item(), inputs.size(0))\n        top5_m.update(acc5.item(), inputs.size(0))\n\n        optimizer.zero_grad()\n        loss.backward()\n        # for index, item in enumerate(model.parameters()):\n        #     if len(item.size()) > 1 and index<=40:\n        #         gradmask=torch.where(task_ready[index]>0,0.0,1.0)\n        #         item.grad=item.grad*gradmask\n        optimizer.step()\n        for index, item in enumerate(model.parameters()):\n            if len(item.size()) > 1 and index<=40:\n                if index<40:\n                    ready=task_ready[index].view(task_ready[index].size()[0],-1)\n                    ready=torch.sum(ready,dim=1)\n                else:\n                    ready=torch.sum(task_ready[index],dim=1)\n                windex=torch.nonzero(ready>0)\n                for i in range(len(windex)):\n                    item.data[windex[i]]=taskww[index][windex[i]]\n\n        num_updates += 1\n\n        batch_time_m.update(time.time() - end)\n        if last_batch or batch_idx % args.log_interval == 0:\n            # lrl = [param_group['lr'] for param_group in optimizer.param_groups]\n            # lr = sum(lrl) / len(lrl)\n\n            _logger.info(\n                'Train: {} [{:>4d}/{} ({:>3.0f}%)]  ' \n                'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f})  '\n                'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '\n                'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'\n                'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '\n                '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '\n                'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(\n                    epoch,\n                    batch_idx, len(loader),\n                    100. * batch_idx / last_idx,\n                    loss=losses_m,\n                    top1=top1_m, top5=top5_m,\n                    batch_time=batch_time_m,\n                    rate=inputs.size(0) / batch_time_m.val,\n                    rate_avg=inputs.size(0)  / batch_time_m.avg,\n                    data_time=data_time_m))\n\n        # if saver is not None and args.recovery_interval and (\n        #         last_batch or (batch_idx + 1) % args.recovery_interval == 0):\n        #     saver.save_recovery(epoch, batch_idx=batch_idx)\n\n    #     if lr_scheduler is not None:\n    #         lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)\n\n    #     end = time.time()\n    # # end for\n\n    # if hasattr(optimizer, 'sync_lookahead'):\n    #     optimizer.sync_lookahead()\n\n    # return OrderedDict([('loss', losses_m.avg)])\n\n\ndef validate(task, model, loader, loss_fn, args, mat,log_suffix='', visualize=False, spike_rate=False):\n    batch_time_m = AverageMeter()\n    losses_m = AverageMeter()\n    top1_m = AverageMeter()\n    top5_m = AverageMeter()\n    model.eval()\n    end = time.time()\n    with torch.no_grad():\n        last_idx = len(loader) - 1\n        for batch_idx, (inputs, target) in enumerate(loader):\n            last_batch = batch_idx == last_idx\n            inputs = inputs.type(torch.FloatTensor).cuda()\n            target = target.cuda()\n\n            output = model(inputs,mat)\n            t_preds = output[task]\n            loss = loss_fn(t_preds, target)\n            acc1, acc5 = accuracy(t_preds, target, topk=(1, 5))\n\n            reduced_loss = loss.data\n\n            torch.cuda.synchronize()\n\n            losses_m.update(reduced_loss.item(), inputs.size(0))\n            top1_m.update(acc1.item(), output.size(0))\n            top5_m.update(acc5.item(), output.size(0))\n\n            batch_time_m.update(time.time() - end)\n            end = time.time()\n\n        log_name = 'Test'+str(task) + log_suffix\n        _logger.info(\n            '{0}: [{1:>4d}/{2}]  '\n            'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '\n            'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  ' \n            'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '\n            'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(\n                log_name, batch_idx, last_idx, batch_time=batch_time_m,\n                loss=losses_m, top1=top1_m, top5=top5_m))\n\n    # metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])\n\n    # return metrics\n    \nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "examples/Structural_Development/DSD-SNN/cifar100/manipulate.py",
    "content": "import torch\nfrom torch.utils.data import Dataset\n\n\ndef permutate_image_pixels(image, permutation):\n    '''Permutate the pixels of an image according to [permutation].\n\n    [image]         3D-tensor containing the image\n    [permutation]   <ndarray> of pixel-indeces in their new order'''\n\n    if permutation is None:\n        return image\n    else:\n        c, h, w = image.size()\n        image = image.view(c, -1)\n        image = image[:, permutation]  #--> same permutation for each channel\n        image = image.view(c, h, w)\n        return image\n\n#----------------------------------------------------------------------------------------------------------#\n\nclass SubDataset(Dataset):\n    '''To sub-sample a dataset, taking only those samples with label in [sub_labels].\n\n    After this selection of samples has been made, it is possible to transform the target-labels,\n    which can be useful when doing continual learning with fixed number of output units.'''\n\n    def __init__(self, original_dataset, sub_labels, target_transform=None):\n        super().__init__()\n        self.dataset = original_dataset\n        self.sub_indeces = []\n        for index in range(len(self.dataset)):\n            if hasattr(original_dataset, \"train_labels\"):\n                if self.dataset.target_transform is None:\n                    label = self.dataset.train_labels[index]\n                else:\n                    label = self.dataset.target_transform(self.dataset.train_labels[index])\n            elif hasattr(self.dataset, \"test_labels\"):\n                if self.dataset.target_transform is None:\n                    label = self.dataset.test_labels[index]\n                else:\n                    label = self.dataset.target_transform(self.dataset.test_labels[index])\n            else:\n                label = self.dataset[index][1]\n            if label in sub_labels:\n                self.sub_indeces.append(index)\n        self.target_transform = target_transform\n\n    def __len__(self):\n        return len(self.sub_indeces)\n\n    def __getitem__(self, index):\n        sample = self.dataset[self.sub_indeces[index]]\n        if self.target_transform:\n            target = self.target_transform(sample[1])\n            sample = (sample[0], target)\n        return sample\n\n\nclass MemorySetDataset(Dataset):\n    '''Create dataset from list of <np.arrays> with shape (N, C, H, W) (i.e., with N images each).\n\n    The images at the i-th entry of [memory_sets] belong to class [i], unless a [target_transform] is specified'''\n\n    def __init__(self, memory_sets, target_transform=None):\n        super().__init__()\n        self.memory_sets = memory_sets\n        self.target_transform = target_transform\n\n    def __len__(self):\n        total = 0\n        for class_id in range(len(self.memory_sets)):\n            total += len(self.memory_sets[class_id])\n        return total\n\n    def __getitem__(self, index):\n        total = 0\n        for class_id in range(len(self.memory_sets)):\n            examples_in_this_class = len(self.memory_sets[class_id])\n            if index < (total + examples_in_this_class):\n                class_id_to_return = class_id if self.target_transform is None else self.target_transform(class_id)\n                example_id = index - total\n                break\n            else:\n                total += examples_in_this_class\n        image = torch.from_numpy(self.memory_sets[class_id][example_id])\n        return (image, class_id_to_return)\n\n\nclass TransformedDataset(Dataset):\n    '''To modify an existing dataset with a transform.\n    This is useful for creating different permutations of MNIST without loading the data multiple times.'''\n\n    def __init__(self, original_dataset, transform=None, target_transform=None):\n        super().__init__()\n        self.dataset = original_dataset\n        self.transform = transform\n        self.target_transform = target_transform\n\n    def __len__(self):\n        return len(self.dataset)\n\n    def __getitem__(self, index):\n        (input, target) = self.dataset[index]\n        if self.transform:\n            input = self.transform(input)\n        if self.target_transform:\n            target = self.target_transform(target)\n        return (input, target)\n\n# ----------------------------------------------------------------------------------------------------------#\n\nclass UnNormalize(object):\n    def __init__(self, mean, std):\n        self.mean = mean\n        self.std = std\n\n    def __call__(self, tensor):\n        \"\"\"Denormalize image, either single image (C,H,W) or image batch (N,C,H,W)\"\"\"\n        batch = (len(tensor.size()) == 4)\n        for t, m, s in zip(tensor.permute(1, 0, 2, 3) if batch else tensor, self.mean, self.std):\n            t.mul_(s).add_(m)\n            # The normalize code -> t.sub_(m).div_(s)\n        return tensor\n"
  },
  {
    "path": "examples/Structural_Development/DSD-SNN/cifar100/maskcl2.py",
    "content": "import numpy as np\r\nimport torch\r\nimport math\r\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\r\nimport random\r\n\r\ndef unit(x):\r\n    if x.size()[0]>0:\r\n        xnp=x.cpu().numpy()\r\n        maxx=torch.max(x)\r\n        #maxx=np.percentile(xnp, 99.5)\r\n        minx=torch.min(x)\r\n        marge=maxx-minx\r\n        if marge!=0:\r\n            xx=(x-minx)/marge\r\n            xx=torch.clip(xx, 0,1)\r\n        else:\r\n            xx=0.5*torch.ones_like(x)\r\n        return xx\r\n    else:\r\n        return x\r\n        \r\nclass Mask:\r\n    def __init__(self, model):\r\n        self.model = model\r\n        self.mat = {}\r\n        self.p_index={}\r\n        self.p_num={}\r\n        self.k=15\r\n        self.task_ready={}\r\n        self.taskmask={}\r\n        self.init_rate=0.3\r\n        self.grow_rate=0.125\r\n\r\n        self.prunconv_init=0.8\r\n        self.prunfc_init=1.3\r\n        self.prunconv_grow=0.5\r\n        self.prunfc_grow=1\r\n\r\n        self.n_delta={}\r\n        self.reduce={}\r\n        self.taskww={}\r\n\r\n    def init_length(self):\r\n        for index, item in enumerate(self.model.parameters()):\r\n            if len(item.size()) > 1:\r\n                print(index,item.size())\r\n                self.mat[index]=torch.ones(item.size(),device=device)\r\n        for index, item in enumerate(self.model.parameters()):\r\n            if len(item.size()) > 1:\r\n                if index<=40:\r\n                    self.p_index[index]=torch.tensor([])\r\n                    self.task_ready[index]=torch.zeros(item.size(),device=device)\r\n                    self.reduce[index] = 5*torch.ones(item.size()[0],device=device)\r\n                    if len(item.size()) == 4:\r\n                        self.p_num[index]=torch.zeros(item.size()[0],device=device)\r\n                        self.mat[index][int(self.init_rate*item.size()[0]):]=0.0\r\n                        if index+5<40:\r\n                            self.mat[index+5][:,int(self.init_rate*item.size()[0]):]=0.0\r\n                        if index+5==40:\r\n                            self.mat[index+5][:,int(self.init_rate*item.size()[0])*16:]=0.0\r\n                    if len(item.size()) == 2:\r\n                        self.p_num[index]=torch.zeros(item.size()[0]*item.size()[1],device=device)\r\n                        self.mat[index][int(self.init_rate*item.size()[0]):]=0.0\r\n                if index>40:\r\n                    self.mat[index]=torch.ones(item.size(),device=device)\r\n                    if index==44:\r\n                        self.mat[index]=torch.zeros(item.size(),device=device)\r\n                        self.mat[index][:,:int(self.init_rate*item.size()[1])]=1.0\r\n        return self.mat\r\n            \r\n    def get_filter_codebook(self,index,ww,task,epoch): \r\n        if task==0:\r\n            pruncon=self.prunconv_init\r\n            prunfc=self.prunfc_init\r\n        else:\r\n            pruncon=self.prunconv_grow\r\n            prunfc=self.prunfc_grow\r\n\r\n        if len(ww.size()) == 4:\r\n            p_ww=ww.view(ww.size()[0],-1)\r\n            p_ww=torch.sum(p_ww,dim=1)\r\n\r\n            task_use=self.mat[index]-torch.sign(self.task_ready[index])\r\n            nouse=torch.sum(task_use.view(ww.size()[0],-1),dim=1)\r\n            no=torch.nonzero(nouse<0.1)\r\n            p_ww[no]=p_ww.max()\r\n            \r\n            self.n_delta[index]=(unit(p_ww)*2-pruncon)\r\n            pos=torch.nonzero(self.n_delta[index]>0)\r\n            self.n_delta[index][pos]=self.n_delta[index][pos]+3\r\n            self.reduce[index]=self.reduce[index]*0.999+self.n_delta[index]*math.exp(-int((epoch-1)/13))\r\n            p_ind = torch.nonzero(self.reduce[index] <0)\r\n            print(self.reduce[index].mean(),self.reduce[index].max(),self.reduce[index].min(),len(p_ind))\r\n            for x in range(0, len(p_ind)):\r\n                self.mat[index][p_ind[x]] = 0\r\n                if index+5<40:\r\n                    self.mat[index+5][:,p_ind[x]]=0\r\n                if index+5==40:\r\n                    self.mat[index+5][:,p_ind[x]*16:(p_ind[x]+1)*16]=0\r\n            self.mat[index]=torch.sign(self.mat[index]+ self.task_ready[index])\r\n\r\n        if len(ww.size()) == 2:\r\n            p_ww=torch.sum(ww,dim=1)\r\n\r\n            task_use=self.mat[index]-torch.sign(self.task_ready[index])\r\n            nouse=torch.sum(task_use,dim=1)\r\n            no=torch.nonzero(nouse<0.1)\r\n            p_ww[no]=p_ww.max()\r\n\r\n            self.n_delta[index]=(unit(p_ww)*2-prunfc)\r\n            print(self.n_delta[index].mean(),self.n_delta[index].max(),self.n_delta[index].min())\r\n            pos=torch.nonzero(self.n_delta[index]>0)\r\n            self.n_delta[index][pos]=self.n_delta[index][pos]+3\r\n            self.reduce[index]=self.reduce[index]*0.999+self.n_delta[index]*math.exp(-int((epoch-1)/13))\r\n            p_ind = torch.nonzero(self.reduce[index] <0)\r\n            print(self.reduce[index].mean(),self.reduce[index].max(),self.reduce[index].min(),len(p_ind))\r\n            index_ta=44+task*2\r\n            for x in range(0, len(p_ind)):\r\n                self.mat[index][p_ind[x]] = 0\r\n                self.mat[index_ta][:,p_ind[x]]=0\r\n            self.mat[index]=torch.sign(self.mat[index]+self.task_ready[index])\r\n\r\n    def convert2tensor(self, x):\r\n        x = torch.FloatTensor(x)\r\n        return x\r\n\r\n    def init_mask(self,task,epoch):\r\n        for index, item in enumerate(self.model.parameters()):\r\n            if len(item.size()) > 1 and index<=40:\r\n                self.get_filter_codebook(index,abs(item.data),task,epoch)\r\n\r\n    def do_mask(self,task):\r\n        for index, item in enumerate(self.model.parameters()):\r\n            if len(item.size()) > 1 and index<=40:\r\n                ww=item.data\r\n                item.data=ww*self.mat[index].cuda()\r\n        return self.mat\r\n\r\n    def init_grow(self,task):\r\n        self.taskmask[task]={}\r\n        index_ta=44+task*2\r\n        self.mat[index_ta]=torch.zeros(self.mat[index_ta].size(),device=device)\r\n        for index, item in enumerate(self.model.parameters()):\r\n            if len(item.size()) > 1 and index<=40:\r\n                self.task_ready[index]=self.task_ready[index]+self.mat[index]\r\n                self.taskmask[task][index]=self.mat[index]\r\n                self.taskww[index]=item.data.clone()\r\n                ind_all=[x for x in range(item.size()[0])]\r\n                pp=list(np.array(self.p_index[index]))\r\n                ind_empty=set(ind_all)-set(pp)\r\n                ind_empty=list(ind_empty)\r\n                # random.shuffle(ind_empty)\r\n                ind_grow=ind_empty[:int(item.size()[0]*self.grow_rate)]\r\n                ind_grow=torch.tensor(ind_grow)\r\n                ww_g=torch.empty(item.size(),device=device)\r\n                if index<40:\r\n                    torch.nn.init.kaiming_uniform_(ww_g, a=math.sqrt(5))\r\n                if index==40:\r\n                    kk=1/math.sqrt(item.size()[1])\r\n                    torch.nn.init.uniform_(ww_g,a=-kk,b=kk)\r\n                for x in range(0, len(ind_grow)):\r\n                    self.mat[index][ind_grow[x]] = 1.0\r\n                    item.data[ind_grow[x]]=ww_g[ind_grow[x]]\r\n                    if index==40:\r\n                        self.mat[index_ta][:,ind_grow[x]]=1.0\r\n                self.mat[index]=torch.sign(self.mat[index]+self.task_ready[index])\r\n                self.p_num[index]=torch.zeros(item.size()[0],device=device)\r\n                self.reduce[index] = 5*torch.ones(item.size()[0],device=device)\r\n                #self.mat[index_ta]=self.mat[index_ta]+self.mat[index_ta-2]\r\n        # nn=torch.sum(self.task_ready[24],dim=1)\r\n        # use_nn=torch.nonzero(nn>1)\r\n        # for x in range(0, len(use_nn)):\r\n        #     self.mat[index_ta][:,use_nn[x]] = 1.0\r\n        for index, item in enumerate(self.model.parameters()):\r\n            if len(item.size()) > 1 and index<40:\r\n                nsum=self.mat[index].view(item.size()[0],-1)\r\n                nn=torch.sum(abs(nsum),dim=1)\r\n                empy_nn=torch.nonzero(nn<0.00001)\r\n                for x in range(0, len(empy_nn)):\r\n                    if index+5<40:\r\n                        self.mat[index+5][:,empy_nn[x]]=0\r\n                    if index+5==40:\r\n                        self.mat[index+5][:,empy_nn[x]*16:(empy_nn[x]+1)*16]=0\r\n\r\n        return self.mat,self.task_ready,self.taskmask,self.taskww\r\n\r\n    def if_zero(self):\r\n        cc=[]\r\n        for index, item in enumerate(self.model.parameters()):\r\n            if len(item.size()) > 1 and index<=40:\r\n                b = item.data.view(-1).cpu().numpy()\r\n                print(\"number of weight is %d, zero is %.3f\" %(len(b),100*(len(b)- np.count_nonzero(b))/len(b)))\r\n                cc.append(100*(len(b)- np.count_nonzero(b))/len(b))\r\n                if index==40:\r\n                    nouse=torch.sum(self.mat[index],dim=1)\r\n                    no=torch.nonzero(nouse<0.1)\r\n                    print(len(no))\r\n                    cc.append(len(no))\r\n        return cc\r\n\r\n    def record(self):\r\n        for index, item in enumerate(self.model.parameters()):\r\n            if len(item.size()) > 1 and index<=40:\r\n                nsum=self.mat[index].view(item.size()[0],-1)\r\n                nn=torch.sum(abs(nsum),dim=1)\r\n                epoch_select=torch.nonzero(nn>0.00001)\r\n                select=set(epoch_select)|set(self.p_index[index])\r\n                self.p_index[index]= torch.tensor(list(select))\r\n        return self.p_index"
  },
  {
    "path": "examples/Structural_Development/DSD-SNN/cifar100/vgg_snn.py",
    "content": "# encoding: utf-8\r\n# Author    : Floyed<Floyed_Shen@outlook.com>\r\n# Datetime  : 2022/7/26 18:56\r\n# User      : Floyed\r\n# Product   : PyCharm\r\n# Project   : BrainCog\r\n# File      : vgg_snn.py\r\n# explain   :\r\n\r\nfrom functools import partial\r\nfrom torch.nn import functional as F\r\nimport torchvision\r\nfrom timm.models import register_model\r\n\r\nfrom braincog.datasets import is_dvs_data\r\nfrom braincog.base.node.node import *\r\nfrom braincog.base.connection.layer import *\r\nfrom braincog.base.encoder.encoder import *\r\nfrom braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule\r\n\r\n@register_model\r\nclass SNN(BaseModule):\r\n    def __init__(self,\r\n                 num_classes=100,\r\n                 step=8,\r\n                 node_type=LIFNode,\r\n                 encode_type='direct',\r\n                 batch_size=100,\r\n                 task_num=10,\r\n                 *args,\r\n                 **kwargs):\r\n        super().__init__(step, encode_type, *args, **kwargs)\r\n\r\n        self.n_preact = kwargs['n_preact'] if 'n_preact' in kwargs else False\r\n        self.batch_size=batch_size\r\n        self.num_classes = num_classes\r\n        self.task_num=task_num\r\n        self.out_num=int(self.num_classes/self.task_num)\r\n\r\n        self.node = node_type\r\n        if issubclass(self.node, BaseNode):\r\n            self.node = partial(self.node, **kwargs, step=step)\r\n\r\n        self.dataset = kwargs['dataset']\r\n        if not is_dvs_data(self.dataset):\r\n            init_channel = 3\r\n            output_size = 2\r\n        else:\r\n            init_channel = 2\r\n            output_size = 3\r\n        #self.channel_number=[256,512,1024]\r\n        self.channel_number=[512,1024,2048]\r\n\r\n        self.feature = nn.Sequential(\r\n            BaseConvModule(init_channel, self.channel_number[0], kernel_size=(3, 3), padding=(1, 1), node=self.node),\r\n            BaseConvModule(self.channel_number[0], self.channel_number[0], kernel_size=(3, 3), padding=(1, 1), node=self.node),\r\n            nn.AvgPool2d(2),\r\n            BaseConvModule(self.channel_number[0], self.channel_number[0], kernel_size=(3, 3), padding=(1, 1), node=self.node),\r\n            BaseConvModule(self.channel_number[0], self.channel_number[0], kernel_size=(3, 3), padding=(1, 1), node=self.node),\r\n            nn.AvgPool2d(2),\r\n            BaseConvModule(self.channel_number[0], self.channel_number[1], kernel_size=(3, 3), padding=(1, 1), node=self.node),\r\n            BaseConvModule(self.channel_number[1], self.channel_number[1], kernel_size=(3, 3), padding=(1, 1), node=self.node),\r\n            nn.AvgPool2d(2),\r\n            BaseConvModule(self.channel_number[1], self.channel_number[2], kernel_size=(3, 3), padding=(1, 1), node=self.node),\r\n            BaseConvModule(self.channel_number[2], self.channel_number[2], kernel_size=(3, 3), padding=(1, 1), node=self.node),\r\n        )\r\n\r\n        self.fc = nn.Sequential(\r\n            nn.Flatten(),\r\n            BaseLinearModule(\r\n                self.channel_number[2]*4*4, self.channel_number[2], node=self.node),\r\n        )\r\n\r\n        self.dec = nn.ModuleDict()\r\n        for task in range(self.task_num):\r\n            ta=str(task)\r\n            self.dec[ta] = self._create_decision()\r\n        \r\n\r\n    def logits(self, x):\r\n        outputs =torch.zeros((self.task_num,self.batch_size,self.out_num),device='cuda')\r\n        for task, func in self.dec.items():\r\n            ta=int(task)\r\n            outputs[ta]=func(x)\r\n        return outputs\r\n\r\n    def _create_decision(self):\r\n        fc = nn.Linear(self.channel_number[2], self.out_num)\r\n        # fc = BaseLinearModule(1024, 10, node=self.node)\r\n        return fc\r\n\r\n    def forward(self, inputs, mat):\r\n        inputs = self.encoder(inputs)\r\n        self.reset()\r\n        step = self.step\r\n        outputs = []\r\n        for index, item in enumerate(self.parameters()):\r\n            if len(item.size()) > 1:\r\n                ww=item.data\r\n                item.data=ww*mat[index].cuda()\r\n\r\n        for t in range(step):\r\n            x = inputs[t]\r\n            x = self.feature(x)\r\n            x = self.fc(x)\r\n            x = self.logits(x)\r\n            outputs.append(x)\r\n\r\n        out=sum(outputs).cuda()\r\n\r\n        return out / step\r\n\r\n\r\n\r\n# class MaskConvModule(nn.Module):\r\n#     \"\"\"\r\n#     SNN卷积模块\r\n#     :param in_channels: 输入通道数\r\n#     :param out_channels: 输出通道数\r\n#     :param kernel_size: kernel size\r\n#     :param stride: stride\r\n#     :param padding: padding\r\n#     :param bias: Bias\r\n#     :param node: 神经元类型\r\n#     :param kwargs:\r\n#     \"\"\"\r\n#     def __init__(self,\r\n#                  in_channels: int,\r\n#                  out_channels: int,\r\n#                  kernel_size=(3, 3),\r\n#                  stride=(1, 1),\r\n#                  padding=(1, 1),\r\n#                  bias=False,\r\n#                  node=PLIFNode,\r\n#                  **kwargs):\r\n\r\n#         super().__init__()\r\n\r\n#         if node is None:\r\n#             raise TypeError\r\n\r\n#         self.groups = kwargs['groups'] if 'groups' in kwargs else 1\r\n#         self.conv = MConv2d(in_channels=in_channels * self.groups,\r\n#                               out_channels=out_channels * self.groups,\r\n#                               kernel_size=kernel_size,\r\n#                               padding=padding,\r\n#                               stride=stride,\r\n#                               bias=bias)\r\n\r\n#         self.bn = nn.BatchNorm2d(out_channels * self.groups)\r\n\r\n#         self.node = partial(node, **kwargs)()\r\n\r\n#         self.activation = nn.Identity()\r\n\r\n#     def forward(self, x, mat):\r\n#         x = self.conv(x,mat)\r\n#         x = self.bn(x)\r\n#         x = self.node(x)\r\n#         return x\r\n\r\n# class MConv2d(nn.Conv2d):\r\n\r\n#     def __init__(self, in_channels, out_channels, kernel_size, stride=1,\r\n#                  padding=0, dilation=1, groups=1, bias=True, gain=True):\r\n#         super(MConv2d, self).__init__(in_channels, out_channels, kernel_size, stride,\r\n#                                        padding, dilation, groups, bias)\r\n\r\n#         self.gain = 1.\r\n\r\n#     def forward(self, x, mat):\r\n#         weight = self.weight\r\n#         weight = weight*mat\r\n#         return F.conv2d(x, weight, self.bias, self.stride,\r\n#                         self.padding, self.dilation, self.groups)\r\n\r\n# class MaskLinearModule(nn.Module):\r\n#     \"\"\"\r\n#     线性模块\r\n#     :param in_features: 输入尺寸\r\n#     :param out_features: 输出尺寸\r\n#     :param bias: 是否有Bias, 默认 ``False``\r\n#     :param node: 神经元类型, 默认 ``LIFNode``\r\n#     :param args:\r\n#     :param kwargs:\r\n#     \"\"\"\r\n#     def __init__(self,\r\n#                  in_features: int,\r\n#                  out_features: int,\r\n#                  bias=True,\r\n#                  node=LIFNode,\r\n#                  *args,\r\n#                  **kwargs):\r\n#         super().__init__()\r\n#         if node is None:\r\n#             raise TypeError\r\n\r\n#         self.fc = MLinear(in_features=in_features,\r\n#                                 out_features=out_features, bias=bias)\r\n#         self.node = partial(node, **kwargs)()\r\n\r\n#     def forward(self, x,mat):\r\n#         outputs = self.fc(x,mat)\r\n#         return self.node(outputs)\r\n\r\n# class MLinear(nn.Linear):\r\n#     def __init__(self, in_features: int, out_features: int, bias: bool = True):\r\n#         super(MLinear, self).__init__(in_features, out_features, bias)\r\n#         self.gain = 1.\r\n\r\n#     def forward(self, input, mat):\r\n#         weight = self.weight\r\n#         weight = weight*mat\r\n#         return F.linear(input, weight, self.bias)"
  },
  {
    "path": "examples/Structural_Development/ELSM/evolve.py",
    "content": "import time\nimport threading\nfrom threading import Thread\nimport os\nimport networkx as nx\nimport numpy as np\nfrom population import *\nimport nsganet as engine\nfrom pymop.problem import Problem\nfrom pymoo.optimize import minimize\nfrom pymoo.operators.sampling.random_sampling import RandomSampling\nfrom pymoo.operators.mutation.bitflip_mutation import BinaryBitflipMutation\nimport logging\nfrom model import *\nfrom spikes import calc_f2\nfrom mul import mul_f1\n\n_logger = logging.getLogger('')\nconfig_parser = parser = argparse.ArgumentParser(description='Evolution Config', add_help=False)\n\nparser = argparse.ArgumentParser(description='ELSM')\nparser.add_argument('--device', type=int, default=2)\nparser.add_argument('--seed', type=int, default=68, metavar='S')\nparser.add_argument('--datapath', default='', type=str, metavar='PATH')\nparser.add_argument('--output', default='', type=str, metavar='PATH')\nparser.add_argument('--liquid-size', type=int, default=8000)\nparser.add_argument('--pop-size', type=int, default=80)\nparser.add_argument('--up', type=int, default=32000000)\nparser.add_argument('--low', type=int, default=320000)\n\nparser.add_argument('--n_offspring', type=int, default=100)\nparser.add_argument('--n_gens', type=int, default=10000)\nparser.add_argument('--arand', type=float, default=285)\nparser.add_argument('--brand', type=float, default=1.8)\n\n\ndef _parse_args():\n    args_config, remaining = config_parser.parse_known_args()\n    args = parser.parse_args(remaining)\n    return args\n\n\n\nclass Evolve(Problem):\n    # first define the NAS problem (inherit from pymop)\n    def __init__(self, args,n_var=20, n_obj=1, n_constr=0, lb=None, ub=None):\n        super().__init__(n_var=n_var, n_obj=n_obj, n_constr=n_constr, type_var=np.int64)\n        self.xl = lb\n        self.xu = ub\n        self._n_evaluated = 0  # keep track of how many architectures are sampled\n        self.args=args\n\n\n    def _evaluate(self, x, out, *args, **kwargs):\n        \n\n        objs = np.full((x.shape[0], self.n_obj), np.nan)\n        g1 = np.full((x.shape[0]), np.nan)\n        g2 = np.full((x.shape[0]), np.nan)\n        gen_dir=os.path.join(self.args.output,'generaion'+str(kwargs['algorithm'].n_gen))\n        os.makedirs(gen_dir,exist_ok = True)\n        # np.save(os.path.join(gen_dir,\"x.npy\"),x)\n        lsms = x.reshape(x.shape[0],self.args.liquid_size,self.args.liquid_size)\n        for i in range(x.shape[0]):\n            temp_G = nx.Graph(lsms[i])\n            nx.write_gpickle(temp_G, os.path.join(gen_dir,str(i)+\".pkl\"))\n        self.ob1=mul_f1(pop=x.shape[0],steps=10,rootdir=gen_dir)\n\n        for i in range(x.shape[0]):\n            arch_id = self._n_evaluated + 1\n            print('\\n')\n            _logger.info('Network= {}'.format(arch_id))\n            genome = x[i, :]\n\n            g1[i]= genome.sum()-self.args.up\n            g2[i]= self.args.low-genome.sum()\n            lsmm = genome.reshape(self.args.liquid_size,self.args.liquid_size)\n            small_coe_a,small_coe_b=self.ob1[i]\n            lsmm=torch.tensor(lsmm,device='cuda:%d' % self.args.device).float()\n            crit = calc_f2(lsmm,'cuda:%d' % self.args.device)\n            objs[i, 1] = abs(crit-1)\n            # all objectives assume to be MINIMIZED !!!!!                \n            objs[i, 0] = -(small_coe_a/self.args.arand)/(small_coe_b/self.args.brand)\n            \n\n            _logger.info('small word= {}'.format(objs[i, 0]))\n            _logger.info('criticality= {}'.format(objs[i, 1]))\n\n            self._n_evaluated += 1\n\n        out[\"F\"] = objs\n        out[\"G\"] = np.column_stack([g1,g2])\n        # if your NAS problem has constraints, use the following line to set constraints\n        # out[\"G\"] = np.column_stack([g1, g2, g3, g4, g5, g6]) in case 6 constraints\n\n\n# ---------------------------------------------------------------------------------------------------------\n# Define what statistics to print or save for each generation\n# ---------------------------------------------------------------------------------------------------------\ndef do_every_generations(algorithm):\n    # this function will be call every generation\n    # it has access to the whole algorithm class\n    gen = algorithm.n_gen\n    pop_var = algorithm.pop.get(\"X\")\n    pop_obj = algorithm.pop.get(\"F\")\n    \n    # report generation info to files\n    _logger.info(\"generation = {}\".format(gen))\n    _logger.info(\"population error1: best = {}, mean = {}, \"\n                 \"median1 = {}, worst1 = {}\".format(np.min(pop_obj[:, 0]), np.mean(pop_obj[:, 0]),\n                                                  np.median(pop_obj[:, 0]), np.max(pop_obj[:, 0])))\n    _logger.info('Best1 Genome id= {}'.format(np.argmin(pop_obj[:, 0])))\n\n    _logger.info(\"population error2: best = {}, mean = {}, \"\n                 \"median2 = {}, worst2 = {}\".format(np.min(pop_obj[:, 1]), np.mean(pop_obj[:, 1]),\n                                                  np.median(pop_obj[:, 1]), np.max(pop_obj[:, 1])))\n    _logger.info('Best2 Genome id= {}'.format(np.argmin(pop_obj[:, 1])))\n    if gen%20==0:\n        best_sid=np.argmin(pop_obj[:, 0])\n        best_sname='-'.join([\n                'gen'+str(gen),\n                's'+str(float('%.4f' % pop_obj[best_sid, 0])),\n                'c'+str(float('%.4f' % pop_obj[best_sid, 1])),\n            ])\n        best_cid=np.argmin(pop_obj[:, 1])\n        best_cname='-'.join([\n                'gen'+str(gen),\n                's'+str(float('%.4f' % pop_obj[best_cid, 0])),\n                'c'+str(float('%.4f' % pop_obj[best_cid, 1])),\n            ])\n        \n        np.save(os.path.join('',best_sname+datetime.now().strftime(\"%Y%m%d-%H%M%S\")),pop_var[np.argmin(pop_obj[:, 0])])\n        np.save(os.path.join('',best_cname+datetime.now().strftime(\"%Y%m%d-%H%M%S\")),pop_var[np.argmin(pop_obj[:, 1])])\n\nif __name__ == '__main__':\n    args = _parse_args()\n    out_base_dir= os.path.join(args.output, datetime.now().strftime(\"%Y%m%d-%H%M%S\"))\n    os.makedirs(out_base_dir,exist_ok = True)\n    args.output=out_base_dir\n    setup_default_logging(log_path=os.path.join(out_base_dir, 'log.txt'))\n\n    kkk = Evolve(args,n_var=args.liquid_size*args.liquid_size, \n                  n_obj=2, n_constr=2)\n    method = engine.nsganet(pop_size=args.pop_size,\n                            sampling=RandomSampling(var_type='custom'),\n                            mutation=BinaryBitflipMutation(),\n                            n_offsprings=args.n_offspring,\n                            eliminate_duplicates=True)\n    kres=minimize(kkk,\n                   method,\n                   callback=do_every_generations,\n                   termination=('n_gen', args.n_gens))\n\n\n"
  },
  {
    "path": "examples/Structural_Development/ELSM/lsm.py",
    "content": "from __future__ import print_function\nimport torchvision\nimport torchvision.transforms as transforms\nimport os\nimport time\nimport numpy as np\nimport torch\nfrom torch import nn as nn\nfrom mnistmodel import *\nfrom tqdm import tqdm\nimport argparse\nfrom datetime import datetime\nimport logging\nfrom timm.utils import *\nfrom timm.loss import LabelSmoothingCrossEntropy\nfrom braincog.base.utils import UnilateralMse, MixLoss\nfrom braincog.base.learningrule.STDP import *\n\ndevice='cuda:7'\n\ndef lr_scheduler(optimizer, epoch, init_lr=0.1, lr_decay_epoch=50):\n    \"\"\"Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs.\"\"\"\n    if epoch % lr_decay_epoch == 0 and epoch > 1:\n        for param_group in optimizer.param_groups:\n            param_group['lr'] = param_group['lr'] * 0.1\n    return optimizer   \n\n\nbatch_size=100\nliquid_size=8000\n\nlearning_rate = 1e-3\nnum_epochs = 100  # max epoch\n\ndata_path = '/data'  \nload_path=''\ntrain_dataset = torchvision.datasets.MNIST(root=data_path, train=True, download=False, transform=transforms.ToTensor())\ntrain_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)\n\ntest_set = torchvision.datasets.MNIST(root=data_path, train=False, download=False, transform=transforms.ToTensor())\ntest_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)\n\nsnn = SNN(ins=784,\n        batchsize=batch_size,\n        device=device,\n        liquid_size=liquid_size,\n        lsm_tau=lsm_tau,\n        lsm_th=lsm_th)\nsnn.load_state_dict(torch.load(load_path)['fc'])\nsnn.learning_rule=[]\nsnn.con[0].load_state_dict(torch.load(load_path)['lsm0'])\nw2tmp=nn.Linear(liquid_size,liquid_size,bias=False,device=device)\nsnn.connectivity_matrix=torch.load(load_path)['connectivity_matrix'].to(device)\nw2tmp.weight.data=(torch.load(load_path)['liquid_weight'].to(device))*snn.connectivity_matrix\nsnn.learning_rule.append(MutliInputSTDP(snn.node_lsm(), [snn.con[0], w2tmp]))  # pm\nsnn.eval()\nsnn.to(device)\n\nls = 'mse'\n\nif ls == 'ce':\n    criterion = nn.CrossEntropyLoss()\nelif ls == 'bce':\n    criterion = nn.BCEWithLogitsLoss()\nelif ls == 'mse':\n    criterion = UnilateralMse(1.)\nelif ls == 'sce':\n    criterion = LabelSmoothingCrossEntropy()\nelif ls == 'sbce':\n    criterion = LabelSmoothingBCEWithLogitsLoss()\nelif ls == 'umse':\n    criterion = UnilateralMse(.5)\n\noptimizer = torch.optim.AdamW(snn.fc.parameters(),lr=0.001, weight_decay=1e-4)\n\nl=[]\nbest_acc=0\nfor epoch in range(num_epochs):\n    running_loss = 0\n    start_time = time.time()\n    for i, (images, labels) in enumerate(tqdm(train_loader)):\n        snn.zero_grad()\n        optimizer.zero_grad()\n        images = images.float().to(device)\n        outputs = snn(images)\n        labels=labels.to(device)\n        loss = criterion(outputs, labels)\n        running_loss += loss.item()\n        loss.backward()\n\n        optimizer.step()\n        snn.reset()\n        if (i + 1) % 100 == 0:\n            running_loss = 0\n\n    correct = 0\n    total = 0\n    optimizer = lr_scheduler(optimizer, epoch, learning_rate, 40)\n\n    for batch_idx, (inputs, targets) in enumerate(test_loader):\n        inputs = inputs.float().to(device)\n        snn.zero_grad()\n        optimizer.zero_grad()\n        outputs = snn(inputs)\n        targets=targets.to(device)\n        loss = criterion(outputs, targets)\n\n        _, predicted = outputs.max(1)\n        total += float(targets.size(0))\n        correct += float(predicted.eq(targets).sum().item())\n        snn.reset()\n        if batch_idx % 100 == 0:\n            acc = 100. * float(correct) / float(total)\n            print(batch_idx, len(test_loader), ' Acc: %.5f' % acc)\n    print('Test Accuracy: %.3f' % (100 * correct / total))\n    acc = 100. * float(correct) / float(total)\n    if best_acc < acc:\n        best_acc = acc\n    print(best_acc)\n    l.append(best_acc)\n\n\n\n"
  },
  {
    "path": "examples/Structural_Development/ELSM/model.py",
    "content": "from functools import partial\nfrom torch.nn import functional as F\nfrom torch import nn as nn\nimport torchvision, pprint\nfrom copy import deepcopy\nfrom timm.models import register_model\nfrom braincog.base.node.node import *\nfrom braincog.base.connection.layer import *\nfrom braincog.base.encoder.encoder import *\nfrom braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule\nfrom braincog.base.brainarea.BrainArea import BrainArea\nfrom braincog.base.connection.CustomLinear import *\nfrom braincog.base.learningrule.STDP import *\nimport matplotlib.pyplot as plt\n\n\n\n\n@register_model\nclass nSNN(BaseModule):\n    def __init__(self,\n                 batchsize,\n                 liquid_size,\n                 device,\n                 connectivity_matrix,\n                 num_classes=10,\n                 step=1,\n                 node_type=LIFNode,\n                 encode_type='direct',\n                 lsm_th=0.3,\n                 fc_th=0.3,\n                 lsm_tau=3,\n                 fc_tau=3,\n                 ins=1156,\n                 *args,\n                 **kwargs):\n        super().__init__(step, encode_type, *args, **kwargs)\n        self.batchsize=batchsize\n        self.ins=ins\n        self.node_lsm=partial(node_type, **kwargs, step=step,tau=lsm_tau,threshold=lsm_th)\n        self.node_fc = partial(node_type, **kwargs, step=step,tau=fc_tau,threshold=fc_th)\n        self.liquid_size=liquid_size\n        self.device=device\n        self.con=[]\n        self.learning_rule=[]\n        self.connectivity_matrix=connectivity_matrix\n        w1tmp=nn.Linear(ins,liquid_size,bias=False).to(device)\n        self.con.append(w1tmp)\n        w2tmp=nn.Linear(liquid_size,liquid_size,bias=False).to(device)\n        self.liquid_weight=w2tmp.weight.data\n        w2tmp.weight.data=w2tmp.weight.data*self.connectivity_matrix\n        self.con.append(w2tmp)\n\n        self.learning_rule.append(MutliInputSTDP(self.node_lsm(), [self.con[0], self.con[1]]))  # pm\n\n        self.fc = nn.Sequential(\n            nn.Linear(liquid_size,num_classes),\n            self.node_fc()\n        )\n\n    def forward(self, x):\n        sum_spike=0\n        self.out = torch.zeros(x.shape[0], self.liquid_size).to(self.device)\n        tw=x.shape[1]\n        self.tw=tw\n        self.firing_tw=torch.zeros(tw, self.batchsize, self.liquid_size).to(self.device)\n\n        for t in range(tw):\n            self.out, self.dw = self.learning_rule[0](x[:,t,:], self.out)\n            out_liquid=self.out[:,0:self.liquid_size]\n            xout = self.fc(out_liquid)\n            sum_spike=sum_spike+xout\n            self.firing_tw[t]=out_liquid\n        outputs = sum_spike / tw\n        return outputs\n\n\n@register_model\nclass mSNN(BaseModule):\n    def __init__(self,\n                 batchsize,\n                 liquid_size,\n                 device,\n                 connectivity_matrix,\n                 num_classes=10,\n                 step=1,\n                 node_type=LIFNode,\n                 encode_type='direct',\n                 lsm_th=0.3,\n                 fc_th=0.3,\n                 lsm_tau=3,\n                 fc_tau=3,\n                 tw=100,\n                 *args,\n                 **kwargs):\n        super().__init__(step, encode_type, *args, **kwargs)\n        self.batchsize=batchsize\n\n        self.node_lsm=partial(node_type, **kwargs, step=step,tau=lsm_tau,threshold=lsm_th)\n        self.node_fc = partial(node_type, **kwargs, step=step,tau=fc_tau,threshold=fc_th)\n        self.liquid_size=liquid_size\n        self.out = torch.zeros(self.batchsize, liquid_size).to(device)\n        self.device=device\n        self.con=[]\n        self.learning_rule=[]\n        self.connectivity_matrix=connectivity_matrix\n        w1tmp=nn.Linear(784,liquid_size,bias=False).to(device)\n        self.con.append(w1tmp)\n        w2tmp=nn.Linear(liquid_size,liquid_size,bias=False).to(device)\n\n        self.liquid_weight=w2tmp.weight.data\n        \n        w2tmp.weight.data=w2tmp.weight.data*self.connectivity_matrix\n        self.con.append(w2tmp)\n        self.learning_rule.append(MutliInputSTDP(self.node_lsm(), [self.con[0], self.con[1]]))  # pm\n\n        self.fc = nn.Sequential(\n            nn.Linear(liquid_size,num_classes),\n            self.node_fc()\n        )\n\n    def forward(self, x):\n        x = x.reshape(x.shape[0], -1)\n        sum_spike=0\n        time_window=20\n        self.tw=time_window\n        self.firing_tw=torch.zeros(time_window, self.batchsize, self.liquid_size).to(self.device)\n        self.out = torch.zeros(self.batchsize, self.liquid_size).to(self.device)\n        for t in range(time_window):\n\n            self.out, self.dw = self.learning_rule[0](x, self.out)\n\n            out_liquid=self.out[:,0:self.liquid_size]\n            xout = self.fc(out_liquid)\n            sum_spike=sum_spike+xout\n            self.firing_tw[t]=out_liquid\n        # print(out_liquid.sum())\n        # print(xout.sum())\n        outputs = sum_spike / time_window\n        return outputs\n\n"
  },
  {
    "path": "examples/Structural_Development/ELSM/nsganet.py",
    "content": "import numpy as np\n\nfrom pymoo.algorithms.genetic_algorithm import GeneticAlgorithm\nfrom pymoo.docs import parse_doc_string\nfrom pymoo.model.individual import Individual\nfrom pymoo.model.survival import Survival\nfrom pymoo.operators.crossover.point_crossover import PointCrossover\nfrom pymoo.operators.mutation.polynomial_mutation import PolynomialMutation\nfrom pymoo.operators.mutation.bitflip_mutation import BinaryBitflipMutation\nfrom pymoo.operators.sampling.random_sampling import RandomSampling\nfrom pymoo.operators.selection.tournament_selection import compare, TournamentSelection\nfrom pymoo.util.display import disp_multi_objective\nfrom pymoo.util.dominator import Dominator\nfrom pymoo.util.non_dominated_sorting import NonDominatedSorting\nfrom pymoo.util.randomized_argsort import randomized_argsort\n\n\n# =========================================================================================================\n# Implementation\n# based on nsga2 from https://github.com/msu-coinlab/pymoo\n# =========================================================================================================\n\n\nclass NSGANet(GeneticAlgorithm):\n\n    def __init__(self, **kwargs):\n        kwargs['individual'] = Individual(rank=np.inf, crowding=-1)\n        super().__init__(**kwargs)\n\n        self.tournament_type = 'comp_by_dom_and_crowding'\n        self.func_display_attrs = disp_multi_objective\n\n\n# ---------------------------------------------------------------------------------------------------------\n# Binary Tournament Selection Function\n# ---------------------------------------------------------------------------------------------------------\n\n\ndef binary_tournament(pop, P, algorithm, **kwargs):\n    if P.shape[1] != 2:\n        raise ValueError(\"Only implemented for binary tournament!\")\n\n    tournament_type = algorithm.tournament_type\n    S = np.full(P.shape[0], np.nan)\n\n    for i in range(P.shape[0]):\n\n        a, b = P[i, 0], P[i, 1]\n\n        # if at least one solution is infeasible\n        if pop[a].CV > 0.0 or pop[b].CV > 0.0:\n            S[i] = compare(a, pop[a].CV, b, pop[b].CV, method='smaller_is_better', return_random_if_equal=True)\n\n        # both solutions are feasible\n        else:\n\n            if tournament_type == 'comp_by_dom_and_crowding':\n                rel = Dominator.get_relation(pop[a].F, pop[b].F)\n                if rel == 1:\n                    S[i] = a\n                elif rel == -1:\n                    S[i] = b\n\n            elif tournament_type == 'comp_by_rank_and_crowding':\n                S[i] = compare(a, pop[a].rank, b, pop[b].rank,\n                               method='smaller_is_better')\n\n            else:\n                raise Exception(\"Unknown tournament type.\")\n\n            # if rank or domination relation didn't make a decision compare by crowding\n            if np.isnan(S[i]):\n                S[i] = compare(a, pop[a].get(\"crowding\"), b, pop[b].get(\"crowding\"),\n                               method='larger_is_better', return_random_if_equal=True)\n\n    return S[:, None].astype(np.int)\n\n\n# ---------------------------------------------------------------------------------------------------------\n# Survival Selection\n# ---------------------------------------------------------------------------------------------------------\n\n\nclass RankAndCrowdingSurvival(Survival):\n\n    def __init__(self) -> None:\n        super().__init__(True)\n\n    def _do(self, pop, n_survive, D=None, **kwargs):\n\n        # get the objective space values and objects\n        F = pop.get(\"F\")\n\n        # the final indices of surviving individuals\n        survivors = []\n\n        # do the non-dominated sorting until splitting front\n        fronts = NonDominatedSorting().do(F, n_stop_if_ranked=n_survive)\n\n        for k, front in enumerate(fronts):\n\n            # calculate the crowding distance of the front\n            crowding_of_front = calc_crowding_distance(F[front, :])\n\n            # save rank and crowding in the individual class\n            for j, i in enumerate(front):\n                pop[i].set(\"rank\", k)\n                pop[i].set(\"crowding\", crowding_of_front[j])\n\n            # current front sorted by crowding distance if splitting\n            if len(survivors) + len(front) > n_survive:\n                I = randomized_argsort(crowding_of_front, order='descending', method='numpy')\n                I = I[:(n_survive - len(survivors))]\n\n            # otherwise take the whole front unsorted\n            else:\n                I = np.arange(len(front))\n\n            # extend the survivors by all or selected individuals\n            survivors.extend(front[I])\n\n        return pop[survivors]\n\n\ndef calc_crowding_distance(F):\n    infinity = 1e+14\n\n    n_points = F.shape[0]\n    n_obj = F.shape[1]\n\n    if n_points <= 2:\n        return np.full(n_points, infinity)\n    else:\n\n        # sort each column and get index\n        I = np.argsort(F, axis=0, kind='mergesort')\n\n        # now really sort the whole array\n        F = F[I, np.arange(n_obj)]\n\n        # get the distance to the last element in sorted list and replace zeros with actual values\n        dist = np.concatenate([F, np.full((1, n_obj), np.inf)]) \\\n               - np.concatenate([np.full((1, n_obj), -np.inf), F])\n\n        index_dist_is_zero = np.where(dist == 0)\n\n        dist_to_last = np.copy(dist)\n        for i, j in zip(*index_dist_is_zero):\n            dist_to_last[i, j] = dist_to_last[i - 1, j]\n\n        dist_to_next = np.copy(dist)\n        for i, j in reversed(list(zip(*index_dist_is_zero))):\n            dist_to_next[i, j] = dist_to_next[i + 1, j]\n\n        # normalize all the distances\n        norm = np.max(F, axis=0) - np.min(F, axis=0)\n        norm[norm == 0] = np.nan\n        dist_to_last, dist_to_next = dist_to_last[:-1] / norm, dist_to_next[1:] / norm\n\n        # if we divided by zero because all values in one columns are equal replace by none\n        dist_to_last[np.isnan(dist_to_last)] = 0.0\n        dist_to_next[np.isnan(dist_to_next)] = 0.0\n\n        # sum up the distance to next and last and norm by objectives - also reorder from sorted list\n        J = np.argsort(I, axis=0)\n        crowding = np.sum(dist_to_last[J, np.arange(n_obj)] + dist_to_next[J, np.arange(n_obj)], axis=1) / n_obj\n\n    # replace infinity with a large number\n    crowding[np.isinf(crowding)] = infinity\n\n    return crowding\n\n\n# =========================================================================================================\n# Interface\n# =========================================================================================================\n\n\ndef nsganet(\n        pop_size=100,\n        sampling=RandomSampling(var_type=np.int),\n        selection=TournamentSelection(func_comp=binary_tournament),\n        crossover=PointCrossover(n_points=2),\n        mutation=PolynomialMutation(eta=3, var_type=np.int),\n        \n        eliminate_duplicates=True,\n        n_offsprings=None,\n        **kwargs):\n    \"\"\"\n\n    Parameters\n    ----------\n    pop_size : {pop_size}\n    sampling : {sampling}\n    selection : {selection}\n    crossover : {crossover}\n    mutation : {mutation}\n    eliminate_duplicates : {eliminate_duplicates}\n    n_offsprings : {n_offsprings}\n\n    Returns\n    -------\n    nsganet : :class:`~pymoo.model.algorithm.Algorithm`\n        Returns an NSGANet algorithm object.\n\n\n    \"\"\"\n\n    return NSGANet(pop_size=pop_size,\n                   sampling=sampling,\n                   selection=selection,\n                   crossover=crossover,\n                   mutation=mutation,\n                   survival=RankAndCrowdingSurvival(),\n                   eliminate_duplicates=eliminate_duplicates,\n                   n_offsprings=n_offsprings,\n                   **kwargs)\n\n\nparse_doc_string(nsganet)\n"
  },
  {
    "path": "examples/Structural_Development/ELSM/spikes.py",
    "content": "from __future__ import print_function\nimport torchvision\nimport torchvision.transforms as transforms\nimport os\nimport numpy as np\nimport torch\nfrom torch import nn as nn\nfrom model import *\nfrom tqdm import tqdm\nimport argparse\nfrom datetime import datetime\nimport logging\nfrom timm.utils import *\nfrom spikingjelly.datasets.n_mnist import NMNIST\nfrom timm.loss import LabelSmoothingCrossEntropy\nfrom braincog.base.utils.criterions import *\nimport networkx as nx\nimport time\nfrom braincog.base.learningrule.STDP import *\n\ndef randbool(size, p=0.5):\n    return torch.rand(*size) < p\n\ndef calc_f2(con,device):       \n    batch_size=1\n    liquid_size=8000\n    images=torch.load('/1000images.pt')\n    labels=torch.load('/1000labels.pt')\n\n    load_path='970.t7'\n\n\n    snn = nSNN(ins=2312,\n            batchsize=batch_size,\n            device=device,\n            liquid_size=liquid_size,\n            lsm_tau=2.0,\n            lsm_th=0.20,\n            connectivity_matrix=randbool([liquid_size, liquid_size],p=0.01).to(device).int())\n\n    snn.load_state_dict(torch.load(load_path,map_location={'cuda:2':device})['fc'])\n    snn.con[0].load_state_dict(torch.load(load_path,map_location={'cuda:2':device})['lsm0'])\n\n    snn.to(device)\n    criterion = UnilateralMse(1.)\n\n    optimizer = torch.optim.AdamW(snn.fc.parameters(),lr=0.001, weight_decay=1e-4)\n\n    k=0\n    sbr=0\n    snn.connectivity_matrix=con\n    snn.learning_rule=[]\n    w2tmp=nn.Linear(liquid_size,liquid_size,bias=False,device=device)\n\n    w2tmp.weight.data=(torch.load(load_path,map_location={'cuda:2':device})['liquid_weight'])*snn.connectivity_matrix\n    snn.learning_rule.append(MutliInputSTDP(snn.node_lsm(), [snn.con[0], w2tmp])) \n    snn.eval()\n    for label,data in zip(labels,images):\n        running_loss = 0\n        snn.zero_grad()\n        optimizer.zero_grad()\n        data = data.to(device)\n        label = label.to(device)\n        data=data.reshape(batch_size,data.shape[0],-1) \n        output = snn(data)\n        # print(torch.argmax(output)==label)\n\n        out_liquid=snn.firing_tw.squeeze(-2)\n\n        mupost=torch.matmul(con,out_liquid.unsqueeze(-1))\n        mupre=torch.matmul(con.t(),out_liquid.unsqueeze(-1))\n        for t in range(snn.tw):\n            if t>5 and t<snn.tw-5:\n                mupost[t] = torch.sum(mupost[t+1:t+5],dim=0)\n                mupre[t] = torch.sum(mupre[t-5:t-1],dim=0)\n        br=mupost/mupre\n        br[torch.isnan(br)] = 0\n        br[torch.isinf(br)] = 0\n        br=(torch.sum(out_liquid*br.squeeze(-1),dim=1)/torch.sum(out_liquid,dim=1)).sum()/snn.tw\n        if torch.isnan(br):\n            continue\n        k+=1\n        if k==500:\n            break\n\n        sbr+=br\n\n        snn.reset()\n    # print(sbr/k)\n\n    return sbr/k\n\n    \n"
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/README.md",
    "content": "# Similarity-based context aware continual learning for spiking neural networks #\n\n## Requirments ##\n* albumentations==1.1.0\n* easydict==1.9\n* matplotlib==3.5.1\n* nni==2.10\n* numpy==1.22.4\n* opencv_python==4.5.5.62\n* Pillow==9.3.0\n* sacred==0.8.2\n* scikit_learn==1.1.3\n* scipy==1.9.3\n* tensorboardX==2.5.1\n* thop==0.0.31.post2005241907\n* torch==1.8.1+cu111\n* torchvision==0.9.1+cu111\n\n\n## Run ##\n\n``` CUDA_VISIBLE_DEVICES=0 python3 -m main train with \"./SCA-SNN/configs/train.yaml\" exp.name=\"cifar_b0_10s\" exp.savedir=\"./log/\" exp.saveckpt=\"./ckpts_cifar_b0_10s/\" exp.ckptdir=\"./log/\" exp.tensorboard_dir=\"./tensorboard/\" exp.debug=False --name=\"cifar_b0_10s\" -D --force```\n\n## Citation ##\nIf you find the code and dataset useful in your research, please consider citing:\n```\n@article{han2024similarity,\n  title={Similarity-based context aware continual learning for spiking neural networks},\n  author={Han, Bing and Zhao, Feifei and Li Yang and Kong Qingqun and Li Xianqi and Zeng, Yi},\n  year={2024}\n  }\n  \n@article{zeng2023braincog,\n  title={Braincog: A spiking neural network based, brain-inspired cognitive intelligence engine for brain-inspired ai and brain simulation},\n  author={Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and others},\n  journal={Patterns},\n  volume={4},\n  number={8},\n  year={2023},\n  publisher={Elsevier},\n}\n```\n\nEnjoy!\n"
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/configs/train.yaml",
    "content": "exp:\n  name: \"CIFAR_B0_10S\"\n  savedir: \"./logs\"\n  tensorboard_dir: \"./tensorboard\"\n  debug: False\n\n\n#Model Cfg\nmodel: \"incmodel\"\nconvnet: 'resnet18' \ntrain_head: 'softmax'\ninfer_head: 'softmax'\nchannel: 64\nuse_bias: False\nlast_relu: False\n\ndea: True\nuse_div_cls: False\ndiv_type: \"n+1\" # n+t, 1+1\ndistillation: False #True --CIL; False --TIL\ndisttype: \"KL\"\ntemperature: 2\ndistlamb: 1\nfeature_type: \"ffm\" # se\nattention_use_residual: True\nignore_new: True\n\nprune: False\n\nattention:\n  add_kl: True\n  kd_warm_up: 50\n  kd_loss_weight: 0.5\n  kl_loss_weight: 0.5\n\nreuse_oldfc: False\nweight_normalization: False\nval_per_n_epoch: -1 # Validation Per N epoch. -1 means the function is off.\nsave_ckpt: True\nsave_mem: True\nload_mem: False\n\n#Optimization;Training related\ntask_max: 10\nlr_min: 0.00005\nlr: 0.1\nweight_decay: 0.0005\ndynamic_weight_decay: False\nscheduler: 'multistep'\nscheduling:\n  - 100\n  - 120\nlr_decay: 0.1\noptimizer: \"sgd\"\nepochs: 170\nresampling: False\nwarmup: True\nwarmup_epochs: 10\n\npostprocessor:\n  enable: True\n  type: 'bic'\n  epochs: 1\n  batch_size: 128\n  lr: 0.1\n  scheduling:\n    - 60\n    - 90\n    - 120\n  lr_decay_factor: 0.1\n  weight_decay: 0.0005\n\npretrain:\n  epochs: 200\n  lr: 0.1\n  scheduling:\n    - 60\n    - 120\n    - 160\n  lr_decay: 0.1\n  weight_decay: 0.0005\n\n\n# Dataset Cfg\ndataset: \"cifar100\" #'imagenet100', 'cifar100'\ntrial: 2\nincrement: 10\nbatch_size: 50\nworkers: 1\nvalidation: 0 # Validation split (0. <= x <= 1.)\nrandom_classes: False #Randomize classes order of increment\nstart_class: 0 # number of tasks for the first step, start from 0.\nstart_task: 0\nmax_task: # Cap the number of task\n\n#Memory\ncoreset_strategy: \"iCaRL\"  # iCaRL, random\nmem_size_mode: \"uniform_fixed_total_mem\" #uniform_fixed_per_cls, uniform_fixed_total_mem\nmemory_size: 2000 # Max number of storable examplars\nfixed_memory_per_cls: 20 # the fixed number of exemplars per cls\n\n# Misc\ndevice: 0 #GPU index to use, for cpu use -1\nseed: 1993\n"
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/inclearn/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/inclearn/convnet/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/inclearn/convnet/classifier.py",
    "content": "import math\n\nimport torch\nfrom torch.nn.parameter import Parameter\nfrom torch.nn import functional as F\nfrom torch.nn import Module\n\n\nclass CosineClassifier(Module):\n    def __init__(self, in_features, n_classes, sigma=True):\n        super(CosineClassifier, self).__init__()\n        self.in_features = in_features\n        self.out_features = n_classes\n        self.weight = Parameter(torch.Tensor(n_classes, in_features))\n        if sigma:\n            self.sigma = Parameter(torch.Tensor(1))\n        else:\n            self.register_parameter('sigma', None)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        stdv = 1. / math.sqrt(self.weight.size(1))\n        self.weight.data.uniform_(-stdv, stdv)\n        if self.sigma is not None:\n            self.sigma.data.fill_(1)  #for initializaiton of sigma\n\n    def forward(self, input):\n        out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1))\n        if self.sigma is not None:\n            out = self.sigma * out\n        return out\n"
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/inclearn/convnet/imbalance.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom torch import nn\nimport numpy as np\nfrom torch.optim.lr_scheduler import CosineAnnealingLR\n\nclass BiC(nn.Module):\n    def __init__(self, lr, scheduling, lr_decay_factor, weight_decay, batch_size, epochs):\n        super(BiC, self).__init__()\n        self.beta = torch.nn.Parameter(torch.ones(1))  #.cuda()\n        self.gamma = torch.nn.Parameter(torch.zeros(1))  #.cuda()\n        self.lr = lr\n        self.scheduling = scheduling\n        self.lr_decay_factor = lr_decay_factor\n        self.weight_decay = weight_decay\n        self.class_specific = False\n        self.batch_size = batch_size\n        self.epochs = epochs\n        self.bic_flag = False\n\n    def reset(self, lr=None, scheduling=None, lr_decay_factor=None, weight_decay=None, n_classes=-1):\n        with torch.no_grad():\n            if lr is None:\n                lr = self.lr\n            if scheduling is None:\n                scheduling = self.scheduling\n            if lr_decay_factor is None:\n                lr_decay_factor = self.lr_decay_factor\n            if weight_decay is None:\n                weight_decay = self.weight_decay\n            if self.class_specific:\n                assert n_classes != -1\n                self.beta = torch.nn.Parameter(torch.ones(n_classes).cuda())\n                self.gamma = torch.nn.Parameter(torch.zeros(n_classes).cuda())\n            else:\n                self.beta = torch.nn.Parameter(torch.ones(1).cuda())\n                self.gamma = torch.nn.Parameter(torch.zeros(1).cuda())\n            self.optimizer = torch.optim.SGD([self.beta, self.gamma], lr=lr, momentum=0.9, weight_decay=weight_decay)\n            # self.scheduler = CosineAnnealingLR(self.optimizer, 10)\n            self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, scheduling, gamma=lr_decay_factor)\n\n    def extract_preds_and_targets(self, model, loader,taski,mask):\n        preds, targets = [], []\n        with torch.no_grad():\n            for (x, y) in loader:\n                preds.append(model(taski,x.cuda(),mask)['logit'])\n                targets.append(y.cuda())\n        return torch.cat((preds)), torch.cat((targets))\n\n    def update(self, logger, task_size, model, loader, loss_criterion=None,taski=None,mask=None):\n        if task_size == 0:\n            logger.info(\"no new task for BiC!\")\n            return\n        if loss_criterion is None:\n            loss_criterion = F.cross_entropy\n\n        self.bic_flag = True\n        logger.info(\"Begin BiC ...\")\n        model.eval()\n\n        for epoch in range(self.epochs):\n            preds_, targets_ = self.extract_preds_and_targets(model, loader,taski,mask)\n            order = np.arange(preds_.shape[0])\n            np.random.shuffle(order)\n\n            preds, targets = preds_.clone(), targets_.clone()\n            preds, targets = preds[order], targets[order]\n            _loss = 0.0\n            _correct = 0\n            _count = 0\n            for start in range(0, preds.shape[0], self.batch_size):\n                if start + self.batch_size < preds.shape[0]:\n                    out = preds[start:start + self.batch_size, :].clone()\n                    lbls = targets[start:start + self.batch_size]\n                else:\n                    out = preds[start:, :].clone()\n                    lbls = targets[start:]\n                if self.class_specific is False:\n                    out1 = out[:, :-task_size].clone()\n                    out2 = out[:, -task_size:].clone()\n                    outputs = torch.cat((out1, out2 * self.beta + self.gamma), 1)\n                else:\n                    outputs = out * self.beta + self.gamma\n                loss = loss_criterion(outputs, lbls)\n                self.optimizer.zero_grad()\n                loss.backward()\n                self.optimizer.step()\n                _, pred = outputs.max(1)\n                _correct += (pred == lbls).sum()\n                _count += lbls.size(0)\n                _loss += loss.item() * outputs.shape[0]\n            logger.info(\"epoch {} loss {:4f} acc {:4f}\".format(epoch, _loss / preds.shape[0], _correct / _count))\n\n            self.scheduler.step()\n        logger.info(\"beta {:.4f} gamma {:.4f}\".format(self.beta.cpu().item(), self.gamma.cpu().item()))\n\n    @torch.no_grad()\n    def post_process(self, preds, task_size):\n        if self.class_specific is False:\n            if task_size != 0:\n                preds[:, -task_size:] = preds[:, -task_size:] * self.beta + self.gamma\n        else:\n            preds = preds * self.beta + self.gamma\n        return preds\n\n\n\nclass CR(object):\n    def __init__(self):\n        self.gamma = None\n\n    @torch.no_grad()\n    def update(self, classifier, task_size):\n        old_weight_norm = torch.norm(classifier.weight[:-task_size], p=2, dim=1)\n        new_weight_norm = torch.norm(classifier.weight[-task_size:], p=2, dim=1)\n        #self.gamma = old_weight_norm.mean() / new_weight_norm.mean()\n        gamma=1.15*(old_weight_norm.mean() / new_weight_norm.mean())\n        gamma = torch.clamp(gamma, 0, 1)\n        self.gamma= gamma\n        \n    @torch.no_grad()\n    def post_process(self, logits, task_size):\n        logits[:, -task_size:] = logits[:, -task_size:] * self.gamma\n        return logits\n\n\nclass All_av(object):\n    def __init__(self):\n        self.gamma = []\n\n    @torch.no_grad()\n    def update(self, classifier, task_size, classnum_list, taski):\n        self.gamma = []\n        for i in range(taski+1):\n            old_weight_norm = torch.norm(classifier.weight[:-task_size], p=2, dim=1)\n            new_weight_norm = torch.norm(classifier.weight[sum(classnum_list[:i]):sum(classnum_list[:i+1])], p=2, dim=1)\n            gamma=1.5*(old_weight_norm.mean() / new_weight_norm.mean())\n            gamma = torch.clamp(gamma, 0, 1)\n            self.gamma.append(gamma)\n\n    @torch.no_grad()\n    def post_process(self, logits, task_size, classnum_list, taski):\n        for i in range(taski+1):\n            logits[:, sum(classnum_list[:i]):sum(classnum_list[:i+1])] = logits[:, sum(classnum_list[:i]):sum(classnum_list[:i+1])] * self.gamma[i]     \n        return logits\n"
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/inclearn/convnet/maskcl2.py",
    "content": "import numpy as np\nimport torch\nimport math\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nimport random\nfrom copy import deepcopy\n\ndef unit(x):\n    if x.size()[0]>0:\n        xnp=x.cpu().numpy()\n        maxx=torch.max(x)\n        #maxx=np.percentile(xnp, 99.5)\n        minx=torch.min(x)\n        marge=maxx-minx\n        if marge!=0:\n            xx=(x-minx)/marge\n            xx=torch.clip(xx, 0,1)\n        else:\n            xx=0.5*torch.ones_like(x)\n        return xx\n    else:\n        return x\n        \nclass Mask:\n    def __init__(self, model):\n        self.model = model\n        self.mat = {}\n        self.p_index={}\n        self.p_num={}\n        self.k=15\n        self.task_ready={}\n        self.regutask_ready={}\n        self.taskmask={}\n        self.init_rate=0.3\n        self.grow_rate=0.125\n\n        self.prunconv_init=0.8\n        self.prunfc_init=1.3\n        self.prunconv_grow=0.5\n        self.prunfc_grow=1\n\n        self.n_delta={}\n        self.ren_delta={}\n        self.reduce={}\n        self.rereduce={}\n        self.taskww={}\n        self.tasknore={}\n\n    def init_length(self,task=0,task_nn=None):\n        for index, item in enumerate(self.model.parameters()):\n            if len(item.size()) > 1 and item.size()[-1]!=1:\n                print(index,item.size())\n                self.mat[index]=torch.ones(item.size(),device=device)\n        for t in range(task):\n            self.rereduce[t]={}\n            for index, item in enumerate(self.model.parameters()):\n                if True:\n                    if index<=20:\n                        c_index=0\n                    elif index<=45:\n                        c_index=1\n                    elif index<=70:\n                        c_index=2\n                    else:\n                        c_index=3\n                taskindb=task_nn[t-1][c_index]\n                taskinda=task_nn[t][c_index]\n                lenre=taskinda-taskindb\n                self.rereduce[t][index] = 1*torch.ones(lenre,device=device)\n\n        return self.mat\n            \n\n    def get_filter_reuse(self,index,ww,task,epoch,c_index,cdim_before=None,task_nn=None,all_dist=None,bias=0): \n        lenre=cdim_before[1]-cdim_before[0]\n        similar=1-all_dist+bias\n        if similar<0.2:\n            similar=0.2\n        if similar>0.9:\n            similar=0.9\n\n        revalue=similar*torch.ones(lenre).cuda() #1/8,1/4,1/2,1,1.5\n\n        if len(ww.size()) == 4:\n            # p_www=ww*self.mat[index]\n            p_ww=torch.sum(torch.sum(torch.sum(ww,dim=3),dim=2),dim=1)\n\n            # p_ww=p_ww[cdim_before[0]:cdim_before[1]]\n            \n            ren_delta=-(2*unit(p_ww)-revalue)#revalue0.8\n            #print(self.ren_delta[index])\n            pos=torch.nonzero(ren_delta>0)\n            ren_delta[pos]=ren_delta[pos]+3\n            self.rereduce[task][index]=self.rereduce[task][index]*0.999+ren_delta*math.exp(-int((epoch-1)/2))\n            p_ind = torch.nonzero(self.rereduce[task][index] <0)\n            matkey=self.mat.keys()\n            matkey=torch.tensor(list(matkey))\n            matindex=torch.nonzero(matkey==index)\n            next_index=matkey[matindex+1]\n            for x in range(0, len(p_ind)):\n                self.mat[next_index.item()][:,p_ind[x]+cdim_before[0]]=0\n            b = self.mat[next_index.item()][:,cdim_before[0]:cdim_before[1]].reshape(-1).cpu().numpy()\n            pruning=100*(len(b)- np.count_nonzero(b))/len(b)\n            #print(index,self.rereduce[task][index].mean(),self.rereduce[task][index].max(),self.rereduce[task][index].min(),len(b)-np.count_nonzero(b),pruning)\n\n    def convert2tensor(self, x):\n        x = torch.FloatTensor(x)\n        return x\n\n    def init_mask(self,task,epoch,dim_cur=None,task_nn=None,all_dist=None,all_model=None):\n        for t in range(task):\n            similart=all_dist[t]\n            for index, item in enumerate(all_model[t].parameters()):\n                if len(item.size()) > 2 and item.size()[-1]!=1 and index<95:\n                    if index<=20:\n                        c_index=0\n                    elif index<=45:\n                        c_index=1\n                    elif index<=70:\n                        c_index=2\n                    else:\n                        c_index=3\n                        \n                    if index<=25:\n                        bias=0.2\n                    elif index<=50:\n                        bias=0.1\n                    elif index<=74:\n                        bias=-0.1\n                    else:\n                        bias=0.2\n                        \n                    taskindb=task_nn[t-1][c_index]\n                    taskinda=task_nn[t][c_index]\n                    cdim_before=[taskindb,taskinda]\n                    self.get_filter_reuse(index,abs(item.grad),t,epoch,c_index,cdim_before,task_nn=task_nn,all_dist=similart,bias=bias)\n\n    def do_mask(self,task):\n        for index, item in enumerate(self.model.parameters()):\n            if len(item.size()) > 1 and item.size()[-1]!=1:\n                ww=item.data\n                item.data=ww*self.mat[index].cuda()\n        return self.mat\n\n    def if_zero(self):\n        cc=[]\n        for index, item in enumerate(self.model.parameters()):\n            if len(item.size()) > 1 and item.size()[-1]!=1 and index>0:\n                b = item.data.view(-1).cpu().numpy()\n                print(\"number of weight is %d, zero is %.3f\" %(len(b),100*(len(b)- np.count_nonzero(b))/len(b)))\n                cc.append(100*(len(b)- np.count_nonzero(b))/len(b))\n        return cc\n\n"
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/inclearn/convnet/network.py",
    "content": "import copy\n# import pdb\n\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\n\nfrom inclearn.tools import factory\nfrom inclearn.convnet.imbalance import CR, All_av,BiC\nfrom inclearn.convnet.classifier import CosineClassifier\n\nfrom braincog.base.node.node import *\nfrom braincog.base.connection.layer import *\nfrom braincog.base.encoder.encoder import *\nfrom braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule\nclass BasicNet(nn.Module):\n    def __init__(\n        self,\n        convnet_type,\n        cfg,\n        nf=64,\n        use_bias=False,\n        init=\"kaiming\",\n        device=None,\n        dataset=\"cifar100\",\n    ):\n        super(BasicNet, self).__init__()\n        self.nf = nf\n        self.init = init\n        self.convnet_type = convnet_type\n        self.dataset = dataset\n        self.start_class = cfg['start_class']\n        self.weight_normalization = cfg['weight_normalization']\n        self.remove_last_relu = True if self.weight_normalization else False\n        self.use_bias = use_bias if not self.weight_normalization else False\n        self.dea = cfg['dea']\n        self.ft_type = cfg.get('feature_type', 'normal')\n        self.at_res = cfg.get('attention_use_residual', False)\n        self.div_type = cfg['div_type']\n        self.reuse_oldfc = cfg['reuse_oldfc']\n        self.prune = cfg.get('prune', False)\n        self.reset = cfg.get('reset_se', True)\n        self.torc=cfg['distillation']\n        self.node =LIFNode\n        self.encoder = Encoder(4, 'direct', temporal_flatten=False, layer_by_layer=False, **cfg)\n\n        # if self.dea:\n        #     print(\"Enable dynamical reprensetation expansion!\")\n        #     self.convnets = nn.ModuleList()\n        #     self.convnets.append(\n        #         factory.get_convnet(convnet_type,\n        #                             nf=nf,\n        #                             dataset=dataset,\n        #                             start_class=self.start_class,\n        #                             remove_last_relu=self.remove_last_relu))\n        #     self.out_dim = self.convnets[0].out_dim\n        #     self.c_dim=self.convnets[0].channel_dim\n        # else:\n        #     self.convnet = factory.get_convnet(convnet_type,\n        #                                        nf=nf,\n        #                                        dataset=dataset,\n        #                                        remove_last_relu=self.remove_last_relu)\n        #     self.out_dim = self.convnet.out_dim\n        self.channel_number=[32,64,128,256]#[32,64,128,256] # [24,48,72,96] #[24,48,96,192][32,64,128,256]\n        self.channel_dim=[48,96,192,384]\n        self.c_number1=np.array(self.channel_number)\n        if self.dea:\n            print(\"Enable dynamical reprensetation expansion!\")\n            self.convnets = nn.ModuleList()\n            self.convnets.append(\n                factory.get_convnet(convnet_type,c_dim=self.channel_number,cdim_cur=self.channel_number)\n            )\n            self.out_dim = self.channel_number[-1]\n            self.out_dim_cc = self.channel_number[-1]\n        else:\n            self.convnet = factory.get_convnet(convnet_type,\n                                               nf=nf,\n                                               dataset=dataset,\n                                               remove_last_relu=self.remove_last_relu)\n            self.out_dim = self.convnet.out_dim\n            \n        self.classifier = None\n        self.se = None\n        self.aux_classifier = None\n\n        self.n_classes = 0\n        self.ntask = 0\n        self.device = device\n\n\n        if cfg['postprocessor']['enable']:\n            if cfg['postprocessor']['type'].lower() == \"cr\":\n                self.postprocessor = CR()\n            elif cfg['postprocessor']['type'].lower() == \"aver\":\n                self.postprocessor = All_av()\n            else:\n                self.postprocessor = BiC(cfg['postprocessor'][\"lr\"], cfg['postprocessor'][\"scheduling\"],\n                cfg['postprocessor'][\"lr_decay_factor\"], cfg['postprocessor'][\"weight_decay\"],\n                cfg['postprocessor'][\"batch_size\"], cfg['postprocessor'][\"epochs\"])\n        \n        self.task_nn={}\n        self.task_nn[-1]=np.array([0,0,0,0])\n        self.task_nn[0]=np.array(self.channel_number)\n\n        self.to(self.device)\n\n    def forward(self, task,inputs,mask=None,classify=True):\n        inputs = self.encoder(inputs)\n        self.resetsnn()\n        step = 4\n        outputs = []\n        if self.classifier is None:\n            raise Exception(\"Add some classes before training.\")\n        \n        if mask is not None:\n            mat=mask.mat\n            if self.torc:\n                ttc=task\n            else:\n                ttc=-1\n            for index, item in enumerate(self.convnets[ttc].parameters()):\n                if len(item.size()) > 1 and item.size()[-1]!=1:\n                    ww=item.data\n                    item.data=ww*mat[index].cuda()\n\n        if self.dea:\n            # feature = [convnet(x) for convnet in self.convnets]\n            for time in range(step):\n                x_init = inputs[time]\n                task_feature={}\n                for t in range(task):\n                    task_feature[t]={}\n                    x=self.convnets[t].forward_init(x_init)\n                    task_feature[t][0]=x\n                    for l in range(len(self.convnets[t].layer_convnets)):\n                        for lc in range(len(self.convnets[t].layer_convnets[l].conv)):\n                            if lc==0 or lc==3:\n                                identity = x\n                            if lc==2:\n                                if l>0:\n                                    identity=self.convnets[t].layer_convnets[l].conv[lc](identity)\n                                x=x+identity\n                                task_feature[t][5*l+lc+1]=x\n                            else:\n                                old_tfeature=[]\n                                for old_t in range(t):\n                                    old_tfeature.append(task_feature[old_t][5*l+lc])\n                                old_tfeature.append(x)\n                                x= torch.cat(old_tfeature, 1)\n                                x=self.convnets[t].layer_convnets[l].conv[lc](x)\n                                if lc==4:\n                                    x=x+identity\n                                task_feature[t][5*l+lc+1]=x\n                x=self.convnets[task].forward_init(x_init)\n                for l in range(len(self.convnets[task].layer_convnets)):\n                    for lc in range(len(self.convnets[task].layer_convnets[l].conv)):\n                        if lc==0 or lc==3:\n                            identity = x\n                        if lc==2:\n                            if l>0:\n                                identity=self.convnets[task].layer_convnets[l].conv[lc](identity)\n                            x=x+identity\n                        else:\n                            mid_feature=[]\n                            for t in range(task):\n                                mid_feature.append(task_feature[t][5*l+lc])\n                            mid_feature.append(x)\n                            x= torch.cat(mid_feature, 1)\n                            x=self.convnets[task].layer_convnets[l].conv[lc](x)\n                            if lc==4:\n                                x=x+identity\n                if self.torc:\n                    last_feature=[]\n                    for t in range(task):\n                        last=len(task_feature[t])-1\n                        last_feature.append(task_feature[t][last])\n                    last_feature.append(x)\n                    outputs.append(torch.cat(last_feature, 1))\n                else:\n                    x=self.convnets[task].avgpool(x)\n                    x=x.view(x.size()[0],-1)\n                    outputs.append(x)\n            feature=sum(outputs).cuda()/ step\n            last_dim =x.size(1)\n            width = feature.size(1) \n            \n            if self.torc:\n                if self.reset:\n                    se = factory.get_attention(width, self.ft_type, self.at_res).to(self.device)\n                    features = se(feature)\n                else:\n                    features = self.se(feature)\n            else:\n                features=feature\n   \n        else:\n            features = self.convnet(x)\n\n        if self.torc:\n            if classify==True:\n                logits = self.convnets[-1].classifer(features)\n\n                div_logits = self.convnets[-1].aux_classifier(features[:, -last_dim:]) if self.ntask > 1 else None\n            else:\n                logits=None\n                div_logits=None\n        else:\n            if classify==True:\n                logits = self.convnets[task].classifer(features)\n            else:\n                logits=None\n\n            div_logits=None\n        \n        return {'feature': features, 'logit': logits, 'div_logit': div_logits, 'features': feature}\n\n    def caculate_dim(self, x):\n        feature = [convnet(x) for convnet in self.convnets]\n        features = torch.cat(feature, 1)\n\n        width = features.size(1)\n\n        # se = factory.get_attention(width, self.ft_type, self.at_res).to(self.device)\n        se = factory.get_attention(width, \"ce\", self.at_res).cuda()\n        features = se(features)\n\n        # import pdb\n        # pdb.set_trace()\n        return features.size(1), feature[-1].size(1)     \n\n    @property\n    def features_dim(self,ntask):\n        if self.dea:\n            return self.out_dim#+ntask*self.channel_number1[-1]\n        else:\n            return self.out_dim\n\n    def freeze(self):\n        for param in self.parameters():\n            param.requires_grad = False\n        self.eval()\n        return self\n\n    def copy(self):\n        return copy.deepcopy(self)\n\n    def add_classes(self, n_classes,min_dist):\n        self.ntask += 1\n\n        if self.dea:\n            self._add_classes_multi_fc(n_classes,min_dist)\n        else:\n            self._add_classes_single_fc(n_classes)\n\n        self.n_classes += n_classes\n\n    def _add_classes_multi_fc(self, n_classes,min_dist):\n        self.classifier=self.convnets[-1].classifer\n        if self.ntask > 1:\n            if min_dist<0.1:\n                min_dist=0.1\n            self.channel_number1=np.array([32,64,96,128])*1*(1-math.exp(-5*min_dist)) #0.5,1,1.5,2,3,4 [24,48,72,96][16,32,48,64]\n            self.channel_number1=self.channel_number1.astype(np.int64)\n            self.channel_dim=self.channel_number1\n            self.c_number1=self.c_number1+np.array(self.channel_number1)\n            self.task_nn[self.ntask-1]=self.c_number1\n            new_clf = factory.get_convnet(\"resnet18\",c_dim=self.c_number1,cdim_cur=self.channel_number1).to(self.device)\n            self.out_dim=self.out_dim+self.channel_number1[-1]\n            self.out_dim_cc=self.channel_number1[-1]\n            self.convnets.append(new_clf)\n        \n        if self.torc:\n            if not self.reset:\n                self.se = factory.get_attention(512*len(self.convnets), self.ft_type, self.at_res)\n                self.se.to(self.device)\n\n            if self.classifier is not None:\n                weight = copy.deepcopy(self.classifier.weight.data)\n\n            fc = self._gen_classifier(self.out_dim, self.n_classes + n_classes)\n            if self.classifier is not None and self.reuse_oldfc:\n                fc.weight.data[:self.n_classes, :(self.out_dim - self.out_dim_cc)] = weight\n            del self.classifier\n            self.classifier = fc\n            self.convnets[-1].classifer=self.classifier\n        else:\n            fc = self._gen_classifier(self.out_dim_cc, n_classes)\n            del self.classifier\n            self.classifier = fc\n            self.convnets[-1].classifer=fc\n\n        if self.torc:\n            if self.div_type == \"n+1\":\n                div_fc = self._gen_classifier(self.out_dim_cc, n_classes + 1)\n            elif self.div_type == \"1+1\":\n                div_fc = self._gen_classifier(self.out_dim_cc, 2)\n            elif self.div_type == \"n+t\":\n                div_fc = self._gen_classifier(self.out_dim_cc, self.ntask + n_classes)\n            else:\n                div_fc = self._gen_classifier(self.out_dim_cc, self.n_classes + n_classes)\n            del self.aux_classifier\n            self.aux_classifier = div_fc\n            self.convnets[-1].aux_classifier=self.aux_classifier\n\n    def _add_classes_single_fc(self, n_classes):\n        if self.classifier is not None:\n            weight = copy.deepcopy(self.classifier.weight.data)\n            if self.use_bias:\n                bias = copy.deepcopy(self.classifier.bias.data)\n\n        classifier = self._gen_classifier(self.features_dim, self.n_classes + n_classes)\n\n        if self.classifier is not None and self.reuse_oldfc:\n            classifier.weight.data[:self.n_classes] = weight\n            if self.use_bias:\n                classifier.bias.data[:self.n_classes] = bias\n\n        del self.classifier\n        self.classifier = classifier\n\n    def _gen_classifier(self, in_features, n_classes):\n        if self.weight_normalization:\n            classifier = CosineClassifier(in_features, n_classes).to(self.device)\n            # classifier = CosineClassifier(in_features, n_classes).cuda()\n        else:\n            classifier = nn.Linear(in_features, n_classes, bias=self.use_bias).to(self.device)\n            # classifier = nn.Linear(in_features, n_classes, bias=self.use_bias).cuda()\n            if self.init == \"kaiming\":\n                nn.init.kaiming_normal_(classifier.weight, nonlinearity=\"linear\")\n            if self.use_bias:\n                nn.init.constant_(classifier.bias, 0.0)\n\n        return classifier\n    \n    def resetsnn(self):\n        \"\"\"\n        重置所有神经元的膜电位\n        :return:\n        \"\"\"\n        for mod in self.convnets.modules():\n            if hasattr(mod, 'n_reset'):\n                mod.n_reset()\n"
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/inclearn/convnet/resnet.py",
    "content": "\"\"\"Taken & slightly modified from:\n* https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport torch.utils.model_zoo as model_zoo\nfrom torch.nn import functional as F\n\n__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']\n\nmodel_urls = {\n    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',\n    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',\n    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',\n    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',\n    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',\n}\n\n\ndef conv3x3(in_planes, out_planes, stride=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n\n\ndef conv1x1(in_planes, out_planes, stride=1):\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None, remove_last_relu=False):\n        super(BasicBlock, self).__init__()\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.downsample = downsample\n        self.stride = stride\n        self.remove_last_relu = remove_last_relu\n\n    def forward(self, x):\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        if not self.remove_last_relu:\n            out = self.relu(out)\n        return out\n\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(Bottleneck, self).__init__()\n        self.conv1 = conv1x1(inplanes, planes)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.conv2 = conv3x3(planes, planes, stride)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.conv3 = conv1x1(planes, planes * self.expansion)\n        self.bn3 = nn.BatchNorm2d(planes * self.expansion)\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\n\nclass ChannelAttention(nn.Module):\n    def __init__(self, in_planes, ratio=16):\n        super(ChannelAttention, self).__init__()\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        self.max_pool = nn.AdaptiveMaxPool2d(1)\n        # 共享权重的MLP\n        self.fc1   = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)\n        self.relu1 = nn.ReLU()\n        self.fc2   = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)\n        self.sigmoid = nn.Sigmoid()\n\n    def forward(self, x):\n        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))\n        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))\n        out = avg_out + max_out\n        return self.sigmoid(out)\n\n\nclass SpatialAttention(nn.Module):\n    def __init__(self, kernel_size=7):\n        super(SpatialAttention, self).__init__()\n        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'\n        padding = 3 if kernel_size == 7 else 1\n        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)\n        self.sigmoid = nn.Sigmoid()\n\n    def forward(self, x):\n        avg_out = torch.mean(x, dim=1, keepdim=True)\n        max_out, _ = torch.max(x, dim=1, keepdim=True)\n        x = torch.cat([avg_out, max_out], dim=1)\n        x = self.conv1(x)\n        return self.sigmoid(x)\n\nclass SEFeatureAt(nn.Module):\n    def __init__(self, inplanes, type, at_res):\n        super(SEFeatureAt, self).__init__()\n        self.se = nn.Sequential(\n            nn.AdaptiveAvgPool2d((1,1)),\n            nn.Conv2d(inplanes,inplanes//16,kernel_size=1),\n            nn.ReLU(),\n            nn.Conv2d(inplanes//16,inplanes,kernel_size=1),\n            nn.Sigmoid()\n        )\n        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n        self.type = type\n        self.at_res = at_res\n        self.ca = ChannelAttention(inplanes)\n        self.sa = SpatialAttention()\n\n    def forward(self, x):\n        residual = x\n        if self.type == \"se\":\n            attention = self.se(x)\n            x = x * attention\n        elif self.type == \"ffm\":\n            x = self.ca(x) * x\n            x = self.sa(x) * x\n        if self.at_res:\n            x += residual\n        x = self.avgpool(x)\n        x = x.view(x.size(0), -1)   \n\n        return x\n\n\nclass ResNet(nn.Module):\n    def __init__(self,\n                 block,\n                 layers,\n                 nf=64,\n                 zero_init_residual=True,\n                 dataset='cifar',\n                 start_class=0,\n                 remove_last_relu=False):\n        super(ResNet, self).__init__()\n        self.remove_last_relu = remove_last_relu\n        self.inplanes = nf\n        if 'cifar' in dataset:\n            self.conv1 = nn.Sequential(nn.Conv2d(3, nf, kernel_size=3, stride=1, padding=1, bias=False),\n                                       nn.BatchNorm2d(nf), nn.ReLU(inplace=True))\n        elif 'imagenet' in dataset:\n            if start_class == 0:\n                self.conv1 = nn.Sequential(\n                    nn.Conv2d(3, nf, kernel_size=7, stride=2, padding=3, bias=False),\n                    nn.BatchNorm2d(nf),\n                    nn.ReLU(inplace=True),\n                    nn.MaxPool2d(kernel_size=3, stride=2, padding=1),\n                )\n            else:\n                # Following PODNET implmentation\n                self.conv1 = nn.Sequential(\n                    nn.Conv2d(3, nf, kernel_size=3, stride=1, padding=1, bias=False),\n                    nn.BatchNorm2d(nf),\n                    nn.ReLU(inplace=True),\n                    nn.MaxPool2d(kernel_size=3, stride=2, padding=1),\n                )\n\n        self.layer1 = self._make_layer(block, 1 * nf, layers[0])\n        self.layer2 = self._make_layer(block, 2 * nf, layers[1], stride=2)\n        self.layer3 = self._make_layer(block, 4 * nf, layers[2], stride=2)\n        self.layer4 = self._make_layer(block, 8 * nf, layers[3], stride=2, remove_last_relu=remove_last_relu)\n\n        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n\n        self.out_dim = 8 * nf * block.expansion\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n        # Zero-initialize the last BN in each residual branch,\n        # so that the residual branch starts with zeros, and each residual block behaves like an identity.\n        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677\n        if zero_init_residual:\n            for m in self.modules():\n                if isinstance(m, Bottleneck):\n                    nn.init.constant_(m.bn3.weight, 0)\n                elif isinstance(m, BasicBlock):\n                    nn.init.constant_(m.bn2.weight, 0)\n\n    def _make_layer(self, block, planes, blocks, remove_last_relu=False, stride=1):\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                conv1x1(self.inplanes, planes * block.expansion, stride),\n                nn.BatchNorm2d(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample))\n        self.inplanes = planes * block.expansion\n        if remove_last_relu:\n            for i in range(1, blocks - 1):\n                layers.append(block(self.inplanes, planes))\n            layers.append(block(self.inplanes, planes, remove_last_relu=True))\n        else:\n            for _ in range(1, blocks):\n                layers.append(block(self.inplanes, planes))\n\n        return nn.Sequential(*layers)\n\n    def reset_bn(self):\n        for m in self.modules():\n            if isinstance(m, nn.BatchNorm2d):\n                m.reset_running_stats()\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n        # x = self.avgpool(x)\n        # x = x.view(x.size(0), -1)\n        return x\n\n\ndef resnet18(pretrained=False, **kwargs):\n    \"\"\"Constructs a ResNet-18 model.\n\n    \"\"\"\n    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)\n    if pretrained:\n        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))\n    return model\n\n\ndef resnet34(pretrained=False, **kwargs):\n    \"\"\"Constructs a ResNet-34 model.\n\n    \"\"\"\n    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)\n    if pretrained:\n        model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))\n    return model\n\n\ndef resnet50(pretrained=False, **kwargs):\n    \"\"\"Constructs a ResNet-50 model.\n\n    \"\"\"\n    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)\n    if pretrained:\n        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))\n    return model\n\n\ndef resnet101(pretrained=False, **kwargs):\n    \"\"\"Constructs a ResNet-101 model.\n\n    \"\"\"\n    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)\n    if pretrained:\n        model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))\n    return model\n\n\ndef resnet152(pretrained=False, **kwargs):\n    \"\"\"Constructs a ResNet-152 model.\n\n    \"\"\"\n    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)\n    if pretrained:\n        model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))\n    return model\n"
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/inclearn/convnet/sew_resnet.py",
    "content": "import torch\nimport torch.nn as nn\nfrom copy import deepcopy\n\ntry:\n    from torchvision.models.utils import load_state_dict_from_url\nexcept ImportError:\n    from torchvision._internally_replaced_utils import load_state_dict_from_url\nfrom braincog.base.node import *\nfrom braincog.model_zoo.base_module import *\nfrom braincog.datasets import is_dvs_data\nfrom timm.models import register_model\n__all__ = ['SEWResNet', 'sew_resnet18', 'sew_resnet34', 'sew_resnet50', 'sew_resnet101',\n           'sew_resnet152', 'sew_resnext50_32x4d', 'sew_resnext101_32x8d',\n           'sew_wide_resnet50_2', 'sew_wide_resnet101_2']\n\nmodel_urls = {\n    \"resnet18\": \"https://download.pytorch.org/models/resnet18-f37072fd.pth\",\n    \"resnet34\": \"https://download.pytorch.org/models/resnet34-b627a593.pth\",\n    \"resnet50\": \"https://download.pytorch.org/models/resnet50-0676ba61.pth\",\n    \"resnet101\": \"https://download.pytorch.org/models/resnet101-63fe2227.pth\",\n    \"resnet152\": \"https://download.pytorch.org/models/resnet152-394f9c45.pth\",\n    \"resnext50_32x4d\": \"https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth\",\n    \"resnext101_32x8d\": \"https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth\",\n    \"wide_resnet50_2\": \"https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth\",\n    \"wide_resnet101_2\": \"https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth\",\n}\n\n# modified by https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py\n\ndef sew_function(x: torch.Tensor, y: torch.Tensor, cnf:str):\n    if cnf == 'ADD':\n        return x + y\n    elif cnf == 'AND':\n        return x * y\n    elif cnf == 'IAND':\n        return x * (1. - y)\n    else:\n        raise NotImplementedError\n\n\n\ndef conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=dilation, groups=groups, bias=False, dilation=dilation)\n\n\ndef conv1x1(in_planes, out_planes, stride=1):\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, planes_cur,stride=1, downsample=None, groups=1,\n                 dilation=1, norm_layer=None, cnf: str = None, node: callable = LIFNode, **kwargs):\n        super(BasicBlock, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        # if groups != 1 or base_width != 64:\n        #     raise ValueError('BasicBlock only supports groups=1 and base_width=64')\n        if dilation > 1:\n            raise NotImplementedError(\"Dilation > 1 not supported in BasicBlock\")\n        self.conv=nn.Sequential(\n        # Both self.conv1 and self.downsample layers downsample the input when stride != 1\n        BaseConvModule(inplanes, planes_cur, kernel_size=(3, 3), stride=stride,padding=(1, 1), node=node),\n        BaseConvModule(planes, planes_cur, kernel_size=(3, 3), padding=(1, 1), node=node),\n        downsample,\n        BaseConvModule(planes, planes_cur, kernel_size=(3, 3), padding=(1, 1), node=node),\n        BaseConvModule(planes, planes_cur, kernel_size=(3, 3), padding=(1, 1), node=node),)\n        self.cnf = cnf\n\n    def forward(self, x):\n        identity = x\n\n        out = self.conv1(x)\n\n        out = self.conv2(out)\n\n        if self.downsample is not None:\n            identity = self.downsample_sn(self.downsample(x))\n\n        out = sew_function(identity, out, self.cnf)\n\n        return out\n\n    def extra_repr(self) -> str:\n        return super().extra_repr() + f'cnf={self.cnf}'\n\n\n\nclass SEWResNet(BaseModule):\n    def __init__(self, block, layers, c_dim=[64,128,256,512], cdim_cur=[],step=4,encode_type=\"direct\",zero_init_residual=False,\n                 groups=1, width_per_group=64, replace_stride_with_dilation=None,\n                 norm_layer=None, cnf: str =  'ADD',   *args,**kwargs):\n        super().__init__(            \n            step,\n            encode_type,\n            *args,\n            **kwargs\n        )\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        self._norm_layer = norm_layer\n        self.groups=groups\n\n        self.node = LIFNode\n        if issubclass(self.node, BaseNode):\n            self.node = partial(self.node, **kwargs, step=step)\n        self.c_dim=c_dim\n        if len(cdim_cur)>0:\n            self.cdim_cur=cdim_cur\n        else:\n            self.cdim_cur=self.c_dim\n        self.inplanes = c_dim[0]\n        self.inplanes_cur = cdim_cur[0]\n        self.dilation = 1\n        if replace_stride_with_dilation is None:\n            # each element in the tuple indicates if we should replace\n            # the 2x2 stride with a dilated convolution instead\n            replace_stride_with_dilation = [False, False, False]\n        if len(replace_stride_with_dilation) != 3:\n            raise ValueError(\"replace_stride_with_dilation should be None \"\n                             \"or a 3-element tuple, got {}\".format(replace_stride_with_dilation))\n\n \n\n        self.conv1 = nn.Conv2d(3, self.cdim_cur[0], kernel_size=3, stride=1, padding=3,\n                               bias=False)\n        self.bn1 = norm_layer(self.cdim_cur[0])\n        self.node1 = self.node()\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer_convnets = nn.ModuleList()\n        self.layer_convnets.append(self._make_layer(block, c_dim[0], self.cdim_cur[0],layers[0], cnf=cnf, node=self.node, **kwargs))\n        self.layer_convnets.append(self._make_layer(block, c_dim[1], self.cdim_cur[1], layers[1], stride=2,\n                                       dilate=replace_stride_with_dilation[0], cnf=cnf, node=self.node, **kwargs))\n        self.layer_convnets.append(self._make_layer(block, c_dim[2], self.cdim_cur[2], layers[2], stride=2,\n                                       dilate=replace_stride_with_dilation[1], cnf=cnf, node=self.node, **kwargs))\n        self.layer_convnets.append(self._make_layer(block, c_dim[3], self.cdim_cur[3], layers[3], stride=2,\n                                       dilate=replace_stride_with_dilation[2], cnf=cnf, node=self.node, **kwargs))\n        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n        # self.fc = nn.Linear(512 * block.expansion, num_classes)\n        self.classifer=None\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n        # Zero-initialize the last BN in each residual branch,\n        # so that the residual branch starts with zeros, and each residual block behaves like an identity.\n        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677\n        if zero_init_residual:\n            for m in self.modules():\n                if isinstance(m, BasicBlock):\n                    nn.init.constant_(m.bn2.weight, 0)\n\n    def _make_layer(self, block, planes, planes_cur,blocks, stride=1, dilate=False, cnf: str=None, node: callable = None, **kwargs):\n        norm_layer = self._norm_layer\n        downsample = None\n        previous_dilation = self.dilation\n        if dilate:\n            self.dilation *= stride\n            stride = 1\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                conv1x1(self.inplanes_cur, planes_cur * block.expansion, stride),\n                norm_layer(planes_cur * block.expansion),\n                node()\n            )\n\n        layers =block(self.inplanes, planes, planes_cur,stride, downsample, self.groups,\n                             previous_dilation, norm_layer, cnf, node, **kwargs)\n        self.inplanes = planes * block.expansion\n        self.inplanes_cur= planes_cur * block.expansion\n        # for _ in range(1, blocks):\n        #     layers.append(block(self.inplanes, planes, groups=self.groups,\n        #                         dilation=self.dilation,\n        #                         norm_layer=norm_layer, cnf=cnf, node=node, **kwargs))\n\n        return layers\n\n    def forward_init(self, inputs):\n        # See note [TorchScript super()]\n        x = self.conv1(inputs)\n        x = self.bn1(x)\n        x = self.node1(x)\n        # x = self.maxpool(x)\n        return x\n\n    \n    def forward_impl(self, inputs):\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        return x\n\n\n\ndef _sew_resnet(arch, block, layers, c_dim,cdim_cur,pretrained, progress, cnf,  **kwargs):\n    model = SEWResNet(block, layers, c_dim=c_dim,cdim_cur=cdim_cur,cnf=cnf,  **kwargs)\n    if pretrained:\n        state_dict = load_state_dict_from_url(model_urls[arch],\n                                              progress=progress)\n        model.load_state_dict(state_dict)\n    return model\n \n@register_model\ndef sew_resnet18(c_dim=[64,128,256,512],cdim_cur=[],pretrained=False, progress=True, cnf: str = None,  **kwargs):\n    \"\"\"\n    :param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet\n    :type pretrained: bool\n    :param progress: If True, displays a progress bar of the download to stderr\n    :type progress: bool\n    :param cnf: the name of spike-element-wise function\n    :type cnf: str\n    :param node: a spiking neuron layer\n    :type node: callable\n    :param kwargs: kwargs for `node`\n    :type kwargs: dict\n    :return: Spiking ResNet-18\n    :rtype: torch.nn.Module\n    The spike-element-wise ResNet-18 `\"Deep Residual Learning in Spiking Neural Networks\" <https://arxiv.org/abs/2102.04159>`_ modified by the ResNet-18 model from `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_\n    \"\"\"\n\n    return _sew_resnet('resnet18', BasicBlock, [2, 2, 2, 2], c_dim,cdim_cur,pretrained, progress,  'ADD', **kwargs)\n\n"
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/inclearn/convnet/utils.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch.optim import SGD\nimport torch.nn.functional as F\nfrom inclearn.tools.metrics import ClassErrorMeter, AverageValueMeter\n\n\ndef finetune_last_layer(\n    logger,\n    network,\n    loader,\n    n_class,\n    nepoch=30,\n    lr=0.1,\n    scheduling=[15, 35],\n    lr_decay=0.1,\n    weight_decay=5e-4,\n    loss_type=\"ce\",\n    temperature=5.0,\n    test_loader=None,\n    samples_per_cls = []\n):\n    network.eval()\n    #if hasattr(network.module, \"convnets\"):\n    #    for net in network.module.convnets:\n    #        net.eval()\n    #else:\n    #    network.module.convnet.eval()\n    optim = SGD(network.module.classifier.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)\n    scheduler = torch.optim.lr_scheduler.MultiStepLR(optim, scheduling, gamma=lr_decay)\n\n    if loss_type == \"ce\":\n        criterion = nn.CrossEntropyLoss()\n    else:\n        criterion = nn.BCEWithLogitsLoss()\n\n    logger.info(\"Begin finetuning last layer\")\n\n    for i in range(nepoch):\n        total_loss = 0.0\n        total_correct = 0.0\n        total_count = 0\n        # print(f\"dataset loader length {len(loader.dataset)}\")\n        for inputs, targets in loader:\n            inputs, targets = inputs.cuda(), targets.cuda()\n            if loss_type == \"bce\":\n                targets = to_onehot(targets, n_class)\n            outputs = network(inputs)['logit']\n            _, preds = outputs.max(1)\n            optim.zero_grad()\n            if loss_type == \"cb\":\n                loss = CB_loss(targets, outputs / temperature, samples_per_cls, n_class, \"focal\")\n            else:\n                loss = criterion(outputs / temperature, targets)\n            loss.backward()\n            optim.step()\n            total_loss += loss * inputs.size(0)\n            total_correct += (preds == targets).sum()\n            total_count += inputs.size(0)\n\n        if test_loader is not None:\n            test_correct = 0.0\n            test_count = 0.0\n            with torch.no_grad():\n                for inputs, targets in test_loader:\n                    outputs = network(inputs.cuda())['logit']\n                    _, preds = outputs.max(1)\n                    test_correct += (preds.cpu() == targets).sum().item()\n                    test_count += inputs.size(0)\n\n        scheduler.step()\n        if test_loader is not None:\n            logger.info(\n                \"Epoch %d finetuning loss %.3f acc %.3f Eval %.3f\" %\n                (i, total_loss.item() / total_count, total_correct.item() / total_count, test_correct / test_count))\n        else:\n            logger.info(\"Epoch %d finetuning loss %.3f acc %.3f\" %\n                        (i, total_loss.item() / total_count, total_correct.item() / total_count))\n    return network\n\n\n\ndef extract_features(task_i,model, loader,mask=None):\n    targets, features = [], []\n    model.eval()\n    with torch.no_grad():\n        for _inputs, _targets in loader:\n            _inputs = _inputs.cuda()\n            _targets = _targets.numpy()\n            _features = model(task_i,_inputs,mask)['feature'].detach().cpu().numpy()\n            features.append(_features)\n            targets.append(_targets)\n\n    return np.concatenate(features), np.concatenate(targets)\n\n\ndef calc_class_mean(network, loader, class_idx, metric):\n    EPSILON = 1e-8\n    features, targets = extract_features(network, loader)\n    # norm_feats = features/(np.linalg.norm(features, axis=1)[:,np.newaxis]+EPSILON)\n    # examplar_mean = norm_feats.mean(axis=0)\n    examplar_mean = features.mean(axis=0)\n    if metric == \"cosine\" or metric == \"weight\":\n        examplar_mean /= (np.linalg.norm(examplar_mean) + EPSILON)\n    return examplar_mean\n\n\ndef update_classes_mean(network, inc_dataset, n_classes, task_size, share_memory=None, metric=\"cosine\", EPSILON=1e-8):\n    loader = inc_dataset._get_loader(inc_dataset.data_inc,\n                                     inc_dataset.targets_inc,\n                                     shuffle=False,\n                                     share_memory=share_memory,\n                                     mode=\"test\")\n    class_means = np.zeros((n_classes, network.module.features_dim))\n    count = np.zeros(n_classes)\n    network.eval()\n    with torch.no_grad():\n        for x, y in loader:\n            feat = network(x.cuda())['feature']\n            for lbl in torch.unique(y):\n                class_means[lbl] += feat[y == lbl].sum(0).cpu().numpy()\n                count[lbl] += feat[y == lbl].shape[0]\n        for i in range(n_classes):\n            class_means[i] /= count[i]\n            if metric == \"cosine\" or metric == \"weight\":\n                class_means[i] /= (np.linalg.norm(class_means) + EPSILON)\n    return class_means\n"
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/inclearn/datasets/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/inclearn/datasets/data.py",
    "content": "import random\nimport cv2\nimport numpy as np\nimport os.path as osp\nfrom copy import deepcopy\nfrom PIL import Image\nimport multiprocessing as mp\nfrom multiprocessing import Pool\nimport albumentations as A\nfrom albumentations.pytorch import ToTensorV2\n\nimport warnings\nwarnings.filterwarnings(\"ignore\", \"Corrupt EXIF data\", UserWarning)\n\nimport torch\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data.sampler import SubsetRandomSampler, WeightedRandomSampler\nfrom torchvision import datasets, transforms\nfrom torchvision.datasets.folder import pil_loader\n\nfrom .dataset import get_dataset\nfrom inclearn.tools.data_utils import construct_balanced_subset\n\n\ndef get_data_folder(data_folder, dataset_name):\n    return osp.join(data_folder, dataset_name)\n\n\nclass IncrementalDataset:\n    def __init__(\n        self,\n        trial_i,\n        dataset_name,\n        random_order=False,\n        shuffle=True,\n        workers=10,\n        batch_size=128,\n        seed=1,\n        increment=10,\n        validation_split=0.0,\n        resampling=False,\n        data_folder=\"./data\",\n        start_class=0,\n        torc = False\n    ):\n\n        # The info about incremental split\n        self.torc=torc\n        self.trial_i = trial_i\n        self.start_class = start_class\n        #the number of classes for each step in incremental stage\n        self.task_size = increment\n        self.increments = []\n        self.random_order = random_order\n        self.validation_split = validation_split\n\n        #-------------------------------------\n        #Dataset Info\n        #-------------------------------------\n        self.data_folder =  get_data_folder('/data0/datasets/', 'CIFAR100/')\n        self.dataset_name = dataset_name\n        # self.transform = transform\n        self.train_dataset = None\n        self.test_dataset = None\n        self.n_tot_cls = -1\n        datasets = get_dataset(dataset_name)\n        self._setup_data(datasets)\n\n        self._workers = workers\n        self._shuffle = shuffle\n        self._batch_size = batch_size\n        self._resampling = resampling\n        #Currently, don't support multiple datasets\n        self.train_transforms = datasets.train_transforms\n        self.test_transforms = datasets.test_transforms\n        #torchvision or albumentations\n        self.transform_type = datasets.transform_type\n\n        # memory Mt\n        self.data_memory = None\n        self.targets_memory = None\n        # Incoming data D_t\n        self.data_cur = None\n        self.targets_cur = None\n        # Available data \\tilde{D}_t = D_t \\cup M_t\n        self.data_inc = None  # Cur task data + memory\n        self.targets_inc = None\n        # Available data stored in cpu memory.\n        self.shared_data_inc = None\n        self.shared_test_data = None\n\n        #Current states for Incremental Learning Stage.\n        self._current_task = 0\n\n    @property\n    def n_tasks(self):\n        return len(self.increments)\n\n    def new_task(self,task_i):\n\n        if self._current_task >= len(self.increments):\n            raise Exception(\"No more tasks.\")\n\n        min_class, max_class, x_train, y_train, x_test, y_test,x_val, y_val = self._get_cur_step_data_for_raw_data(task_i)\n\n        self.data_cur, self.targets_cur = x_train, y_train\n\n        if self.torc:\n            if self.data_memory is not None:\n                print(\"Set memory of size: {}.\".format(len(self.data_memory)))\n                if len(self.data_memory) != 0:\n                    x_train = np.concatenate((x_train, self.data_memory))\n                    y_train = np.concatenate((y_train, self.targets_memory))\n\n\n        self.data_inc, self.targets_inc = x_train, y_train\n        self.data_test_inc, self.targets_test_inc = x_test, y_test\n\n        train_loader = self._get_loader(x_train, y_train, mode=\"train\")\n        if self.torc:\n            val_loader = self._get_loader(x_val, y_val, shuffle=False, mode=\"test\")\n            test_loader = self._get_loader(x_test, y_test, shuffle=False, mode=\"test\")\n        else:\n            val_loader=[]\n            test_loader=[]\n            for i in range(len(x_test)):\n                val_loader.append(self._get_loader(x_test[i], y_test[i], shuffle=False, mode=\"test\"))\n                test_loader.append(self._get_loader(x_test[i], y_test[i], shuffle=False, mode=\"test\"))\n                \n        task_info = {\n            \"min_class\": min_class,\n            \"max_class\": max_class,\n            \"increment\": self.increments[self._current_task],\n            \"task\": self._current_task,\n            \"max_task\": len(self.increments),\n            \"n_train_data\": len(x_train),\n            \"n_test_data\": len(y_train),\n        }\n\n        self._current_task += 1\n        return task_info, train_loader, val_loader, test_loader\n\n\n    def _get_cur_step_data_for_raw_data(self, task_i):\n        min_class = sum(self.increments[:self._current_task])\n        max_class = sum(self.increments[:self._current_task + 1])\n\n        if self.torc:\n            x_train, y_train = self._select(task_i,self.data_train, self.targets_train, low_range=min_class, high_range=max_class)\n            x_test, y_test = self._select(task_i,self.data_test, self.targets_test, low_range=0, high_range=max_class)\n            x_val, y_val = self._select(task_i,self.data_test, self.targets_test, low_range=min_class, high_range=max_class)\n        else:\n            x_test=[]\n            y_test=[]\n            x_train, y_train = self._select(task_i,self.data_train, self.targets_train, low_range=min_class, high_range=max_class)\n            min_c=0\n            num_c=max_class-min_class\n            for taski in range(task_i+1):\n                x_test_i, y_test_i = self._select(taski,self.data_test, self.targets_test, low_range=min_c, high_range=min_c+num_c)\n                x_test.append(x_test_i)\n                y_test.append(y_test_i)\n                min_c+=num_c\n            x_val=x_test\n            y_val=y_test\n        return min_class, max_class, x_train, y_train, x_test, y_test,x_val, y_val\n\n\n    #--------------------------------\n    #           Data Setup\n    #--------------------------------\n    def _setup_data(self, dataset):\n        # FIXME: handles online loading of images\n        self.data_train, self.targets_train = [], []\n        self.data_test, self.targets_test = [], []\n        self.data_val, self.targets_val = [], []\n        self.increments = []\n        self.class_order = []\n\n        current_class_idx = 0  # When using multiple datasets\n        train_dataset = dataset(self.data_folder, train=True)\n        test_dataset = dataset(self.data_folder, train=False)\n        self.train_dataset = train_dataset\n        self.test_datasets = test_dataset\n        self.n_tot_cls = self.train_dataset.n_cls  #number of classes in whole dataset\n\n        self._setup_data_for_raw_data(dataset, train_dataset, test_dataset, current_class_idx)\n        # !list\n        self.data_train = np.concatenate(self.data_train)\n        self.targets_train = np.concatenate(self.targets_train)\n        self.data_val = np.concatenate(self.data_val)\n        self.targets_val = np.concatenate(self.targets_val)\n        self.data_test = np.concatenate(self.data_test)\n        self.targets_test = np.concatenate(self.targets_test)\n\n    def _setup_data_for_raw_data(self, dataset, train_dataset, test_dataset, current_class_idx=0):\n        increment = self.task_size\n\n        x_train, y_train = train_dataset.data, np.array(train_dataset.targets)\n        x_val, y_val, x_train, y_train = self._split_per_class(x_train, y_train, self.validation_split)\n        x_test, y_test = test_dataset.data, np.array(test_dataset.targets)\n\n        # Get Class Order\n        order = [i for i in range(len(np.unique(y_train)))]\n        if self.random_order:\n            random.seed(self._seed)  # Ensure that following order is determined by seed:\n            random.shuffle(order)\n        elif dataset.class_order(self.trial_i) is not None:\n            order = dataset.class_order(self.trial_i)\n\n        self.class_order.append(order)\n        y_train = self._map_new_class_index(y_train, order)\n        y_val = self._map_new_class_index(y_val, order)\n        y_test = self._map_new_class_index(y_test, order)\n\n        y_train += current_class_idx\n        y_val += current_class_idx\n        y_test += current_class_idx\n\n        current_class_idx += len(order)\n        if self.start_class == 0:\n            # increment = 10, 那么 increments 就是 [10, 10, 10, 10, ...]\n            self.increments = [increment for _ in range(len(order) // increment)]\n        else:\n            self.increments.append(self.start_class)\n            for _ in range((len(order) - self.start_class) // increment):\n                self.increments.append(increment)\n        self.data_train.append(x_train)\n        self.targets_train.append(y_train)\n        self.data_val.append(x_val)\n        self.targets_val.append(y_val)\n        self.data_test.append(x_test)\n        self.targets_test.append(y_test)\n\n    @staticmethod\n    def _split_per_class(x, y, validation_split=0.0):\n        \"\"\"Splits train data for a subset of validation data.\n\n        Split is done so that each class has a much data.\n        \"\"\"\n        shuffled_indexes = np.random.permutation(x.shape[0])\n        x = x[shuffled_indexes]\n        y = y[shuffled_indexes]\n\n        x_val, y_val = [], []\n        x_train, y_train = [], []\n\n        for class_id in np.unique(y):\n            class_indexes = np.where(y == class_id)[0]\n            nb_val_elts = int(class_indexes.shape[0] * validation_split)\n\n            val_indexes = class_indexes[:nb_val_elts]\n            train_indexes = class_indexes[nb_val_elts:]\n\n            x_val.append(x[val_indexes])\n            y_val.append(y[val_indexes])\n            x_train.append(x[train_indexes])\n            y_train.append(y[train_indexes])\n\n        # !list\n        x_val, y_val = np.concatenate(x_val), np.concatenate(y_val)\n        x_train, y_train = np.concatenate(x_train), np.concatenate(y_train)\n\n        return x_val, y_val, x_train, y_train\n\n    @staticmethod\n    def _map_new_class_index(y, order):\n        \"\"\"Transforms targets for new class order.\"\"\"\n        return np.array(list(map(lambda x: order.index(x), y)))\n\n    def _select(self, task_i,x, y, low_range=0, high_range=0):\n        idxes = sorted(np.where(np.logical_and(y >= low_range, y < high_range))[0])\n        if isinstance(x, list):\n            selected_x = [x[idx] for idx in idxes]\n        else:\n            selected_x = x[idxes]\n        if self.torc:\n            selected_y=y[idxes]\n        else:\n            selected_y=y[idxes]-low_range\n        return selected_x, selected_y\n\n    #--------------------------------\n    #           Get Loader\n    #--------------------------------\n    def get_datainc_loader(self, mode='train'):\n        print(self.data_inc.shape)\n        train_loader = self._get_loader(self.data_inc, self.targets_inc, mode=mode)\n        return train_loader\n\n    def get_custom_loader_from_memory(self, class_indexes, mode=\"test\"):\n        if not isinstance(class_indexes, list):\n            class_indexes = [class_indexes]\n        data, targets = [], []\n        for class_index in class_indexes:\n            class_data, class_targets = self._select(self.data_memory,\n                                                     self.targets_memory,\n                                                     low_range=class_index,\n                                                     high_range=class_index + 1)\n            data.append(class_data)\n            targets.append(class_targets)\n\n        data = np.concatenate(data)\n        targets = np.concatenate(targets)\n\n        return data, targets, self._get_loader(data, targets, shuffle=False, mode=mode)\n\n    def _get_loader(self, x, y, share_memory=None, shuffle=True, mode=\"train\", batch_size=None, resample=None):\n        if \"balanced\" in mode:\n            x, y = construct_balanced_subset(x, y)\n\n        batch_size = batch_size if batch_size is not None else self._batch_size\n\n        if \"train\" in mode:\n            trsf = self.train_transforms\n            resample_ = self._resampling if resample is None else True\n            if resample_ is False:\n                sampler = None\n            else:\n                sampler = get_weighted_random_sampler(y)\n            shuffle = False if resample_ is True else True\n        elif \"test\" in mode:\n            trsf = self.test_transforms\n            sampler = None\n        elif mode == \"flip\":\n            if \"imagenet\" in self.dataset_name:\n                trsf = A.Compose([A.HorizontalFlip(p=1.0), *self.test_transforms.transforms])\n            else:\n                trsf = transforms.Compose([transforms.RandomHorizontalFlip(p=1.0), *self.test_transforms.transforms])\n            sampler = None\n        else:\n            raise NotImplementedError(\"Unknown mode {}.\".format(mode))\n\n        return DataLoader(DummyDataset(x,\n                                       y,\n                                       trsf,\n                                       trsf_type=self.transform_type,\n                                       share_memory_=share_memory,\n                                       dataset_name=self.dataset_name),\n                          batch_size=batch_size,\n                          shuffle=shuffle,\n                          num_workers=self._workers,\n                          sampler=sampler,\n                          pin_memory=True)\n\n    def get_custom_loader(self, class_indexes, mode=\"test\", data_source=\"train\", imgs=None, tgts=None):\n        \"\"\"Returns a custom loader.\n\n        :param class_indexes: A list of class indexes that we want.\n        :param mode: Various mode for the transformations applied on it.\n        :param data_source: Whether to fetch from the train, val, or test set.\n        :return: The raw data and a loader.\n        \"\"\"\n        if not isinstance(class_indexes, list):  # TODO: deprecated, should always give a list\n            class_indexes = [class_indexes]\n\n        if data_source == \"train\":\n            x, y = self.data_inc, self.targets_inc\n        elif data_source == \"val\":\n            x, y = self.data_val, self.targets_val\n        elif data_source == \"test\":\n            x, y = self.data_test, self.targets_test\n        elif data_source == 'specified' and imgs is not None and tgts is not None:\n            x, y = imgs, tgts\n        else:\n            raise ValueError(\"Unknown data source <{}>.\".format(data_source))\n\n        data, targets = [], []\n        for class_index in class_indexes:\n            class_data, class_targets, = self._select(x, y, low_range=class_index, high_range=class_index + 1)\n            data.append(class_data)\n            targets.append(class_targets)\n\n        data = np.concatenate(data)\n        targets = np.concatenate(targets)\n\n        return data, targets, self._get_loader(data, targets, shuffle=False, mode=mode)\n\n\nclass DummyDataset(torch.utils.data.Dataset):\n    def __init__(self, x, y, trsf, trsf_type, share_memory_=None, dataset_name=None):\n        self.dataset_name = dataset_name\n        self.x, self.y = x, y\n        self.trsf = trsf\n        self.trsf_type = trsf_type\n        self.manager = mp.Manager()\n        self.buffer_size = 4000000\n        if share_memory_ is None:\n            if self.x.shape[0] > self.buffer_size:\n                self.share_memory = self.manager.list([None for i in range(self.buffer_size)])\n            else:\n                self.share_memory = self.manager.list([None for i in range(len(x))])\n        else:\n            self.share_memory = share_memory_\n\n    def __len__(self):\n        if isinstance(self.x, list):\n            return len(self.x)\n        else:\n            return self.x.shape[0]\n\n    def __getitem__(self, idx):\n        x, y, = self.x[idx], self.y[idx]\n        if isinstance(x, np.ndarray):\n            # assume cifar\n            x = Image.fromarray(x)\n        else:\n            # Assume the dataset is ImageNet\n            if idx < len(self.share_memory):\n                if self.share_memory[idx] is not None:\n                    x = self.share_memory[idx]\n                else:\n                    x = cv2.imread(x)\n                    x = x[:, :, ::-1]\n                    self.share_memory[idx] = x\n            else:\n                x = cv2.imread(x)\n                x = x[:, :, ::-1]\n\n        if 'torch' in self.trsf_type:\n            x = self.trsf(x)\n        else:\n            x = self.trsf(image=x)['image']\n        return x, y\n"
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/inclearn/datasets/dataset.py",
    "content": "import os.path as osp\nimport numpy as np\nimport glob\n\nfrom albumentations.pytorch import ToTensorV2\n\nfrom torchvision import datasets, transforms\nimport torch\nfrom inclearn.tools.cutout import Cutout\nfrom inclearn.tools.autoaugment_extra import ImageNetPolicy\n\n\ndef get_datasets(dataset_names):\n    return [get_dataset(dataset_name) for dataset_name in dataset_names.split(\"-\")]\n\n\ndef get_dataset(dataset_name):\n    if dataset_name == \"cifar10\":\n        return iCIFAR10\n    elif dataset_name == \"cifar100\":\n        return iCIFAR100\n    elif \"imagenet100\" in dataset_name:\n        return iImageNet100\n    else:\n        raise NotImplementedError(\"Unknown dataset {}.\".format(dataset_name))\n\n\nclass DataHandler:\n    base_dataset = None\n    train_transforms = []\n    common_transforms = [ToTensorV2()]\n    class_order = None\n\n\nclass iCIFAR10(DataHandler):\n    base_dataset_cls = datasets.cifar.CIFAR10\n    transform_type = 'torchvision'\n    train_transforms = transforms.Compose([\n        transforms.RandomCrop(32, padding=4),\n        transforms.RandomHorizontalFlip(),\n        # transforms.ColorJitter(brightness=63 / 255),\n        transforms.ToTensor(),\n        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n    ])\n    test_transforms = transforms.Compose([\n        transforms.ToTensor(),\n        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n    ])\n\n    def __init__(self, data_folder, train, is_fine_label=False):\n        self.base_dataset = self.base_dataset_cls(data_folder, train=train, download=True)\n        self.data = self.base_dataset.data\n        self.targets = self.base_dataset.targets\n        self.n_cls = 10\n\n    @property\n    def is_proc_inc_data(self):\n        return False\n\n    @classmethod\n    def class_order(cls, trial_i):\n        return [4, 0, 2, 5, 8, 3, 1, 6, 9, 7]\n\n\nclass iCIFAR100(iCIFAR10):\n    label_list = [\n        'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy',\n        'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',\n        'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish',\n        'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',\n        'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange',\n        'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',\n        'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk',\n        'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank',\n        'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale',\n        'willow_tree', 'wolf', 'woman', 'worm'\n    ]\n    base_dataset_cls = datasets.cifar.CIFAR100\n    transform_type = 'torchvision'\n    train_transforms = transforms.Compose([\n        transforms.RandomCrop(32, padding=4),\n        transforms.RandomHorizontalFlip(),\n        transforms.ColorJitter(brightness=63 / 255),\n        transforms.ToTensor(),\n        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),\n    ])\n    test_transforms = transforms.Compose([\n        transforms.ToTensor(),\n        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),\n    ])\n\n    def __init__(self, data_folder, train, is_fine_label=False):\n        self.base_dataset = self.base_dataset_cls(data_folder, train=train, download=True)\n        self.data = self.base_dataset.data\n        self.targets = self.base_dataset.targets\n        self.n_cls = 100\n        self.transform_type = 'torchvision'\n\n    @property\n    def is_proc_inc_data(self):\n        return False\n\n    @classmethod\n    def class_order(cls, trial_i):\n        return [\n                87, 0, 52, 58, 44, 91, 68, 97, 51, 15, 94, 92, 10, 72, 49, 78, 61, 14, 8, 86, 84, 96, 18, 24, 32, 45,\n                88, 11, 4, 67, 69, 66, 77, 47, 79, 93, 29, 50, 57, 83, 17, 81, 41, 12, 37, 59, 25, 20, 80, 73, 1, 28, 6,\n                46, 62, 82, 53, 9, 31, 75, 38, 63, 33, 74, 27, 22, 36, 3, 16, 21, 60, 19, 70, 90, 89, 43, 5, 42, 65, 76,\n                40, 30, 23, 85, 2, 95, 56, 48, 71, 64, 98, 13, 99, 7, 34, 55, 54, 26, 35, 39\n            ]\n\n\nclass DataHandler:\n    base_dataset = None\n    train_transforms = []\n    common_transforms = [ToTensorV2()]\n    class_order = None\n\n\nclass iImageNet100(DataHandler):\n\n    base_dataset_cls = datasets.ImageFolder\n    transform_type = 'torchvision'\n    train_transforms = transforms.Compose([\n        transforms.ToPILImage(),\n        transforms.ToTensor(),\n        Cutout(n_holes=1, length=16),\n        transforms.ToPILImage(),\n        transforms.RandomResizedCrop(32),\n        transforms.RandomHorizontalFlip(),\n        ImageNetPolicy(),\n        transforms.ToTensor(),\n        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),\n\n    ])\n    test_transforms = transforms.Compose([\n        transforms.ToPILImage(),\n        transforms.ToTensor(),\n        transforms.ToPILImage(),\n        transforms.Resize(32),\n        transforms.CenterCrop(32),\n        transforms.ToTensor(),\n        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n    ])\n\n    def __init__(self, data_folder, train, is_fine_label=False):\n        if train is True:\n            self.base_dataset = self.base_dataset_cls(osp.join(data_folder, \"train\"))\n        else:\n            self.base_dataset = self.base_dataset_cls(osp.join(data_folder, \"val\"))\n\n        self.data, self.targets = zip(*self.base_dataset.samples)\n        self.data = np.array(self.data)\n        self.targets = np.array(self.targets)\n        self.n_cls = 200\n\n    @property\n    def is_proc_inc_data(self):\n        return False\n\n    @classmethod\n    def class_order(cls, trial_i):\n       return [\n            68, 56, 78, 8, 23, 84, 90, 65, 74, 76, 40, 89, 3, 92, 55, 9, 26, 80, 43, 38, 58, 70, 77, 1, 85, 19, 17, 50,\n            28, 53, 13, 81, 45, 82, 6, 59, 83, 16, 15, 44, 91, 41, 72, 60, 79, 52, 20, 10, 31, 54, 37, 95, 14, 71, 96,\n            98, 97, 2, 64, 66, 42, 22, 35, 86, 24, 34, 87, 21, 99, 0, 88, 27, 18, 94, 11, 12, 47, 25, 30, 46, 62, 69,\n            36, 61, 7, 63, 75, 5, 32, 4, 51, 48, 73, 93, 39, 67, 29, 49, 57, 33,\n                        168, 156, 178, 108, 123, 184, 190, 165, 174, 176, 140, 189, 103,\n       192, 155, 109, 126, 180, 143, 138, 158, 170, 177, 101, 185, 119,\n       117, 150, 128, 153, 113, 181, 145, 182, 106, 159, 183, 116, 115,\n       144, 191, 141, 172, 160, 179, 152, 120, 110, 131, 154, 137, 195,\n       114, 171, 196, 198, 197, 102, 164, 166, 142, 122, 135, 186, 124,\n       134, 187, 121, 199, 100, 188, 127, 118, 194, 111, 112, 147, 125,\n       130, 146, 162, 169, 136, 161, 107, 163, 175, 105, 132, 104, 151,\n       148, 173, 193, 139, 167, 129, 149, 157, 133,\n\n       268, 256, 278, 208, 223, 284, 290, 265, 274, 276, 240, 289, 203,\n       292, 255, 209, 226, 280, 243, 238, 258, 270, 277, 201, 285, 219,\n       217, 250, 228, 253, 213, 281, 245, 282, 206, 259, 283, 216, 215,\n       244, 291, 241, 272, 260, 279, 252, 220, 210, 231, 254, 237, 295,\n       214, 271, 296, 298, 297, 202, 264, 266, 242, 222, 235, 286, 224,\n       234, 287, 221, 299, 200, 288, 227, 218, 294, 211, 212, 247, 225,\n       230, 246, 262, 269, 236, 261, 207, 263, 275, 205, 232, 204, 251,\n       248, 273, 293, 239, 267, 229, 249, 257, 233, 368, 356, 378, 308,\n       323, 384, 390, 365, 374, 376, 340, 389, 303, 392, 355, 309, 326,\n       380, 343, 338, 358, 370, 377, 301, 385, 319, 317, 350, 328, 353,\n       313, 381, 345, 382, 306, 359, 383, 316, 315, 344, 391, 341, 372,\n       360, 379, 352, 320, 310, 331, 354, 337, 395, 314, 371, 396, 398,\n       397, 302, 364, 366, 342, 322, 335, 386, 324, 334, 387, 321, 399,\n       300, 388, 327, 318, 394, 311, 312, 347, 325, 330, 346, 362, 369,\n       336, 361, 307, 363, 375, 305, 332, 304, 351, 348, 373, 393, 339,\n       367, 329, 349, 357, 333,\n\n       468, 456, 478, 408, 423, 484, 490, 465, 474, 476, 440, 489, 403,\n       492, 455, 409, 426, 480, 443, 438, 458, 470, 477, 401, 485, 419,\n       417, 450, 428, 453, 413, 481, 445, 482, 406, 459, 483, 416, 415,\n       444, 491, 441, 472, 460, 479, 452, 420, 410, 431, 454, 437, 495,\n       414, 471, 496, 498, 497, 402, 464, 466, 442, 422, 435, 486, 424,\n       434, 487, 421, 499, 400, 488, 427, 418, 494, 411, 412, 447, 425,\n       430, 446, 462, 469, 436, 461, 407, 463, 475, 405, 432, 404, 451,\n       448, 473, 493, 439, 467, 429, 449, 457, 433, 568, 556, 578, 508,\n       523, 584, 590, 565, 574, 576, 540, 589, 503, 592, 555, 509, 526,\n       580, 543, 538, 558, 570, 577, 501, 585, 519, 517, 550, 528, 553,\n       513, 581, 545, 582, 506, 559, 583, 516, 515, 544, 591, 541, 572,\n       560, 579, 552, 520, 510, 531, 554, 537, 595, 514, 571, 596, 598,\n       597, 502, 564, 566, 542, 522, 535, 586, 524, 534, 587, 521, 599,\n       500, 588, 527, 518, 594, 511, 512, 547, 525, 530, 546, 562, 569,\n       536, 561, 507, 563, 575, 505, 532, 504, 551, 548, 573, 593, 539,\n       567, 529, 549, 557, 533,\n\n       668, 656, 678, 608, 623, 684, 690, 665, 674, 676, 640, 689, 603,\n       692, 655, 609, 626, 680, 643, 638, 658, 670, 677, 601, 685, 619,\n       617, 650, 628, 653, 613, 681, 645, 682, 606, 659, 683, 616, 615,\n       644, 691, 641, 672, 660, 679, 652, 620, 610, 631, 654, 637, 695,\n       614, 671, 696, 698, 697, 602, 664, 666, 642, 622, 635, 686, 624,\n       634, 687, 621, 699, 600, 688, 627, 618, 694, 611, 612, 647, 625,\n       630, 646, 662, 669, 636, 661, 607, 663, 675, 605, 632, 604, 651,\n       648, 673, 693, 639, 667, 629, 649, 657, 633, 768, 756, 778, 708,\n       723, 784, 790, 765, 774, 776, 740, 789, 703, 792, 755, 709, 726,\n       780, 743, 738, 758, 770, 777, 701, 785, 719, 717, 750, 728, 753,\n       713, 781, 745, 782, 706, 759, 783, 716, 715, 744, 791, 741, 772,\n       760, 779, 752, 720, 710, 731, 754, 737, 795, 714, 771, 796, 798,\n       797, 702, 764, 766, 742, 722, 735, 786, 724, 734, 787, 721, 799,\n       700, 788, 727, 718, 794, 711, 712, 747, 725, 730, 746, 762, 769,\n       736, 761, 707, 763, 775, 705, 732, 704, 751, 748, 773, 793, 739,\n       767, 729, 749, 757, 733,\n\n       868, 856, 878, 808, 823, 884, 890, 865, 874, 876, 840, 889, 803,\n       892, 855, 809, 826, 880, 843, 838, 858, 870, 877, 801, 885, 819,\n       817, 850, 828, 853, 813, 881, 845, 882, 806, 859, 883, 816, 815,\n       844, 891, 841, 872, 860, 879, 852, 820, 810, 831, 854, 837, 895,\n       814, 871, 896, 898, 897, 802, 864, 866, 842, 822, 835, 886, 824,\n       834, 887, 821, 899, 800, 888, 827, 818, 894, 811, 812, 847, 825,\n       830, 846, 862, 869, 836, 861, 807, 863, 875, 805, 832, 804, 851,\n       848, 873, 893, 839, 867, 829, 849, 857, 833, 968, 956, 978, 908,\n       923, 984, 990, 965, 974, 976, 940, 989, 903, 992, 955, 909, 926,\n       980, 943, 938, 958, 970, 977, 901, 985, 919, 917, 950, 928, 953,\n       913, 981, 945, 982, 906, 959, 983, 916, 915, 944, 991, 941, 972,\n       960, 979, 952, 920, 910, 931, 954, 937, 995, 914, 971, 996, 998,\n       997, 902, 964, 966, 942, 922, 935, 986, 924, 934, 987, 921, 999,\n       900, 988, 927, 918, 994, 911, 912, 947, 925, 930, 946, 962, 969,\n       936, 961, 907, 963, 975, 905, 932, 904, 951, 948, 973, 993, 939,\n       967, 929, 949, 957, 933\n\n        ]\n\n"
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/inclearn/models/__init__.py",
    "content": "from .incmodel import IncModel\n"
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/inclearn/models/base.py",
    "content": "import abc\nimport logging\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom inclearn.tools.metrics import ClassErrorMeter\n\nLOGGER = logging.Logger(\"IncLearn\", level=\"INFO\")\n\n\nclass IncrementalLearner(abc.ABC):\n    \"\"\"Base incremental learner.\n\n    Methods are called in this order (& repeated for each new task):\n\n    1. set_task_info\n    2. before_task\n    3. train_task\n    4. after_task\n    5. eval_task\n    \"\"\"\n    def __init__(self, *args, **kwargs):\n        self._increments = []\n        self._seen_classes = []\n\n    def set_task_info(self, task, total_n_classes, increment, n_train_data, n_test_data, n_tasks):\n        self._task = task\n        self._task_size = increment\n        self._increments.append(self._task_size)\n        self._total_n_classes = total_n_classes\n        self._n_train_data = n_train_data\n        self._n_test_data = n_test_data\n        self._n_tasks = n_tasks\n\n    def before_task(self, taski, inc_dataset,mask,min_dist,all_dist):\n        LOGGER.info(\"Before task\")\n        self.eval()\n        self._before_task(taski, inc_dataset,mask,min_dist,all_dist)\n\n    def train_task(self, task_i,train_loader, val_loader,mask,min_dist,all_dist):\n        LOGGER.info(\"train task\")\n        self.train()\n        self._train_task(task_i,train_loader, val_loader,mask,min_dist,all_dist)\n\n    def after_task(self, taski, inc_dataset,mask):\n        LOGGER.info(\"after task\")\n        self.eval()\n        self._after_task(taski, inc_dataset,mask)\n\n    def eval_task(self, task_i,data_loader,mask):\n        LOGGER.info(\"eval task\")\n        self.eval()\n        return self._eval_task(task_i,data_loader,mask)\n\n    def get_memory(self):\n        return None\n\n    def eval(self):\n        raise NotImplementedError\n\n    def train(self):\n        raise NotImplementedError\n\n    def _before_task(self, data_loader):\n        pass\n\n    def _train_task(self, train_loader, val_loader):\n        raise NotImplementedError\n\n    def _after_task(self, data_loader):\n        pass\n\n    def _eval_task(self, data_loader):\n        raise NotImplementedError\n\n    @property\n    def _new_task_index(self):\n        return self._task * self._task_size\n\n    @property\n    def _memory_per_class(self):\n        \"\"\"Returns the number of examplars per class.\"\"\"\n        return self._memory_size.mem_per_cls\n\n    def _after_epoch(self, epoch, avg_loss, train_new_accu, train_old_accu, accu):\n        self._run.log_scalar(f\"train_loss_trial{self._trial_i}_task{self._task}\", avg_loss, epoch + 1)\n        self._tensorboard.add_scalar(f\"trial{self._trial_i}_task{self._task}/train_loss\", avg_loss, epoch + 1)\n\n        # self._run.log_scalar(f\"train_new_accu_trial{self._trial_i}_task{self._task}\",\n        #                      train_new_accu.value()[0], epoch + 1)\n        # self._tensorboard.add_scalar(f\"trial{self._trial_i}_task{self._task}/train_new_accu\",\n        #                              train_new_accu.value()[0], epoch + 1)\n\n        # if self._task != 0:\n        #     self._run.log_scalar(f\"train_old_accu_trial{self._trial_i}_task{self._task}\",\n        #                          train_old_accu.value()[0], epoch + 1)\n        #     self._tensorboard.add_scalar(f\"trial{self._trial_i}_task{self._task}/train_old_accu\",\n        #                                  train_old_accu.value()[0], epoch + 1)\n\n        self._run.log_scalar(f\"train_accu_trial{self._trial_i}_task{self._task}\", accu.value()[0], epoch + 1)\n        self._tensorboard.add_scalar(f\"trial{self._trial_i}_task{self._task}/train_accu\", accu.value()[0], epoch + 1)\n        # self._tensorboard.close()\n        self._tensorboard.flush()\n\n    def _validation(self, val_loader, epoch):\n        topk = 5 if self._n_classes >= 5 else self._n_classes\n        if self._val_per_n_epoch != -1 and epoch % self._val_per_n_epoch == 0:\n            _val_loss = 0\n            _val_accu = ClassErrorMeter(accuracy=True, topk=[1, topk])\n            _val_new_accu = ClassErrorMeter(accuracy=True)\n            _val_old_accu = ClassErrorMeter(accuracy=True)\n            self._parallel_network.eval()\n            with torch.no_grad():\n                for i, (inputs, targets) in enumerate(val_loader, 1):\n                    old_classes = targets < (self._n_classes - self._task_size)\n                    new_classes = targets >= (self._n_classes - self._task_size)\n                    val_loss, _ = self._forward_loss(\n                        inputs,\n                        targets,\n                        old_classes,\n                        new_classes,\n                        accu=_val_accu,\n                        old_accu=_val_old_accu,\n                        new_accu=_val_new_accu,\n                    )\n                    _val_loss += val_loss.item()\n            self._ex.logger.info(\n                f\"epoch{epoch} val acc:{_val_accu.value()[0]:.2f}, val top5acc:{_val_accu.value()[1]:.2f}\")\n            # Test accu\n            self._run.log_scalar(f\"test_accu_trial{self._trial_i}_task{self._task}\", _val_accu.value()[0], epoch + 1)\n            self._run.log_scalar(f\"test_5accu_trial{self._trial_i}_task{self._task}\", _val_accu.value()[1], epoch + 1)\n            self._tensorboard.add_scalar(f\"trial{self._trial_i}_task{self._task}/test_accu\",\n                                         _val_accu.value()[0], epoch + 1)\n            self._tensorboard.add_scalar(f\"trial{self._trial_i}_task{self._task}/test_5accu\",\n                                         _val_accu.value()[1], epoch + 1)\n\n            # Test new accu\n            self._run.log_scalar(f\"test_new_accu_trial{self._trial_i}_task{self._task}\",\n                                 _val_new_accu.value()[0], epoch + 1)\n            self._tensorboard.add_scalar(f\"trial{self._trial_i}_task{self._task}/test_new_accu\",\n                                         _val_new_accu.value()[0], epoch + 1)\n\n            # Test old accu\n            if self._task != 0:\n                self._run.log_scalar(f\"test_old_accu_trial{self._trial_i}_task{self._task}\",\n                                     _val_old_accu.value()[0], epoch + 1)\n                self._tensorboard.add_scalar(f\"trial{self._trial_i}_task{self._task}/test_old_accu\",\n                                             _val_old_accu.value()[0], epoch + 1)\n\n            # Test loss\n            self._run.log_scalar(f\"test_loss_trial{self._trial_i}_task{self._task}\", round(_val_loss / i, 3), epoch + 1)\n            self._tensorboard.add_scalar(f\"trial{self._trial_i}_task{self._task}/test_loss\", round(_val_loss / i, 3),\n                                         epoch + 1)\n            self._tensorboard.close()"
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/inclearn/models/incmodel.py",
    "content": "import numpy as np\nimport random\nimport time\nimport math\nimport os\nfrom copy import deepcopy\nfrom scipy.spatial.distance import cdist\nfrom torchvision.utils import save_image\n\nimport torch\n# import pdb\nfrom torch.nn import DataParallel\nfrom torch.nn import functional as F\nfrom torch import nn\n\n\n\nfrom inclearn.convnet import network\nfrom inclearn.models.base import IncrementalLearner\nfrom inclearn.tools import factory, utils\nfrom inclearn.tools.metrics import ClassErrorMeter\nfrom inclearn.tools.memory import MemorySize\nfrom inclearn.tools.scheduler import GradualWarmupScheduler\nfrom inclearn.convnet.utils import extract_features, update_classes_mean, finetune_last_layer\n\n# Constants\nEPSILON = 1e-8\n\n\nclass IncModel(IncrementalLearner):\n    def __init__(self, cfg, trial_i, _run, ex, tensorboard, inc_dataset):\n        super().__init__()\n        self._cfg = cfg\n        self._device = cfg['device']\n        self._ex = ex\n        self._run = _run  # the sacred _run object.\n\n        # Data\n        self._inc_dataset = inc_dataset\n        self._n_classes = 0\n        self.classnum_list = []\n        self.sample_list = []\n        self._trial_i = trial_i  # which class order is used\n\n        # Optimizer paras\n        self._opt_name = cfg[\"optimizer\"]\n        self._warmup = cfg['warmup']\n        self._lr = cfg[\"lr\"]\n        self._weight_decay = cfg[\"weight_decay\"]\n        self._n_epochs = cfg[\"epochs\"]\n        self._scheduling = cfg[\"scheduling\"]\n        self._lr_decay = cfg[\"lr_decay\"]\n        self.torc=cfg['distillation']\n        self.prune = cfg.get('prune', False)\n\n\n        # Logging\n        self._tensorboard = tensorboard\n        if f\"trial{self._trial_i}\" not in self._run.info:\n            self._run.info[f\"trial{self._trial_i}\"] = {}\n        self._val_per_n_epoch = cfg[\"val_per_n_epoch\"]\n\n        # Model\n        self._dea = cfg['dea']  # Whether to expand the representation\n        self._network = network.BasicNet(\n            cfg[\"convnet\"],\n            cfg=cfg,\n            nf=cfg[\"channel\"],\n            device=self._device,\n            use_bias=cfg[\"use_bias\"],\n            dataset=cfg[\"dataset\"],\n        )\n\n\n        if self._cfg.get(\"caculate_params\", False):\n            self._parallel_network = self._network\n        else:\n            # 并行计算\n            # gpus = [0, 1, 2, 3]\n            # self._parallel_network = DataParallel(self._network, device_ids=gpus, output_device=gpus[0])\n            self._parallel_network = DataParallel(self._network)\n\n        self._train_head = cfg[\"train_head\"]\n        self._infer_head = cfg[\"infer_head\"]\n        self._old_model = None\n\n        # Learning\n        self._temperature = cfg[\"temperature\"]\n        self._distillation = cfg[\"distillation\"]\n        self.lamb = cfg[\"distlamb\"]\n\n        # Memory\n        self._memory_size = MemorySize(cfg[\"mem_size_mode\"], inc_dataset, cfg[\"memory_size\"],\n                                       cfg[\"fixed_memory_per_cls\"])\n        self._herding_matrix = []\n        self._coreset_strategy = cfg[\"coreset_strategy\"]\n\n        if self._cfg[\"save_ckpt\"]:\n            save_path = os.path.join(os.getcwd(), f\"{self._cfg.exp.saveckpt}\")\n            if not os.path.exists(save_path):\n                os.mkdir(save_path)\n            if self._cfg[\"save_mem\"]:\n                save_path = os.path.join(os.getcwd(), f\"{self._cfg.exp.saveckpt}/mem\")\n                if not os.path.exists(save_path):\n                    os.mkdir(save_path)\n\n    def eval(self):\n        self._parallel_network.eval()\n\n    def train(self):\n        if self._dea:\n            self._parallel_network.train()\n            self._parallel_network.module.convnets[-1].train()\n            if self._task >= 1:\n                for i in range(self._task):\n                    self._parallel_network.module.convnets[i].eval()\n        else:\n            self._parallel_network.train()\n\n    def _before_task(self, taski, inc_dataset,mask,min_dist,all_dist):\n        self._ex.logger.info(f\"Begin step {taski}\")\n\n        # Update Task info\n        self._task = taski\n        self._n_classes += self._task_size\n        self.classnum_list.append(self._task_size)\n        self.sample_list = [ int(2000/(self._n_classes-10)) for i in range(self._n_classes-10)] + [ 500 for i in range(10)]\n\n        # Memory\n        self._memory_size.update_n_classes(self._n_classes)\n        self._memory_size.update_memory_per_cls(self._network, self._n_classes, self._task_size)\n        self._ex.logger.info(\"Now {} examplars per class.\".format(self._memory_per_class))\n\n        self._network.add_classes(self._task_size,min_dist)\n        self._network.task_size = self._task_size\n        mask.model=self._network.convnets[-1]\n        mask.init_length(taski,task_nn=self._network.task_nn)\n        self.set_optimizer()\n\n    def set_optimizer(self, lr=None):\n        if lr is None:\n            lr = self._lr\n\n        if self._cfg[\"dynamic_weight_decay\"]:\n            # used in BiC official implementation\n            weight_decay = self._weight_decay * self._cfg[\"task_max\"] / (self._task + 1)\n        else:\n            weight_decay = self._weight_decay\n        self._ex.logger.info(\"Step {} weight decay {:.5f}\".format(self._task, weight_decay))\n\n        # if self._dea and self._task > 0 and not self._cfg.get(\"caculate_params\", False):\n        #     for i in range(self._task):\n        #         for p in self._parallel_network.module.convnets[i].parameters():\n        #             p.requires_grad = False\n\n        self._optimizer = factory.get_optimizer(self._network.convnets[-1].parameters(),\n                                                self._opt_name, lr, weight_decay)\n\n        if \"cos\" in self._cfg[\"scheduler\"]:\n            self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self._optimizer, self._n_epochs)\n        else:\n            self._scheduler = torch.optim.lr_scheduler.MultiStepLR(self._optimizer,\n                                                                   self._scheduling,\n                                                                   gamma=self._lr_decay)\n\n        if self._warmup:\n            print(\"warmup\")\n            self._warmup_scheduler = GradualWarmupScheduler(self._optimizer,\n                                                            multiplier=1,\n                                                            total_epoch=self._cfg['warmup_epochs'],\n                                                            after_scheduler=self._scheduler)\n\n    def _train_task(self, task_i,train_loader, val_loader,mask,min_dist,all_dist):\n\n        self._ex.logger.info(f\"nb {len(train_loader.dataset)}\")\n\n        topk = 5 if self._n_classes > 5 else self._task_size\n        accu = ClassErrorMeter(accuracy=True, topk=[1, topk])\n        train_new_accu = ClassErrorMeter(accuracy=True)\n        train_old_accu = ClassErrorMeter(accuracy=True)\n\n        self._optimizer.zero_grad()\n        self._optimizer.step()\n\n        for epoch in range(self._n_epochs):\n            # torch.cuda.empty_cache()\n            _loss, _loss_div, _loss_trip, _loss_dist, _loss_atmap = 0.0, 0.0, 0.0, 0.0, 0.0\n            accu.reset()\n            train_new_accu.reset()\n            train_old_accu.reset()\n            if self._warmup:\n                self._warmup_scheduler.step()\n                if epoch == self._cfg['warmup_epochs']:\n                    if self.torc:\n                        self._network.convnets[-1].classifer.reset_parameters()\n                        if self._cfg['use_div_cls']:\n                            self._network.convnets[-1].aux_classifier.reset_parameters()\n                    else:\n                        self._network.convnets[task_i].classifer.reset_parameters()\n                        if self._cfg['use_div_cls']:\n                            self._network.aux_classifier[task_i].reset_parameters()\n            for i, (inputs, targets) in enumerate(train_loader, start=1):\n                self.train()\n                self._optimizer.zero_grad()\n                old_classes = targets < (self._n_classes - self._task_size)\n                new_classes = targets >= (self._n_classes - self._task_size)\n                loss_ce, loss_div, loss_trip, loss_dist, loss_atmap = self._forward_loss(\n                    task_i,\n                    inputs,\n                    targets,\n                    old_classes,\n                    new_classes,\n                    epoch,\n                    accu=accu,\n                    new_accu=train_new_accu,\n                    old_accu=train_old_accu,\n                    mask=mask\n                )\n\n                loss = loss_ce\n\n                if self._cfg[\"distillation\"] and self._task > 0:\n                    # trade-off - the lambda from the paper if lamb=-1\n                    if self.lamb == -1:\n                        lamb = (self._n_classes - self._task_size) / self._n_classes\n                        loss = (1-lamb) * loss + lamb * loss_dist\n                    else:\n                        loss =  loss + self.lamb * loss_dist\n\n                if self._cfg[\"use_div_cls\"] and self._task > 0:\n                    loss += loss_div           \n\n\n                loss.backward()\n                self._optimizer.step()\n\n                if self.torc:\n                    if self._cfg[\"postprocessor\"][\"enable\"]:\n                        if self._cfg[\"postprocessor\"][\"type\"].lower() == \"cr\" or self._cfg[\"postprocessor\"][\"type\"].lower() == \"aver\":\n                            for p in self._network.convnets[-1].classifer.parameters():\n                                p.data.clamp_(0.0)\n\n                _loss += loss_ce\n                _loss_trip += loss_trip\n                _loss_div += loss_div\n                _loss_dist += loss_dist\n                _loss_atmap += loss_atmap \n            \n            if task_i>0:\n                mask.init_mask(self._task,epoch,dim_cur=self._network.channel_dim, task_nn=self._network.task_nn,all_dist=all_dist,all_model=self._network.convnets)\n                mat=mask.do_mask(self._task)\n                \n\n            _loss = _loss.item()\n            _loss_div = _loss_div.item()\n            _loss_trip = _loss_trip.item()\n            _loss_dist = _loss_dist.item()\n            _loss_atmap = _loss_atmap.item()\n            if not self._warmup:\n                self._scheduler.step()\n            self._ex.logger.info(\n                \"Task {}/{}, Epoch {}/{} => Clf loss: {} Div loss: {}, Knowledge Distllation loss:{}, Train Accu: {}, Train@5 Acc: {}\".\n                format(\n                    self._task + 1,\n                    self._n_tasks,\n                    epoch + 1,\n                    self._n_epochs,\n                    round(_loss / i, 3),\n                    round(_loss_div / i, 3),\n                    round(_loss_dist / i, 3),\n                    round(accu.value()[0], 3),\n                    round(accu.value()[1], 3),\n                ))\n\n            if self._val_per_n_epoch > 0 and epoch % self._val_per_n_epoch == 0:\n                self.validate(val_loader)\n\n        if self.torc:\n            # For the large-scale dataset, we manage the data in the shared memory.\n            self._inc_dataset.shared_data_inc = train_loader.dataset.share_memory\n\n            utils.display_weight_norm(self._ex.logger, self._parallel_network, self._increments, \"After training\")\n            utils.display_feature_norm(task_i,self._ex.logger, self._parallel_network, train_loader, self._n_classes,\n                                    self._increments, \"Trainset\",mask=mask)\n            self._run.info[f\"trial{self._trial_i}\"][f\"task{self._task}_train_accu\"] = round(accu.value()[0], 3)\n\n    def _forward_loss(self, task_i,inputs, targets, old_classes, new_classes, epoch, accu=None, new_accu=None, old_accu=None,mask=None):\n        inputs, targets = inputs.to(self._device, non_blocking=True), targets.to(self._device, non_blocking=True)\n\n        outputs = self._parallel_network(task_i,inputs,mask)\n        if accu is not None:\n            accu.add(outputs['logit'], targets)\n        return self._compute_loss(task_i, inputs, targets, outputs, old_classes, new_classes, epoch,mask=mask)\n\n    def cross_entropy(self, outputs, targets, exp=1.0, size_average=True, eps=1e-5):\n        \"\"\"Calculates cross-entropy with temperature scaling\"\"\"\n        out = torch.nn.functional.softmax(outputs, dim=1)\n        tar = torch.nn.functional.softmax(targets, dim=1)\n        if exp != 1:\n            out = out.pow(exp)\n            out = out / out.sum(1).view(-1, 1).expand_as(out)\n            tar = tar.pow(exp)\n            tar = tar / tar.sum(1).view(-1, 1).expand_as(tar)\n        out = out + eps / out.size(1)\n        out = out / out.sum(1).view(-1, 1).expand_as(out)\n        ce = -(tar * out.log()).sum(1)\n        if size_average:\n            ce = ce.mean()\n        return ce\n\n    def hcl(self, fstudent, fteacher, targets):\n        loss_all = 0.0\n        fs = fstudent\n        select_teacher =  self._cfg.get(\"select_teacher\",False)\n\n        if select_teacher:\n            for i in range(len(fteacher)):\n                ft = fteacher[i]\n                if i > 0:\n                    old_classes = np.logical_and((targets < (self._n_classes - self._task_size * (len(fteacher)-i))).cpu(), (targets >= (self._n_classes - self._task_size * (len(fteacher)-i+1))).cpu())\n                else:\n                    old_classes = (targets < (self._n_classes - self._task_size * len(fteacher))).cpu()\n                classes_indice = torch.from_numpy(np.where(old_classes==True)[0]).to(self._device)\n                # targets_old = torch.index_select(targets_old, 0, old_classes_indice)\n                # log_probs_new = torch.index_select(log_probs_new, 0, old_classes_indice)\n                fs = torch.index_select(fstudent, 0, classes_indice)\n                ft = torch.index_select(ft, 0, classes_indice) \n                n,c,h,w = fs.shape\n                if n == 0:\n                    break\n                loss = F.mse_loss(fs, ft, reduction='mean')\n                cnt = 1.0\n                tot = 1.0\n                for l in [4,2,1]:\n                    if l >=h:\n                        continue\n                    tmpfs = F.adaptive_avg_pool2d(fs, (l,l))\n                    tmpft = F.adaptive_avg_pool2d(ft, (l,l))\n                    cnt /= 2.0\n                    loss += F.mse_loss(tmpfs, tmpft, reduction='mean') * cnt\n                    tot += cnt\n                loss = loss / tot\n                loss_all = loss_all + loss               \n        else:\n            for i in range(len(fteacher)):\n                ft = fteacher[i]\n                n,c,h,w = fs.shape\n                if n == 0:\n                    break\n                loss = F.mse_loss(fs, ft, reduction='mean')\n                cnt = 1.0\n                tot = 1.0\n                for l in [4,2,1]:\n                    if l >=h:\n                        continue\n                    tmpfs = F.adaptive_avg_pool2d(fs, (l,l))\n                    tmpft = F.adaptive_avg_pool2d(ft, (l,l))\n                    cnt /= 2.0\n                    loss += F.mse_loss(tmpfs, tmpft, reduction='mean') * cnt\n                    tot += cnt\n                loss = loss / tot\n                loss_all = loss_all + loss\n        return loss_all\n\n    def _compute_loss(self, task_i, inputs, targets, outputs, old_classes, new_classes, epoch,mask=None):\n\n        loss = F.cross_entropy(outputs['logit'], targets)\n\n        trip_loss = torch.zeros([1]).cuda()\n\n        atmap_loss = torch.zeros([1]).cuda()\n\n        if outputs['div_logit'] is not None:\n            div_targets = targets.clone()\n            if self._cfg[\"div_type\"] == \"n+1\":\n                div_targets[old_classes] = 0\n                div_targets[new_classes] -= sum(self._inc_dataset.increments[:self._task]) - 1\n            elif self._cfg[\"div_type\"] == \"1+1\":\n                div_targets[old_classes] = 0\n                div_targets[new_classes] = 1\n            elif self._cfg[\"div_type\"] == \"n+t\":\n                div_targets[new_classes] -= sum(self._inc_dataset.increments[:self._task]) - self._task\n                for i in range(self._task):\n                    if i > 0:\n                        old_class = np.logical_and((targets < (self._n_classes - self._task_size * (self._task-i))).cpu(), (targets >= (self._n_classes - self._task_size * (self._task-i+1))).cpu())\n                    else:\n                        old_class = (targets < (self._n_classes - self._task_size * self._task)).cpu()\n\n                    div_targets[old_class] = i\n            # import pdb\n            # pdb.set_trace()\n            div_loss = F.cross_entropy(outputs['div_logit'], div_targets)\n\n        else:\n            div_loss = torch.zeros([1]).cuda()   \n\n        if self._cfg[\"distillation\"] and self._old_model is not None:\n            outputs_old = self._old_model(task_i-1,inputs,mask=None)\n            targets_old = outputs_old['logit'].detach()\n\n            if self._cfg[\"disttype\"] == \"KL\":\n                log_probs_new = (outputs['logit'][:, :-self._task_size] / self._temperature).log_softmax(dim=1)\n                if self._task > 1 and self._cfg[\"postprocessor\"][\"enable\"]:\n                    if self._cfg[\"postprocessor\"][\"type\"].lower() == \"aver\":\n                        targets_old = self._old_model.module.postprocessor.post_process(targets_old, self._task_size, self.classnum_list[:-1], self._task-1)\n                    else:\n                        targets_old = self._old_model.module.postprocessor.post_process(targets_old, self._task_size)\n                modify =  self._cfg.get(\"modify_new\",False)\n                if modify:\n                    old_weight_norm = torch.norm(self._network.convnets[-1].classifer.weight[:-self._task_size], p=2, dim=1)\n                    new_weight_norm = torch.norm(self._network.convnets[-1].classifer.weight[-self._task_size:], p=2, dim=1)\n                    gamma = old_weight_norm.mean() / new_weight_norm.mean()\n         \n                    targets_old[new_classes,:] = targets_old[new_classes,:] * gamma\n                probs_old = (targets_old / self._temperature).softmax(dim=1)\n                \n                dist_loss = F.kl_div(log_probs_new, probs_old, reduction=\"batchmean\")\n            \n                \n            else:\n                dist_loss = self.cross_entropy(outputs['logit'][:, :-self._task_size], targets_old, exp=1.0 / self._temperature)\n        else:\n            dist_loss = torch.zeros([1]).cuda()         \n\n        return loss, div_loss, trip_loss, dist_loss, atmap_loss\n\n    def _after_task(self, taski, inc_dataset,mask=None):\n        network = deepcopy(self._parallel_network)\n        network.eval()\n        \n        if self._cfg[\"save_ckpt\"] and taski >= self._cfg[\"start_task\"] and not self.prune:\n            self._ex.logger.info(\"save model\")\n            save_path = os.path.join(os.getcwd(), f\"{self._cfg.exp.saveckpt}\")\n            torch.save(network.cpu().state_dict(), \"{}/step{}.ckpt\".format(save_path, self._task))\n\n        if self.torc:\n            if self._cfg[\"postprocessor\"][\"enable\"]:\n                self._update_postprocessor(taski,inc_dataset,mask=mask)\n\n        if self._cfg[\"infer_head\"] == 'NCM':\n            self._ex.logger.info(\"compute prototype\")\n            self.update_prototype()\n\n        if self._memory_size.memsize != 0:\n            self._ex.logger.info(\"build memory\")\n            self.build_exemplars(taski,inc_dataset, self._coreset_strategy,mask=mask)\n\n            if self._cfg[\"save_mem\"]:\n                save_path = os.path.join(os.getcwd(), f\"{self._cfg.exp.saveckpt}/mem\")\n                memory = {\n                    'x': inc_dataset.data_memory,\n                    'y': inc_dataset.targets_memory,\n                    'herding': self._herding_matrix\n                }\n                if not os.path.exists(save_path):\n                    os.makedirs(save_path)\n                if not (os.path.exists(f\"{save_path}/mem_step{self._task}.ckpt\") and self._cfg['load_mem']):\n                    torch.save(memory, \"{}/mem_step{}.ckpt\".format(save_path, self._task))\n                    self._ex.logger.info(f\"Save step{self._task} memory!\")\n\n        # utils.display_weight_norm(self._ex.logger, self._parallel_network, self._increments, \"After training\")\n\n        \n        self._parallel_network.eval()\n        self._old_model = deepcopy(self._parallel_network)\n        if not self._cfg.get(\"caculate_params\", False):\n            self._old_model.module.freeze()\n        del self._inc_dataset.shared_data_inc\n        self._inc_dataset.shared_data_inc = None\n\n    def _eval_task(self,task_i, data_loader,mask):\n        # if self._cfg.get(\"caculate_params\", False):\n            # from thop import profile\n            # self._parallel_network.eval()\n            # with torch.no_grad():\n            #     input = torch.randn(1, 3, 256, 256).to(self._device, non_blocking=True)\n            #     flops, params = profile(self._parallel_network, inputs=(input,))\n            # ypred = flops/1000**3\n            # ytrue = params/1000**2\n            # from torchstat import stat\n            # stat(self._parallel_network, (3, 256, 256))\n            # ypred,ytrue = 0,0\n        # else:\n        if self._infer_head == \"softmax\":\n            ypred, ytrue = self._compute_accuracy_by_netout(task_i,data_loader,mask)\n        elif self._infer_head == \"NCM\":\n            ypred, ytrue = self._compute_accuracy_by_ncm(data_loader)\n        else:\n            raise ValueError()\n\n        return ypred, ytrue\n\n\n    def _compute_accuracy_by_netout(self, task_i,data_loader,mask):\n        preds, targets = [], []\n        self._parallel_network.eval()\n        if self._cfg.get(\"caculate_params\", False):\n            with torch.no_grad():\n                from thop import profile\n                inputs = torch.randn(1, 3, 112, 112)\n                flops, params = profile(self._parallel_network, (inputs,))\n                preds = flops/1000**3\n                targets = params/1000**2\n                # print('flops: ', flops, 'params: ', params)\n                # for i, (inputs, lbls) in enumerate(data_loader):\n                #     from thop import profile\n                #     # inputs = inputs.to(self._device, non_blocking=True)\n                    \n                #     flops, params = profile(self._parallel_network, inputs[0])\n                #     preds = flops/1000**3\n                #     targets = params/1000**2\n                #     break                             \n        else:\n            with torch.no_grad():            \n                for i, (inputs, lbls) in enumerate(data_loader):\n                    inputs = inputs.to(self._device, non_blocking=True)\n                    _preds = self._parallel_network(task_i,inputs,mask)['logit']\n                    if self.torc:\n                        if self._cfg[\"postprocessor\"][\"enable\"] and self._task > 0:\n                            if self._cfg[\"postprocessor\"][\"type\"].lower() == \"aver\":\n                                _preds = self._network.postprocessor.post_process(_preds, self._task_size, self.classnum_list, self._task)\n                            else:\n                                _preds = self._network.postprocessor.post_process(_preds, self._task_size)\n                    preds.append(_preds.detach().cpu().numpy())\n                    targets.append(lbls.long().cpu().numpy())\n            preds = np.concatenate(preds, axis=0)\n            targets = np.concatenate(targets, axis=0)\n        return preds, targets\n\n    def _compute_accuracy_by_ncm(self, loader):\n        features, targets_ = extract_features(self._parallel_network, loader)\n        targets = np.zeros((targets_.shape[0], self._n_classes), np.float32)\n        targets[range(len(targets_)), targets_.astype(\"int32\")] = 1.0\n\n        class_means = (self._class_means.T / (np.linalg.norm(self._class_means.T, axis=0) + EPSILON)).T\n\n        features = (features.T / (np.linalg.norm(features.T, axis=0) + EPSILON)).T\n        # Compute score for iCaRL\n        sqd = cdist(class_means, features, \"sqeuclidean\")\n        score_icarl = (-sqd).T\n        return score_icarl[:, :self._n_classes], targets_\n\n    def _update_postprocessor(self, taski,inc_dataset,mask=None):\n        if self._cfg[\"postprocessor\"][\"type\"].lower() == \"bic\":\n            if False:#self._cfg[\"postprocessor\"][\"disalign_resample\"] is True:\n                bic_loader = inc_dataset._get_loader(inc_dataset.data_inc,\n                                                     inc_dataset.targets_inc,\n                                                     mode=\"train\",\n                                                     resample='disalign_resample')\n            else:\n                xdata, ydata = inc_dataset._select(taski,inc_dataset.data_train,\n                                                   inc_dataset.targets_train,\n                                                   low_range=0,\n                                                   high_range=self._n_classes)\n                bic_loader = inc_dataset._get_loader(xdata, ydata, shuffle=True, mode='train')\n            bic_loss = None\n            self._network.postprocessor.reset(n_classes=self._n_classes)\n            self._network.postprocessor.update(self._ex.logger,\n                                               self._task_size,\n                                               self._parallel_network,\n                                               bic_loader,\n                                               loss_criterion=bic_loss,\n                                               taski=taski,\n                                               mask=mask)\n        elif self._cfg[\"postprocessor\"][\"type\"].lower() == \"cr\":\n            self._ex.logger.info(\"Post processor cr update !\")\n            self._network.postprocessor.update(self._network.convnets[-1].classifer, self._task_size)\n        elif self._cfg[\"postprocessor\"][\"type\"].lower() == \"aver\":\n            self._ex.logger.info(\"Post processor aver update !\")\n            self._network.postprocessor.update(self._network.convnets[-1].classifer, self._task_size, self.classnum_list, self._task)\n\n    def update_prototype(self):\n        if hasattr(self._inc_dataset, 'shared_data_inc'):\n            shared_data_inc = self._inc_dataset.shared_data_inc\n        else:\n            shared_data_inc = None\n        self._class_means = update_classes_mean(self._parallel_network,\n                                                self._inc_dataset,\n                                                self._n_classes,\n                                                self._task_size,\n                                                share_memory=self._inc_dataset.shared_data_inc,\n                                                metric='None')\n\n    def build_exemplars(self, task_i,inc_dataset, coreset_strategy,mask=None):\n        save_path = os.path.join(os.getcwd(), f\"{self._cfg.exp.saveckpt}/mem/mem_step{self._task}.ckpt\")\n        if self._cfg[\"load_mem\"] and os.path.exists(save_path):\n            memory_states = torch.load(save_path)\n            self._inc_dataset.data_memory = memory_states['x']\n            self._inc_dataset.targets_memory = memory_states['y']\n            self._herding_matrix = memory_states['herding']\n            self._ex.logger.info(f\"Load saved step{self._task} memory!\")\n            return\n\n        if coreset_strategy == \"random\":\n            from inclearn.tools.memory import random_selection\n\n            self._inc_dataset.data_memory, self._inc_dataset.targets_memory = random_selection(\n                self._n_classes,\n                self._task_size,\n                self._parallel_network,\n                self._ex.logger,\n                inc_dataset,\n                self._memory_per_class,\n            )\n        elif coreset_strategy == \"iCaRL\":\n            from inclearn.tools.memory import herding\n            data_inc = self._inc_dataset.shared_data_inc if self._inc_dataset.shared_data_inc is not None else self._inc_dataset.data_inc\n            self._inc_dataset.data_memory, self._inc_dataset.targets_memory, self._herding_matrix = herding(\n                task_i,\n                self._n_classes,\n                self._task_size,\n                self._parallel_network,\n                self._herding_matrix,\n                inc_dataset,\n                data_inc,\n                self._memory_per_class,\n                self._ex.logger,\n                mask=mask\n            )\n        else:\n            raise ValueError()\n\n    def validate(self, data_loader):\n        if self._infer_head == 'NCM':\n            self.update_prototype()\n        ypred, ytrue = self._eval_task(data_loader)\n        test_acc_stats = utils.compute_accuracy(ypred, ytrue, increments=self._increments, n_classes=self._n_classes)\n        self._ex.logger.info(f\"test top1acc:{test_acc_stats['top1']}\")\n        return test_acc_stats['top1']['total']\n\n    def after_prune(self, taski, inc_dataset):\n        x = torch.randn(1, 3, 32, 32)\n        self._network = self._network.cpu()\n        dim1, dim2 = self._network.caculate_dim(x)\n        del self._network.classifier\n        self._network.classifier = self._network._gen_classifier(dim1, self._n_classes)\n        \n        if self._network.se is not None:\n            del self._network.se\n            ft_type = self._cfg.get('feature_type', 'ce')\n            at_res = self._cfg.get('attention_use_residual', False)\n            self._network.se = factory.get_attention(dim1, ft_type, at_res)\n        if taski > 0:\n            del self._network.aux_classifier\n            self._network.aux_classifier = self._network._gen_classifier(dim2, self._task_size+1)\n\n        del self._parallel_network\n        self._parallel_network = DataParallel(self._network)\n\n\nclass DistillKL(nn.Module):\n    \"\"\"Distilling the Knowledge in a Neural Network\"\"\"\n    def __init__(self, T):\n        super(DistillKL, self).__init__()\n        self.T = T\n\n    def forward(self, y_s, y_t):\n        p_s = F.log_softmax(y_s/self.T, dim=1)\n        p_t = F.softmax(y_t/self.T, dim=1)\n        loss = F.kl_div(p_s, p_t, reduction=\"sum\") * (self.T**2) / y_s.shape[0]\n        # loss = F.kl_div(p_s, p_t, reduction=\"batchmean\") * (self.T**2) / y_s.shape[0]\n        \n        return loss\n"
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/inclearn/tools/__init__.py",
    "content": ""
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/inclearn/tools/autoaugment_extra.py",
    "content": "from PIL import Image, ImageEnhance, ImageOps, ImageDraw\nimport numpy as np\nimport random\n\n\nclass ImageNetPolicy(object):\n    \"\"\" Randomly choose one of the best 24 Sub-policies on ImageNet.\n        Example:\n        >>> policy = ImageNetPolicy()\n        >>> transformed = policy(image)\n        Example as a PyTorch Transform:\n        >>> transform=transforms.Compose([\n        >>>     transforms.Resize(256),\n        >>>     ImageNetPolicy(),\n        >>>     transforms.ToTensor()])\n    \"\"\"\n    def __init__(self, fillcolor=(128, 128, 128)):\n        self.policies = [\n            SubPolicy(0.4, \"posterize\", 8, 0.6, \"rotate\", 9, fillcolor),\n            SubPolicy(0.6, \"solarize\", 5, 0.6, \"autocontrast\", 5, fillcolor),\n            SubPolicy(0.8, \"equalize\", 8, 0.6, \"equalize\", 3, fillcolor),\n            SubPolicy(0.6, \"posterize\", 7, 0.6, \"posterize\", 6, fillcolor),\n            SubPolicy(0.4, \"equalize\", 7, 0.2, \"solarize\", 4, fillcolor),\n\n            SubPolicy(0.4, \"equalize\", 4, 0.8, \"rotate\", 8, fillcolor),\n            SubPolicy(0.6, \"solarize\", 3, 0.6, \"equalize\", 7, fillcolor),\n            SubPolicy(0.8, \"posterize\", 5, 1.0, \"equalize\", 2, fillcolor),\n            SubPolicy(0.2, \"rotate\", 3, 0.6, \"solarize\", 8, fillcolor),\n            SubPolicy(0.6, \"equalize\", 8, 0.4, \"posterize\", 6, fillcolor),\n\n            SubPolicy(0.8, \"rotate\", 8, 0.4, \"color\", 0, fillcolor),\n            SubPolicy(0.4, \"rotate\", 9, 0.6, \"equalize\", 2, fillcolor),\n            SubPolicy(0.0, \"equalize\", 7, 0.8, \"equalize\", 8, fillcolor),\n            SubPolicy(0.6, \"invert\", 4, 1.0, \"equalize\", 8, fillcolor),\n            SubPolicy(0.6, \"color\", 4, 1.0, \"contrast\", 8, fillcolor),\n\n            SubPolicy(0.8, \"rotate\", 8, 1.0, \"color\", 2, fillcolor),\n            SubPolicy(0.8, \"color\", 8, 0.8, \"solarize\", 7, fillcolor),\n            SubPolicy(0.4, \"sharpness\", 7, 0.6, \"invert\", 8, fillcolor),\n            SubPolicy(0.6, \"shearX\", 5, 1.0, \"equalize\", 9, fillcolor),\n            SubPolicy(0.4, \"color\", 0, 0.6, \"equalize\", 3, fillcolor),\n\n            SubPolicy(0.4, \"equalize\", 7, 0.2, \"solarize\", 4, fillcolor),\n            SubPolicy(0.6, \"solarize\", 5, 0.6, \"autocontrast\", 5, fillcolor),\n            SubPolicy(0.6, \"invert\", 4, 1.0, \"equalize\", 8, fillcolor),\n            SubPolicy(0.6, \"color\", 4, 1.0, \"contrast\", 8, fillcolor),\n            SubPolicy(0.8, \"equalize\", 8, 0.6, \"equalize\", 3, fillcolor),\n            \n            SubPolicy(0.1, \"invert\", 7, 0.2, \"contrast\", 6, fillcolor),  # set-1\n            SubPolicy(0.7, \"rotate\", 2, 0.3, \"translateX\", 9, fillcolor),\n            SubPolicy(0.8, \"sharpness\", 1, 0.9, \"sharpness\", 3, fillcolor),\n            SubPolicy(0.5, \"shearY\", 8, 0.7, \"translateY\", 9, fillcolor),\n            SubPolicy(0.5, \"autocontrast\", 8, 0.9, \"equalize\", 2, fillcolor),\n\n            SubPolicy(0.2, \"shearY\", 7, 0.3, \"posterize\", 7, fillcolor), # set-3\n            SubPolicy(0.4, \"color\", 3, 0.6, \"brightness\", 7, fillcolor),\n            SubPolicy(0.3, \"sharpness\", 9, 0.7, \"brightness\", 9, fillcolor),\n            SubPolicy(0.6, \"equalize\", 5, 0.5, \"equalize\", 1, fillcolor),\n            SubPolicy(0.6, \"contrast\", 7, 0.6, \"sharpness\", 5, fillcolor),\n\n            SubPolicy(0.7, \"color\", 7, 0.5, \"translateX\", 8, fillcolor),    #set-11\n            SubPolicy(0.3, \"equalize\", 7, 0.4, \"autocontrast\", 8, fillcolor),\n            SubPolicy(0.4, \"translateY\", 3, 0.2, \"sharpness\", 6, fillcolor),\n            SubPolicy(0.9, \"brightness\", 6, 0.2, \"color\", 8, fillcolor),\n            SubPolicy(0.5, \"solarize\", 2, 0.0, \"invert\", 3, fillcolor),\n\n            SubPolicy(0.2, \"equalize\", 0, 0.6, \"autocontrast\", 0, fillcolor),\n            SubPolicy(0.2, \"equalize\", 8, 0.8, \"equalize\", 4, fillcolor),\n            SubPolicy(0.9, \"color\", 9, 0.6, \"equalize\", 6, fillcolor),\n            SubPolicy(0.8, \"autocontrast\", 4, 0.2, \"solarize\", 8, fillcolor),\n            SubPolicy(0.1, \"brightness\", 3, 0.7, \"color\", 0, fillcolor),\n\n            SubPolicy(0.4, \"solarize\", 5, 0.9, \"autocontrast\", 3, fillcolor), # set-2 \n            SubPolicy(0.9, \"translateY\", 9, 0.7, \"translateY\", 9, fillcolor),\n            SubPolicy(0.9, \"autocontrast\", 2, 0.8, \"solarize\", 3, fillcolor),\n            SubPolicy(0.8, \"equalize\", 8, 0.1, \"invert\", 3, fillcolor),\n            SubPolicy(0.7, \"translateY\", 9, 0.9, \"autocontrast\", 1, fillcolor),\n            \n            SubPolicy(0.4, \"solarize\",  5, 0.9, \"autocontrast\", 3, fillcolor), \n            SubPolicy(0.9, \"translateY\", 9, 0.7, \"translateY\", 9, fillcolor),\n            SubPolicy(0.9, \"autocontrast\",  2, 0.8, \"solarize\", 3, fillcolor),\n            SubPolicy(0.8, \"equalize\",  8, 0.1, \"invert\", 3, fillcolor),\n            SubPolicy(0.7, \"translateY\", 9, 0.9, \"autocontrast\", 1, fillcolor),\n   \n            SubPolicy(0.4, \"solarize\",  5, 0.9, \"autocontrast\", 1, fillcolor), \n            SubPolicy(0.8, \"translateY\",  9, 0.9, \"translateY\", 9, fillcolor),\n            SubPolicy(0.8, \"autocontrast\",  0, 0.7, \"translateY\", 9, fillcolor),\n            SubPolicy(0.2, \"translateY\",  7, 0.9, \"color\", 6, fillcolor),\n            SubPolicy(0.7, \"equalize\",  6, 0.4, \"color\", 9, fillcolor),\n             \n            SubPolicy(0.3, \"brightness\",  7, 0.5, \"autocontrast\", 8, fillcolor), \n            SubPolicy(0.9, \"autocontrast\",  4, 0.5, \"autocontrast\", 6, fillcolor),\n            SubPolicy(0.3, \"solarize\",  5, 0.6, \"equalize\", 5, fillcolor),\n            SubPolicy(0.2, \"translateY\",  4, 0.3, \"sharpness\", 3, fillcolor),\n            SubPolicy(0.0, \"brightness\",  8, 0.8, \"color\", 8, fillcolor),\n\n            SubPolicy(0.2, \"solarize\",  6, 0.8, \"color\", 6, fillcolor), \n            SubPolicy(0.2, \"solarize\",  6, 0.8, \"autocontrast\", 1, fillcolor),\n            SubPolicy(0.4, \"solarize\",  1, 0.6, \"equalize\", 5, fillcolor),\n            SubPolicy(0.0, \"brightness\",  0, 0.5, \"solarize\", 2, fillcolor),\n            SubPolicy(0.9, \"autocontrast\",  5, 0.5, \"brightness\", 3, fillcolor),\n\n            SubPolicy(0.7, \"contrast\",  5, 0.0, \"brightness\", 2, fillcolor), \n            SubPolicy(0.2, \"solarize\",  8, 0.1, \"solarize\", 5, fillcolor),\n            SubPolicy(0.5, \"contrast\",  1, 0.2, \"translateY\", 9, fillcolor),\n            SubPolicy(0.6, \"autocontrast\",  5, 0.0, \"translateY\", 9, fillcolor),\n            SubPolicy(0.9, \"autocontrast\",  4, 0.8, \"equalize\", 4, fillcolor),\n            \n            SubPolicy(0.0, \"brightness\",  7, 0.4, \"equalize\", 7, fillcolor), \n            SubPolicy(0.2, \"solarize\",  5, 0.7, \"equalize\", 5, fillcolor),\n            SubPolicy(0.6, \"equalize\",  8, 0.6, \"color\", 2, fillcolor),\n            SubPolicy(0.3, \"color\",  7, 0.2, \"color\", 4, fillcolor),\n            SubPolicy(0.5, \"autocontrast\",  2, 0.7, \"solarize\", 2, fillcolor),\n            \n            SubPolicy(0.2, \"autocontrast\",  0, 0.1, \"equalize\", 0, fillcolor), \n            SubPolicy(0.6, \"shearY\",  5, 0.6, \"equalize\", 5, fillcolor),\n            SubPolicy(0.9, \"brightness\",  3, 0.4, \"autocontrast\", 1, fillcolor),\n            SubPolicy(0.8, \"equalize\",  8, 0.7, \"equalize\", 7, fillcolor),\n            SubPolicy(0.7, \"equalize\",  7, 0.5, \"solarize\", 0, fillcolor),\n            \n            SubPolicy(0.8, \"equalize\",  4, 0.8, \"translateY\", 9, fillcolor), \n            SubPolicy(0.8, \"translateY\",  9, 0.6, \"translateY\", 9, fillcolor),\n            SubPolicy(0.9, \"translateY\",  0, 0.5, \"translateY\", 9, fillcolor),\n            SubPolicy(0.5, \"autocontrast\",  3, 0.3, \"solarize\", 4, fillcolor),\n            SubPolicy(0.5, \"solarize\",  3, 0.4, \"equalize\", 4, fillcolor),\n            \n            SubPolicy(0.1, \"autocontrast\",  5, 0.0, \"brightness\", 0, fillcolor), \n            SubPolicy(0.7, \"equalize\",  7, 0.6, \"autocontrast\", 4, fillcolor),\n            SubPolicy(0.1, \"color\",  8, 0.2, \"shearY\", 3, fillcolor),\n            SubPolicy(0.4, \"shearY\",  2, 0.7, \"rotate\", 0, fillcolor),\n            \n            SubPolicy(0.1, \"shearY\",  3, 0.9, \"autocontrast\", 5, fillcolor), \n            SubPolicy(0.5, \"equalize\",  0, 0.6, \"solarize\", 6, fillcolor),\n            SubPolicy(0.3, \"autocontrast\",  5, 0.2, \"rotate\", 7, fillcolor),\n            SubPolicy(0.8, \"equalize\",  2, 0.4, \"invert\", 0, fillcolor),\n            \n            SubPolicy(0.9, \"equalize\",  5, 0.7, \"color\", 0, fillcolor), \n            SubPolicy(0.1, \"equalize\",  1, 0.1, \"shearY\", 3, fillcolor),\n            SubPolicy(0.7, \"autocontrast\",  3, 0.7, \"equalize\", 0, fillcolor),\n            SubPolicy(0.5, \"brightness\",  1, 0.1, \"contrast\", 7, fillcolor),\n            SubPolicy(0.1, \"contrast\",  4, 0.6, \"solarize\", 5, fillcolor),\n            \n            SubPolicy(0.2, \"solarize\",  3, 0.0, \"shearX\", 0, fillcolor), \n            SubPolicy(0.3, \"translateX\",  0, 0.6, \"translateX\", 0, fillcolor),\n            SubPolicy(0.5, \"equalize\",  9, 0.6, \"translateY\", 7, fillcolor),\n            SubPolicy(0.1, \"shearX\",  0, 0.5, \"sharpness\", 1, fillcolor),\n            SubPolicy(0.8, \"equalize\",  6, 0.3, \"invert\", 6, fillcolor),\n            \n            SubPolicy(0.4, \"shearX\",  4, 0.9, \"autocontrast\", 2, fillcolor),\n            SubPolicy(0.0, \"shearX\",  3, 0.0, \"posterize\", 3, fillcolor),\n            SubPolicy(0.4, \"solarize\",  3, 0.2, \"color\", 4, fillcolor),\n            SubPolicy(0.1, \"equalize\",  4, 0.7, \"equalize\", 6, fillcolor),\n            \n            SubPolicy(0.3, \"equalize\",  8, 0.4, \"autocontrast\", 3, fillcolor), \n            SubPolicy(0.6, \"solarize\",  4, 0.7, \"autocontrast\", 6, fillcolor),\n            SubPolicy(0.2, \"autocontrast\",  9, 0.4, \"brightness\", 8, fillcolor),\n            SubPolicy(0.1, \"equalize\",  0, 0.0, \"equalize\", 6, fillcolor),\n            SubPolicy(0.8, \"equalize\",  4, 0.0, \"equalize\", 4, fillcolor),\n            \n            SubPolicy(0.5, \"equalize\",  5, 0.1, \"autocontrast\", 2, fillcolor), \n            SubPolicy(0.5, \"solarize\",  5, 0.9, \"autocontrast\", 5, fillcolor),\n        ]\n\n\n    def __call__(self, img):\n        policy_idx = random.randint(0, len(self.policies) - 1)\n        return self.policies[policy_idx](img)\n\n    def __repr__(self):\n        return \"AutoAugment ImageNet Policy\"\n\n\nclass SubPolicy(object):\n    def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):\n        ranges = {\n            \"shearX\": np.linspace(0, 0.3, 10),\n            \"shearY\": np.linspace(0, 0.3, 10),\n            \"translateX\": np.linspace(0, 150 / 331, 10),\n            \"translateY\": np.linspace(0, 150 / 331, 10),\n            \"rotate\": np.linspace(0, 30, 10),\n            \"color\": np.linspace(0.0, 0.9, 10),\n            \"posterize\": np.round(np.linspace(8, 4, 10), 0).astype(np.int64),\n            \"solarize\": np.linspace(256, 0, 10),\n            \"contrast\": np.linspace(0.0, 0.9, 10),\n            \"sharpness\": np.linspace(0.0, 0.9, 10),\n            \"brightness\": np.linspace(0.0, 0.9, 10),\n            \"autocontrast\": [0] * 10,\n            \"equalize\": [0] * 10,\n            \"invert\": [0] * 10,\n            \"cutout\": np.linspace(0.0, 0.2, 10),\n        }\n    \n        def Cutout(img, v):  # [0, 60] => percentage: [0, 0.2]\n            #assert 0.0 <= v <= 0.2\n            if v <= 0.:\n                return img\n\n            v = v * img.size[0]\n\n            return CutoutAbs(img, v)\n\n            # x0 = np.random.uniform(w - v)\n            # y0 = np.random.uniform(h - v)\n            # xy = (x0, y0, x0 + v, y0 + v)\n            # color = (127, 127, 127)\n            # img = img.copy()\n            # PIL.ImageDraw.Draw(img).rectangle(xy, color)\n            # return img\n\n\n        def CutoutAbs(img, v):  # [0, 60] => percentage: [0, 0.2]\n            # assert 0 <= v <= 20\n            if v < 0:\n                return img\n            w, h = img.size\n            x0 = np.random.uniform(w)\n            y0 = np.random.uniform(h)\n\n            x0 = int(max(0, x0 - v / 2.))\n            y0 = int(max(0, y0 - v / 2.))\n            x1 = min(w, x0 + v)\n            y1 = min(h, y0 + v)\n\n            xy = (x0, y0, x1, y1)\n            color = (125, 123, 114)\n            # color = (0, 0, 0)\n            img = img.copy()\n            ImageDraw.Draw(img).rectangle(xy, color)\n            return img\n\n        def rotate_with_fill(img, magnitude):\n            rot = img.convert(\"RGBA\").rotate(magnitude)\n            return Image.composite(rot, Image.new(\"RGBA\", rot.size, (128,) * 4), rot).convert(img.mode)\n\n        func = {\n            \"shearX\": lambda img, magnitude: img.transform(\n                img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),\n                Image.BICUBIC, fillcolor=fillcolor),\n            \"shearY\": lambda img, magnitude: img.transform(\n                img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),\n                Image.BICUBIC, fillcolor=fillcolor),\n            \"translateX\": lambda img, magnitude: img.transform(\n                img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),\n                fillcolor=fillcolor),\n            \"translateY\": lambda img, magnitude: img.transform(\n                img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),\n                fillcolor=fillcolor),\n            \"cutout\": lambda img, magnitude: Cutout(img, magnitude),\n            \"rotate\": lambda img, magnitude: rotate_with_fill(img, magnitude),\n            # \"rotate\": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])),\n            \"color\": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),\n            \"posterize\": lambda img, magnitude: ImageOps.posterize(img, magnitude),\n            \"solarize\": lambda img, magnitude: ImageOps.solarize(img, magnitude),\n            \"contrast\": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(\n                1 + magnitude * random.choice([-1, 1])),\n            \"sharpness\": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(\n                1 + magnitude * random.choice([-1, 1])),\n            \"brightness\": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(\n                1 + magnitude * random.choice([-1, 1])),\n            \"autocontrast\": lambda img, magnitude: ImageOps.autocontrast(img),\n            \"equalize\": lambda img, magnitude: ImageOps.equalize(img),\n            \"invert\": lambda img, magnitude: ImageOps.invert(img)\n        }\n\n        # self.name = \"{}_{:.2f}_and_{}_{:.2f}\".format(\n        #     operation1, ranges[operation1][magnitude_idx1],\n        #     operation2, ranges[operation2][magnitude_idx2])\n        self.p1 = p1\n        self.operation1 = func[operation1]\n        self.magnitude1 = ranges[operation1][magnitude_idx1]\n        self.p2 = p2\n        self.operation2 = func[operation2]\n        self.magnitude2 = ranges[operation2][magnitude_idx2]\n\n\n    def __call__(self, img):\n        if random.random() < self.p1: img = self.operation1(img, self.magnitude1)\n        if random.random() < self.p2: img = self.operation2(img, self.magnitude2)\n        return img"
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/inclearn/tools/cutout.py",
    "content": "import torch\nimport numpy as np\n\nclass Cutout(object):\n    \"\"\"Randomly mask out one or more patches from an image.\n    Args:\n        n_holes (int): Number of patches to cut out of each image.\n        length (int): The length (in pixels) of each square patch.\n    \"\"\"\n    def __init__(self, n_holes, length):\n        self.n_holes = n_holes\n        self.length = length\n\n    def __call__(self, img):\n        \"\"\"\n        Args:\n            img (Tensor): Tensor image of size (C, H, W).\n        Returns:\n            Tensor: Image with n_holes of dimension length x length cut out of it.\n        \"\"\"\n        h = img.size(1)\n        w = img.size(2)\n\n        mask = np.ones((h, w), np.float32)\n\n        for n in range(self.n_holes):\n            y = np.random.randint(h)\n            x = np.random.randint(w)\n\n            y1 = np.clip(y - self.length // 2, 0, h)\n            y2 = np.clip(y + self.length // 2, 0, h)\n            x1 = np.clip(x - self.length // 2, 0, w)\n            x2 = np.clip(x + self.length // 2, 0, w)\n\n            mask[y1: y2, x1: x2] = 0.\n\n        mask = torch.from_numpy(mask)\n        mask = mask.expand_as(img)\n        img = img * mask\n\n        return img"
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/inclearn/tools/data_utils.py",
    "content": "import numpy as np\n\n\ndef construct_balanced_subset(x, y):\n    xdata, ydata = [], []\n    minsize = np.inf\n    for cls_ in np.unique(y):\n        xdata.append(x[y == cls_])\n        ydata.append(y[y == cls_])\n        if ydata[-1].shape[0] < minsize:\n            minsize = ydata[-1].shape[0]\n    for i in range(len(xdata)):\n        # if xdata[i].shape[0] < minsize:\n            # import pdb\n            # pdb.set_trace()\n        idx = np.arange(xdata[i].shape[0])\n        np.random.shuffle(idx)\n        xdata[i] = xdata[i][idx][:minsize]\n        ydata[i] = ydata[i][idx][:minsize]\n    # !list\n    return np.concatenate(xdata, 0), np.concatenate(ydata, 0)"
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/inclearn/tools/factory.py",
    "content": "from matplotlib.transforms import Transform\nimport torch\nfrom torch import nn\nfrom torch import optim\n\nfrom inclearn import models\nfrom inclearn.convnet import sew_resnet\nfrom inclearn.datasets import data\nfrom inclearn.convnet.resnet import SEFeatureAt\n\ndef get_optimizer(params, optimizer, lr, weight_decay=0.0):\n    if optimizer == \"adam\":\n        return optim.Adam(params, lr=lr, weight_decay=weight_decay, betas=(0.9, 0.999))\n    elif optimizer == \"sgd\":\n        return optim.SGD(params, lr=lr, weight_decay=weight_decay, momentum=0.9)\n    else:\n        raise NotImplementedError\n\ndef get_attention(inplane, type, at_res):\n    return SEFeatureAt(inplane, type, at_res)\n\ndef get_convnet(convnet_type, c_dim=None,cdim_cur=None,**kwargs):\n\n    if convnet_type == \"resnet18\":\n        return sew_resnet.sew_resnet18(c_dim,cdim_cur,**kwargs)\n    else:\n        raise NotImplementedError(\"Unknwon convnet type {}.\".format(convnet_type))\n\n\ndef get_model(cfg, trial_i, _run, ex, tensorboard, inc_dataset):\n    if cfg[\"model\"] == \"incmodel\":\n        return models.IncModel(cfg, trial_i, _run, ex, tensorboard, inc_dataset)\n    else:\n        raise NotImplementedError(cfg[\"model\"])\n\n\ndef get_data(cfg, trial_i):\n    return data.IncrementalDataset(\n        trial_i=trial_i,\n        dataset_name=cfg[\"dataset\"],\n        random_order=cfg[\"random_classes\"],\n        shuffle=True,\n        batch_size=cfg[\"batch_size\"],\n        workers=cfg[\"workers\"],\n        validation_split=cfg[\"validation\"],\n        resampling=cfg[\"resampling\"],\n        increment=cfg[\"increment\"],\n        data_folder=cfg[\"data_folder\"],\n        start_class=cfg[\"start_class\"],\n        torc=cfg.get(\"distillation\")\n    )\n\n\ndef set_device(cfg):\n    device_type = cfg[\"device\"]\n\n    if device_type == -1:\n        device = torch.device(\"cpu\")\n    else:\n        device = torch.device(\"cuda:{}\".format(device_type))\n\n    cfg[\"device\"] = device\n    return device\n"
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/inclearn/tools/memory.py",
    "content": "import numpy as np\nfrom copy import deepcopy\nimport torch\nfrom torch.nn import functional as F\n\nfrom inclearn.tools.utils import get_class_loss\nfrom inclearn.convnet.utils import extract_features\n\n\nclass MemorySize:\n    def __init__(self, mode, inc_dataset, total_memory=None, fixed_memory_per_cls=None):\n        self.mode = mode\n        assert mode.lower() in [\"uniform_fixed_per_cls\", \"uniform_fixed_total_mem\", \"dynamic_fixed_per_cls\"]\n        self.total_memory = total_memory\n        self.fixed_memory_per_cls = fixed_memory_per_cls\n        self._n_classes = 0\n        self.mem_per_cls = []\n        self._inc_dataset = inc_dataset\n\n    def update_n_classes(self, n_classes):\n        self._n_classes = n_classes\n\n    def update_memory_per_cls_uniform(self, n_classes):\n        if \"fixed_per_cls\" in self.mode:\n            self.mem_per_cls = [self.fixed_memory_per_cls for i in range(n_classes)]\n        elif \"fixed_total_mem\" in self.mode:\n            self.mem_per_cls = [self.total_memory // n_classes for i in range(n_classes)]\n        return self.mem_per_cls\n\n    def update_memory_per_cls(self, network, n_classes, task_size):\n        if \"uniform\" in self.mode:\n            self.update_memory_per_cls_uniform(n_classes)\n        else:\n            if n_classes == task_size:\n                self.update_memory_per_cls_uniform(n_classes)\n\n    @property\n    def memsize(self):\n        if self.mode == \"fixed_total_mem\":\n            return self.total_memory\n        elif self.mode == \"fixed_per_cls\":\n            return self.fixed_memory_per_cls * self._n_classes\n\n\ndef compute_examplar_mean(feat_norm, feat_flip, herding_mat, nb_max):\n    EPSILON = 1e-8\n    D = feat_norm.T\n    D = D / (np.linalg.norm(D, axis=0) + EPSILON)\n\n    D2 = feat_flip.T\n    D2 = D2 / (np.linalg.norm(D2, axis=0) + EPSILON)\n\n    alph = herding_mat\n    alph = (alph > 0) * (alph < nb_max + 1) * 1.0\n\n    alph_mean = alph / np.sum(alph)\n\n    mean = (np.dot(D, alph_mean) + np.dot(D2, alph_mean)) / 2\n    # mean = np.dot(D, alph_mean)\n    mean /= np.linalg.norm(mean) + EPSILON\n\n    return mean, alph\n\n\ndef select_examplars(features, nb_max):\n    EPSILON = 1e-8\n    D = features.T\n    D = D / (np.linalg.norm(D, axis=0) + EPSILON)\n    mu = np.mean(D, axis=1)\n    herding_matrix = np.zeros((features.shape[0], ))\n    idxes = []\n    w_t = mu\n\n    iter_herding, iter_herding_eff = 0, 0\n\n    while not (np.sum(herding_matrix != 0) == min(nb_max, features.shape[0])) and iter_herding_eff < 1000:\n        tmp_t = np.dot(w_t, D)\n        # tmp_t = -np.linalg.norm(w_t[:,np.newaxis]-D, axis=0)\n        # tmp_t = np.linalg.norm(w_t[:,np.newaxis]-D, axis=0)\n        ind_max = np.argmax(tmp_t)\n        iter_herding_eff += 1\n        if herding_matrix[ind_max] == 0:\n            herding_matrix[ind_max] = 1 + iter_herding\n            idxes.append(ind_max)\n            iter_herding += 1\n\n        w_t = w_t + mu - D[:, ind_max]\n\n    return herding_matrix, idxes\n\n\ndef random_selection(n_classes, task_size, network, logger, inc_dataset, memory_per_class: list):\n    # TODO: Move data_memroy,targets_memory into IncDataset\n    logger.info(\"Building & updating memory.(Random Selection)\")\n    tmp_data_memory, tmp_targets_memory = [], []\n    assert len(memory_per_class) == n_classes\n    for class_idx in range(n_classes):\n        # 旧类数据从get_custom_loader_from_memory中读取，新类数据从get_custom_loader中读取\n        if class_idx < n_classes - task_size:\n            inputs, targets, loader = inc_dataset.get_custom_loader_from_memory([class_idx])\n        else:\n            inputs, targets, loader = inc_dataset.get_custom_loader(class_idx, mode=\"test\")\n        memory_this_cls = min(memory_per_class[class_idx], inputs.shape[0])\n        idxs = np.random.choice(inputs.shape[0], memory_this_cls, replace=False)\n        tmp_data_memory.append(inputs[idxs])\n        tmp_targets_memory.append(targets[idxs])\n    tmp_data_memory = np.concatenate(tmp_data_memory)\n    tmp_targets_memory = np.concatenate(tmp_targets_memory)\n    return tmp_data_memory, tmp_targets_memory\n\n\ndef herding(task_i,n_classes, task_size, network, herding_matrix, inc_dataset, shared_data_inc, memory_per_class: list,\n            logger,mask=None):\n    \"\"\"Herding matrix: list\n    \"\"\"\n    logger.info(\"Building & updating memory.(iCaRL)\")\n    tmp_data_memory, tmp_targets_memory = [], []\n\n    for class_idx in range(n_classes):\n        inputs = inc_dataset.data_train[inc_dataset.targets_train == class_idx] \n        targets = inc_dataset.targets_train[inc_dataset.targets_train == class_idx]\n        # zi = inc_dataset.zimages[inc_dataset.zlabels == class_idx]\n        # zt = inc_dataset.zlabels[inc_dataset.zlabels == class_idx]\n        # inputs = np.concatenate((inputs, zi))\n        # targets = np.concatenate((targets, zt))\n\n\n        if class_idx >= n_classes - task_size:\n            if len(shared_data_inc) > len(inc_dataset.targets_inc):\n                share_memory = [shared_data_inc[i] for i in np.where(inc_dataset.targets_inc == class_idx)[0].tolist()]\n            else:\n                share_memory = []\n                for i in np.where(inc_dataset.targets_inc == class_idx)[0].tolist():\n                    if i < len(shared_data_inc):\n                        share_memory.append(shared_data_inc[i])\n\n            # share_memory = [shared_data_inc[i] for i in np.where(inc_dataset.targets_inc == class_idx)[0].tolist()]\n            loader = inc_dataset._get_loader(inc_dataset.data_inc[inc_dataset.targets_inc == class_idx],\n                                             inc_dataset.targets_inc[inc_dataset.targets_inc == class_idx],\n                                             share_memory=share_memory,\n                                             batch_size=128,\n                                             shuffle=False,\n                                             mode=\"test\")\n            features, _ = extract_features(task_i,network, loader,mask=mask)\n            # features_flipped, _ = extract_features(network, inc_dataset.get_custom_loader(class_idx, mode=\"flip\")[-1])\n            herding_matrix.append(select_examplars(features, memory_per_class[class_idx])[0])\n        alph = herding_matrix[class_idx]\n        alph = (alph > 0) * (alph < memory_per_class[class_idx] + 1) * 1.0\n        # examplar_mean, alph = compute_examplar_mean(features, features_flipped, herding_matrix[class_idx],\n        #                                             memory_per_class[class_idx])\n        tmp_data_memory.append(inputs[np.where(alph == 1)[0]])\n        tmp_targets_memory.append(targets[np.where(alph == 1)[0]])\n    tmp_data_memory = np.concatenate(tmp_data_memory)\n    tmp_targets_memory = np.concatenate(tmp_targets_memory)\n    return tmp_data_memory, tmp_targets_memory, herding_matrix\n"
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/inclearn/tools/metrics.py",
    "content": "import numpy as np\nimport torch\nimport numbers\nimport math\n\n\nclass IncConfusionMeter:\n    \"\"\"Maintains a confusion matrix for a given calssification problem.\n    The ConfusionMeter constructs a confusion matrix for a multi-class\n    classification problems. It does not support multi-label, multi-class problems:\n    for such problems, please use MultiLabelConfusionMeter.\n    Args:\n        k (int): number of classes in the classification problem\n        normalized (boolean): Determines whether or not the confusion matrix\n            is normalized or not\n    \"\"\"\n    def __init__(self, k, increments, normalized=False):\n        self.conf = np.ndarray((k, k), dtype=np.int32)\n        self.normalized = normalized\n        self.increments = increments\n        self.cum_increments = [0] + [sum(increments[:i + 1]) for i in range(len(increments))]\n        self.k = k\n        self.reset()\n\n    def reset(self):\n        self.conf.fill(0)\n\n    def add(self, predicted, target):\n        \"\"\"Computes the confusion matrix of K x K size where K is no of classes\n        Args:\n            predicted (tensor): Can be an N x K tensor of predicted scores obtained from\n                the model for N examples and K classes or an N-tensor of\n                integer values between 0 and K-1.\n            target (tensor): Can be a N-tensor of integer values assumed to be integer\n                values between 0 and K-1 or N x K tensor, where targets are\n                assumed to be provided as one-hot vectors\n        \"\"\"\n        if isinstance(predicted, torch.Tensor):\n            predicted = predicted.cpu().numpy()\n        if isinstance(target, torch.Tensor):\n            target = target.cpu().numpy()\n\n        assert predicted.shape[0] == target.shape[0], \\\n            'number of targets and predicted outputs do not match'\n\n        if np.ndim(predicted) != 1:\n            assert predicted.shape[1] == self.k, \\\n                'number of predictions does not match size of confusion matrix'\n            predicted = np.argmax(predicted, 1)\n        else:\n            assert (predicted.max() < self.k) and (predicted.min() >= 0), \\\n                'predicted values are not between 1 and k'\n\n        onehot_target = np.ndim(target) != 1\n        if onehot_target:\n            assert target.shape[1] == self.k, \\\n                'Onehot target does not match size of confusion matrix'\n            assert (target >= 0).all() and (target <= 1).all(), \\\n                'in one-hot encoding, target values should be 0 or 1'\n            assert (target.sum(1) == 1).all(), \\\n                'multi-label setting is not supported'\n            target = np.argmax(target, 1)\n        else:\n            assert (predicted.max() < self.k) and (predicted.min() >= 0), \\\n                'predicted values are not between 0 and k-1'\n\n        # hack for bincounting 2 arrays together\n        x = predicted + self.k * target\n        bincount_2d = np.bincount(x.astype(np.int32), minlength=self.k**2)\n        assert bincount_2d.size == self.k**2\n        conf = bincount_2d.reshape((self.k, self.k))\n\n        self.conf += conf\n\n    def value(self):\n        \"\"\"\n        Returns:\n            Confustion matrix of K rows and K columns, where rows corresponds\n            to ground-truth targets and columns corresponds to predicted\n            targets.\n        \"\"\"\n        conf = self.conf.astype(np.float32)\n        new_conf = np.zeros([len(self.increments), len(self.increments) + 2])\n        for i in range(len(self.increments)):\n            idxs = range(self.cum_increments[i], self.cum_increments[i + 1])\n            new_conf[i, 0] = conf[idxs, idxs].sum()\n            new_conf[i, 1] = conf[self.cum_increments[i]:self.cum_increments[i + 1],\n                                  self.cum_increments[i]:self.cum_increments[i + 1]].sum() - new_conf[i, 0]\n            for j in range(len(self.increments)):\n                new_conf[i, j + 2] = conf[self.cum_increments[i]:self.cum_increments[i + 1],\n                                          self.cum_increments[j]:self.cum_increments[j + 1]].sum()\n        conf = new_conf\n        if self.normalized:\n            return conf / conf[:, 2:].sum(1).clip(min=1e-12)[:, None]\n        else:\n            return conf\n\n\nclass ClassErrorMeter:\n    def __init__(self, topk=[1], accuracy=False):\n        super(ClassErrorMeter, self).__init__()\n        self.topk = np.sort(topk)\n        self.accuracy = accuracy\n        self.reset()\n\n    def reset(self):\n        self.sum = {v: 0 for v in self.topk}\n        self.n = 0\n\n    def add(self, output, target):\n        if isinstance(output, np.ndarray):\n            output = torch.Tensor(output)\n        if isinstance(target, np.ndarray):\n            target = torch.Tensor(target)\n        # if torch.is_tensor(output):\n        #     output = output.cpu().squeeze().numpy()\n        # if torch.is_tensor(target):\n        #     target = target.cpu().squeeze().numpy()\n        # elif isinstance(target, numbers.Number):\n        #     target = np.asarray([target])\n        # if np.ndim(output) == 1:\n        #     output = output[np.newaxis]\n        # else:\n        #     assert np.ndim(output) == 2, \\\n        #         'wrong output size (1D or 2D expected)'\n        #     assert np.ndim(target) == 1, \\\n        #         'target and output do not match'\n        # assert target.shape[0] == output.shape[0], \\\n        #     'target and output do not match'\n        topk = self.topk\n        maxk = int(topk[-1])  # seems like Python3 wants int and not np.int64\n        no = output.shape[0]\n\n        pred = output.topk(maxk, 1, True, True)[1]\n        correct = pred == target.unsqueeze(1).repeat(1, pred.shape[1])\n        # pred = torch.from_numpy(output).topk(maxk, 1, True, True)[1].numpy()\n        # correct = pred == target[:, np.newaxis].repeat(pred.shape[1], 1)\n\n        for k in topk:\n            self.sum[k] += no - correct[:, 0:k].sum()\n        self.n += no\n\n    def value(self, k=-1):\n        if k != -1:\n            assert k in self.sum.keys(), \\\n                'invalid k (this k was not provided at construction time)'\n            if self.n == 0:\n                return float('nan')\n            if self.accuracy:\n                return (1. - float(self.sum[k]) / self.n) * 100.0\n            else:\n                return float(self.sum[k]) / self.n * 100.0\n        else:\n            return [self.value(k_) for k_ in self.topk]\n\n\nclass AverageValueMeter:\n    def __init__(self):\n        super(AverageValueMeter, self).__init__()\n        self.reset()\n        self.val = 0\n\n    def add(self, value, n=1):\n        self.val = value\n        self.sum += value\n        self.var += value * value\n        self.n += n\n\n        if self.n == 0:\n            self.mean, self.std = np.nan, np.nan\n        elif self.n == 1:\n            self.mean, self.std = self.sum, np.inf\n            self.mean_old = self.mean\n            self.m_s = 0.0\n        else:\n            self.mean = self.mean_old + (value - n * self.mean_old) / float(self.n)\n            self.m_s += (value - self.mean_old) * (value - self.mean)\n            self.mean_old = self.mean\n            self.std = math.sqrt(self.m_s / (self.n - 1.0))\n\n    def value(self):\n        return self.mean, self.std\n\n    def reset(self):\n        self.n = 0\n        self.sum = 0.0\n        self.var = 0.0\n        self.val = 0.0\n        self.mean = np.nan\n        self.mean_old = 0.0\n        self.m_s = 0.0\n        self.std = np.nan"
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/inclearn/tools/results_utils.py",
    "content": "import glob\nimport json\nimport math\nimport os\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom copy import deepcopy\n\nfrom . import utils\n\n\ndef get_template_results(cfg):\n    return {\"config\": cfg, \"results\": []}\n\n\ndef save_results(results, label):\n    del results[\"config\"][\"device\"]\n\n    folder_path = os.path.join(\"results\", \"{}_{}\".format(utils.get_date(), label))\n    if not os.path.exists(folder_path):\n        os.makedirs(folder_path)\n\n    file_path = \"{}_{}_.json\".format(utils.get_date(), results[\"config\"][\"seed\"])\n    with open(os.path.join(folder_path, file_path), \"w+\") as f:\n        json.dump(results, f, indent=2)\n\n\ndef compute_avg_inc_acc(results):\n    \"\"\"Computes the average incremental accuracy as defined in iCaRL.\n\n    The average incremental accuracies at task X are the average of accuracies\n    at task 0, 1, ..., and X.\n\n    :param accs: A list of dict for per-class accuracy at each step.\n    :return: A float.\n    \"\"\"\n    top1_tasks_accuracy = [r['top1'][\"total\"] for r in results]\n    top1acc = sum(top1_tasks_accuracy) / len(top1_tasks_accuracy)\n    if \"top5\" in results[0].keys():\n        top5_tasks_accuracy = [r['top5'][\"total\"] for r in results]\n        top5acc = sum(top5_tasks_accuracy) / len(top5_tasks_accuracy)\n    else:\n        top5acc = None\n    return top1acc, top5acc"
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/inclearn/tools/scheduler.py",
    "content": "import math\nfrom torch.optim.lr_scheduler import _LRScheduler\nfrom torch.optim.lr_scheduler import ReduceLROnPlateau\n\n\nclass ConstantTaskLR:\n    def __init__(self, lr):\n        self._lr = lr\n\n    def get_lr(self, task_i):\n        return self._lr\n\n\nclass CosineAnnealTaskLR:\n    def __init__(self, lr_max, lr_min, task_max):\n        self._lr_max = lr_max\n        self._lr_min = lr_min\n        self._task_max = task_max\n\n    def get_lr(self, task_i):\n        return self._lr_min + (self._lr_max - self._lr_min) * (1 + math.cos(math.pi * task_i / self._task_max)) / 2\n\n\nclass GradualWarmupScheduler(_LRScheduler):\n    \"\"\" Gradually warm-up(increasing) learning rate in optimizer.\n    https://github.com/ildoonet/pytorch-gradual-warmup-lr\n    Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.\n    Args:\n        optimizer (Optimizer): Wrapped optimizer.\n        multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.\n        total_epoch: target learning rate is reached at total_epoch, gradually\n        after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)\n    \"\"\"\n    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):\n        self.multiplier = multiplier\n        if self.multiplier < 1.:\n            raise ValueError('multiplier should be greater thant or equal to 1.')\n        self.total_epoch = total_epoch\n        self.after_scheduler = after_scheduler\n        self.finished = False\n        super(GradualWarmupScheduler, self).__init__(optimizer)\n\n    def get_lr(self):\n        if self.last_epoch > self.total_epoch:\n            if self.after_scheduler:\n                if not self.finished:\n                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]\n                    self.finished = True\n                return self.after_scheduler.get_last_lr()\n            return [base_lr * self.multiplier for base_lr in self.base_lrs]\n\n        if self.multiplier == 1.0:\n            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]\n        else:\n            return [\n                base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.)\n                for base_lr in self.base_lrs\n            ]\n\n    def step_ReduceLROnPlateau(self, metrics, epoch=None):\n        if epoch is None:\n            epoch = self.last_epoch + 1\n        self.last_epoch = epoch if epoch != 0 else 1  # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning\n        if self.last_epoch <= self.total_epoch:\n            warmup_lr = [\n                base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.)\n                for base_lr in self.base_lrs\n            ]\n            for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):\n                param_group['lr'] = lr\n        else:\n            if epoch is None:\n                self.after_scheduler.step(metrics, None)\n            else:\n                self.after_scheduler.step(metrics, epoch - self.total_epoch)\n\n    def step(self, epoch=None, metrics=None):\n        if type(self.after_scheduler) != ReduceLROnPlateau:\n            if self.finished and self.after_scheduler:\n                if epoch is None:\n                    self.after_scheduler.step(None)\n                else:\n                    self.after_scheduler.step(epoch - self.total_epoch)\n                self._last_lr = self.after_scheduler.get_last_lr()\n            else:\n                return super(GradualWarmupScheduler, self).step(epoch)\n        else:\n            self.step_ReduceLROnPlateau(metrics, epoch)\n"
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/inclearn/tools/similar.py",
    "content": "import torch.nn as nn\nimport torchvision.models as models\nimport numpy as np\nimport os\nfrom sklearn.utils import shuffle\nimport torch\nimport torch.nn.functional as F\n\nclass Appr(object):\n    \"\"\" Class implementing the TALL \"\"\"\n    def __init__(self, pretrained_feat_extractor, num_task, torc,device='cuda', args=None):\n\n        self.task2expert = []\n        self.expert2task = []\n        self.torc=torc\n        self.num_task=num_task\n        self.task2mean = {}\n        self.task2cov = {}\n        for i in range(num_task):\n            self.task2mean[i]=[]\n            self.task2cov[i]=[]\n        self.task_dist = torch.zeros(num_task, num_task).to(device='cuda')\n        self.task_dist2 = torch.zeros(num_task, num_task).to(device='cuda')\n        self.feat_extractor = pretrained_feat_extractor\n        self.task_relatedness_method = \"mean\"\n        self.reuse_threshold=0.3\n        self.reuse_cell_threshold=0.75\n\n    def get_mean_cov_feats(self,  taski, data, device):\n        \"\"\"compute mean and cov for features of data extracted by the expert of task t\n\n        \"\"\"\n        # data = deepcopy(data)  # copy for using different preprocess\n        \n        # self.model.requires_grad_(False)\n        # self.model.eval()\n        # self.model.set_current_task(t)\n        if self.torc:\n            steps=int(100/self.num_task)\n            class_num = steps*taski\n            labels = torch.arange(class_num,class_num+steps).view(-1, 1).to(device)\n        else:\n            steps = int(100/self.num_task)\n            labels = torch.arange(steps).view(-1, 1).to(device)\n        all_task_feats={}\n        for t in range(taski):\n            all_task_feats[t]=[[] for _ in range(steps)]\n\n        self.feat_extractor.eval()\n        with torch.no_grad():\n            for batch_idx, (x, y) in enumerate(data):\n                x, y = x.to(device), y.to(device)\n                index = labels == y.view(1, -1) # CxC\n                for t_p in range(taski):\n                    feat = self.feat_extractor(t_p,x,mask=None,classify=False)['features']\n                    feat = feat.view(feat.size(0), -1)\n\n                    for i in range(steps):\n                         all_task_feats[t_p][i].append(feat[index[i]])\n            \n            \n            feat_means = {}\n            feat_covs = {}\n            all_feats_cat={}\n            for t_p in range(taski):\n                feat_means[t_p] = []\n                feat_covs [t_p]= []\n                all_feats_cat[t_p] = [torch.cat(feats, axis=0) for feats in all_task_feats[t_p]]\n\n                for feat in all_feats_cat[t_p]:\n                    feat_mean, feat_cov = gaussian_mean_cov(feat)\n                    feat_means[t_p].append(feat_mean)\n                    # feat_covs[t_p].append(feat_cov)\n            # feat_means = [torch.mean(feat, dim=0) for feat in all_feats_cat]\n\n        return feat_means, feat_covs, all_feats_cat\n    \n        \n    def after_get_mean_cov_feats(self, taski, data, device):\n        \"\"\"compute mean and cov for features of data extracted by the expert of task t\n\n        \"\"\"\n        # data = deepcopy(data)  # copy for using different preprocess\n        \n        # self.model.requires_grad_(False)\n        # self.model.eval()\n        # self.model.set_current_task(t)\n        if self.torc:\n            steps=int(100/self.num_task)\n            class_num = steps*taski\n            labels = torch.arange(class_num,class_num+steps).view(-1, 1).to(device)\n        else:\n            steps = int(100/self.num_task)\n            labels = torch.arange(steps).view(-1, 1).to(device)\n        all_task_feats=[[] for _ in range(steps)]\n        self.feat_extractor.eval()\n        with torch.no_grad():\n            for batch_idx, (x, y) in enumerate(data):\n                x, y = x.to(device), y.to(device)\n                # forward\n                feat = self.feat_extractor(taski,x,mask=None,classify=False)['features']\n                feat = feat.view(feat.size(0), -1)\n\n                index = labels == y.view(1, -1) # CxC\n                for i in range(steps):\n                    all_task_feats[i].append(feat[index[i]])\n            \n            feat_means= []\n            feat_covs = []\n            all_feats_cat= [torch.cat(feats, axis=0) for feats in all_task_feats]\n\n            for feat in all_feats_cat:\n                feat_mean, feat_cov = gaussian_mean_cov(feat)\n                feat_means.append(feat_mean)\n                # feat_covs.append(feat_cov)\n            # feat_means = [torch.mean(feat, dim=0) for feat in all_feats_cat]\n\n        return feat_means, feat_covs, all_feats_cat\n\n    def add_mean_cov(self, taski,mean, cov=None):\n        self.task2mean[taski].append(mean)\n        # self.task2cov[taski].append(cov)\n\n    def task_relatedness_knnkl(self, task_id, p_task_id, all_feats):\n        \"\"\"\n        Params:\n            all_feat: shape C x N_c x D\n        \"\"\"\n        # means and features of current data from expert of p_task_id\n        # feat_means, all_feats = means_and_feats[p_task_id]\n        # means of data of p_task_id\n        p_feat_means = self.task2mean[p_task_id][p_task_id] # C x D\n        feat_means = self.task2mean[task_id][p_task_id] # C x D\n        # p_feat_cov = self.task2cov[p_task_id][p_task_id] # C x D\n        # feat_cov = self.task2cov[task_id][p_task_id] # C x D\n        p_feat_means = torch.stack(p_feat_means, dim=0)\n\n        task_dist = 0\n        task_dist2=0\n        d = p_feat_means.shape[-1]\n        n = 0\n        flag = False\n        for i in range(len(feat_means)):\n            # for each current class\n            n += all_feats[i].shape[0]\n\n            dist_in = all_feats[i] - feat_means[i] # N_c x D\n            dist_in = torch.sqrt(torch.sum(dist_in ** 2, dim=-1)) # N_c\n            \n            # N_c x C x D\n            dist_out = torch.unsqueeze(all_feats[i], dim=1) - torch.unsqueeze(p_feat_means, dim=0)\n            dist_out, _ = torch.min(torch.sqrt(torch.sum(dist_out ** 2, dim=-1)), dim=-1)  # N_c\n\n            dist = torch.mean(torch.log(dist_out / (0.9*dist_in)))\n            #dist = torch.mean(torch.log(dist_out))\n\n            # if dist <= 0:\n            #     flag = True\n\n            task_dist += torch.maximum(dist, torch.zeros_like(dist))\n            task_dist2 += torch.maximum(dist, torch.zeros_like(dist))\n        \n        # task_dist = task_dist / len(feat_means)\n        # task_dist = 1 - torch.exp(-2*task_dist)\n        task_dist = torch.minimum(1 - torch.exp(-2*task_dist), task_dist)\n\n        if task_dist == 0:\n            task_dist = torch.ones_like(task_dist)\n\n        return task_dist,task_dist2\n\n    def get_relatedness(self, task_id, feats):\n        \"\"\"Compute relatedness\n        \n        \"\"\"\n        # the distance between task_id and task_id \n        self.task_dist[task_id][task_id] = 0\n        self.task_dist2[task_id][task_id] = 0\n\n        for p_task_id in range(task_id):\n        # for p_task_id in range(task_id + 1):\n            # task_dist = self.task_relatedness_cos(task_id, p_task_id)\n            task_dist,task_dist2 = self.task_relatedness_knnkl(task_id, p_task_id, feats[p_task_id])\n            # task_dist = self.task_relatedness_CKA(task_id, p_task_id, feats)\n            # task_dist = self.task_relatedness_gaussian_kl(task_id, p_task_id)\n            self.task_dist[task_id][p_task_id] = task_dist\n            self.task_dist[p_task_id][task_id] = task_dist\n            self.task_dist2[task_id][p_task_id] = task_dist2\n            self.task_dist2[p_task_id][task_id] = task_dist2\n    \n    def strategy(self, task_id, num_train_samples):\n        \"\"\" Find the expert to be reused. If not found, return -1.\n        \n        \"\"\"\n        expert = -1\n        min_dist = None\n\n        \n        all_dist = []\n        all_dist2 = []\n        for expert_id, p_tasks in enumerate(self.expert2task):\n            d = self.task_dist[task_id, p_tasks] \n            dd=  self.task_dist2[task_id, p_tasks] \n            if self.task_relatedness_method == \"mean\":\n                s = torch.mean(d).item()\n                ss = torch.mean(dd).item()\n            elif self.task_relatedness_method == \"max\":\n                s = torch.max(d).item()\n            elif self.task_relatedness_method == \"min\":\n                s = torch.min(d).item()\n            else:\n                raise Exception(\"Unknown reuse strategy !!!\")\n            all_dist.append(s)\n            all_dist2.append(ss)\n            if min_dist is None:\n                min_dist = s\n                expert = expert_id\n            elif s < min_dist:\n                min_dist = s\n                expert = expert_id\n\n        # if num_train_samples <= 25: # for s_long\n        #     all_dist = torch.tensor(all_dist)\n        #     _, expert_idx =torch.sort(all_dist)\n        #     for e in expert_idx:\n        #         if self.model.expert2max_train_samples[e] >= 10 * num_train_samples:\n        #             return \"reuse\", e\n        \n        if min_dist <= self.reuse_threshold:\n            return \"reuse\", expert ,min_dist,all_dist,all_dist2\n        # elif min_dist <= self.reuse_cell_threshold:\n        #     return \"reuse cell\", expert\n        else:\n            return \"new\", expert,min_dist,all_dist,all_dist2\n            \n    def learn(self, task_id, valid_data, batch_size,device):\n        \"\"\"learn a task \n\n        \"\"\"      \n        if task_id == 0:\n            # train\n            strategy='new'\n            expert_id=task_id\n            min_dist=0\n            all_dist=0\n            all_dist2=0\n        \n\n        else:\n            feat_means, feat_covs, all_feats = self.get_mean_cov_feats(\n            task_id, valid_data, device=device)\n            for t in range(task_id):\n                self.add_mean_cov(task_id,feat_means[t],feat_covs[t])\n            self.get_relatedness(task_id, all_feats)\n\n            num_train_samples=len(valid_data)*batch_size\n            strategy, expert_id,min_dist,all_dist,all_dist2 = self.strategy(task_id, num_train_samples)\n        self.expert2task.append([task_id])  \n        print(self.task_dist)\n        print(self.task_dist2)\n\n        return strategy, expert_id,min_dist,all_dist,all_dist2\n    \n    def after_learn(self, task_id, valid_data, batch_size,device):\n        \"\"\"learn a task \n\n        \"\"\" \n        feat_means, feat_covs, all_feats = self.after_get_mean_cov_feats(\n            task_id, valid_data, device=device)\n        self.add_mean_cov(task_id,feat_means,feat_covs)\n        print(self.task_dist)\n        print(self.task_dist2)\n            \n    \nclass ResNet_FE(nn.Module):\n    \"\"\"\n\tCreate a feature extractor model from an Alexnet architecture, that is used to train the autoencoder model\n\tand get the most related model whilst training a new task in a sequence\n\t\"\"\"\n    def __init__(self, resnet_model):\n        super(ResNet_FE, self).__init__()\n        self.fe_model = nn.Sequential(*list(resnet_model.children())[:-1])\n        self.fe_model.eval()\n        self.fe_model.requires_grad_(False)\n    \n    def forward(self, x):\n        return self.fe_model(x)\n    \ndef get_pretrained_feat_extractor(name):\n    \"\"\"get the feature extractor pretrained on ImageNet\n    \n    \"\"\"\n    if name == \"resnet18\":\n        feat_extractor = ResNet_FE(models.resnet18(weights=True))\n        # self.logger.info(\"Using relatedness feature extractor: ResNet18\")\n    else:\n        raise Exception(\"Unknown relatedness feature extractor !!!\")\n\n    return feat_extractor\n\ndef gaussian_mean_cov(X):\n    \"\"\"mean and covariance of Guassian distribution\n    \n    Params:\n        X: N x D\n    \"\"\"\n    device = X.device\n    N, D = X.shape[0], X.shape[1]\n    u = torch.mean(X, dim=0)\n    u_row = torch.reshape(u, (1, -1))  # 1 x D\n    cov = torch.matmul(X.T, X) - N * torch.matmul(u_row.T, u_row)  # D x D\n    cov = cov / (N - 1)\n\n    cov = cov * torch.diag(torch.ones(D)).to(X.device) + (torch.diag(torch.ones(D))).to(X.device)\n\n    return u, cov\n\n\n# import torch.nn as nn\n# import torchvision.models as models\n# import numpy as np\n# import os\n# from sklearn.utils import shuffle\n# import torch\n# import torch.nn.functional as F\n\n# class Appr(object):\n#     \"\"\" Class implementing the TALL \"\"\"\n#     def __init__(self, pretrained_feat_extractor, num_task, device='cuda', args=None):\n\n#         self.task2expert = []\n#         self.expert2task = []\n        \n#         self.task2mean = []\n#         self.task2cov = []\n#         self.task_dist = torch.zeros(num_task, num_task).to(device='cuda')\n#         self.feat_extractor = get_pretrained_feat_extractor(pretrained_feat_extractor).to(device='cuda')\n#         self.task_relatedness_method = \"mean\"\n#         self.reuse_threshold=0.3\n#         self.reuse_cell_threshold=0.75\n\n#     def get_mean_cov_feats(self, t, data, device):\n#         \"\"\"compute mean and cov for features of data extracted by the expert of task t\n\n#         \"\"\"\n#         # data = deepcopy(data)  # copy for using different preprocess\n        \n#         # self.model.requires_grad_(False)\n#         # self.model.eval()\n#         # self.model.set_current_task(t)\n\n#         class_num = 10\n#         labels = torch.arange(class_num).view(-1, 1).to(device)\n#         all_feats = [[] for _ in range(class_num)]\n\n#         self.feat_extractor.eval()\n#         with torch.no_grad():\n#             for batch_idx, (x, y) in enumerate(data):\n#                 x, y = x.to(device), y.to(device)\n#                 # forward\n#                 feat = self.feat_extractor(x)\n#                 feat = feat.view(feat.size(0), -1)\n\n#                 index = labels == y.view(1, -1) # CxC\n#                 for i in range(class_num):\n#                     all_feats[i].append(feat[index[i]])\n            \n            \n#             all_feats_cat = [torch.cat(feats, axis=0) for feats in all_feats]\n#             feat_means = []\n#             feat_covs = []\n#             for feat in all_feats_cat:\n#                 feat_mean, feat_cov = gaussian_mean_cov(feat)\n#                 feat_means.append(feat_mean)\n#                 feat_covs.append(feat_cov)\n#             # feat_means = [torch.mean(feat, dim=0) for feat in all_feats_cat]\n\n#         return feat_means, feat_covs, all_feats_cat\n    \n#     def add_mean_cov(self, mean, cov=None):\n#         self.task2mean.append(mean)\n#         self.task2cov.append(cov)\n\n#     def task_relatedness_knnkl(self, task_id, p_task_id, all_feats):\n#         \"\"\"\n#         Params:\n#             all_feat: shape C x N_c x D\n#         \"\"\"\n#         # means and features of current data from expert of p_task_id\n#         # feat_means, all_feats = means_and_feats[p_task_id]\n#         # means of data of p_task_id\n#         p_feat_means = self.task2mean[p_task_id] # C x D\n#         feat_means = self.task2mean[task_id] # C x D\n#         p_feat_means = torch.stack(p_feat_means, dim=0)\n\n#         task_dist = 0\n#         d = p_feat_means.shape[-1]\n#         n = 0\n#         flag = False\n#         for i in range(len(feat_means)):\n#             # for each current class\n#             n += all_feats[i].shape[0]\n\n#             dist_in = all_feats[i] - feat_means[i] # N_c x D\n#             dist_in = torch.sqrt(torch.sum(dist_in ** 2, dim=-1)) # N_c\n            \n#             # N_c x C x D\n#             dist_out = torch.unsqueeze(all_feats[i], dim=1) - torch.unsqueeze(p_feat_means, dim=0)\n#             dist_out, _ = torch.min(torch.sqrt(torch.sum(dist_out ** 2, dim=-1)), dim=-1)  # N_c\n\n#             dist = torch.mean(torch.log(dist_out / dist_in))\n\n#             # if dist <= 0:\n#             #     flag = True\n\n#             task_dist += torch.maximum(dist, torch.zeros_like(dist))\n        \n#         # task_dist = task_dist / len(feat_means)\n#         # task_dist = 1 - torch.exp(-2*task_dist)\n#         task_dist = torch.minimum(1 - torch.exp(-2*task_dist), task_dist)\n\n#         if task_dist == 0:\n#             task_dist = torch.ones_like(task_dist)\n\n#         return task_dist\n\n#     def get_relatedness(self, task_id, feats):\n#         \"\"\"Compute relatedness\n        \n#         \"\"\"\n#         # the distance between task_id and task_id \n#         self.task_dist[task_id][task_id] = 0\n\n#         for p_task_id in range(task_id):\n#         # for p_task_id in range(task_id + 1):\n#             # task_dist = self.task_relatedness_cos(task_id, p_task_id)\n#             task_dist = self.task_relatedness_knnkl(task_id, p_task_id, feats)\n#             # task_dist = self.task_relatedness_CKA(task_id, p_task_id, feats)\n#             # task_dist = self.task_relatedness_gaussian_kl(task_id, p_task_id)\n#             self.task_dist[task_id][p_task_id] = task_dist\n#             self.task_dist[p_task_id][task_id] = task_dist\n    \n#     def strategy(self, task_id, num_train_samples):\n#         \"\"\" Find the expert to be reused. If not found, return -1.\n        \n#         \"\"\"\n#         expert = -1\n#         min_dist = None\n\n        \n#         all_dist = []\n#         for expert_id, p_tasks in enumerate(self.expert2task):\n#             d = self.task_dist[task_id, p_tasks]   \n#             if self.task_relatedness_method == \"mean\":\n#                 s = torch.mean(d).item()\n#             elif self.task_relatedness_method == \"max\":\n#                 s = torch.max(d).item()\n#             elif self.task_relatedness_method == \"min\":\n#                 s = torch.min(d).item()\n#             else:\n#                 raise Exception(\"Unknown reuse strategy !!!\")\n#             all_dist.append(s)\n#             if min_dist is None:\n#                 min_dist = s\n#                 expert = expert_id\n#             elif s < min_dist:\n#                 min_dist = s\n#                 expert = expert_id\n\n#         # if num_train_samples <= 25: # for s_long\n#         #     all_dist = torch.tensor(all_dist)\n#         #     _, expert_idx =torch.sort(all_dist)\n#         #     for e in expert_idx:\n#         #         if self.model.expert2max_train_samples[e] >= 10 * num_train_samples:\n#         #             return \"reuse\", e\n        \n#         if min_dist <= self.reuse_threshold:\n#             return \"reuse\", expert ,min_dist,all_dist\n#         # elif min_dist <= self.reuse_cell_threshold:\n#         #     return \"reuse cell\", expert\n#         else:\n#             return \"new\", expert,min_dist,all_dist\n            \n#     def learn(self, task_id, valid_data, batch_size,device):\n#         \"\"\"learn a task \n\n#         \"\"\"      \n#         feat_means, feat_covs, all_feats = self.get_mean_cov_feats(\n#             task_id, valid_data, device=device)\n#         self.add_mean_cov(feat_means)\n#         self.get_relatedness(task_id, all_feats)\n\n#         if task_id == 0:\n#             # train\n#             strategy='new'\n#             expert_id=task_id\n#             min_dist=0\n#             all_dist=0\n#         else:\n#             num_train_samples=len(valid_data)*batch_size\n#             strategy, expert_id,min_dist,all_dist = self.strategy(task_id, num_train_samples)\n#         self.expert2task.append([task_id])  \n\n#         return strategy, expert_id,min_dist,all_dist\n            \n    \n# class ResNet_FE(nn.Module):\n#     \"\"\"\n# \tCreate a feature extractor model from an Alexnet architecture, that is used to train the autoencoder model\n# \tand get the most related model whilst training a new task in a sequence\n# \t\"\"\"\n#     def __init__(self, resnet_model):\n#         super(ResNet_FE, self).__init__()\n#         self.fe_model = nn.Sequential(*list(resnet_model.children())[:-1])\n#         self.fe_model.eval()\n#         self.fe_model.requires_grad_(False)\n    \n#     def forward(self, x):\n#         return self.fe_model(x)\n    \n# def get_pretrained_feat_extractor(name):\n#     \"\"\"get the feature extractor pretrained on ImageNet\n    \n#     \"\"\"\n#     if name == \"resnet18\":\n#         feat_extractor = ResNet_FE(models.resnet18(weights=True))\n#         # self.logger.info(\"Using relatedness feature extractor: ResNet18\")\n#     else:\n#         raise Exception(\"Unknown relatedness feature extractor !!!\")\n\n#     return feat_extractor\n\n# def gaussian_mean_cov(X):\n#     \"\"\"mean and covariance of Guassian distribution\n    \n#     Params:\n#         X: N x D\n#     \"\"\"\n#     device = X.device\n#     N, D = X.shape[0], X.shape[1]\n#     u = torch.mean(X, dim=0)\n#     u_row = torch.reshape(u, (1, -1))  # 1 x D\n#     cov = torch.matmul(X.T, X) - N * torch.matmul(u_row.T, u_row)  # D x D\n#     cov = cov / (N - 1)\n\n#     cov = cov * torch.diag(torch.ones(D)).to(X.device) + (torch.diag(torch.ones(D))).to(X.device)\n\n#     return u, cov\n"
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/inclearn/tools/utils.py",
    "content": "import random\nfrom copy import deepcopy\nimport numpy as np\nimport datetime\n\nimport torch\n\nfrom inclearn.tools.metrics import ClassErrorMeter\nfrom sklearn.metrics import classification_report\n\n\ndef get_date():\n    return datetime.datetime.now().strftime(\"%Y%m%d\")\n\n\ndef to_onehot(targets, n_classes):\n    if not hasattr(targets, \"device\"):\n        targets = torch.from_numpy(targets)\n    onehot = torch.zeros(targets.shape[0], n_classes).to(targets.device)\n    onehot.scatter_(dim=1, index=targets.long().view(-1, 1), value=1.0)\n    return onehot\n\n\ndef get_class_loss(network, cur_n_cls, loader):\n    class_loss = torch.zeros(cur_n_cls)\n    n_cls_data = torch.zeros(cur_n_cls)  # the num of imgs for cls i.\n    EPS = 1e-10\n    task_size = 10\n    network.eval()\n    for x, y in loader:\n        x, y = x.cuda(), y.cuda()\n        preds = network(x)['logit'].softmax(dim=1)\n        # preds[:,-task_size:] = preds[:,-task_size:].softmax(dim=1)\n        for i, lbl in enumerate(y):\n            class_loss[lbl] = class_loss[lbl] - (preds[i, lbl] + EPS).detach().log().cpu()\n            n_cls_data[lbl] += 1\n    class_loss = class_loss / n_cls_data\n    return class_loss\n\n\ndef get_featnorm_grouped_by_class(task_i,network, cur_n_cls, loader,m=None):\n    \"\"\"\n    Ret: feat_norms: list of list\n            feat_norms[idx] is the list of feature norm of the images for class idx.\n    \"\"\"\n    feats = [[] for i in range(cur_n_cls)]\n    feat_norms = np.zeros(cur_n_cls)\n    network.eval()\n    with torch.no_grad():\n        for x, y in loader:\n            x = x.cuda()\n            feat = network(task_i,x,m)['feature'].cpu()\n            for i, lbl in enumerate(y):\n                if lbl >= cur_n_cls:\n                    continue\n                feats[lbl].append(feat[y == lbl])\n    for i in range(len(feats)):\n        if len(feats[i]) != 0:\n            feat_cls = torch.cat((feats[i]))\n            feat_norms[i] = torch.norm(feat_cls, p=2, dim=1).mean().data.numpy()\n    return feat_norms\n\n\ndef set_seed(seed):\n    print(\"Set seed\", seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    torch.backends.cudnn.deterministic = True  # This will slow down training.\n    torch.backends.cudnn.benchmark = False\n\n\ndef display_weight_norm(logger, network, increments, tag):\n    weight_norms = [[] for _ in range(len(increments))]\n    increments = np.cumsum(np.array(increments))\n    for idx in range(network.module.convnets[-1].classifer.weight.shape[0]):\n        norm = torch.norm(network.module.convnets[-1].classifer.weight[idx].data, p=2).item()\n        for i in range(len(weight_norms)):\n            if idx < increments[i]:\n                break\n        weight_norms[i].append(round(norm, 3))\n    avg_weight_norm = []\n    # all_weight_norms = []\n    for idx in range(len(weight_norms)):\n        # all_weight_norms += weight_norms[idx]\n        # logger.info(\"task %s: Weight norm per class %s\" % (str(idx), str(weight_norms[idx])))\n        avg_weight_norm.append(round(np.array(weight_norms[idx]).mean(), 3))\n\n    logger.info(\"%s: Weight norm per task %s\" % (tag, str(avg_weight_norm)))\n\n\ndef display_feature_norm(task_i,logger, network, loader, n_classes, increments, tag, return_norm=False,mask=None):\n    avg_feat_norm_per_cls = get_featnorm_grouped_by_class(task_i,network, n_classes, loader,m=mask)\n    feature_norms = [[] for _ in range(len(increments))]\n    increments = np.cumsum(np.array(increments))\n    for idx in range(len(avg_feat_norm_per_cls)):\n        for i in range(len(feature_norms)):\n            if idx < increments[i]:  #Find the mapping from class idx to step i.\n                break\n        feature_norms[i].append(round(avg_feat_norm_per_cls[idx], 3))\n    avg_feature_norm = []\n    for idx in range(len(feature_norms)):\n        avg_feature_norm.append(round(np.array(feature_norms[idx]).mean(), 3))\n    logger.info(\"%s: Feature norm per class %s\" % (tag, str(avg_feature_norm)))\n    if return_norm:\n        return avg_feature_norm\n    else:\n        return\n\n\ndef check_loss(loss):\n    return not bool(torch.isnan(loss).item()) and bool((loss >= 0.0).item())\n\ndef class2task(class_form, classnum):\n    target_form = deepcopy(class_form)\n    for i in range(classnum):\n      mask = (target_form==i)\n      target_form[mask] = -(i//10)-1\n    target_form = (target_form+1)*(-1)\n    return target_form\n\ndef maskclass(pred, target, classnum, type='new'):\n    # type 为new，遮盖new的class，old遮盖旧的，all遮盖新旧\n    target_form = deepcopy(target)\n    pred_form = deepcopy(pred)\n    if type == 'old':\n        mask = np.logical_or(pred_form<(classnum-10), target_form<(classnum-10))\n        pred_form[mask] = 0\n        target_form[mask] = 0\n\n    if type == 'new':\n        mask = np.logical_or(pred_form>=(classnum-10), target_form>=(classnum-10))\n        pred_form[mask] = 1000\n        target_form[mask] = 1000\n\n    if type == 'all':\n        mask = (target_form>=(classnum-10))\n        target_form[mask] = 1000\n        mask = (pred_form>=(classnum-10))\n        pred_form[mask] = 1000\n\n        mask = (target_form<(classnum-10))\n        target_form[mask] = 0\n        mask = (pred_form<(classnum-10))\n        pred_form[mask] = 0\n\n        all_err = np.sum(pred_form!=target_form)\n        pred_form1 = deepcopy(pred_form)\n        mask = (target_form<(classnum-10))\n        pred_form[mask] = 0\n\n        new_old_err = np.sum(pred_form!=target_form)\n\n        mask = (target_form>=(classnum-10))\n        pred_form1[mask] = 1000        \n        old_new_err = np.sum(pred_form1!=target_form)\n        return all_err, new_old_err, old_new_err      \n\n\n    return pred_form, target_form\n\ndef compute_old_new_mix(ypred, ytrue, increments, n_classes, task_order):\n\n    task_means = []\n    for i in range (n_classes//10):\n        taski_mask = np.logical_and(ytrue>=i*10, ytrue<(i+1)*10)\n        task_i_mean = ypred[np.arange(ytrue.shape[0]), ytrue][taski_mask].mean().item()\n        task_means.append(task_i_mean)\n    task_mean = ypred[np.arange(ytrue.shape[0]), ytrue].mean().item()\n\n    classnum = ypred.shape[1]\n    ypred = ypred.argmax(1)\n    all_err = np.sum(ypred!=ytrue)\n    ypred_task = class2task(ypred, classnum)\n    ytrue_task = class2task(ytrue, classnum)\n    err_among_task = np.sum(ypred_task!=ytrue_task)\n    err_inner_task = all_err - err_among_task\n    # print(\"all err : {}\\n among task err: {}\\n inner task err: {}\\n\".format(all_err, err_among_task, err_inner_task))\n\n\n    ypred_new, ytrue_new =  maskclass(ypred, ytrue, n_classes, 'old')\n    new_err = np.sum(ypred_new!=ytrue_new)\n\n    ypred_old, ytrue_old =  maskclass(ypred, ytrue, n_classes, 'new')\n    old_err = np.sum(ypred_old!=ytrue_old)\n    \n    all_err, new_old_err, old_new_err =  maskclass(ypred, ytrue, n_classes, 'all')\n    print(\"******all_err:****\", all_err)\n\n    all_acc = {\"task_mean\": task_mean, \"task_means\":task_means, \"new_err\": new_err, \"old_err\":old_err, \"new_old_err\": new_old_err, \"old_new_err\": old_new_err, \"err_among_task\": err_among_task, \"err_inner_task\": err_inner_task}\n\n    return all_acc\n\n\ndef compute_task_accuracy(ypred, ytrue, increments, n_classes, task_order):\n    task_mean = ypred[np.arange(ytrue.shape[0]), ytrue].mean().item()\n    classnum = ypred.shape[1]\n    ypred = ypred.argmax(1)\n    ypred_task = class2task(ypred, classnum)\n    ytrue_task = class2task(ytrue, classnum)\n    \n\n    all_acc = {\"task_mean\": task_mean, \"class_info\": classification_report(ytrue, ypred), \"task_info\": classification_report(ytrue_task, ypred_task)}\n\n    return all_acc\n\ndef compute_accuracy(ypred, ytrue, increments, n_classes):\n    all_acc = {\"top1\": {}, \"top5\": {}}\n    topk = 5 if n_classes >= 5 else n_classes\n    ncls = np.unique(ytrue).shape[0]\n    if topk > ncls:\n        topk = ncls\n    all_acc_meter = ClassErrorMeter(topk=[1, topk], accuracy=True)\n    all_acc_meter.add(ypred, ytrue)\n    all_acc[\"top1\"][\"total\"] = round(all_acc_meter.value()[0], 3)\n    all_acc[\"top5\"][\"total\"] = round(all_acc_meter.value()[1], 3)\n    # all_acc[\"total\"] = round((ypred == ytrue).sum() / len(ytrue), 3)\n\n    # for class_id in range(0, np.max(ytrue), task_size):\n    start, end = 0, 0\n    for i in range(len(increments)):\n        if increments[i] <= 0:\n            pass\n        else:\n            start = end\n            end += increments[i]\n\n            idxes = np.where(np.logical_and(ytrue >= start, ytrue < end))[0]\n            topk_ = 5 if increments[i] >= 5 else increments[i]\n            ncls = np.unique(ytrue[idxes]).shape[0]\n            if topk_ > ncls:\n                topk_ = ncls\n            cur_acc_meter = ClassErrorMeter(topk=[1, topk_], accuracy=True)\n            cur_acc_meter.add(ypred[idxes], ytrue[idxes])\n            top1_acc = (ypred[idxes].argmax(1) == ytrue[idxes]).sum() / idxes.shape[0] * 100\n            if start < end:\n                label = \"{}-{}\".format(str(start).rjust(2, \"0\"), str(end - 1).rjust(2, \"0\"))\n            else:\n                label = \"{}-{}\".format(str(start).rjust(2, \"0\"), str(end).rjust(2, \"0\"))\n            all_acc[\"top1\"][label] = round(top1_acc, 3)\n            all_acc[\"top5\"][label] = round(cur_acc_meter.value()[1], 3)\n            # all_acc[label] = round((ypred[idxes] == ytrue[idxes]).sum() / len(idxes), 3)\n\n    return all_acc\n\n\ndef make_logger(log_name, savedir='.logs/'):\n    \"\"\"Set up the logger for saving log file on the disk\n    Args:\n        cfg: configuration dict\n\n    Return:\n        logger: a logger for record essential information\n    \"\"\"\n    import logging\n    import os\n    from logging.config import dictConfig\n    import time\n\n    logging_config = dict(\n        version=1,\n        formatters={'f_t': {\n            'format': '\\n %(asctime)s | %(levelname)s | %(name)s \\t %(message)s'\n        }},\n        handlers={\n            'stream_handler': {\n                'class': 'logging.StreamHandler',\n                'formatter': 'f_t',\n                'level': logging.INFO\n            },\n            'file_handler': {\n                'class': 'logging.FileHandler',\n                'formatter': 'f_t',\n                'level': logging.INFO,\n                'filename': None,\n            }\n        },\n        root={\n            'handlers': ['stream_handler', 'file_handler'],\n            'level': logging.DEBUG,\n        },\n    )\n    # set up logger\n    log_file = '{}.log'.format(log_name)\n    # if folder not exist,create it\n    if not os.path.exists(savedir):\n        os.makedirs(savedir)\n    log_file_path = os.path.join(savedir, log_file)\n\n    logging_config['handlers']['file_handler']['filename'] = log_file_path\n\n    open(log_file_path, 'w').close()  # Clear the content of logfile\n    # get logger from dictConfig\n    dictConfig(logging_config)\n\n    logger = logging.getLogger()\n\n    return logger"
  },
  {
    "path": "examples/Structural_Development/SCA-SNN/main.py",
    "content": "import sys\nimport os\nimport os.path as osp\nimport copy\nimport time\nimport shutil\nimport cProfile\nimport logging\nfrom pathlib import Path\nimport numpy as np\nimport random\nfrom easydict import EasyDict as edict\nfrom tensorboardX import SummaryWriter\nimport os\nimport inclearn.convnet.maskcl2 as Mask\n\nos.environ['CUDA_VISIBLE_DEVICES']='0'\n\nrepo_name = 'TCIL'\nbase_dir = '/data1/hanbing/TCIL10/'\nsys.path.insert(0, base_dir)\n\nfrom sacred import Experiment\nex = Experiment(base_dir=base_dir, save_git_info=False)\n\n\nimport torch\n\nfrom inclearn.tools import factory, results_utils, utils\n\nfrom inclearn.tools.metrics import IncConfusionMeter\nfrom inclearn.tools.similar import Appr\n\n\n\ndef initialization(config, seed, mode, exp_id):\n\n    torch.backends.cudnn.benchmark = True  # This will result in non-deterministic results.\n    # ex.captured_out_filter = lambda text: 'Output capturing turned off.'\n    cfg = edict(config)\n    utils.set_seed(cfg['seed'])\n    if exp_id is None:\n        exp_id = -1\n        cfg.exp.savedir = \"./logs_aphal\"\n    logger = utils.make_logger(str(exp_id)+str(cfg.exp.name)+str(mode), savedir=cfg.exp.savedir)\n\n    # Tensorboard\n    exp_name = '{exp_id}_{cfg[\"exp\"][\"name\"]}' if exp_id is not None else '../inbox/{cfg[\"exp\"][\"name\"]}'\n    tensorboard_dir = cfg[\"exp\"][\"tensorboard_dir\"] + \"/{exp_name}\"\n\n    # If not only save latest tensorboard log.\n    # if Path(tensorboard_dir).exists():\n    #     shutil.move(tensorboard_dir, cfg[\"exp\"][\"tensorboard_dir\"] + f\"/../inbox/{time.time()}_{exp_name}\")\n\n    tensorboard = SummaryWriter(tensorboard_dir)\n\n    return cfg, logger, tensorboard\n\n\n@ex.command\ndef train(_run, _rnd, _seed):\n    cfg, ex.logger, tensorboard = initialization(_run.config, _seed, \"train\", _run._id)\n    ex.logger.info(cfg)\n    cfg.data_folder = osp.join(base_dir, \"data\")\n\n    start_time = time.time()\n    _train(cfg, _run, ex, tensorboard)\n    ex.logger.info(\"Training finished in {}s.\".format(int(time.time() - start_time)))\n\n\ndef _train(cfg, _run, ex, tensorboard):\n    device = factory.set_device(cfg)\n    trial_i = cfg['trial']\n    torc=cfg['distillation']\n\n    inc_dataset = factory.get_data(cfg, trial_i)\n    ex.logger.info(\"classes_order\")\n    ex.logger.info(inc_dataset.class_order)\n\n    model = factory.get_model(cfg, trial_i, _run, ex, tensorboard, inc_dataset)\n    mask=Mask.Mask(model._network.convnets[-1])\n    if _run.meta_info[\"options\"][\"--file_storage\"] is not None:\n        _save_dir = osp.join(_run.meta_info[\"options\"][\"--file_storage\"], str(_run._id))\n    else:\n        _save_dir = cfg[\"exp\"][\"ckptdir\"]\n\n    results = results_utils.get_template_results(cfg)\n    appr=Appr(model._network,10,torc)\n    for task_i in range(inc_dataset.n_tasks):\n        task_info, train_loader, val_loader, test_loader = inc_dataset.new_task(task_i)\n\n        model.set_task_info(\n            task=task_info[\"task\"],\n            total_n_classes=task_info[\"max_class\"],\n            increment=task_info[\"increment\"],\n            n_train_data=task_info[\"n_train_data\"],\n            n_test_data=task_info[\"n_test_data\"],\n            n_tasks=inc_dataset.n_tasks,\n        )\n        if torc:\n            strategy, expert_id ,min_dist,all_dist,all_dist2= appr.learn(task_i, val_loader,  cfg['batch_size'],device)\n        else:\n            strategy, expert_id ,min_dist,all_dist,all_dist2= appr.learn(task_i, test_loader[task_i],  cfg['batch_size'],device)\n        print(\"Task:\",task_i,strategy, expert_id,min_dist,all_dist,all_dist2)\n\n\n        model.before_task(task_i, inc_dataset,mask,min_dist,all_dist)\n        \n        # TODO: Move to incmodel.py\n        if 'min_class' in task_info:\n            ex.logger.info(\"Train on {}->{}.\".format(task_info[\"min_class\"], task_info[\"max_class\"]))\n\n        if torc:\n            model.train_task(task_i,train_loader, test_loader,mask,min_dist,all_dist)\n            model.after_task(task_i, inc_dataset,mask)\n            appr.after_learn(task_i, val_loader, cfg['batch_size'],device)\n\n        else:\n            model.train_task(task_i,train_loader, val_loader[task_i],mask,min_dist,all_dist)\n            appr.after_learn(task_i, test_loader[task_i], cfg['batch_size'],device)\n\n\n        if torc:\n            ex.logger.info(\"Eval on {}->{}.\".format(0, task_info[\"max_class\"]))\n            ypred, ytrue = model.eval_task(task_i,test_loader,mask)\n            acc_stats = utils.compute_accuracy(ypred, ytrue, increments=model._increments, n_classes=model._n_classes)\n            #Logging\n            model._tensorboard.add_scalar(\"taskaccu/trial{trial_i}\", acc_stats[\"top1\"][\"total\"], task_i)\n            _run.log_scalar(\"trial{trial_i}_taskaccu\", acc_stats[\"top1\"][\"total\"], task_i)\n            _run.log_scalar(\"trial{trial_i}_task_top5_accu\", acc_stats[\"top5\"][\"total\"], task_i)\n            ex.logger.info(\"top1:\"+str(acc_stats['top1']))\n            ex.logger.info(\"top5:\"+str(acc_stats['top5']))\n            results[\"results\"].append(acc_stats)\n        else:\n            for taski in range(task_i+1):\n                ypred, ytrue = model.eval_task(taski,test_loader[taski],mask)\n    \n                acc_stats = utils.compute_accuracy(ypred, ytrue, increments=[1], n_classes=model._n_classes)\n\n                model._tensorboard.add_scalar(f\"taskaccu/trial{trial_i}\", acc_stats[\"top1\"][\"total\"], taski)\n\n                _run.log_scalar(f\"trial{trial_i}_taskaccu\", acc_stats[\"top1\"][\"total\"], taski)\n                _run.log_scalar(f\"trial{trial_i}_task_top5_accu\", acc_stats[\"top5\"][\"total\"], taski)\n\n                ex.logger.info(f\"top1:{acc_stats['top1']}\")\n                ex.logger.info(f\"top5:{acc_stats['top5']}\")\n\n                results[\"results\"].append(acc_stats)\n\n    top1_avg_acc, top5_avg_acc = results_utils.compute_avg_inc_acc(results[\"results\"])\n\n    _run.info[\"trial{trial_i}\"][\"avg_incremental_accu_top1\"] = top1_avg_acc\n    _run.info[\"trial{trial_i}\"][\"avg_incremental_accu_top5\"] = top5_avg_acc\n    ex.logger.info(\"Average Incremental Accuracy Top 1: {} Top 5: {}.\".format(\n        _run.info[\"trial{trial_i}\"][\"avg_incremental_accu_top1\"],\n        _run.info[\"trial{trial_i}\"][\"avg_incremental_accu_top5\"],\n    ))\n    if cfg[\"exp\"][\"name\"]:\n        results_utils.save_results(results, cfg[\"exp\"][\"name\"])\n\n\nif __name__ == \"__main__\":\n    ex.add_config(\"/data1/hanbing/SCA-SNN/configs/train.yaml\")\n    ex.run_commandline()\n"
  },
  {
    "path": "examples/Structural_Development/SD-SNN/README.md",
    "content": "# Adaptive Sparse Structure Development with Pruning and Regeneration for Spiking Neural Networks #\n\n## Requirments ##\n* numpy\n* timm\n* pytorch >= 1.7.0\n* collections\n* argparse\n\n## Run ##\n\n```CUDA_VISIBLE_DEVICES=0 python main.py```\n\n## Citation ##\nIf you find the code and dataset useful in your research, please consider citing:\n```\n@article{han2025adaptive,\n  title={Adaptive sparse structure development with pruning and regeneration for spiking neural networks},\n  author={Han, Bing and Zhao, Feifei and Pan, Wenxuan and Zeng, Yi},\n  journal={Information Sciences},\n  volume={689},\n  pages={121481},\n  year={2025},\n  publisher={Elsevier}\n}\n  \n@article{zeng2023braincog,\n  title={Braincog: A spiking neural network based, brain-inspired cognitive intelligence engine for brain-inspired ai and brain simulation},\n  author={Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and others},\n  journal={Patterns},\n  volume={4},\n  number={8},\n  year={2023},\n  publisher={Elsevier},\n}\n```\n\nEnjoy!\n"
  },
  {
    "path": "examples/Structural_Development/SD-SNN/main.py",
    "content": "import argparse\nimport time\nimport os\nfrom collections import OrderedDict\nfrom contextlib import suppress\nfrom datetime import datetime\nimport sys\nsys.path.append('..')\nimport logging\nimport torch\nfrom timm.data import ImageDataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset\nfrom timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model\nfrom timm.utils import *\nfrom timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy\nfrom timm.optim import create_optimizer\nfrom timm.scheduler import create_scheduler\nfrom timm.utils import ApexScaler, NativeScaler\n\nfrom braincog.base.node.node import *\nfrom braincog.base.encoder.encoder import *\nfrom braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule\nfrom braincog.base.utils.criterions import *\nfrom braincog.datasets.datasets import *\n\nfrom prun_and_generation import *\nfrom snn_model import *\nfrom utils import *\n\n_logger = logging.getLogger('train')\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nexp_name = '-'.join([datetime.now().strftime(\"%Y%m%d-%H%M%S\"),'c10'])\noutput_dir = get_outdir('./', 'train', exp_name)\nsetup_default_logging(log_path=os.path.join(output_dir, 'log.txt'))\n_logger.info(exp_name)\n\nconfig_parser = cfg = argparse.ArgumentParser(description='Training Config', add_help=False)\nmodel='cifar_convnet'\ndataset='cifar10'\nnum_classes=10\nstep=8\nencode='direct'\nnode_type='PLIFNode'\nthresh=0.5\ntau=2.0\ntorch.backends.cudnn.benchmark = True\ndevicee=0\nseed=36\n\nchannels = 2\nlr=5e-3\nbatch_size=50\nepochs=600\nlinear_scaled_lr = lr * batch_size/ 1024.0\ncfg.opt='adamw'\ncfg.lr=linear_scaled_lr\ncfg.weight_decay=0.01\ncfg.momentum=0.9\ncfg.epochs=epochs\ncfg.sched='cosine'\ncfg.min_lr=1e-5\ncfg.warmup_lr=1e-6\ncfg.warmup_epochs=5\ncfg.cooldown_epochs=10\ncfg.decay_rate=0.1\n\nepoch_prune = 1\neval_metric='top1'\nbest_test = 0\nbest_testepoch = 0\nbest_testprun = 0\nbest_testepochprun = 0\nspines_num=18\n\ntorch.cuda.set_device('cuda:%d' % devicee)\ntorch.manual_seed(seed)\n\nmodel = my_cifar_model(step=step,encode_type=encode,node_type=node_type,num_classes=num_classes)\nmodel = model.cuda()\nprint(model)\noptimizer = create_optimizer(cfg, model)\nlr_scheduler, num_epochs = create_scheduler(cfg, optimizer)\n\nloader_train, loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % dataset)(batch_size=batch_size, step=step)\n\ntrain_loss_fn = UnilateralMse(1.)\nvalidate_loss_fn = UnilateralMse(1.)\n\nm = Mask(model,spines_num)\n\ndef train_epoch(\n        epoch, model, loader, optimizer, loss_fn,\n        lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,\n        loss_scaler=None, model_ema=None, mixup_fn=None):\n\n    batch_time_m = AverageMeter()\n    data_time_m = AverageMeter()\n    losses_m = AverageMeter()\n    top1_m = AverageMeter()\n\n    model.train()\n\n    end = time.time()\n    last_idx = len(loader) - 1\n    num_updates = epoch * len(loader)\n\n    for batch_idx, (inputs, target) in enumerate(loader):\n        last_batch = batch_idx == last_idx\n        data_time_m.update(time.time() - end)\n        inputs, target = inputs.type(torch.FloatTensor).cuda(), target.cuda()\n        output = model(inputs)\n\n        loss = loss_fn(output, target)\n        acc1, acc5 = accuracy(output, target, topk=(1, 5))\n        losses_m.update(loss.item(), inputs.size(0))\n        top1_m.update(acc1.item(), inputs.size(0))\n\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        batch_time_m.update(time.time() - end)\n        if last_batch or batch_idx %100 == 0:\n            lrl = [param_group['lr'] for param_group in optimizer.param_groups]\n            lr = sum(lrl) / len(lrl)\n            print(\"Train: epoch:\",epoch,batch_idx,\"/\",len(loader),\"loss:\",losses_m.avg,\"acc1:\", top1_m.avg,\"lr:\",lr)\n\n        if lr_scheduler is not None:\n            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)\n\n        end = time.time()\n        # end for\n\n    if hasattr(optimizer, 'sync_lookahead'):\n        optimizer.sync_lookahead()\n\n    return OrderedDict([('loss', losses_m.avg)])\n\ndef validate(model, loader, loss_fn, amp_autocast=suppress, log_suffix=''):\n    batch_time_m = AverageMeter()\n    losses_m = AverageMeter()\n    top1_m = AverageMeter()\n    top5_m = AverageMeter()\n\n    model.eval()\n\n    end = time.time()\n    last_idx = len(loader) - 1\n    with torch.no_grad():\n        for batch_idx, (inputs, target) in enumerate(loader):\n            #print(inputs.size())\n            # inputs = inputs.type(torch.float64)\n            last_batch = batch_idx == last_idx\n            inputs = inputs.type(torch.FloatTensor).cuda()\n            target = target.cuda()\n\n            output = model(inputs)\n            if isinstance(output, (tuple, list)):\n                output = output[0]\n\n            loss = loss_fn(output, target)\n            acc1, acc5 = accuracy(output, target, topk=(1, 5))\n\n            reduced_loss = loss.data\n\n            torch.cuda.synchronize()\n\n            losses_m.update(reduced_loss.item(), inputs.size(0))\n            top1_m.update(acc1.item(), output.size(0))\n            if last_batch or batch_idx %100 == 0:\n                print(\"Test: loss:\",losses_m.avg,\"acc1:\", top1_m.avg)\n                \n            batch_time_m.update(time.time() - end)\n            end = time.time()\n\n\n    metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])\n\n    return metrics\n\n\nfor epoch in range(epochs):\n\n    train_metrics= train_epoch(\n        epoch, model, loader_train, optimizer, train_loss_fn,\n        lr_scheduler=lr_scheduler)\n\n    if epoch==0:\n        m.init_length()\n    if epoch>0:\n        m.model = model\n        m.init_mask_dsd()\n        if epoch>spines_num:\n            matt=m.do_mask_dsd()\n        if epoch>2*spines_num:\n            m.do_growth_ww(epoch)\n            matt=m.do_pruning_dsd(epoch)\n        model = m.model\n    cc=m.if_zero()\n\n    eval_metrics = validate(model, loader_eval, validate_loss_fn)\n    top1=eval_metrics['top1']\n    if top1 > best_testprun:\n        best_testprun = top1\n        best_testepochprun =epoch\n    if epoch%40==0:\n        print('best acc:',best_testprun,'best epoch:',best_testepochprun)\n    if epoch>4:\n        _logger.info('*** epoch: {0} (pruning rate {1},acc:{2})'.format(epoch, cc,top1))\n\n    if lr_scheduler is not None:\n        lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])\n"
  },
  {
    "path": "examples/Structural_Development/SD-SNN/prun_and_generation.py",
    "content": "import numpy as np\r\nimport torch\r\nimport math\r\nfrom utils import *\r\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\r\n\r\nclass Mask:\r\n    def __init__(self, model,count_thre):\r\n        self.model_size = {}\r\n        self.model_length = {}\r\n        self.compress_rate = {}\r\n        self.mat = {}\r\n        self.model = model\r\n        self.mask_index = []\r\n        self.distance_rate = {}\r\n        self.filter_small_index = {}\r\n        self.filter_large_index = {}\r\n        self.similar_matrix = {}\r\n        self.norm_matrix = {}\r\n\r\n        # dendritic dynamics\r\n        self.cur_range_pos = {}  # current range foe every weight\r\n        self.cur_range_neg = {}  # current range foe every weight\r\n        self.dsum_pos_out = {}  # current range foe every weight\r\n        self.dsum_neg_out = {}  # current range foe every weight\r\n        self.dsum_pos_in = {}  # current range foe every weight\r\n        self.dsum_neg_in = {}  # current range foe every weight\r\n        self.dendritic_previous_pos = {}  # the index for beyond range weight\r\n        self.dendritic_previous_neg = {}  # the index for within range weight\r\n        self.dendritic_previous_in = {}  # the index for within range weight\r\n        self.dendritic_count_pos = {}  # the count of beyond range for every weight\r\n        self.dendritic_count_neg = {}  # the count of within range for every weight\r\n        self.dendritic_count_in = {}  # the count of within range for every weight\r\n        self.out = 0\r\n        self.count_thre=count_thre\r\n        self.weight_previous = {}\r\n        self.mask={}\r\n        self.codebook = {}\r\n        self.codebookww={}\r\n        self.his_ind={}\r\n        self.his_groth={}\r\n        self.his_gro_count={}\r\n        self.prune={}\r\n        self.pruncc={}\r\n        self.prunn={}\r\n        self.his_prun={}\r\n        self.convlayer =model.convlayer \r\n        for index in self.convlayer:\r\n            if index<self.convlayer[-1]:\r\n                self.pruncc[index]=0.5\r\n            else:\r\n                self.prune[index]=2\r\n            self.prunn[index]=0\r\n        self.feature=self.model.feature\r\n        self.fc1=self.model.fc_prun[0]\r\n\r\n\r\n    def convert2tensor(self, x):\r\n        x = torch.FloatTensor(x)\r\n        return x\r\n\r\n    def init_length(self):\r\n        for index in self.convlayer:\r\n            if index<self.convlayer[-1]:\r\n                ww=self.feature[index].conv.weight\r\n            if index==self.convlayer[-1]:\r\n                ww=self.fc1.fc.weight\r\n            self.model_size[index]=ww.size()\r\n            self.codebook[index]=torch.ones_like(ww)\r\n\r\n        for index1 in self.model_size:\r\n            for index2 in range(0, len(self.model_size[index1])):\r\n                if index2 == 0:\r\n                    self.model_length[index1] = self.model_size[index1][0]\r\n                else:\r\n                    self.model_length[index1] *= self.model_size[index1][index2]\r\n            # dendritic parameters initialize\r\n            self.his_groth[index1]=np.array([self.model_length[index1]+1])\r\n            self.his_prun[index1]=np.array([self.model_length[index1]+1])\r\n            self.his_ind[index1]=np.array([-1])\r\n            self.dendritic_previous_pos[index1] = np.array([self.model_length[index1] + 1]) \r\n            self.dendritic_previous_neg[index1] = np.array([self.model_length[index1] + 1])\r\n            self.dendritic_count_pos[index1] = np.zeros((self.model_length[index1],))\r\n            self.dendritic_count_neg[index1] = np.zeros((self.model_length[index1],))\r\n            self.dendritic_count_in[index1] = np.zeros(\r\n                (self.model_length[index1],))  # the count of within range for every weight\r\n            self.dendritic_previous_in[index1] = np.array(\r\n                [self.model_length[index1] + 1])  # the index for within range weight\r\n            self.dsum_pos_out[index1] = np.zeros((self.model_length[index1],))  # current range for every weight\r\n            self.dsum_neg_out[index1] = np.zeros((self.model_length[index1],))  # current range for every weight\r\n            self.dsum_pos_in[index1] = np.zeros((self.model_length[index1],))  # current range for every weight\r\n            self.dsum_neg_in[index1] = np.zeros((self.model_length[index1],))  # current range for every weight\r\n            if index1<self.convlayer[-1]:\r\n                self.his_gro_count[index1] = np.zeros((self.model_size[index1][0]*self.model_size[index1][1]))\r\n            else:\r\n                self.his_gro_count[index1] = np.zeros((self.model_length[index1],))\r\n\r\n        for index in self.convlayer:\r\n            if index<self.convlayer[-1]:\r\n                ww=self.feature[index].conv.weight\r\n            if index==self.convlayer[-1]:\r\n                ww=self.fc1.fc.weight\r\n            weight_tmp = ww.data.view(-1)  # one conv one weight vector\r\n            weight_tmp_np = weight_tmp.cpu().numpy()\r\n            self.weight_previous[index] = abs(weight_tmp_np)\r\n            self.cur_range_pos[index] = weight_tmp_np.max() * np.ones((self.model_length[index],))  # args.init_range\r\n            self.cur_range_neg[index] = -1 * weight_tmp_np.max() * np.ones((self.model_length[index],))  # -1*args.init_range\r\n\r\n\r\n    def init_mask_dsd(self):\r\n        for index in self.convlayer:\r\n            if index<self.convlayer[-1]:\r\n                ww=self.feature[index].conv.weight\r\n            if index==self.convlayer[-1]:\r\n                ww=self.fc1.fc.weight\r\n                ww=ww.data*self.codebook[index].cuda()  \r\n                self.cur_range_pos[index], self.cur_range_neg[index] = \\\r\n                    self.get_range_weight(ww,self.model_length[index],index,self.count_thre)  # update current rang\r\n    \r\n    def get_range_weight(self, weight_torch, length, i, count_thre):\r\n        # >r+\r\n        weight_vec = weight_torch.view(-1)  # one conv one weight vector\r\n        weight_vec_np = weight_vec.cpu().numpy()\r\n        weight_np_abs = abs(weight_vec_np)\r\n        dendritic_pos_tmp = np.where((weight_vec_np >= self.cur_range_pos[i]))  # find the weight beyong range\r\n        pos_index = set(dendritic_pos_tmp[0]) & set(self.dendritic_previous_pos[i])  # calculate intersection  consectively\r\n        pos_zero = set([i for i in range(length)]) - pos_index  # non-intersection weight will be count from 0\r\n        pos_index = np.array(list(pos_index))\r\n        pos_zero = np.array(list(pos_zero))\r\n        if pos_zero.size > 0:\r\n            self.dendritic_count_pos[i][pos_zero] = 0\r\n            self.dsum_pos_out[i][pos_zero] = 0\r\n        if pos_index.size > 0:\r\n            self.dendritic_count_pos[i][pos_index] = self.dendritic_count_pos[i][pos_index] + 1  # intewrsection +1\r\n            self.dsum_pos_out[i][pos_index] += weight_vec_np[pos_index] - self.cur_range_pos[i][pos_index]\r\n        dendritic_index = np.where(self.dendritic_count_pos[i] >= count_thre)  # count>threshold\r\n        self.out = self.out + len(dendritic_index[0])\r\n        self.cur_range_pos[i][dendritic_index]+=(self.dsum_pos_out[i][dendritic_index] / count_thre)  #  self.cur_range_pos[i][dendritic_index]+0.025 # \r\n        self.dendritic_count_pos[i][dendritic_index] = 0  # intrsection count set to 0\r\n        self.dsum_pos_out[i][dendritic_index] = 0\r\n        self.dendritic_previous_pos[i] = dendritic_pos_tmp[0]  # update previous\r\n        # <r-\r\n        dendritic_neg_tmp = np.where((weight_vec_np <= self.cur_range_neg[i]))  # find the weight beyong range\r\n        neg_index = set(dendritic_neg_tmp[0]) & set(\r\n            self.dendritic_previous_neg[i])  # calculate intersection  consectively\r\n        neg_zero = set([i for i in range(length)]) - neg_index  # non-intersection weight will be count from 0\r\n        neg_index = np.array(list(neg_index))\r\n        neg_zero = np.array(list(neg_zero))\r\n        if neg_zero.size > 0:\r\n            self.dendritic_count_neg[i][neg_zero] = 0\r\n            self.dsum_neg_out[i][neg_zero] = 0\r\n        if neg_index.size > 0:\r\n            self.dendritic_count_neg[i][neg_index] = self.dendritic_count_neg[i][neg_index] + 1  # intewrsection +1\r\n            self.dsum_neg_out[i][neg_index] += weight_vec_np[neg_index] - self.cur_range_neg[i][neg_index]\r\n        dendritic_index_neg = np.where(self.dendritic_count_neg[i] >= count_thre)  # count>threshold\r\n        self.out = self.out + len(dendritic_index_neg[0])\r\n        self.cur_range_neg[i][dendritic_index_neg] += (self.dsum_neg_out[i][dendritic_index_neg] / count_thre)  #self.cur_range_neg[i][dendritic_index_neg]-0.025 #\r\n        self.dendritic_count_neg[i][dendritic_index_neg] = 0  # intrsection count set to 0\r\n        self.dsum_neg_out[i][dendritic_index_neg] = 0\r\n        self.dendritic_previous_neg[i] = dendritic_neg_tmp[0]  # update previous\r\n        # r-~r+\r\n        dendritic_in_tmp = np.where((weight_np_abs < self.weight_previous[i]))  # find the weight beyong range\r\n        in_index = set(dendritic_in_tmp[0]) & set(self.dendritic_previous_in[i])  # calculate intersection  consectively\r\n        in_zero = set([i for i in range(length)]) - in_index  # non-intersection weight will be count from 0\r\n        in_index = np.array(list(in_index))\r\n        in_zero = np.array(list(in_zero))\r\n        if in_zero.size > 0:\r\n            self.dendritic_count_in[i][in_zero] = 0\r\n            self.dsum_neg_in[i][in_zero] = 0\r\n        if in_index.size > 0:\r\n            self.dendritic_count_in[i][in_index] = self.dendritic_count_in[i][in_index] + 1  # intewrsection +1\r\n            self.dsum_neg_in[i][in_index] += weight_np_abs[in_index] - self.weight_previous[i][in_index]\r\n        dendritic_index_in = np.where(self.dendritic_count_in[i] >= count_thre)  # count>threshold\r\n        self.cur_range_pos[i][dendritic_index_in] =  0.75*self.cur_range_pos[i][dendritic_index_in]  # self.cur_range_pos[i][dendritic_index_in]-0.025\r\n        self.cur_range_neg[i][dendritic_index_in] =  0.75*self.cur_range_neg[i][dendritic_index_in]  # self.cur_range_neg[i][dendritic_index_in]+0.025\r\n        self.dendritic_count_in[i][dendritic_index_in] = 0  # intrsection count set to 0\r\n        self.dsum_neg_in[i][dendritic_index_in] = 0\r\n        self.dendritic_previous_in[i] = dendritic_in_tmp[0]  # update previous\r\n        self.weight_previous[i] = weight_np_abs\r\n        print('dendritic dynamics done', np.mean(self.cur_range_pos[i]), np.mean(self.cur_range_neg[i][0]))\r\n        return self.cur_range_pos[i], self.cur_range_neg[i]\r\n\r\n    def do_mask_dsd(self):\r\n        for index in self.convlayer:\r\n            if index<self.convlayer[-1]:\r\n                ww=self.feature[index].conv.weight\r\n            if index==self.convlayer[-1]:\r\n                ww=self.fc1.fc.weight\r\n            a = ww.data.view(self.model_length[index])\r\n            a = a.cpu().numpy()\r\n            a[a > self.cur_range_pos[index]] = self.cur_range_pos[index][\r\n                a > self.cur_range_pos[index]]  # weight beyond range set to range\r\n            a[a < self.cur_range_neg[index]] = self.cur_range_neg[index][a < self.cur_range_neg[index]]\r\n            a = torch.FloatTensor(a).cuda()\r\n            ww.data = a.view(self.model_size[index])\r\n        print(\"mask Done\")\r\n\r\n    def do_pruning_dsd(self,epoch):\r\n        for index in self.convlayer:\r\n            if index<self.convlayer[-1]:\r\n                rate = int(self.model_size[index][0]* self.pruncc[index]/100)\r\n                b=unit(self.cur_range_pos[index]+abs(self.cur_range_neg[index]))\r\n                b=b*self.codebook[index].view(-1).cpu().numpy()\r\n                a=torch.tensor(b)\r\n                prer = a.reshape(self.model_size[index][0], -1)\r\n                presum = torch.sum(prer, dim=1)\r\n                n_range =presum\r\n                n_index = torch.argsort(n_range)\r\n                n_index=n_index[0:rate]\r\n                n_index=n_index.cpu().numpy()\r\n                ind=set(n_index)-set(self.his_prun[index])\r\n                ind=np.array(list(ind))\r\n                self.codebook[index][ind] = 0\r\n                ww=self.feature[index].conv.weight\r\n                a = ww.data*self.codebook[index].cuda()\r\n                ww.data = a\r\n                self.his_prun[index]=set(self.his_prun[index]) | set(ind)\r\n                self.cal_prune(index,epoch)\r\n            if index==self.convlayer[-1]:\r\n                rate = int(self.model_size[index][0] * self.prune[index]/100)\r\n                b=unit(self.cur_range_pos[index]+abs(self.cur_range_neg[index]))\r\n                b=b*self.codebook[index].view(-1).cpu().numpy()\r\n                a=torch.tensor(b)\r\n                prer = a.reshape(self.model_size[index][0], -1)\r\n                presum = torch.sum(prer, dim=1)\r\n                n_range =presum\r\n                n_index = torch.argsort(n_range)[0:rate]\r\n                n_index=n_index.cpu().numpy()\r\n                ind=set(n_index)-set(self.his_prun[index])\r\n                ind=np.array(list(ind))\r\n                self.codebook[index][ind] = 0\r\n                ww=self.fc1.fc.weight\r\n                a = ww.data*self.codebook[index].cuda()\r\n                ww.data = a\r\n                self.his_prun[index]=set(self.his_prun[index]) | set(ind)\r\n                self.cal_prune(index,epoch)\r\n\r\n    def cal_prune(self,index,epoch):\r\n        if index<self.convlayer[-1]:\r\n            ind=self.convlayer.index(index)\r\n            code=self.codebook[self.convlayer[ind+1]].view(self.model_size[self.convlayer[ind+1]][0],-1)\r\n            sumbook=torch.sum(code,dim=1)\r\n            prun=torch.where(sumbook<20)[0]\r\n            prunn=prun.size()[0]\r\n            codei=self.codebook[index].view(self.model_size[index][0],-1)\r\n            sumbooki=torch.sum(codei,dim=1)\r\n            pruni=torch.where(sumbooki<20)[0]\r\n            prunni=pruni.size()[0]\r\n            self.pruncc[index]=self.pruncc[index]+math.exp(-(epoch-2*self.count_thre-1)/5)*(self.model_size[index][0]-prunni)/(self.model_size[self.convlayer[ind+1]][0]-prunn)\r\n            print(prunn,self.pruncc[index])\r\n        else:\r\n            if epoch<=60:\r\n                aphla=0.5*math.exp(-(epoch-37)) #0.25**(epoch-37) math.exp(-(epoch-37))\r\n            if epoch>60:\r\n                aphla=0.0005\r\n            sumbook=torch.sum(self.codebook[self.convlayer[-1]],dim=1)\r\n            prun=torch.where(sumbook<50)[0]\r\n            prunn=prun.size()[0]\r\n            self.prune[index]=self.prune[index]+aphla*(512-prunn)/10\r\n            print(prunn,self.prune[index])\r\n\r\n    def do_growth_ww(self,epoch):\r\n        for index in self.convlayer:\r\n            if index<self.convlayer[-1]:\r\n                ww=self.feature[index].conv.weight\r\n                ww=ww.data\r\n                ww=torch.sum(torch.sum(ww,dim=2),dim=2).view(-1).cpu().numpy()\r\n                code=torch.sum(torch.sum(self.codebook[index],dim=2),dim=2)\r\n                p_index=np.where(code.view(-1).cpu().numpy()==0)[0]\r\n                rate=65+1.1**(epoch- 2*self.count_thre-1)\r\n                if rate>99:\r\n                    rate=99\r\n                gg=np.percentile(ww, rate)\r\n                grow=np.where(ww>gg)[0]\r\n                growth_ind=set(grow) & set(p_index)\r\n                growth_index=growth_ind & set(self.his_groth[index])\r\n                zero_index=set([i for i in range(ww.size)]) - growth_index\r\n                growth_index=np.array(list(growth_index))\r\n                zero_index=np.array(list(zero_index))\r\n                if zero_index.size>0:\r\n                    self.his_gro_count[index][zero_index]=0\r\n                if growth_index.size>0:\r\n                    self.his_gro_count[index][growth_index]=self.his_gro_count[index][growth_index]+1\r\n                gr_index=np.where(self.his_gro_count[index]> self.count_thre)[0]\r\n                self.codebook[index]=self.codebook[index].view(-1)\r\n                for x in range(len(gr_index)):\r\n                    self.codebook[index][gr_index[x]*9:(gr_index[x]+1)*9]=1\r\n                self.codebook[index]=self.codebook[index].view(self.model_size[index])\r\n                print(len(gr_index),len(growth_ind),len(p_index))\r\n                self.his_groth[index]=growth_ind\r\n                self.his_gro_count[index][gr_index]=0\r\n            if index==self.convlayer[-1]:\r\n                ww=self.fc1.fc.weight\r\n                ww=ww.data\r\n                ww=ww.view(-1).cpu().numpy()\r\n                p_index=np.where(self.codebook[index].view(-1).cpu().numpy()==0)[0]\r\n                rate=60+1.1**(epoch- 2*self.count_thre-1)\r\n                if rate>99:\r\n                    rate=99\r\n                gg=np.percentile(ww, rate)\r\n                grow=np.where(ww>gg)[0]\r\n                growth_ind=set(grow) & set(p_index)\r\n                growth_index=growth_ind & set(self.his_groth[index])\r\n                zero_index=set([i for i in range(ww.size)]) - growth_index\r\n                growth_index=np.array(list(growth_index))\r\n                zero_index=np.array(list(zero_index))\r\n                if zero_index.size>0:\r\n                    self.his_gro_count[index][zero_index]=0\r\n                if growth_index.size>0:\r\n                    self.his_gro_count[index][growth_index]=self.his_gro_count[index][growth_index]+1\r\n                gr_index=np.where(self.his_gro_count[index]> self.count_thre)[0]\r\n                self.codebook[index]=self.codebook[index].view(-1)\r\n                self.codebook[index][gr_index]=1\r\n                self.codebook[index]=self.codebook[index].view(self.model_size[index][0],-1)\r\n                print(len(gr_index),len(growth_ind),len(p_index))\r\n                self.his_groth[index]=growth_ind\r\n                self.his_gro_count[index][gr_index]=0\r\n\r\n    def if_zero(self):\r\n        cc=[]\r\n        for index in self.convlayer:\r\n            if index<self.convlayer[-1]:\r\n                ww=self.feature[index].conv.weight\r\n            if index==self.convlayer[-1]:\r\n                ww=self.fc1.fc.weight\r\n            if len(ww.size()) > 1:\r\n                a = ww.data.view(self.model_length[index])\r\n                b = a.cpu().numpy()\r\n                print(\r\n                    \"number of nonzero weight is %d, zero is %d\" % (np.count_nonzero(b), len(b) - np.count_nonzero(b)))\r\n                cc.append(len(b) - np.count_nonzero(b))\r\n        return cc\r\n                \r\n\r\n"
  },
  {
    "path": "examples/Structural_Development/SD-SNN/snn_model.py",
    "content": "import abc\nfrom functools import partial\nfrom torch.nn import functional as F\nimport torchvision\nfrom timm.models import register_model\n\nfrom braincog.base.node.node import *\nfrom braincog.base.encoder.encoder import *\nfrom braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule\nfrom utils import *\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\nclass my_cifar_model(BaseModule):\n    def __init__(self,\n                 num_classes=10,\n                 step=8,\n                 node_type=LIFNode,\n                 encode_type='direct',\n                 *args,\n                 **kwargs):\n        super().__init__(step, encode_type, *args, **kwargs)\n\n        self.num_classes = num_classes\n\n        self.feature = nn.Sequential(\n            BaseConvModule(3, 128, kernel_size=(3, 3), padding=(1, 1)),\n            BaseConvModule(128,128, kernel_size=(3, 3), padding=(1, 1)),\n            nn.MaxPool2d(2),\n            BaseConvModule(128,256, kernel_size=(3, 3), padding=(1, 1)),\n            BaseConvModule(256, 256, kernel_size=(3, 3), padding=(1, 1)),\n            nn.MaxPool2d(2),\n            BaseConvModule(256, 512, kernel_size=(3, 3), padding=(1, 1)),\n            BaseConvModule(512, 512, kernel_size=(3, 3), padding=(1, 1)),\n        )\n\n        self.convlayer = [0,1,3,4,6,7,8]\n\n        self.cfla=self._cflatten()\n        self.fc_prun = self._create_fc_prun()\n        self.fc = self._create_fc()\n\n    def _cflatten(self):\n        fc = nn.Sequential(\n            nn.Flatten(),\n        )\n        return fc\n        \n    def _create_fc_prun(self):\n        fc = nn.Sequential(\n            BaseLinearModule(512*8*8, 512)\n        )\n        return fc\n\n    def _create_fc(self):\n        fc = nn.Sequential(\n            BaseLinearModule(512, self.num_classes)\n        )\n        return fc\n    \n    def forward(self, inputs):\n        inputs = self.encoder(inputs)\n\n        self.reset()\n        if not self.training:\n            self.fire_rate.clear()\n\n        outputs = []\n                \n        for t in range(self.step):\n            x = inputs[t]\n            if x.shape[-1] > 32:\n                x = F.interpolate(x, size=[64, 64])\n\n            for i in range(len(self.feature)):\n                x=self.feature[i](x)\n\n            x=self.cfla(x)\n            x=self.fc_prun(x)\n            x = self.fc(x)\n\n            outputs.append(x)\n\n        return sum(outputs) / len(outputs)\n\n\n"
  },
  {
    "path": "examples/Structural_Development/SD-SNN/utils.py",
    "content": "import torch\nimport numpy as np\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\ndef unit(x):\n    if len(x.shape)>0:\n        maxx=np.max(x)\n        minx=np.min(x)\n        marge=maxx-minx\n        if marge!=0:\n            xx=(x-minx)/marge\n            xx=np.clip(xx, 0,1)\n        else:\n            xx=0.5*np.ones_like(x)\n        return xx\n    else:\n        return x\n\ndef unit_tensor(x):\n    if x.size()[0]>0:\n        maxx=torch.max(x)\n        minx=torch.min(x)\n        marge=maxx-minx\n        if marge!=0:\n            xx=(x-minx)/marge\n        else:\n            xx=0.5*torch.ones_like(x)\n        return xx\n    else:\n        return x"
  },
  {
    "path": "examples/Structure_Evolution/Adaptive_lsm/BrainCog-Version/README.md",
    "content": "\n\n# Adaptive structure evolution and biologically plausible synaptic plasticity for recurrent spiking neural networks —— Based on BrainCog #\n\n\n**This is the BrainBog-based version of the paper. To run the original code, please go to [raw](https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Structure_Evolution/Adaptive_lsm/raw)**\n\n## Requirments ##\n* numpy\n* pytorch >= 1.12.0\n* BrainCog\n\n## Run ##\n\n```python main.py```\n\n## Citation ##\n\nIf you find the code and dataset useful in your research, please consider citing:\n```\n\n@article{pan2023adaptive,\n\ttitle = {Adaptive structure evolution and biologically plausible synaptic plasticity for recurrent spiking neural networks},\n\tauthor = {Pan, Wenxuan and Zhao, Feifei and Zeng, Yi and Han, Bing},\n\tjournal = {Scientific Reports},\n\tvolume = {13},\n\tnumber = {1},\n\tpages = {16924},\n\tyear = {2023},\n\turl = {https://doi.org/10.1038/s41598-023-43488-x},\n\tdoi = {10.1038/s41598-023-43488-x},\n}\n\n@article{zeng2023braincog,\n  title={BrainCog: A spiking neural network based, brain-inspired cognitive intelligence engine for brain-inspired AI and brain simulation},\n  author={Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and others},\n  journal={Patterns},\n  volume={4},\n  number={8},\n  year={2023},\n  publisher={Elsevier}\n}\n```\n"
  },
  {
    "path": "examples/Structure_Evolution/Adaptive_lsm/BrainCog-Version/brid.py",
    "content": "import torch, os\nimport pygame\nfrom pygame.locals import *\nfrom collections import deque\nfrom random import randint\nimport numpy as np\nimport nsganet as engine\nfrom pymop.problem import Problem\nfrom pymoo.optimize import minimize\nfrom pymoo.operators.sampling.random_sampling import RandomSampling\nfrom pymoo.operators.mutation.bitflip_mutation import BinaryBitflipMutation\nfrom lsmmodel import SNN\nimport torch.nn.functional as F\n\nfrom tools.update_weights import stdp,bcm\nimport matplotlib.pyplot as plt\nos.environ[\"SDL_VIDEODRIVER\"] = \"dummy\"\nsteps=30000\nseeds=50\nresult=np.zeros([seeds,int(steps/2)])\nt=[i for i in range(steps)]\n\n\n\ndef randbool(size, p):\n    return torch.rand(*size) < p\n\nclass Evolve(Problem):\n    # first define the NAS problem (inherit from pymop)\n    def __init__(self, n_var=20, n_obj=1, n_constr=0, lb=None, ub=None):\n        super().__init__(n_var=n_var, n_obj=n_obj, n_constr=n_constr, type_var=np.int64)\n        self.xl = lb\n        self.xu = ub\n        self._n_evaluated = 0  # keep track of how many architectures are sampled\n\n\n    def _evaluate(self, x, out, *args, **kwargs):\n        \n        objs = np.full((x.shape[0], self.n_obj), np.nan)\n        for i in range(x.shape[0]):\n            arch_id = self._n_evaluated + 1\n            print('Network= {}'.format(arch_id))\n            objs[i, 0] = np.linalg.matrix_rank(x[i])\n            objs[i, 1] = 0\n            self._n_evaluated += 1\n        out[\"F\"] = objs\n        # if your NAS problem has constraints, use the following line to set constraints\n        # out[\"G\"] = np.column_stack([g1, g2, g3, g4, g5, g6]) in case 6 constraints\n\n\n# ---------------------------------------------------------------------------------------------------------\n# Define what statistics to print or save for each generation\n# ---------------------------------------------------------------------------------------------------------\ndef do_every_generations(algorithm):\n    # this function will be call every generation\n    # it has access to the whole algorithm class\n    gen = algorithm.n_gen\n    pop_var = algorithm.pop.get(\"X\")\n    pop_obj = algorithm.pop.get(\"F\")\n    \n    # report generation info to files\n    print(\"generation = {}\".format(gen))\n    print(\"population error: best = {}, mean = {}, \"\n                 \"median = {}, worst = {}\".format(np.min(pop_obj[:, 0]), np.mean(pop_obj[:, 0]),\n                                                  np.median(pop_obj[:, 0]), np.max(pop_obj[:, 0])))\n    print('Best Genome id= {}'.format(np.argmin(pop_obj[:, 0])))\n\ndef load_images():\n    \"\"\"\n    Flappy Bird中load图像\n    :return:load的图像\n    \"\"\"\n\n    def load_image(img_file_name):\n        file_name = os.path.join('/home/panwenxuan/raw', 'birdimages', img_file_name)\n        img = pygame.image.load(file_name)\n        # converting all images before use speeds up blitting\n        img.convert()\n        return img\n\n    return {'background': load_image('background.png'),\n            'pipe-end': load_image('pipe_end.png'),\n            'pipe-body': load_image('pipe_body.png'),\n            # images for animating the flapping bird -- animated GIFs are\n            # not supported in pygame\n            'bird-wingup': load_image('bird_wing_up.png'),\n            'bird-wingdown': load_image('bird_wing_down.png'), }\n\n\nclass Bird(pygame.sprite.Sprite):\n    \"\"\"\n    Flappy Bird类\n    \"\"\"\n    WIDTH = HEIGHT = 32\n    SINK_SPEED = 0.2\n    Fail_SINk_SPEED = 0.6\n    CLIMB_SPEED = 0.25\n    CLIMB_DURATION = 333.3\n    REGION = CLIMB_DURATION / 3\n    NEAR_COLLIDE = 30\n    NEAR_PIPE = 0\n\n    def __init__(self, x, y, msec_to_climb, images):\n        super(Bird, self).__init__()\n        self.x, self.y = x, y\n        self.msec_to_climb = msec_to_climb\n        self._img_wingup, self._img_wingdown = images\n        self._mask_wingup = pygame.mask.from_surface(self._img_wingup)\n        self._mask_wingdown = pygame.mask.from_surface(self._img_wingdown)\n\n    def update(self, action, state, delta_frames=1):\n        \"\"\"\n        更新小鸟的位置\n        :param action: 输入行为\n        :param state:输入状态\n        :param delta_frames:Fault\n        :return:None\n        \"\"\"\n        if self.msec_to_climb > 0 and action == 1:\n            if state == 4 or state == 5 or state == 2 or state == 3:\n                self.y -= (2 * Bird.CLIMB_SPEED * (1000.0 * delta_frames / 60))\n            else:\n                self.y -= (Bird.CLIMB_SPEED * (1000.0 * delta_frames / 60))\n        else:\n            if state == 4 or state == 5 or state == 2 or state == 3:\n                self.y += 2 * Bird.SINK_SPEED * (1000.0 * delta_frames / 60)\n            else:\n                self.y += Bird.SINK_SPEED * (1000.0 * delta_frames / 60)\n\n    def sink(self, delta_frames=1):\n        self.y += Bird.Fail_SINk_SPEED * (1000.0 * delta_frames / 60)\n\n    @property\n    def image(self):\n        if pygame.time.get_ticks() % 500 >= 250:\n            return self._img_wingup\n        else:\n            return self._img_wingdown\n\n    @property\n    def mask(self):\n        if pygame.time.get_ticks() % 500 >= 250:\n            return self._mask_wingup\n        else:\n            return self._mask_wingdown\n\n    @property\n    def rect(self):\n        return Rect(self.x, self.y, Bird.WIDTH, Bird.HEIGHT)\n\n\nclass PipePair(pygame.sprite.Sprite):\n    \"\"\"\n    Flappy Bird 中的管子类\n    \"\"\"\n    WIDTH = 80\n    PIECE_HEIGHT = 32\n    ADD_INTERVAL = 2000\n    ADD_EVENT = pygame.USEREVENT + 1\n    ROOM_HIGHT = 2 * Bird.HEIGHT + 2 * PIECE_HEIGHT\n\n    def __init__(self, pipe_end_img, pipe_body_img):\n        self.x = float(WIN_WIDTH - 1)\n        self.score_counted = False\n        self.isNewPipe = True\n\n        self.image = pygame.Surface((PipePair.WIDTH, WIN_HEIGHT), SRCALPHA)\n        self.image.convert()  # speeds up blitting\n        self.image.fill((0, 0, 0, 0))\n        total_pipe_body_pieces = int(\n            (WIN_HEIGHT -  # fill window from top to bottom\n             3 * Bird.HEIGHT -  # make room for bird to fit through\n             3 * PipePair.PIECE_HEIGHT) /  # 2 end pieces + 1 body piece\n            PipePair.PIECE_HEIGHT  # to get number of pipe pieces\n        )\n        self.bottom_pieces = randint(1, total_pipe_body_pieces)\n        self.top_pieces = total_pipe_body_pieces - self.bottom_pieces\n\n        # bottom pipe\n        for i in range(1, self.bottom_pieces + 1):\n            piece_pos = (0, WIN_HEIGHT - i * PipePair.PIECE_HEIGHT)\n            self.image.blit(pipe_body_img, piece_pos)\n        bottom_pipe_end_y = WIN_HEIGHT - self.bottom_height_px\n        bottom_end_piece_pos = (0, bottom_pipe_end_y - PipePair.PIECE_HEIGHT)\n        self.image.blit(pipe_end_img, bottom_end_piece_pos)\n\n        # top pipe\n        for i in range(self.top_pieces):\n            self.image.blit(pipe_body_img, (0, i * PipePair.PIECE_HEIGHT))\n        top_pipe_end_y = self.top_height_px\n        self.image.blit(pipe_end_img, (0, top_pipe_end_y))\n\n        self.center = (top_pipe_end_y + bottom_pipe_end_y) / 2\n\n        # compensate for added end pieces\n        self.top_pieces += 1\n        self.bottom_pieces += 1\n\n        # for collision detection\n        self.mask = pygame.mask.from_surface(self.image)\n        self.top_y = top_pipe_end_y\n        self.bottom_y = bottom_pipe_end_y\n\n    @property\n    def top_height_px(self):\n        return self.top_pieces * PipePair.PIECE_HEIGHT\n\n    @property\n    def bottom_height_px(self):\n        return self.bottom_pieces * PipePair.PIECE_HEIGHT\n\n    @property\n    def visible(self):\n        return -PipePair.WIDTH < self.x < WIN_WIDTH\n\n    @property\n    def rect(self):\n        return Rect(self.x, 0, PipePair.WIDTH, PipePair.PIECE_HEIGHT)\n\n    def update(self, delta_frames=1):\n        self.x -= 0.18 * 1000.0 * delta_frames / 60\n\n    def collides_with(self, bird):\n        return pygame.sprite.collide_mask(self, bird)\n\n\ndef judgeState(bird, pipes, collide):\n    \"\"\"\n    根据小鸟和管子之间的位置关系判断当前状态\n    :param bird:传入小鸟的各项属性\n    :param pipes:传入管子的各项属性\n    :param collide:是否发生碰撞\n    :return:状态，距离，是否是新的管子\n    \"\"\"\n    # bird's x and y coordinate in the left top of the image\n    dist = bird.y + Bird.HEIGHT / 2 - WIN_HEIGHT / 2\n    isNew = False\n    index = -1\n    state = -1\n    if collide:\n        state = 8\n        return state\n    for p in pipes:\n        if p.x + PipePair.WIDTH - Bird.HEIGHT / 4 < bird.x and not p.score_counted:\n            continue\n        if p.x - Bird.NEAR_PIPE <= bird.x + Bird.HEIGHT and \\\n                p.x + PipePair.WIDTH - Bird.HEIGHT / 4 >= bird.x:\n\n            p_top_y = p.top_y + PipePair.PIECE_HEIGHT\n            p_bottom_y = p.bottom_y - PipePair.PIECE_HEIGHT\n            if p.center - bird.y - Bird.HEIGHT / 2 >= 0 and bird.y >= p_top_y + Bird.NEAR_COLLIDE / 2:\n                state = 0\n            elif bird.y - p.center + Bird.HEIGHT / 2 > 0 and bird.y + Bird.HEIGHT <= p_bottom_y - Bird.NEAR_COLLIDE / 2:\n                state = 1\n            elif bird.y < p_top_y + Bird.NEAR_COLLIDE / 2 and bird.y > p_top_y - 10:\n                state = 6\n            elif bird.y + Bird.HEIGHT > p_bottom_y - Bird.NEAR_COLLIDE / 2 and bird.y + Bird.HEIGHT < p_bottom_y + 10:\n                state = 7\n            if state > -0.5:\n                index = 1\n        elif p.x > bird.x + Bird.HEIGHT + Bird.NEAR_PIPE:\n            state = blankState(bird, p.center)\n            if p.isNewPipe:\n                isNew = True\n            p.isNewPipe = False\n            index = 1\n        if index > 0:  # only judge the nearest and not passed pipe\n            dist = bird.y + Bird.HEIGHT / 2 - p.center\n            break\n    if index < -0.5:  # no pipe left, key the bird in the middle\n        pos = WIN_HEIGHT / 2\n        dist = bird.y + Bird.HEIGHT / 2 - pos\n        state = blankState(bird, pos)\n\n    return state, dist, isNew\n\n\ndef blankState(bird, center):\n    \"\"\"\n    judgeState中调用的判断状态的函数 根据鸟的位置和管子中心的距离来判断\n    :param bird: 传入小鸟的各项属性\n    :param center:中心\n    :return:状态\n    \"\"\"\n    realHeight = (PipePair.ROOM_HIGHT - Bird.HEIGHT) / 2\n    if center - bird.y - Bird.HEIGHT / 2 >= 0 and \\\n            center - bird.y - Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2:\n        state = 0\n    elif bird.y - center + Bird.HEIGHT / 2 >= 0 and \\\n            bird.y - center + Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2:\n        state = 1\n    elif center - bird.y - Bird.HEIGHT / 2 >= realHeight - Bird.NEAR_COLLIDE / 2 and \\\n            center - bird.y - Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION:\n        state = 2\n    elif bird.y - center + Bird.HEIGHT / 2 >= realHeight - Bird.NEAR_COLLIDE / 2 and \\\n            bird.y - center + Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION:\n        state = 3\n    elif bird.y + Bird.HEIGHT / 2 <= center - (realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION):\n        state = 4\n    elif bird.y + Bird.HEIGHT / 2 >= center + realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION:\n        state = 5\n    return state\n\n\ndef getReward(state, lastState, smallerError, isNewPipe):\n    \"\"\"\n    根据状态和距离的变化获得奖励\n    :param state: 执行行为后的当前状态\n    :param lastState:执行行为之前的上一状态\n    :param smallerError:距离是否变小\n    :param isNewPipe:是否是新的管子\n    :return:奖励\n    \"\"\"\n    if state == 0 or state == 1:\n        reward = 6\n    elif state == 2 or state == 3:\n        if lastState == state and not isNewPipe:\n            if smallerError:\n                reward = 3\n            else:\n                reward = -5\n        else:\n            reward = -3\n    elif state == 4 or state == 5:\n        if lastState == state and not isNewPipe:\n            if smallerError:\n                reward = 3\n            else:\n                reward = -8\n        else:\n            reward = -5\n    elif state == 6 or state == 7:\n        if lastState == state and not isNewPipe:\n            if smallerError:\n                reward = 3\n            else:\n                reward = -3\n        else:\n            reward = -3\n    elif state == 8:  # collide\n        reward = -100\n    return reward\n\n\n\n\nif __name__ == \"__main__\":\n\n    n_agent=1\n    num = 8\n    p_amount = int(num * num / 10)\n    s_amount = 4\n    num_state = 9\n    num_action = 2\n    weight_exc = 1\n    weight_inh = -0.5\n    trace_decay = 0.8\n    gens=1000\n    for seed in range(seeds):\n\n        kkk = Evolve(n_var=num*num, \n                        n_obj=2, n_constr=0)\n        method = engine.nsganet(pop_size=n_agent,\n                                sampling=RandomSampling(var_type='custom'),\n                                mutation=BinaryBitflipMutation(),\n                                n_offsprings=10,\n                                eliminate_duplicates=True)\n        kres=minimize(kkk,\n                        method,\n                        callback=do_every_generations,\n                        termination=('n_gen', gens))\n    \n        pop_var = kres.X\n        pop_obj = kres.F        \n        lm=torch.from_numpy(pop_var[np.argmin(pop_obj[:, 0])].reshape(num,num))\n    \n        model = SNN(ins=9,num_classes=2,n_agent=n_agent,device='cuda:0',liquid_size=num,lsm_tau=2,lsm_th=0.2,connectivity_matrix=lm.to('cuda:0').float())\n        model.to('cuda:0')\n\n        con_matrix1 = torch.zeros((num_state, num_state * num_action), dtype=torch.float)\n        for i in range(num_state):\n            for j in range(num_action):\n                con_matrix1[i, i * num_action + j] = weight_exc\n        weight_trace_d1 = torch.zeros(con_matrix1.shape, dtype=torch.float)\n        weight_trace_d2 = torch.zeros(con_matrix1.shape, dtype=torch.float)\n\n        pygame.init()\n        WIN_HEIGHT = 512\n        WIN_WIDTH = 284 * 2\n        heighest = 0\n        contTime = 0\n        display_frame = 0\n        display_surface = pygame.display.set_mode((WIN_WIDTH, WIN_HEIGHT))\n        pygame.display.set_caption('Flappy Bird')\n        images = load_images()\n        bird = Bird(250, int(WIN_HEIGHT / 2 - Bird.HEIGHT / 2), 2,\n                    (images['bird-wingup'], images['bird-wingdown']))\n\n        clock = pygame.time.Clock()\n        score_font = pygame.font.SysFont(None, 25, bold=True)\n        info_font = pygame.font.SysFont(None, 50, bold=True)\n        collide = paused = False\n        frame_clock = 0\n        pipes = deque()\n        score = 0\n        lastDist = 0\n        lastState = 0  # init\n        state = lastState\n        i = 0\n        num_reward = []\n        num_score = []\n        reward=1\n        while not collide:\n            i = i + 1\n            if i > steps:\n                break\n            clock.tick(60)\n            if frame_clock % 2 == 0 or frame_clock == 1:\n                state, dist, isNewPipe = judgeState(bird, pipes, collide)\n                lastState = state\n                lastDist = dist\n                action= model(F.one_hot(torch.tensor([state]), num_classes=num_state).to('cuda:0').float()).cpu().detach().numpy()\n                action=int(np.argmax(action,axis=1))\n                print(i)\n\n            if not (paused or frame_clock % (60 * PipePair.ADD_INTERVAL / 1000.0)):\n                pygame.event.post(pygame.event.Event(PipePair.ADD_EVENT))\n\n            for e in pygame.event.get():\n                if e.type == QUIT or (e.type == KEYUP and e.key == K_ESCAPE):\n                    collide = True\n                elif e.type == KEYUP and e.key in (K_PAUSE, K_p):\n                    paused = not paused\n                elif e.type == PipePair.ADD_EVENT:\n                    pp = PipePair(images['pipe-end'], images['pipe-body'])\n                    pipes.append(pp)\n            if paused:\n                continue  # don't draw anything\n            pipe_collision = any(p.collides_with(bird) for p in pipes)\n            if pipe_collision or 0 >= bird.y or bird.y >= WIN_HEIGHT - Bird.HEIGHT:\n                collide = True\n            for x in (0, WIN_WIDTH / 2):\n                display_surface.blit(images['background'], (x, 0))\n            while pipes and not pipes[0].visible:\n                pipes.popleft()\n            for p in pipes:\n                p.update()\n                display_surface.blit(p.image, p.rect)\n            bird.update(action, state)\n            display_surface.blit(bird.image, bird.rect)\n            if frame_clock % 2 == 0 or frame_clock == 1 or collide:\n                dist = 0\n                if collide:\n                    nextState = 8\n                    isNewPipe = False\n                else:\n                    nextState, dist, isNewPipe = judgeState(bird, pipes, collide)  # judge the bird's state\n                    print(\"next state:\", nextState)\n                print(\"lastdist, dist:\", lastDist, dist)\n                isSmallerError = False\n                if state == nextState:\n                    isSmallerError = False\n                    if lastDist <= 0:\n                        if lastDist < dist:\n                            isSmallerError = True\n                    else:\n                        if lastDist > dist:\n                            isSmallerError = True\n                if frame_clock > 0 and not collide:\n                    reward = getReward(nextState, state, isSmallerError, isNewPipe)\n                    print(\"reward:\", reward)\n                    num_reward.append(reward)\n                    bcmreward=np.array([reward, reward])\n                    # bcm(model,bcmreward, input=input)\n                state = nextState  # going on the next state\n                weight_trace_d1 = torch.zeros(con_matrix1.shape, dtype=torch.float)\n                weight_trace_d2 = torch.zeros(con_matrix1.shape, dtype=torch.float)\n                model.reset()\n                display_frame += 1\n            for p in pipes:\n                if p.x + PipePair.WIDTH < bird.x and not p.score_counted:\n                    score += 1\n                    p.score_counted = True\n            num_score.append(score)\n            score_surface = score_font.render('Current score: ' + str(score), True, (0, 0, 0))  # current score\n            score_x = WIN_WIDTH / 2 - 3 * score_surface.get_width() / 4\n            display_surface.blit(score_surface, (score_x, PipePair.PIECE_HEIGHT))\n            if heighest < score:\n                heighest = score\n            score_surface_h = score_font.render('Highest score: ' + str(heighest), True,\n                                                (0, 0, 0))  # heighest score\n            score_x_h = 4 * WIN_WIDTH / 5 - 1.2 * score_surface.get_width() / 3\n            display_surface.blit(score_surface_h, (score_x_h, PipePair.PIECE_HEIGHT))\n            score_surface_i = score_font.render('Attempts: 0', True, (0, 0, 0))  # heighest score\n            score_x_i = 10\n            display_surface.blit(score_surface_i, (score_x_i, PipePair.PIECE_HEIGHT))\n            frame_clock += 1\n            pygame.display.flip()\n\n        #  if collide, display the fail information, for 2 frames\n        cct = 0\n        while (bird.y < WIN_HEIGHT - Bird.HEIGHT - 3):\n            clock.tick(60)\n            for x in (0, WIN_WIDTH / 2):\n                display_surface.blit(images['background'], (x, 0))\n            while pipes and not pipes[0].visible:\n                pipes.popleft()\n            for p in pipes:\n                display_surface.blit(p.image, p.rect)\n            if cct >= 6:\n                bird.sink()\n            display_surface.blit(bird.image, bird.rect)\n            fail_infor = info_font.render('Game over !', True, (255, 60, 30))  # current score\n            pos_x = WIN_WIDTH / 2 - fail_infor.get_width() / 2\n            pos_y = WIN_HEIGHT / 2 - 100\n            display_surface.blit(fail_infor, (pos_x, pos_y))\n            #  display the score\n            score_surface = score_font.render('Current score: ' + str(score), True, (0, 0, 0))  # current score\n            score_x = WIN_WIDTH / 2 - 3 * score_surface.get_width() / 4\n            display_surface.blit(score_surface, (score_x, PipePair.PIECE_HEIGHT))\n            if heighest < score:\n                heighest = score\n            score_surface_h = score_font.render('Highest score: ' + str(heighest), True,\n                                                (0, 0, 0))  # heighest score\n            score_x_h = 4 * WIN_WIDTH / 5 - 1.2 * score_surface.get_width() / 3\n            display_surface.blit(score_surface_h, (score_x_h, PipePair.PIECE_HEIGHT))\n            score_surface_i = score_font.render('Attempts: 0', True, (0, 0, 0))  # heighest score\n            score_x_i = 10\n            display_surface.blit(score_surface_i, (score_x_i, PipePair.PIECE_HEIGHT))\n            pygame.display.flip()\n            cct += 1\n        if heighest < score:\n            heighest = score\n        contTime += 1\n        num_reward_np = np.array(num_reward)\n        print(num_reward_np)\n        k=num_reward_np.shape[0]\n        result[seed,:k]=num_reward_np\n\n"
  },
  {
    "path": "examples/Structure_Evolution/Adaptive_lsm/BrainCog-Version/lsmmodel.py",
    "content": "\nfrom functools import partial\nfrom torch.nn import functional as F\nfrom torch import nn as nn\nimport torchvision, pprint\nfrom copy import deepcopy\nfrom timm.models import register_model\nfrom braincog.base.node.node import *\nfrom braincog.base.connection.layer import *\nfrom braincog.base.encoder.encoder import *\nfrom braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule\nfrom braincog.base.brainarea.BrainArea import BrainArea\nfrom braincog.base.connection.CustomLinear import *\nfrom braincog.base.learningrule.STDP import *\nfrom braincog.base.learningrule.BCM import *\nimport matplotlib.pyplot as plt\n\n\n@register_model\nclass SNN(BaseModule):\n    def __init__(self,\n                 liquid_size,\n                 n_agent,\n                 device,\n                 connectivity_matrix,\n                 num_classes=3,\n                 step=1,\n                 node_type=LIFNode,\n                 encode_type='direct',\n                 lsm_th=0.3,\n                 fc_th=0.3,\n                 lsm_tau=3,\n                 fc_tau=3,\n                 tw=100,\n                 *args,\n                 **kwargs):\n        super().__init__(step, encode_type, *args, **kwargs)\n        self.batchsize=n_agent\n\n        self.node_lsm=partial(node_type, **kwargs, step=step,tau=lsm_tau,threshold=lsm_th)\n        self.node_fc = partial(node_type, **kwargs, step=step,tau=fc_tau,threshold=fc_th)\n        self.liquid_size=liquid_size\n        self.out = torch.zeros(self.batchsize, liquid_size).to(device)\n        self.device=device\n        self.con=[]\n        self.learning_rule=[]\n        self.connectivity_matrix=connectivity_matrix\n        w1tmp=nn.Linear(4,liquid_size,bias=False).to(device)\n        self.con.append(w1tmp)\n        w2tmp=nn.Linear(liquid_size,liquid_size,bias=False).to(device)\n\n        self.liquid_weight=w2tmp.weight.data\n        \n        w2tmp.weight.data=w2tmp.weight.data*self.connectivity_matrix\n        self.con.append(w2tmp)\n        self.learning_rule.append(BCM(self.node_lsm(), [self.con[0], self.con[1]]))  # pm\n        self.fc = nn.Linear(liquid_size,num_classes).to(device)\n\n \n        self.learning_rule.append(BCM(self.node_fc(), [self.fc]))  # pm\n\n        \n\n    def forward(self, x):\n        x = x.reshape(x.shape[0], -1)\n        sum_spike=0\n        time_window=20\n        self.tw=time_window\n        self.firing_tw=torch.zeros(time_window, self.batchsize, self.liquid_size).to(self.device)\n        self.out = torch.zeros(self.batchsize, self.liquid_size).to(self.device)\n        for t in range(time_window):\n\n            self.out, self.dw = self.learning_rule[0](x, self.out)\n            # self.con[0].weight+=self.dw[0]\n            self.con[1].weight.data+=self.dw[1]\n\n            out_liquid=self.out[:,0:self.liquid_size]\n\n            xout,dw = self.learning_rule[1](out_liquid)\n            self.fc.weight.data+=dw[0]\n            sum_spike=sum_spike+xout\n            self.firing_tw[t]=out_liquid\n\n        outputs = sum_spike / time_window\n        return outputs\n\n"
  },
  {
    "path": "examples/Structure_Evolution/Adaptive_lsm/BrainCog-Version/maze.py",
    "content": "import matplotlib.pyplot as plt\r\nimport networkx as nx\r\nimport numpy as np\r\nimport math\r\nfrom matplotlib import pyplot as plt\r\nimport matplotlib\r\nimport seaborn as sns\r\nfrom lsmmodel import SNN\r\nfrom tools.ExperimentEnvGlobalNetworkSurvival import ExperimentEnvGlobalNetworkSurvival\r\nfrom tools.MazeTurnEnvVec import MazeTurnEnvVec\r\nimport torch\r\nimport brewer2mpl\r\nfrom cycler import cycler\r\nimport nsganet as engine\r\nfrom pymop.problem import Problem\r\nfrom pymoo.optimize import minimize\r\nfrom pymoo.operators.sampling.random_sampling import RandomSampling\r\nfrom pymoo.operators.mutation.bitflip_mutation import BinaryBitflipMutation\r\n\r\ndef randbool(size, p):\r\n    return torch.rand(*size) < p\r\n\r\nclass Evolve(Problem):\r\n    # first define the NAS problem (inherit from pymop)\r\n    def __init__(self, n_var=20, n_obj=1, n_constr=0, lb=None, ub=None):\r\n        super().__init__(n_var=n_var, n_obj=n_obj, n_constr=n_constr, type_var=np.int64)\r\n        self.xl = lb\r\n        self.xu = ub\r\n        self._n_evaluated = 0  # keep track of how many architectures are sampled\r\n\r\n\r\n    def _evaluate(self, x, out, *args, **kwargs):\r\n        \r\n        objs = np.full((x.shape[0], self.n_obj), np.nan)\r\n        for i in range(x.shape[0]):\r\n            arch_id = self._n_evaluated + 1\r\n            print('Network= {}'.format(arch_id))\r\n            objs[i, 0] = np.linalg.matrix_rank(x[i])\r\n            self._n_evaluated += 1\r\n        out[\"F\"] = objs\r\n        # if your NAS problem has constraints, use the following line to set constraints\r\n        # out[\"G\"] = np.column_stack([g1, g2, g3, g4, g5, g6]) in case 6 constraints\r\n\r\n\r\n# ---------------------------------------------------------------------------------------------------------\r\n# Define what statistics to print or save for each generation\r\n# ---------------------------------------------------------------------------------------------------------\r\ndef do_every_generations(algorithm):\r\n    # this function will be call every generation\r\n    # it has access to the whole algorithm class\r\n    gen = algorithm.n_gen\r\n    pop_var = algorithm.pop.get(\"X\")\r\n    pop_obj = algorithm.pop.get(\"F\")\r\n    \r\n    # report generation info to files\r\n    print(\"generation = {}\".format(gen))\r\n    print(\"population error: best = {}, mean = {}, \"\r\n                 \"median = {}, worst = {}\".format(np.min(pop_obj[:, 0]), np.mean(pop_obj[:, 0]),\r\n                                                  np.median(pop_obj[:, 0]), np.max(pop_obj[:, 0])))\r\n    print('Best Genome id= {}'.format(np.argmin(pop_obj[:, 0])))\r\n\r\nif __name__ == '__main__':\r\n\r\n    device = 'cuda:8'\r\n    num = 8\r\n    n_agent = 20\r\n    steps = 500\r\n    liquid_size=80\r\n\r\n    env = MazeTurnEnvVec(n_agent, n_steps=steps)\r\n    newenv=MazeTurnEnvVec(n_agent, n_steps=steps)\r\n    data_env = ExperimentEnvGlobalNetworkSurvival(env)\r\n    newdata_env = ExperimentEnvGlobalNetworkSurvival(newenv)\r\n\r\n\r\n\r\n    gens=100\r\n    seed=0\r\n    sum_of_env = np.zeros([gens, n_agent])\r\n    env_r=np.zeros([steps,n_agent])\r\n\r\n    population = torch.zeros(n_agent,liquid_size,liquid_size)\r\n\r\n    for i in range(n_agent):\r\n        population[i] = randbool([liquid_size, liquid_size],p=0.01).to(device).float()\r\n\r\n\r\n\r\n\r\n    kkk = Evolve(n_var=liquid_size*liquid_size, \r\n                    n_obj=1, n_constr=2)\r\n    method = engine.nsganet(pop_size=n_agent,\r\n                            sampling=RandomSampling(var_type='custom'),\r\n                            mutation=BinaryBitflipMutation(),\r\n                            n_offsprings=10,\r\n                            eliminate_duplicates=True)\r\n    kres=minimize(kkk,\r\n                    method,\r\n                    callback=do_every_generations,\r\n                    termination=('n_gen', gens))\r\n\r\n\r\n        \r\n    # lm=evolve(population, gens)\r\n\r\n    model = SNN(ins=4,n_agent=n_agent,device=device,liquid_size=liquid_size,lsm_tau=2,lsm_th=0.2,connectivity_matrix=randbool([liquid_size, liquid_size],p=0.01).to(device).float())\r\n    model.to(device)\r\n    old_dis = np.ones([n_agent,])*13\r\n\r\n    X = data_env.reset()\r\n    envreward = np.zeros([n_agent, ])\r\n    fit=np.zeros([n_agent])\r\n\r\n    for i in range(steps):\r\n        model.reset()\r\n\r\n        out = model(torch.from_numpy(X+1).float().to(device)).cpu().detach().numpy()\r\n\r\n        X_next, envreward, fitness, infos = data_env.step(np.argmax(out,axis=1))\r\n\r\n        food_pos = data_env.env.food_pos[:, 0, :2]\r\n\r\n        agent_pos = data_env.env.agents_pos\r\n        print(agent_pos)\r\n        dis = ((agent_pos - food_pos) ** 2).sum(1)\r\n\r\n        reward =np.array((np.sqrt(old_dis)-np.sqrt(dis))>0,dtype=int)\r\n\r\n        aa=np.ones_like(reward)*-1\r\n\r\n        bb = np.ones_like(reward)*3\r\n\r\n        cc = np.ones_like(reward)*-3\r\n\r\n        reward=np.where(reward == 0 , aa, reward)\r\n        reward=np.where(envreward == 1, bb, reward)\r\n        reward = np.where(envreward == -1, cc, reward)\r\n        old_dis= dis\r\n        env_r[i]=reward\r\n\r\n"
  },
  {
    "path": "examples/Structure_Evolution/Adaptive_lsm/BrainCog-Version/tools/EnuGlobalNetwork.py",
    "content": "import pickle\r\nimport time\r\n\r\nimport numpy as np\r\nimport torch\r\nimport matplotlib.pyplot as plt\r\nimport seaborn as sns\r\nfrom matplotlib import gridspec\r\n\r\nfrom AbstractLayerBMM import AbstractLayerBMM\r\nfrom EvolvableNeuralUnitStacked import EvolvableNeuralUnitStacked\r\nfrom Tools import get_data_path\r\n\r\nsns.set_style(\"darkgrid\")\r\n\r\n\r\nclass EnuGlobalNetwork(AbstractLayerBMM):\r\n    \"\"\"Network of ENUs implementation in PyTorch, where each synapse and neuron is modeled as an ENU. \"\"\"\r\n\r\n    def __init__(self, n_offspring, n_pseudo_env, n_input_neurons, n_hidden_neurons, n_output_neurons, n_syn_per_neuron):\r\n        # offspring\r\n        self.n_offspring = n_offspring\r\n        self.n_pseudo_env = n_pseudo_env\r\n        # input channels\r\n        n_input_channels = 16\r\n        self.n_input_channels = n_input_channels\r\n        n_dynamic_param = 32\r\n        # total neurons\r\n        n_neurons = n_output_neurons + n_hidden_neurons\r\n        self.n_neurons = n_neurons\r\n        super().__init__(n_offspring, n_neurons, n_input_neurons, n_output_neurons)\r\n        torch.random.manual_seed(0)\r\n        #NOTE: batch dimension holds output of each neuron/synapse, allowing fast GPU MM\r\n        #NOTE neurons far less than synapses, so can be relatively bigger rnn for little cost\r\n        n_input_channels_neuron = 16\r\n        n_input_neuron, n_output_neuron = n_input_channels_neuron, n_input_channels\r\n        self.neurons = EvolvableNeuralUnitStacked(n_offspring, batch_size=self.n_neurons, n_input=n_input_neuron, n_dynamic_param=n_dynamic_param, n_output=n_output_neuron)\r\n        #self.n_syn = next_power_of_2(int(n_neurons * (rel_connectivity*n_neurons)))\r\n        self.n_syn_per_neuron = n_syn_per_neuron\r\n        self.n_syn = n_neurons * n_syn_per_neuron\r\n        n_input_syn, n_output_syn = n_input_channels * 2, n_input_channels_neuron # * 2 for neuron feedback (which same channel as n_channel input)\r\n        self.synapses = EvolvableNeuralUnitStacked(n_offspring, batch_size=self.n_syn, n_input=n_input_syn, n_dynamic_param=n_dynamic_param, n_output=n_output_syn)\r\n        # just randomly connect synapses to neurons\r\n        self.synapse_connections = torch.randint(n_input_neurons + n_neurons, size=(n_neurons, n_syn_per_neuron), device='cuda', dtype=torch.long)\r\n        # fixed predefined connection patterns\r\n        if n_input_neurons==2 and n_output_neurons==2 and n_hidden_neurons==2:\r\n            print(\"Fixed connection Network 2-2-2\")\r\n            self.synapse_connections = torch.tensor([[0, 1],\r\n                                                     [0, 1],\r\n                                                     [2, 3],\r\n                                                     [2, 3]], device='cuda', dtype=torch.long)\r\n        elif n_input_neurons == 4 and n_output_neurons == 3 and n_hidden_neurons == 3 and n_syn_per_neuron==3:\r\n            print(\"Fixed connection Network 4-3-3 (3syn)\")\r\n            self.synapse_connections = torch.tensor([[0, 1, 3],# hidden connections #4\r\n                                                     [0, 2, 3], #5\r\n                                                     [1, 2, 3],# 6\r\n                                                     [4, 5, 6], # output connections #7\r\n                                                     [4, 5, 6],#8\r\n                                                     [4, 5, 6]#9\r\n                                                     ], device='cuda', dtype=torch.long)\r\n        elif n_input_neurons==5 and n_hidden_neurons==0 and n_output_neurons==4:\r\n            print(\"Fixed connection Network 5-0-4 (5syn)\")\r\n            # neuron i connected to neuron j and k, neuron 0..input_neurons is index\r\n            self.synapse_connections = torch.tensor([[0, 1, 2, 3, 4],# output connections\r\n                                                     [0, 1, 2, 3, 4],\r\n                                                     [0, 1, 2, 3, 4],\r\n                                                     [0, 1, 2, 3, 4]\r\n                                                     ], device='cuda', dtype=torch.long)\r\n        elif n_input_neurons==1 and n_hidden_neurons==0 and n_output_neurons==2:\r\n            print(\"Fixed connection Network 1-0-2 (1syn)\")\r\n            # neuron i connected to neuron j and k, neuron 0..input_neurons is index\r\n            self.synapse_connections = torch.tensor([[0],# output connections\r\n                                                     [0]\r\n                                                     ], device='cuda', dtype=torch.long)\r\n        elif n_input_neurons==4 and n_hidden_neurons==0 and n_output_neurons==3 and n_syn_per_neuron==4:\r\n            print(\"Sparse connection Network 4-0-3 (4syn)\")\r\n            # neuron i connected to neuron j and k, neuron 0..input_neurons is index\r\n            self.synapse_connections = torch.tensor([[0, 1, 2, 3],# output connections #4\r\n                                                     [0, 1, 2, 3], #5\r\n                                                     [0, 1, 2, 3],# 6\r\n                                                     ], device='cuda', dtype=torch.long)\r\n        elif n_input_neurons == 4 and n_hidden_neurons == 3 and n_output_neurons == 3 and n_syn_per_neuron == 4:\r\n            print(\"Sparse connection Network 4-3-3 (3syn)\")\r\n            # neuron i connected to neuron j and k, neuron 0..input_neurons is index\r\n            self.synapse_connections = torch.tensor([[0, 1, 3],  # hidden connections #4\r\n                                                     [0, 2, 3],  # 5\r\n                                                     [1, 2, 3],  # 6\r\n                                                     [4, 5, 3],  # output connections #7\r\n                                                     [4, 6, 3],  # 8\r\n                                                     [5, 6, 3]  # 9\r\n                                                     ], device='cuda', dtype=torch.long)\r\n        elif n_input_neurons==4 and n_hidden_neurons==3 and n_output_neurons==3 and n_syn_per_neuron==8:\r\n            print(\"Sparse connection Network 4-3-3 (8syn)\")\r\n            # neuron i connected to neuron j and k, neuron 0..input_neurons is index\r\n            self.synapse_connections = torch.tensor([[0, 1, 5, 6, 7, 8, 3, 4],# hidden connections #4\r\n                                                     [0, 2, 4, 6, 7, 9, 3, 5], #5\r\n                                                     [1, 2, 4, 5, 8, 9, 3, 6],# 6\r\n                                                     [4, 5, 8, 9, 0, 1, 3, 7], # output connections #7\r\n                                                     [4, 6, 7, 9, 0, 2, 3, 8],#8\r\n                                                     [5, 6, 7, 8, 1, 2, 3, 9]#9\r\n                                                     ], device='cuda', dtype=torch.long)\r\n        elif n_input_neurons==4 and n_hidden_neurons==4 and n_output_neurons==4 and n_syn_per_neuron==8:\r\n            print(\"Fixed connection Network 4-4-4 (8syn)\")\r\n            # neuron i connected to neuron j and k, neuron 0..input_neurons is index\r\n            self.synapse_connections = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7],# hidden connections\r\n                                                     [0, 1, 2, 3, 4, 5, 6, 7],\r\n                                                     [0, 1, 2, 3, 4, 5, 6, 7],\r\n                                                     [0, 1, 2, 3, 4, 5, 6, 7],\r\n                                                     [4, 5, 6, 7, 8, 9, 10, 11], # output connections\r\n                                                     [4, 5, 6, 7, 8, 9, 10, 11],\r\n                                                     [4, 5, 6, 7, 8, 9, 10, 11],\r\n                                                     [4, 5, 6, 7, 8, 9, 10, 11],\r\n                                                     ], device='cuda', dtype=torch.long)\r\n        elif n_input_neurons==5 and n_hidden_neurons==5 and n_output_neurons==4:\r\n            print(\"Fixed connection Network 5-5-4 (5syn)\")\r\n            # neuron i connected to neuron j and k, neuron 0..input_neurons is index\r\n            self.synapse_connections = torch.tensor([[0, 1, 2, 3, 4],# hidden connections\r\n                                                     [0, 1, 2, 3, 4],\r\n                                                     [0, 1, 2, 3, 4],\r\n                                                     [0, 1, 2, 3, 4],\r\n                                                     [0, 1, 2, 3, 4],\r\n                                                     [5, 6, 7, 8, 9], # output connections\r\n                                                     [5, 6, 7, 8, 9],\r\n                                                     [5, 6, 7, 8, 9],\r\n                                                     [5, 6, 7, 8, 9],\r\n                                                     ], device='cuda', dtype=torch.long)\r\n        elif n_input_neurons==1 and n_hidden_neurons==0 and n_output_neurons==1:\r\n            print(\"Fixed connection Single\")\r\n            self.synapse_connections = torch.tensor([[0]], device='cuda', dtype=torch.long)\r\n        else:\r\n            print(\"Random connections\")\r\n        # each synapse is connected also to its post-synaptic neuron, to allow STDP type learning to emerge\r\n        self.synapse_connections_post = torch.arange(n_neurons, device='cuda', dtype=torch.long).reshape(n_neurons, -1).repeat(1, n_syn_per_neuron)\r\n        # define compartments\r\n        self.compartments = [self.neurons, self.synapses]\r\n        self.trainable_layers = self.neurons.trainable_layers + self.synapses.trainable_layers\r\n        self.track_data = False\r\n\r\n\r\n    def dump_model(self, e, exp_name):\r\n        \"\"\"Dump model to restore\"\"\"\r\n        with open(get_data_path(e, exp_name, \"Model\"), 'wb') as f:\r\n            parameters = {}\r\n            parameters[\"neuron\"] = [layer.base_parameters.cpu().numpy() for layer in self.neurons.trainable_layers]\r\n            parameters[\"synapse\"] = [layer.base_parameters.cpu().numpy() for layer in self.synapses.trainable_layers]\r\n            pickle.dump(parameters, f)\r\n\r\n    def restore_model(self, e, exp_name):\r\n        \"\"\"Restore model\"\"\"\r\n        with open(get_data_path(e, exp_name, \"Model\"), 'rb') as f:\r\n            parameters = pickle.load(f)\r\n        #TODO: refactor to dump/restore at ENU level and just call those functions\r\n        assert len(self.neurons.trainable_layers) == len(parameters[\"neuron\"])\r\n        for i in range(len(parameters[\"neuron\"])):\r\n            self.neurons.trainable_layers[i].base_parameters = torch.from_numpy(parameters[\"neuron\"][i].astype(np.float32)).cuda()\r\n        assert len(self.synapses.trainable_layers) == len(parameters[\"synapse\"])\r\n        for i in range(len(parameters[\"synapse\"])):\r\n            self.synapses.trainable_layers[i].base_parameters = torch.from_numpy(parameters[\"synapse\"][i].astype(np.float32)).cuda()\r\n\r\n    @staticmethod\r\n    def plot_weights(e, exp_name):\r\n        \"\"\"Visualize weights of ENU gates\"\"\"\r\n        sns.set_style(\"dark\")\r\n        def calc_average(start, stop):\r\n            weights_average = None\r\n            for e in range(start, stop, 1000):\r\n                with open(get_data_path(e, exp_name, \"Model\"), 'rb') as f:\r\n                    parameters = pickle.load(f)\r\n                weights = []\r\n                for i in range(len(parameters[\"neuron\"])):\r\n                    weights += [parameters[\"neuron\"][i].astype(np.float32)]\r\n                if weights_average is None:\r\n                    weights_average = weights\r\n                else:\r\n                    for i in range(len(weights_average)):\r\n                        weights_average[i] += weights[i]\r\n            return weights_average\r\n        weights_mean1 = calc_average(20000, 30000)\r\n        fig, ax = plt.subplots(1, 2, sharex='col', sharey='row')\r\n        for i in range(len(weights_mean1)):\r\n            ax[i].imshow(weights_mean1[i], cmap=\"gray\")\r\n        weights_mean2 = calc_average(30000, 40000)\r\n        fig, ax = plt.subplots(1, 2, sharex='col', sharey='row')\r\n        for i in range(len(weights_mean2)):\r\n            ax[i].imshow(weights_mean2[i], cmap=\"gray\")\r\n        fig, ax = plt.subplots(1, 2, sharex='col', sharey='row')\r\n        for i in range(len(weights_mean2)):\r\n            ax[i].imshow((weights_mean2[i] - weights_mean1[i])**5, cmap=\"gray\")\r\n        plt.show()\r\n\r\n\r\n    def dump_network_activity(self, e, exp_name):\r\n        \"\"\"Dump raw data for visualization\"\"\"\r\n        with open(get_data_path(e, exp_name, \"GlobalNetwork\"), 'wb') as f:\r\n            pickle.dump(self.vis_data, f)\r\n\r\n    def print(self):\r\n        print(\"--Neurons--\")\r\n        self.neurons.print()\r\n        print(\"--Synapses--\")\r\n        self.synapses.print()\r\n\r\n    def reset(self):\r\n        self.vis_data = []\r\n        if self.track_data:\r\n            print(\"Tracking network activity\")\r\n        for compartment in self.compartments:\r\n            compartment.reset()\r\n\r\n    def forward(self, X):\r\n        \"\"\"Main computation forward pass\"\"\"\r\n        # transfer to GPU\r\n        X_raw_gpu = torch.from_numpy(X.astype(np.float32)).cuda()\r\n        X_gpu = torch.zeros((X.shape[0], X.shape[1], self.n_input_channels), device='cuda', dtype=torch.float32)\r\n        X_gpu[:, :, :X_raw_gpu.shape[2]] = X_raw_gpu\r\n        # first compute synapses, set input to previous output of connected neuron\r\n        # concat our input spiking pattern directly to input to our synapses (the neurons)\r\n        # NOTE: this concats in batch dimension, meaning it feeds into input neurons directly spiking pattern, while rest receive input from network\r\n        input_to_synapses = torch.cat([X_gpu, self.neurons.out_mem], dim=1)\r\n        # connect each synapse randomly to multiple inputs\r\n        input_to_synapses_connected = input_to_synapses[:, self.synapse_connections.flatten(), :]\r\n        # need feedback connection from neuron to synapse, to allow stdp type rules to emerge (else it has to do it through feedback connections, but less guarentee on connections and cannot distinguise type)\r\n        # one synapse has 1 pre-synaptic neuron and 1 post-synaptic neuron, connectection defined in synapse_connections, synapse_connections[i, :] gives all input synapses of that neuron\r\n        # so feedback to all it's input synapses through broadcasting backwards\r\n        post_neuron_backprop_connected = self.neurons.out_mem[:, self.synapse_connections_post.flatten(), :]\r\n        input_to_synapses_connected = torch.cat([input_to_synapses_connected, post_neuron_backprop_connected], dim=-1)\r\n        # compute synapse\r\n        self.synapses.forward(input_to_synapses_connected)\r\n        # then integrate(sum) all outputs of a neurons input synapses, can just reshape into valid shape, since we already randomly connected when computing synapses\r\n        # NOTE: each neuron then requires same number of synapses, then reshape by modifying batch dim (which contains syn outputs)\r\n        integration = torch.sum(self.synapses.out.reshape((self.n_offspring, self.n_neurons, -1, self.synapses.shape[-1])), dim=2)\r\n        # scale by number of synapses\r\n        integration /= self.n_syn_per_neuron\r\n        self.out_integration = integration\r\n        # finally set neuron input to summated connected synapses output\r\n        input_to_neurons = integration\r\n        out = self.neurons.forward(input_to_neurons)\r\n        # output is last neuron output, NOTE: just first channel is returned, since we reshape neurons to channels\r\n        self.out = out[:, -self.n_output:, 0].reshape(self.n_offspring, self.n_output)\r\n        if self.track_data:\r\n            self._track_vis_data(X, input_to_synapses_connected, input_to_neurons)\r\n        return self.out\r\n\r\n    def _track_vis_data(self, X, input_to_synapses_connected, input_to_neurons):\r\n        offspring_idx = 0\r\n        self.vis_data += [(X[offspring_idx], input_to_neurons[offspring_idx].cpu().numpy(), self.neurons.out[offspring_idx].cpu().numpy(),\r\n                           input_to_synapses_connected[offspring_idx].cpu().numpy(), self.synapses.out[offspring_idx].cpu().numpy())]\r\n\r\n    @staticmethod\r\n    def plot_network_activity(e, exp_name):\r\n        with open(get_data_path(e, exp_name, \"GlobalNetwork\"), 'rb') as f:\r\n            vis_data = pickle.load(f)\r\n            X, input_to_neurons, neurons_out, input_to_synapses, synapses_out = map(np.array, zip(*vis_data))\r\n        def plot_enu_activity(input, output, title):\r\n            n_cells = output.shape[1]\r\n            n_cells = np.minimum(10, output.shape[1])\r\n            fig, grid = plt.subplots(2, n_cells, sharex='col', sharey='row')\r\n            if n_cells==1:\r\n                grid[0].plot(input[:, 0, :])\r\n                grid[1].plot(output[:, 0, :])\r\n            else:\r\n                for i in range(n_cells):\r\n                    grid[0, i].plot(input[:, i, :])\r\n                    grid[1, i].plot(output[:, i, :])\r\n            plt.xlabel(\"t\")\r\n            plt.title(title)\r\n            #plt.ylabel(\"\")\r\n            plt.legend()\r\n        plt.figure()\r\n        plt.plot(X[:, :, 0])\r\n        plot_enu_activity(input_to_neurons, neurons_out, \"ENU neuron activity\")\r\n        plot_enu_activity(input_to_synapses, synapses_out, \"ENU synapse activity\")\r\n\r\n        plt.figure()\r\n        spike_points = np.where(neurons_out[:, :, 0] > 0)\r\n        plt.scatter(spike_points[0], spike_points[1], marker='|')\r\n\r\n        plt.show()\r\n\r\n\r\n"
  },
  {
    "path": "examples/Structure_Evolution/Adaptive_lsm/BrainCog-Version/tools/ExperimentEnvGlobalNetworkSurvival.py",
    "content": "import pickle\r\n\r\nimport numpy as np\r\n\r\nfrom tools.Tools import get_data_path\r\n\r\n\r\nclass ExperimentEnvGlobalNetworkSurvival:\r\n    \"\"\"Wrapper around a given RL environment for a Network of ENUs model,\r\n    turns reward into fitness and dumps relevant data\"\"\"\r\n\r\n\r\n    def __init__(self, env, exp_name='maze'):\r\n        self.env = env\r\n        self.exp_name = exp_name\r\n        self.n_output = self.env.n_actions\r\n        #NOTE: +1 reward neuron\r\n        self.n_input_neurons = self.env.n_obs + 1\r\n        self.n_agents = self.env.n_agents\r\n\r\n    def _convert_obs(self, obs, rewards):\r\n        n_input_channels_used = 3\r\n        X = np.zeros((self.n_agents, self.n_input_neurons, n_input_channels_used))\r\n        #X[:, :obs.shape[1], 0] = obs\r\n        # Shuffle only obs to avoid topology exploitation, reward neuron linked to EnuGlobal synapse connectivity\r\n        X[:, :obs.shape[1], 0] = np.take_along_axis(obs, self.obs_shuffle, axis=1)\r\n        # split pos and negative reward to different channels, And set to last input neuron\r\n        if rewards is not None:\r\n            X[rewards>0, -1, 1] = np.abs(rewards[rewards>0])\r\n            X[rewards<=0, -1, 2] = np.abs(rewards[rewards<=0])\r\n        return X\r\n\r\n    def _convert_reward(self, obs, actions, rewards, infos, dones):\r\n        fitness = np.copy(rewards)\r\n        # first poison is considered positive reward, since learning to learn\r\n        #NOTE: dead by env means less reward can be obtained so should implictely reduce overall fitness automatically\r\n        fitness[np.logical_and(self._prev_reward_count == 1, rewards != 0)] = 1\r\n        # include episode length as extra fitness, since not taking poison would allow survive longer, so should try avoid take poison\r\n        fitness[dones==0] += 0.1/4\r\n        return fitness\r\n\r\n    def step(self, y):\r\n\r\n        # if self.t % 3 != 0:\r\n        #     actions = np.zeros((self.n_agents), dtype=np.int32) - 1\r\n        # else:\r\n            # winner take all, in given time window\r\n        actions = y\r\n            # if all same output, do nothing\r\n            # equal_actions = self.y_hist.shape[1] == np.sum(self.y_hist == np.take_along_axis(self.y_hist, actions.reshape(-1, 1), axis=1), axis=-1)\r\n            # actions[equal_actions] = -1\r\n            # self.y_hist[:] = 0\r\n        # take env step\r\n\r\n        allobs, obs, rewards, dones, infos = self.env.step(actions)\r\n        # X = self._convert_obs(obs, rewards)\r\n        X=allobs\r\n        self._prev_reward_count += rewards!=0\r\n        fitness = self._convert_reward(obs, actions, rewards, infos, dones)\r\n        self._prev_action = actions\r\n        self._prev_obs = obs\r\n        return X, rewards, fitness, None\r\n\r\n    def reset(self):\r\n        self.t = 0\r\n        self.y_hist = np.zeros((self.n_agents, self.n_output), dtype=np.float32)\r\n        self._prev_action = None\r\n        self._prev_obs = None\r\n        self._prev_reward_count = np.zeros((self.n_agents), dtype=np.float32)\r\n        # each time different input/output neurons should have different meaning, to have learning to learn\r\n        self.obs_shuffle = np.argsort(np.random.randn(self.n_agents, self.n_input_neurons - 1), axis=1, kind='mergesort')\r\n        self.action_shuffle = np.argsort(np.random.randn(self.n_agents, self.n_output), axis=1, kind='mergesort')\r\n        # reset env\r\n        self.allobs,self.obs = self.env.reset()\r\n        # return self._convert_obs(self.obs, None)\r\n        return self.allobs\r\n\r\n    def render(self):\r\n        if self.t%4==0:\r\n            self.env.render()\r\n\r\n    def track_vis_data(self, vis_data, model, X, y_est, t):\r\n        n_fetch = 128\r\n        # TODO: also get our gates from the model\r\n        vis_data+=[(X[:n_fetch, :], y_est[:n_fetch, :])]\r\n\r\n    def dump_vis_data(self, vis_data, fitness_per_offspring, e):\r\n        with open(get_data_path(e, self.exp_name, \"output\"), 'wb') as f:\r\n            pickle.dump((vis_data, fitness_per_offspring), f)\r\n\r\n    @staticmethod\r\n    def load_vis_data(e, exp_name):\r\n        with open(get_data_path(e, exp_name, \"output\"), 'rb') as f:\r\n            vis_data, fitness_per_offspring = pickle.load(f)\r\n        return vis_data, fitness_per_offspring\r\n\r\n    @staticmethod\r\n    def plot_vis_data(e, exp_name):\r\n        vis_data, fitness_per_offspring = ExperimentEnvGlobalNetworkSurvival.load_vis_data(e, exp_name)"
  },
  {
    "path": "examples/Structure_Evolution/Adaptive_lsm/BrainCog-Version/tools/MazeTurnEnvVec.py",
    "content": "import pickle\r\n\r\nimport numpy as np\r\nimport matplotlib.pyplot as plt\r\nimport seaborn as sns\r\n\r\nfrom tools.Tools import save_fig, get_data_path\r\n\r\n# np.random.seed(0)\r\n\r\nclass MazeTurnEnvVec:\r\n    \"\"\"Vectorized RL T-Maze environment written in pure Numpy. We require an efficient environment since we need to evaluate\r\n    and run up to thousands of offspring in parallel\"\"\"\r\n\r\n    def __init__(self, n_agents, n_steps):\r\n        # 4 important points, start point, decision point, food point, dead point.\r\n        # just generate a very large matrix that could fit any maze of any size, then can generate smaller maze as well\r\n        self.n_actions = 3\r\n        self.n_obs = 3\r\n        self.max_size = 7\r\n        self.n_agents = n_agents\r\n        self.n_steps = n_steps\r\n        self.window = plt.figure()\r\n        self.t_maze = True\r\n        self.turn_based = False\r\n        # steps can be longer if poison and need to turn around\r\n        self.steps_to_food = 2\r\n        if self.t_maze:\r\n            self.steps_to_food = 3\r\n        self.steps_to_food += self.steps_to_food*2\r\n        # give some extra leniency\r\n        self.steps_to_food *= 2\r\n    def step(self, actions):\r\n        # L R U D\r\n        # TODO: check legal action or not..\r\n        pos_copy = np.copy(self.agents_pos)\r\n        actions = np.copy(actions)\r\n\r\n        # GIVE TIME UPDATE WEIGHTS\r\n        # actions[self.agents_reset > 0] = -1\r\n        # actions[self.agent_energy<=0] = -1\r\n        self.agents_reset[self.agents_reset > 0] -= 1\r\n        # if turn based\r\n        if self.turn_based:\r\n            Forward = actions == 0\r\n            self.agents_pos[np.logical_and(Forward, self.agent_directions == 0), 1] += 1\r\n            self.agents_pos[np.logical_and(Forward, self.agent_directions == 2), 1] -= 1\r\n            self.agents_pos[np.logical_and(Forward, self.agent_directions == 1), 0] -= 1\r\n            self.agents_pos[np.logical_and(Forward, self.agent_directions == 3), 0] += 1\r\n            L = actions == 1\r\n            self.agent_directions[L] += 1\r\n            R = actions == 2\r\n            self.agent_directions[R] -= 1\r\n            self.agent_directions[self.agent_directions > 3] = 0\r\n            self.agent_directions[self.agent_directions < 0] = 3\r\n        else:\r\n            # or just direct movement\r\n            U = actions == 2\r\n            D = actions == 1\r\n            R = actions == 0\r\n            if self.agents_pos[U].size>0:\r\n                self.agents_pos[U, 0] += 1\r\n                self.agent_directions[U] = 3\r\n            if self.agents_pos[D].size>0:\r\n                self.agents_pos[D, 0] -= 1\r\n                self.agent_directions[D] = 1\r\n            if self.t_maze and self.agents_pos[R].size>0:\r\n                self.agents_pos[R, 1] += 1\r\n                self.agent_directions[R] = 0\r\n\r\n        # UNDO MOVES THAT GOT AGENT INTO WALL\r\n        self.current_cells = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1]]\r\n        self.agents_pos[self.current_cells==1] = pos_copy[self.current_cells==1]\r\n        movement_loss = np.prod(self.agents_pos==pos_copy, axis=-1)\r\n        # CHECK IF FOOD CONSUMED, is reward + pos reset\r\n        consumed_food = np.prod(self.agents_pos==self.food_pos[:, 0, :2], axis=-1).astype(np.bool)\r\n        consumed_pois = np.prod(self.agents_pos==self.food_pos[:, 1, :2], axis=-1).astype(np.bool)\r\n        self.consumed_count += consumed_food.astype(np.int32)\r\n        self.consumed_count_total += consumed_food.astype(np.int32)\r\n        self.consumed_count_pois += consumed_pois.astype(np.int32)\r\n        self._reset_pos(np.logical_or(consumed_food, consumed_pois))\r\n        # self._reset_pos_pois(consumed_pois)\r\n        # self._reset_food(self.consumed_count==self.swap_limit, prob=0.0)\r\n        # reset food for agents that ate food, and swap with some probability\r\n        self._reset_food(self.consumed_count==5, prob=0.5)\r\n        self.rewards = consumed_food.astype(np.float32) - consumed_pois.astype(np.float32) #* 0.5 #- movement_loss.astype(np.float32) * 0.01\r\n        # get observation from current position of each agent\r\n        self.agent_allobs,self.obs = self._get_obs_from_pos()\r\n        # instant dead on second poison\r\n        self.agent_energy += np.where(self.consumed_count_pois<=1, np.abs(self.rewards) * self.steps_to_food, -self.agent_energy * np.abs(self.rewards))\r\n        # energy decay to encourage exploration, agent dies if running out of energy\r\n        self.agent_energy = np.minimum(self.agent_energy, self.steps_to_food)\r\n        self.agent_energy -= 1.0/4\r\n        dones = self.agent_energy<=0\r\n\r\n        return self.agent_allobs,self.obs, self.rewards, dones, None\r\n\r\n    def _reset_pos(self, idxs):\r\n        self.agents_pos[idxs] = [self.start_point, 2]  # set X pos\r\n        self.agent_directions[idxs] = 0\r\n        if self.max_size==5:\r\n            self.agent_directions[idxs] = 1\r\n        self.agents_reset[idxs] = 0\r\n        self.agents_reset_count[idxs] += 1\r\n\r\n    def _reset_pos_pois(self, idxs):\r\n        #NOTE: 2 since we call reset twice!\r\n        #NOTE: turning around already cost 8 steps, then 4x4 more is 16+8, 24, so reset should be much worse\r\n        self.agents_reset[np.logical_and(idxs, self.agents_reset_count>2)] = 64\r\n\r\n    def _reset_food(self, idxs, prob=0.5):\r\n        # swap food with some probability, avoids agent overfitting on environment\r\n        swap = np.take_along_axis(self.random_swap_matrix, self.consumed_count_total.reshape(-1, 1), axis=1).ravel()\r\n        swap_idxs = swap * idxs\r\n        food_loc = np.copy(self.food_pos[swap_idxs, 0, :])\r\n        pois_loc = np.copy(self.food_pos[swap_idxs, 1, :])\r\n        self.food_pos[swap_idxs, 0, :] = pois_loc\r\n        self.food_pos[swap_idxs, 1, :] = food_loc\r\n        # set maze value\r\n        self.mazes[np.arange(self.mazes.shape[0]), self.food_pos[:, 0, 0], self.food_pos[:, 0, 1]] = 2\r\n        self.mazes[np.arange(self.mazes.shape[0]), self.food_pos[:, 1, 0], self.food_pos[:, 1, 1]] = 3\r\n        self.consumed_count[idxs] = 0\r\n\r\n    def reset(self):\r\n        self.consumed_count = np.zeros((self.n_agents), dtype=np.int32)\r\n        self.consumed_count_total = np.zeros_like(self.consumed_count)\r\n        # consistent swapping such that if agent eat food once for all agents swapped with same seed, fair fitness comparison\r\n        max_eat = self.n_steps\r\n        self.random_swap_matrix = np.random.uniform(0, 1, size=(1, max_eat)) >= 0.5\r\n        self.random_swap_matrix = np.repeat(self.random_swap_matrix, int(self.n_agents), axis=0)\r\n\r\n        self.agent_energy = np.zeros((self.n_agents), dtype=np.float32) + self.steps_to_food\r\n        self.consumed_count_pois = np.zeros_like(self.consumed_count)\r\n        #self.swap_limit = np.random.randint(1, 5, size=1)\r\n        #self.swap_limit = np.random.randint(1, 4, size=self.n_agents)\r\n        self.mazes = np.ones((self.n_agents, self.max_size, self.max_size), dtype=np.int32)\r\n        #TODO: support variable maze length\r\n        self.start_point = int(self.max_size/2)\r\n        if self.t_maze:\r\n            self.mazes[:, self.start_point, 2:-1] = 0\r\n            self.mazes[:, 1:-1, -2] = 0\r\n            # FOOD either at -1,-1 or -1,1?\r\n            # two foods: x, y, value\r\n            self.food_pos = np.zeros((self.n_agents, 2, 2), dtype=np.int32)\r\n            self.food_pos[:, :, 1] = self.max_size - 2\r\n            self.food_pos[:, 1, 0] = 1\r\n            self.food_pos[:, 0, 0] = self.max_size - 2\r\n            self._reset_food(np.ones(self.food_pos.shape[0], dtype=np.bool), prob=0.5)\r\n        else:\r\n            self.mazes[:, 1:-1, 1] = 0\r\n            # two foods: x, y, value\r\n            self.food_pos = np.zeros((self.n_agents, 2, 2), dtype=np.int32)\r\n            self.food_pos[:, :, 1] = 1\r\n            self.food_pos[:, 0, 0] = 1\r\n            self.food_pos[:, 1, 0] = (self.max_size - 2)\r\n            self._reset_food(np.ones(self.food_pos.shape[0], dtype=np.bool), prob=0.5)\r\n        # AGENT\r\n        self.agents_pos = np.ones((self.n_agents, 2), dtype=np.int32)\r\n        self.agents_reset = np.zeros((self.n_agents), dtype=np.int32)\r\n        self.agents_reset_count = np.zeros_like(self.agents_reset)\r\n        self.agent_directions = np.zeros((self.n_agents), dtype=np.int32)\r\n        self._reset_pos(np.arange(self.agents_pos.shape[0]))\r\n        # OBS\r\n        self.agent_allobs,self.obs = self._get_obs_from_pos()\r\n        return self.agent_allobs,self.obs\r\n\r\n    def _get_obs_from_pos(self):\r\n        # obs is neighbouring cell states around agent\r\n        obs = np.zeros((self.n_agents, self.n_obs), dtype=np.float32)\r\n        raw_obs = np.zeros(self.n_agents, dtype=np.int32)\r\n        # get observation based on direction agent is facing\r\n        leftobs = np.zeros(self.n_agents)\r\n        rightobs = np.zeros(self.n_agents)\r\n        backobs = np.zeros(self.n_agents)\r\n\r\n        # get observation based on direction agent is facing\r\n        D = self.agent_directions == 0\r\n        raw_obs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1] + 1][D]\r\n        leftobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0] - 1, self.agents_pos[:, 1]][D]\r\n        rightobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0] + 1, self.agents_pos[:, 1]][D]\r\n        backobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1]-1][D]\r\n\r\n\r\n        D = self.agent_directions == 2\r\n        raw_obs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1] - 1][D]\r\n        leftobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0] + 1, self.agents_pos[:, 1]][D]\r\n        rightobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0] - 1, self.agents_pos[:, 1]][D]\r\n        backobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1]+1][D]\r\n\r\n        D = self.agent_directions == 1\r\n        raw_obs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0] - 1, self.agents_pos[:, 1]][D]\r\n        leftobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1] - 1][D]\r\n        rightobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1] + 1][D]\r\n        backobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0]+1, self.agents_pos[:, 1]][D]\r\n\r\n        D = self.agent_directions == 3\r\n        raw_obs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0] + 1, self.agents_pos[:, 1]][D]\r\n        leftobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1] + 1][D]\r\n        rightobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1] - 1][D]\r\n        backobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0]-1, self.agents_pos[:, 1]][D]\r\n\r\n        # mark what was observed at different index\r\n        obs[raw_obs == 1, 0] = 1\r\n        obs[raw_obs == 2, 1] = 1\r\n        obs[raw_obs == 3, 2] = 1\r\n        allobs=np.squeeze(np.dstack((leftobs,raw_obs,rightobs,backobs)))\r\n        return allobs, obs\r\n\r\n    def render(self):\r\n        plt.clf()\r\n        sns.set_style(\"white\")\r\n        #TODO: support render all mazes? can reshape to square?\r\n        max_render = 1\r\n        flattened_render = np.dstack(np.split(self.mazes[18, :], max_render, axis=0)).reshape(self.mazes.shape[1], -1)\r\n        flattened_render[flattened_render>1] = 0\r\n        plt.axis('off')\r\n\r\n        plt.imshow(flattened_render,cmap='bone')\r\n\r\n        for j in range(1):\r\n            i=18\r\n            marker = \">\"\r\n            if self.agent_directions[i] == 1:\r\n                marker = \"^\"\r\n            if self.agent_directions[i] == 2:\r\n                marker = \"<\"\r\n            if self.agent_directions[i] == 3:\r\n                marker = \"v\"\r\n            obs_color = \"black\"\r\n            if self.obs[i, 0] == 1:\r\n                obs_color = \"gray\"\r\n            if self.obs[i, 1] == 1:\r\n                obs_color = \"green\"\r\n            if self.obs[i, 2] == 1:\r\n                obs_color = \"red\"\r\n            alpha = 1\r\n            if self.agent_energy[i]<=0:\r\n                alpha = 1\r\n            plt.scatter(self.agents_pos[i, 1] + j * self.mazes.shape[1], self.agents_pos[i, 0], color=\"skyblue\", alpha=alpha, marker=marker)\r\n            plt.scatter(self.agents_pos[i, 1] + j * self.mazes.shape[1], self.agents_pos[i, 0], color=obs_color,alpha=alpha, marker=marker, s=3)\r\n            plt.scatter(self.food_pos[i, 0, 1] + j * self.mazes.shape[1], self.food_pos[i, 0, 0], color=\"green\", alpha=1, marker=\"o\")\r\n            plt.scatter(self.food_pos[i, 1, 1] + j * self.mazes.shape[1], self.food_pos[i, 1, 0], color=\"red\", alpha=1, marker=\"o\")\r\n        plt.pause(0.001)\r\n        #plt.pause(2)\r\n\r\n    @staticmethod\r\n    def load_vis_data(e, exp_name):\r\n        with open(get_data_path(e, exp_name, \"output\"), 'rb') as f:\r\n            vis_data, fitness_per_offspring = pickle.load(f)\r\n        return vis_data, fitness_per_offspring\r\n\r\n    @staticmethod\r\n    def plot_vis_data(e, exp_name):\r\n        import matplotlib.pyplot as plt\r\n        from cycler import cycler\r\n        import seaborn as sns\r\n        sns.set_style(\"whitegrid\")\r\n\r\n        vis_data, fitness_per_offspring = MazeTurnEnvVec.load_vis_data(e, exp_name)\r\n\r\n        offspring_idx = 0\r\n        #x, y_est, y = np.array(vis_data).transpose(1, 2, 0, 3)\r\n        X, Y_est = map(np.array, zip(*vis_data))\r\n        X_base, y_est_base = X[:, offspring_idx], Y_est[:, offspring_idx]\r\n        #X_base, y_est_base = X_base[:300], y_est_base[:300]\r\n        #X_base = np.max(X_base, axis=1)\r\n        # --- normal output single example---\r\n        plt.rc('axes', prop_cycle=(cycler('color', ['gray', '#ff7f0e', '#9467bd', '#8c564b', '#e377c2', '#17becf'])))\r\n        # OLD METHOD!\r\n        plt.figure()\r\n        #NOTE: last neuron is always reward neuron\r\n        plt.plot(-X_base[:, :-1, 0], label=\"N-ENUs input\", alpha=0.7)\r\n        plt.plot(- np.max(X_base, axis=1)[:, 1], label=\"Positive reward\", alpha=0.7, color='#2ca02c')\r\n        plt.plot(- np.max(X_base, axis=1)[:, 2], label=\"Negative reward\", alpha=0.7, color='#d62728')\r\n        #plt.gca().set_color_cycle(['orange', 'purple', 'brown'])\r\n        plt.plot(y_est_base[:, :], label=\"N-ENUs output\", alpha=0.8)\r\n        plt.legend(loc='upper right')\r\n        save_fig(e, exp_name, \"single_episode\")\r\n\r\n        plt.rc('axes', prop_cycle=(cycler('color', ['gray', '#ff7f0e', '#9467bd', '#8c564b', '#e377c2', '#17becf'])))\r\n        # new method!\r\n        fig, grid = plt.subplots(2, sharex=True)\r\n        # input\r\n        #grid[1].set_prop_cycle(cycler('color', ['gray', '#ff7f0e', '#1f77b4']))\r\n        grid[1].set_prop_cycle(cycler('color', ['gray', '#F5B041', '#2E86C1']))\r\n        grid[1].plot(X_base[:, :-1, 0], label=\"N-ENUs input\", alpha=0.7, linewidth=2)\r\n        grid[1].plot(np.max(X_base, axis=1)[:, 1], label=\"Positive reward\", alpha=0.7, color='#1ABC9C', linewidth=2)\r\n        grid[1].plot(np.max(X_base, axis=1)[:, 2], label=\"Negative reward\", alpha=0.7, color='#CB4335', linewidth=2)\r\n        grid[1].legend(['Sensor (wall)', 'Sensor (red)', 'Sensor (green)', 'Positive reward', 'Negative reward'], loc='upper right')\r\n        grid[1].set_ylabel('Neuron output')\r\n        # output\r\n        grid[0].set_prop_cycle(cycler('color', ['#9467bd', '#e377c2', '#17becf','#8c564b']))\r\n        grid[0].plot(y_est_base[:, :], label=\"N-ENUs output\", alpha=0.8, linewidth=2)\r\n        grid[0].legend(['ENU-NN (left)', 'ENU-NN (right)', 'ENU-NN (forward)'],loc='upper right')\r\n        plt.xlabel('t')\r\n        grid[0].set_ylabel('ENU neuron output')\r\n        plt.xlim(-5, X_base.shape[0]+10)\r\n        save_fig(e, exp_name, \"single_episode_dual\")\r\n        #plt.show()\r\n\r\n    @staticmethod\r\n    def plot_rollout_data(e, exp_name):\r\n        #TODO: dump rollout as array not the actual plots\r\n        import matplotlib.pyplot as plt\r\n        import seaborn as sns\r\n        sns.set_style(\"white\")\r\n        import os\r\n\r\n        rollout_path = \"./\" + get_data_path(e, exp_name, \"rollout\").split(\".\")[1][:-1]+\"/\"\r\n        rollout_files = sorted(os.listdir(rollout_path))\r\n        rollouts = []\r\n        for i in range(4, 200, 4):\r\n            file = \"rollout_{}_.png\".format(i)\r\n            print(file)\r\n            rollout = plt.imread(rollout_path + file)\r\n            rollout = rollout[80:450, 200:500]\r\n            rollout = np.where(rollout[:, :, [0]]> 0.98, 1, rollout)\r\n            rollouts.append(rollout)\r\n            # plt.imshow(rollout)\r\n            # plt.show()\r\n        # for rollout in rollouts:\r\n        vert_bar = np.zeros((rollout.shape[0], 5, 4))\r\n        vert_bar[::] += 0.5\r\n        # get red -> learned go other way\r\n        # food swapped -> sees red -> learned turn around\r\n        rollouts1 = [rollouts[1], rollouts[9], rollouts[10], rollouts[11],  rollouts[17], rollouts[18]]#, rollouts[19]\r\n        #, rollouts[28]\r\n        #rollouts[24],\r\n        rollouts2 = [rollouts[25], rollouts[29], rollouts[30], rollouts[31], rollouts[32], rollouts[33]]\r\n        plt.axis('off')\r\n        plt.imshow(np.vstack([np.column_stack(rollouts1), np.column_stack(rollouts2)]), cmap='gray')\r\n        save_fig(e, exp_name, \"rollout_combined\")\r\n        #plt.show()\r\n\r\n\r\n\r\n\r\nif __name__ == '__main__':\r\n    \"\"\"Test function\"\"\"\r\n    n_offspring = 1024\r\n    envs = MazeTurnEnvVec(n_offspring, n_steps=400)\r\n    envs.n_pseudo_env = 8\r\n    while True:\r\n        envs.reset()\r\n        opt_actions_up = [0,0,0,0,1,0]\r\n        opt_actions_down = [0,0,0,0,2,0]\r\n        opt_actions = [opt_actions_up, opt_actions_down]\r\n        opt_current = 0\r\n        total_reward = 0\r\n        rewards = np.zeros((n_offspring))\r\n        rewards_all = np.zeros_like(rewards)\r\n        k = 0\r\n        n_steps = 200\r\n        for i in range(n_steps):\r\n            actions = np.random.randint(0, envs.n_actions, size=n_offspring)\r\n            #actions[:] = opt_actions_up[i%len(opt_actions_up)]\r\n            #actions[0] = opt_actions[opt_current][k % len(opt_actions_up)]\r\n            obs, rewards,_,_ = envs.step(actions=actions)\r\n            total_reward += (rewards[0] * 100) + 1\r\n            rewards_all += (rewards * 100) + 1\r\n            #print(rewards[0])\r\n            #print(obs[0], rewards[0])\r\n\r\n            envs.render()\r\n            plt.pause(0.5)\r\n\r\n            k += 1\r\n            if rewards[0] < 0:\r\n                #print(rewards_last[0])\r\n                if opt_current==0:\r\n                    opt_current = 1\r\n                else:\r\n                    opt_current = 0\r\n            if rewards[0] != 0:\r\n                k = 0\r\n        total_reward/=n_steps\r\n        rewards_all/=n_steps\r\n        print(rewards_all[0], np.mean(rewards_all), np.std(rewards_all), np.std(rewards_all)/np.mean(rewards_all))\r\n        print(rewards_all[:10])\r\n"
  },
  {
    "path": "examples/Structure_Evolution/Adaptive_lsm/BrainCog-Version/tools/nsganet.py",
    "content": "import numpy as np\n\nfrom pymoo.algorithms.genetic_algorithm import GeneticAlgorithm\nfrom pymoo.docs import parse_doc_string\nfrom pymoo.model.individual import Individual\nfrom pymoo.model.survival import Survival\nfrom pymoo.operators.crossover.point_crossover import PointCrossover\nfrom pymoo.operators.mutation.polynomial_mutation import PolynomialMutation\nfrom pymoo.operators.mutation.bitflip_mutation import BinaryBitflipMutation\nfrom pymoo.operators.sampling.random_sampling import RandomSampling\nfrom pymoo.operators.selection.tournament_selection import compare, TournamentSelection\nfrom pymoo.util.display import disp_multi_objective\nfrom pymoo.util.dominator import Dominator\nfrom pymoo.util.non_dominated_sorting import NonDominatedSorting\nfrom pymoo.util.randomized_argsort import randomized_argsort\n\n\n# =========================================================================================================\n# Implementation\n# based on nsga2 from https://github.com/msu-coinlab/pymoo\n# =========================================================================================================\n\n\nclass NSGANet(GeneticAlgorithm):\n\n    def __init__(self, **kwargs):\n        kwargs['individual'] = Individual(rank=np.inf, crowding=-1)\n        super().__init__(**kwargs)\n\n        self.tournament_type = 'comp_by_dom_and_crowding'\n        self.func_display_attrs = disp_multi_objective\n\n\n# ---------------------------------------------------------------------------------------------------------\n# Binary Tournament Selection Function\n# ---------------------------------------------------------------------------------------------------------\n\n\ndef binary_tournament(pop, P, algorithm, **kwargs):\n    if P.shape[1] != 2:\n        raise ValueError(\"Only implemented for binary tournament!\")\n\n    tournament_type = algorithm.tournament_type\n    S = np.full(P.shape[0], np.nan)\n\n    for i in range(P.shape[0]):\n\n        a, b = P[i, 0], P[i, 1]\n\n        # if at least one solution is infeasible\n        if pop[a].CV > 0.0 or pop[b].CV > 0.0:\n            S[i] = compare(a, pop[a].CV, b, pop[b].CV, method='smaller_is_better', return_random_if_equal=True)\n\n        # both solutions are feasible\n        else:\n\n            if tournament_type == 'comp_by_dom_and_crowding':\n                rel = Dominator.get_relation(pop[a].F, pop[b].F)\n                if rel == 1:\n                    S[i] = a\n                elif rel == -1:\n                    S[i] = b\n\n            elif tournament_type == 'comp_by_rank_and_crowding':\n                S[i] = compare(a, pop[a].rank, b, pop[b].rank,\n                               method='smaller_is_better')\n\n            else:\n                raise Exception(\"Unknown tournament type.\")\n\n            # if rank or domination relation didn't make a decision compare by crowding\n            if np.isnan(S[i]):\n                S[i] = compare(a, pop[a].get(\"crowding\"), b, pop[b].get(\"crowding\"),\n                               method='larger_is_better', return_random_if_equal=True)\n\n    return S[:, None].astype(np.int)\n\n\n# ---------------------------------------------------------------------------------------------------------\n# Survival Selection\n# ---------------------------------------------------------------------------------------------------------\n\n\nclass RankAndCrowdingSurvival(Survival):\n\n    def __init__(self) -> None:\n        super().__init__(True)\n\n    def _do(self, pop, n_survive, D=None, **kwargs):\n\n        # get the objective space values and objects\n        F = pop.get(\"F\")\n\n        # the final indices of surviving individuals\n        survivors = []\n\n        # do the non-dominated sorting until splitting front\n        fronts = NonDominatedSorting().do(F, n_stop_if_ranked=n_survive)\n\n        for k, front in enumerate(fronts):\n\n            # calculate the crowding distance of the front\n            crowding_of_front = calc_crowding_distance(F[front, :])\n\n            # save rank and crowding in the individual class\n            for j, i in enumerate(front):\n                pop[i].set(\"rank\", k)\n                pop[i].set(\"crowding\", crowding_of_front[j])\n\n            # current front sorted by crowding distance if splitting\n            if len(survivors) + len(front) > n_survive:\n                I = randomized_argsort(crowding_of_front, order='descending', method='numpy')\n                I = I[:(n_survive - len(survivors))]\n\n            # otherwise take the whole front unsorted\n            else:\n                I = np.arange(len(front))\n\n            # extend the survivors by all or selected individuals\n            survivors.extend(front[I])\n\n        return pop[survivors]\n\n\ndef calc_crowding_distance(F):\n    infinity = 1e+14\n\n    n_points = F.shape[0]\n    n_obj = F.shape[1]\n\n    if n_points <= 2:\n        return np.full(n_points, infinity)\n    else:\n\n        # sort each column and get index\n        I = np.argsort(F, axis=0, kind='mergesort')\n\n        # now really sort the whole array\n        F = F[I, np.arange(n_obj)]\n\n        # get the distance to the last element in sorted list and replace zeros with actual values\n        dist = np.concatenate([F, np.full((1, n_obj), np.inf)]) \\\n               - np.concatenate([np.full((1, n_obj), -np.inf), F])\n\n        index_dist_is_zero = np.where(dist == 0)\n\n        dist_to_last = np.copy(dist)\n        for i, j in zip(*index_dist_is_zero):\n            dist_to_last[i, j] = dist_to_last[i - 1, j]\n\n        dist_to_next = np.copy(dist)\n        for i, j in reversed(list(zip(*index_dist_is_zero))):\n            dist_to_next[i, j] = dist_to_next[i + 1, j]\n\n        # normalize all the distances\n        norm = np.max(F, axis=0) - np.min(F, axis=0)\n        norm[norm == 0] = np.nan\n        dist_to_last, dist_to_next = dist_to_last[:-1] / norm, dist_to_next[1:] / norm\n\n        # if we divided by zero because all values in one columns are equal replace by none\n        dist_to_last[np.isnan(dist_to_last)] = 0.0\n        dist_to_next[np.isnan(dist_to_next)] = 0.0\n\n        # sum up the distance to next and last and norm by objectives - also reorder from sorted list\n        J = np.argsort(I, axis=0)\n        crowding = np.sum(dist_to_last[J, np.arange(n_obj)] + dist_to_next[J, np.arange(n_obj)], axis=1) / n_obj\n\n    # replace infinity with a large number\n    crowding[np.isinf(crowding)] = infinity\n\n    return crowding\n\n\n# =========================================================================================================\n# Interface\n# =========================================================================================================\n\n\ndef nsganet(\n        pop_size=100,\n        sampling=RandomSampling(var_type=np.int),\n        selection=TournamentSelection(func_comp=binary_tournament),\n        crossover=PointCrossover(n_points=2),\n        mutation=PolynomialMutation(eta=3, var_type=np.int),\n        \n        eliminate_duplicates=True,\n        n_offsprings=None,\n        **kwargs):\n    \"\"\"\n\n    Parameters\n    ----------\n    pop_size : {pop_size}\n    sampling : {sampling}\n    selection : {selection}\n    crossover : {crossover}\n    mutation : {mutation}\n    eliminate_duplicates : {eliminate_duplicates}\n    n_offsprings : {n_offsprings}\n\n    Returns\n    -------\n    nsganet : :class:`~pymoo.model.algorithm.Algorithm`\n        Returns an NSGANet algorithm object.\n\n\n    \"\"\"\n\n    return NSGANet(pop_size=pop_size,\n                   sampling=sampling,\n                   selection=selection,\n                   crossover=crossover,\n                   mutation=mutation,\n                   survival=RankAndCrowdingSurvival(),\n                   eliminate_duplicates=eliminate_duplicates,\n                   n_offsprings=n_offsprings,\n                   **kwargs)\n\n\nparse_doc_string(nsganet)\n"
  },
  {
    "path": "examples/Structure_Evolution/Adaptive_lsm/raw/BCM.py",
    "content": "import argparse, math, os, sys\r\nimport numpy as np\r\nimport gym\r\nfrom gym import wrappers\r\nimport matplotlib.pyplot as plt\r\nimport nsganet as engine\r\nfrom pymop.problem import Problem\r\nfrom pymoo.optimize import minimize\r\nfrom pymoo.operators.sampling.random_sampling import RandomSampling\r\nfrom pymoo.operators.mutation.bitflip_mutation import BinaryBitflipMutation\r\nfrom tools.ExperimentEnvGlobalNetworkSurvival import ExperimentEnvGlobalNetworkSurvival\r\nfrom tools.MazeTurnEnvVec import MazeTurnEnvVec\r\nimport torch\r\nimport torch.nn.utils as utils\r\nfrom torch.distributions import Categorical\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nimport torch.optim as optim\r\nfrom torch.autograd import Variable\r\nfrom itertools import product\r\n\r\nfrom functools import partial\r\nimport torchvision, pprint\r\nfrom timm.models import register_model\r\nfrom braincog.base.node.node import *\r\nfrom braincog.base.connection.layer import *\r\nfrom braincog.model_zoo.base_module import BaseModule\r\nfrom braincog.base.learningrule.BCM import *\r\nfrom braincog.base.learningrule.STDP import *\r\n\r\nparser = argparse.ArgumentParser(description='PyTorch REINFORCE example')\r\nparser.add_argument('--gamma', type=float, default=0.98, metavar='G')\r\nparser.add_argument('--seed', type=int, default=1, metavar='N')\r\nparser.add_argument('--num_steps', type=int, default=500, metavar='N')\r\nparser.add_argument('--num_episodes', type=int, default=100, metavar='N')\r\nparser.add_argument('--render', action='store_true')\r\nargs = parser.parse_args()\r\n\r\n\r\nn_agent = 1\r\nsteps = 500\r\nhidden_size=64\r\nenv = MazeTurnEnvVec(n_agent, n_steps=steps)\r\ndata_env = ExperimentEnvGlobalNetworkSurvival(env)\r\ns_dim = 4\r\na_dim = 3\r\ndef randbool(size, p):\r\n    return torch.rand(*size) < p\r\n\r\n\r\ndef fit(agent):\r\n    states = list(product([0, 1], repeat=4))\r\n    ls_list=[]\r\n    for state in states:\r\n        agent.model.reset()\r\n        state_tensor = torch.tensor(state).float().reshape(1, -1)\r\n        la, ls = agent.model(Variable(state_tensor.float()).reshape(-1,)) \r\n        ls_np = ls.detach().numpy() \r\n        ls_list.append(ls_np)\r\n\r\n    ls_matrix = np.vstack(ls_list)\r\n\r\n    rank = np.linalg.matrix_rank(ls_matrix)\r\n\r\n    return rank\r\n\r\n@register_model\r\nclass SNN(BaseModule):                                        \r\n    def __init__(self,\r\n                 hidden_size,\r\n                 n_agent,\r\n                 connectivity_matrix,\r\n                 num_classes=3,\r\n                 step=1,\r\n                 node_type=LIFNode,\r\n                 encode_type='direct',\r\n                 ins=4,\r\n                 lsm_th=0.3,\r\n                 fc_th=0.3,\r\n                 lsm_tau=3,\r\n                 fc_tau=3,\r\n                 tw=100,\r\n                 *args,\r\n                 **kwargs):\r\n        super().__init__(step, encode_type, *args, **kwargs)\r\n        self.linear1 = nn.Linear(s_dim, hidden_size)      \r\n        self.node=partial(node_type, **kwargs, step=step,tau=lsm_tau,threshold=lsm_th)    \r\n        self.linear2 = nn.Linear(hidden_size, a_dim)\r\n\r\n        self.node_lsm=partial(node_type, **kwargs, step=step,tau=lsm_tau,threshold=lsm_th)\r\n        self.node_fc = partial(node_type, **kwargs, step=step,tau=fc_tau,threshold=fc_th)\r\n        self.hidden_size=hidden_size\r\n        self.out = torch.zeros(hidden_size)\r\n        self.con=[]\r\n        self.learning_rule=[]\r\n        self.connectivity_matrix=connectivity_matrix\r\n        w1tmp=nn.Linear(ins,hidden_size,bias=False)\r\n        self.con.append(w1tmp)\r\n        w2tmp=nn.Linear(hidden_size,hidden_size,bias=False)\r\n        self.liquid_weight=w2tmp.weight.data\r\n        w2tmp.weight.data=w2tmp.weight.data*self.connectivity_matrix\r\n        self.con.append(w2tmp)\r\n        self.learning_rule.append(BCM(self.node_lsm(), [self.con[0], self.con[1]])) \r\n        self.fc = nn.Linear(hidden_size,num_classes)\r\n        self.learning_rule.append(BCM(self.node_fc(), [self.fc])) \r\n\r\n\r\n    def forward(self, x):\r\n        sum_spike=0\r\n        time_window=20\r\n        self.tw=time_window\r\n        self.firing_tw=torch.zeros(time_window, self.hidden_size)\r\n        self.out = torch.zeros(self.hidden_size)\r\n        for t in range(time_window):\r\n            self.out, self.dw = self.learning_rule[0](x, self.out)\r\n            self.con[1].weight.data+=self.dw[1]\r\n            out_liquid=self.out[0:self.hidden_size]\r\n            xout,dw = self.learning_rule[1](out_liquid)\r\n            self.fc.weight.data+=dw[0]\r\n            sum_spike=sum_spike+xout\r\n            self.firing_tw[t]=out_liquid\r\n        outputs = sum_spike+0.0001 / time_window\r\n        return outputs,out_liquid\r\n\r\nclass REINFORCE:\r\n    def __init__(self, lm):\r\n        self.model = SNN(ins=4,n_agent=n_agent,hidden_size=hidden_size,lsm_tau=2,lsm_th=0.2,connectivity_matrix=lm)\r\n        self.model.train()\r\n\r\n\r\n    def select_action(self, state):\r\n        # mu, sigma_sq = self.model(Variable(state).cuda())\r\n        prob,_= self.model(Variable(state).reshape(-1,))\r\n        dist = Categorical(probs=prob)\r\n        action = dist.sample()\r\n        log_prob = prob[action.item()].log()\r\n        entropy = dist.entropy()\r\n        return action, log_prob, entropy\r\n\r\n\r\n\r\n\r\nclass Evolve(Problem):\r\n    # first define the NAS problem (inherit from pymop)\r\n    def __init__(self, n_var=20, n_obj=1, n_constr=0, lb=None, ub=None):\r\n        super().__init__(n_var=n_var, n_obj=n_obj, n_constr=n_constr, type_var=np.int64)\r\n        self.xl = lb\r\n        self.xu = ub\r\n        self._n_evaluated = 0  # keep track of how many architectures are sampled\r\n\r\n\r\n    def _evaluate(self, x, out, *args, **kwargs):\r\n        \r\n        objs = np.full((x.shape[0], self.n_obj), np.nan)\r\n        for i in range(x.shape[0]):\r\n            arch_id = self._n_evaluated + 1\r\n            print('Network= {}'.format(arch_id))\r\n\r\n            agent = REINFORCE(torch.from_numpy(x[i].reshape(hidden_size,hidden_size)).float())\r\n            log_reward = []\r\n            log_smooth = []\r\n            # gamma=np.linspace(0.9,1.0,100)\r\n            gam=0.9\r\n            # for gam in gamma:\r\n            for i_episode in range(100):\r\n                state = torch.tensor(data_env.reset()).unsqueeze(0)\r\n                entropies = []\r\n                log_probs = []\r\n                rewards = []\r\n                old_dis = np.ones([1,])*13\r\n                reawrd_perstep=[]\r\n                ss=0\r\n                allrewards=[]\r\n                for t in range(500): \r\n                    action, log_prob, entropy = agent.select_action(state.float())\r\n                    action=action.unsqueeze(0).numpy()\r\n                    next_state, envreward, done, _ = data_env.step(action)\r\n                    entropies.append(entropy)\r\n                    log_probs.append(log_prob)\r\n                    state = torch.Tensor([next_state])\r\n                    rewards.append(envreward[0])\r\n                print(\"Episode: {}, reward: {}\".format(i_episode, np.sum(rewards)))\r\n                log_reward.append(np.sum(rewards))\r\n                if i_episode == 0:\r\n                    log_smooth.append(log_reward[-1])\r\n                else:\r\n                    log_smooth.append(log_smooth[-1]*0.99+0.01*np.sum(rewards))\r\n                plt.plot(log_smooth)\r\n                plt.plot(log_reward)\r\n                plt.pause(1e-5)\r\n\r\n            objs[i, 0] = fit(agent)\r\n            self._n_evaluated += 1\r\n        out[\"F\"] = objs\r\n\r\ndef do_every_generations(algorithm):\r\n    gen = algorithm.n_gen\r\n    pop_var = algorithm.pop.get(\"X\")\r\n    pop_obj = algorithm.pop.get(\"F\")\r\n\r\nif __name__ == \"__main__\":\r\n    n_agent=1\r\n    kkk = Evolve(n_var=hidden_size*hidden_size, \r\n                    n_obj=1, n_constr=0)\r\n    method = engine.nsganet(pop_size=n_agent,\r\n                            sampling=RandomSampling(var_type='custom'),\r\n                            mutation=BinaryBitflipMutation(),\r\n                            n_offsprings=10,\r\n                            eliminate_duplicates=True)\r\n    kres=minimize(kkk,\r\n                    method,\r\n                    callback=do_every_generations,\r\n                    termination=('n_gen', 1000))\r\n    \r\n"
  },
  {
    "path": "examples/Structure_Evolution/Adaptive_lsm/raw/README.md",
    "content": "\n# Adaptive structure evolution and biologically plausible synaptic plasticity for recurrent spiking neural networks #\n\n## Requirments ##\n* numpy\n* pytorch >= 1.12.0\n\n## Run ##\n\n```python BCM.py```\n\n## Citation ##\nIf you find the code and dataset useful in your research, please consider citing:\n```\n\n@article{pan2023adaptive,\n\ttitle = {Adaptive structure evolution and biologically plausible synaptic plasticity for recurrent spiking neural networks},\n\tauthor = {Pan, Wenxuan and Zhao, Feifei and Zeng, Yi and Han, Bing},\n\tjournal = {Scientific Reports},\n\tvolume = {13},\n\tnumber = {1},\n\tpages = {16924},\n\tyear = {2023},\n\turl = {https://doi.org/10.1038/s41598-023-43488-x},\n\tdoi = {10.1038/s41598-023-43488-x},\n}\n\n@article{zeng2023braincog,\n  title={BrainCog: A spiking neural network based, brain-inspired cognitive intelligence engine for brain-inspired AI and brain simulation},\n  author={Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and others},\n  journal={Patterns},\n  volume={4},\n  number={8},\n  year={2023},\n  publisher={Elsevier}\n}\n```\n"
  },
  {
    "path": "examples/Structure_Evolution/Adaptive_lsm/raw/lstm.py",
    "content": "import argparse, math, os, sys\r\n\r\nfrom re import S\r\nfrom aiohttp import ServerDisconnectedError\r\nimport numpy as np\r\nimport gym\r\nfrom gym import wrappers\r\nimport matplotlib.pyplot as plt\r\n\r\nfrom tools.ExperimentEnvGlobalNetworkSurvival import ExperimentEnvGlobalNetworkSurvival\r\nfrom tools.MazeTurnEnvVec import MazeTurnEnvVec\r\nimport torch\r\nfrom torch.autograd import Variable\r\nimport torch.autograd as autograd\r\nimport torch.nn.utils as utils\r\nfrom torch.distributions import Categorical\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nimport torch.optim as optim\r\n\r\n\r\nparser = argparse.ArgumentParser(description='PyTorch REINFORCE example')\r\nparser.add_argument('--gamma', type=float, default=0.98, metavar='G')\r\nparser.add_argument('--seed', type=int, default=598, metavar='N')\r\nparser.add_argument('--num_steps', type=int, default=500, metavar='N')\r\nparser.add_argument('--num_episodes', type=int, default=1000, metavar='N')\r\nparser.add_argument('--hidden_size', type=int, default=128, metavar='N')\r\nparser.add_argument('--render', action='store_true')\r\nargs = parser.parse_args()\r\n\r\n\r\nn_agent = 1\r\nsteps = 500\r\n\r\nenv = MazeTurnEnvVec(n_agent, n_steps=steps)\r\ndata_env = ExperimentEnvGlobalNetworkSurvival(env)\r\ns_dim = 4\r\na_dim = 3\r\n\r\n\r\nclass Policy(nn.Module):                                           \r\n    def __init__(self, hidden_size, s_dim, a_dim):\r\n        super(Policy, self).__init__()\r\n        self.lstm = nn.LSTM(s_dim, hidden_size, batch_first = True)\r\n        self.linear1 = nn.Linear(hidden_size, hidden_size)          \r\n        self.linear2 = nn.Linear(hidden_size, a_dim)\r\n\r\n\r\n    def forward(self, x,hidden):\r\n        x, hidden = self.lstm(x, hidden)\r\n        x = F.relu(self.linear1(x))\r\n        p = F.softmax(self.linear2(x),-1)                            \r\n        return p,hidden\r\n\r\nclass REINFORCE:\r\n    def __init__(self, hidden_size, s_dim, a_dim):\r\n        self.model = Policy(hidden_size, s_dim, a_dim)   \r\n        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-2) # \r\n        self.model.train()\r\n        self.pi = Variable(torch.FloatTensor([math.pi])) # \r\n\r\n\r\n    def select_action(self, state,hx,cx):\r\n        # mu, sigma_sq = self.model(Variable(state).cuda())\r\n        prob,(hx,cx) = self.model(Variable(state),(hx,cx))\r\n        dist = Categorical(probs=prob)\r\n        action = dist.sample()\r\n        log_prob = prob[0][0,action.item()].log()\r\n        # log_prob = prob.log()\r\n        entropy = dist.entropy()\r\n        \r\n        return action, log_prob, entropy\r\n\r\n    def update_parameters(self, rewards, log_probs, entropies, gamma):# 更新参数\r\n        R = torch.tensor(0)\r\n        loss = 0\r\n        for i in reversed(range(len(rewards))):\r\n            R = gamma * R + rewards[i]                                \r\n            loss = loss - (log_probs[i]*Variable(R)) - 0.005*entropies[i][0]\r\n            \r\n        loss = loss / len(rewards)\r\n\r\n        self.optimizer.zero_grad()\r\n        loss.backward()\r\n        utils.clip_grad_norm_(self.model.parameters(), 2)             \r\n        self.optimizer.step()\r\n\r\nseeds=20\r\nfor seed in range(seeds):\r\n    log_reward = []\r\n    log_smooth = []\r\n    gamma=np.linspace(0.9,1.0,100)\r\n    for g in range(100):\r\n        agent = REINFORCE(args.hidden_size,s_dim,a_dim)\r\n        result=np.zeros([100,args.num_steps])\r\n        for i_episode in range(args.num_episodes):\r\n            state = torch.tensor(data_env.reset()).unsqueeze(0)\r\n            entropies = []\r\n            log_probs = []\r\n            rewards = []\r\n            old_dis = np.ones([1,])*13\r\n            reawrd_perstep=[]\r\n            allrewards=[]\r\n            hx = torch.zeros(args.hidden_size).unsqueeze(0).unsqueeze(0)\r\n            cx = torch.zeros(args.hidden_size).unsqueeze(0).unsqueeze(0)\r\n            for t in range(args.num_steps): # 1个episode最长num_steps\r\n                action, log_prob, entropy = agent.select_action(state.unsqueeze(0).float(),hx,cx)\r\n                action = action.cpu().numpy()\r\n                next_state, envreward, done, _ = data_env.step(action[0])\r\n                entropies.append(entropy)\r\n                log_probs.append(log_prob)\r\n                state = torch.Tensor([next_state])\r\n\r\n                rewards.append(envreward[0])\r\n            agent.update_parameters(rewards, log_probs, entropies, gamma[g])\r\n\r\n            print(\"Episode: {}, reward: {}\".format(i_episode, np.sum(rewards)))\r\n\r\n            log_reward.append(np.sum(rewards))\r\n            if i_episode == 0:\r\n                log_smooth.append(log_reward[-1])\r\n            else:\r\n                log_smooth.append(log_smooth[-1]*0.99+0.01*np.sum(rewards))\r\n\r\n            plt.plot(log_smooth)\r\n            plt.plot(log_reward)\r\n            plt.pause(1e-5)\r\n        result[g]=np.array(allrewards).squeeze(1)\r\n\r\n    np.save('./lstm.npy',result)\r\n\r\n"
  },
  {
    "path": "examples/Structure_Evolution/Adaptive_lsm/raw/main.py",
    "content": "import argparse, math, os, sys\r\nimport numpy as np\r\nimport gym\r\nfrom gym import wrappers\r\nimport matplotlib.pyplot as plt\r\n\r\nfrom tools.ExperimentEnvGlobalNetworkSurvival import ExperimentEnvGlobalNetworkSurvival\r\nfrom tools.MazeTurnEnvVec import MazeTurnEnvVec\r\nimport torch\r\nfrom torch.autograd import Variable\r\nimport torch.autograd as autograd\r\nimport torch.nn.utils as utils\r\nfrom torch.distributions import Categorical\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nimport torch.optim as optim\r\n\r\n\r\nparser = argparse.ArgumentParser(description='PyTorch REINFORCE example')\r\nparser.add_argument('--gamma', type=float, default=0.98, metavar='G')\r\nparser.add_argument('--seed', type=int, default=1, metavar='N')\r\nparser.add_argument('--num_steps', type=int, default=500, metavar='N')\r\nparser.add_argument('--num_episodes', type=int, default=100, metavar='N')\r\nparser.add_argument('--hidden_size', type=int, default=128, metavar='N')\r\nparser.add_argument('--render', action='store_true')\r\nargs = parser.parse_args()\r\n\r\n\r\nn_agent = 1\r\nsteps = 500\r\n\r\nenv = MazeTurnEnvVec(n_agent, n_steps=steps)\r\ndata_env = ExperimentEnvGlobalNetworkSurvival(env)\r\ns_dim = 4\r\na_dim = 3\r\n\r\n\r\n\r\n\r\nclass Policy(nn.Module):                                        \r\n    def __init__(self, hidden_size, s_dim, a_dim):\r\n        super(Policy, self).__init__()\r\n\r\n        self.linear1 = nn.Linear(s_dim, hidden_size)          \r\n        self.linear2 = nn.Linear(hidden_size, a_dim)\r\n\r\n\r\n    def forward(self, x):\r\n        x = F.relu(self.linear1(x))\r\n        p = F.softmax(self.linear2(x),-1)                          \r\n        return p\r\n\r\nclass REINFORCE:\r\n    def __init__(self, hidden_size, s_dim, a_dim):\r\n        self.model = Policy(hidden_size, s_dim, a_dim)    \r\n        # self.model = self.model.cuda()                              \r\n        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-2) \r\n        self.model.train()\r\n        self.pi = Variable(torch.FloatTensor([math.pi])) \r\n\r\n\r\n    def select_action(self, state):\r\n        # mu, sigma_sq = self.model(Variable(state).cuda())\r\n        prob = self.model(Variable(state))\r\n        dist = Categorical(probs=prob)\r\n        action = dist.sample()\r\n        log_prob = prob[0,action.item()].log()\r\n        # log_prob = prob.log()\r\n        entropy = dist.entropy()\r\n        \r\n        return action, log_prob, entropy\r\n\r\n    def update_parameters(self, rewards, log_probs, entropies, gamma):\r\n        R = torch.tensor(0)\r\n        loss = 0\r\n        for i in reversed(range(len(rewards))):\r\n            R = gamma * R + rewards[i]                         \r\n            loss = loss - (log_probs[i]*Variable(R)) - 0.005*entropies[i][0]\r\n            \r\n        loss = loss / len(rewards)\r\n\r\n        self.optimizer.zero_grad()\r\n        loss.backward()\r\n        utils.clip_grad_norm_(self.model.parameters(), 2)           \r\n        self.optimizer.step()\r\nseeds=20\r\nfor seed in range(seeds):\r\n\t# torch.manual_seed(args.seed)                                   \r\n\t# np.random.seed(args.seed)\r\n    agent = REINFORCE(args.hidden_size,s_dim,a_dim)\r\n    log_reward = []\r\n    log_smooth = []\r\n    gamma=np.linspace(0.9,1.0,100)\r\n    for gam in gamma:\r\n        for i_episode in range(args.num_episodes):\r\n            state = torch.tensor(data_env.reset()).unsqueeze(0)\r\n            entropies = []\r\n            log_probs = []\r\n            rewards = []\r\n            old_dis = np.ones([1,])*13\r\n            reawrd_perstep=[]\r\n            ss=0\r\n            allrewards=[]\r\n            for t in range(args.num_steps): \r\n                action, log_prob, entropy = agent.select_action(state.float())\r\n                action = action.cpu().numpy()\r\n                next_state, envreward, done, _ = data_env.step(action)\r\n                entropies.append(entropy)\r\n                log_probs.append(log_prob)\r\n                state = torch.Tensor([next_state])\r\n                rewards.append(envreward[0])\r\n            agent.update_parameters(rewards, log_probs, entropies, gam)\r\n            print(\"Episode: {}, reward: {}\".format(i_episode, np.sum(rewards)))\r\n            log_reward.append(np.sum(rewards))\r\n            if i_episode == 0:\r\n                log_smooth.append(log_reward[-1])\r\n            else:\r\n                log_smooth.append(log_smooth[-1]*0.99+0.01*np.sum(rewards))\r\n            plt.plot(log_smooth)\r\n            plt.plot(log_reward)\r\n            plt.pause(1e-5)\r\n"
  },
  {
    "path": "examples/Structure_Evolution/Adaptive_lsm/raw/pltbcm.py",
    "content": "import numpy as np\r\nimport matplotlib.pyplot as plt\r\nfrom matplotlib import pyplot\r\nimport matplotlib as mpl\r\nfrom scipy.ndimage import gaussian_filter1d\r\nsigm=3\r\n# mpl.rcParams['font.size']=x\r\nplt.style.use('seaborn-whitegrid')\r\nplt.figure( figsize=(8,8) )\r\nax = plt.subplot()\r\npalette = pyplot.get_cmap('Set1')\r\nfont1 = {'family': 'Times New Roman',\r\n         'weight': 'normal',\r\n         'size': 14,\r\n         }\r\nsteps=500\r\nt = [i for i in range(steps)]\r\n########################################BCM+BCM\r\nbcm=np.load('./10rewards.npy')\r\nfor e in range(bcm.shape[0]):\r\n    sum1=0\r\n    sum2=0\r\n    best_agent_id=np.argmax(np.sum(bcm[e,:,:],axis=0))\r\n    best_agent=bcm[e,:,best_agent_id]\r\n    best_agent=best_agent[:steps]\r\n    for i in range(steps): #累积\r\n        sum2=sum2+best_agent[i]\r\n        best_agent[i]=sum2\r\n    bcm[e]=best_agent\r\navg = np.mean(bcm, axis=0)\r\nstd = np.std(bcm, axis=0)\r\nr1 = list(map(lambda x: x[0] - x[1], zip(avg, std)))  \r\nr2 = list(map(lambda x: x[0] + x[1], zip(avg, std)))  \r\ny_smoothed = gaussian_filter1d(avg, sigma=40)\r\nr1 = gaussian_filter1d(r1, sigma=40)\r\nr2 = gaussian_filter1d(r2, sigma=40)\r\ncolor = palette(0)  \r\nax.plot(t, y_smoothed, color=color, label=\"Evolved model with DA-BCM\", linewidth=3.0)\r\nax.fill_between(t, r1, r2, color=color, alpha=0.2)\r\nprint(\"Evolved model with DA-BCM\")\r\nprint(avg[-1],avg[-1]-r1[-1],r2[-1]-avg[-1])\r\n########################################unbcm\r\nunbcm=np.load('./unevolved_with_bcm.npy')\r\navg = np.mean(unbcm, axis=0)\r\nstd = np.std(unbcm, axis=0)\r\nr1 = list(map(lambda x: x[0] - x[1], zip(avg, std)))  \r\nr2 = list(map(lambda x: x[0] + x[1], zip(avg, std)))  \r\ny_smoothed = gaussian_filter1d(avg, sigma=sigm)\r\nr1 = gaussian_filter1d(r1, sigma=sigm)\r\nr2 = gaussian_filter1d(r2, sigma=sigm)\r\ncolor = palette(1)  \r\nax.plot(t, y_smoothed, color=color, label=\"Unevolved model with DA-BCM\", linewidth=3.0)\r\nax.fill_between(t, r1, r2, color=color, alpha=0.2)\r\nprint(\"Unevolved model with DA-BCM+DA-BCM\")\r\nprint(avg[-1],avg[-1]-r1[-1],r2[-1]-avg[-1])\r\n########################################none+bcm\r\nnonbcm=np.load('./none_bcm.npy')\r\navg = np.mean(nonbcm, axis=0)\r\nstd = np.std(nonbcm, axis=0)\r\nr1 = list(map(lambda x: x[0] - x[1], zip(avg, std)))  \r\nr2 = list(map(lambda x: x[0] + x[1], zip(avg, std)))  \r\ny_smoothed = gaussian_filter1d(avg, sigma=sigm)\r\nr1 = gaussian_filter1d(r1, sigma=sigm)\r\nr2 = gaussian_filter1d(r2, sigma=sigm)\r\ncolor = palette(2)  \r\nax.plot(t, y_smoothed, color=color, label=\"Evolved model with NONE+DA-BCM\", linewidth=3.0)\r\nax.fill_between(t, r1, r2, color=color, alpha=0.2)\r\nprint(\"Evolved model with none+DA-BCM\")\r\nprint(avg[-1],avg[-1]-r1[-1],r2[-1]-avg[-1])\r\n########################################stdp+bcm\r\nstdpbcm=np.load('./stdp_bcm.npy')\r\navg = np.mean(stdpbcm, axis=0)\r\nstd = np.std(stdpbcm, axis=0)\r\nr1 = list(map(lambda x: x[0] - x[1], zip(avg, std)))  \r\nr2 = list(map(lambda x: x[0] + x[1], zip(avg, std)))  \r\ny_smoothed = gaussian_filter1d(avg, sigma=sigm)\r\nr1 = gaussian_filter1d(r1, sigma=sigm)\r\nr2 = gaussian_filter1d(r2, sigma=sigm)\r\ncolor = palette(5)  \r\nax.plot(t, y_smoothed, color=color, label=\"Evolved model with STDP+DA-BCM\", linewidth=3.0)\r\nax.fill_between(t, r1, r2, color=color, alpha=0.2)\r\nprint(\"Evolved model with STDP+DA-BCM\")\r\nprint(avg[-1],avg[-1]-r1[-1],r2[-1]-avg[-1])\r\n########################################LSTM\r\nlstm=np.load('./lstm.npy')\r\navg = np.mean(lstm, axis=0)\r\nstd = np.std(lstm, axis=0)\r\nr1 = list(map(lambda x: x[0] - x[1], zip(avg, std)))  \r\nr2 = list(map(lambda x: x[0] + x[1], zip(avg, std)))  \r\ny_smoothed = gaussian_filter1d(avg, sigma=sigm)\r\nr1 = gaussian_filter1d(r1, sigma=sigm)\r\nr2 = gaussian_filter1d(r2, sigma=sigm)\r\ncolor = palette(3)  \r\nax.plot(t, y_smoothed, color=color, label=\"LSTM\", linewidth=3.0)\r\nax.fill_between(t, r1, r2, color=color, alpha=0.2)\r\nprint(\"LSTM\")\r\nprint(avg[-1],avg[-1]-r1[-1],r2[-1]-avg[-1])\r\n########################################Q-learning\r\nql=np.load('./ql.npy')\r\navg = np.mean(ql, axis=0)\r\nstd = np.std(ql, axis=0)\r\nr1 = list(map(lambda x: x[0] - x[1], zip(avg, std)))  \r\nr2 = list(map(lambda x: x[0] + x[1], zip(avg, std)))  \r\ny_smoothed = gaussian_filter1d(avg, sigma=sigm)\r\nr1 = gaussian_filter1d(r1, sigma=sigm)\r\nr2 = gaussian_filter1d(r2, sigma=sigm)\r\ncolor = palette(6)  \r\nax.plot(t, y_smoothed, color=color, label=\"Q-learning\", linewidth=3.0)\r\nax.fill_between(t, r1, r2, color=color, alpha=0.2)\r\nprint(\"Q-learning\")\r\nprint(avg[-1],avg[-1]-r1[-1],r2[-1]-avg[-1])\r\n########################################STDP\r\nstdp=np.load('./inac.npy')\r\navg = np.mean(stdp, axis=0)\r\nstd = np.std(stdp, axis=0)\r\nr1 = list(map(lambda x: x[0] - x[1], zip(avg, std)))  \r\nr2 = list(map(lambda x: x[0] + x[1], zip(avg, std)))  \r\ny_smoothed = gaussian_filter1d(avg, sigma=sigm)\r\nr1 = gaussian_filter1d(r1, sigma=sigm)\r\nr2 = gaussian_filter1d(r2, sigma=sigm)\r\ncolor = palette(4)  \r\nax.plot(t, y_smoothed, color=color, label=\"Evolved STDP\", linewidth=3.0)\r\nax.fill_between(t, r1, r2, color=color, alpha=0.2)\r\nprint(\"Evolved STDP\")\r\nprint(avg[-1],avg[-1]-r1[-1],r2[-1]-avg[-1])\r\n\r\n\r\n\r\n\r\nax.tick_params(labelsize=16)\r\nax.spines['right'].set_color('black')\r\nax.spines['top'].set_color('black')\r\nax.spines['left'].set_color('black')\r\nax.spines['bottom'].set_color('black')\r\nax.legend(loc='upper left', prop=font1)\r\nplt.xlabel('Steps', fontsize=18)\r\nplt.ylabel('Average Reward', fontsize=18)\r\nplt.savefig('./bcm.png')\r\nplt.show()\r\n"
  },
  {
    "path": "examples/Structure_Evolution/Adaptive_lsm/raw/pltrank.py",
    "content": "import numpy as np\r\nimport matplotlib.pyplot as plt\r\nfrom matplotlib import pyplot\r\nimport matplotlib as mpl\r\nfrom scipy.ndimage import gaussian_filter1d\r\nplt.figure( figsize=(8,8) ) \r\nsteps=1000\r\nt = [i for i in range(steps)]\r\nplt.style.use('seaborn-whitegrid')\r\npalette = pyplot.get_cmap('Set1')\r\nfont1 = {'family': 'Times New Roman',\r\n         'weight': 'normal',\r\n         'size': 18,\r\n         }\r\n\r\nkk=np.load('./rank.npy')\r\navg = np.mean(kk, axis=0)\r\nstd = np.std(kk, axis=0)\r\nr1 = list(map(lambda x: x[0] - x[1], zip(avg, std)))  \r\nr2 = list(map(lambda x: x[0] + x[1], zip(avg, std)))  \r\nr1 = gaussian_filter1d(r1, sigma=20)\r\nr2 = gaussian_filter1d(r2, sigma=20)\r\ny_smoothed = gaussian_filter1d(avg, sigma=20)\r\n\r\ncolor = palette(0)  \r\nax = plt.subplot()\r\nax.plot(t, y_smoothed, color=color, label=\"Average Fitness\", linewidth=3.0)\r\nax.fill_between(t, r1, r2, color=color, alpha=0.2)\r\nax.tick_params(labelsize=18)\r\nax.spines['right'].set_color('black')\r\nax.spines['top'].set_color('black')\r\nax.spines['left'].set_color('black')\r\nax.spines['bottom'].set_color('black')\r\n\r\nax.legend(loc='lower right', prop=font1)\r\nplt.xlabel('generations', fontsize=18)\r\nplt.ylabel('SP', fontsize=18)\r\nplt.savefig('./rank.png')\r\nplt.show()\r\n"
  },
  {
    "path": "examples/Structure_Evolution/Adaptive_lsm/raw/q_l.py",
    "content": "import random\r\nimport time\r\nimport tkinter as tk\r\nimport pandas as pd\r\nfrom tools.ExperimentEnvGlobalNetworkSurvival import ExperimentEnvGlobalNetworkSurvival\r\nfrom tools.MazeTurnEnvVec import MazeTurnEnvVec\r\nimport numpy as np\r\nfrom matplotlib import pyplot as plt\r\nsteps=500\r\nt=[i for i in range(steps)]\r\nclass Agent(object):\r\n    '''个体类'''\r\n    MAZE_R = 6  \r\n    MAZE_C = 6 \r\n\r\n    def __init__(self, env,alpha=0.1, gamma=0.9):\r\n        '''初始化'''\r\n        self.states = {} \r\n        self.actions = 3  \r\n    \r\n        self.alpha = alpha\r\n        self.gamma = gamma\r\n        self.q_table = np.zeros([32,3])\r\n\r\n    def choose_action(self,state,epsilon=0.8):\r\n        '''选择相应的动作。根据当前状态，随机或贪婪，按照参数epsilon'''\r\n\r\n        if random.uniform(0, 1) > epsilon:  \r\n            action = random.choice([0,1,2])\r\n        else:\r\n            max_index=(self.q_table[state] == self.q_table[state].max()).nonzero()\r\n            if len(max_index)==1:\r\n                max_qvalue_actions=max_index[0]\r\n            else:\r\n                max_qvalue_actions=max_index[:][1]\r\n            action = random.choice(np.array(max_qvalue_actions))\r\n        return np.array([action])\r\n\r\n\r\n    def update_q_value(self, state, action, next_state_reward, next_state_q_values):\r\n        self.q_table[state, action] += self.alpha * (\r\n                next_state_reward + self.gamma * next_state_q_values.max() - self.q_table[state, action])\r\n\r\n    def add_state(self,X_next):\r\n        x_str = ','.join(str(i) for i in X_next.astype(int))\r\n        if (x_str in self.states) == False:\r\n            self.states[x_str] = max(self.states.values()) + 1\r\n        return self.states[x_str]\r\n\r\n\r\n    def learn(self, env, episode=100, epsilon=0.8):\r\n        '''q-learning算法'''\r\n        env.reset()\r\n        X=np.array([0,1,0,0])\r\n        sss = ','.join(str(i) for i in X.astype(int))\r\n        self.states[sss] = 0\r\n        for i in range(episode):\r\n            steps=0\r\n            current_state = np.array([0])\r\n            env.env.current_cell=np.array([0])\r\n            X_next, envreward, fitness, infos=env.step(current_state)\r\n            self.add_state(X_next)\r\n            next_state_reward=0\r\n            while next_state_reward==0 and steps<1000:\r\n                current_action = self.choose_action(current_state, epsilon) \r\n                X_next, next_state_reward, fitness, infos = env.step(current_action)\r\n                next_state_number=self.add_state(X_next)\r\n                next_state_q_values = self.q_table[next_state_number]\r\n                self.update_q_value(current_state, current_action, next_state_reward, next_state_q_values)\r\n                current_state = next_state_number\r\n                steps+=1\r\n\r\n    def play(self, env):\r\n        step=0\r\n        self.learn(env, epsilon=0.8)\r\n        current_state = np.array([0])\r\n        env.env.current_cell = np.array([0])\r\n        X_next, envreward, fitness, infos = env.step(current_state)\r\n        self.add_state(X_next)\r\n        env_r=[]\r\n        rsum=0\r\n        old_dis=13\r\n\r\n        while step<steps:\r\n            current_action = self.choose_action(current_state, 1)\r\n            X_next, envreward, fitness, infos = env.step(current_action)\r\n            envreward=envreward[0]\r\n            food_pos = env.env.food_pos[:, 0, :2]\r\n            agent_pos = env.env.agents_pos\r\n            dis = ((agent_pos - food_pos) ** 2).sum(1)\r\n            reward =np.array((np.sqrt(old_dis)-np.sqrt(dis))>0,dtype=int)[0]\r\n            if reward==0:\r\n                reward=-1\r\n            elif reward==1:\r\n                reward=1\r\n            if envreward==1:\r\n                reward=3\r\n            elif envreward==-1:\r\n                reward=-3\r\n            next_state_number = self.add_state(X_next)\r\n            rsum+=reward\r\n            current_state = next_state_number\r\n            env_r.append(rsum)\r\n            step+=1\r\n        return np.array(env_r)\r\n\r\n\r\ndef QQ():\r\n    steps=500\r\n    env = MazeTurnEnvVec(1, n_steps=steps)\r\n    data_env = ExperimentEnvGlobalNetworkSurvival(env)\r\n    agent = Agent(data_env)  \r\n    r=agent.play(data_env)\r\n    return r\r\n\r\nnp.save('./ql.npy',QQ())\r\n"
  },
  {
    "path": "examples/Structure_Evolution/Adaptive_lsm/raw/tools/EnuGlobalNetwork.py",
    "content": "import pickle\r\nimport time\r\n\r\nimport numpy as np\r\nimport torch\r\nimport matplotlib.pyplot as plt\r\nimport seaborn as sns\r\nfrom matplotlib import gridspec\r\n\r\nfrom AbstractLayerBMM import AbstractLayerBMM\r\nfrom EvolvableNeuralUnitStacked import EvolvableNeuralUnitStacked\r\nfrom Tools import get_data_path\r\n\r\nsns.set_style(\"darkgrid\")\r\n\r\n\r\nclass EnuGlobalNetwork(AbstractLayerBMM):\r\n    \"\"\"Network of ENUs implementation in PyTorch, where each synapse and neuron is modeled as an ENU. \"\"\"\r\n\r\n    def __init__(self, n_offspring, n_pseudo_env, n_input_neurons, n_hidden_neurons, n_output_neurons, n_syn_per_neuron):\r\n        # offspring\r\n        self.n_offspring = n_offspring\r\n        self.n_pseudo_env = n_pseudo_env\r\n        # input channels\r\n        n_input_channels = 16\r\n        self.n_input_channels = n_input_channels\r\n        n_dynamic_param = 32\r\n        # total neurons\r\n        n_neurons = n_output_neurons + n_hidden_neurons\r\n        self.n_neurons = n_neurons\r\n        super().__init__(n_offspring, n_neurons, n_input_neurons, n_output_neurons)\r\n        torch.random.manual_seed(0)\r\n        #NOTE: batch dimension holds output of each neuron/synapse, allowing fast GPU MM\r\n        #NOTE neurons far less than synapses, so can be relatively bigger rnn for little cost\r\n        n_input_channels_neuron = 16\r\n        n_input_neuron, n_output_neuron = n_input_channels_neuron, n_input_channels\r\n        self.neurons = EvolvableNeuralUnitStacked(n_offspring, batch_size=self.n_neurons, n_input=n_input_neuron, n_dynamic_param=n_dynamic_param, n_output=n_output_neuron)\r\n        #self.n_syn = next_power_of_2(int(n_neurons * (rel_connectivity*n_neurons)))\r\n        self.n_syn_per_neuron = n_syn_per_neuron\r\n        self.n_syn = n_neurons * n_syn_per_neuron\r\n        n_input_syn, n_output_syn = n_input_channels * 2, n_input_channels_neuron # * 2 for neuron feedback (which same channel as n_channel input)\r\n        self.synapses = EvolvableNeuralUnitStacked(n_offspring, batch_size=self.n_syn, n_input=n_input_syn, n_dynamic_param=n_dynamic_param, n_output=n_output_syn)\r\n        # just randomly connect synapses to neurons\r\n        self.synapse_connections = torch.randint(n_input_neurons + n_neurons, size=(n_neurons, n_syn_per_neuron), device='cuda', dtype=torch.long)\r\n        # fixed predefined connection patterns\r\n        if n_input_neurons==2 and n_output_neurons==2 and n_hidden_neurons==2:\r\n            print(\"Fixed connection Network 2-2-2\")\r\n            self.synapse_connections = torch.tensor([[0, 1],\r\n                                                     [0, 1],\r\n                                                     [2, 3],\r\n                                                     [2, 3]], device='cuda', dtype=torch.long)\r\n        elif n_input_neurons == 4 and n_output_neurons == 3 and n_hidden_neurons == 3 and n_syn_per_neuron==3:\r\n            print(\"Fixed connection Network 4-3-3 (3syn)\")\r\n            self.synapse_connections = torch.tensor([[0, 1, 3],# hidden connections #4\r\n                                                     [0, 2, 3], #5\r\n                                                     [1, 2, 3],# 6\r\n                                                     [4, 5, 6], # output connections #7\r\n                                                     [4, 5, 6],#8\r\n                                                     [4, 5, 6]#9\r\n                                                     ], device='cuda', dtype=torch.long)\r\n        elif n_input_neurons==5 and n_hidden_neurons==0 and n_output_neurons==4:\r\n            print(\"Fixed connection Network 5-0-4 (5syn)\")\r\n            # neuron i connected to neuron j and k, neuron 0..input_neurons is index\r\n            self.synapse_connections = torch.tensor([[0, 1, 2, 3, 4],# output connections\r\n                                                     [0, 1, 2, 3, 4],\r\n                                                     [0, 1, 2, 3, 4],\r\n                                                     [0, 1, 2, 3, 4]\r\n                                                     ], device='cuda', dtype=torch.long)\r\n        elif n_input_neurons==1 and n_hidden_neurons==0 and n_output_neurons==2:\r\n            print(\"Fixed connection Network 1-0-2 (1syn)\")\r\n            # neuron i connected to neuron j and k, neuron 0..input_neurons is index\r\n            self.synapse_connections = torch.tensor([[0],# output connections\r\n                                                     [0]\r\n                                                     ], device='cuda', dtype=torch.long)\r\n        elif n_input_neurons==4 and n_hidden_neurons==0 and n_output_neurons==3 and n_syn_per_neuron==4:\r\n            print(\"Sparse connection Network 4-0-3 (4syn)\")\r\n            # neuron i connected to neuron j and k, neuron 0..input_neurons is index\r\n            self.synapse_connections = torch.tensor([[0, 1, 2, 3],# output connections #4\r\n                                                     [0, 1, 2, 3], #5\r\n                                                     [0, 1, 2, 3],# 6\r\n                                                     ], device='cuda', dtype=torch.long)\r\n        elif n_input_neurons == 4 and n_hidden_neurons == 3 and n_output_neurons == 3 and n_syn_per_neuron == 4:\r\n            print(\"Sparse connection Network 4-3-3 (3syn)\")\r\n            # neuron i connected to neuron j and k, neuron 0..input_neurons is index\r\n            self.synapse_connections = torch.tensor([[0, 1, 3],  # hidden connections #4\r\n                                                     [0, 2, 3],  # 5\r\n                                                     [1, 2, 3],  # 6\r\n                                                     [4, 5, 3],  # output connections #7\r\n                                                     [4, 6, 3],  # 8\r\n                                                     [5, 6, 3]  # 9\r\n                                                     ], device='cuda', dtype=torch.long)\r\n        elif n_input_neurons==4 and n_hidden_neurons==3 and n_output_neurons==3 and n_syn_per_neuron==8:\r\n            print(\"Sparse connection Network 4-3-3 (8syn)\")\r\n            # neuron i connected to neuron j and k, neuron 0..input_neurons is index\r\n            self.synapse_connections = torch.tensor([[0, 1, 5, 6, 7, 8, 3, 4],# hidden connections #4\r\n                                                     [0, 2, 4, 6, 7, 9, 3, 5], #5\r\n                                                     [1, 2, 4, 5, 8, 9, 3, 6],# 6\r\n                                                     [4, 5, 8, 9, 0, 1, 3, 7], # output connections #7\r\n                                                     [4, 6, 7, 9, 0, 2, 3, 8],#8\r\n                                                     [5, 6, 7, 8, 1, 2, 3, 9]#9\r\n                                                     ], device='cuda', dtype=torch.long)\r\n        elif n_input_neurons==4 and n_hidden_neurons==4 and n_output_neurons==4 and n_syn_per_neuron==8:\r\n            print(\"Fixed connection Network 4-4-4 (8syn)\")\r\n            # neuron i connected to neuron j and k, neuron 0..input_neurons is index\r\n            self.synapse_connections = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7],# hidden connections\r\n                                                     [0, 1, 2, 3, 4, 5, 6, 7],\r\n                                                     [0, 1, 2, 3, 4, 5, 6, 7],\r\n                                                     [0, 1, 2, 3, 4, 5, 6, 7],\r\n                                                     [4, 5, 6, 7, 8, 9, 10, 11], # output connections\r\n                                                     [4, 5, 6, 7, 8, 9, 10, 11],\r\n                                                     [4, 5, 6, 7, 8, 9, 10, 11],\r\n                                                     [4, 5, 6, 7, 8, 9, 10, 11],\r\n                                                     ], device='cuda', dtype=torch.long)\r\n        elif n_input_neurons==5 and n_hidden_neurons==5 and n_output_neurons==4:\r\n            print(\"Fixed connection Network 5-5-4 (5syn)\")\r\n            # neuron i connected to neuron j and k, neuron 0..input_neurons is index\r\n            self.synapse_connections = torch.tensor([[0, 1, 2, 3, 4],# hidden connections\r\n                                                     [0, 1, 2, 3, 4],\r\n                                                     [0, 1, 2, 3, 4],\r\n                                                     [0, 1, 2, 3, 4],\r\n                                                     [0, 1, 2, 3, 4],\r\n                                                     [5, 6, 7, 8, 9], # output connections\r\n                                                     [5, 6, 7, 8, 9],\r\n                                                     [5, 6, 7, 8, 9],\r\n                                                     [5, 6, 7, 8, 9],\r\n                                                     ], device='cuda', dtype=torch.long)\r\n        elif n_input_neurons==1 and n_hidden_neurons==0 and n_output_neurons==1:\r\n            print(\"Fixed connection Single\")\r\n            self.synapse_connections = torch.tensor([[0]], device='cuda', dtype=torch.long)\r\n        else:\r\n            print(\"Random connections\")\r\n        # each synapse is connected also to its post-synaptic neuron, to allow STDP type learning to emerge\r\n        self.synapse_connections_post = torch.arange(n_neurons, device='cuda', dtype=torch.long).reshape(n_neurons, -1).repeat(1, n_syn_per_neuron)\r\n        # define compartments\r\n        self.compartments = [self.neurons, self.synapses]\r\n        self.trainable_layers = self.neurons.trainable_layers + self.synapses.trainable_layers\r\n        self.track_data = False\r\n\r\n\r\n    def dump_model(self, e, exp_name):\r\n        \"\"\"Dump model to restore\"\"\"\r\n        with open(get_data_path(e, exp_name, \"Model\"), 'wb') as f:\r\n            parameters = {}\r\n            parameters[\"neuron\"] = [layer.base_parameters.cpu().numpy() for layer in self.neurons.trainable_layers]\r\n            parameters[\"synapse\"] = [layer.base_parameters.cpu().numpy() for layer in self.synapses.trainable_layers]\r\n            pickle.dump(parameters, f)\r\n\r\n    def restore_model(self, e, exp_name):\r\n        \"\"\"Restore model\"\"\"\r\n        with open(get_data_path(e, exp_name, \"Model\"), 'rb') as f:\r\n            parameters = pickle.load(f)\r\n        #TODO: refactor to dump/restore at ENU level and just call those functions\r\n        assert len(self.neurons.trainable_layers) == len(parameters[\"neuron\"])\r\n        for i in range(len(parameters[\"neuron\"])):\r\n            self.neurons.trainable_layers[i].base_parameters = torch.from_numpy(parameters[\"neuron\"][i].astype(np.float32)).cuda()\r\n        assert len(self.synapses.trainable_layers) == len(parameters[\"synapse\"])\r\n        for i in range(len(parameters[\"synapse\"])):\r\n            self.synapses.trainable_layers[i].base_parameters = torch.from_numpy(parameters[\"synapse\"][i].astype(np.float32)).cuda()\r\n\r\n    @staticmethod\r\n    def plot_weights(e, exp_name):\r\n        \"\"\"Visualize weights of ENU gates\"\"\"\r\n        sns.set_style(\"dark\")\r\n        def calc_average(start, stop):\r\n            weights_average = None\r\n            for e in range(start, stop, 1000):\r\n                with open(get_data_path(e, exp_name, \"Model\"), 'rb') as f:\r\n                    parameters = pickle.load(f)\r\n                weights = []\r\n                for i in range(len(parameters[\"neuron\"])):\r\n                    weights += [parameters[\"neuron\"][i].astype(np.float32)]\r\n                if weights_average is None:\r\n                    weights_average = weights\r\n                else:\r\n                    for i in range(len(weights_average)):\r\n                        weights_average[i] += weights[i]\r\n            return weights_average\r\n        weights_mean1 = calc_average(20000, 30000)\r\n        fig, ax = plt.subplots(1, 2, sharex='col', sharey='row')\r\n        for i in range(len(weights_mean1)):\r\n            ax[i].imshow(weights_mean1[i], cmap=\"gray\")\r\n        weights_mean2 = calc_average(30000, 40000)\r\n        fig, ax = plt.subplots(1, 2, sharex='col', sharey='row')\r\n        for i in range(len(weights_mean2)):\r\n            ax[i].imshow(weights_mean2[i], cmap=\"gray\")\r\n        fig, ax = plt.subplots(1, 2, sharex='col', sharey='row')\r\n        for i in range(len(weights_mean2)):\r\n            ax[i].imshow((weights_mean2[i] - weights_mean1[i])**5, cmap=\"gray\")\r\n        plt.show()\r\n\r\n\r\n    def dump_network_activity(self, e, exp_name):\r\n        \"\"\"Dump raw data for visualization\"\"\"\r\n        with open(get_data_path(e, exp_name, \"GlobalNetwork\"), 'wb') as f:\r\n            pickle.dump(self.vis_data, f)\r\n\r\n    def print(self):\r\n        print(\"--Neurons--\")\r\n        self.neurons.print()\r\n        print(\"--Synapses--\")\r\n        self.synapses.print()\r\n\r\n    def reset(self):\r\n        self.vis_data = []\r\n        if self.track_data:\r\n            print(\"Tracking network activity\")\r\n        for compartment in self.compartments:\r\n            compartment.reset()\r\n\r\n    def forward(self, X):\r\n        \"\"\"Main computation forward pass\"\"\"\r\n        # transfer to GPU\r\n        X_raw_gpu = torch.from_numpy(X.astype(np.float32)).cuda()\r\n        X_gpu = torch.zeros((X.shape[0], X.shape[1], self.n_input_channels), device='cuda', dtype=torch.float32)\r\n        X_gpu[:, :, :X_raw_gpu.shape[2]] = X_raw_gpu\r\n        # first compute synapses, set input to previous output of connected neuron\r\n        # concat our input spiking pattern directly to input to our synapses (the neurons)\r\n        # NOTE: this concats in batch dimension, meaning it feeds into input neurons directly spiking pattern, while rest receive input from network\r\n        input_to_synapses = torch.cat([X_gpu, self.neurons.out_mem], dim=1)\r\n        # connect each synapse randomly to multiple inputs\r\n        input_to_synapses_connected = input_to_synapses[:, self.synapse_connections.flatten(), :]\r\n        # need feedback connection from neuron to synapse, to allow stdp type rules to emerge (else it has to do it through feedback connections, but less guarentee on connections and cannot distinguise type)\r\n        # one synapse has 1 pre-synaptic neuron and 1 post-synaptic neuron, connectection defined in synapse_connections, synapse_connections[i, :] gives all input synapses of that neuron\r\n        # so feedback to all it's input synapses through broadcasting backwards\r\n        post_neuron_backprop_connected = self.neurons.out_mem[:, self.synapse_connections_post.flatten(), :]\r\n        input_to_synapses_connected = torch.cat([input_to_synapses_connected, post_neuron_backprop_connected], dim=-1)\r\n        # compute synapse\r\n        self.synapses.forward(input_to_synapses_connected)\r\n        # then integrate(sum) all outputs of a neurons input synapses, can just reshape into valid shape, since we already randomly connected when computing synapses\r\n        # NOTE: each neuron then requires same number of synapses, then reshape by modifying batch dim (which contains syn outputs)\r\n        integration = torch.sum(self.synapses.out.reshape((self.n_offspring, self.n_neurons, -1, self.synapses.shape[-1])), dim=2)\r\n        # scale by number of synapses\r\n        integration /= self.n_syn_per_neuron\r\n        self.out_integration = integration\r\n        # finally set neuron input to summated connected synapses output\r\n        input_to_neurons = integration\r\n        out = self.neurons.forward(input_to_neurons)\r\n        # output is last neuron output, NOTE: just first channel is returned, since we reshape neurons to channels\r\n        self.out = out[:, -self.n_output:, 0].reshape(self.n_offspring, self.n_output)\r\n        if self.track_data:\r\n            self._track_vis_data(X, input_to_synapses_connected, input_to_neurons)\r\n        return self.out\r\n\r\n    def _track_vis_data(self, X, input_to_synapses_connected, input_to_neurons):\r\n        offspring_idx = 0\r\n        self.vis_data += [(X[offspring_idx], input_to_neurons[offspring_idx].cpu().numpy(), self.neurons.out[offspring_idx].cpu().numpy(),\r\n                           input_to_synapses_connected[offspring_idx].cpu().numpy(), self.synapses.out[offspring_idx].cpu().numpy())]\r\n\r\n    @staticmethod\r\n    def plot_network_activity(e, exp_name):\r\n        with open(get_data_path(e, exp_name, \"GlobalNetwork\"), 'rb') as f:\r\n            vis_data = pickle.load(f)\r\n            X, input_to_neurons, neurons_out, input_to_synapses, synapses_out = map(np.array, zip(*vis_data))\r\n        def plot_enu_activity(input, output, title):\r\n            n_cells = output.shape[1]\r\n            n_cells = np.minimum(10, output.shape[1])\r\n            fig, grid = plt.subplots(2, n_cells, sharex='col', sharey='row')\r\n            if n_cells==1:\r\n                grid[0].plot(input[:, 0, :])\r\n                grid[1].plot(output[:, 0, :])\r\n            else:\r\n                for i in range(n_cells):\r\n                    grid[0, i].plot(input[:, i, :])\r\n                    grid[1, i].plot(output[:, i, :])\r\n            plt.xlabel(\"t\")\r\n            plt.title(title)\r\n            #plt.ylabel(\"\")\r\n            plt.legend()\r\n        plt.figure()\r\n        plt.plot(X[:, :, 0])\r\n        plot_enu_activity(input_to_neurons, neurons_out, \"ENU neuron activity\")\r\n        plot_enu_activity(input_to_synapses, synapses_out, \"ENU synapse activity\")\r\n\r\n        plt.figure()\r\n        spike_points = np.where(neurons_out[:, :, 0] > 0)\r\n        plt.scatter(spike_points[0], spike_points[1], marker='|')\r\n\r\n        plt.show()\r\n\r\n\r\n"
  },
  {
    "path": "examples/Structure_Evolution/Adaptive_lsm/raw/tools/ExperimentEnvGlobalNetworkSurvival.py",
    "content": "import pickle\r\n\r\nimport numpy as np\r\n\r\nfrom tools.Tools import get_data_path\r\n\r\n\r\nclass ExperimentEnvGlobalNetworkSurvival:\r\n    \"\"\"Wrapper around a given RL environment for a Network of ENUs model,\r\n    turns reward into fitness and dumps relevant data\"\"\"\r\n\r\n\r\n    def __init__(self, env, exp_name='maze'):\r\n        self.env = env\r\n        self.exp_name = exp_name\r\n        self.n_output = self.env.n_actions\r\n        #NOTE: +1 reward neuron\r\n        self.n_input_neurons = self.env.n_obs + 1\r\n        self.n_agents = self.env.n_agents\r\n\r\n    def _convert_obs(self, obs, rewards):\r\n        n_input_channels_used = 3\r\n        X = np.zeros((self.n_agents, self.n_input_neurons, n_input_channels_used))\r\n        #X[:, :obs.shape[1], 0] = obs\r\n        # Shuffle only obs to avoid topology exploitation, reward neuron linked to EnuGlobal synapse connectivity\r\n        X[:, :obs.shape[1], 0] = np.take_along_axis(obs, self.obs_shuffle, axis=1)\r\n        # split pos and negative reward to different channels, And set to last input neuron\r\n        if rewards is not None:\r\n            X[rewards>0, -1, 1] = np.abs(rewards[rewards>0])\r\n            X[rewards<=0, -1, 2] = np.abs(rewards[rewards<=0])\r\n        return X\r\n\r\n    def _convert_reward(self, obs, actions, rewards, infos, dones):\r\n        fitness = np.copy(rewards)\r\n        # first poison is considered positive reward, since learning to learn\r\n        #NOTE: dead by env means less reward can be obtained so should implictely reduce overall fitness automatically\r\n        fitness[np.logical_and(self._prev_reward_count == 1, rewards != 0)] = 1\r\n        # include episode length as extra fitness, since not taking poison would allow survive longer, so should try avoid take poison\r\n        fitness[dones==0] += 0.1/4\r\n        return fitness\r\n\r\n    def step(self, y):\r\n\r\n        # if self.t % 3 != 0:\r\n        #     actions = np.zeros((self.n_agents), dtype=np.int32) - 1\r\n        # else:\r\n            # winner take all, in given time window\r\n        actions = y\r\n            # if all same output, do nothing\r\n            # equal_actions = self.y_hist.shape[1] == np.sum(self.y_hist == np.take_along_axis(self.y_hist, actions.reshape(-1, 1), axis=1), axis=-1)\r\n            # actions[equal_actions] = -1\r\n            # self.y_hist[:] = 0\r\n        # take env step\r\n\r\n        allobs, obs, rewards, dones, infos = self.env.step(actions)\r\n        # X = self._convert_obs(obs, rewards)\r\n        X=allobs\r\n        self._prev_reward_count += rewards!=0\r\n        fitness = self._convert_reward(obs, actions, rewards, infos, dones)\r\n        self._prev_action = actions\r\n        self._prev_obs = obs\r\n        return X, rewards, fitness, None\r\n\r\n    def reset(self):\r\n        self.t = 0\r\n        self.y_hist = np.zeros((self.n_agents, self.n_output), dtype=np.float32)\r\n        self._prev_action = None\r\n        self._prev_obs = None\r\n        self._prev_reward_count = np.zeros((self.n_agents), dtype=np.float32)\r\n        # each time different input/output neurons should have different meaning, to have learning to learn\r\n        self.obs_shuffle = np.argsort(np.random.randn(self.n_agents, self.n_input_neurons - 1), axis=1, kind='mergesort')\r\n        self.action_shuffle = np.argsort(np.random.randn(self.n_agents, self.n_output), axis=1, kind='mergesort')\r\n        # reset env\r\n        self.allobs,self.obs = self.env.reset()\r\n        # return self._convert_obs(self.obs, None)\r\n        return self.allobs\r\n\r\n    def render(self):\r\n        if self.t%4==0:\r\n            self.env.render()\r\n\r\n    def track_vis_data(self, vis_data, model, X, y_est, t):\r\n        n_fetch = 128\r\n        # TODO: also get our gates from the model\r\n        vis_data+=[(X[:n_fetch, :], y_est[:n_fetch, :])]\r\n\r\n    def dump_vis_data(self, vis_data, fitness_per_offspring, e):\r\n        with open(get_data_path(e, self.exp_name, \"output\"), 'wb') as f:\r\n            pickle.dump((vis_data, fitness_per_offspring), f)\r\n\r\n    @staticmethod\r\n    def load_vis_data(e, exp_name):\r\n        with open(get_data_path(e, exp_name, \"output\"), 'rb') as f:\r\n            vis_data, fitness_per_offspring = pickle.load(f)\r\n        return vis_data, fitness_per_offspring\r\n\r\n    @staticmethod\r\n    def plot_vis_data(e, exp_name):\r\n        vis_data, fitness_per_offspring = ExperimentEnvGlobalNetworkSurvival.load_vis_data(e, exp_name)"
  },
  {
    "path": "examples/Structure_Evolution/Adaptive_lsm/raw/tools/MazeTurnEnvVec.py",
    "content": "import pickle\r\n\r\nimport numpy as np\r\nimport matplotlib.pyplot as plt\r\nimport seaborn as sns\r\n\r\nfrom tools.Tools import save_fig, get_data_path\r\n\r\n# np.random.seed(0)\r\n\r\nclass MazeTurnEnvVec:\r\n    \"\"\"Vectorized RL T-Maze environment written in pure Numpy. We require an efficient environment since we need to evaluate\r\n    and run up to thousands of offspring in parallel\"\"\"\r\n\r\n    def __init__(self, n_agents, n_steps):\r\n        # 4 important points, start point, decision point, food point, dead point.\r\n        # just generate a very large matrix that could fit any maze of any size, then can generate smaller maze as well\r\n        self.n_actions = 3\r\n        self.n_obs = 3\r\n        self.max_size = 7\r\n        self.n_agents = n_agents\r\n        self.n_steps = n_steps\r\n        self.window = plt.figure()\r\n        self.t_maze = True\r\n        self.turn_based = False\r\n        # steps can be longer if poison and need to turn around\r\n        self.steps_to_food = 2\r\n        if self.t_maze:\r\n            self.steps_to_food = 3\r\n        self.steps_to_food += self.steps_to_food*2\r\n        # give some extra leniency\r\n        self.steps_to_food *= 2\r\n    def step(self, actions):\r\n        # L R U D\r\n        # TODO: check legal action or not..\r\n        pos_copy = np.copy(self.agents_pos)\r\n        actions = np.copy(actions)\r\n\r\n        # GIVE TIME UPDATE WEIGHTS\r\n        # actions[self.agents_reset > 0] = -1\r\n        # actions[self.agent_energy<=0] = -1\r\n        self.agents_reset[self.agents_reset > 0] -= 1\r\n        # if turn based\r\n        if self.turn_based:\r\n            Forward = actions == 0\r\n            self.agents_pos[np.logical_and(Forward, self.agent_directions == 0), 1] += 1\r\n            self.agents_pos[np.logical_and(Forward, self.agent_directions == 2), 1] -= 1\r\n            self.agents_pos[np.logical_and(Forward, self.agent_directions == 1), 0] -= 1\r\n            self.agents_pos[np.logical_and(Forward, self.agent_directions == 3), 0] += 1\r\n            L = actions == 1\r\n            self.agent_directions[L] += 1\r\n            R = actions == 2\r\n            self.agent_directions[R] -= 1\r\n            self.agent_directions[self.agent_directions > 3] = 0\r\n            self.agent_directions[self.agent_directions < 0] = 3\r\n        else:\r\n            # or just direct movement\r\n            U = actions == 2\r\n            D = actions == 1\r\n            R = actions == 0\r\n            if self.agents_pos[U].size>0:\r\n                self.agents_pos[U, 0] += 1\r\n                self.agent_directions[U] = 3\r\n            if self.agents_pos[D].size>0:\r\n                self.agents_pos[D, 0] -= 1\r\n                self.agent_directions[D] = 1\r\n            if self.t_maze and self.agents_pos[R].size>0:\r\n                self.agents_pos[R, 1] += 1\r\n                self.agent_directions[R] = 0\r\n\r\n        # UNDO MOVES THAT GOT AGENT INTO WALL\r\n        self.current_cells = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1]]\r\n        self.agents_pos[self.current_cells==1] = pos_copy[self.current_cells==1]\r\n        movement_loss = np.prod(self.agents_pos==pos_copy, axis=-1)\r\n        # CHECK IF FOOD CONSUMED, is reward + pos reset\r\n        consumed_food = np.prod(self.agents_pos==self.food_pos[:, 0, :2], axis=-1).astype(np.bool)\r\n        consumed_pois = np.prod(self.agents_pos==self.food_pos[:, 1, :2], axis=-1).astype(np.bool)\r\n        self.consumed_count += consumed_food.astype(np.int32)\r\n        self.consumed_count_total += consumed_food.astype(np.int32)\r\n        self.consumed_count_pois += consumed_pois.astype(np.int32)\r\n        self._reset_pos(np.logical_or(consumed_food, consumed_pois))\r\n        # self._reset_pos_pois(consumed_pois)\r\n        # self._reset_food(self.consumed_count==self.swap_limit, prob=0.0)\r\n        # reset food for agents that ate food, and swap with some probability\r\n        self._reset_food(self.consumed_count==5, prob=0.5)\r\n        self.rewards = consumed_food.astype(np.float32) - consumed_pois.astype(np.float32) #* 0.5 #- movement_loss.astype(np.float32) * 0.01\r\n        # get observation from current position of each agent\r\n        self.agent_allobs,self.obs = self._get_obs_from_pos()\r\n        # instant dead on second poison\r\n        self.agent_energy += np.where(self.consumed_count_pois<=1, np.abs(self.rewards) * self.steps_to_food, -self.agent_energy * np.abs(self.rewards))\r\n        # energy decay to encourage exploration, agent dies if running out of energy\r\n        self.agent_energy = np.minimum(self.agent_energy, self.steps_to_food)\r\n        self.agent_energy -= 1.0/4\r\n        dones = self.agent_energy<=0\r\n\r\n        return self.agent_allobs,self.obs, self.rewards, dones, None\r\n\r\n    def _reset_pos(self, idxs):\r\n        self.agents_pos[idxs] = [self.start_point, 2]  # set X pos\r\n        self.agent_directions[idxs] = 0\r\n        if self.max_size==5:\r\n            self.agent_directions[idxs] = 1\r\n        self.agents_reset[idxs] = 0\r\n        self.agents_reset_count[idxs] += 1\r\n\r\n    def _reset_pos_pois(self, idxs):\r\n        #NOTE: 2 since we call reset twice!\r\n        #NOTE: turning around already cost 8 steps, then 4x4 more is 16+8, 24, so reset should be much worse\r\n        self.agents_reset[np.logical_and(idxs, self.agents_reset_count>2)] = 64\r\n\r\n    def _reset_food(self, idxs, prob=0.5):\r\n        # swap food with some probability, avoids agent overfitting on environment\r\n        swap = np.take_along_axis(self.random_swap_matrix, self.consumed_count_total.reshape(-1, 1), axis=1).ravel()\r\n        swap_idxs = swap * idxs\r\n        food_loc = np.copy(self.food_pos[swap_idxs, 0, :])\r\n        pois_loc = np.copy(self.food_pos[swap_idxs, 1, :])\r\n        self.food_pos[swap_idxs, 0, :] = pois_loc\r\n        self.food_pos[swap_idxs, 1, :] = food_loc\r\n        # set maze value\r\n        self.mazes[np.arange(self.mazes.shape[0]), self.food_pos[:, 0, 0], self.food_pos[:, 0, 1]] = 2\r\n        self.mazes[np.arange(self.mazes.shape[0]), self.food_pos[:, 1, 0], self.food_pos[:, 1, 1]] = 3\r\n        self.consumed_count[idxs] = 0\r\n\r\n    def reset(self):\r\n        self.consumed_count = np.zeros((self.n_agents), dtype=np.int32)\r\n        self.consumed_count_total = np.zeros_like(self.consumed_count)\r\n        # consistent swapping such that if agent eat food once for all agents swapped with same seed, fair fitness comparison\r\n        max_eat = self.n_steps\r\n        self.random_swap_matrix = np.random.uniform(0, 1, size=(1, max_eat)) >= 0.5\r\n        self.random_swap_matrix = np.repeat(self.random_swap_matrix, int(self.n_agents), axis=0)\r\n\r\n        self.agent_energy = np.zeros((self.n_agents), dtype=np.float32) + self.steps_to_food\r\n        self.consumed_count_pois = np.zeros_like(self.consumed_count)\r\n        #self.swap_limit = np.random.randint(1, 5, size=1)\r\n        #self.swap_limit = np.random.randint(1, 4, size=self.n_agents)\r\n        self.mazes = np.ones((self.n_agents, self.max_size, self.max_size), dtype=np.int32)\r\n        #TODO: support variable maze length\r\n        self.start_point = int(self.max_size/2)\r\n        if self.t_maze:\r\n            self.mazes[:, self.start_point, 2:-1] = 0\r\n            self.mazes[:, 1:-1, -2] = 0\r\n            # FOOD either at -1,-1 or -1,1?\r\n            # two foods: x, y, value\r\n            self.food_pos = np.zeros((self.n_agents, 2, 2), dtype=np.int32)\r\n            self.food_pos[:, :, 1] = self.max_size - 2\r\n            self.food_pos[:, 1, 0] = 1\r\n            self.food_pos[:, 0, 0] = self.max_size - 2\r\n            self._reset_food(np.ones(self.food_pos.shape[0], dtype=np.bool), prob=0.5)\r\n        else:\r\n            self.mazes[:, 1:-1, 1] = 0\r\n            # two foods: x, y, value\r\n            self.food_pos = np.zeros((self.n_agents, 2, 2), dtype=np.int32)\r\n            self.food_pos[:, :, 1] = 1\r\n            self.food_pos[:, 0, 0] = 1\r\n            self.food_pos[:, 1, 0] = (self.max_size - 2)\r\n            self._reset_food(np.ones(self.food_pos.shape[0], dtype=np.bool), prob=0.5)\r\n        # AGENT\r\n        self.agents_pos = np.ones((self.n_agents, 2), dtype=np.int32)\r\n        self.agents_reset = np.zeros((self.n_agents), dtype=np.int32)\r\n        self.agents_reset_count = np.zeros_like(self.agents_reset)\r\n        self.agent_directions = np.zeros((self.n_agents), dtype=np.int32)\r\n        self._reset_pos(np.arange(self.agents_pos.shape[0]))\r\n        # OBS\r\n        self.agent_allobs,self.obs = self._get_obs_from_pos()\r\n        return self.agent_allobs,self.obs\r\n\r\n    def _get_obs_from_pos(self):\r\n        # obs is neighbouring cell states around agent\r\n        obs = np.zeros((self.n_agents, self.n_obs), dtype=np.float32)\r\n        raw_obs = np.zeros(self.n_agents, dtype=np.int32)\r\n        # get observation based on direction agent is facing\r\n        leftobs = np.zeros(self.n_agents)\r\n        rightobs = np.zeros(self.n_agents)\r\n        backobs = np.zeros(self.n_agents)\r\n\r\n        # get observation based on direction agent is facing\r\n        D = self.agent_directions == 0\r\n        raw_obs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1] + 1][D]\r\n        leftobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0] - 1, self.agents_pos[:, 1]][D]\r\n        rightobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0] + 1, self.agents_pos[:, 1]][D]\r\n        backobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1]-1][D]\r\n\r\n\r\n        D = self.agent_directions == 2\r\n        raw_obs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1] - 1][D]\r\n        leftobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0] + 1, self.agents_pos[:, 1]][D]\r\n        rightobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0] - 1, self.agents_pos[:, 1]][D]\r\n        backobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1]+1][D]\r\n\r\n        D = self.agent_directions == 1\r\n        raw_obs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0] - 1, self.agents_pos[:, 1]][D]\r\n        leftobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1] - 1][D]\r\n        rightobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1] + 1][D]\r\n        backobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0]+1, self.agents_pos[:, 1]][D]\r\n\r\n        D = self.agent_directions == 3\r\n        raw_obs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0] + 1, self.agents_pos[:, 1]][D]\r\n        leftobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1] + 1][D]\r\n        rightobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0], self.agents_pos[:, 1] - 1][D]\r\n        backobs[D] = self.mazes[np.arange(self.mazes.shape[0]), self.agents_pos[:, 0]-1, self.agents_pos[:, 1]][D]\r\n\r\n        # mark what was observed at different index\r\n        obs[raw_obs == 1, 0] = 1\r\n        obs[raw_obs == 2, 1] = 1\r\n        obs[raw_obs == 3, 2] = 1\r\n        allobs=np.squeeze(np.dstack((leftobs,raw_obs,rightobs,backobs)))\r\n        return allobs, obs\r\n\r\n    def render(self):\r\n        plt.clf()\r\n        sns.set_style(\"white\")\r\n        #TODO: support render all mazes? can reshape to square?\r\n        max_render = 1\r\n        flattened_render = np.dstack(np.split(self.mazes[18, :], max_render, axis=0)).reshape(self.mazes.shape[1], -1)\r\n        flattened_render[flattened_render>1] = 0\r\n        plt.axis('off')\r\n\r\n        plt.imshow(flattened_render,cmap='bone')\r\n\r\n        for j in range(1):\r\n            i=18\r\n            marker = \">\"\r\n            if self.agent_directions[i] == 1:\r\n                marker = \"^\"\r\n            if self.agent_directions[i] == 2:\r\n                marker = \"<\"\r\n            if self.agent_directions[i] == 3:\r\n                marker = \"v\"\r\n            obs_color = \"black\"\r\n            if self.obs[i, 0] == 1:\r\n                obs_color = \"gray\"\r\n            if self.obs[i, 1] == 1:\r\n                obs_color = \"green\"\r\n            if self.obs[i, 2] == 1:\r\n                obs_color = \"red\"\r\n            alpha = 1\r\n            if self.agent_energy[i]<=0:\r\n                alpha = 1\r\n            plt.scatter(self.agents_pos[i, 1] + j * self.mazes.shape[1], self.agents_pos[i, 0], color=\"skyblue\", alpha=alpha, marker=marker)\r\n            plt.scatter(self.agents_pos[i, 1] + j * self.mazes.shape[1], self.agents_pos[i, 0], color=obs_color,alpha=alpha, marker=marker, s=3)\r\n            plt.scatter(self.food_pos[i, 0, 1] + j * self.mazes.shape[1], self.food_pos[i, 0, 0], color=\"green\", alpha=1, marker=\"o\")\r\n            plt.scatter(self.food_pos[i, 1, 1] + j * self.mazes.shape[1], self.food_pos[i, 1, 0], color=\"red\", alpha=1, marker=\"o\")\r\n        plt.pause(0.001)\r\n        #plt.pause(2)\r\n\r\n    @staticmethod\r\n    def load_vis_data(e, exp_name):\r\n        with open(get_data_path(e, exp_name, \"output\"), 'rb') as f:\r\n            vis_data, fitness_per_offspring = pickle.load(f)\r\n        return vis_data, fitness_per_offspring\r\n\r\n    @staticmethod\r\n    def plot_vis_data(e, exp_name):\r\n        import matplotlib.pyplot as plt\r\n        from cycler import cycler\r\n        import seaborn as sns\r\n        sns.set_style(\"whitegrid\")\r\n\r\n        vis_data, fitness_per_offspring = MazeTurnEnvVec.load_vis_data(e, exp_name)\r\n\r\n        offspring_idx = 0\r\n        #x, y_est, y = np.array(vis_data).transpose(1, 2, 0, 3)\r\n        X, Y_est = map(np.array, zip(*vis_data))\r\n        X_base, y_est_base = X[:, offspring_idx], Y_est[:, offspring_idx]\r\n        #X_base, y_est_base = X_base[:300], y_est_base[:300]\r\n        #X_base = np.max(X_base, axis=1)\r\n        # --- normal output single example---\r\n        plt.rc('axes', prop_cycle=(cycler('color', ['gray', '#ff7f0e', '#9467bd', '#8c564b', '#e377c2', '#17becf'])))\r\n        # OLD METHOD!\r\n        plt.figure()\r\n        #NOTE: last neuron is always reward neuron\r\n        plt.plot(-X_base[:, :-1, 0], label=\"N-ENUs input\", alpha=0.7)\r\n        plt.plot(- np.max(X_base, axis=1)[:, 1], label=\"Positive reward\", alpha=0.7, color='#2ca02c')\r\n        plt.plot(- np.max(X_base, axis=1)[:, 2], label=\"Negative reward\", alpha=0.7, color='#d62728')\r\n        #plt.gca().set_color_cycle(['orange', 'purple', 'brown'])\r\n        plt.plot(y_est_base[:, :], label=\"N-ENUs output\", alpha=0.8)\r\n        plt.legend(loc='upper right')\r\n        save_fig(e, exp_name, \"single_episode\")\r\n\r\n        plt.rc('axes', prop_cycle=(cycler('color', ['gray', '#ff7f0e', '#9467bd', '#8c564b', '#e377c2', '#17becf'])))\r\n        # new method!\r\n        fig, grid = plt.subplots(2, sharex=True)\r\n        # input\r\n        #grid[1].set_prop_cycle(cycler('color', ['gray', '#ff7f0e', '#1f77b4']))\r\n        grid[1].set_prop_cycle(cycler('color', ['gray', '#F5B041', '#2E86C1']))\r\n        grid[1].plot(X_base[:, :-1, 0], label=\"N-ENUs input\", alpha=0.7, linewidth=2)\r\n        grid[1].plot(np.max(X_base, axis=1)[:, 1], label=\"Positive reward\", alpha=0.7, color='#1ABC9C', linewidth=2)\r\n        grid[1].plot(np.max(X_base, axis=1)[:, 2], label=\"Negative reward\", alpha=0.7, color='#CB4335', linewidth=2)\r\n        grid[1].legend(['Sensor (wall)', 'Sensor (red)', 'Sensor (green)', 'Positive reward', 'Negative reward'], loc='upper right')\r\n        grid[1].set_ylabel('Neuron output')\r\n        # output\r\n        grid[0].set_prop_cycle(cycler('color', ['#9467bd', '#e377c2', '#17becf','#8c564b']))\r\n        grid[0].plot(y_est_base[:, :], label=\"N-ENUs output\", alpha=0.8, linewidth=2)\r\n        grid[0].legend(['ENU-NN (left)', 'ENU-NN (right)', 'ENU-NN (forward)'],loc='upper right')\r\n        plt.xlabel('t')\r\n        grid[0].set_ylabel('ENU neuron output')\r\n        plt.xlim(-5, X_base.shape[0]+10)\r\n        save_fig(e, exp_name, \"single_episode_dual\")\r\n        #plt.show()\r\n\r\n    @staticmethod\r\n    def plot_rollout_data(e, exp_name):\r\n        #TODO: dump rollout as array not the actual plots\r\n        import matplotlib.pyplot as plt\r\n        import seaborn as sns\r\n        sns.set_style(\"white\")\r\n        import os\r\n\r\n        rollout_path = \"./\" + get_data_path(e, exp_name, \"rollout\").split(\".\")[1][:-1]+\"/\"\r\n        rollout_files = sorted(os.listdir(rollout_path))\r\n        rollouts = []\r\n        for i in range(4, 200, 4):\r\n            file = \"rollout_{}_.png\".format(i)\r\n            print(file)\r\n            rollout = plt.imread(rollout_path + file)\r\n            rollout = rollout[80:450, 200:500]\r\n            rollout = np.where(rollout[:, :, [0]]> 0.98, 1, rollout)\r\n            rollouts.append(rollout)\r\n            # plt.imshow(rollout)\r\n            # plt.show()\r\n        # for rollout in rollouts:\r\n        vert_bar = np.zeros((rollout.shape[0], 5, 4))\r\n        vert_bar[::] += 0.5\r\n        # get red -> learned go other way\r\n        # food swapped -> sees red -> learned turn around\r\n        rollouts1 = [rollouts[1], rollouts[9], rollouts[10], rollouts[11],  rollouts[17], rollouts[18]]#, rollouts[19]\r\n        #, rollouts[28]\r\n        #rollouts[24],\r\n        rollouts2 = [rollouts[25], rollouts[29], rollouts[30], rollouts[31], rollouts[32], rollouts[33]]\r\n        plt.axis('off')\r\n        plt.imshow(np.vstack([np.column_stack(rollouts1), np.column_stack(rollouts2)]), cmap='gray')\r\n        save_fig(e, exp_name, \"rollout_combined\")\r\n        #plt.show()\r\n\r\n\r\n\r\n\r\nif __name__ == '__main__':\r\n    \"\"\"Test function\"\"\"\r\n    n_offspring = 1024\r\n    envs = MazeTurnEnvVec(n_offspring, n_steps=400)\r\n    envs.n_pseudo_env = 8\r\n    while True:\r\n        envs.reset()\r\n        opt_actions_up = [0,0,0,0,1,0]\r\n        opt_actions_down = [0,0,0,0,2,0]\r\n        opt_actions = [opt_actions_up, opt_actions_down]\r\n        opt_current = 0\r\n        total_reward = 0\r\n        rewards = np.zeros((n_offspring))\r\n        rewards_all = np.zeros_like(rewards)\r\n        k = 0\r\n        n_steps = 200\r\n        for i in range(n_steps):\r\n            actions = np.random.randint(0, envs.n_actions, size=n_offspring)\r\n            #actions[:] = opt_actions_up[i%len(opt_actions_up)]\r\n            #actions[0] = opt_actions[opt_current][k % len(opt_actions_up)]\r\n            obs, rewards,_,_ = envs.step(actions=actions)\r\n            total_reward += (rewards[0] * 100) + 1\r\n            rewards_all += (rewards * 100) + 1\r\n            #print(rewards[0])\r\n            #print(obs[0], rewards[0])\r\n\r\n            envs.render()\r\n            plt.pause(0.5)\r\n\r\n            k += 1\r\n            if rewards[0] < 0:\r\n                #print(rewards_last[0])\r\n                if opt_current==0:\r\n                    opt_current = 1\r\n                else:\r\n                    opt_current = 0\r\n            if rewards[0] != 0:\r\n                k = 0\r\n        total_reward/=n_steps\r\n        rewards_all/=n_steps\r\n        print(rewards_all[0], np.mean(rewards_all), np.std(rewards_all), np.std(rewards_all)/np.mean(rewards_all))\r\n        print(rewards_all[:10])\r\n"
  },
  {
    "path": "examples/Structure_Evolution/EB-NAS/acc_predictor/adaptive_switching.py",
    "content": "import utils\nimport numpy as np\nfrom acc_predictor.factory import get_acc_predictor\n\n\nclass AdaptiveSwitching:\n    \"\"\" ensemble surrogate model \"\"\"\n    \"\"\" try all available models, pick one based on 10-fold crx vld \"\"\"\n    def __init__(self, n_fold=10):\n        # self.model_pool = ['rbf', 'gp', 'mlp', 'carts']\n        self.model_pool = ['rbf', 'gp', 'carts']\n        self.n_fold = n_fold\n        self.name = 'adaptive switching'\n        self.model = None\n        # self.predictor_pool = []\n\n    def fit(self, train_data, train_target):\n        self._n_fold_validation(train_data, train_target, n=self.n_fold)\n        # for p in self.predictor_pool:\n        #     p.fit(train_data,train_target)\n\n    def _n_fold_validation(self, train_data, train_target, n=10):\n\n        n_samples = len(train_data)\n        perm = np.random.permutation(n_samples)\n\n\n        kendall_tau = np.full((n, len(self.model_pool)), np.nan)\n\n        all_predict_result=[]\n\n        for i, tst_split in enumerate(np.array_split(perm, n)):\n            trn_split = np.setdiff1d(perm, tst_split, assume_unique=True)\n            rl=[]\n            # loop over all considered surrogate model in pool\n            for j, model in enumerate(self.model_pool):\n                acc_predictor = get_acc_predictor(model, train_data[trn_split], train_target[trn_split])                \n                result = acc_predictor.predict(train_data[tst_split])\n                rl.append(result)\n                \n                rmse, rho, tau = utils.get_correlation(result, train_target[tst_split])\n\n                kendall_tau[i, j] = tau\n\n            all_predict_result.append(rl)\n            \n        winner = int(np.argmax(np.mean(kendall_tau, axis=0) - np.std(kendall_tau, axis=0)))\n        print(\"winner model = {}, tau = {}\".format(self.model_pool[winner],\n                                                   np.mean(kendall_tau, axis=0)[winner]))\n        self.winner = self.model_pool[winner]\n        # re-fit the winner model with entire data\n\n        # acc_predictor = get_acc_predictor(self.model_pool[winner], train_data, train_target)\n        # self.model = acc_predictor\n\n    def predict(self, test_data):\n        \n\n\n        return self.model.predict(test_data)\n"
  },
  {
    "path": "examples/Structure_Evolution/EB-NAS/acc_predictor/carts.py",
    "content": "# implementation based on\n# https://github.com/yn-sun/e2epp/blob/master/build_predict_model.py\n# and https://github.com/HandingWang/RF-CMOCO\nimport numpy as np\nfrom sklearn.tree import DecisionTreeRegressor\n\n\nclass CART:\n    \"\"\" Classification and Regression Tree \"\"\"\n    def __init__(self, n_tree=1000):\n        self.n_tree = n_tree\n        self.name = 'carts'\n        self.model = None\n\n    @staticmethod\n    def _make_decision_trees(train_data, train_label, n_tree):\n        feature_record = []\n        tree_record = []\n\n        for i in range(n_tree):\n            sample_idx = np.arange(train_data.shape[0])\n            np.random.shuffle(sample_idx)\n            train_data = train_data[sample_idx, :]\n            train_label = train_label[sample_idx]\n\n            feature_idx = np.arange(train_data.shape[1])\n            np.random.shuffle(feature_idx)\n            n_feature = np.random.randint(1, train_data.shape[1] + 1)\n            selected_feature_ids = feature_idx[0:n_feature]\n            feature_record.append(selected_feature_ids)\n\n            dt = DecisionTreeRegressor()\n            dt.fit(train_data[:, selected_feature_ids], train_label)\n            tree_record.append(dt)\n\n        return tree_record, feature_record\n\n    def fit(self, train_data, train_label):\n        self.model = self._make_decision_trees(train_data, train_label, self.n_tree)\n\n    def predict(self, test_data):\n        assert self.model is not None, \"carts does not exist, call fit to obtain cart first\"\n\n        # redundant variable device\n        trees, features = self.model[0], self.model[1]\n        test_num, n_tree = len(test_data), len(trees)\n\n        predict_labels = np.zeros((test_num, 1))\n        for i in range(test_num):\n            this_test_data = test_data[i, :]\n            predict_this_list = np.zeros(n_tree)\n\n            for j, (tree, feature) in enumerate(zip(trees, features)):\n                predict_this_list[j] = tree.predict([this_test_data[feature]])[0]\n\n            # find the top 100 prediction\n            predict_this_list = np.sort(predict_this_list)\n            predict_this_list = predict_this_list[::-1]\n            this_predict = np.mean(predict_this_list)\n            predict_labels[i, 0] = this_predict\n\n        return predict_labels\n\n"
  },
  {
    "path": "examples/Structure_Evolution/EB-NAS/acc_predictor/factory.py",
    "content": "def get_acc_predictor(model, inputs, targets):\n\n    if model == 'rbf':\n        from acc_predictor.rbf import RBF\n        acc_predictor = RBF()\n        acc_predictor.fit(inputs, targets)\n\n    elif model == 'carts':\n        from acc_predictor.carts import CART\n        acc_predictor = CART(n_tree=5000)\n        acc_predictor.fit(inputs, targets)\n\n    elif model == 'gp':\n        from acc_predictor.gp import GP\n        acc_predictor = GP()\n        acc_predictor.fit(inputs, targets)\n\n    elif model == 'mlp':\n        from acc_predictor.mlp import MLP\n        acc_predictor = MLP(n_feature=inputs.shape[1])\n        acc_predictor.fit(x=inputs, y=targets)\n\n    elif model == 'as':\n        from acc_predictor.adaptive_switching import AdaptiveSwitching\n        acc_predictor = AdaptiveSwitching()\n        acc_predictor.fit(inputs, targets)\n\n    else:\n        raise NotImplementedError\n\n    return acc_predictor\n\n"
  },
  {
    "path": "examples/Structure_Evolution/EB-NAS/acc_predictor/gp.py",
    "content": "from pydacefit.regr import regr_constant\nfrom pydacefit.dace import DACE, regr_linear, regr_quadratic\nfrom pydacefit.corr import corr_gauss, corr_cubic, corr_exp, corr_expg, corr_spline, corr_spherical\n\n\nclass GP:\n    \"\"\" Gaussian Process (Kriging) \"\"\"\n    def __init__(self, regr='linear', corr='gauss'):\n        self.regr = regr\n        self.corr = corr\n        self.name = 'gp'\n        self.model = None\n\n    def fit(self, train_data, train_label):\n        if self.regr == 'linear':\n            regr = regr_linear\n        elif self.regr == 'constant':\n            regr = regr_constant\n        elif self.regr == 'quadratic':\n            regr = regr_quadratic\n        else:\n            raise NotImplementedError(\"unknown GP regression\")\n\n        if self.corr == 'gauss':\n            corr = corr_gauss\n        elif self.corr == 'cubic':\n            corr = corr_cubic\n        elif self.corr == 'exp':\n            corr = corr_exp\n        elif self.corr == 'expg':\n            corr = corr_expg\n        elif self.corr == 'spline':\n            corr = corr_spline\n        elif self.corr == 'spherical':\n            corr = corr_spherical\n        else:\n            raise NotImplementedError(\"unknown GP correlation\")\n\n        self.model = DACE(\n            regr=regr, corr=corr, theta=1.0, thetaL=0.00001, thetaU=100)\n        self.model.fit(train_data, train_label)\n\n    def predict(self, test_data):\n        assert self.model is not None, \"GP does not exist, call fit to obtain GP first\"\n        return self.model.predict(test_data)\n"
  },
  {
    "path": "examples/Structure_Evolution/EB-NAS/acc_predictor/mlp.py",
    "content": "import copy\nimport torch\nimport numpy as np\nimport torch.nn as nn\nfrom utils import get_correlation\n\n\nclass Net(nn.Module):\n    # N-layer MLP\n    def __init__(self, n_feature, n_layers=2, n_hidden=300, n_output=1, drop=0.2):\n        super(Net, self).__init__()\n\n        self.stem = nn.Sequential(nn.Linear(n_feature, n_hidden), nn.ReLU())\n\n        hidden_layers = []\n        for _ in range(n_layers):\n            hidden_layers.append(nn.Linear(n_hidden, n_hidden))\n            hidden_layers.append(nn.ReLU())\n        self.hidden = nn.Sequential(*hidden_layers)\n\n        self.regressor = nn.Linear(n_hidden, n_output)  # output layer\n        self.drop = nn.Dropout(p=drop)\n\n    def forward(self, x):\n        x = self.stem(x)\n        x = self.hidden(x)\n        x = self.drop(x)\n        x = self.regressor(x)  # linear output\n        return x\n\n    @staticmethod\n    def init_weights(m):\n        if type(m) == nn.Linear:\n            n = m.in_features\n            y = 1.0 / np.sqrt(n)\n            m.weight.data.uniform_(-y, y)\n            m.bias.data.fill_(0)\n\n\nclass MLP:\n    \"\"\" Multi Layer Perceptron \"\"\"\n    def __init__(self, **kwargs):\n        self.model = Net(**kwargs)\n        self.name = 'mlp'\n\n    def fit(self, **kwargs):\n        self.model = train(self.model, **kwargs)\n\n    def predict(self, test_data, device='cpu'):\n        return predict(self.model, test_data, device=device)\n\n\ndef train(net, x, y, trn_split=0.8, pretrained=None, device='cpu',\n          lr=8e-4, epochs=2000, verbose=False):\n\n    n_samples = x.shape[0]\n    target = torch.zeros(n_samples, 1)\n    perm = torch.randperm(target.size(0))\n    trn_idx = perm[:int(n_samples * trn_split)]\n    vld_idx = perm[int(n_samples * trn_split):]\n\n    inputs = torch.from_numpy(x).float()\n    target[:, 0] = torch.from_numpy(y).float()\n\n    # back-propagation training of a NN\n    if pretrained is not None:\n        print(\"Constructing MLP surrogate model with pre-trained weights\")\n        init = torch.load(pretrained, map_location='cpu')\n        net.load_state_dict(init)\n        best_net = copy.deepcopy(net)\n    else:\n        # print(\"Constructing MLP surrogate model with \"\n        #       \"sample size = {}, epochs = {}\".format(x.shape[0], epochs))\n\n        # initialize the weights\n        # net.apply(Net.init_weights)\n        net = net.to(device)\n        optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n        criterion = nn.SmoothL1Loss()\n        # criterion = nn.MSELoss()\n\n        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, int(epochs), eta_min=0)\n\n        best_loss = 1e33\n        for epoch in range(epochs):\n            trn_inputs = inputs[trn_idx]\n            trn_labels = target[trn_idx]\n            loss_trn = train_one_epoch(net, trn_inputs, trn_labels, criterion, optimizer, device)\n            loss_vld = infer(net, inputs[vld_idx], target[vld_idx], criterion, device)\n            scheduler.step()\n\n            # if epoch % 500 == 0 and verbose:\n            #     print(\"Epoch {:4d}: trn loss = {:.4E}, vld loss = {:.4E}\".format(epoch, loss_trn, loss_vld))\n\n            if loss_vld < best_loss:\n                best_loss = loss_vld\n                best_net = copy.deepcopy(net)\n\n    validate(best_net, inputs, target, device=device)\n\n    return best_net.to('cpu')\n\n\ndef train_one_epoch(net, data, target, criterion, optimizer, device):\n    net.train()\n    optimizer.zero_grad()\n\n    data, target = data.to(device), target.to(device)\n    pred = net(data)\n    loss = criterion(pred, target)\n    loss.backward()\n    optimizer.step()\n\n    return loss.item()\n\n\ndef infer(net, data, target, criterion, device):\n    net.eval()\n\n    with torch.no_grad():\n        data, target = data.to(device), target.to(device)\n        pred = net(data)\n        loss = criterion(pred, target)\n\n    return loss.item()\n\n\ndef validate(net, data, target, device):\n    net.eval()\n\n    with torch.no_grad():\n        data, target = data.to(device), target.to(device)\n        pred = net(data)\n        pred, target = pred.cpu().detach().numpy(), target.cpu().detach().numpy()\n\n        rmse, rho, tau = get_correlation(pred, target)\n\n    # print(\"Validation RMSE = {:.4f}, Spearman's Rho = {:.4f}, Kendall’s Tau = {:.4f}\".format(rmse, rho, tau))\n    return rmse, rho, tau, pred, target\n\n\ndef predict(net, query, device):\n\n    if query.ndim < 2:\n        data = torch.zeros(1, query.shape[0])\n        data[0, :] = torch.from_numpy(query).float()\n    else:\n        data = torch.from_numpy(query).float()\n\n    net = net.to(device)\n    net.eval()\n    with torch.no_grad():\n        data = data.to(device)\n        pred = net(data)\n\n    return pred.cpu().detach().numpy()"
  },
  {
    "path": "examples/Structure_Evolution/EB-NAS/acc_predictor/rbf.py",
    "content": "from pySOT.surrogate import RBFInterpolant, CubicKernel, TPSKernel, LinearTail, ConstantTail\n\n\nclass RBF:\n    \"\"\" Radial Basis Function \"\"\"\n\n    def __init__(self, kernel='cubic', tail='linear'):\n        self.kernel = kernel\n        self.tail = tail\n        self.name = 'rbf'\n        self.model = None\n\n    def fit(self, train_data, train_label):\n        if self.kernel == 'cubic':\n            kernel = CubicKernel\n        elif self.kernel == 'tps':\n            kernel = TPSKernel\n        else:\n            raise NotImplementedError(\"unknown RBF kernel\")\n\n        if self.tail == 'linear':\n            tail = LinearTail\n        elif self.tail == 'constant':\n            tail = ConstantTail\n        else:\n            raise NotImplementedError(\"unknown RBF tail\")\n\n        self.model = RBFInterpolant(dim=train_data.shape[1], kernel=kernel(), tail=tail(train_data.shape[1]))\n\n        for i in range(len(train_data)):\n            self.model.add_points(train_data[i, :], train_label[i])\n\n    def predict(self, test_data):\n        assert self.model is not None, \"RBF model does not exist, call fit to obtain rbf model first\"\n        return self.model.predict(test_data)\n"
  },
  {
    "path": "examples/Structure_Evolution/EB-NAS/cellmodel.py",
    "content": "import os\nfrom functools import partial\nfrom typing import List, Type\n\nfrom operations import *\nfrom motifs import *\nfrom utils import drop_path\nfrom timm.models import register_model\nfrom braincog.base.node.node import *\nfrom braincog.base.connection.layer import *\nfrom braincog.model_zoo.base_module import BaseModule\nfrom torchvision import transforms\n\nEVO=True\nclass EvoCell2(nn.Module):\n    def __init__(self,motif, C_prev_prev, C_prev, C, reduction, reduction_prev, act_fun):\n        # print(C_prev_prev, C_prev, C, reduction)\n        super(EvoCell2, self).__init__()\n        self.act_fun = act_fun\n        self.reduction = reduction\n        self.motif=motif\n        self.back_connection=False\n        if reduction:\n            self.fun = FactorizedReduce(\n                C_prev, C * 3, act_fun=act_fun\n            )\n            self.multiplier = 3\n        else:\n            if reduction_prev:\n                self.preprocess0 = FactorizedReduce(\n                    C_prev_prev, C, act_fun=act_fun)\n            else:\n                self.preprocess0 = ReLUConvBN(\n                    C_prev_prev, C, 1, 1, 0, act_fun=act_fun)\n            self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, act_fun=act_fun)\n\n            op_names, indices = zip(*motif.normal)\n            concat = motif.normal_concat\n            self._compile(C, op_names, indices, concat, reduction)\n\n    def _compile(self, C, op_names, indices, concat, reduction):\n        assert len(op_names) == len(indices)\n        # self._steps = len(op_names) // 2\n        self._concat = concat\n        self.multiplier = len(concat)\n\n        self._ops = nn.ModuleList()\n        self._ops_back = nn.ModuleList()\n        back_begin_index = 0\n        for i, (name, index) in enumerate(zip(op_names, indices)):\n            # print(name, index)\n            if '_back' in name:\n                self.back_connection=True\n                back_begin_index = i\n                break\n            stride = 2 if reduction and index < 2 else 1\n            op = OPS[name](C, stride, True, act_fun=self.act_fun)\n            self._ops += [op]\n\n        if self.back_connection:\n            for name, index in zip(op_names[back_begin_index:], indices[back_begin_index:]):\n                op = OPS[name.replace('_back', '')](\n                    C, 1, True, act_fun=self.act_fun)\n                self._ops_back += [op]\n\n        if self.back_connection:\n            self._indices_forward = indices[:back_begin_index]\n            self._indices_backward = indices[back_begin_index:]\n        else:\n            self._indices_backward = []\n            self._indices_forward = indices\n        self._steps = len(self._indices_forward) // 2\n\n    def forward(self, s0, s1, drop_prob):\n        if self.reduction:\n            return self.fun(s1)\n        # print('s0',s0.shape)\n        s0 = self.preprocess0(s0)\n        # print(s0.shape)\n        # print('s1',s1.shape)\n        s1 = self.preprocess1(s1)\n        # print(s1.shape)\n\n        states = [s0, s1]\n        for i in range(self._steps):\n            i1=self._indices_forward[2 * i]\n            i2=self._indices_forward[2 * i + 1]\n            h1 = states[i1]\n            h2 = states[i2]\n            op1 = self._ops[2 * i]\n            op2 = self._ops[2 * i + 1]\n            h1 = op1(h1)\n            h2 = op2(h2)\n            if self.training and drop_prob > 0.:\n                if not isinstance(op1, Identity):\n                    h1 = drop_path(h1, drop_prob)\n                if not isinstance(op2, Identity):\n                    h2 = drop_path(h2, drop_prob)\n            s = h1 + h2\n            \n            if self.back_connection:\n                if i != 0:\n                    s_back = self._ops_back[i - 1](s)\n                    states[self._indices_backward[i - 1]] = states[self._indices_backward[i - 1]] + s_back\n            states += [s]\n        \n            \n        \n        outputs = torch.cat([states[i]\n                            for i in self._concat], dim=1)  # N，C，H, W\n        return outputs\n        # return self.node(outputs)\n\n\nclass EvoCell3(nn.Module):\n    def __init__(self,motif, C_prev_prev_prev, C_prev_prev, C_prev, C, reduction, reduction_prev, reduction_prev_prev, act_fun):\n        # print(C_prev_prev_prev,C_prev_prev, C_prev, C, reduction,reduction_prev, reduction_prev_prev)\n\n        super(EvoCell3, self).__init__()\n        self.act_fun = act_fun\n        self.reduction = reduction\n        self.motif=motif\n        self.back_connection=False\n        if reduction:\n            self.fun = FactorizedReduce(C_prev, C * 3, act_fun=act_fun)\n            self.multiplier = 3\n        else:\n\n            if reduction_prev:\n                self.preprocess1 = FactorizedReduce(C_prev_prev, C, act_fun=act_fun)\n            else:\n                self.preprocess1 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, act_fun=act_fun)\n\n\n            if int(reduction_prev_prev)+int(reduction_prev)==1:\n                self.preprocess0 = FactorizedReduce(C_prev_prev_prev, C, act_fun=act_fun)\n            elif int(reduction_prev_prev)+int(reduction_prev)==2:\n                self.preprocess0 = F0(C_prev_prev_prev, C, act_fun=act_fun)\n            else:\n                self.preprocess0 = ReLUConvBN(C_prev_prev_prev, C, 1, 1, 0, act_fun=act_fun)\n\n\n            self.preprocess2 = ReLUConvBN(C_prev, C, 1, 1, 0, act_fun=act_fun)\n\n\n            op_names, indices = zip(*motif.normal)\n            concat = motif.normal_concat\n            self._compile(C, op_names, indices, concat, reduction)\n    def _compile(self, C, op_names, indices, concat, reduction):\n        assert len(op_names) == len(indices)\n        # self._steps = len(op_names) // 2\n        self._concat = concat\n        self.multiplier = len(concat)\n\n        self._ops = nn.ModuleList()\n        self._ops_back = nn.ModuleList()\n        back_begin_index = 0\n        for i, (name, index) in enumerate(zip(op_names, indices)):\n            # print(name, index)\n            if '_back' in name:\n                self.back_connection=True\n                back_begin_index = i\n                break\n            stride = 2 if reduction and index < 2 else 1\n            op = OPS[name](C, stride, True, act_fun=self.act_fun)\n            self._ops += [op]\n\n        if self.back_connection:\n            for name, index in zip(op_names[back_begin_index:], indices[back_begin_index:]):\n                op = OPS[name.replace('_back', '')](\n                    C, 1, True, act_fun=self.act_fun)\n                self._ops_back += [op]\n\n        if self.back_connection:\n            self._indices_forward = indices[:back_begin_index]\n            self._indices_backward = indices[back_begin_index:]\n        else:\n            self._indices_backward = []\n            self._indices_forward = indices\n        self._steps = len(self._indices_forward) // 3\n\n    def forward(self, s0, s1, s2, drop_prob):\n        if self.reduction:\n            return self.fun(s2)\n\n        s0 = self.preprocess0(s0)\n\n        s1 = self.preprocess1(s1)\n        s2 = self.preprocess2(s2)\n\n        states = [s0, s1, s2]\n\n        for i in range(self._steps):\n            i1=self._indices_forward[3 * i]\n            i2=self._indices_forward[3 * i + 1]\n            i3=self._indices_forward[3 * i + 2]\n\n            h1 = states[i1]\n            h2 = states[i2]\n            h3 = states[i3]\n\n            op1 = self._ops[3 * i]\n            op2 = self._ops[3 * i + 1]\n            op3 = self._ops[3 * i + 2]\n            h1 = op1(h1)\n            h2 = op2(h2)\n            h3 = op3(h3)\n\n            if self.training and drop_prob > 0.:\n                if not isinstance(op1, Identity):\n                    h1 = drop_path(h1, drop_prob)\n                if not isinstance(op2, Identity):\n                    h2 = drop_path(h2, drop_prob)                \n                if not isinstance(op3, Identity):\n                    h3 = drop_path(h3, drop_prob)\n            s = h1 + h2 + h3\n            \n            if self.back_connection:\n                if i != 0:\n                    s_back = self._ops_back[i - 1](s)\n                    states[self._indices_backward[i - 1]] = states[self._indices_backward[i - 1]] + s_back\n            states += [s]\n        \n            \n        \n        outputs = torch.cat([states[i] for i in self._concat], dim=1)  # N，C，H, W\n        return outputs\n        # return self.node(outputs)\n\nclass EvoCell4(nn.Module):\n    def __init__(self,motif, C_prev_prev_prev_prev,C_prev_prev_prev, C_prev_prev, C_prev, C, reduction, reduction_prev, reduction_prev_prev,reduction_prev_prev_prev, act_fun):\n        # print(C_prev_prev_prev_prev,C_prev_prev_prev,C_prev_prev, C_prev, C, reduction,reduction_prev, reduction_prev_prev,reduction_prev_prev_prev)\n\n        super(EvoCell4, self).__init__()\n        self.act_fun = act_fun\n        self.reduction = reduction\n        self.motif=motif\n        self.back_connection=False\n        if reduction:\n            self.fun = FactorizedReduce(C_prev, C * 3, act_fun=act_fun)\n            self.multiplier = 3\n        else:\n\n            if reduction_prev:\n                self.preprocess2 = FactorizedReduce(C_prev_prev, C, act_fun=act_fun)\n            else:\n                self.preprocess2 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, act_fun=act_fun)\n\n            if int(reduction_prev_prev)+int(reduction_prev)==1:\n                self.preprocess1 = FactorizedReduce(C_prev_prev_prev, C, act_fun=act_fun)\n            elif int(reduction_prev_prev)+int(reduction_prev)==2:\n                self.preprocess1 = F0(C_prev_prev_prev, C, act_fun=act_fun)\n            else:\n                self.preprocess1 = ReLUConvBN(C_prev_prev_prev, C, 1, 1, 0, act_fun=act_fun)\n            \n            if int(reduction_prev_prev_prev)+int(reduction_prev_prev)+int(reduction_prev)==1:\n                self.preprocess0 = FactorizedReduce(C_prev_prev_prev_prev, C, act_fun=act_fun)\n            elif int(reduction_prev_prev_prev)+int(reduction_prev_prev)+int(reduction_prev)==2:\n                self.preprocess0 = F0(C_prev_prev_prev_prev, C, act_fun=act_fun)            \n            elif int(reduction_prev_prev_prev)+int(reduction_prev_prev)+int(reduction_prev)==3:\n                self.preprocess0 = F1(C_prev_prev_prev_prev, C, act_fun=act_fun)\n            else:\n                self.preprocess0 = ReLUConvBN(C_prev_prev_prev_prev, C, 1, 1, 0, act_fun=act_fun)\n\n\n\n            self.preprocess3 = ReLUConvBN(C_prev, C, 1, 1, 0, act_fun=act_fun)\n\n\n            op_names, indices = zip(*motif.normal)\n            # print(self.preprocess0)\n            # print(self.preprocess1)\n            # print(self.preprocess2)\n            # print(self.preprocess3)\n            concat = motif.normal_concat\n            self._compile(C, op_names, indices, concat, reduction)\n    def _compile(self, C, op_names, indices, concat, reduction):\n        assert len(op_names) == len(indices)\n        # self._steps = len(op_names) // 2\n        self._concat = concat\n        self.multiplier = len(concat)\n\n        self._ops = nn.ModuleList()\n        self._ops_back = nn.ModuleList()\n        back_begin_index = 0\n        for i, (name, index) in enumerate(zip(op_names, indices)):\n            # print(name, index)\n            if '_back' in name:\n                self.back_connection=True\n                back_begin_index = i\n                break\n            stride = 2 if reduction and index < 2 else 1\n            op = OPS[name](C, stride, True, act_fun=self.act_fun)\n            self._ops += [op]\n\n        if self.back_connection:\n            for name, index in zip(op_names[back_begin_index:], indices[back_begin_index:]):\n                op = OPS[name.replace('_back', '')](\n                    C, 1, True, act_fun=self.act_fun)\n                self._ops_back += [op]\n\n        if self.back_connection:\n            self._indices_forward = indices[:back_begin_index]\n            self._indices_backward = indices[back_begin_index:]\n        else:\n            self._indices_backward = []\n            self._indices_forward = indices\n        self._steps = len(self._indices_forward) // 4\n\n    def forward(self, s0, s1, s2, s3, drop_prob):\n        if self.reduction:\n            return self.fun(s3)\n\n        s0 = self.preprocess0(s0)\n        s3 = self.preprocess3(s3)\n        s1 = self.preprocess1(s1)\n        s2 = self.preprocess2(s2)\n\n        # if s1.shape[1]!=s3.shape[1]:\n        #     s1 = nn.Conv2d(s1.shape[1], s3.shape[1], 3, stride=2, padding=1, bias=False)\n\n        states = [s0, s1, s2,s3]\n\n        for i in range(self._steps):\n            i1=self._indices_forward[4 * i]\n            i2=self._indices_forward[4 * i + 1]\n            i3=self._indices_forward[4 * i + 2]\n            i4=self._indices_forward[4 * i + 3]\n\n            h1 = states[i1]\n            h2 = states[i2]\n            h3 = states[i3]\n            h4 = states[i4]\n\n            op1 = self._ops[4 * i]\n            op2 = self._ops[4 * i + 1]\n            op3 = self._ops[4 * i + 2]\n            op4 = self._ops[4 * i + 3]\n            h1 = op1(h1)\n            h2 = op2(h2)\n            h3 = op3(h3)\n            h4 = op4(h4)\n\n            if self.training and drop_prob > 0.:\n                if not isinstance(op1, Identity):\n                    h1 = drop_path(h1, drop_prob)\n                if not isinstance(op2, Identity):\n                    h2 = drop_path(h2, drop_prob)                \n                if not isinstance(op3, Identity):\n                    h3 = drop_path(h3, drop_prob)                \n                if not isinstance(op4, Identity):\n                    h4= drop_path(h4, drop_prob)\n            s = h1 + h2 + h3 + h4\n            \n            if self.back_connection:\n                if i != 0:\n                    s_back = self._ops_back[i - 1](s)\n                    states[self._indices_backward[i - 1]] = states[self._indices_backward[i - 1]] + s_back\n            states += [s]\n        \n            \n        \n        outputs = torch.cat([states[i] for i in self._concat], dim=1)  # N，C，H, W\n        return outputs\n        # return self.node(outputs)\n\n\n\n@register_model\n\nclass NetworkCIFAR(BaseModule):\n\n    def __init__(self,\n                 C,\n                 num_classes,\n                 layers,\n                 auxiliary,\n                 motif,\n                 cell_type,\n                 parse_method='darts',\n                 step=5,\n                 node_type='ReLUNode',\n                 **kwargs):\n        super(NetworkCIFAR, self).__init__(\n            step=step,\n            num_classes=num_classes,\n            **kwargs\n        )\n        self.node_type=node_type\n        if isinstance(node_type, str):\n            self.act_fun = eval(node_type)\n        else:\n            self.act_fun = node_type\n        self.act_fun = partial(self.act_fun, **kwargs)\n        \n        self.spike_output = kwargs['spike_output'] if 'spike_output' in kwargs else True\n        self.dataset = kwargs['dataset']\n\n        if self.layer_by_layer:\n            self.flatten = nn.Flatten(start_dim=1)\n        else:\n            self.flatten = nn.Flatten()\n\n        self._layers = layers\n        self.cell_type = cell_type\n        self._auxiliary = auxiliary\n\n        self.drop_path_prob = 0\n\n        stem_multiplier = 3\n        C_curr = stem_multiplier * C\n        if self.dataset == 'dvsg' or self.dataset == 'dvsc10' or self.dataset == 'NCALTECH101':\n            self.stem = nn.Sequential(\n                nn.Conv2d(2 * self.init_channel_mul, C_curr, 3, padding=1, bias=False),\n                nn.BatchNorm2d(C_curr),\n            )\n            # self.reduce_idx = [\n            #     layers // 4,\n            #     layers // 2,\n            #     3 * layers // 4\n            # ]\n            self.reduce_idx = [1, 3, 5, 7]\n        else:\n            self.stem = nn.Sequential(\n                nn.Conv2d(3 * self.init_channel_mul, C_curr, 3, padding=1, bias=False),\n                nn.BatchNorm2d(C_curr),\n            )\n            self.reduce_idx = [layers // 4,\n                               layers // 2,\n                               3 * layers // 4]\n        C_prev_prev_prev = C_curr\n        C_prev_prev_prev_prev = C_curr\n\n        C_prev_prev, C_prev, C_curr = C_curr, C_curr, C\n        self.cells = nn.ModuleList()\n        reduction_prev = False\n        reduction_prev_prev = False\n        reduction_prev_prev_prev = False\n\n\n        for i in range(layers):\n            if i in self.reduce_idx:\n                C_curr *= 2\n                reduction = True\n            else:\n                reduction = False\n\n            if cell_type==2:\n                # print(C_prev_prev, C_prev, C_curr)\n\n                cell = EvoCell2(motif[i], C_prev_prev, C_prev, C_curr,reduction, reduction_prev,act_fun=self.act_fun)\n                self.cells += [cell]\n                C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr\n\n            if cell_type==3:\n                cell = EvoCell3(motif[i], C_prev_prev_prev, C_prev_prev, C_prev, C_curr,reduction, reduction_prev,reduction_prev_prev,act_fun=self.act_fun)  \n                self.cells += [cell]\n                C_prev_prev_prev = C_prev_prev\n                reduction_prev_prev = reduction_prev\n\n                C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr\n\n            if cell_type==4:\n                cell = EvoCell4(motif[i], C_prev_prev_prev_prev,C_prev_prev_prev, C_prev_prev, C_prev, C_curr,reduction, reduction_prev,reduction_prev_prev,reduction_prev_prev_prev,act_fun=self.act_fun)  \n                self.cells += [cell]\n                C_prev_prev_prev_prev = C_prev_prev_prev\n                C_prev_prev_prev = C_prev_prev\n                reduction_prev_prev_prev = reduction_prev_prev\n                reduction_prev_prev = reduction_prev\n\n                C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr\n\n\n            reduction_prev = reduction\n\n\n        self.global_pooling = nn.Sequential(\n            self.act_fun(), nn.AdaptiveAvgPool2d(1))\n\n        if self.spike_output:\n            self.classifier = nn.Sequential(\n                nn.Linear(C_prev, 10 * num_classes),\n                self.act_fun())\n            self.vote = VotingLayer(10)\n        else:\n            self.classifier = nn.Linear(C_prev, num_classes)\n            self.vote = nn.Identity()\n\n        # self.classifier = nn.Linear(C_prev, num_classes)\n        # self.vote = nn.Identity()\n\n    def forward(self, inputs):\n        logits_aux = None\n        inputs = self.encoder(inputs)\n        if not self.layer_by_layer:\n            outputs = []\n            output_aux = []\n            self.reset()\n\n            if self.cell_type==2:\n\n                for t in range(self.step):\n                    x = inputs[t]\n                    s0 = s1 = self.stem(x)\n                    for i, cell in enumerate(self.cells):\n                        s0, s1 = s1, cell(s0, s1, self.drop_path_prob)\n                    out = self.global_pooling(s1)\n                    out = self.classifier(self.flatten(out))\n                    logits = self.vote(out)\n                    outputs.append(logits)\n                    output_aux.append(logits_aux)\n                return sum(outputs) / len(outputs)\n\n            if self.cell_type==3:\n                for t in range(self.step):\n                    x = inputs[t]\n                    s0 = s1 = s2= self.stem(x)\n                    for i, cell in enumerate(self.cells):\n                        s0, s1, s2 = s1, s2, cell(s0, s1, s2, self.drop_path_prob)\n\n                    out = self.global_pooling(s2)\n                    out = self.classifier(self.flatten(out))\n                    logits = self.vote(out)\n                    outputs.append(logits)\n                    output_aux.append(logits_aux)\n                return sum(outputs) / len(outputs)\n\n            if self.cell_type==4:\n                for t in range(self.step):\n                    x = inputs[t]\n                    s0 = s1 = s2= s3=self.stem(x)\n                    for i, cell in enumerate(self.cells):\n                        s0, s1, s2,s3= s1, s2, s3,cell(s0, s1, s2,s3 ,self.drop_path_prob)\n\n                    out = self.global_pooling(s3)\n                    out = self.classifier(self.flatten(out))\n                    logits = self.vote(out)\n                    outputs.append(logits)\n                    output_aux.append(logits_aux)\n                return sum(outputs) / len(outputs)\n                \n            \n\n\n\n            # logits_aux if logits_aux is None else (sum(output_aux) / len(output_aux))\n        else:\n            s0 = s1 = self.stem(inputs)\n            for i, cell in enumerate(self.cells):\n                s0, s1 = s1, cell(s0, s1, self.drop_path_prob)\n                if i == 2 * self._layers // 3:\n                    if self._auxiliary and self.training:\n                        logits_aux = self.auxiliary_head(s1)\n            out = self.global_pooling(s1)\n            out = self.classifier(self.flatten(out))\n            out = rearrange(out, '(t b) c -> t b c', t=self.step).mean(0)\n            logits = self.vote(out)\n            return logits\n\n\n@register_model\n\nclass NetworkCIFAR_(BaseModule):\n\n    def __init__(self,\n                 C,\n                 num_classes,\n                 layers,\n                 glob,\n                 auxiliary,\n                 motif,\n                 parse_method='darts',\n                 step=5,\n                 node_type='ReLUNode',\n                 **kwargs):\n        super(NetworkCIFAR_, self).__init__(\n            step=step,\n            num_classes=num_classes,\n            **kwargs\n        )\n        self.node_type=node_type\n        if isinstance(node_type, str):\n            self.act_fun = eval(node_type)\n        else:\n            self.act_fun = node_type\n        self.act_fun = partial(self.act_fun, **kwargs)\n        \n        self.spike_output = kwargs['spike_output'] if 'spike_output' in kwargs else True\n        self.dataset = kwargs['dataset']\n\n        if self.layer_by_layer:\n            self.flatten = nn.Flatten(start_dim=1)\n        else:\n            self.flatten = nn.Flatten()\n        self.glob = glob\n        self._layers = layers\n        self._auxiliary = auxiliary\n\n        self.drop_path_prob = 0\n\n        stem_multiplier = 3\n        C_curr = stem_multiplier * C\n        if self.dataset == 'dvsg' or self.dataset == 'dvsc10' or self.dataset == 'NCALTECH101':\n            self.stem = nn.Sequential(\n                nn.Conv2d(2 * self.init_channel_mul, C_curr, 3, padding=1, bias=False),\n                nn.BatchNorm2d(C_curr),\n            )\n            # self.reduce_idx = [\n            #     layers // 4,\n            #     layers // 2,\n            #     3 * layers // 4\n            # ]\n            self.reduce_idx = [1, 3, 5, 7]\n        else:\n            self.stem = nn.Sequential(\n                nn.Conv2d(3 * self.init_channel_mul, C_curr, 3, padding=1, bias=False),\n                nn.BatchNorm2d(C_curr),\n            )\n            self.reduce_idx = [layers // 4,\n                               layers // 2,\n                               3 * layers // 4]\n        C_prev_prev_prev = C_curr\n        C_prev_prev_prev_prev = C_curr\n\n        C_prev_prev, C_prev, C_curr = C_curr, C_curr, C\n        self.cells = nn.ModuleList()\n        reduction_prev = False\n        reduction_prev_prev = False\n        reduction_prev_prev_prev = False\n\n\n        for i in range(layers):\n            if i in self.reduce_idx:\n                C_curr *= 2\n                reduction = True\n            else:\n                reduction = False\n\n            cell = EvoCell2(motif[i], C_prev_prev, C_prev, C_curr,reduction, reduction_prev,act_fun=self.act_fun)\n            self.cells += [cell]\n            C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr\n            reduction_prev = reduction\n\n\n        self.global_pooling = nn.Sequential(\n            self.act_fun(), nn.AdaptiveAvgPool2d(1))\n\n        if self.spike_output:\n            self.classifier = nn.Sequential(\n                nn.Linear(C_prev, 10 * num_classes),\n                self.act_fun())\n            self.vote = VotingLayer(10)\n        else:\n            self.classifier = nn.Linear(C_prev, num_classes)\n            self.vote = nn.Identity()\n\n        # self.classifier = nn.Linear(C_prev, num_classes)\n        # self.vote = nn.Identity()\n\n    def forward(self, inputs):\n        logits_aux = None\n        inputs = self.encoder(inputs)\n        if not self.layer_by_layer:\n            outputs = []\n            output_aux = []\n            self.reset()\n\n            zzz=[]\n            kkk=[]\n\n            for t in range(self.step):\n\n                x = inputs[t]\n                s0 = s1 = self.stem(x)\n                # print(s1.shape)\n                for i, cell in enumerate(self.cells):\n                    \n                    if t>0 and i%5==4:\n                        qw = np.where(self.glob[:,int(i//5)]==1)\n                        if qw[0].shape[0]!=0:\n\n                            for m in qw:\n                                if zzz[m[0]].shape[-1]>s1.shape[-1]:\n                                    ks=zzz[m[0]].shape[-1] - (s1.shape[-1]-1)*2+2\n                                    if ks<0:\n                                        ks=zzz[m[0]].shape[-1] - (s1.shape[-1]-1)+2\n                                        bb=nn.Conv2d(zzz[m[0]].shape[1], s1.shape[1], kernel_size=ks, stride=1,padding=1, bias=False).to(zzz[m[0]].device)\n                                    else:\n                                        bb=nn.Conv2d(zzz[m[0]].shape[1], s1.shape[1], kernel_size=ks, stride=2,padding=1, bias=False).to(zzz[m[0]].device)\n                                    aa=bb(zzz[m[0]])\n                                    s1=aa+s1\n                                elif zzz[m[0]].shape[-1]<s1.shape[-1]:\n\n                                    aa=nn.functional.interpolate(zzz[m[0]],[s1.shape[-1], s1.shape[-1]], mode='bilinear', align_corners=False)\n                                    bb=nn.Conv2d(zzz[m[0]].shape[1], s1.shape[1],kernel_size=1).to(zzz[m[0]].device)\n                                    aa=bb(aa)\n                                    s1=aa+s1\n\n\n\n\n                    s0,s1=s1,cell(s0, s1, self.drop_path_prob)\n                    if i%5==4:\n                        if t==0:\n                            zzz.append(s1)\n                            kkk.append(s0)\n                        else:\n                            zzz[int(i//5)] = s1\n                            kkk[int(i//5)] = s0\n\n\n                out = self.global_pooling(s1)\n                out = self.classifier(self.flatten(out))\n                logits = self.vote(out)\n                outputs.append(logits)\n                output_aux.append(logits_aux)\n            return sum(outputs) / len(outputs)\n\n\n            # logits_aux if logits_aux is None else (sum(output_aux) / len(output_aux))\n        else:\n            s0 = s1 = self.stem(inputs)\n            for i, cell in enumerate(self.cells):\n                s0, s1 = s1, cell(s0, s1, self.drop_path_prob)\n                if i == 2 * self._layers // 3:\n                    if self._auxiliary and self.training:\n                        logits_aux = self.auxiliary_head(s1)\n            out = self.global_pooling(s1)\n            out = self.classifier(self.flatten(out))\n            out = rearrange(out, '(t b) c -> t b c', t=self.step).mean(0)\n            logits = self.vote(out)\n            return logits\n\n\ndef occumpy_mem(cuda_device):\n    total, used = os.popen('\"/usr/bin/nvidia-smi\" --query-gpu=memory.total,memory.used --format=csv,nounits,noheader').read().strip().split(\"\\n\")[int(cuda_device)].split(',')\n    # total, used = check_mem(cuda_device)\n    total = int(total)\n    used = int(used)\n    max_mem = int(total * 1)\n    block_mem = int((max_mem - used)*0.85)\n    x = torch.cuda.FloatTensor(256,1024,block_mem)\n    del x\n\nif __name__ == '__main__':\n    torch.cuda.set_device('cuda:3')\n    # occumpy_mem(str(3))\n\n    x = torch.rand(128, 3, 32, 32)\n    glob = np.array([[0,1,0,0],[1,0,1,0],[1,0,0,1],[0,1,0,0]])\n    # glob = np.array([[0,1],[1,0]])\n    glob = np.array([[0,1,0,0],[0,0,0,0],[0,0,0,1],[0,0,0,0]])\n    glob = np.array([[0,1,1,1],[1,0,1,1],[1,1,0,1],[1,1,1,0]])\n    glob = np.array([[0,1,0],[1,0,0],[1,1,1]])\n\n    motifs=[mm2,mm3,mm4,mm5,mm1,mm5,mm3,mm4,mm2,mm1,mm2,mm3,mm4,mm5,mm1]\n    # motifs=[m1,m2,m3,m1,m5,m4,m1,m2,m3,m1,m5,m4,m5,m4,m1]##3\n    # motifs=[t2,t3,t4,t5,t1,t5,t3,t4,t2,t1,t2,t3,t4,t5,t1,t5,t3,t4,t2,t1]\n    # motifs=[t2,t3,t4,t5,t1]\n\n    # motifs=[subnet,subnet,subnet]\n\n    net=NetworkCIFAR_(C=12,num_classes=10,motif=motifs,layers=len(motifs),auxiliary=True,dataset='cifar10',glob=glob)\n    # net=NetworkCIFAR(C=12,num_classes=10,motif=motifs,layers=len(motifs),auxiliary=True,dataset='cifar10',cell_type=2)\n\n    net=net.cuda()\n    layers=int(len(motifs)/5)\n    out=net(x.to('cuda:3'))\n    print(out.shape)\n"
  },
  {
    "path": "examples/Structure_Evolution/EB-NAS/ebnas.py",
    "content": "import sys\nimport numpy as np\nimport argparse\nimport time\nimport timm.models\nimport yaml\nimport os\nimport logging\nfrom random import choice\nfrom collections import OrderedDict\nfrom contextlib import suppress\nfrom datetime import datetime\nfrom micro_encoding import ops\nfrom braincog.base.node.node import *\nfrom braincog.utils import *\nfrom braincog.base.utils.criterions import *\nfrom braincog.datasets.datasets import *\nfrom braincog.model_zoo.resnet import *\nfrom braincog.model_zoo.convnet import *\n# from braincog.model_zoo.reactnet import *\n# from braincog.model_zoo.convxnet import *\nfrom braincog.utils import save_feature_map, setup_seed\nfrom braincog.base.utils.visualization import plot_tsne_3d, plot_tsne, plot_confusion_matrix\nimport micro_encoding\nimport nsganet as engine\nfrom pymop.problem import Problem\nimport torch\nimport torch.nn as nn\nimport torchvision.utils\nfrom torch.nn.parallel import DistributedDataParallel as NativeDDP\nfrom pymoo.optimize import minimize\nfrom tm import train_motifs\nfrom timm.data import create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset\nfrom timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model\nfrom timm.utils import *\nfrom timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy\nfrom timm.optim import create_optimizer\nfrom timm.scheduler import create_scheduler\nfrom timm.utils import ApexScaler, NativeScaler\nfrom cellmodel import NetworkCIFAR\nbits=20\n\n# from ptflops import get_model_complexity_info\n# from thop import profile, clever_format\n\n\n\ntorch.backends.cudnn.benchmark = True\n_logger = logging.getLogger('')\n# The first arg parser parses out only thei --config argument, this argument is used to\n# load a yaml file containing key-values that override the defaults for the main parser below\nconfig_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)\ndevices=[4]\n\nmax_gen = 100\n\nparser = argparse.ArgumentParser(description='SNN Training and Evaluating')\n# Model parameters\nparser.add_argument('--seed', type=int, default=99, metavar='S',\n                    help='random seed (default: 42)')\nparser.add_argument('--eval_epochs', type=int, default=1)\nparser.add_argument('--bns', action='store_true', default=True)\nparser.add_argument('--mid', type=int, default=5)\nparser.add_argument('--trainning_epochs', type=int, default=600, metavar='N',help='number of epochs to train (default: 2)')\nparser.add_argument('--cooldown-epochs', type=int, default=0, metavar='N',help='epochs to cooldown LR at min_lr, after cyclic schedule ends')\nparser.add_argument('--init-channels', type=int, default=36)\nparser.add_argument('--layers', type=int, default=2)\nparser.add_argument('--output', default='', type=str, metavar='PATH')\nparser.add_argument('--spike-rate', action='store_true', default=False)\n\nparser.add_argument('--n_gens', type=int, default=max_gen, help='population size')\nparser.add_argument('--bs', type=int, default=100)\n\nparser.add_argument('--n_offspring', type=int, default=60, help='number of offspring created per generation')\nparser.add_argument('-c', '--config', default='', type=str, metavar='FILE',\n                    help='YAML config file specifying default arguments')\nparser.add_argument('--dataset', default='cifar10', type=str)\nparser.add_argument('--num-classes', type=int, default=10, metavar='N')\nparser.add_argument('--model', default='NetworkCIFAR', type=str, metavar='MODEL',\n                    help='Name of model to train (default: \"countception\"')\nparser.add_argument('--pretrained', action='store_true', default=False,\n                    help='Start with pretrained version of specified network (if avail)')\nparser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',\n                    help='Initialize model from this checkpoint (default: none)')\nparser.add_argument('--resume', default='', type=str, metavar='PATH',\n                    help='Resume full model and optimizer state from checkpoint (default: none)')\nparser.add_argument('--eval_checkpoint', default='', type=str, metavar='PATH',\n                    help='path to eval checkpoint (default: none)')\nparser.add_argument('--no-resume-opt', action='store_true', default=False,\n                    help='prevent resume of optimizer state when resuming model')\nparser.add_argument('--gp', default=None, type=str, metavar='POOL',\n                    help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')\n\n# Dataset parameters for static datasets\nparser.add_argument('--img-size', type=int, default=224, metavar='N',\n                    help='Image patch size (default: None => model default)')\nparser.add_argument('--crop-pct', default=None, type=float,\n                    metavar='N', help='inputs image center crop percent (for validation only)')\nparser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',\n                    help='Override mean pixel value of dataset')\nparser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',\n                    help='Override std deviation of of dataset')\nparser.add_argument('--interpolation', default='', type=str, metavar='NAME',\n                    help='Image resize interpolation type (overrides model)')\nparser.add_argument('--strgenome', default='4,0,0,1,1,1,0,1,0,1,0,0,0,1,0,1,0,1,0,0,4,1,1,1,1,1,0,1,1,0,1,1,0,1,1,1,0,1,0,0,4,1,1,0,1,0,0,1,0,1,1,0,1,1,0,0,1,0,1,2,4,1,0,0,1,1,1,0,0,1,0,0,1,0,0,1,0,0,1,1,4,0,1,1,0,1,0,0,1,0,1,0,1,1,1,0,1,0,1,3,4,1,0,1,0,0,1,1,1,0,0,1,0,1,0,0,0,1,0,0,3', type=str)\n\n\n# Dataloader parameters\nparser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',\n                    help='inputs batch size for training (default: 128)')\nparser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',\n                    help='ratio of validation batch size to training batch size (default: 1)')\n\n# Optimizer parameters\nparser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',\n                    help='Optimizer (default: \"adamw\"')\nparser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',\n                    help='Optimizer Epsilon (default: None, use opt default)')\nparser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',\n                    help='Optimizer Betas (default: None, use opt default)')\nparser.add_argument('--momentum', type=float, default=0.9, metavar='M',\n                    help='Optimizer momentum (default: 0.9)')\nparser.add_argument('--weight-decay', type=float, default=0.01,\n                    help='weight decay (default: 0.01 for adamw)')\nparser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',\n                    help='Clip gradient norm (default: None, no clipping)')\nparser.add_argument('--adam-epoch', type=int, default=1000, help='lamb switch to adamw')\n\n# Learning rate schedule parameters\nparser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',\n                    help='LR scheduler (default: \"cosine\"')\nparser.add_argument('--lr', type=float, default=5e-3, metavar='LR',\n                    help='learning rate (default: 0.01)')\nparser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',\n                    help='learning rate noise on/off epoch percentages')\nparser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',\n                    help='learning rate noise limit percent (default: 0.67)')\nparser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',\n                    help='learning rate noise std-dev (default: 1.0)')\nparser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',\n                    help='learning rate cycle len multiplier (default: 1.0)')\nparser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',\n                    help='learning rate cycle limit')\nparser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',\n                    help='warmup learning rate (default: 0.0001)')\nparser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',\n                    help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')\n\nparser.add_argument('--start-epoch', default=None, type=int, metavar='N',\n                    help='manual epoch number (useful on restarts)')\nparser.add_argument('--decay-epochs', type=float, default=30, metavar='N',\n                    help='epoch interval to decay LR')\nparser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',\n                    help='epochs to warmup LR, if scheduler supports')\n\nparser.add_argument('--patience-epochs', type=int, default=10, metavar='N',\n                    help='patience epochs for Plateau LR scheduler (default: 10')\nparser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',\n                    help='LR decay rate (default: 0.1)')\nparser.add_argument('--power', type=int, default=1, help='power')\n\n# Augmentation & regularization parameters ONLY FOR IMAGE NET\nparser.add_argument('--no-aug', action='store_true', default=False,\n                    help='Disable all training augmentation, override other train aug args')\nparser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',\n                    help='Random resize scale (default: 0.08 1.0)')\nparser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                    help='Random resize aspect ratio (default: 0.75 1.33)')\nparser.add_argument('--hflip', type=float, default=0.5,\n                    help='Horizontal flip training aug probability')\nparser.add_argument('--vflip', type=float, default=0.,\n                    help='Vertical flip training aug probability')\nparser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',\n                    help='Color jitter factor (default: 0.4)')\nparser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',\n                    help='Use AutoAugment policy. \"v0\" or \"original\". (default: None)'),\nparser.add_argument('--aug-splits', type=int, default=0,\n                    help='Number of augmentation splits (default: 0, valid: 0 or >=2)')\nparser.add_argument('--jsd', action='store_true', default=False,\n                    help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')\nparser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',\n                    help='Random erase prob (default: 0.25)')\nparser.add_argument('--remode', type=str, default='pixel',\n                    help='Random erase mode (default: \"const\")')\nparser.add_argument('--recount', type=int, default=1,\n                    help='Random erase count (default: 1)')\nparser.add_argument('--resplit', action='store_true', default=False,\n                    help='Do not random erase first (clean) augmentation split')\nparser.add_argument('--mixup', type=float, default=0.8,\n                    help='mixup alpha, mixup enabled if > 0. (default: 0.)')\nparser.add_argument('--cutmix', type=float, default=1.0,\n                    help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')\nparser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,\n                    help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')\nparser.add_argument('--mixup-prob', type=float, default=1.0,\n                    help='Probability of performing mixup or cutmix when either/both is enabled')\nparser.add_argument('--mixup-switch-prob', type=float, default=0.5,\n                    help='Probability of switching to cutmix when both mixup and cutmix enabled')\nparser.add_argument('--mixup-mode', type=str, default='batch',\n                    help='How to apply mixup/cutmix params. Per \"batch\", \"pair\", or \"elem\"')\nparser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',\n                    help='Turn off mixup after this epoch, disabled if 0 (default: 0)')\nparser.add_argument('--smoothing', type=float, default=0.1,\n                    help='Label smoothing (default: 0.1)')\nparser.add_argument('--train-interpolation', type=str, default='random',\n                    help='Training interpolation (random, bilinear, bicubic default: \"random\")')\nparser.add_argument('--drop', type=float, default=0.0, metavar='PCT',\n                    help='Dropout rate (default: 0.0)')\nparser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',\n                    help='Drop connect rate, DEPRECATED, use drop-path (default: None)')\nparser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',\n                    help='Drop path rate (default: None)')\nparser.add_argument('--drop-block', type=float, default=None, metavar='PCT',\n                    help='Drop block rate (default: None)')\nparser.add_argument('--newton-maxiter', default=20, type=int,\n                    help='max iterration in newton method')\nparser.add_argument('--reset-drop', action='store_true', default=False,\n                    help='whether to reset drop')\nparser.add_argument('--kernel-method', type=str, default='cuda', choices=['torch', 'cuda'],\n                    help='The implementation way of gaussian kernel method, choose from \"cuda\" and \"torch\"')\n\n# Batch norm parameters (only works with gen_efficientnet based models currently)\nparser.add_argument('--bn-tf', action='store_true', default=False,\n                    help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')\nparser.add_argument('--bn-momentum', type=float, default=None,\n                    help='BatchNorm momentum override (if not None)')\nparser.add_argument('--bn-eps', type=float, default=None,\n                    help='BatchNorm epsilon override (if not None)')\nparser.add_argument('--sync-bn', action='store_true',\n                    help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')\nparser.add_argument('--dist-bn', type=str, default='',\n                    help='Distribute BatchNorm stats between node after each epoch (\"broadcast\", \"reduce\", or \"\")')\nparser.add_argument('--split-bn', action='store_true',\n                    help='Enable separate BN layers per augmentation split.')\n\n# Model Exponential Moving Average\nparser.add_argument('--model-ema', action='store_true', default=False,\n                    help='Enable tracking moving average of model weights')\nparser.add_argument('--model-ema-force-cpu', action='store_true', default=False,\n                    help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')\nparser.add_argument('--model-ema-decay', type=float, default=0.99996,\n                    help='decay factor for model weights moving average (default: 0.9998)')\n\n# Misc\n\nparser.add_argument('--log-interval', type=int, default=50, metavar='N',\n                    help='how many batches to wait before logging training status')\nparser.add_argument('--recovery-interval', type=int, default=0, metavar='N',\n                    help='how many batches to wait before writing recovery checkpoint')\nparser.add_argument('-j', '--workers', type=int, default=8, metavar='N',\n                    help='how many training processes to use (default: 1)')\nparser.add_argument('--num-gpu', type=int, default=len(devices),\n                    help='Number of GPUS to use')\nparser.add_argument('--save-images', action='store_true', default=False,\n                    help='save images of inputs bathes every log interval for debugging')\nparser.add_argument('--amp', action='store_true', default=False,\n                    help='use NVIDIA Apex AMP or Native AMP for mixed precision training')\nparser.add_argument('--apex-amp', action='store_true', default=False,\n                    help='Use NVIDIA Apex AMP mixed precision')\nparser.add_argument('--native-amp', action='store_true', default=False,\n                    help='Use Native Torch AMP mixed precision')\nparser.add_argument('--channels-last', action='store_true', default=False,\n                    help='Use channels_last memory layout')\nparser.add_argument('--pin-mem', action='store_true', default=False,\n                    help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')\nparser.add_argument('--no-prefetcher', action='store_true', default=False,\n                    help='disable fast prefetcher')\n\nparser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',\n                    help='Best metric (default: \"top1\"')\nparser.add_argument('--tta', type=int, default=0, metavar='N',\n                    help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')\nparser.add_argument('--local_rank', default=0, type=int)\nparser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,\n                    help='use the multi-epochs-loader to save time at the beginning of every epoch')\nparser.add_argument('--eval', action='store_true', help='Perform evaluation only')\nparser.add_argument('--device', type=int, default=devices[0])\n\n# Spike parameters\nparser.add_argument('--step', type=int, default=10, help='Simulation time step (default: 10)')\nparser.add_argument('--encode', type=str, default='direct', help='Input encode method (default: direct)')\nparser.add_argument('--temporal-flatten', action='store_true',\n                    help='Temporal flatten to channels. ONLY FOR EVENT DATA TRAINING BY ANN')\nparser.add_argument('--adaptive-node', action='store_true')\nparser.add_argument('--critical-loss', action='store_true')\n\n# neuron type\nparser.add_argument('--node-type', type=str, default='LIFNode', help='Node type in network (default: PLIF)')\nparser.add_argument('--act-fun', type=str, default='QGateGrad',\n                    help='Surogate Function in node. Only for Surrogate nodes (default: AtanGrad)')\nparser.add_argument('--threshold', type=float, default=.5, help='Firing threshold (default: 0.5)')\nparser.add_argument('--tau', type=float, default=2., help='Attenuation coefficient (default: 2.)')\nparser.add_argument('--requires-thres-grad', action='store_true')\nparser.add_argument('--sigmoid-thres', action='store_true')\n\nparser.add_argument('--loss-fn', type=str, default='ce', help='loss function (default: ce)')\nparser.add_argument('--noisy-grad', type=float, default=0.,\n                    help='Add noise to backward, sometime will make higher accuracy (default: 0.)')\nparser.add_argument('--spike-output', action='store_true', default=False,\n                    help='Using mem output or spike output (default: False)')\nparser.add_argument('--n_groups', type=int, default=1)\n\n# EventData Augmentation\nparser.add_argument('--mix-up', action='store_true', help='Mix-up for event data (default: False)')\nparser.add_argument('--cut-mix', action='store_true', help='CutMix for event data (default: False)')\nparser.add_argument('--event-mix', action='store_true', help='EventMix for event data (default: False)')\nparser.add_argument('--cutmix_beta', type=float, default=1.0, help='cutmix_beta (default: 1.)')\nparser.add_argument('--cutmix_prob', type=float, default=0.5, help='cutmix_prib for event data (default: .5)')\nparser.add_argument('--cutmix_num', type=int, default=1, help='cutmix_num for event data (default: 1)')\nparser.add_argument('--cutmix_noise', type=float, default=0.,\n                    help='Add Pepper noise after mix, sometimes work (default: 0.)')\nparser.add_argument('--rand-aug', action='store_true',\n                    help='Rand Augment for Event data (default: False)')\nparser.add_argument('--randaug_n', type=int, default=3,\n                    help='Rand Augment times n (default: 3)')\nparser.add_argument('--randaug_m', type=int, default=15,\n                    help='Rand Augment times n (default: 15) (0-30)')\nparser.add_argument('--train-portion', type=float, default=0.9,\n                    help='Dataset portion, only for datasets which do not have validation set (default: 0.9)')\nparser.add_argument('--event-size', default=48, type=int,\n                    help='Event size. Resize event data before process (default: 48)')\nparser.add_argument('--layer-by-layer', action='store_true',\n                    help='forward step-by-step or layer-by-layer. '\n                         'Larger Model with layer-by-layer will be faster (default: False)')\nparser.add_argument('--node-resume', type=str, default='',\n                    help='resume weights in node for adaptive node. (default: False)')\nparser.add_argument('--node-trainable', action='store_true')\n\n# visualize\nparser.add_argument('--visualize', action='store_true',\n                    help='Visualize spiking map for each layer, only for validate (default: False)')\n\nparser.add_argument('--tsne', action='store_true')\nparser.add_argument('--conf-mat', action='store_true')\n\n# DARTS parameters\n\nparser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')\n# parser.add_argument('--arch', default='dvsc10_new_skip19', type=str)\n# parser.add_argument('--motif', default='m1', type=str)\n\nparser.add_argument('--parse_method', default='darts', type=str)\nparser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability')\n\n# parser.add_argument('--back-connection', action='store_true',default=True)\n\nparser.add_argument('--suffix', type=str, default='',\n                    help='Add an additional suffix to the save path (default: \\'\\')')\n\ntry:\n    from apex import amp\n    from apex.parallel import DistributedDataParallel as ApexDDP\n    from apex.parallel import convert_syncbn_model\n\n    has_apex = True\nexcept ImportError:\n    has_apex = False\n\nhas_native_amp = False\ntry:\n    if getattr(torch.cuda.amp, 'autocast') is not None:\n        has_native_amp = True\nexcept AttributeError:\n    pass\n\n\n\nclass NAS(Problem):\n    # first define the NAS problem (inherit from pymop)\n    def __init__(self, args,n_var=20, n_obj=1, n_constr=0, lb=None, ub=None,\n                 init_channels=24, layers=8):\n        super().__init__(n_var=n_var, n_obj=n_obj, n_constr=n_constr, type_var=np.int64)\n        self.xl = lb\n        self.xu = ub\n        self._lr =args.lr\n\n        self._n_evaluated = 0  # keep track of how many architectures are sampled\n        self.args=args\n    def _evaluate(self, x, out, *args, **kwargs):\n        objs = np.full((x.shape[0], self.n_obj), np.nan)\n        for i in range(x.shape[0]):\n            arch_id = self._n_evaluated + 1\n            print('\\n')\n            _logger.info('Network= {}'.format(arch_id))\n\n            genome = x[i, :]\n            arch_dir=os.path.join(self.args.output_dir,i)\n            if os.path.exists(arch_dir) is False:\n                os.makedirs(arch_dir,exist_ok = True)\n            self.args.lr=self._lr\n            performance,acc=train_motifs(args=self.args,gen=0,arch_dir=arch_dir,genome=genome,_logger=_logger,args_text=args_text,devices=devices,bits=bits)\n\n\n            objs[i, 0] = 1000 - performance\n            _logger.info('performance= {}'.format(objs[i, 0]))\n            self._n_evaluated += 1\n\n\n        out[\"F\"] = objs\n        # if your NAS problem has constraints, use the following line to set constraints\n        # out[\"G\"] = np.column_stack([g1, g2, g3, g4, g5, g6]) in case 6 constraints\n\ndef do_every_generations(algorithm):\n    # this function will be call every generation\n    # it has access to the whole algorithm class\n    gen = algorithm.n_gen\n    pop_var = algorithm.pop.get(\"X\")\n    pop_obj = algorithm.pop.get(\"F\")\n\n    # report generation info to files\n    _logger.info(\"generation = {}\".format(gen))\n    _logger.info(\"population error: best = {}, mean = {}, \"\n                 \"median = {}, worst = {}\".format(np.min(pop_obj[:, 0]), np.mean(pop_obj[:, 0]),\n                                                  np.median(pop_obj[:, 0]), np.max(pop_obj[:, 0])))\n    _logger.info('Best Genome= {}'.format(pop_var[np.argmin(pop_obj[:, 0])]))\n\n\ndef _parse_args():\n    args_config, remaining = config_parser.parse_known_args()\n    args = parser.parse_args(remaining)\n    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)\n    return args, args_text\n\n\nif __name__ == '__main__':\n\n\n    args, args_text = _parse_args()\n    # args.no_spike_output = args.no_spike_output | args.cut_mix\n    args.no_spike_output = True\n    output_dir = ''\n\n    if args.local_rank == 0:\n        output_base = args.output if args.output else './output'\n        exp_name = '-'.join([\n            datetime.now().strftime(\"%Y%m%d-%H%M%S\"),\n            # args.model,\n            # args.dataset,\n            str(args.layers)+'layers',\n            str(args.init_channels)+'channels',\n            'motif'+str(args.mid),\n            str(args.step)+'steps',\n            # args.suffix\n            # str(args.img_size)\n        ])\n        output_dir = get_outdir(output_base,str(args.dataset),exp_name)\n        args.output_dir = output_dir\n        setup_default_logging(log_path=os.path.join(output_dir, 'log.txt'))\n\n    else:\n        setup_default_logging()\n\n    args.prefetcher = not args.no_prefetcher\n    args.distributed = False\n    if 'WORLD_SIZE' in os.environ:\n        args.distributed = int(os.environ['WORLD_SIZE']) > 1\n        if args.distributed and args.num_gpu > 1:\n            _logger.warning(\n                'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')\n            args.num_gpu = 1\n    # args.device = 'cuda:0'\n    args.world_size = 1\n    args.rank = 0  # global rank\n    if args.distributed:\n        args.num_gpu = 1\n        args.device = 'cuda:%d' % args.local_rank\n        torch.cuda.set_device(args.local_rank)\n        torch.distributed.init_process_group(backend='nccl', init_method='env://')\n        args.world_size = torch.distributed.get_world_size()\n        args.rank = torch.distributed.get_rank()\n    else:\n        torch.cuda.set_device('cuda:%d' % args.device)\n    assert args.rank >= 0\n\n    if args.distributed:\n        _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'\n                        % (args.rank, args.world_size))\n    else:\n        _logger.info('Training with a single process on %d GPUs.' % args.num_gpu)\n\n    # torch.manual_seed(args.seed + args.rank)\n    setup_seed(args.seed + args.rank)\n    defalut_lr = args.lr\n    sn = np.arange(1,6) \n    np.random.shuffle(sn)\n    args.subnet=sn\n\n    args.layers*=5\n    len_motifs=args.layers*bits\n\n    low = np.zeros(len_motifs)\n    up=[]\n    for i in range(0,args.layers*bits,bits):\n        t=[args.mid]\n        low[i]=args.mid\n        t=t+[(ops-1) for j in range(bits-1)]\n        t[-1]=2*(ops-1)\n        up.extend(t)\n\n    up=np.array(up).reshape(-1,)\n\n    kkk = NAS(args,n_var=len_motifs, \n                  n_obj=2, n_constr=0, lb=low, ub=up,\n                  init_channels=args.init_channels, layers=args.layers)\n    method = engine.nsganet(pop_size=args.pop_size,\n                            n_offsprings=args.n_offspring,\n                            eliminate_duplicates=True)\n    kres=minimize(kkk,\n                   method,\n                   callback=do_every_generations,\n                   termination=('n_gen', args.n_gens))"
  },
  {
    "path": "examples/Structure_Evolution/EB-NAS/micro_encoding.py",
    "content": "# NASNet Search Space https://arxiv.org/pdf/1707.07012.pdf\n# code modified from DARTS https://github.com/quark0/darts\nimport numpy as np\nfrom collections import namedtuple\nfrom random import choice\nfrom numpy.linalg import matrix_rank\nimport itertools\nimport torch\n# from models.micro_models import NetworkCIFAR as Network\n\nimport motifs\n\n# Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')\n# Genotype_norm = namedtuple('Genotype', 'normal normal_concat')\n# Genotype_redu = namedtuple('Genotype', 'reduce reduce_concat')\nGenotype = namedtuple('Genotype', 'normal normal_concat')\n\n# what you want to search should be defined here and in micro_operations\n\nPRIMITIVES = [\n    'max_pool_3x3',\n    'avg_pool_3x3',\n    'skip_connect',\n    'sep_conv_3x3',\n    'sep_conv_5x5',\n    'dil_conv_3x3',\n    'dil_conv_5x5',\n    'sep_conv_7x7',\n    'conv_7x1_1x7',\n]\nOPERATIONS_back = [\n    # 'max_pool_3x3_p_back',\n    # 'avg_pool_3x3_p_back',\n    'conv_3x3_p_back',\n    'conv_5x5_p_back',\n    # 'avg_pool_3x3_n_back',\n    'conv_3x3_n_back',\n    'conv_5x5_n_back',\n    # 'sep_conv_3x3_p_back',\n    # 'sep_conv_5x5_p_back',\n    # 'dil_conv_3x3_p_back',\n    # 'dil_conv_5x5_p_back',\n    # 'def_conv_3x3_p_back',\n    # 'def_conv_5x5_p_back',\n]\nOPERATIONS_p = [\n    # 'max_pool_3x3_p',\n    # 'avg_pool_3x3_p',\n    'conv_3x3_p',\n    'conv_5x5_p',\n    # 'sep_conv_3x3_p',\n    # 'sep_conv_5x5_p',\n    # 'dil_conv_3x3_p',\n    # 'dil_conv_5x5_p',\n    # 'def_conv_3x3_p',\n    # 'def_conv_5x5_p',\n]\nops=len(OPERATIONS_p)\n\nOPERATIONS_n = [\n    # 'max_pool_3x3_n',\n    # 'avg_pool_3x3_n',\n    'conv_3x3_n',\n    'conv_5x5_n',\n    # 'sep_conv_3x3_n',\n    # 'sep_conv_5x5_n',\n    # 'dil_conv_3x3_n',\n    # 'dil_conv_5x5_n',\n    # 'def_conv_3x3_n',\n    # 'def_conv_5x5_n',\n\n    # 'transformer',\n]\nmids=(3,3,2,2,1)\nmids=(1,2,3,4,5)\n\nms=len(mids)\npermutations = list({}.fromkeys(list(itertools.permutations(mids))).keys())\nmotifdict_c = dict(enumerate(permutations, 1))\nmotifdict = dict(zip(motifdict_c.values(),motifdict_c.keys()))\n\ndef convert_cell(cell_bit_string):\n    # convert cell bit-string to genome\n    tmp = [cell_bit_string[i:i + 2] for i in range(0, len(cell_bit_string), 2)]\n    return [tmp[i:i + 2] for i in range(0, len(tmp), 2)]\n\ndef filt(sn):\n    for i in range(sn.shape[0]):\n        if sn[i]<1:\n            sn[i]=1\n        if sn[i]>120:\n            sn[i]=120\n    return sn\n# def convert(bit_string):\n#     # convert network bit-string (norm_cell + redu_cell) to genome\n#     norm_gene = convert_cell(bit_string[:len(bit_string)//2])\n#     redu_gene = convert_cell(bit_string[len(bit_string)//2:])\n#     return [norm_gene, redu_gene]\n\ndef shuffle_along_axis(a, axis):\n    idx = np.random.rand(*a.shape).argsort(axis=axis)\n    return np.take_along_axis(a,idx,axis=axis)\n\ndef sample(pops, layers, bits , ):\n    sn = np.tile(mids,pops*layers).reshape(-1,ms)\n    sn=shuffle_along_axis(sn,axis=1)\n    sn=sn.reshape(-1,1)\n    # bigmotifs=(np.random.rand(pops*layers*ms,bits-1)<0.5).astype(int).reshape(-1,bits-1)\n    bigmotifs=(np.random.rand(pops*layers,bits-1)<0.5).astype(int).reshape(-1,bits-1)\n    # bigmotifs=np.repeat(bigmotifs,ms,axis=0)\n    sn=sn.reshape(pops*layers,-1)\n\n    # glob=np.array([2 for i in range(pops)])[:,np.newaxis]\n\n    genome=np.concatenate((sn, bigmotifs),axis=1).reshape(pops,-1)\n    glob = (np.random.rand(pops,layers*layers)<0.5).astype(int).reshape(pops,layers,layers)\n    # for i in range(layers):\n    #     glob[:,i,i]=0\n\n    genome=np.concatenate((genome, glob.reshape(pops,layers*layers)),axis=1).reshape(pops,-1)\n\n\n    return genome,sn,bigmotifs,ms,glob.reshape(pops,layers*layers)\n\ndef reencode(sn,bigmotifs,pops):\n    sn=sn.reshape(-1,ms)\n    mnumber=np.array([motifdict[tuple(x)] for x in sn]).reshape(pops,-1)\n    mgenome=np.concatenate((bigmotifs.reshape(pops,-1), mnumber),axis=1).reshape(pops,-1)\n\n    return mgenome\n\ndef convert(mgenome, layers, bits):\n    bigmotifs = mgenome[:,0:-layers].reshape(-1,bits-1)\n    sn=mgenome[:,-layers:].reshape(-1)\n    sn=filt(sn)\n    ssn=np.array([list(motifdict_c[sn[i]]) for i in range(sn.shape[0])]).reshape(-1,1)\n    result=np.concatenate((ssn,np.repeat(bigmotifs,ms,axis=0)),axis=1).reshape(-1,layers*bits*ms)\n\n    return np.c_[result,np.ones((result.shape[0],1))*2]\n\ndef convert_single(mgenome, layers, bits):\n    bigmotifs = mgenome[0:-layers].reshape(-1,bits-1)\n    sn=mgenome[-layers:].reshape(-1)\n    sn=filt(sn)\n    ssn=np.array([list(motifdict_c[sn[i]]) for i in range(sn.shape[0])]).reshape(-1,1)\n    result=list(np.concatenate((ssn,np.repeat(bigmotifs,ms,axis=0)),axis=1).reshape(layers*bits*ms))\n    return np.array(result)\n\ndef c_single(mgenome, layers, bits):\n    glob = mgenome[-layers*layers:]\n    mgenome = mgenome[:-layers*layers]\n    big=mgenome.reshape(layers,-1)\n    bigmotifs = big[:,ms:]\n    sn=big[:,0:ms]\n    result = np.concatenate((sn.reshape(-1,1),np.repeat(bigmotifs,ms,axis=0)),axis=1).reshape(-1,).tolist()\n    result.append(glob.reshape(layers,layers))\n    return result\n# def decode_cell(genome, norm=True):\n\n#     cell, cell_concat = [], list(range(2, len(genome)+2))\n#     for block in genome:\n#         for unit in block:\n#             cell.append((PRIMITIVES[unit[0]], unit[1]))\n#             if unit[1] in cell_concat:\n#                 cell_concat.remove(unit[1])\n\n#     if norm:\n#         return Genotype_norm(normal=cell, normal_concat=cell_concat)\n#     else:\n#         return Genotype_redu(reduce=cell, reduce_concat=cell_concat)\n\n\ndef decode(genome):\n    # decodes genome to architecture\n    normal_cell = genome[0]\n    reduce_cell = genome[1]\n\n    normal, normal_concat = [], list(range(2, len(normal_cell)+2))\n    reduce, reduce_concat = [], list(range(2, len(reduce_cell)+2))\n\n    for block in normal_cell:\n        for unit in block:\n            normal.append((PRIMITIVES[unit[0]], unit[1]))\n            if unit[1] in normal_concat:\n                normal_concat.remove(unit[1])\n\n    for block in reduce_cell:\n        for unit in block:\n            reduce.append((PRIMITIVES[unit[0]], unit[1]))\n            if unit[1] in reduce_concat:\n                reduce_concat.remove(unit[1])\n\n    return Genotype(\n        normal=normal, normal_concat=normal_concat,\n        reduce=reduce, reduce_concat=reduce_concat\n    )\n\ndef decode_motif(layers,bits,genome):\n    # decodes genome to architecture\n    motif_list=[]\n    motif_ids=[]\n\n    for b in range(0,layers*bits,bits):\n        motif_id='mm'+str(genome[b])\n\n        \n        motif_ids.append(genome[b])\n\n        normalcell=eval('motifs.%s' % motif_id)\n    \n        newnormal=[]\n        for i in range(0,len(normalcell.normal)):\n            op=normalcell.normal[i]\n            if 'skip' in op[0]:\n                newnormal.append(op)\n                continue\n            elif 'back' in op[0]:\n                newnormal.append((OPERATIONS_back[genome[b+1+len(normalcell.normal)-1]],op[1]))\n                continue            \n            elif '_n' in op[0]:\n                newnormal.append((OPERATIONS_n[genome[b+1+i]],op[1]))\n                continue\n            elif '_p' in op[0]:\n                newnormal.append((OPERATIONS_p[genome[b+1+i]],op[1]))\n                continue\n        m=Genotype(normal=newnormal, normal_concat=normalcell.normal_concat,)\n        motif_list.append(m)\n\n\n    return motif_list,motif_ids\n\ndef compare_cell(cell_string1, cell_string2):\n    cell_genome1 = convert_cell(cell_string1)\n    cell_genome2 = convert_cell(cell_string2)\n    cell1, cell2 = cell_genome1[:], cell_genome2[:]\n\n    for block1 in cell1:\n        for block2 in cell2:\n            if block1 == block2 or block1 == block2[::-1]:\n                cell2.remove(block2)\n                break\n    if len(cell2) > 0:\n        return False\n    else:\n        return True\n\n\ndef compare(string1, string2):\n\n    if compare_cell(string1[:len(string1)//2],\n                    string2[:len(string2)//2]):\n        if compare_cell(string1[len(string1)//2:],\n                        string2[len(string2)//2:]):\n            return True\n\n    return False\n\n\n# def debug():\n#     # design to debug the encoding scheme\n#     seed = 0\n#     np.random.seed(seed)\n#     budget = 2000\n#     B, n_ops, n_cell = 5, 7, 2\n#     networks = []\n#     design_id = 1\n#     while len(networks) < budget:\n#         bit_string = []\n#         for c in range(n_cell):\n#             for b in range(B):\n#                 bit_string += [np.random.randint(n_ops),\n#                                np.random.randint(b + 2),\n#                                np.random.randint(n_ops),\n#                                np.random.randint(b + 2)\n#                                ]\n\n#         genome = convert(bit_string)\n#         # check against evaluated networks in case of duplicates\n#         doTrain = True\n#         for network in networks:\n#             if compare(genome, network):\n#                 doTrain = False\n#                 break\n\n#         if doTrain:\n#             genotype = decode(genome)\n#             model = Network(16, 10, 8, False, genotype)\n#             model.drop_path_prob = 0.0\n#             data = torch.randn(1, 3, 32, 32)\n#             output, output_aux = model(torch.autograd.Variable(data))\n#             networks.append(genome)\n#             design_id += 1\n#             print(design_id)\n\n\nif __name__ == \"__main__\":\n    # debug()\n    # genome1 = [[[[3, 0], [3, 1]], [[3, 0], [3, 1]],\n    #             [[3, 1], [2, 0]], [[2, 0], [5, 2]]],\n    #            [[[0, 0], [0, 1]], [[2, 2], [0, 1]],\n    #             [[0, 0], [2, 2]], [[2, 2], [0, 1]]]]\n    # genome2 = [[[[3, 1], [3, 0]], [[3, 1], [3, 0]],\n    #             [[3, 1], [2, 0]], [[2, 0], [5, 2]]],\n    #            [[[0, 1], [0, 0]], [[2, 2], [0, 1]],\n    #             [[0, 0], [2, 2]], [[2, 2], [0, 0]]]]\n    #\n    # print(compare(genome1, genome2))\n    # print(genome1)\n    # print(genome2)\n    # bit_string1 = [3,1,3,0,3,1,3,0,3,1,2,0,2,0,5,2,0,0,0,1,2,2,0,1,0,0,2,2,2,2,0,1]\n    # bit_string2 = [3, 0, 3, 1, 3, 0, 3, 1, 3, 1, 2, 0, 2, 0, 5, 2,\n    #                0, 0, 0, 1, 2, 2, 0, 1, 0, 0, 2, 2, 2, 2, 0, 1]\n    # # print(convert(bit_string1))\n    # print(compare(bit_string1, bit_string2))\n    # print(decode(convert(bit_string)))\n\n    cell_bit_string = [3, 0, 3, 1, 3, 0, 3, 1, 3, 1, 2, 0, 2, 0, 5, 2]\n    # print(decode_cell(convert_cell(cell_bit_string), norm=False))\n    genome,sn,bigmotifs,ms=sample(195,2,20)\n    mgeno=reencode(sn,bigmotifs,195)\n    convert(mgeno,2,20)\n\n"
  },
  {
    "path": "examples/Structure_Evolution/EB-NAS/motifs.py",
    "content": "from collections import namedtuple\n\nimport torch\n\nGenotype = namedtuple('Genotype', 'normal normal_concat')\n\n\"\"\"\nOperation sets\n\"\"\"\n\nPRIMITIVES = [\n    'skip_connect',\n    # 'max_pool_3x3',\n    # 'avg_pool_3x3',\n    # 'def_conv_3x3',\n    # 'def_conv_5x5',\n    # 'sep_conv_3x3',\n    # 'sep_conv_5x5',\n    # 'dil_conv_3x3',\n    # 'dil_conv_5x5',\n\n    # 'max_pool_3x3_p',\n    # 'avg_pool_3x3_p',\n    'conv_3x3_p',\n    'conv_5x5_p',\n    # 'skip_connect_p',\n    # 'sep_conv_3x3_p',\n    # 'sep_conv_5x5_p',\n    # 'dil_conv_3x3_p',\n    # 'dil_conv_5x5_p',\n    # 'def_conv_3x3_p',\n    # 'def_conv_5x5_p',n\n\n    # 'max_pool_3x3_n',\n    # 'avg_pool_3x3_n',\n    'conv_3x3_n',\n    'conv_5x5_n',\n    # 'skip_connect_n',\n    # 'sep_conv_3x3_n',\n    # 'sep_conv_5x5_n',\n    # 'dil_conv_3x3_n',\n    # 'dil_conv_5x5_n',\n    # 'def_conv_3x3_n',\n    # 'def_conv_5x5_n',\n\n    # 'transformer',\n]\nm0=Genotype(\n    normal=[\n        ('skip', 0), ('skip', 1),('skip', 2),\n    ],\n    normal_concat=range(3, 4)\n)\nmm0=Genotype(\n    normal=[\n        ('skip', 0), ('skip', 1),('skip', 2),\n    ],\n    normal_concat=range(2, 3)\n)\nmm1=Genotype(\n    normal=[\n        ('conv_3x3_p', 0), ('conv_5x5_p', 1),\n        ('skip_connect', 0), ('conv_5x5_p', 2),\n    ],\n    normal_concat=range(2, 4)\n)\n\n\nmm2=Genotype(\n    normal=[\n        ('conv_5x5_p', 0), ('conv_5x5_p', 1),\n        ('skip_connect', 0), ('conv_5x5_n', 2),\n        ('conv_5x5_p', 2), ('conv_3x3_n', 3),\n\n    ],\n    normal_concat=range(2, 5)\n)\n\nmm4=Genotype(\n    normal=[\n        ('conv_3x3_p', 0), ('conv_5x5_p', 1),#2\n        ('conv_3x3_p', 0), ('conv_3x3_p', 1),#3\n        ('conv_5x5_p', 2), ('conv_5x5_p', 3),#4\n        ('skip_connect', 0), ('conv_3x3_p', 4),#5\n        ('skip_connect', 0), ('conv_3x3_p', 4),#6\n        ],\n    normal_concat=range(2, 7)\n)\n\n\n\nmm3=Genotype(\n    normal=[\n        ('conv_3x3_p', 0), ('conv_5x5_p', 1),#2\n        ('skip_connect', 0), ('conv_5x5_n', 2),#3\n        ('skip_connect', 0), ('conv_5x5_p', 3),#4\n\n        ('skip_connect_back', 2),#3\n        ('conv_3x3_p_back', 3),#4\n\n    ],\n    normal_concat=range(2, 5)\n)\n\n\nmm5=Genotype(\n    normal=[\n        ('conv_3x3_p', 0), ('conv_5x5_p', 1),#2\n        ('skip_connect', 0), ('conv_5x5_p', 2),#3\n\n        ('skip_connect_back', 2),#3\n    ],\n    normal_concat=range(2, 4)\n)\n\n\n\nm1=Genotype(\n    normal=[\n        ('conv_3x3_p', 0), ('conv_5x5_p', 1), ('conv_5x5_p', 2), #B3\n        ('skip', 0), ('conv_5x5_p', 3), ('skip', 1), #C4\n    ],\n    normal_concat=range(3, 5)\n)\n\n\nm2=Genotype(\n    normal=[\n        ('conv_5x5_p', 0), ('conv_5x5_p', 1),('conv_5x5_p', 2), #B3\n        ('skip', 0), ('conv_5x5_n', 3), ('skip', 1),#C4\n        ('conv_5x5_p', 3), ('conv_3x3_n', 4), ('skip', 1), #D5\n\n    ],\n    normal_concat=range(3, 6)\n)\n\nm4=Genotype(\n    normal=[\n        ('conv_3x3_p', 0), ('conv_5x5_p', 1),('conv_5x5_p', 2), #3\n        ('conv_3x3_p', 0), ('conv_3x3_p', 1),('conv_5x5_p', 2), #4\n        ('skip', 0), ('conv_5x5_p', 3), ('conv_5x5_p', 4), #5\n        ('skip', 0), ('conv_3x3_p', 3),('conv_3x3_n', 5),#6\n        ('skip', 0), ('conv_3x3_p', 4),('conv_3x3_n', 5),#7\n        ],\n    normal_concat=range(3, 8)\n)\n\n\n\nm3=Genotype(\n    normal=[\n        ('conv_3x3_p', 0), ('conv_5x5_p', 1), ('conv_3x3_p', 2), #3\n        ('skip', 0), ('conv_5x5_p', 3),('skip', 1), #4\n        ('skip', 0), ('conv_5x5_p', 3), ('skip', 1), #5\n\n        ('conv_3x3_n_back', 3),#4\n        ('skip_back', 2),#5\n\n    ],\n    normal_concat=range(3, 6)\n)\n\n\nm5=Genotype(\n    normal=[\n        ('conv_3x3_p', 0), ('conv_5x5_p', 1), ('conv_5x5_p', 2),#3\n        ('skip', 0),('skip', 1), ('conv_5x5_n', 3), #4\n\n        ('skip_connect_back', 3),#4\n    ],\n    normal_concat=range(3, 5)\n)\n\n\nt1=Genotype(\n    normal=[\n        ('conv_3x3_p', 0), ('conv_5x5_p', 1), ('conv_5x5_p', 2),  ('conv_5x5_p', 3), #4\n        ('skip', 0), ('conv_5x5_p', 4), ('skip', 1), ('skip', 2), #5\n        ('skip', 0), ('conv_5x5_p', 5), ('skip', 1), ('skip', 2), #6\n        ('skip', 0), ('conv_5x5_p', 5), ('skip', 1), ('skip', 2), #7\n    ],\n    normal_concat=range(4, 8)\n)\n\n\nt2=Genotype(\n    normal=[\n        ('conv_5x5_p', 0), ('conv_5x5_p', 1),('conv_5x5_p', 2), ('conv_5x5_p', 3), #4\n        ('skip', 0), ('conv_5x5_n', 4), ('skip', 1),('skip', 2),#5\n        ('conv_5x5_p', 4), ('conv_3x3_n', 5), ('skip', 1),('skip', 2), #6\n\n    ],\n    normal_concat=range(4, 7)\n)\n\nt4=Genotype(\n    normal=[\n        ('conv_3x3_p', 0), ('conv_5x5_p', 1),('conv_5x5_p', 2), ('conv_5x5_p', 3), #4\n        ('conv_5x5_p', 0), ('skip', 1),('conv_5x5_n', 4), ('skip', 3), #5\n        ('skip', 0), ('conv_5x5_p', 3), ('conv_5x5_n', 4), ('skip', 2),#6\n\n        ],\n    normal_concat=range(4, 7)\n)\n\n\n\nt3=Genotype(\n    normal=[\n        ('conv_3x3_p', 0), ('conv_5x5_p', 1), ('conv_3x3_p', 2), ('conv_3x3_p', 3),#4\n        ('skip', 0), ('skip', 2),('skip', 1), ('skip', 3),('conv_3x3_p', 4),#5\n        ('skip', 0), ('conv_5x5_p', 4), ('skip', 1),('skip', 2), #6\n\n        ('conv_3x3_n_back', 4),#5\n        ('skip_back', 4),#6\n\n    ],\n    normal_concat=range(4, 7)\n)\n\n\nt5=Genotype(\n    normal=[\n        ('conv_3x3_p', 0), ('conv_5x5_p', 1), ('conv_5x5_p', 2),('conv_5x5_p', 3),#4\n        ('skip', 0),('skip', 1), ('skip', 2), ('conv_5x5_n', 4), #5\n        ('skip', 0),('skip', 1), ('conv_5x5_n', 4),('conv_5x5_n', 5), #6\n        ('skip', 0),('skip', 1), ('conv_5x5_n', 4),('conv_5x5_n', 5), #7\n\n        ('conv_3x3_n_back', 4),#5\n        ('skip_back', 4),#6\n        ('skip_back', 5),#7\n\n    ],\n    normal_concat=range(4, 8)\n)"
  },
  {
    "path": "examples/Structure_Evolution/EB-NAS/nsganet.py",
    "content": "import numpy as np\n\nfrom pymoo.algorithms.genetic_algorithm import GeneticAlgorithm\nfrom pymoo.docs import parse_doc_string\nfrom pymoo.model.individual import Individual\nfrom pymoo.model.survival import Survival\nfrom pymoo.operators.crossover.point_crossover import PointCrossover\nfrom pymoo.operators.mutation.polynomial_mutation import PolynomialMutation\nfrom pymoo.operators.sampling.random_sampling import RandomSampling\nfrom pymoo.operators.selection.tournament_selection import compare, TournamentSelection\nfrom pymoo.util.display import disp_multi_objective\nfrom pymoo.util.dominator import Dominator\nfrom pymoo.util.non_dominated_sorting import NonDominatedSorting\nfrom pymoo.util.randomized_argsort import randomized_argsort\n\n\n# =========================================================================================================\n# Implementation\n# based on nsga2 from https://github.com/msu-coinlab/pymoo\n# =========================================================================================================\n\n\nclass NSGANet(GeneticAlgorithm):\n\n    def __init__(self, **kwargs):\n        kwargs['individual'] = Individual(rank=np.inf, crowding=-1)\n        super().__init__(**kwargs)\n\n        self.tournament_type = 'comp_by_dom_and_crowding'\n        self.func_display_attrs = disp_multi_objective\n\n\n# ---------------------------------------------------------------------------------------------------------\n# Binary Tournament Selection Function\n# ---------------------------------------------------------------------------------------------------------\n\n\ndef binary_tournament(pop, P, algorithm, **kwargs):\n    if P.shape[1] != 2:\n        raise ValueError(\"Only implemented for binary tournament!\")\n\n    tournament_type = algorithm.tournament_type\n    S = np.full(P.shape[0], np.nan)\n\n    for i in range(P.shape[0]):\n\n        a, b = P[i, 0], P[i, 1]\n\n        # if at least one solution is infeasible\n        if pop[a].CV > 0.0 or pop[b].CV > 0.0:\n            S[i] = compare(a, pop[a].CV, b, pop[b].CV, method='smaller_is_better', return_random_if_equal=True)\n\n        # both solutions are feasible\n        else:\n\n            if tournament_type == 'comp_by_dom_and_crowding':\n                rel = Dominator.get_relation(pop[a].F, pop[b].F)\n                if rel == 1:\n                    S[i] = a\n                elif rel == -1:\n                    S[i] = b\n\n            elif tournament_type == 'comp_by_rank_and_crowding':\n                S[i] = compare(a, pop[a].rank, b, pop[b].rank,\n                               method='smaller_is_better')\n\n            else:\n                raise Exception(\"Unknown tournament type.\")\n\n            # if rank or domination relation didn't make a decision compare by crowding\n            if np.isnan(S[i]):\n                S[i] = compare(a, pop[a].get(\"crowding\"), b, pop[b].get(\"crowding\"),\n                               method='larger_is_better', return_random_if_equal=True)\n\n    return S[:, None].astype(np.int)\n\n\n# ---------------------------------------------------------------------------------------------------------\n# Survival Selection\n# ---------------------------------------------------------------------------------------------------------\n\n\nclass RankAndCrowdingSurvival(Survival):\n\n    def __init__(self) -> None:\n        super().__init__(True)\n\n    def _do(self, pop, n_survive, D=None, **kwargs):\n\n        # get the objective space values and objects\n        F = pop.get(\"F\")\n\n        # the final indices of surviving individuals\n        survivors = []\n\n        # do the non-dominated sorting until splitting front\n        fronts = NonDominatedSorting().do(F, n_stop_if_ranked=n_survive)\n\n        for k, front in enumerate(fronts):\n\n            # calculate the crowding distance of the front\n            crowding_of_front = calc_crowding_distance(F[front, :])\n\n            # save rank and crowding in the individual class\n            for j, i in enumerate(front):\n                pop[i].set(\"rank\", k)\n                pop[i].set(\"crowding\", crowding_of_front[j])\n\n            # current front sorted by crowding distance if splitting\n            if len(survivors) + len(front) > n_survive:\n                I = randomized_argsort(crowding_of_front, order='descending', method='numpy')\n                I = I[:(n_survive - len(survivors))]\n\n            # otherwise take the whole front unsorted\n            else:\n                I = np.arange(len(front))\n\n            # extend the survivors by all or selected individuals\n            survivors.extend(front[I])\n\n        return pop[survivors]\n\n\ndef calc_crowding_distance(F):\n    infinity = 1e+14\n\n    n_points = F.shape[0]\n    n_obj = F.shape[1]\n\n    if n_points <= 2:\n        return np.full(n_points, infinity)\n    else:\n\n        # sort each column and get index\n        I = np.argsort(F, axis=0, kind='mergesort')\n\n        # now really sort the whole array\n        F = F[I, np.arange(n_obj)]\n\n        # get the distance to the last element in sorted list and replace zeros with actual values\n        dist = np.concatenate([F, np.full((1, n_obj), np.inf)]) \\\n               - np.concatenate([np.full((1, n_obj), -np.inf), F])\n\n        index_dist_is_zero = np.where(dist == 0)\n\n        dist_to_last = np.copy(dist)\n        for i, j in zip(*index_dist_is_zero):\n            dist_to_last[i, j] = dist_to_last[i - 1, j]\n\n        dist_to_next = np.copy(dist)\n        for i, j in reversed(list(zip(*index_dist_is_zero))):\n            dist_to_next[i, j] = dist_to_next[i + 1, j]\n\n        # normalize all the distances\n        norm = np.max(F, axis=0) - np.min(F, axis=0)\n        norm[norm == 0] = np.nan\n        dist_to_last, dist_to_next = dist_to_last[:-1] / norm, dist_to_next[1:] / norm\n\n        # if we divided by zero because all values in one columns are equal replace by none\n        dist_to_last[np.isnan(dist_to_last)] = 0.0\n        dist_to_next[np.isnan(dist_to_next)] = 0.0\n\n        # sum up the distance to next and last and norm by objectives - also reorder from sorted list\n        J = np.argsort(I, axis=0)\n        crowding = np.sum(dist_to_last[J, np.arange(n_obj)] + dist_to_next[J, np.arange(n_obj)], axis=1) / n_obj\n\n    # replace infinity with a large number\n    crowding[np.isinf(crowding)] = infinity\n\n    return crowding\n\n\n# =========================================================================================================\n# Interface\n# =========================================================================================================\n\n\ndef nsganet(\n        pop_size=100,\n        sampling=RandomSampling(var_type=np.int),\n        selection=TournamentSelection(func_comp=binary_tournament),\n        crossover=PointCrossover(n_points=2),\n        mutation=PolynomialMutation(eta=3, var_type=np.int),\n        eliminate_duplicates=True,\n        n_offsprings=None,\n        **kwargs):\n    \"\"\"\n\n    Parameters\n    ----------\n    pop_size : {pop_size}\n    sampling : {sampling}\n    selection : {selection}\n    crossover : {crossover}\n    mutation : {mutation}\n    eliminate_duplicates : {eliminate_duplicates}\n    n_offsprings : {n_offsprings}\n\n    Returns\n    -------\n    nsganet : :class:`~pymoo.model.algorithm.Algorithm`\n        Returns an NSGANet algorithm object.\n\n\n    \"\"\"\n\n    return NSGANet(pop_size=pop_size,\n                   sampling=sampling,\n                   selection=selection,\n                   crossover=crossover,\n                   mutation=mutation,\n                   survival=RankAndCrowdingSurvival(),\n                   eliminate_duplicates=eliminate_duplicates,\n                   n_offsprings=n_offsprings,\n                   **kwargs)\n\n\nparse_doc_string(nsganet)\n"
  },
  {
    "path": "examples/Structure_Evolution/EB-NAS/operations.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.nn import *\nimport torch.nn.functional as F\nfrom torch import einsum\nfrom einops import rearrange\n# from braincog.model_zoo.base_module import DeformConvPack\nfrom braincog.model_zoo.base_module import BaseLinearModule\n\n\n# from mmcv.ops import ModulatedDeformConv2dPack\n\n\ndef si_relu(x, positive):\n    if positive == 1:\n        return torch.where(x > 0., x, torch.zeros_like(x))\n    elif positive == 0:\n        return x\n    elif positive == -1:\n        return torch.where(x < 0., x, torch.zeros_like(x))\n    else:\n        raise ValueError\n\n\n\nclass SiReLU(nn.Module):\n    def __init__(self, positive=0):\n        super().__init__()\n        self.positive = positive\n\n    def forward(self, x):\n        return si_relu(x, self.positive)\n\n\ndef weight_init(m):\n    if isinstance(m, nn.Conv2d):\n        torch.nn.init.xavier_normal(m.weight.data, gain=0.1)\n        torch.nn.init.constant(m.bias.data, 0.)\n\nOPS_Mlp = {\n    'mlp': lambda C, act_fun:\n        SiMLP(C, C, act_fun=act_fun, positive=0),\n    'mlp_p': lambda C, act_fun:\n        SiMLP(C, C, act_fun=act_fun, positive=1),\n    'mlp_n': lambda C, act_fun:\n        SiMLP(C, C, act_fun=act_fun, positive=-1),\n\n    'skip_connect': lambda C, act_fun:\n        Identity(positive=0),\n    'skip_connect_p': lambda C, act_fun:\n        Identity(positive=1),\n    'skip_connect_n': lambda C, act_fun:\n        Identity(positive=-1),\n}\n\nOPS = {\n    'avg_pool_3x3': lambda C, stride, affine, act_fun: nn.AvgPool2d(3, stride=stride, padding=1,\n                                                                    count_include_pad=False),\n    'conv_3x3': lambda C, stride, affine, act_fun:\n        ReLUConvBN(C_in=C, C_out=C, kernel_size=3, padding=1, stride=stride, affine=affine, act_fun=act_fun, positive=0),\n    'conv_5x5': lambda C, stride, affine, act_fun:\n        ReLUConvBN(C_in=C, C_out=C, kernel_size=5, padding=2, stride=stride, affine=affine, act_fun=act_fun, positive=0),\n    'max_pool_3x3': lambda C, stride, affine, act_fun: nn.MaxPool2d(3, stride=stride, padding=1),\n    'skip_connect': lambda C, stride, affine, act_fun:\n        Identity(positive=0) if stride == 1 else FactorizedReduce(C, C, affine=affine, act_fun=act_fun),\n    'sep_conv_3x3': lambda C, stride, affine, act_fun:\n        SepConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=0),\n    'sep_conv_5x5': lambda C, stride, affine, act_fun:\n        SepConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=0),\n    'sep_conv_7x7': lambda C, stride, affine, act_fun:\n        SepConv(C, C, 7, stride, 3, affine=affine, act_fun=act_fun, positive=0),\n    'dil_conv_3x3': lambda C, stride, affine, act_fun:\n        DilConv(C, C, 3, stride, 2, 2, affine=affine, act_fun=act_fun, positive=0),\n    'dil_conv_5x5': lambda C, stride, affine, act_fun:\n        DilConv(C, C, 5, stride, 4, 2, affine=affine, act_fun=act_fun, positive=0),\n    'def_conv_3x3': lambda C, stride, affine, act_fun:\n        DeformConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=0),\n    'def_conv_5x5': lambda C, stride, affine, act_fun:\n        DeformConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=0),\n\n    'avg_pool_3x3_p': lambda C, stride, affine, act_fun: nn.Sequential(\n        nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),\n        SiReLU(positive=1)\n    ),\n    'max_pool_3x3_p': lambda C, stride, affine, act_fun: nn.Sequential(\n        nn.MaxPool2d(3, stride=stride, padding=1),\n        SiReLU(positive=1)\n    ),\n    'conv_3x3_p': lambda C, stride, affine, act_fun:\n        ReLUConvBN(C_in=C, C_out=C, kernel_size=3, padding=1, stride=stride, affine=affine, act_fun=act_fun, positive=1),\n    'conv_5x5_p': lambda C, stride, affine, act_fun:\n        ReLUConvBN(C_in=C, C_out=C, kernel_size=5, padding=2, stride=stride, affine=affine, act_fun=act_fun, positive=1),\n    'skip_connect_p': lambda C, stride, affine, act_fun:\n        Identity(positive=1) if stride == 1 else FactorizedReduce(C, C, affine=affine, act_fun=act_fun, positive=1),\n    'sep_conv_3x3_p': lambda C, stride, affine, act_fun:\n        SepConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=1),\n    'sep_conv_5x5_p': lambda C, stride, affine, act_fun:\n        SepConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=1),\n    'sep_conv_7x7_p': lambda C, stride, affine, act_fun:\n        SepConv(C, C, 7, stride, 3, affine=affine, act_fun=act_fun, positive=1),\n    'dil_conv_3x3_p': lambda C, stride, affine, act_fun:\n        DilConv(C, C, 3, stride, 2, 2, affine=affine, act_fun=act_fun, positive=1),\n    'dil_conv_5x5_p': lambda C, stride, affine, act_fun:\n        DilConv(C, C, 5, stride, 4, 2, affine=affine, act_fun=act_fun, positive=1),\n    'def_conv_3x3_p': lambda C, stride, affine, act_fun:\n        DeformConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=1),\n    'def_conv_5x5_p': lambda C, stride, affine, act_fun:\n        DeformConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=1),\n\n    'avg_pool_3x3_n': lambda C, stride, affine, act_fun: nn.Sequential(\n        nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),\n        SiReLU(positive=-1)\n    ),\n    'max_pool_3x3_n': lambda C, stride, affine, act_fun: nn.Sequential(\n            nn.MaxPool2d(3, stride=stride, padding=1),\n            SiReLU(positive=-1)\n    ),\n    'conv_3x3_n': lambda C, stride, affine, act_fun:\n        ReLUConvBN(C_in=C, C_out=C, kernel_size=3, padding=1, stride=stride, affine=affine, act_fun=act_fun, positive=-1),\n    'conv_5x5_n': lambda C, stride, affine, act_fun:\n        ReLUConvBN(C_in=C, C_out=C, kernel_size=5, padding=2, stride=stride, affine=affine, act_fun=act_fun, positive=-1),\n    'skip_connect_n': lambda C, stride, affine, act_fun:\n        Identity(positive=-1) if stride == 1 else FactorizedReduce(C, C, affine=affine, act_fun=act_fun, positive=-1),\n    'sep_conv_3x3_n': lambda C, stride, affine, act_fun:\n        SepConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=-1),\n    'sep_conv_5x5_n': lambda C, stride, affine, act_fun:\n        SepConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=-1),\n    'sep_conv_7x7_n': lambda C, stride, affine, act_fun:\n        SepConv(C, C, 7, stride, 3, affine=affine, act_fun=act_fun, positive=-1),\n    'dil_conv_3x3_n': lambda C, stride, affine, act_fun:\n        DilConv(C, C, 3, stride, 2, 2, affine=affine, act_fun=act_fun, positive=-1),\n    'dil_conv_5x5_n': lambda C, stride, affine, act_fun:\n        DilConv(C, C, 5, stride, 4, 2, affine=affine, act_fun=act_fun, positive=-1),\n    'def_conv_3x3_n': lambda C, stride, affine, act_fun:\n        DeformConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=-1),\n    'def_conv_5x5_n': lambda C, stride, affine, act_fun:\n        DeformConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=-1),\n\n    'conv_7x1_1x7': lambda C, stride, affine, act_fun: nn.Sequential(\n        # nn.ReLU(inplace=False),\n        act_fun(),\n        nn.Conv2d(C, C, (1, 7), stride=(1, stride),\n                  padding=(0, 3), bias=False),\n        nn.Conv2d(C, C, (7, 1), stride=(stride, 1),\n                  padding=(3, 0), bias=False),\n        nn.BatchNorm2d(C, affine=affine)\n    ),\n    'skip': lambda C, stride, affine, act_fun:\n        Zero(stride) if stride == 1 else FactorizedReduce(C, C, affine=affine, act_fun=act_fun, positive=1),\n    'transformer': lambda C, stride, affine, act_fun:\n        FactorizedReduce(\n            C, C, affine=affine, act_fun=act_fun) if stride != 1 else TransformerEncoderLayer(C),\n}\n\n\nclass SiMLP(nn.Module):\n    def __init__(self, c_in, c_out, act_fun=nn.ReLU, positive=0, *args, **kwargs):\n        super(SiMLP, self).__init__()\n        self.op = nn.Sequential(\n            nn.Linear(c_in, c_out, bias=True),\n            act_fun()\n        )\n        self.positive = positive\n\n    def forward(self, x):\n        out = self.op(si_relu(x, self.positive))\n        return out\n\n\n\n\nclass DilConv(nn.Module):\n    \"\"\"\n    Dilation Convolution ： ReLU -> DilConv -> Conv2d -> BatchNorm2d\n    \"\"\"\n\n    def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True, act_fun=nn.ReLU, positive=0):\n        super(DilConv, self).__init__()\n        self.op = nn.Sequential(\n            # nn.ReLU(inplace=False),\n            act_fun(),\n            nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,\n                      groups=C_in, bias=False),\n            nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),\n            nn.BatchNorm2d(C_out, affine=affine),\n        )\n        self.positive = positive\n        # if positive == -1:\n        #     weight_init(self.op)\n\n    def forward(self, x):\n        out = self.op(x)\n        return si_relu(out, self.positive)\n\n\nclass SepConv(nn.Module):\n\n    def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True, act_fun=nn.ReLU, positive=0):\n        super(SepConv, self).__init__()\n        self.op = nn.Sequential(\n            # nn.ReLU(inplace=False),\n            act_fun(),\n            nn.Conv2d(C_in, C_in, kernel_size=kernel_size,\n                      stride=stride, padding=padding, groups=C_in, bias=False),\n            nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),\n            nn.BatchNorm2d(C_in, affine=affine),\n            nn.ReLU(inplace=False),\n            nn.Conv2d(C_in, C_in, kernel_size=kernel_size,\n                      stride=1, padding=padding, groups=C_in, bias=False),\n            nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),\n            nn.BatchNorm2d(C_out, affine=affine),\n        )\n        self.positive = positive\n        # if positive == -1:\n        #     weight_init(self.op)\n\n    def forward(self, x):\n        out = self.op(x)\n        return si_relu(out, self.positive)\n\n\nclass Identity(nn.Module):\n\n    def __init__(self, positive=0):\n        super(Identity, self).__init__()\n        self.positive = positive\n\n    def forward(self, x):\n        return si_relu(x, self.positive)\n\n\nclass Zero(nn.Module):\n\n    def __init__(self, stride):\n        super(Zero, self).__init__()\n        self.stride = stride\n\n    def forward(self, x):\n        if self.stride == 1:\n            return x.mul(0.)\n        return x[:, :, ::self.stride, ::self.stride].mul(0.)  # N * C * W * H\n\n\nclass FactorizedReduce(nn.Module):\n\n    def __init__(self, C_in, C_out, affine=True, act_fun=nn.ReLU, positive=0):\n        super(FactorizedReduce, self).__init__()\n        assert C_out % 2 == 0\n        # self.relu = nn.ReLU(inplace=False)\n        self.activation = act_fun()\n        self.conv_1 = nn.Conv2d(C_in, C_out // 2, 3,stride=2, padding=1, bias=False)\n        self.conv_2 = nn.Conv2d(C_in, C_out // 2, 3,stride=2, padding=1, bias=False)\n        self.bn = nn.BatchNorm2d(C_out, affine=affine)\n        self.positive = positive\n        # if positive == -1:\n        #     weight_init(self.op)\n\n    def forward(self, x):\n        # x = self.relu(x)\n        x = self.activation(x)\n        out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1)\n        out = self.bn(out)\n        out = si_relu(out, self.positive)\n        return out\n\nclass F0(nn.Module):\n\n    def __init__(self, C_in, C_out, affine=True, act_fun=nn.ReLU, positive=0):\n        super(F0, self).__init__()\n        assert C_out % 2 == 0\n        # self.relu = nn.ReLU(inplace=False)\n        self.activation = act_fun()\n        self.op=nn.Conv2d(C_out, C_out, 3, stride=2, padding=1, bias=False)\n        self.conv_1 = nn.Conv2d(C_in, C_out // 2, 3,stride=2, padding=1, bias=False)\n        self.conv_2 = nn.Conv2d(C_in, C_out // 2, 3,stride=2, padding=1, bias=False)\n        self.bn = nn.BatchNorm2d(C_out, affine=affine)\n        self.positive = positive\n        # if positive == -1:\n        #     weight_init(self.op)\n\n    def forward(self, x):\n        # x = self.relu(x)\n        x = self.activation(x)\n        out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1)\n        out = self.bn(out)\n        out = si_relu(out, self.positive)\n        out=self.op(out)\n        return out\n\nclass F1(nn.Module):\n\n    def __init__(self, C_in, C_out, affine=True, act_fun=nn.ReLU, positive=0):\n        super(F1, self).__init__()\n        assert C_out % 2 == 0\n        # self.relu = nn.ReLU(inplace=False)\n        self.activation = act_fun()\n        self.op=nn.Conv2d(C_out, C_out, 3, stride=2, padding=1, bias=False)\n        self.conv_1 = nn.Conv2d(C_in, C_out // 2, 3,stride=2, padding=1, bias=False)\n        self.conv_2 = nn.Conv2d(C_in, C_out // 2, 3,stride=2, padding=1, bias=False)\n        self.bn = nn.BatchNorm2d(C_out, affine=affine)\n        self.positive = positive\n        # if positive == -1:\n        #     weight_init(self.op)\n\n    def forward(self, x):\n        # x = self.relu(x)\n        x = self.activation(x)\n        out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1)\n        out = self.bn(out)\n        out = si_relu(out, self.positive)\n        out=self.op(out)\n        return out\n        \nclass ReLUConvBN(nn.Module):\n    \"\"\"\n    ReLu -> Conv2d -> BatchNorm2d\n    \"\"\"\n\n    def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True, act_fun=nn.ReLU, positive=0):\n        super(ReLUConvBN, self).__init__()\n        self.op = nn.Sequential(\n            # nn.ReLU(inplace=False),\n            act_fun(),\n            nn.Conv2d(C_in, C_out, kernel_size, stride=stride,\n                      padding=padding, bias=False),\n            nn.BatchNorm2d(C_out, affine=affine)\n        )\n        self.positive = positive\n        # if positive == -1:\n        #     weight_init(self.op)\n\n    def forward(self, x):\n        out = self.op(x)\n        return si_relu(out, self.positive)\n\n\n# class DeformConv(nn.Module):\n#     def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True, act_fun=nn.ReLU, positive=0):\n#         super(DeformConv, self).__init__()\n#         self.op = nn.Sequential(\n#             # nn.ReLU(inplace=False),\n#             act_fun(),\n#             DeformConvPack(C_in, C_out, kernel_size=kernel_size,\n#                            stride=stride, padding=padding, bias=True),\n#             nn.BatchNorm2d(C_out, affine=affine)\n#         )\n#         self.positive = positive\n#         # if positive == -1:\n#         #     weight_init(self.op)\n\n#     def forward(self, x):\n#         out = self.op(x)\n#         return si_relu(out, self.positive)\n\n\nclass Attention(Module):\n    \"\"\"\n    Obtained from: github.com:rwightman/pytorch-image-models\n    \"\"\"\n\n    def __init__(self, dim, num_heads=4, attention_dropout=0.1, projection_dropout=0.1):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // self.num_heads\n        self.scale = head_dim ** -0.5\n\n        self.qkv = Linear(dim, dim * 3, bias=False)\n        self.attn_drop = Dropout(attention_dropout)\n        self.proj = Linear(dim, dim)\n        self.proj_drop = Dropout(projection_dropout)\n\n    def forward(self, x):\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //\n                                  self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass TransformerEncoderLayer(Module):\n    \"\"\"\n    Inspired by torch.nn.TransformerEncoderLayer and\n    rwightman's timm package.\n    \"\"\"\n\n    def __init__(self, d_model, nhead=4, dim_feedforward=256, dropout=0.1,\n                 attention_dropout=0.1, drop_path_rate=0.1):\n        super(TransformerEncoderLayer, self).__init__()\n        self.pre_norm = LayerNorm(d_model)\n        self.self_attn = Attention(dim=d_model, num_heads=nhead,\n                                   attention_dropout=attention_dropout, projection_dropout=dropout)\n        dim_feedforward = d_model\n        self.linear1 = Linear(d_model, dim_feedforward)\n        self.dropout1 = Dropout(dropout)\n        self.norm1 = LayerNorm(d_model)\n        self.linear2 = Linear(dim_feedforward, d_model)\n        self.dropout2 = Dropout(dropout)\n\n        self.drop_path = DropPath(\n            drop_path_rate) if drop_path_rate > 0 else Identity()\n\n        self.activation = F.gelu\n\n    def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor:\n        # print(src.shape)\n        c = src.shape[-1]\n        src = rearrange(src, 'b d r c -> b (r c) d')\n        # print(src.shape)\n        src = src + self.drop_path(self.self_attn(self.pre_norm(src)))\n        src = self.norm1(src)\n        src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))\n        src = src + self.drop_path(self.dropout2(src2))\n        src = rearrange(src, 'b (r c) d -> b d r c', c=c)\n        return src\n\n\ndef drop_path(x, drop_prob: float = 0., training: bool = False):\n    \"\"\"\n    Obtained from: github.com:rwightman/pytorch-image-models\n    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\n    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,\n    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...\n    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for\n    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use\n    'survival rate' as the argument.\n    \"\"\"\n    if drop_prob == 0. or not training:\n        return x\n    keep_prob = 1 - drop_prob\n    # work with diff dim tensors, not just 2D ConvNets\n    shape = (x.shape[0],) + (1,) * (x.ndim - 1)\n    random_tensor = keep_prob + \\\n        torch.rand(shape, dtype=x.dtype, device=x.device)\n    random_tensor.floor_()  # binarize\n    output = x.div(keep_prob) * random_tensor\n    return output\n\n\nclass DropPath(Module):\n    \"\"\"\n    Obtained from: github.com:rwightman/pytorch-image-models\n    Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).\n    \"\"\"\n\n    def __init__(self, drop_prob=None):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n\n    def forward(self, x):\n        return drop_path(x, self.drop_prob, self.training)\n"
  },
  {
    "path": "examples/Structure_Evolution/EB-NAS/readme.md",
    "content": "\n\n\n\n# Brain-Inspired Evolutionary Architectures for Spiking Neural Networks —— Based on BrainCog #\n\n\n\n## Requirments ##\n* numpy\n* pytorch >= 1.12.0\n* pymoo = 0.4.0\n* BrainCog\n\n## Run ##\n\n```python ebnas.py```\n\n## Citation ##\n\nIf you find the code and dataset useful in your research, please consider citing:\n```\n@article{pan2024brain,\n  title={Brain-inspired Evolutionary Architectures for Spiking Neural Networks},\n  author={Pan, Wenxuan and Zhao, Feifei and Zhao, Zhuoya and Zeng, Yi},\n  journal={IEEE Transactions on Artificial Intelligence},\n  year={2024},\n  publisher={IEEE}\n}\n\n@article{zeng2023braincog,\n  title={BrainCog: A spiking neural network based, brain-inspired cognitive intelligence engine for brain-inspired AI and brain simulation},\n  author={Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and others},\n  journal={Patterns},\n  volume={4},\n  number={8},\n  year={2023},\n  publisher={Elsevier}\n}\n```\n"
  },
  {
    "path": "examples/Structure_Evolution/EB-NAS/single_genome.py",
    "content": "import os\nimport json\nimport shutil\nimport argparse\nimport subprocess\nimport numpy as np\nimport torch\nfrom tm import get_net_info\nimport micro_encoding\nimport logging\nimport yaml\nfrom braincog.utils import save_feature_map, setup_seed\nfrom tm import train_motifs\nfrom datetime import datetime\nfrom timm.utils import *\nfrom pymoo.optimize import minimize\nfrom pymoo.model.problem import Problem\nfrom pymoo.factory import get_performance_indicator\nfrom pymoo.algorithms.so_genetic_algorithm import GA\nfrom pymoo.util.nds.non_dominated_sorting import NonDominatedSorting\nfrom pymoo.factory import get_algorithm, get_crossover, get_mutation\n\nfrom search_space.ofa import OFASearchSpace\nfrom acc_predictor.factory import get_acc_predictor\n\n_DEBUG = True\nif _DEBUG: from pymoo.visualization.scatter import Scatter\n\ndevices=[0]\nos.environ['CUDA_VISIBLE_DEVICES']='3'\n_logger = logging.getLogger('')\n\nconfig_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)\nparser = argparse.ArgumentParser(description='SNN Training and Evaluating')\nparser.add_argument('--seed', type=int, default=42, metavar='S',\n                help='random seed (default: 42)')\n\nparser.add_argument('--ocrate', type=float, default=0.0)\n\nparser.add_argument('--dataset', type=str, default='dvsg',\n                    help='imagenet, cifar10, cifar100, dvsg, dvsc10')\nparser.add_argument('--step', type=int, default=8, help='Simulation time step (default: 10)')\nparser.add_argument('--num-classes', type=int, default=11, metavar='N')\nparser.add_argument('--layers', type=int, default=2)\nparser.add_argument('--bits', type=int, default=20)\nparser.add_argument('--eval_epochs', type=int, default=6000)\nparser.add_argument('--trainning_epochs', type=int, default=600)\nparser.add_argument('--iterations', type=int, default=20,\n                    help='number of search iterations')\nparser.add_argument('--n_doe', type=int, default=100,\n                    help='initial sample size for DOE')\nparser.add_argument('--bns', action='store_true', default=True)\nparser.add_argument('--init-channels', type=int, default=16)\nparser.add_argument('--spike-rate', action='store_true', default=False)\nparser.add_argument('--save', type=str, default='',\n                    help='location of dir to save')\nparser.add_argument('--sec_obj', type=str, default='flops',\n                    help='second objective to optimize simultaneously')\nparser.add_argument('--n_iter', type=int, default=8,\n                    help='number of architectures to high-fidelity eval (low level) in each iteration')\nparser.add_argument('--predictor', type=str, default='rbf',\n                    help='which accuracy predictor model to fit (rbf/gp/carts/mlp/as)')\nparser.add_argument('--n_gpus', type=int, default=len(devices),\n                    help='total number of available gpus')\nparser.add_argument('--gpu', type=int, default=1,\n                    help='number of gpus per evaluation job')\nparser.add_argument('--supernet_path', type=str, default='./ofa_mbv3_d234_e346_k357_w1.0',\n                    help='file path to supernet weights')\nparser.add_argument('--n_workers', type=int, default=1,\n                    help='number of workers for dataloader per evaluation job')\nparser.add_argument('--vld_size', type=int, default=5000,\n                    help='validation set size, randomly sampled from training set')\nparser.add_argument('--trn_batch_size', type=int, default=128,\n                    help='train batch size for training')\nparser.add_argument('--vld_batch_size', type=int, default=200,\n                    help='test batch size for inference')\nparser.add_argument('--cooldown-epochs', type=int, default=0, metavar='N',help='epochs to cooldown LR at min_lr, after cyclic schedule ends')\n\nparser.add_argument('--test', action='store_true', default=False,\n                    help='evaluation performance on testing set')\nparser.add_argument('-c', '--config', default='', type=str, metavar='FILE',\n                    help='YAML config file specifying default arguments')\nparser.add_argument('--model', default='NetworkCIFAR_', type=str, metavar='MODEL',\n                    help='Name of model to train (default: \"countception\"')\nparser.add_argument('--pretrained', action='store_true', default=False,\n                    help='Start with pretrained version of specified network (if avail)')\nparser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',\n                    help='Initialize model from this checkpoint (default: none)')\nparser.add_argument('--resume', default='', type=str, metavar='PATH',\n                    help='Resume full model and optimizer state from checkpoint (default: none)')\nparser.add_argument('--eval_checkpoint', default='', type=str, metavar='PATH',\n                    help='path to eval checkpoint (default: none)')\nparser.add_argument('--no-resume-opt', action='store_true', default=False,\n                    help='prevent resume of optimizer state when resuming model')\nparser.add_argument('--gp', default=None, type=str, metavar='POOL',\n                    help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')\n\n# Dataset parameters for static datasets\nparser.add_argument('--img-size', type=int, default=224, metavar='N',\n                    help='Image patch size (default: None => model default)')\nparser.add_argument('--crop-pct', default=None, type=float,\n                    metavar='N', help='inputs image center crop percent (for validation only)')\nparser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',\n                    help='Override mean pixel value of dataset')\nparser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',\n                    help='Override std deviation of of dataset')\nparser.add_argument('--interpolation', default='', type=str, metavar='NAME',\n                    help='Image resize interpolation type (overrides model)')\n\n\n# Dataloader parameters\nparser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',\n                    help='inputs batch size for training (default: 128)')\nparser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',\n                    help='ratio of validation batch size to training batch size (default: 1)')\n\n# Optimizer parameters\nparser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',\n                    help='Optimizer (default: \"adamw\"')\nparser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',\n                    help='Optimizer Epsilon (default: None, use opt default)')\nparser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',\n                    help='Optimizer Betas (default: None, use opt default)')\nparser.add_argument('--momentum', type=float, default=0.9, metavar='M',\n                    help='Optimizer momentum (default: 0.9)')\nparser.add_argument('--weight-decay', type=float, default=0.01,\n                    help='weight decay (default: 0.01 for adamw)')\nparser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',\n                    help='Clip gradient norm (default: None, no clipping)')\nparser.add_argument('--adam-epoch', type=int, default=1000, help='lamb switch to adamw')\n\n# Learning rate schedule parameters\nparser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',\n                    help='LR scheduler (default: \"cosine\"')\nparser.add_argument('--lr', type=float, default=5e-3, metavar='LR',\n                    help='learning rate (default: 0.01)')\nparser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',\n                    help='learning rate noise on/off epoch percentages')\nparser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',\n                    help='learning rate noise limit percent (default: 0.67)')\nparser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',\n                    help='learning rate noise std-dev (default: 1.0)')\nparser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',\n                    help='learning rate cycle len multiplier (default: 1.0)')\nparser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',\n                    help='learning rate cycle limit')\nparser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',\n                    help='warmup learning rate (default: 0.0001)')\nparser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',\n                    help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')\n\nparser.add_argument('--start-epoch', default=None, type=int, metavar='N',\n                    help='manual epoch number (useful on restarts)')\nparser.add_argument('--decay-epochs', type=float, default=30, metavar='N',\n                    help='epoch interval to decay LR')\nparser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',\n                    help='epochs to warmup LR, if scheduler supports')\n\nparser.add_argument('--patience-epochs', type=int, default=10, metavar='N',\n                    help='patience epochs for Plateau LR scheduler (default: 10')\nparser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',\n                    help='LR decay rate (default: 0.1)')\nparser.add_argument('--power', type=int, default=1, help='power')\n\n# Augmentation & regularization parameters ONLY FOR IMAGE NET\nparser.add_argument('--no-aug', action='store_true', default=False,\n                    help='Disable all training augmentation, override other train aug args')\nparser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',\n                    help='Random resize scale (default: 0.08 1.0)')\nparser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                    help='Random resize aspect ratio (default: 0.75 1.33)')\nparser.add_argument('--hflip', type=float, default=0.5,\n                    help='Horizontal flip training aug probability')\nparser.add_argument('--vflip', type=float, default=0.,\n                    help='Vertical flip training aug probability')\nparser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',\n                    help='Color jitter factor (default: 0.4)')\nparser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',\n                    help='Use AutoAugment policy. \"v0\" or \"original\". (default: None)'),\nparser.add_argument('--aug-splits', type=int, default=0,\n                    help='Number of augmentation splits (default: 0, valid: 0 or >=2)')\nparser.add_argument('--jsd', action='store_true', default=False,\n                    help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')\nparser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',\n                    help='Random erase prob (default: 0.25)')\nparser.add_argument('--remode', type=str, default='pixel',\n                    help='Random erase mode (default: \"const\")')\nparser.add_argument('--recount', type=int, default=1,\n                    help='Random erase count (default: 1)')\nparser.add_argument('--resplit', action='store_true', default=False,\n                    help='Do not random erase first (clean) augmentation split')\nparser.add_argument('--mixup', type=float, default=0.8,\n                    help='mixup alpha, mixup enabled if > 0. (default: 0.)')\nparser.add_argument('--cutmix', type=float, default=1.0,\n                    help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')\nparser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,\n                    help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')\nparser.add_argument('--mixup-prob', type=float, default=1.0,\n                    help='Probability of performing mixup or cutmix when either/both is enabled')\nparser.add_argument('--mixup-switch-prob', type=float, default=0.5,\n                    help='Probability of switching to cutmix when both mixup and cutmix enabled')\nparser.add_argument('--mixup-mode', type=str, default='batch',\n                    help='How to apply mixup/cutmix params. Per \"batch\", \"pair\", or \"elem\"')\nparser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',\n                    help='Turn off mixup after this epoch, disabled if 0 (default: 0)')\nparser.add_argument('--smoothing', type=float, default=0.1,\n                    help='Label smoothing (default: 0.1)')\nparser.add_argument('--train-interpolation', type=str, default='random',\n                    help='Training interpolation (random, bilinear, bicubic default: \"random\")')\nparser.add_argument('--drop', type=float, default=0.0, metavar='PCT',\n                    help='Dropout rate (default: 0.0)')\nparser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',\n                    help='Drop connect rate, DEPRECATED, use drop-path (default: None)')\nparser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',\n                    help='Drop path rate (default: None)')\nparser.add_argument('--drop-block', type=float, default=None, metavar='PCT',\n                    help='Drop block rate (default: None)')\nparser.add_argument('--newton-maxiter', default=20, type=int,\n                    help='max iterration in newton method')\nparser.add_argument('--reset-drop', action='store_true', default=False,\n                    help='whether to reset drop')\nparser.add_argument('--kernel-method', type=str, default='cuda', choices=['torch', 'cuda'],\n                    help='The implementation way of gaussian kernel method, choose from \"cuda\" and \"torch\"')\n\n# Batch norm parameters (only works with gen_efficientnet based models currently)\nparser.add_argument('--bn-tf', action='store_true', default=False,\n                    help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')\nparser.add_argument('--bn-momentum', type=float, default=None,\n                    help='BatchNorm momentum override (if not None)')\nparser.add_argument('--bn-eps', type=float, default=None,\n                    help='BatchNorm epsilon override (if not None)')\nparser.add_argument('--sync-bn', action='store_true',\n                    help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')\nparser.add_argument('--dist-bn', type=str, default='',\n                    help='Distribute BatchNorm stats between node after each epoch (\"broadcast\", \"reduce\", or \"\")')\nparser.add_argument('--split-bn', action='store_true',\n                    help='Enable separate BN layers per augmentation split.')\n\n# Model Exponential Moving Average\nparser.add_argument('--model-ema', action='store_true', default=False,\n                    help='Enable tracking moving average of model weights')\nparser.add_argument('--model-ema-force-cpu', action='store_true', default=False,\n                    help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')\nparser.add_argument('--model-ema-decay', type=float, default=0.99996,\n                    help='decay factor for model weights moving average (default: 0.9998)')\n\n# Misc\n\nparser.add_argument('--log-interval', type=int, default=50, metavar='N',\n                    help='how many batches to wait before logging training status')\nparser.add_argument('--recovery-interval', type=int, default=0, metavar='N',\n                    help='how many batches to wait before writing recovery checkpoint')\nparser.add_argument('-j', '--workers', type=int, default=8, metavar='N',\n                    help='how many training processes to use (default: 1)')\nparser.add_argument('--num-gpu', type=int, default=len(devices),\n                    help='Number of GPUS to use')\nparser.add_argument('--save-images', action='store_true', default=False,\n                    help='save images of inputs bathes every log interval for debugging')\nparser.add_argument('--amp', action='store_true', default=False,\n                    help='use NVIDIA Apex AMP or Native AMP for mixed precision training')\nparser.add_argument('--apex-amp', action='store_true', default=False,\n                    help='Use NVIDIA Apex AMP mixed precision')\nparser.add_argument('--native-amp', action='store_true', default=False,\n                    help='Use Native Torch AMP mixed precision')\nparser.add_argument('--channels-last', action='store_true', default=False,\n                    help='Use channels_last memory layout')\nparser.add_argument('--pin-mem', action='store_true', default=False,\n                    help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')\nparser.add_argument('--no-prefetcher', action='store_true', default=False,\n                    help='disable fast prefetcher')\n\nparser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',\n                    help='Best metric (default: \"top1\"')\nparser.add_argument('--tta', type=int, default=0, metavar='N',\n                    help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')\nparser.add_argument('--local_rank', default=0, type=int)\nparser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,\n                    help='use the multi-epochs-loader to save time at the beginning of every epoch')\nparser.add_argument('--eval', action='store_true', help='Perform evaluation only')\nparser.add_argument('--device', type=int, default=devices[0])\n\n# Spike parameters\nparser.add_argument('--encode', type=str, default='direct', help='Input encode method (default: direct)')\nparser.add_argument('--temporal-flatten', action='store_true',\n                    help='Temporal flatten to channels. ONLY FOR EVENT DATA TRAINING BY ANN')\nparser.add_argument('--adaptive-node', action='store_true')\nparser.add_argument('--critical-loss', action='store_true')\n\n# neuron type\nparser.add_argument('--node-type', type=str, default='LIFNode', help='Node type in network (default: PLIF)')\nparser.add_argument('--act-fun', type=str, default='QGateGrad',\n                    help='Surogate Function in node. Only for Surrogate nodes (default: AtanGrad)')\nparser.add_argument('--threshold', type=float, default=.5, help='Firing threshold (default: 0.5)')\nparser.add_argument('--tau', type=float, default=2., help='Attenuation coefficient (default: 2.)')\nparser.add_argument('--requires-thres-grad', action='store_true')\nparser.add_argument('--sigmoid-thres', action='store_true')\n\nparser.add_argument('--loss-fn', type=str, default='ce', help='loss function (default: ce)')\nparser.add_argument('--noisy-grad', type=float, default=0.,\n                    help='Add noise to backward, sometime will make higher accuracy (default: 0.)')\nparser.add_argument('--spike-output', action='store_true', default=False,\n                    help='Using mem output or spike output (default: False)')\nparser.add_argument('--n_groups', type=int, default=1)\n\n# EventData Augmentation\nparser.add_argument('--mix-up', action='store_true', help='Mix-up for event data (default: False)')\nparser.add_argument('--cut-mix', action='store_true', help='CutMix for event data (default: False)')\nparser.add_argument('--event-mix', action='store_true', help='EventMix for event data (default: False)')\nparser.add_argument('--cutmix_beta', type=float, default=1.0, help='cutmix_beta (default: 1.)')\nparser.add_argument('--cutmix_prob', type=float, default=0.5, help='cutmix_prib for event data (default: .5)')\nparser.add_argument('--cutmix_num', type=int, default=1, help='cutmix_num for event data (default: 1)')\nparser.add_argument('--cutmix_noise', type=float, default=0.,\n                    help='Add Pepper noise after mix, sometimes work (default: 0.)')\nparser.add_argument('--rand-aug', action='store_true',\n                    help='Rand Augment for Event data (default: False)')\nparser.add_argument('--randaug_n', type=int, default=3,\n                    help='Rand Augment times n (default: 3)')\nparser.add_argument('--randaug_m', type=int, default=15,\n                    help='Rand Augment times n (default: 15) (0-30)')\nparser.add_argument('--train-portion', type=float, default=0.9,\n                    help='Dataset portion, only for datasets which do not have validation set (default: 0.9)')\nparser.add_argument('--event-size', default=48, type=int,\n                    help='Event size. Resize event data before process (default: 48)')\nparser.add_argument('--layer-by-layer', action='store_true',\n                    help='forward step-by-step or layer-by-layer. '\n                        'Larger Model with layer-by-layer will be faster (default: False)')\nparser.add_argument('--node-resume', type=str, default='',\n                    help='resume weights in node for adaptive node. (default: False)')\nparser.add_argument('--node-trainable', action='store_true')\n\n# visualize\nparser.add_argument('--visualize', action='store_true',\n                    help='Visualize spiking map for each layer, only for validate (default: False)')\n\nparser.add_argument('--tsne', action='store_true')\nparser.add_argument('--conf-mat', action='store_true')\n\n# DARTS parameters\n\nparser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')\nparser.add_argument('--parse_method', default='darts', type=str)\nparser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability')\n\nparser.add_argument('--suffix', type=str, default='',\n                    help='Add an additional suffix to the save path (default: \\'\\')')\n\n\n\n\n\n\n\ndef _parse_args():\n    args_config, remaining = config_parser.parse_known_args()\n    args = parser.parse_args(remaining)\n    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)\n    return args, args_text\n\n\nif __name__ == '__main__':\n\n    args, args_text = _parse_args()\n    args.no_spike_output = True\n    output_dir = ''\n    if args.bns:\n        from cellmodel import NetworkCIFAR_\n    else:\n        from cell123model import NetworkCIFAR_\n    # if 'dvs' in args.dataset:\n    #     args.step=10\n    # else:\n    #     args.step=4\n\n    if args.local_rank == 0:\n        output_base = args.save\n        exp_name = '-'.join([\n            datetime.now().strftime(\"%Y%m%d-%H%M%S\"),\n            # args.model,\n            args.dataset,\n            str(args.layers)+'layers',\n            str(args.init_channels)+'channels',\n            str(args.step)+'steps',\n            # args.suffix\n            # str(args.img_size)\n        ])\n        output_dir = get_outdir(output_base,str(args.dataset),exp_name)\n        args.output_dir = output_dir\n        setup_default_logging(log_path=os.path.join(output_dir, 'log.txt'))\n\n    else:\n        setup_default_logging()\n\n    args.prefetcher = not args.no_prefetcher\n    args.distributed = False\n    if 'WORLD_SIZE' in os.environ:\n        args.distributed = int(os.environ['WORLD_SIZE']) > 1\n        if args.distributed and args.num_gpu > 1:\n            _logger.warning(\n                'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')\n            args.num_gpu = 1\n    # args.device = 'cuda:0'\n    args.world_size = 1\n    args.rank = 0  # global rank\n    if args.distributed:\n        args.num_gpu = 1\n        args.device = 'cuda:%d' % args.local_rank\n        torch.cuda.set_device(args.local_rank)\n        torch.distributed.init_process_group(backend='nccl', init_method='env://')\n        args.world_size = torch.distributed.get_world_size()\n        args.rank = torch.distributed.get_rank()\n    else:\n        torch.cuda.set_device('cuda:0')\n    assert args.rank >= 0\n\n    if args.distributed:\n        _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'\n                        % (args.rank, args.world_size))\n    else:\n        _logger.info('Training with a single process on %d GPUs.' % args.num_gpu)\n\n    setup_seed(args.seed + args.rank)\n    defalut_lr = args.lr\n\n    arch_doe,sn,bigmotifs,ms ,glob_con= micro_encoding.sample(pops=10, layers=args.layers, bits=20)\n    i=0\n    for geno in arch_doe:\n        arch_dir=os.path.join('train',args.output_dir,str(i))\n        if os.path.exists(arch_dir) is False:\n            os.makedirs(arch_dir,exist_ok = True)\n        args.lr=defalut_lr\n        geno=micro_encoding.c_single(geno,layers=args.layers,bits=args.bits)\n        acc,info=train_motifs(args=args,gen=0,arch_dir=arch_dir,genome=np.array(geno[:-1]),_logger=_logger,args_text=args_text,devices=devices,ms=ms,glob=geno[-1])\n        i+=1"
  },
  {
    "path": "examples/Structure_Evolution/EB-NAS/tm.py",
    "content": "import sys\nimport numpy as np\nimport argparse\nimport time\nimport timm.models\nimport yaml\nimport os\nimport logging\nfrom random import choice\nfrom collections import OrderedDict\nfrom contextlib import suppress\nfrom datetime import datetime\nfrom micro_encoding import ops\nfrom braincog.base.node.node import *\nfrom braincog.utils import *\nfrom braincog.base.utils.criterions import *\nfrom braincog.datasets.datasets import *\nfrom braincog.model_zoo.resnet import *\nfrom braincog.model_zoo.convnet import *\nfrom braincog.utils import save_feature_map, setup_seed\nfrom braincog.base.utils.visualization import plot_tsne_3d, plot_tsne, plot_confusion_matrix\nimport micro_encoding\nfrom pymop.problem import Problem\nimport torch\nfrom thop import profile\nimport torch.nn as nn\nimport torchvision.utils\nfrom torch.nn.parallel import DistributedDataParallel as NativeDDP\nfrom pymoo.optimize import minimize\nfrom timm.data import create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset\nfrom timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model\nfrom timm.utils import *\nfrom timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy\nfrom timm.optim import create_optimizer\nfrom timm.scheduler import create_scheduler\nfrom timm.utils import ApexScaler, NativeScaler\nfrom torchprofile import profile_macs\nimport copy\nimport torch.backends.cudnn as cudnn\nimport warnings\nwarnings.simplefilter(\"ignore\")\n\ndef train_motifs(args,gen,arch_dir,genome,_logger,args_text,devices,ms,glob):\n    if args.bns:\n        from cellmodel import NetworkCIFAR_\n    else:\n        from cell123model import NetworkCIFAR_\n    \n    # qw=np.where(args.glob_con[0]==1)\n    # ccc=np.array([1,0,0,0])\n    # for i in qw:\n    #     ccc[i[0]]=1\n    #     ddd=np.where(args.glob_con[i[0]]==1)\n    #     if len(ddd[0])!=0:\n    #         for j in ddd:\n    #             ccc[j[0]]=1\n    #             www=np.where(args.glob_con[j[0]]==1)\n    #             if len(www[0])!=0:\n    #                 for k in www:\n    #                     ccc[k[0]]=1\n\n\n    test_motifs,ids = micro_encoding.decode_motif(args.layers*ms,args.bits,genome.astype(int))\n\n    \n\n    \n\n\n        \n\n\n    # if gen==-1:\n    args.epochs=args.eval_epochs\n    # else:\n    #     args.epochs=args.eval_epochs\n\n    try:\n        model = create_model(\n            args.model,\n            pretrained=args.pretrained,\n            num_classes=args.num_classes,\n            dataset=args.dataset,\n            step=args.step,\n            encode_type=args.encode,\n            node_type=eval(args.node_type),\n            threshold=args.threshold,\n            tau=args.tau,\n            sigmoid_thres=args.sigmoid_thres,\n            requires_thres_grad=args.requires_thres_grad,\n            spike_output=not args.no_spike_output,\n            C=args.init_channels,\n            layers=args.layers*ms,\n            auxiliary=args.auxiliary,\n            motif=test_motifs,\n            parse_method=args.parse_method,\n            act_fun=args.act_fun,\n            temporal_flatten=args.temporal_flatten,\n            layer_by_layer=args.layer_by_layer,\n            n_groups=args.n_groups,\n            glob=glob,\n        )\n\n        if 'dvs' in args.dataset:\n            args.channels = 2\n        elif 'mnist' in args.dataset:\n            args.channels = 1\n        else:\n            args.channels = 3\n        # flops, params = profile(model, inputs=(torch.randn(1, args.channels, args.event_size, args.event_size),), verbose=False)\n        # _logger.info('flops = %fM', flops / 1e6)\n        # _logger.info('param size = %fM', params / 1e6)\n        flops=0\n        params=0\n        linear_scaled_lr = args.lr * args.batch_size * args.world_size / 1024.0\n        args.lr = linear_scaled_lr\n        _logger.info(\"learning rate is %f\" % linear_scaled_lr)\n\n        if args.local_rank == 0:\n            sumpram=sum([m.numel() for m in model.parameters()])\n            _logger.info('Model %s created, param count: %d' %\n                        (args.model, sumpram))\n\n\n        num_aug_splits = 0\n        if args.aug_splits > 0:\n            assert args.aug_splits > 1, 'A split of 1 makes no sense'\n            num_aug_splits = args.aug_splits\n\n        if args.split_bn:\n            assert num_aug_splits > 1 or args.resplit\n            model = convert_splitbn_model(model, max(num_aug_splits, 2))\n\n        use_amp = None\n        if args.amp:\n            # for backwards compat, `--amp` arg tries apex before native amp\n            if has_apex:\n                args.apex_amp = True\n            elif has_native_amp:\n                args.native_amp = True\n        if args.apex_amp and has_apex:\n            use_amp = 'apex'\n        elif args.native_amp and has_native_amp:\n            use_amp = 'native'\n        elif args.apex_amp or args.native_amp:\n            _logger.warning(\"Neither APEX or native Torch AMP is available, using float32. \"\n                            \"Install NVIDA apex or upgrade to PyTorch 1.6\")\n\n        if args.num_gpu > 1:\n            if use_amp == 'apex':\n                _logger.warning(\n                    'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.')\n                use_amp = None\n            model = nn.DataParallel(model, device_ids=devices).cuda()\n            assert not args.channels_last, \"Channels last not supported with DP, use DDP.\"\n        else:\n            model = model.cuda()\n            if args.channels_last:\n                model = model.to(memory_format=torch.channels_last)\n\n        optimizer = create_optimizer(args, model)\n\n        amp_autocast = suppress  # do nothing\n        loss_scaler = None\n        if use_amp == 'apex':\n            model, optimizer = amp.initialize(model, optimizer, opt_level='O1')\n            loss_scaler = ApexScaler()\n            if args.local_rank == 0:\n                _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')\n        elif use_amp == 'native':\n            amp_autocast = torch.cuda.amp.autocast\n            loss_scaler = NativeScaler()\n            if args.local_rank == 0:\n                _logger.info('Using native Torch AMP. Training in mixed precision.')\n        else:\n            if args.local_rank == 0:\n                _logger.info('AMP not enabled. Training in float32.')\n\n        # optionally resume from a checkpoint\n        resume_epoch = None\n        if args.resume and args.eval_checkpoint == '':\n            args.eval_checkpoint = args.resume\n        if args.resume:\n            args.eval = True\n            # checkpoint = torch.load(args.resume, map_location='cpu')\n            # model.load_state_dict(checkpoint['state_dict'], False)\n            resume_epoch = resume_checkpoint(\n                model, args.resume,\n                optimizer=None if args.no_resume_opt else optimizer,\n                loss_scaler=None if args.no_resume_opt else loss_scaler,\n                log_info=args.local_rank == 0)\n            # print(model.get_attr('mu'))\n            # print(model.get_attr('sigma'))\n\n        if args.critical_loss or args.spike_rate:\n            if args.num_gpu>1:\n                model.module.set_requires_fp(True)\n            else:\n                model.set_requires_fp(True)\n\n        model_ema = None\n        if args.model_ema:\n            # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper\n            model_ema = ModelEma(\n                model,\n                decay=args.model_ema_decay,\n                device='cpu' if args.model_ema_force_cpu else '',\n                resume=args.resume)\n\n        if args.node_resume:\n            ckpt = torch.load(args.node_resume, map_location='cpu')\n            model.load_node_weight(ckpt, args.node_trainable)\n\n        model_without_ddp = model\n        if args.distributed:\n            if args.sync_bn:\n                assert not args.split_bn\n                try:\n                    if has_apex and use_amp != 'native':\n                        # Apex SyncBN preferred unless native amp is activated\n                        model = convert_syncbn_model(model)\n                    else:\n                        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)\n                    if args.local_rank == 0:\n                        _logger.info(\n                            'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '\n                            'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')\n                except Exception as e:\n                    _logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')\n            if has_apex and use_amp != 'native':\n                # Apex DDP preferred unless native amp is activated\n                if args.local_rank == 0:\n                    _logger.info(\"Using NVIDIA APEX DistributedDataParallel.\")\n                model = ApexDDP(model, delay_allreduce=True)\n            else:\n                if args.local_rank == 0:\n                    _logger.info(\"Using native Torch DistributedDataParallel.\")\n                model = NativeDDP(model, device_ids=[args.local_rank],\n                                find_unused_parameters=True)  # can use device str in Torch >= 1.1\n            model_without_ddp = model.module\n        # NOTE: EMA model does not need to be wrapped by DDP\n\n        lr_scheduler, num_epochs = create_scheduler(args, optimizer)\n        start_epoch = 0\n        if args.start_epoch is not None:\n            # a specified start_epoch will always override the resume epoch\n            start_epoch = args.start_epoch\n        elif resume_epoch is not None:\n            start_epoch = resume_epoch\n        if lr_scheduler is not None and start_epoch > 0:\n            lr_scheduler.step(start_epoch)\n\n        if args.local_rank == 0:\n            _logger.info('Scheduled epochs: {}'.format(num_epochs))\n\n        # now config only for imnet\n        data_config = resolve_data_config(vars(args), model=model, verbose=False)\n        loader_train, loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % args.dataset)(\n            batch_size=args.batch_size,\n            step=args.step,\n            args=args,\n            _logge=_logger,\n            data_config=data_config,\n            num_aug_splits=num_aug_splits,\n            size=args.event_size,\n            mix_up=args.mix_up,\n            cut_mix=args.cut_mix,\n            event_mix=args.event_mix,\n            beta=args.cutmix_beta,\n            prob=args.cutmix_prob,\n            num=args.cutmix_num,\n            noise=args.cutmix_noise,\n            num_classes=args.num_classes,\n            rand_aug=args.rand_aug,\n            randaug_n=args.randaug_n,\n            randaug_m=args.randaug_m,\n            temporal_flatten=args.temporal_flatten,\n            portion=args.train_portion,\n            _logger=_logger,\n\n        )\n\n        if args.loss_fn == 'mse':\n            train_loss_fn = UnilateralMse(1.)\n            validate_loss_fn = UnilateralMse(1.)\n\n        else:\n            if args.jsd:\n                assert num_aug_splits > 1  # JSD only valid with aug splits set\n                train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()\n            elif mixup_active:\n                # smoothing is handled with mixup target transform\n                train_loss_fn = SoftTargetCrossEntropy().cuda()\n            elif args.smoothing:\n                train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()\n            else:\n                train_loss_fn = nn.CrossEntropyLoss().cuda()\n\n            validate_loss_fn = nn.CrossEntropyLoss().cuda()\n\n        if args.loss_fn == 'mix':\n            train_loss_fn = MixLoss(train_loss_fn)\n            validate_loss_fn = MixLoss(validate_loss_fn)\n\n        eval_metric = args.eval_metric\n        best_metric = None\n        best_epoch = None\n\n        if args.eval:  # evaluate the model\n            if args.distributed:\n                state_dict = torch.load(args.eval_checkpoint)['state_dict_ema']\n                new_state_dict = OrderedDict()\n                # add module prefix for DDP\n                for k, v in state_dict.items():\n                    k = 'module.' + k\n                    new_state_dict[k] = v\n\n                model.load_state_dict(new_state_dict)\n            # else:\n            #     load_checkpoint(model, args.eval_checkpoint, args.model_ema)\n            for i in range(1):\n                val_metrics,_ = validate(start_epoch, model, loader_eval, validate_loss_fn, args,arch_dir,\n                                    visualize=args.visualize, spike_rate=args.spike_rate,\n                                    tsne=args.tsne, conf_mat=args.conf_mat)\n                print(f\"Top-1 accuracy of the model is: {val_metrics['top1']:.1f}%\")\n            # return\n\n        saver = None\n        if args.local_rank == 0:\n            decreasing = True if eval_metric == 'loss' else False\n\n            saver = CheckpointSaver(\n                model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,\n                checkpoint_dir=arch_dir, recovery_dir=arch_dir, decreasing=decreasing)\n            with open(os.path.join(arch_dir, 'args.yaml'), 'w') as f:\n                f.write(args_text)\n        f=open(os.path.join(arch_dir, 'direct_genome.txt'), 'a')\n        f.write(\",\".join(str(k) for k in genome))\n        f.write('\\n')\n        f.close()\n        try:  # train the model\n            if args.reset_drop:\n                model_without_ddp.reset_drop_path(0.0)\n            for epoch in range(start_epoch, args.epochs):\n                if epoch == 0 and args.reset_drop:\n                    model_without_ddp.reset_drop_path(args.drop_path)\n\n                if args.distributed:\n                    loader_train.sampler.set_epoch(epoch)\n\n                train_metrics = train_epoch(\n                    epoch, model, loader_train, optimizer, train_loss_fn, args,_logger=_logger,\n                    lr_scheduler=lr_scheduler, saver=saver, output_dir=arch_dir,\n                    amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)\n\n                if args.distributed and args.dist_bn in ('broadcast', 'reduce'):\n                    if args.local_rank == 0:\n                        _logger.info(\"Distributing BatchNorm running means and vars\")\n                    distribute_bn(model, args.world_size, args.dist_bn == 'reduce')\n\n                eval_metrics,_ = validate(epoch, model, loader_eval, validate_loss_fn, args, arch_dir,amp_autocast=amp_autocast,_logger=_logger,\n                                        visualize=args.visualize, spike_rate=args.spike_rate,\n                                        tsne=args.tsne, conf_mat=args.conf_mat)\n\n                if model_ema is not None and not args.model_ema_force_cpu:\n                    if args.distributed and args.dist_bn in ('broadcast', 'reduce'):\n                        distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')\n                    ema_eval_metrics,_ = validate(\n                        epoch, model_ema.ema, loader_eval, validate_loss_fn, args, arch_dir,amp_autocast=amp_autocast, log_suffix=' (EMA)',_logger=_logger,\n                        visualize=args.visualize, spike_rate=args.spike_rate,\n                        tsne=args.tsne, conf_mat=args.conf_mat)\n                    eval_metrics = ema_eval_metrics\n\n                if lr_scheduler is not None:\n                    # step LR for next epoch\n                    lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])\n\n                update_summary(\n                    epoch, train_metrics, eval_metrics, os.path.join(arch_dir, 'summary.csv'),\n                    write_header=best_metric is None)\n\n                # if saver is not None and epoch >= args.n_warm_up:\n                if saver is not None:\n                    # save proper checkpoint with eval metric\n                    save_metric = eval_metrics[eval_metric]\n                    best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)\n                best_metric, best_epoch = eval_metrics[eval_metric],epoch\n                _logger.info('Train: {0} '.format(best_metric))\n\n            f=open(os.path.join(arch_dir, 'direct.txt'), 'a')\n            f.write(str(best_metric))\n            f.write('\\n')\n            f.close()\n\n\n\n        except KeyboardInterrupt:\n            pass\n    except MemoryError:\n        return -10000, 0\n    except RuntimeError:\n        # return -10000, {'flops': flops / 1e6, 'param': params / 1e6}\n        return -10000, 0\n\n    # if best_metric is not None:\n    #     _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))\n\n\n    # info=get_net_info(model)\n    \n    val_metrics,spikes = validate(start_epoch, model, loader_eval, validate_loss_fn, args,arch_dir,\n                                    visualize=args.visualize, spike_rate=args.spike_rate,\n                                    tsne=args.tsne, conf_mat=args.conf_mat,_logger=_logger,)\n\n    return best_metric,spikes\n\ndef train_epoch(\n        epoch, model, loader, optimizer, loss_fn, args,_logger,\n        lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,\n        loss_scaler=None, model_ema=None, mixup_fn=None):\n    if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:\n        if args.prefetcher and loader.mixup_enabled:\n            loader.mixup_enabled = False\n        elif mixup_fn is not None:\n            mixup_fn.mixup_enabled = False\n\n    model.drop_path_prob = args.drop_path_prob * epoch / args.epochs\n\n    second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order\n    batch_time_m = AverageMeter()\n    data_time_m = AverageMeter()\n    losses_m = AverageMeter()\n    closses_m = AverageMeter()\n    top1_m = AverageMeter()\n    top5_m = AverageMeter()\n\n    model.train()\n\n    # t, k = adjust_surrogate_coeff(100, args.epochs)\n    # model.set_attr('t', t)\n    # model.set_attr('k', k)\n\n    end = time.time()\n    last_idx = len(loader) - 1\n    num_updates = epoch * len(loader)\n    for batch_idx, (inputs, target) in enumerate(loader):\n        last_batch = batch_idx == last_idx\n        data_time_m.update(time.time() - end)\n        if not args.prefetcher or args.dataset != 'imnet':\n            inputs, target = inputs.type(torch.FloatTensor).cuda(), target.cuda()\n            if mixup_fn is not None:\n                inputs, target = mixup_fn(inputs, target)\n        if args.channels_last:\n            inputs = inputs.contiguous(memory_format=torch.channels_last)\n        with amp_autocast():\n            output = model(inputs)\n            loss = loss_fn(output, target)\n        if not (args.cut_mix | args.mix_up | args.event_mix) and args.dataset != 'imnet':\n            # print(output.shape, target.shape)\n            acc1, acc5 = accuracy(output, target, topk=(1, 5))\n            # acc1, = accuracy(output, target)\n        else:\n            acc1, acc5 = torch.tensor([0.]), torch.tensor([0.])\n\n        closs = torch.tensor([0.], device=loss.device)\n\n        if args.critical_loss:\n            closs = calc_critical_loss(model)\n        loss = loss + .1 * closs\n\n        spike_rate_avg_layer_str = ''\n        threshold_str = ''\n\n\n\n        if not args.distributed:\n            losses_m.update(loss.item(), inputs.size(0))\n            top1_m.update(acc1.item(), inputs.size(0))\n            top5_m.update(acc5.item(), inputs.size(0))\n            closses_m.update(closs.item(), inputs.size(0))\n            if args.num_gpu>1:\n                spike_rate_avg_layer = model.module.get_fire_rate().tolist()\n                spike_rate_avg_layer_str = ['{:.3f}'.format(i) for i in spike_rate_avg_layer]\n                threshold = model.module.get_threshold()\n            \n            else:\n                spike_rate_avg_layer = model.get_fire_rate().tolist()\n                spike_rate_avg_layer_str = ['{:.3f}'.format(i) for i in spike_rate_avg_layer]\n                threshold = model.get_threshold()                \n            \n            threshold_str = ['{:.3f}'.format(i) for i in threshold]\n\n\n                  \n        optimizer.zero_grad()\n        if loss_scaler is not None:\n            loss_scaler(\n                loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order)\n        else:\n            loss.backward(create_graph=second_order)\n            if args.noisy_grad != 0.:\n                random_gradient(model, args.noisy_grad)\n            if args.clip_grad is not None:\n                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)\n            if args.opt == 'lamb':\n                optimizer.step(epoch=epoch)\n            else:\n                optimizer.step()\n\n        torch.cuda.synchronize()\n        if model_ema is not None:\n            model_ema.update(model)\n        num_updates += 1\n\n        batch_time_m.update(time.time() - end)\n        if last_batch or batch_idx % args.log_interval == 0:\n            lrl = [param_group['lr'] for param_group in optimizer.param_groups]\n            lr = sum(lrl) / len(lrl)\n\n            mu_str = ''\n            sigma_str = ''\n            if not args.distributed:\n                if 'Noise' in args.node_type:\n                    mu, sigma = model.get_noise_param()\n                    mu_str = ['{:.3f}'.format(i.detach()) for i in mu]\n                    sigma_str = ['{:.3f}'.format(i.detach()) for i in sigma]\n\n            if args.distributed:\n                reduced_loss = reduce_tensor(loss.data, args.world_size)\n                losses_m.update(reduced_loss.item(), inputs.size(0))\n                closses_m.update(reduced_loss.item(), inputs.size(0))\n\n            if args.local_rank == 0:\n                if args.distributed:\n                    _logger.info(\n                        'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '\n                        'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f})  '\n                        'cLoss: {closs.val:>9.6f} ({closs.avg:>6.4f})  '\n                        'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '\n                        'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})  '\n                        'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '\n                        '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '\n                        'LR: {lr:.3e}  '\n                        'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(\n                            epoch,\n                            batch_idx, len(loader),\n                            100. * batch_idx / last_idx,\n                            loss=losses_m,\n                            closs=closses_m,\n                            top1=top1_m,\n                            top5=top5_m,\n                            batch_time=batch_time_m,\n                            rate=inputs.size(0) * args.world_size / batch_time_m.val,\n                            rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,\n                            lr=lr,\n                            data_time=data_time_m\n                        ))\n                else:\n                    _logger.info(\n                        'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '\n                        'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f})  '\n                        'cLoss: {closs.val:>9.6f} ({closs.avg:>6.4f})  '\n                        'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '\n                        'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})  '\n                        'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '\n                        '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '\n                        'LR: {lr:.3e}  '\n                        'Data: {data_time.val:.3f} ({data_time.avg:.3f})\\n'\n                        'Fire_rate: {spike_rate}\\n'\n                        # 'Thres: {threshold}\\n'\n                        # 'Mu: {mu_str}\\n'\n                        # 'Sigma: {sigma_str}\\n'\n                        .format(\n                            epoch,\n                            batch_idx, len(loader),\n                            100. * batch_idx / last_idx,\n                            loss=losses_m,\n                            closs=closses_m,\n                            top1=top1_m,\n                            top5=top5_m,\n                            batch_time=batch_time_m,\n                            rate=inputs.size(0) * args.world_size / batch_time_m.val,\n                            rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,\n                            lr=lr,\n                            data_time=data_time_m,\n                            spike_rate=spike_rate_avg_layer_str,\n                            # threshold=threshold_str,\n                            # mu_str=mu_str,\n                            # sigma_str=sigma_str\n                        ))\n\n                if args.save_images and output_dir:\n                    torchvision.utils.save_image(\n                        inputs,\n                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),\n                        padding=0,\n                        normalize=True)\n\n        if saver is not None and args.recovery_interval and (\n                last_batch or (batch_idx + 1) % args.recovery_interval == 0):\n            saver.save_recovery(epoch, batch_idx=batch_idx)\n\n        if lr_scheduler is not None:\n            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)\n\n        end = time.time()\n    # end for\n\n    if hasattr(optimizer, 'sync_lookahead'):\n        optimizer.sync_lookahead()\n\n    return OrderedDict([('loss', losses_m.avg)])\n\n\ndef validate(epoch, model, loader, loss_fn, args, arch_dir,_logger,amp_autocast=suppress,\n             log_suffix='', visualize=False, spike_rate=False, tsne=False, conf_mat=False):\n    batch_time_m = AverageMeter()\n    losses_m = AverageMeter()\n    closses_m = AverageMeter()\n    top1_m = AverageMeter()\n    top5_m = AverageMeter()\n\n    model.eval()\n\n    feature_vec = []\n    feature_cls = []\n    logits_vec = []\n    labels_vec = []\n\n    end = time.time()\n    last_idx = len(loader) - 1\n    with torch.no_grad():\n        for batch_idx, (inputs, target) in enumerate(loader):\n            # inputs = inputs.type(torch.float64)\n            last_batch = batch_idx == last_idx\n            if not args.prefetcher or args.dataset != 'imnet':\n                inputs = inputs.type(torch.FloatTensor).cuda()\n                target = target.cuda()\n            if args.channels_last:\n                inputs = inputs.contiguous(memory_format=torch.channels_last)\n\n            if not args.distributed:\n                if (visualize or spike_rate or tsne or conf_mat) and not args.critical_loss:\n                    if args.num_gpu>1:\n                        model.module.set_requires_fp(True)\n                    else:\n                        model.set_requires_fp(True)\n\n                    # if not args.critical_loss:\n                    #     model.set_requires_fp(False)\n\n            with amp_autocast():\n                output = model(inputs)\n            if isinstance(output, (tuple, list)):\n                output = output[0]\n\n            if not args.distributed:\n                if visualize:\n                    x = model.get_fp()\n                    feature_path = os.path.join(arch_dir, 'feature_map')\n                    if os.path.exists(feature_path) is False:\n                        os.mkdir(feature_path)\n                    save_feature_map(x, feature_path)\n                    # if not args.critical_loss:\n                    #     model_config.set_requires_fp(False)\n\n                if tsne:\n                    x = model.get_fp(temporal_info=False)[-1]\n                    x = torch.nn.AdaptiveAvgPool2d((1, 1))(x)\n                    x = x.reshape(x.shape[0], -1)\n                    feature_vec.append(x)\n                    feature_cls.append(target)\n\n                if conf_mat:\n                    logits_vec.append(output)\n                    labels_vec.append(target)\n\n                if spike_rate:\n                    if args.num_gpu>1:\n                        avg, var, spike, avg_per_step = model.module.get_spike_info()\n\n                    else:\n                        avg, var, spike, avg_per_step = model.get_spike_info()\n                    save_spike_info(\n                        os.path.join(arch_dir, 'spike_info.csv'),\n                        epoch, batch_idx,\n                        args.step, avg, var,\n                        spike, avg_per_step)\n\n            # augmentation reduction\n            reduce_factor = args.tta\n            if reduce_factor > 1:\n                output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)\n                target = target[0:target.size(0):reduce_factor]\n\n            loss = loss_fn(output, target)\n            acc1, acc5 = accuracy(output, target, topk=(1, 5))\n            # acc1, = accuracy(output, target)\n\n            closs = torch.tensor([0.], device=loss.device)\n\n            if not args.distributed:\n                if args.num_gpu>1:\n                    spike_rate_avg_layer = model.module.get_fire_rate().tolist()\n                    threshold = model.module.get_threshold()\n                    threshold_str = ['{:.3f}'.format(i) for i in threshold]\n                    spike_rate_avg_layer_str = ['{:.3f}'.format(i) for i in spike_rate_avg_layer]\n                    tot_spike = model.module.get_tot_spike()\n                else:\n                    spike_rate_avg_layer = model.get_fire_rate().tolist()\n                    threshold = model.get_threshold()\n                    threshold_str = ['{:.3f}'.format(i) for i in threshold]\n                    spike_rate_avg_layer_str = ['{:.3f}'.format(i) for i in spike_rate_avg_layer]\n                    tot_spike = model.get_tot_spike()                    \n\n            if args.critical_loss:\n                closs = calc_critical_loss(model)\n            loss = loss + .1 * closs\n\n            if args.distributed:\n                reduced_loss = reduce_tensor(loss.data, args.world_size)\n                acc1 = reduce_tensor(acc1, args.world_size)\n                acc5 = reduce_tensor(acc5, args.world_size)\n            else:\n                reduced_loss = loss.data\n\n            torch.cuda.synchronize()\n\n            losses_m.update(reduced_loss.item(), inputs.size(0))\n            top1_m.update(acc1.item(), output.size(0))\n            top5_m.update(acc5.item(), output.size(0))\n            closses_m.update(closs.item(), inputs.size(0))\n\n            batch_time_m.update(time.time() - end)\n            end = time.time()\n            if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):\n                log_name = 'Test' + log_suffix\n\n                mu_str = ''\n                sigma_str = ''\n                if not args.distributed:\n                    if 'Noise' in args.node_type:\n                        mu, sigma = model.get_noise_param()\n                        mu_str = ['{:.3f}'.format(i.detach()) for i in mu]\n                        sigma_str = ['{:.3f}'.format(i.detach()) for i in sigma]\n\n                if args.distributed:\n                    _logger.info(\n                        '{0}: [{1:>4d}/{2}]  '\n                        'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '\n                        'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '\n                        'cLoss: {closs.val:>7.4f} ({closs.avg:>6.4f})  '\n                        'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})'\n                        'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(\n                            log_name,\n                            batch_idx,\n                            last_idx,\n                            batch_time=batch_time_m,\n                            loss=losses_m,\n                            closs=closses_m,\n                            top1=top1_m,\n                            top5=top5_m,\n                            ))\n                else:\n                    _logger.info(\n                        '{0}: [{1:>4d}/{2}]  '\n                        'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '\n                        'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '\n                        'cLoss: {closs.val:>7.4f} ({closs.avg:>6.4f})  '\n                        'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})'\n                        'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})\\n'\n                        'Fire_rate: {spike_rate}\\n'\n                        'Tot_spike: {tot_spike}\\n'\n                        'Thres: {threshold}\\n'\n                        'Mu: {mu_str}\\n'\n                        'Sigma: {sigma_str}\\n'.format(\n                            log_name,\n                            batch_idx,\n                            last_idx,\n                            batch_time=batch_time_m,\n                            loss=losses_m,\n                            closs=closses_m,\n                            top1=top1_m,\n                            top5=top5_m,\n                            spike_rate=spike_rate_avg_layer_str,\n                            tot_spike=tot_spike,\n                            threshold=threshold_str,\n                            mu_str=mu_str,\n                            sigma_str=sigma_str\n                        ))\n\n    # metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])\n    metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg)])\n\n    if not args.distributed:\n        if tsne:\n            feature_vec = torch.cat(feature_vec)\n            feature_cls = torch.cat(feature_cls)\n            plot_tsne(feature_vec, feature_cls, os.path.join(arch_dir, 't-sne-2d.eps'))\n            plot_tsne_3d(feature_vec, feature_cls, os.path.join(arch_dir, 't-sne-3d.eps'))\n        if conf_mat:\n            logits_vec = torch.cat(logits_vec)\n            labels_vec = torch.cat(labels_vec)\n            plot_confusion_matrix(logits_vec, labels_vec, os.path.join(arch_dir, 'confusion_matrix.eps'))\n\n    return metrics,tot_spike\n\n\n\ndef get_net_info(args, gen,genome,ms):\n    \"\"\"\n    Modified from https://github.com/mit-han-lab/once-for-all/blob/\n    35ddcb9ca30905829480770a6a282d49685aa282/ofa/imagenet_codebase/utils/pytorch_utils.py#L139\n    \"\"\"\n    from ofa.imagenet_codebase.utils.pytorch_utils import count_parameters, measure_net_latency\n\n    # artificial input data\n\n    if args.bns:\n        from cellmodel import NetworkCIFAR\n    else:\n        from cell123model import NetworkCIFAR\n    \n\n    test_motifs,ids = micro_encoding.decode_motif(args.layers*ms,args.bits,genome.astype(int))\n    net = create_model(\n            args.model,\n            pretrained=args.pretrained,\n            num_classes=args.num_classes,\n            dataset=args.dataset,\n            step=args.step,\n            encode_type=args.encode,\n            node_type=eval(args.node_type),\n            threshold=args.threshold,\n            tau=args.tau,\n            sigmoid_thres=args.sigmoid_thres,\n            requires_thres_grad=args.requires_thres_grad,\n            spike_output=not args.no_spike_output,\n            C=args.init_channels,\n            layers=args.layers*ms,\n            auxiliary=args.auxiliary,\n            motif=test_motifs,\n            parse_method=args.parse_method,\n            act_fun=args.act_fun,\n            temporal_flatten=args.temporal_flatten,\n            layer_by_layer=args.layer_by_layer,\n            n_groups=args.n_groups,\n            cell_type=genome[-1],\n        )\n\n    if 'dvs' in args.dataset:\n        args.channels = 2\n    elif 'mnist' in args.dataset:\n        args.channels = 1\n    else:\n        args.channels = 3\n    inputs = torch.randn(1, args.channels, 224, 224)\n\n\n    # move network to GPU if available\n    if torch.cuda.is_available():\n        device = torch.device('cuda:0')\n        net = net.to(device)\n        cudnn.benchmark = True\n        inputs = inputs.to(device)\n\n    net_info = {}\n    if isinstance(net, nn.DataParallel):\n        net = net.module\n\n    # parameters\n    net_info['params'] = count_parameters(net)\n\n    # flops\n    net_info['flops'] = int(profile_macs(copy.deepcopy(net), inputs))\n\n\n    return net_info\n"
  },
  {
    "path": "examples/Structure_Evolution/ELSM/README.md",
    "content": "\n\n\n\n# Emergence of Brain-inspired Small-world Spiking Neural Network through Neuroevolution —— Based on BrainCog #\n\n\n\n## Requirments ##\n* numpy\n* pytorch >= 1.12.0\n* BrainCog\n\n## Run ##\n\n```python evolve.py```\n\n## Citation ##\n\nIf you find the code and dataset useful in your research, please consider citing:\n```\n@article{pan2024emergence,\n  title={Emergence of Brain-inspired Small-world Spiking Neural Network through Neuroevolution},\n  author={Pan, Wenxuan and Zhao, Feifei and Han, Bing and Dong, Yiting and Zeng, Yi},\n  journal={iScience},\n  year={2024},\n  publisher={Elsevier}\n}\n\n@article{zeng2023braincog,\n  title={BrainCog: A spiking neural network based, brain-inspired cognitive intelligence engine for brain-inspired AI and brain simulation},\n  author={Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and others},\n  journal={Patterns},\n  volume={4},\n  number={8},\n  year={2023},\n  publisher={Elsevier}\n}\n```\n"
  },
  {
    "path": "examples/Structure_Evolution/ELSM/evolve.py",
    "content": "import time\nimport threading\nfrom threading import Thread\nimport os\nimport networkx as nx\nimport numpy as np\nfrom population import *\nimport nsganet as engine\nfrom pymop.problem import Problem\nfrom pymoo.optimize import minimize\nfrom pymoo.operators.sampling.random_sampling import RandomSampling\nfrom pymoo.operators.mutation.bitflip_mutation import BinaryBitflipMutation\nimport logging\nfrom model import *\nfrom spikes import calc_f2\nfrom multiprocessing import Process,Pool\nfrom datetime import datetime\nimport time\n\n\n_logger = logging.getLogger('')\nconfig_parser = parser = argparse.ArgumentParser(description='Evolution Config', add_help=False)\n\nparser = argparse.ArgumentParser(description='SNN Evoving')\nparser.add_argument('--device', type=int, default=2)\nparser.add_argument('--seed', type=int, default=68, metavar='S')\nparser.add_argument('--datapath', default='/data/', type=str, metavar='PATH')\nparser.add_argument('--output', default='/data/LSM/Eresult/new', type=str, metavar='PATH')\nparser.add_argument('--liquid-size', type=int, default=8000)\nparser.add_argument('--pop-size', type=int, default=20)\nparser.add_argument('--up', type=int, default=32000000)\nparser.add_argument('--low', type=int, default=320000)\n\nparser.add_argument('--n_offspring', type=int, default=200)\nparser.add_argument('--n_gens', type=int, default=2000)\nparser.add_argument('--arand', type=float, default=285)\nparser.add_argument('--brand', type=float, default=1.8)\n\n\ndef _parse_args():\n    args_config, remaining = config_parser.parse_known_args()\n    args = parser.parse_args(remaining)\n    return args\n\ndef calc_f1(dirs):\n    ci=[]\n    G=nx.read_gpickle(dirs)\n    largest_component = max(nx.connected_components(G), key=len)\n    G = G.subgraph(largest_component)\n    for u in G.nodes:\n        ci.append(nx.clustering(G,u))\n    a=sum(ci)\n    print(\"start\")\n    print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))\n    path=nx.average_shortest_path_length(G)\n    print(\"end\")\n    print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))\n    return a,path\n\ndef mul_f1(pop,steps,rootdir):\n    result=[]\n    for i in range(0,pop,steps):\n        p = Pool(steps)\n        dirs=[os.path.join(rootdir,str(i)+'.pkl') for i in range(i,i+steps)]\n        ret = p.map(calc_f1,dirs)\n        result.extend(ret)\n        print(ret)\n        p.close()\n        p.join()\n    return result\n\nclass Evolve(Problem):\n    # first define the NAS problem (inherit from pymop)\n    def __init__(self, args,n_var=20, n_obj=1, n_constr=0, lb=None, ub=None):\n        super().__init__(n_var=n_var, n_obj=n_obj, n_constr=n_constr, type_var=np.int64)\n        self.xl = lb\n        self.xu = ub\n        self._n_evaluated = 0  # keep track of how many architectures are sampled\n        self.args=args\n\n\n    def _evaluate(self, x, out, *args, **kwargs):\n        \n\n        objs = np.full((x.shape[0], self.n_obj), np.nan)\n        g1 = np.full((x.shape[0]), np.nan)\n        g2 = np.full((x.shape[0]), np.nan)\n        gen_dir=os.path.join(self.args.output,'generaion'+str(kwargs['algorithm'].n_gen))\n        os.makedirs(gen_dir,exist_ok = True)\n        # np.save(os.path.join(gen_dir,\"x.npy\"),x)\n        lsms = x.reshape(x.shape[0],self.args.liquid_size,self.args.liquid_size)\n        for i in range(x.shape[0]):\n            temp_G = nx.Graph(lsms[i])\n            nx.write_gpickle(temp_G, os.path.join(gen_dir,str(i)+\".pkl\"))\n        self.ob1=mul_f1(pop=x.shape[0],steps=10,rootdir=gen_dir)\n\n        for i in range(x.shape[0]):\n            arch_id = self._n_evaluated + 1\n            print('\\n')\n            _logger.info('Network= {}'.format(arch_id))\n            genome = x[i, :]\n\n            g1[i]= genome.sum()-self.args.up\n            g2[i]= self.args.low-genome.sum()\n            lsmm = genome.reshape(self.args.liquid_size,self.args.liquid_size)\n            small_coe_a,small_coe_b=self.ob1[i]\n            lsmm=torch.tensor(lsmm,device='cuda:%d' % self.args.device).float()\n            crit = calc_f2(lsmm,'cuda:%d' % self.args.device)\n            objs[i, 1] = abs(crit-1)\n            # all objectives assume to be MINIMIZED !!!!!                \n            objs[i, 0] = -(small_coe_a/self.args.arand)/(small_coe_b/self.args.brand)\n            \n\n            _logger.info('small word= {}'.format(objs[i, 0]))\n            _logger.info('criticality= {}'.format(objs[i, 1]))\n\n            self._n_evaluated += 1\n\n        out[\"F\"] = objs\n        out[\"G\"] = np.column_stack([g1,g2])\n        # if your NAS problem has constraints, use the following line to set constraints\n        # out[\"G\"] = np.column_stack([g1, g2, g3, g4, g5, g6]) in case 6 constraints\n\n\n# ---------------------------------------------------------------------------------------------------------\n# Define what statistics to print or save for each generation\n# ---------------------------------------------------------------------------------------------------------\ndef do_every_generations(algorithm):\n    # this function will be call every generation\n    # it has access to the whole algorithm class\n    gen = algorithm.n_gen\n    pop_var = algorithm.pop.get(\"X\")\n    pop_obj = algorithm.pop.get(\"F\")\n    \n    # report generation info to files\n    _logger.info(\"generation = {}\".format(gen))\n    _logger.info(\"population error1: best = {}, mean = {}, \"\n                 \"median1 = {}, worst1 = {}\".format(np.min(pop_obj[:, 0]), np.mean(pop_obj[:, 0]),\n                                                  np.median(pop_obj[:, 0]), np.max(pop_obj[:, 0])))\n    _logger.info('Best1 Genome id= {}'.format(np.argmin(pop_obj[:, 0])))\n\n    _logger.info(\"population error2: best = {}, mean = {}, \"\n                 \"median2 = {}, worst2 = {}\".format(np.min(pop_obj[:, 1]), np.mean(pop_obj[:, 1]),\n                                                  np.median(pop_obj[:, 1]), np.max(pop_obj[:, 1])))\n    _logger.info('Best2 Genome id= {}'.format(np.argmin(pop_obj[:, 1])))\n    if gen%20==0:\n        best_sid=np.argmin(pop_obj[:, 0])\n        best_sname='-'.join([\n                'gen'+str(gen),\n                's'+str(float('%.4f' % pop_obj[best_sid, 0])),\n                'c'+str(float('%.4f' % pop_obj[best_sid, 1])),\n            ])\n        best_cid=np.argmin(pop_obj[:, 1])\n        best_cname='-'.join([\n                'gen'+str(gen),\n                's'+str(float('%.4f' % pop_obj[best_cid, 0])),\n                'c'+str(float('%.4f' % pop_obj[best_cid, 1])),\n            ])\n        \n        np.save(os.path.join('/data/save/genome',best_sname+datetime.now().strftime(\"%Y%m%d-%H%M%S\")),pop_var[np.argmin(pop_obj[:, 0])])\n        np.save(os.path.join('/data/save/genome',best_cname+datetime.now().strftime(\"%Y%m%d-%H%M%S\")),pop_var[np.argmin(pop_obj[:, 1])])\n\nif __name__ == '__main__':\n    args = _parse_args()\n    out_base_dir= os.path.join(args.output, datetime.now().strftime(\"%Y%m%d-%H%M%S\"))\n    os.makedirs(out_base_dir,exist_ok = True)\n    args.output=out_base_dir\n    setup_default_logging(log_path=os.path.join(out_base_dir, 'log.txt'))\n\n    kkk = Evolve(args,n_var=args.liquid_size*args.liquid_size, \n                  n_obj=2, n_constr=2)\n    method = engine.nsganet(pop_size=args.pop_size,\n                            sampling=RandomSampling(var_type='custom'),\n                            mutation=BinaryBitflipMutation(),\n                            n_offsprings=args.n_offspring,\n                            eliminate_duplicates=True)\n    kres=minimize(kkk,\n                   method,\n                   callback=do_every_generations,\n                   termination=('n_gen', args.n_gens))\n\n\n"
  },
  {
    "path": "examples/Structure_Evolution/ELSM/lsm.py",
    "content": "from __future__ import print_function\nimport torchvision\nimport torchvision.transforms as transforms\nimport os\nimport time\nimport numpy as np\nimport torch\nfrom torch import nn as nn\nfrom mnistmodel import *\nfrom tqdm import tqdm\nimport argparse\nfrom datetime import datetime\nimport logging\nfrom timm.utils import *\nfrom timm.loss import LabelSmoothingCrossEntropy\nfrom braincog.base.utils import UnilateralMse, MixLoss\nfrom braincog.base.learningrule.STDP import *\n\ndevice='cuda:7'\n\ndef lr_scheduler(optimizer, epoch, init_lr=0.1, lr_decay_epoch=50):\n    \"\"\"Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs.\"\"\"\n    if epoch % lr_decay_epoch == 0 and epoch > 1:\n        for param_group in optimizer.param_groups:\n            param_group['lr'] = param_group['lr'] * 0.1\n    return optimizer   \n\n\nbatch_size=100\nliquid_size=8000\n\nlearning_rate = 1e-3\nnum_epochs = 100  # max epoch\n\ndata_path = '/data'  \nload_path=''\ntrain_dataset = torchvision.datasets.MNIST(root=data_path, train=True, download=False, transform=transforms.ToTensor())\ntrain_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)\n\ntest_set = torchvision.datasets.MNIST(root=data_path, train=False, download=False, transform=transforms.ToTensor())\ntest_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)\n\nsnn = SNN(ins=784,\n        batchsize=batch_size,\n        device=device,\n        liquid_size=liquid_size,\n        lsm_tau=lsm_tau,\n        lsm_th=lsm_th)\nsnn.load_state_dict(torch.load(load_path)['fc'])\nsnn.learning_rule=[]\nsnn.con[0].load_state_dict(torch.load(load_path)['lsm0'])\nw2tmp=nn.Linear(liquid_size,liquid_size,bias=False,device=device)\nsnn.connectivity_matrix=torch.load(load_path)['connectivity_matrix'].to(device)\nw2tmp.weight.data=(torch.load(load_path)['liquid_weight'].to(device))*snn.connectivity_matrix\nsnn.learning_rule.append(MutliInputSTDP(snn.node_lsm(), [snn.con[0], w2tmp]))  # pm\nsnn.eval()\nsnn.to(device)\n\nclass LabelSmoothingBCEWithLogitsLoss(nn.Module):\n\n    def __init__(self, smoothing=0.1):\n        \"\"\"\n        Constructor for the LabelSmoothing module.\n        :param smoothing: label smoothing factor\n        \"\"\"\n        super(LabelSmoothingBCEWithLogitsLoss, self).__init__()\n        assert smoothing < 1.0\n        self.smoothing = smoothing\n        self.confidence = 1. - smoothing\n        self.BCELoss = nn.BCEWithLogitsLoss()\n\n    def forward(self, x, target):\n        target = torch.eye(x.shape[-1], device=x.device)[target]\n        nll = torch.ones_like(x) / x.shape[-1]\n        return self.BCELoss(x, target) * self.confidence + self.BCELoss(x, nll) * self.smoothing\n\n\nls = 'mse'\n\nif ls == 'ce':\n    criterion = nn.CrossEntropyLoss()\nelif ls == 'bce':\n    criterion = nn.BCEWithLogitsLoss()\nelif ls == 'mse':\n    criterion = UnilateralMse(1.)\nelif ls == 'sce':\n    criterion = LabelSmoothingCrossEntropy()\nelif ls == 'sbce':\n    criterion = LabelSmoothingBCEWithLogitsLoss()\nelif ls == 'umse':\n    criterion = UnilateralMse(.5)\n\noptimizer = torch.optim.AdamW(snn.fc.parameters(),lr=0.001, weight_decay=1e-4)\n\nl=[]\nbest_acc=0\nfor epoch in range(num_epochs):\n    running_loss = 0\n    start_time = time.time()\n    for i, (images, labels) in enumerate(tqdm(train_loader)):\n        snn.zero_grad()\n        optimizer.zero_grad()\n        images = images.float().to(device)\n        outputs = snn(images)\n        labels=labels.to(device)\n        loss = criterion(outputs, labels)\n        running_loss += loss.item()\n        loss.backward()\n\n        optimizer.step()\n        snn.reset()\n        if (i + 1) % 100 == 0:\n            running_loss = 0\n\n    correct = 0\n    total = 0\n    optimizer = lr_scheduler(optimizer, epoch, learning_rate, 40)\n\n    for batch_idx, (inputs, targets) in enumerate(test_loader):\n        inputs = inputs.float().to(device)\n        snn.zero_grad()\n        optimizer.zero_grad()\n        outputs = snn(inputs)\n        targets=targets.to(device)\n        loss = criterion(outputs, targets)\n\n        _, predicted = outputs.max(1)\n        total += float(targets.size(0))\n        correct += float(predicted.eq(targets).sum().item())\n        snn.reset()\n        if batch_idx % 100 == 0:\n            acc = 100. * float(correct) / float(total)\n            print(batch_idx, len(test_loader), ' Acc: %.5f' % acc)\n    print('Test Accuracy: %.3f' % (100 * correct / total))\n    acc = 100. * float(correct) / float(total)\n    if best_acc < acc:\n        best_acc = acc\n    print(best_acc)\n    l.append(best_acc)\n\n\n\n"
  },
  {
    "path": "examples/Structure_Evolution/ELSM/model.py",
    "content": "from functools import partial\nfrom torch.nn import functional as F\nfrom torch import nn as nn\nimport torchvision, pprint\nfrom copy import deepcopy\nfrom timm.models import register_model\nfrom braincog.base.node.node import *\nfrom braincog.base.connection.layer import *\nfrom braincog.base.encoder.encoder import *\nfrom braincog.model_zoo.base_module import BaseModule, BaseConvModule, BaseLinearModule\nfrom braincog.base.brainarea.BrainArea import BrainArea\nfrom braincog.base.connection.CustomLinear import *\nfrom braincog.base.learningrule.STDP import *\nimport matplotlib.pyplot as plt\n\n\n\n\n@register_model\nclass nSNN(BaseModule):\n    def __init__(self,\n                 batchsize,\n                 liquid_size,\n                 device,\n                 connectivity_matrix,\n                 num_classes=10,\n                 step=1,\n                 node_type=LIFNode,\n                 encode_type='direct',\n                 lsm_th=0.3,\n                 fc_th=0.3,\n                 lsm_tau=3,\n                 fc_tau=3,\n                 ins=1156,\n                 *args,\n                 **kwargs):\n        super().__init__(step, encode_type, *args, **kwargs)\n        self.batchsize=batchsize\n        self.ins=ins\n        self.node_lsm=partial(node_type, **kwargs, step=step,tau=lsm_tau,threshold=lsm_th)\n        self.node_fc = partial(node_type, **kwargs, step=step,tau=fc_tau,threshold=fc_th)\n        self.liquid_size=liquid_size\n        self.device=device\n        self.con=[]\n        self.learning_rule=[]\n        self.connectivity_matrix=connectivity_matrix\n        w1tmp=nn.Linear(ins,liquid_size,bias=False).to(device)\n        self.con.append(w1tmp)\n        w2tmp=nn.Linear(liquid_size,liquid_size,bias=False).to(device)\n        self.liquid_weight=w2tmp.weight.data\n        w2tmp.weight.data=w2tmp.weight.data*self.connectivity_matrix\n        self.con.append(w2tmp)\n\n        self.learning_rule.append(MutliInputSTDP(self.node_lsm(), [self.con[0], self.con[1]]))  # pm\n\n        self.fc = nn.Sequential(\n            nn.Linear(liquid_size,num_classes),\n            self.node_fc()\n        )\n\n    def forward(self, x):\n        sum_spike=0\n        self.out = torch.zeros(x.shape[0], self.liquid_size).to(self.device)\n        tw=x.shape[1]\n        self.tw=tw\n        self.firing_tw=torch.zeros(tw, self.batchsize, self.liquid_size).to(self.device)\n\n        for t in range(tw):\n            self.out, self.dw = self.learning_rule[0](x[:,t,:], self.out)\n            out_liquid=self.out[:,0:self.liquid_size]\n            xout = self.fc(out_liquid)\n            sum_spike=sum_spike+xout\n            self.firing_tw[t]=out_liquid\n        outputs = sum_spike / tw\n        return outputs\n\n\n@register_model\nclass mSNN(BaseModule):\n    def __init__(self,\n                 batchsize,\n                 liquid_size,\n                 device,\n                 connectivity_matrix,\n                 num_classes=10,\n                 step=1,\n                 node_type=LIFNode,\n                 encode_type='direct',\n                 lsm_th=0.3,\n                 fc_th=0.3,\n                 lsm_tau=3,\n                 fc_tau=3,\n                 tw=100,\n                 *args,\n                 **kwargs):\n        super().__init__(step, encode_type, *args, **kwargs)\n        self.batchsize=batchsize\n\n        self.node_lsm=partial(node_type, **kwargs, step=step,tau=lsm_tau,threshold=lsm_th)\n        self.node_fc = partial(node_type, **kwargs, step=step,tau=fc_tau,threshold=fc_th)\n        self.liquid_size=liquid_size\n        self.out = torch.zeros(self.batchsize, liquid_size).to(device)\n        self.device=device\n        self.con=[]\n        self.learning_rule=[]\n        self.connectivity_matrix=connectivity_matrix\n        w1tmp=nn.Linear(784,liquid_size,bias=False).to(device)\n        self.con.append(w1tmp)\n        w2tmp=nn.Linear(liquid_size,liquid_size,bias=False).to(device)\n\n        self.liquid_weight=w2tmp.weight.data\n        \n        w2tmp.weight.data=w2tmp.weight.data*self.connectivity_matrix\n        self.con.append(w2tmp)\n        self.learning_rule.append(MutliInputSTDP(self.node_lsm(), [self.con[0], self.con[1]]))  # pm\n\n        self.fc = nn.Sequential(\n            nn.Linear(liquid_size,num_classes),\n            self.node_fc()\n        )\n\n    def forward(self, x):\n        x = x.reshape(x.shape[0], -1)\n        sum_spike=0\n        time_window=20\n        self.tw=time_window\n        self.firing_tw=torch.zeros(time_window, self.batchsize, self.liquid_size).to(self.device)\n        self.out = torch.zeros(self.batchsize, self.liquid_size).to(self.device)\n        for t in range(time_window):\n\n            self.out, self.dw = self.learning_rule[0](x, self.out)\n\n            out_liquid=self.out[:,0:self.liquid_size]\n            xout = self.fc(out_liquid)\n            sum_spike=sum_spike+xout\n            self.firing_tw[t]=out_liquid\n        # print(out_liquid.sum())\n        # print(xout.sum())\n        outputs = sum_spike / time_window\n        return outputs\n\n"
  },
  {
    "path": "examples/Structure_Evolution/ELSM/nsganet.py",
    "content": "import numpy as np\n\nfrom pymoo.algorithms.genetic_algorithm import GeneticAlgorithm\nfrom pymoo.docs import parse_doc_string\nfrom pymoo.model.individual import Individual\nfrom pymoo.model.survival import Survival\nfrom pymoo.operators.crossover.point_crossover import PointCrossover\nfrom pymoo.operators.mutation.polynomial_mutation import PolynomialMutation\nfrom pymoo.operators.mutation.bitflip_mutation import BinaryBitflipMutation\nfrom pymoo.operators.sampling.random_sampling import RandomSampling\nfrom pymoo.operators.selection.tournament_selection import compare, TournamentSelection\nfrom pymoo.util.display import disp_multi_objective\nfrom pymoo.util.dominator import Dominator\nfrom pymoo.util.non_dominated_sorting import NonDominatedSorting\nfrom pymoo.util.randomized_argsort import randomized_argsort\n\n\n# =========================================================================================================\n# Implementation\n# based on nsga2 from https://github.com/msu-coinlab/pymoo\n# =========================================================================================================\n\n\nclass NSGANet(GeneticAlgorithm):\n\n    def __init__(self, **kwargs):\n        kwargs['individual'] = Individual(rank=np.inf, crowding=-1)\n        super().__init__(**kwargs)\n\n        self.tournament_type = 'comp_by_dom_and_crowding'\n        self.func_display_attrs = disp_multi_objective\n\n\n# ---------------------------------------------------------------------------------------------------------\n# Binary Tournament Selection Function\n# ---------------------------------------------------------------------------------------------------------\n\n\ndef binary_tournament(pop, P, algorithm, **kwargs):\n    if P.shape[1] != 2:\n        raise ValueError(\"Only implemented for binary tournament!\")\n\n    tournament_type = algorithm.tournament_type\n    S = np.full(P.shape[0], np.nan)\n\n    for i in range(P.shape[0]):\n\n        a, b = P[i, 0], P[i, 1]\n\n        # if at least one solution is infeasible\n        if pop[a].CV > 0.0 or pop[b].CV > 0.0:\n            S[i] = compare(a, pop[a].CV, b, pop[b].CV, method='smaller_is_better', return_random_if_equal=True)\n\n        # both solutions are feasible\n        else:\n\n            if tournament_type == 'comp_by_dom_and_crowding':\n                rel = Dominator.get_relation(pop[a].F, pop[b].F)\n                if rel == 1:\n                    S[i] = a\n                elif rel == -1:\n                    S[i] = b\n\n            elif tournament_type == 'comp_by_rank_and_crowding':\n                S[i] = compare(a, pop[a].rank, b, pop[b].rank,\n                               method='smaller_is_better')\n\n            else:\n                raise Exception(\"Unknown tournament type.\")\n\n            # if rank or domination relation didn't make a decision compare by crowding\n            if np.isnan(S[i]):\n                S[i] = compare(a, pop[a].get(\"crowding\"), b, pop[b].get(\"crowding\"),\n                               method='larger_is_better', return_random_if_equal=True)\n\n    return S[:, None].astype(np.int)\n\n\n# ---------------------------------------------------------------------------------------------------------\n# Survival Selection\n# ---------------------------------------------------------------------------------------------------------\n\n\nclass RankAndCrowdingSurvival(Survival):\n\n    def __init__(self) -> None:\n        super().__init__(True)\n\n    def _do(self, pop, n_survive, D=None, **kwargs):\n\n        # get the objective space values and objects\n        F = pop.get(\"F\")\n\n        # the final indices of surviving individuals\n        survivors = []\n\n        # do the non-dominated sorting until splitting front\n        fronts = NonDominatedSorting().do(F, n_stop_if_ranked=n_survive)\n\n        for k, front in enumerate(fronts):\n\n            # calculate the crowding distance of the front\n            crowding_of_front = calc_crowding_distance(F[front, :])\n\n            # save rank and crowding in the individual class\n            for j, i in enumerate(front):\n                pop[i].set(\"rank\", k)\n                pop[i].set(\"crowding\", crowding_of_front[j])\n\n            # current front sorted by crowding distance if splitting\n            if len(survivors) + len(front) > n_survive:\n                I = randomized_argsort(crowding_of_front, order='descending', method='numpy')\n                I = I[:(n_survive - len(survivors))]\n\n            # otherwise take the whole front unsorted\n            else:\n                I = np.arange(len(front))\n\n            # extend the survivors by all or selected individuals\n            survivors.extend(front[I])\n\n        return pop[survivors]\n\n\ndef calc_crowding_distance(F):\n    infinity = 1e+14\n\n    n_points = F.shape[0]\n    n_obj = F.shape[1]\n\n    if n_points <= 2:\n        return np.full(n_points, infinity)\n    else:\n\n        # sort each column and get index\n        I = np.argsort(F, axis=0, kind='mergesort')\n\n        # now really sort the whole array\n        F = F[I, np.arange(n_obj)]\n\n        # get the distance to the last element in sorted list and replace zeros with actual values\n        dist = np.concatenate([F, np.full((1, n_obj), np.inf)]) \\\n               - np.concatenate([np.full((1, n_obj), -np.inf), F])\n\n        index_dist_is_zero = np.where(dist == 0)\n\n        dist_to_last = np.copy(dist)\n        for i, j in zip(*index_dist_is_zero):\n            dist_to_last[i, j] = dist_to_last[i - 1, j]\n\n        dist_to_next = np.copy(dist)\n        for i, j in reversed(list(zip(*index_dist_is_zero))):\n            dist_to_next[i, j] = dist_to_next[i + 1, j]\n\n        # normalize all the distances\n        norm = np.max(F, axis=0) - np.min(F, axis=0)\n        norm[norm == 0] = np.nan\n        dist_to_last, dist_to_next = dist_to_last[:-1] / norm, dist_to_next[1:] / norm\n\n        # if we divided by zero because all values in one columns are equal replace by none\n        dist_to_last[np.isnan(dist_to_last)] = 0.0\n        dist_to_next[np.isnan(dist_to_next)] = 0.0\n\n        # sum up the distance to next and last and norm by objectives - also reorder from sorted list\n        J = np.argsort(I, axis=0)\n        crowding = np.sum(dist_to_last[J, np.arange(n_obj)] + dist_to_next[J, np.arange(n_obj)], axis=1) / n_obj\n\n    # replace infinity with a large number\n    crowding[np.isinf(crowding)] = infinity\n\n    return crowding\n\n\n# =========================================================================================================\n# Interface\n# =========================================================================================================\n\n\ndef nsganet(\n        pop_size=100,\n        sampling=RandomSampling(var_type=np.int),\n        selection=TournamentSelection(func_comp=binary_tournament),\n        crossover=PointCrossover(n_points=2),\n        mutation=PolynomialMutation(eta=3, var_type=np.int),\n        \n        eliminate_duplicates=True,\n        n_offsprings=None,\n        **kwargs):\n    \"\"\"\n\n    Parameters\n    ----------\n    pop_size : {pop_size}\n    sampling : {sampling}\n    selection : {selection}\n    crossover : {crossover}\n    mutation : {mutation}\n    eliminate_duplicates : {eliminate_duplicates}\n    n_offsprings : {n_offsprings}\n\n    Returns\n    -------\n    nsganet : :class:`~pymoo.model.algorithm.Algorithm`\n        Returns an NSGANet algorithm object.\n\n\n    \"\"\"\n\n    return NSGANet(pop_size=pop_size,\n                   sampling=sampling,\n                   selection=selection,\n                   crossover=crossover,\n                   mutation=mutation,\n                   survival=RankAndCrowdingSurvival(),\n                   eliminate_duplicates=eliminate_duplicates,\n                   n_offsprings=n_offsprings,\n                   **kwargs)\n\n\nparse_doc_string(nsganet)\n"
  },
  {
    "path": "examples/Structure_Evolution/ELSM/spikes.py",
    "content": "from __future__ import print_function\nimport torchvision\nimport torchvision.transforms as transforms\nimport os\nimport numpy as np\nimport torch\nfrom torch import nn as nn\nfrom model import *\nfrom tqdm import tqdm\nimport argparse\nfrom datetime import datetime\nimport logging\nfrom timm.utils import *\nfrom spikingjelly.datasets.n_mnist import NMNIST\nfrom timm.loss import LabelSmoothingCrossEntropy\nfrom braincog.base.utils.criterions import *\nimport networkx as nx\nimport time\nfrom braincog.base.learningrule.STDP import *\n\ndef randbool(size, p=0.5):\n    return torch.rand(*size) < p\n\ndef calc_f2(con,device):       \n    batch_size=1\n    liquid_size=8000\n    images=torch.load('/1000images.pt')\n    labels=torch.load('/1000labels.pt')\n\n    load_path='970.t7'\n\n\n    snn = nSNN(ins=2312,\n            batchsize=batch_size,\n            device=device,\n            liquid_size=liquid_size,\n            lsm_tau=2.0,\n            lsm_th=0.20,\n            connectivity_matrix=randbool([liquid_size, liquid_size],p=0.01).to(device).int())\n\n    snn.load_state_dict(torch.load(load_path,map_location={'cuda:2':device})['fc'])\n    snn.con[0].load_state_dict(torch.load(load_path,map_location={'cuda:2':device})['lsm0'])\n\n    snn.to(device)\n    criterion = UnilateralMse(1.)\n\n    optimizer = torch.optim.AdamW(snn.fc.parameters(),lr=0.001, weight_decay=1e-4)\n\n    k=0\n    sbr=0\n    snn.connectivity_matrix=con\n    snn.learning_rule=[]\n    w2tmp=nn.Linear(liquid_size,liquid_size,bias=False,device=device)\n\n    w2tmp.weight.data=(torch.load(load_path,map_location={'cuda:2':device})['liquid_weight'])*snn.connectivity_matrix\n    snn.learning_rule.append(MutliInputSTDP(snn.node_lsm(), [snn.con[0], w2tmp])) \n    snn.eval()\n    for label,data in zip(labels,images):\n        running_loss = 0\n        snn.zero_grad()\n        optimizer.zero_grad()\n        data = data.to(device)\n        label = label.to(device)\n        data=data.reshape(batch_size,data.shape[0],-1) \n        output = snn(data)\n        # print(torch.argmax(output)==label)\n\n        out_liquid=snn.firing_tw.squeeze(-2)\n\n        mupost=torch.matmul(con,out_liquid.unsqueeze(-1))\n        mupre=torch.matmul(con.t(),out_liquid.unsqueeze(-1))\n        for t in range(snn.tw):\n            if t>5 and t<snn.tw-5:\n                mupost[t] = torch.sum(mupost[t+1:t+5],dim=0)\n                mupre[t] = torch.sum(mupre[t-5:t-1],dim=0)\n        br=mupost/mupre\n        br[torch.isnan(br)] = 0\n        br[torch.isinf(br)] = 0\n        br=(torch.sum(out_liquid*br.squeeze(-1),dim=1)/torch.sum(out_liquid,dim=1)).sum()/snn.tw\n        if torch.isnan(br):\n            continue\n        k+=1\n        if k==500:\n            break\n\n        sbr+=br\n\n        snn.reset()\n    # print(sbr/k)\n\n    return sbr/k\n\n    \n"
  },
  {
    "path": "examples/Structure_Evolution/MSE-NAS/auto_augment.py",
    "content": "# -*- coding: utf-8 -*-\n#\n# modified from: https://github.com/DeepVoltaire/AutoAugment/\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nfrom PIL import Image, ImageEnhance, ImageOps\nimport random\n\n\nclass ImageNetPolicy(object):\n    \"\"\" Randomly choose one of the best 24 Sub-policies on ImageNet.\n\n            Example:\n                    policy = ImageNetPolicy()\n                    transformed = policy(image)\n\n            Example as a PyTorch Transform:\n                    transform = transforms.Compose([\n                            transforms.Resize(256),\n                            ImageNetPolicy(),\n                            transforms.ToTensor()])\n    \"\"\"\n\n    def __init__(self, fillcolor=(128, 128, 128)):\n        self.policies = [\n            SubPolicy(0.4, \"posterize\", 8, 0.6, \"rotate\", 9, fillcolor),\n            SubPolicy(0.6, \"solarize\", 5, 0.6, \"autocontrast\", 5, fillcolor),\n            SubPolicy(0.8, \"equalize\", 8, 0.6, \"equalize\", 3, fillcolor),\n            SubPolicy(0.6, \"posterize\", 7, 0.6, \"posterize\", 6, fillcolor),\n            SubPolicy(0.4, \"equalize\", 7, 0.2, \"solarize\", 4, fillcolor),\n\n            SubPolicy(0.4, \"equalize\", 4, 0.8, \"rotate\", 8, fillcolor),\n            SubPolicy(0.6, \"solarize\", 3, 0.6, \"equalize\", 7, fillcolor),\n            SubPolicy(0.8, \"posterize\", 5, 1.0, \"equalize\", 2, fillcolor),\n            SubPolicy(0.2, \"rotate\", 3, 0.6, \"solarize\", 8, fillcolor),\n            SubPolicy(0.6, \"equalize\", 8, 0.4, \"posterize\", 6, fillcolor),\n\n            SubPolicy(0.8, \"rotate\", 8, 0.4, \"color\", 0, fillcolor),\n            SubPolicy(0.4, \"rotate\", 9, 0.6, \"equalize\", 2, fillcolor),\n            SubPolicy(0.0, \"equalize\", 7, 0.8, \"equalize\", 8, fillcolor),\n            SubPolicy(0.6, \"invert\", 4, 1.0, \"equalize\", 8, fillcolor),\n            SubPolicy(0.6, \"color\", 4, 1.0, \"contrast\", 8, fillcolor),\n\n            SubPolicy(0.8, \"rotate\", 8, 1.0, \"color\", 2, fillcolor),\n            SubPolicy(0.8, \"color\", 8, 0.8, \"solarize\", 7, fillcolor),\n            SubPolicy(0.4, \"sharpness\", 7, 0.6, \"invert\", 8, fillcolor),\n            SubPolicy(0.6, \"shearX\", 5, 1.0, \"equalize\", 9, fillcolor),\n            SubPolicy(0.4, \"color\", 0, 0.6, \"equalize\", 3, fillcolor),\n\n            SubPolicy(0.4, \"equalize\", 7, 0.2, \"solarize\", 4, fillcolor),\n            SubPolicy(0.6, \"solarize\", 5, 0.6, \"autocontrast\", 5, fillcolor),\n            SubPolicy(0.6, \"invert\", 4, 1.0, \"equalize\", 8, fillcolor),\n            SubPolicy(0.6, \"color\", 4, 1.0, \"contrast\", 8, fillcolor),\n            SubPolicy(0.8, \"equalize\", 8, 0.6, \"equalize\", 3, fillcolor)]\n\n    def __call__(self, img):\n        policy_idx = random.randint(0, len(self.policies) - 1)\n        return self.policies[policy_idx](img)\n\n    def __repr__(self):\n        return \"AutoAugment ImageNet Policy\"\n\n\nclass CIFAR10Policy(object):\n    \"\"\" Randomly choose one of the best 25 Sub-policies on CIFAR10.\n\n            Example:\n                    policy = CIFAR10Policy()\n                    transformed = policy(image)\n\n            Example as a PyTorch Transform:\n                    transform = transforms.Compose([\n                            transforms.Resize(256),\n                            CIFAR10Policy(),\n                            transforms.ToTensor()])\n    \"\"\"\n\n    def __init__(self, fillcolor=(128, 128, 128)):\n        self.policies = [\n            SubPolicy(0.1, \"invert\", 7, 0.2, \"contrast\", 6, fillcolor),\n            SubPolicy(0.7, \"rotate\", 2, 0.3, \"translateX\", 9, fillcolor),\n            SubPolicy(0.8, \"sharpness\", 1, 0.9, \"sharpness\", 3, fillcolor),\n            SubPolicy(0.5, \"shearY\", 8, 0.7, \"translateY\", 9, fillcolor),\n            SubPolicy(0.5, \"autocontrast\", 8, 0.9, \"equalize\", 2, fillcolor),\n\n            SubPolicy(0.2, \"shearY\", 7, 0.3, \"posterize\", 7, fillcolor),\n            SubPolicy(0.4, \"color\", 3, 0.6, \"brightness\", 7, fillcolor),\n            SubPolicy(0.3, \"sharpness\", 9, 0.7, \"brightness\", 9, fillcolor),\n            SubPolicy(0.6, \"equalize\", 5, 0.5, \"equalize\", 1, fillcolor),\n            SubPolicy(0.6, \"contrast\", 7, 0.6, \"sharpness\", 5, fillcolor),\n\n            SubPolicy(0.7, \"color\", 7, 0.5, \"translateX\", 8, fillcolor),\n            SubPolicy(0.3, \"equalize\", 7, 0.4, \"autocontrast\", 8, fillcolor),\n            SubPolicy(0.4, \"translateY\", 3, 0.2, \"sharpness\", 6, fillcolor),\n            SubPolicy(0.9, \"brightness\", 6, 0.2, \"color\", 8, fillcolor),\n            SubPolicy(0.5, \"solarize\", 2, 0.0, \"invert\", 3, fillcolor),\n\n            SubPolicy(0.2, \"equalize\", 0, 0.6, \"autocontrast\", 0, fillcolor),\n            SubPolicy(0.2, \"equalize\", 8, 0.8, \"equalize\", 4, fillcolor),\n            SubPolicy(0.9, \"color\", 9, 0.6, \"equalize\", 6, fillcolor),\n            SubPolicy(0.8, \"autocontrast\", 4, 0.2, \"solarize\", 8, fillcolor),\n            SubPolicy(0.1, \"brightness\", 3, 0.7, \"color\", 0, fillcolor),\n\n            SubPolicy(0.4, \"solarize\", 5, 0.9, \"autocontrast\", 3, fillcolor),\n            SubPolicy(0.9, \"translateY\", 9, 0.7, \"translateY\", 9, fillcolor),\n            SubPolicy(0.9, \"autocontrast\", 2, 0.8, \"solarize\", 3, fillcolor),\n            SubPolicy(0.8, \"equalize\", 8, 0.1, \"invert\", 3, fillcolor),\n            SubPolicy(0.7, \"translateY\", 9, 0.9, \"autocontrast\", 1, fillcolor)]\n\n    def __call__(self, img):\n        policy_idx = random.randint(0, len(self.policies) - 1)\n        return self.policies[policy_idx](img)\n\n    def __repr__(self):\n        return \"AutoAugment CIFAR10 Policy\"\n\n\nclass SVHNPolicy(object):\n    \"\"\" Randomly choose one of the best 25 Sub-policies on SVHN.\n\n            Example:\n                    policy = SVHNPolicy()\n                    transformed = policy(image)\n\n            Example as a PyTorch Transform:\n                    transform = transforms.Compose([\n                            transforms.Resize(256),\n                            SVHNPolicy(),\n                            transforms.ToTensor()])\n    \"\"\"\n\n    def __init__(self, fillcolor=(128, 128, 128)):\n        self.policies = [\n            SubPolicy(0.9, \"shearX\", 4, 0.2, \"invert\", 3, fillcolor),\n            SubPolicy(0.9, \"shearY\", 8, 0.7, \"invert\", 5, fillcolor),\n            SubPolicy(0.6, \"equalize\", 5, 0.6, \"solarize\", 6, fillcolor),\n            SubPolicy(0.9, \"invert\", 3, 0.6, \"equalize\", 3, fillcolor),\n            SubPolicy(0.6, \"equalize\", 1, 0.9, \"rotate\", 3, fillcolor),\n\n            SubPolicy(0.9, \"shearX\", 4, 0.8, \"autocontrast\", 3, fillcolor),\n            SubPolicy(0.9, \"shearY\", 8, 0.4, \"invert\", 5, fillcolor),\n            SubPolicy(0.9, \"shearY\", 5, 0.2, \"solarize\", 6, fillcolor),\n            SubPolicy(0.9, \"invert\", 6, 0.8, \"autocontrast\", 1, fillcolor),\n            SubPolicy(0.6, \"equalize\", 3, 0.9, \"rotate\", 3, fillcolor),\n\n            SubPolicy(0.9, \"shearX\", 4, 0.3, \"solarize\", 3, fillcolor),\n            SubPolicy(0.8, \"shearY\", 8, 0.7, \"invert\", 4, fillcolor),\n            SubPolicy(0.9, \"equalize\", 5, 0.6, \"translateY\", 6, fillcolor),\n            SubPolicy(0.9, \"invert\", 4, 0.6, \"equalize\", 7, fillcolor),\n            SubPolicy(0.3, \"contrast\", 3, 0.8, \"rotate\", 4, fillcolor),\n\n            SubPolicy(0.8, \"invert\", 5, 0.0, \"translateY\", 2, fillcolor),\n            SubPolicy(0.7, \"shearY\", 6, 0.4, \"solarize\", 8, fillcolor),\n            SubPolicy(0.6, \"invert\", 4, 0.8, \"rotate\", 4, fillcolor),\n            SubPolicy(0.3, \"shearY\", 7, 0.9, \"translateX\", 3, fillcolor),\n            SubPolicy(0.1, \"shearX\", 6, 0.6, \"invert\", 5, fillcolor),\n\n            SubPolicy(0.7, \"solarize\", 2, 0.6, \"translateY\", 7, fillcolor),\n            SubPolicy(0.8, \"shearY\", 4, 0.8, \"invert\", 8, fillcolor),\n            SubPolicy(0.7, \"shearX\", 9, 0.8, \"translateY\", 3, fillcolor),\n            SubPolicy(0.8, \"shearY\", 5, 0.7, \"autocontrast\", 3, fillcolor),\n            SubPolicy(0.7, \"shearX\", 2, 0.1, \"invert\", 5, fillcolor)]\n\n    def __call__(self, img):\n        policy_idx = random.randint(0, len(self.policies) - 1)\n        return self.policies[policy_idx](img)\n\n    def __repr__(self):\n        return \"AutoAugment SVHN Policy\"\n\n\nclass SubPolicy(object):\n\n    def __init__(self,\n                 p1,\n                 operation1,\n                 magnitude_idx1,\n                 p2,\n                 operation2,\n                 magnitude_idx2,\n                 fillcolor=(128, 128, 128)):\n        ranges = {\n            \"shearX\": np.linspace(0, 0.3, 10),\n            \"shearY\": np.linspace(0, 0.3, 10),\n            \"translateX\": np.linspace(0, 150 / 331, 10),\n            \"translateY\": np.linspace(0, 150 / 331, 10),\n            \"rotate\": np.linspace(0, 30, 10),\n            \"color\": np.linspace(0.0, 0.9, 10),\n            \"posterize\": np.round(np.linspace(8, 4, 10), 0).astype(np.int),\n            \"solarize\": np.linspace(256, 0, 10),\n            \"contrast\": np.linspace(0.0, 0.9, 10),\n            \"sharpness\": np.linspace(0.0, 0.9, 10),\n            \"brightness\": np.linspace(0.0, 0.9, 10),\n            \"autocontrast\": [0] * 10,\n            \"equalize\": [0] * 10,\n            \"invert\": [0] * 10}\n\n        # from https://stackoverflow.com/questions/5252170/specify-image\n        # -filling-color-when-rotating-in-python-with-pil-and-setting-expand\n        def rotate_with_fill(img, magnitude):\n            rot = img.convert(\"RGBA\").rotate(magnitude)\n            return Image.composite(\n                rot, Image.new(\"RGBA\", rot.size, (128,) * 4), rot) \\\n                .convert(img.mode)\n\n        func = {\n            \"shearX\": lambda img, magnitude: img.transform(\n                img.size,\n                Image.AFFINE,\n                (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),\n                Image.BICUBIC,\n                fillcolor=fillcolor),\n            \"shearY\": lambda img, magnitude: img.transform(\n                img.size,\n                Image.AFFINE,\n                (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),\n                Image.BICUBIC,\n                fillcolor=fillcolor),\n            \"translateX\": lambda img, magnitude: img.transform(\n                img.size,\n                Image.AFFINE,\n                (1, 0, magnitude * img.size[0] *\n                 random.choice([-1, 1]), 0, 1, 0),\n                fillcolor=fillcolor),\n            \"translateY\": lambda img, magnitude: img.transform(\n                img.size,\n                Image.AFFINE,\n                (1, 0, 0, 0, 1, magnitude *\n                 img.size[1] * random.choice([-1, 1])),\n                fillcolor=fillcolor),\n            \"rotate\": lambda img, magnitude: rotate_with_fill(img, magnitude),\n            # \"rotate\": lambda img, magnitude: \\\n            #     img.rotate(magnitude * random.choice([-1, 1])),\n            \"color\": lambda img, magnitude: \\\n            ImageEnhance.Color(img).enhance(\n                1 + magnitude * random.choice([-1, 1])),\n            \"posterize\": lambda img, magnitude: \\\n            ImageOps.posterize(img, magnitude),\n            \"solarize\": lambda img, magnitude: \\\n            ImageOps.solarize(img, magnitude),\n            \"contrast\": lambda img, magnitude: \\\n            ImageEnhance.Contrast(img).enhance(\n                1 + magnitude * random.choice([-1, 1])),\n            \"sharpness\": lambda img, magnitude: \\\n            ImageEnhance.Sharpness(img).enhance(\n                1 + magnitude * random.choice([-1, 1])),\n            \"brightness\": lambda img, magnitude: \\\n            ImageEnhance.Brightness(img).enhance(\n                1 + magnitude * random.choice([-1, 1])),\n            \"autocontrast\": lambda img, magnitude: ImageOps.autocontrast(img),\n            \"equalize\": lambda img, magnitude: ImageOps.equalize(img),\n            \"invert\": lambda img, magnitude: ImageOps.invert(img)\n        }\n\n        # self.name = \"{}_{:.2f}_and_{}_{:.2f}\".format(\n        #     operation1, ranges[operation1][magnitude_idx1],\n        #     operation2, ranges[operation2][magnitude_idx2])\n        self.p1 = p1\n        self.operation1 = func[operation1]\n        self.magnitude1 = ranges[operation1][magnitude_idx1]\n        self.p2 = p2\n        self.operation2 = func[operation2]\n        self.magnitude2 = ranges[operation2][magnitude_idx2]\n\n    def __call__(self, img):\n        if random.random() < self.p1:\n            img = self.operation1(img, self.magnitude1)\n        if random.random() < self.p2:\n            img = self.operation2(img, self.magnitude2)\n        return img\n"
  },
  {
    "path": "examples/Structure_Evolution/MSE-NAS/cellmodel.py",
    "content": "from functools import partial\nfrom typing import List, Type\n\nfrom operations import *\nfrom motifs import *\nfrom utils import drop_path\nfrom timm.models import register_model\nfrom braincog.base.node.node import *\nfrom braincog.base.connection.layer import *\nfrom braincog.model_zoo.base_module import BaseModule\nfrom motifs import *\nfrom torchvision import transforms\n\nEVO=True\nclass EvoCell2(nn.Module):\n    def __init__(self,motif, C_prev_prev, C_prev, C, reduction, reduction_prev, act_fun):\n        # print(C_prev_prev, C_prev, C, reduction)\n        super(EvoCell2, self).__init__()\n        self.act_fun = act_fun\n        self.reduction = reduction\n        self.motif=motif\n        self.back_connection=False\n        if reduction:\n            self.fun = FactorizedReduce(\n                C_prev, C * 3, act_fun=act_fun\n            )\n            self.multiplier = 3\n        else:\n            if reduction_prev:\n                self.preprocess0 = FactorizedReduce(\n                    C_prev_prev, C, act_fun=act_fun)\n            else:\n                self.preprocess0 = ReLUConvBN(\n                    C_prev_prev, C, 1, 1, 0, act_fun=act_fun)\n            self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, act_fun=act_fun)\n\n            op_names, indices = zip(*motif.normal)\n            concat = motif.normal_concat\n            self._compile(C, op_names, indices, concat, reduction)\n\n    def _compile(self, C, op_names, indices, concat, reduction):\n        assert len(op_names) == len(indices)\n        # self._steps = len(op_names) // 2\n        self._concat = concat\n        self.multiplier = len(concat)\n\n        self._ops = nn.ModuleList()\n        self._ops_back = nn.ModuleList()\n        back_begin_index = 0\n        for i, (name, index) in enumerate(zip(op_names, indices)):\n            # print(name, index)\n            if '_back' in name:\n                self.back_connection=True\n                back_begin_index = i\n                break\n            stride = 2 if reduction and index < 2 else 1\n            op = OPS[name](C, stride, True, act_fun=self.act_fun)\n            self._ops += [op]\n\n        if self.back_connection:\n            for name, index in zip(op_names[back_begin_index:], indices[back_begin_index:]):\n                op = OPS[name.replace('_back', '')](\n                    C, 1, True, act_fun=self.act_fun)\n                self._ops_back += [op]\n\n        if self.back_connection:\n            self._indices_forward = indices[:back_begin_index]\n            self._indices_backward = indices[back_begin_index:]\n        else:\n            self._indices_backward = []\n            self._indices_forward = indices\n        self._steps = len(self._indices_forward) // 2\n\n    def forward(self, s0, s1, drop_prob):\n        if self.reduction:\n            return self.fun(s1)\n        # print('s0',s0.shape)\n        s0 = self.preprocess0(s0)\n        # print(s0.shape)\n        # print('s1',s1.shape)\n        s1 = self.preprocess1(s1)\n        # print(s1.shape)\n\n        states = [s0, s1]\n        for i in range(self._steps):\n            i1=self._indices_forward[2 * i]\n            i2=self._indices_forward[2 * i + 1]\n            h1 = states[i1]\n            h2 = states[i2]\n            op1 = self._ops[2 * i]\n            op2 = self._ops[2 * i + 1]\n            h1 = op1(h1)\n            h2 = op2(h2)\n            if self.training and drop_prob > 0.:\n                if not isinstance(op1, Identity):\n                    h1 = drop_path(h1, drop_prob)\n                if not isinstance(op2, Identity):\n                    h2 = drop_path(h2, drop_prob)\n            s = h1 + h2\n            \n            if self.back_connection:\n                if i != 0:\n                    s_back = self._ops_back[i - 1](s)\n                    states[self._indices_backward[i - 1]] = states[self._indices_backward[i - 1]] + s_back\n            states += [s]\n        \n            \n        \n        outputs = torch.cat([states[i]\n                            for i in self._concat], dim=1)  # N，C，H, W\n        return outputs\n        # return self.node(outputs)\n\n\nclass EvoCell3(nn.Module):\n    def __init__(self,motif, C_prev_prev_prev, C_prev_prev, C_prev, C, reduction, reduction_prev, reduction_prev_prev, act_fun):\n        # print(C_prev_prev_prev,C_prev_prev, C_prev, C, reduction,reduction_prev, reduction_prev_prev)\n\n        super(EvoCell3, self).__init__()\n        self.act_fun = act_fun\n        self.reduction = reduction\n        self.motif=motif\n        self.back_connection=False\n        if reduction:\n            self.fun = FactorizedReduce(C_prev, C * 3, act_fun=act_fun)\n            self.multiplier = 3\n        else:\n\n            if reduction_prev:\n                self.preprocess1 = FactorizedReduce(C_prev_prev, C, act_fun=act_fun)\n            else:\n                self.preprocess1 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, act_fun=act_fun)\n\n\n            if int(reduction_prev_prev)+int(reduction_prev)==1:\n                self.preprocess0 = FactorizedReduce(C_prev_prev_prev, C, act_fun=act_fun)\n            elif int(reduction_prev_prev)+int(reduction_prev)==2:\n                self.preprocess0 = F0(C_prev_prev_prev, C, act_fun=act_fun)\n            else:\n                self.preprocess0 = ReLUConvBN(C_prev_prev_prev, C, 1, 1, 0, act_fun=act_fun)\n\n\n            self.preprocess2 = ReLUConvBN(C_prev, C, 1, 1, 0, act_fun=act_fun)\n\n\n            op_names, indices = zip(*motif.normal)\n            concat = motif.normal_concat\n            self._compile(C, op_names, indices, concat, reduction)\n    def _compile(self, C, op_names, indices, concat, reduction):\n        assert len(op_names) == len(indices)\n        # self._steps = len(op_names) // 2\n        self._concat = concat\n        self.multiplier = len(concat)\n\n        self._ops = nn.ModuleList()\n        self._ops_back = nn.ModuleList()\n        back_begin_index = 0\n        for i, (name, index) in enumerate(zip(op_names, indices)):\n            # print(name, index)\n            if '_back' in name:\n                self.back_connection=True\n                back_begin_index = i\n                break\n            stride = 2 if reduction and index < 2 else 1\n            op = OPS[name](C, stride, True, act_fun=self.act_fun)\n            self._ops += [op]\n\n        if self.back_connection:\n            for name, index in zip(op_names[back_begin_index:], indices[back_begin_index:]):\n                op = OPS[name.replace('_back', '')](\n                    C, 1, True, act_fun=self.act_fun)\n                self._ops_back += [op]\n\n        if self.back_connection:\n            self._indices_forward = indices[:back_begin_index]\n            self._indices_backward = indices[back_begin_index:]\n        else:\n            self._indices_backward = []\n            self._indices_forward = indices\n        self._steps = len(self._indices_forward) // 3\n\n    def forward(self, s0, s1, s2, drop_prob):\n        if self.reduction:\n            return self.fun(s2)\n\n        s0 = self.preprocess0(s0)\n\n        s1 = self.preprocess1(s1)\n        s2 = self.preprocess2(s2)\n\n        states = [s0, s1, s2]\n\n        for i in range(self._steps):\n            i1=self._indices_forward[3 * i]\n            i2=self._indices_forward[3 * i + 1]\n            i3=self._indices_forward[3 * i + 2]\n\n            h1 = states[i1]\n            h2 = states[i2]\n            h3 = states[i3]\n\n            op1 = self._ops[3 * i]\n            op2 = self._ops[3 * i + 1]\n            op3 = self._ops[3 * i + 2]\n            h1 = op1(h1)\n            h2 = op2(h2)\n            h3 = op3(h3)\n\n            if self.training and drop_prob > 0.:\n                if not isinstance(op1, Identity):\n                    h1 = drop_path(h1, drop_prob)\n                if not isinstance(op2, Identity):\n                    h2 = drop_path(h2, drop_prob)                \n                if not isinstance(op3, Identity):\n                    h3 = drop_path(h3, drop_prob)\n            s = h1 + h2 + h3\n            \n            if self.back_connection:\n                if i != 0:\n                    s_back = self._ops_back[i - 1](s)\n                    states[self._indices_backward[i - 1]] = states[self._indices_backward[i - 1]] + s_back\n            states += [s]\n        \n            \n        \n        outputs = torch.cat([states[i] for i in self._concat], dim=1)  # N，C，H, W\n        return outputs\n        # return self.node(outputs)\n\nclass EvoCell4(nn.Module):\n    def __init__(self,motif, C_prev_prev_prev_prev,C_prev_prev_prev, C_prev_prev, C_prev, C, reduction, reduction_prev, reduction_prev_prev,reduction_prev_prev_prev, act_fun):\n        # print(C_prev_prev_prev_prev,C_prev_prev_prev,C_prev_prev, C_prev, C, reduction,reduction_prev, reduction_prev_prev,reduction_prev_prev_prev)\n\n        super(EvoCell4, self).__init__()\n        self.act_fun = act_fun\n        self.reduction = reduction\n        self.motif=motif\n        self.back_connection=False\n        if reduction:\n            self.fun = FactorizedReduce(C_prev, C * 3, act_fun=act_fun)\n            self.multiplier = 3\n        else:\n\n            if reduction_prev:\n                self.preprocess2 = FactorizedReduce(C_prev_prev, C, act_fun=act_fun)\n            else:\n                self.preprocess2 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, act_fun=act_fun)\n\n            if int(reduction_prev_prev)+int(reduction_prev)==1:\n                self.preprocess1 = FactorizedReduce(C_prev_prev_prev, C, act_fun=act_fun)\n            elif int(reduction_prev_prev)+int(reduction_prev)==2:\n                self.preprocess1 = F0(C_prev_prev_prev, C, act_fun=act_fun)\n            else:\n                self.preprocess1 = ReLUConvBN(C_prev_prev_prev, C, 1, 1, 0, act_fun=act_fun)\n            \n            if int(reduction_prev_prev_prev)+int(reduction_prev_prev)+int(reduction_prev)==1:\n                self.preprocess0 = FactorizedReduce(C_prev_prev_prev_prev, C, act_fun=act_fun)\n            elif int(reduction_prev_prev_prev)+int(reduction_prev_prev)+int(reduction_prev)==2:\n                self.preprocess0 = F0(C_prev_prev_prev_prev, C, act_fun=act_fun)            \n            elif int(reduction_prev_prev_prev)+int(reduction_prev_prev)+int(reduction_prev)==3:\n                self.preprocess0 = F1(C_prev_prev_prev_prev, C, act_fun=act_fun)\n            else:\n                self.preprocess0 = ReLUConvBN(C_prev_prev_prev_prev, C, 1, 1, 0, act_fun=act_fun)\n\n\n\n            self.preprocess3 = ReLUConvBN(C_prev, C, 1, 1, 0, act_fun=act_fun)\n\n\n            op_names, indices = zip(*motif.normal)\n            # print(self.preprocess0)\n            # print(self.preprocess1)\n            # print(self.preprocess2)\n            # print(self.preprocess3)\n            concat = motif.normal_concat\n            self._compile(C, op_names, indices, concat, reduction)\n    def _compile(self, C, op_names, indices, concat, reduction):\n        assert len(op_names) == len(indices)\n        # self._steps = len(op_names) // 2\n        self._concat = concat\n        self.multiplier = len(concat)\n\n        self._ops = nn.ModuleList()\n        self._ops_back = nn.ModuleList()\n        back_begin_index = 0\n        for i, (name, index) in enumerate(zip(op_names, indices)):\n            # print(name, index)\n            if '_back' in name:\n                self.back_connection=True\n                back_begin_index = i\n                break\n            stride = 2 if reduction and index < 2 else 1\n            op = OPS[name](C, stride, True, act_fun=self.act_fun)\n            self._ops += [op]\n\n        if self.back_connection:\n            for name, index in zip(op_names[back_begin_index:], indices[back_begin_index:]):\n                op = OPS[name.replace('_back', '')](\n                    C, 1, True, act_fun=self.act_fun)\n                self._ops_back += [op]\n\n        if self.back_connection:\n            self._indices_forward = indices[:back_begin_index]\n            self._indices_backward = indices[back_begin_index:]\n        else:\n            self._indices_backward = []\n            self._indices_forward = indices\n        self._steps = len(self._indices_forward) // 4\n\n    def forward(self, s0, s1, s2, s3, drop_prob):\n        if self.reduction:\n            return self.fun(s3)\n\n        s0 = self.preprocess0(s0)\n        s3 = self.preprocess3(s3)\n        s1 = self.preprocess1(s1)\n        s2 = self.preprocess2(s2)\n\n        # if s1.shape[1]!=s3.shape[1]:\n        #     s1 = nn.Conv2d(s1.shape[1], s3.shape[1], 3, stride=2, padding=1, bias=False)\n\n        states = [s0, s1, s2,s3]\n\n        for i in range(self._steps):\n            i1=self._indices_forward[4 * i]\n            i2=self._indices_forward[4 * i + 1]\n            i3=self._indices_forward[4 * i + 2]\n            i4=self._indices_forward[4 * i + 3]\n\n            h1 = states[i1]\n            h2 = states[i2]\n            h3 = states[i3]\n            h4 = states[i4]\n\n            op1 = self._ops[4 * i]\n            op2 = self._ops[4 * i + 1]\n            op3 = self._ops[4 * i + 2]\n            op4 = self._ops[4 * i + 3]\n            h1 = op1(h1)\n            h2 = op2(h2)\n            h3 = op3(h3)\n            h4 = op4(h4)\n\n            if self.training and drop_prob > 0.:\n                if not isinstance(op1, Identity):\n                    h1 = drop_path(h1, drop_prob)\n                if not isinstance(op2, Identity):\n                    h2 = drop_path(h2, drop_prob)                \n                if not isinstance(op3, Identity):\n                    h3 = drop_path(h3, drop_prob)                \n                if not isinstance(op4, Identity):\n                    h4= drop_path(h4, drop_prob)\n            s = h1 + h2 + h3 + h4\n            \n            if self.back_connection:\n                if i != 0:\n                    s_back = self._ops_back[i - 1](s)\n                    states[self._indices_backward[i - 1]] = states[self._indices_backward[i - 1]] + s_back\n            states += [s]\n        \n            \n        \n        outputs = torch.cat([states[i] for i in self._concat], dim=1)  # N，C，H, W\n        return outputs\n        # return self.node(outputs)\n\n\n\n\n@register_model\nclass NetworkCIFAR(BaseModule):\n\n    def __init__(self,\n                 C,\n                 num_classes,\n                 layers,\n                 auxiliary,\n                 motif,\n                 cell_type,\n                 parse_method='darts',\n                 step=5,\n                 node_type='ReLUNode',\n                 **kwargs):\n        super(NetworkCIFAR, self).__init__(\n            step=step,\n            num_classes=num_classes,\n            **kwargs\n        )\n        self.node_type=node_type\n        if isinstance(node_type, str):\n            self.act_fun = eval(node_type)\n        else:\n            self.act_fun = node_type\n        self.act_fun = partial(self.act_fun, **kwargs)\n        \n        self.spike_output = kwargs['spike_output'] if 'spike_output' in kwargs else True\n        self.dataset = kwargs['dataset']\n\n        if self.layer_by_layer:\n            self.flatten = nn.Flatten(start_dim=1)\n        else:\n            self.flatten = nn.Flatten()\n\n        self._layers = layers\n        self.cell_type = cell_type\n        self._auxiliary = auxiliary\n\n        self.drop_path_prob = 0\n\n        stem_multiplier = 3\n        C_curr = stem_multiplier * C\n        if self.dataset == 'dvsg' or self.dataset == 'dvsc10' or self.dataset == 'NCALTECH101':\n            self.stem = nn.Sequential(\n                nn.Conv2d(2 * self.init_channel_mul, C_curr, 3, padding=1, bias=False),\n                nn.BatchNorm2d(C_curr),\n            )\n            # self.reduce_idx = [\n            #     layers // 4,\n            #     layers // 2,\n            #     3 * layers // 4\n            # ]\n            self.reduce_idx = [1, 3, 5, 7]\n        else:\n            self.stem = nn.Sequential(\n                nn.Conv2d(1, 3 * self.init_channel_mul, 3, padding=1, bias=False),\n                nn.Conv2d(3 * self.init_channel_mul, C_curr, 3, padding=1, bias=False),\n                nn.BatchNorm2d(C_curr),\n            )\n            self.reduce_idx = [layers // 4,\n                               layers // 2,\n                               3 * layers // 4]\n        C_prev_prev_prev = C_curr\n        C_prev_prev_prev_prev = C_curr\n\n        C_prev_prev, C_prev, C_curr = C_curr, C_curr, C\n        self.cells = nn.ModuleList()\n        reduction_prev = False\n        reduction_prev_prev = False\n        reduction_prev_prev_prev = False\n\n\n        for i in range(layers):\n            if i in self.reduce_idx:\n                C_curr *= 2\n                reduction = True\n            else:\n                reduction = False\n\n            if cell_type==2:\n                # print(C_prev_prev, C_prev, C_curr)\n\n                cell = EvoCell2(motif[i], C_prev_prev, C_prev, C_curr,reduction, reduction_prev,act_fun=self.act_fun)\n                self.cells += [cell]\n                C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr\n\n            if cell_type==3:\n                cell = EvoCell3(motif[i], C_prev_prev_prev, C_prev_prev, C_prev, C_curr,reduction, reduction_prev,reduction_prev_prev,act_fun=self.act_fun)  \n                self.cells += [cell]\n                C_prev_prev_prev = C_prev_prev\n                reduction_prev_prev = reduction_prev\n\n                C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr\n\n            if cell_type==4:\n                cell = EvoCell4(motif[i], C_prev_prev_prev_prev,C_prev_prev_prev, C_prev_prev, C_prev, C_curr,reduction, reduction_prev,reduction_prev_prev,reduction_prev_prev_prev,act_fun=self.act_fun)  \n                self.cells += [cell]\n                C_prev_prev_prev_prev = C_prev_prev_prev\n                C_prev_prev_prev = C_prev_prev\n                reduction_prev_prev_prev = reduction_prev_prev\n                reduction_prev_prev = reduction_prev\n\n                C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr\n\n\n            reduction_prev = reduction\n\n\n        self.global_pooling = nn.Sequential(\n            self.act_fun(), nn.AdaptiveAvgPool2d(1))\n\n        if self.spike_output:\n            self.classifier = nn.Sequential(\n                nn.Linear(C_prev, 10 * num_classes),\n                self.act_fun())\n            self.vote = VotingLayer(10)\n        else:\n            self.classifier = nn.Linear(C_prev, num_classes)\n            self.vote = nn.Identity()\n\n        # self.classifier = nn.Linear(C_prev, num_classes)\n        # self.vote = nn.Identity()\n\n    def forward(self, inputs):\n        logits_aux = None\n        inputs = self.encoder(inputs)\n        if not self.layer_by_layer:\n            outputs = []\n            output_aux = []\n            self.reset()\n\n            if self.cell_type==2:\n\n                for t in range(self.step):\n                    x = inputs[t]\n                    s0 = s1 = self.stem(x)\n                    for i, cell in enumerate(self.cells):\n                        s0, s1 = s1, cell(s0, s1, self.drop_path_prob)\n                    out = self.global_pooling(s1)\n                    out = self.classifier(self.flatten(out))\n                    logits = self.vote(out)\n                    outputs.append(logits)\n                    output_aux.append(logits_aux)\n                return sum(outputs) / len(outputs)\n\n\n            if self.cell_type==3:\n                for t in range(self.step):\n                    x = inputs[t]\n                    s0 = s1 = s2= self.stem(x)\n                    for i, cell in enumerate(self.cells):\n                        s0, s1, s2 = s1, s2, cell(s0, s1, s2, self.drop_path_prob)\n                    out = self.global_pooling(s2)\n                    out = self.classifier(self.flatten(out))\n                    logits = self.vote(out)\n                    outputs.append(logits)\n                    output_aux.append(logits_aux)\n                return sum(outputs) / len(outputs)\n\n            if self.cell_type==4:\n                for t in range(self.step):\n                    x = inputs[t]\n                    s0 = s1 = s2= s3=self.stem(x)\n                    for i, cell in enumerate(self.cells):\n                        s0, s1, s2,s3= s1, s2, s3,cell(s0, s1, s2,s3 ,self.drop_path_prob)\n\n                    out = self.global_pooling(s3)\n                    out = self.classifier(self.flatten(out))\n                    logits = self.vote(out)\n                    outputs.append(logits)\n                    output_aux.append(logits_aux)\n                return sum(outputs) / len(outputs)\n            # logits_aux if logits_aux is None else (sum(output_aux) / len(output_aux))\n        else:\n            self.reset()\n            if self.cell_type==2:\n\n                s0 = s1 = self.stem(inputs)\n                for i, cell in enumerate(self.cells):\n                    s0, s1 = s1, cell(s0, s1, self.drop_path_prob)\n                    if i == 2 * self._layers // 3:\n                        if self._auxiliary and self.training:\n                            logits_aux = self.auxiliary_head(s1)\n                out = self.global_pooling(s1)\n                out = self.classifier(self.flatten(out))\n                out = rearrange(out, '(t b) c -> t b c', t=self.step).mean(0)\n                logits = self.vote(out)\n                return logits\n\n            if self.cell_type==3:\n                s0 = s1 = s2= self.stem(inputs)\n                for i, cell in enumerate(self.cells):\n                    s0, s1, s2 = s1, s2, cell(s0, s1, s2, self.drop_path_prob)\n                    if i == 2 * self._layers // 3:\n                        if self._auxiliary and self.training:\n                            logits_aux = self.auxiliary_head(s1)\n                out = self.global_pooling(s2)\n                out = self.classifier(self.flatten(out))\n                out = rearrange(out, '(t b) c -> t b c', t=self.step).mean(0)\n                logits = self.vote(out)\n                return logits\n                \n            if self.cell_type==4:\n                s0 = s1 = s2=s3= self.stem(inputs)\n                for i, cell in enumerate(self.cells):\n                    s0, s1, s2,s3= s1, s2, s3,cell(s0, s1, s2,s3 ,self.drop_path_prob)\n                    if i == 2 * self._layers // 3:\n                        if self._auxiliary and self.training:\n                            logits_aux = self.auxiliary_head(s1)\n                out = self.global_pooling(s3)\n                out = self.classifier(self.flatten(out))\n                out = rearrange(out, '(t b) c -> t b c', t=self.step).mean(0)\n                logits = self.vote(out)\n                return logits\n\n\n\n@register_model\nclass NetworkImageNet(BaseModule):\n\n    def __init__(self,\n                 C,\n                 num_classes,\n                 layers,\n                 auxiliary,\n                 motif,\n                 step=1,\n                 node_type='ReLUNode',\n                 **kwargs):\n        super(NetworkImageNet, self).__init__(\n            step=step,\n            num_classes=num_classes,\n            **kwargs)\n\n        if isinstance(node_type, str):\n            self.act_fun = eval(node_type)\n        else:\n            self.act_fun = node_type\n        self.act_fun = partial(self.act_fun, **kwargs)\n\n        if 'back_connection' in kwargs.keys():\n            self.back_connection = kwargs['back_connection']\n        else:\n            self.back_connection = False\n\n        self.spike_output = kwargs['spike_output'] if 'spike_output' in kwargs else True\n\n        if self.layer_by_layer:\n            self.flatten = nn.Flatten(start_dim=1)\n        else:\n            self.flatten = nn.Flatten()\n\n        self._layers = layers\n        self._auxiliary = auxiliary\n        self.drop_path_prob = 0\n\n        self.stem0 = nn.Sequential(\n            nn.Conv2d(3, C // 2, kernel_size=3,\n                      stride=2, padding=1, bias=False),\n            nn.BatchNorm2d(C // 2),\n            # nn.ReLU(inplace=True),\n            self.act_fun(),\n            nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),\n            nn.BatchNorm2d(C),\n        )\n\n        self.stem1 = nn.Sequential(\n            # nn.ReLU(inplace=True),\n            self.act_fun(),\n            nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),\n            nn.BatchNorm2d(C),\n        )\n\n        C_prev_prev, C_prev, C_curr = C, C, C\n\n        self.cells = nn.ModuleList()\n        reduction_prev = True\n        for i in range(layers):\n            if i in [layers // 3, 2 * layers // 3]:\n                C_curr *= 2\n                reduction = True\n            else:\n                reduction = False\n            cell = EvoCell2(motif[i], C_prev_prev, C_prev,C_curr, reduction, reduction_prev,act_fun=self.act_fun)\n\n            reduction_prev = reduction\n            self.cells += [cell]\n            C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr\n\n        self.global_pooling = nn.AvgPool2d(7)\n        self.classifier = nn.Linear(C_prev, num_classes)\n\n    def forward(self, inputs):\n        outputs = []\n        self.reset()\n        for t in range(self.step):\n            s0 = self.stem0(inputs)\n            s1 = self.stem1(s0)\n            for i, cell in enumerate(self.cells):\n                s0, s1 = s1, cell(s0, s1, self.drop_path_prob)\n            out = self.global_pooling(s1)\n            logits = self.classifier(self.flatten(out))\n            outputs.append(logits)\n        return sum(outputs) / len(outputs)\n\n\nif __name__ == '__main__':\n\n    \n    x = torch.rand(128, 36, 32, 32)\n    extra_edge=np.array([[3,5],[4,1]])\n\n    # sort based on head\n\n    extra_edge = extra_edge[extra_edge[:,0].argsort()]\n\n    # motifs=[mm1,mm2,mm3,mm1,mm5,mm4]\n    # motifs=[m1,m2,m3,m1,m5,m4,m1,m2,m3,m1,m5,m4,m5,m4,m1]\n    motifs=[t2,t3,t4,t5,t5,t4,t3,t4]\n\n\n    net=NetworkCIFAR(C=12,num_classes=10,motif=motifs,layers=len(motifs),auxiliary=True,dataset='cifar10',cell_type=4)\n    out=net(torch.rand(128, 3, 32, 32))\n    print(out.shape)\n"
  },
  {
    "path": "examples/Structure_Evolution/MSE-NAS/evolution.py",
    "content": "import sys\nimport numpy as np\nimport argparse\nimport time\nimport obj\nimport timm.models\nimport yaml\nimport os\nimport logging\nfrom random import choice\nfrom collections import OrderedDict\nfrom contextlib import suppress\nfrom datetime import datetime\nfrom micro_encoding import ops\nfrom braincog.base.node.node import *\nfrom braincog.utils import *\nfrom braincog.base.utils.criterions import *\nfrom braincog.datasets.datasets import *\nfrom braincog.model_zoo.resnet import *\nfrom braincog.model_zoo.convnet import *\n\nfrom braincog.utils import save_feature_map, setup_seed\nfrom braincog.base.utils.visualization import plot_tsne_3d, plot_tsne, plot_confusion_matrix\nimport micro_encoding\nimport nsganet as engine\nfrom pymop.problem import Problem\nimport torch\nimport torch.nn as nn\nimport torchvision.utils\nfrom torch.nn.parallel import DistributedDataParallel as NativeDDP\nfrom pymoo.optimize import minimize\nfrom utils import data_transforms\nfrom tm import train_motifs\nfrom timm.data import create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset\nfrom timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model\nfrom timm.utils import *\nfrom timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy\nfrom timm.optim import create_optimizer\nfrom timm.scheduler import create_scheduler\nfrom timm.utils import ApexScaler, NativeScaler\n# os.environ['CUDA_VISIBLE_DEVICES']='3'\nbits=20\n\n\n\n\ntorch.backends.cudnn.benchmark = True\n_logger = logging.getLogger('')\n# The first arg parser parses out only thei --config argument, this argument is used to\n# load a yaml file containing key-values that override the defaults for the main parser below\nconfig_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)\ndevices=[1]\n\nmax_gen = 100\n\nparser = argparse.ArgumentParser(description='SNN Training and Evaluating')\n# Model parameters\nparser.add_argument('--seed', type=int, default=99, metavar='S',\n                    help='random seed (default: 42)')\nparser.add_argument('--eval_epochs', type=int, default=10)\nparser.add_argument('--bns', action='store_true', default=True)\nparser.add_argument('--mid', type=int, default=3)\nparser.add_argument('--trainning_epochs', type=int, default=600, metavar='N',help='number of epochs to train (default: 2)')\nparser.add_argument('--cooldown-epochs', type=int, default=0, metavar='N',help='epochs to cooldown LR at min_lr, after cyclic schedule ends')\nparser.add_argument('--init-channels', type=int, default=48)\nparser.add_argument('--layers', type=int, default=6)\nparser.add_argument('--pop_size', type=int, default=50, help='population size of networks')\nparser.add_argument('--output', default='', type=str, metavar='PATH')\nparser.add_argument('--spike-rate', action='store_true', default=False)\n\nparser.add_argument('--n_gens', type=int, default=max_gen, help='population size')\nparser.add_argument('--bs', type=int, default=100)\n\nparser.add_argument('--n_offspring', type=int, default=50, help='number of offspring created per generation')\nparser.add_argument('-c', '--config', default='', type=str, metavar='FILE',\n                    help='YAML config file specifying default arguments')\nparser.add_argument('--dataset', default='dvsg', type=str)\nparser.add_argument('--num-classes', type=int, default=11, metavar='N',\n                    help='number of label classes (default: 1000)')\nparser.add_argument('--model', default='NetworkCIFAR', type=str, metavar='MODEL',\n                    help='Name of model to train (default: \"countception\"')\nparser.add_argument('--pretrained', action='store_true', default=False,\n                    help='Start with pretrained version of specified network (if avail)')\nparser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',\n                    help='Initialize model from this checkpoint (default: none)')\nparser.add_argument('--resume', default='', type=str, metavar='PATH',\n                    help='Resume full model and optimizer state from checkpoint (default: none)')\nparser.add_argument('--eval_checkpoint', default='', type=str, metavar='PATH',\n                    help='path to eval checkpoint (default: none)')\nparser.add_argument('--no-resume-opt', action='store_true', default=False,\n                    help='prevent resume of optimizer state when resuming model')\n\nparser.add_argument('--gp', default=None, type=str, metavar='POOL',\n                    help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')\n\n# Dataset parameters for static datasets\nparser.add_argument('--img-size', type=int, default=224, metavar='N',\n                    help='Image patch size (default: None => model default)')\nparser.add_argument('--crop-pct', default=None, type=float,\n                    metavar='N', help='inputs image center crop percent (for validation only)')\nparser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',\n                    help='Override mean pixel value of dataset')\nparser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',\n                    help='Override std deviation of of dataset')\nparser.add_argument('--interpolation', default='', type=str, metavar='NAME',\n                    help='Image resize interpolation type (overrides model)')\nparser.add_argument('--strgenome', default='4,0,0,1,1,1,0,1,0,1,0,0,0,1,0,1,0,1,0,0,4,1,1,1,1,1,0,1,1,0,1,1,0,1,1,1,0,1,0,0,4,1,1,0,1,0,0,1,0,1,1,0,1,1,0,0,1,0,1,2,4,1,0,0,1,1,1,0,0,1,0,0,1,0,0,1,0,0,1,1,4,0,1,1,0,1,0,0,1,0,1,0,1,1,1,0,1,0,1,3,4,1,0,1,0,0,1,1,1,0,0,1,0,1,0,0,0,1,0,0,3', type=str)\n\n\n# Dataloader parameters\nparser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',\n                    help='inputs batch size for training (default: 128)')\nparser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',\n                    help='ratio of validation batch size to training batch size (default: 1)')\n\n# Optimizer parameters\nparser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',\n                    help='Optimizer (default: \"adamw\"')\nparser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',\n                    help='Optimizer Epsilon (default: None, use opt default)')\nparser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',\n                    help='Optimizer Betas (default: None, use opt default)')\nparser.add_argument('--momentum', type=float, default=0.9, metavar='M',\n                    help='Optimizer momentum (default: 0.9)')\nparser.add_argument('--weight-decay', type=float, default=0.01,\n                    help='weight decay (default: 0.01 for adamw)')\nparser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',\n                    help='Clip gradient norm (default: None, no clipping)')\nparser.add_argument('--adam-epoch', type=int, default=1000, help='lamb switch to adamw')\n\n# Learning rate schedule parameters\nparser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',\n                    help='LR scheduler (default: \"cosine\"')\nparser.add_argument('--lr', type=float, default=5e-3, metavar='LR',\n                    help='learning rate (default: 0.01)')\nparser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',\n                    help='learning rate noise on/off epoch percentages')\nparser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',\n                    help='learning rate noise limit percent (default: 0.67)')\nparser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',\n                    help='learning rate noise std-dev (default: 1.0)')\nparser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',\n                    help='learning rate cycle len multiplier (default: 1.0)')\nparser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',\n                    help='learning rate cycle limit')\nparser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',\n                    help='warmup learning rate (default: 0.0001)')\nparser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',\n                    help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')\n\nparser.add_argument('--start-epoch', default=None, type=int, metavar='N',\n                    help='manual epoch number (useful on restarts)')\nparser.add_argument('--decay-epochs', type=float, default=30, metavar='N',\n                    help='epoch interval to decay LR')\nparser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',\n                    help='epochs to warmup LR, if scheduler supports')\n\nparser.add_argument('--patience-epochs', type=int, default=10, metavar='N',\n                    help='patience epochs for Plateau LR scheduler (default: 10')\nparser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',\n                    help='LR decay rate (default: 0.1)')\nparser.add_argument('--power', type=int, default=1, help='power')\n\n# Augmentation & regularization parameters ONLY FOR IMAGE NET\nparser.add_argument('--no-aug', action='store_true', default=False,\n                    help='Disable all training augmentation, override other train aug args')\nparser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',\n                    help='Random resize scale (default: 0.08 1.0)')\nparser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                    help='Random resize aspect ratio (default: 0.75 1.33)')\nparser.add_argument('--hflip', type=float, default=0.5,\n                    help='Horizontal flip training aug probability')\nparser.add_argument('--vflip', type=float, default=0.,\n                    help='Vertical flip training aug probability')\nparser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',\n                    help='Color jitter factor (default: 0.4)')\nparser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',\n                    help='Use AutoAugment policy. \"v0\" or \"original\". (default: None)'),\nparser.add_argument('--aug-splits', type=int, default=0,\n                    help='Number of augmentation splits (default: 0, valid: 0 or >=2)')\nparser.add_argument('--jsd', action='store_true', default=False,\n                    help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')\nparser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',\n                    help='Random erase prob (default: 0.25)')\nparser.add_argument('--remode', type=str, default='pixel',\n                    help='Random erase mode (default: \"const\")')\nparser.add_argument('--recount', type=int, default=1,\n                    help='Random erase count (default: 1)')\nparser.add_argument('--resplit', action='store_true', default=False,\n                    help='Do not random erase first (clean) augmentation split')\nparser.add_argument('--mixup', type=float, default=0.8,\n                    help='mixup alpha, mixup enabled if > 0. (default: 0.)')\nparser.add_argument('--cutmix', type=float, default=1.0,\n                    help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')\nparser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,\n                    help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')\nparser.add_argument('--mixup-prob', type=float, default=1.0,\n                    help='Probability of performing mixup or cutmix when either/both is enabled')\nparser.add_argument('--mixup-switch-prob', type=float, default=0.5,\n                    help='Probability of switching to cutmix when both mixup and cutmix enabled')\nparser.add_argument('--mixup-mode', type=str, default='batch',\n                    help='How to apply mixup/cutmix params. Per \"batch\", \"pair\", or \"elem\"')\nparser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',\n                    help='Turn off mixup after this epoch, disabled if 0 (default: 0)')\nparser.add_argument('--smoothing', type=float, default=0.1,\n                    help='Label smoothing (default: 0.1)')\nparser.add_argument('--train-interpolation', type=str, default='random',\n                    help='Training interpolation (random, bilinear, bicubic default: \"random\")')\nparser.add_argument('--drop', type=float, default=0.0, metavar='PCT',\n                    help='Dropout rate (default: 0.0)')\nparser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',\n                    help='Drop connect rate, DEPRECATED, use drop-path (default: None)')\nparser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',\n                    help='Drop path rate (default: None)')\nparser.add_argument('--drop-block', type=float, default=None, metavar='PCT',\n                    help='Drop block rate (default: None)')\nparser.add_argument('--newton-maxiter', default=20, type=int,\n                    help='max iterration in newton method')\nparser.add_argument('--reset-drop', action='store_true', default=False,\n                    help='whether to reset drop')\nparser.add_argument('--kernel-method', type=str, default='cuda', choices=['torch', 'cuda'],\n                    help='The implementation way of gaussian kernel method, choose from \"cuda\" and \"torch\"')\n\n# Batch norm parameters (only works with gen_efficientnet based models currently)\nparser.add_argument('--bn-tf', action='store_true', default=False,\n                    help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')\nparser.add_argument('--bn-momentum', type=float, default=None,\n                    help='BatchNorm momentum override (if not None)')\nparser.add_argument('--bn-eps', type=float, default=None,\n                    help='BatchNorm epsilon override (if not None)')\nparser.add_argument('--sync-bn', action='store_true',\n                    help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')\nparser.add_argument('--dist-bn', type=str, default='',\n                    help='Distribute BatchNorm stats between node after each epoch (\"broadcast\", \"reduce\", or \"\")')\nparser.add_argument('--split-bn', action='store_true',\n                    help='Enable separate BN layers per augmentation split.')\n\n# Model Exponential Moving Average\nparser.add_argument('--model-ema', action='store_true', default=False,\n                    help='Enable tracking moving average of model weights')\nparser.add_argument('--model-ema-force-cpu', action='store_true', default=False,\n                    help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')\nparser.add_argument('--model-ema-decay', type=float, default=0.99996,\n                    help='decay factor for model weights moving average (default: 0.9998)')\n\n# Misc\n\nparser.add_argument('--log-interval', type=int, default=50, metavar='N',\n                    help='how many batches to wait before logging training status')\nparser.add_argument('--recovery-interval', type=int, default=0, metavar='N',\n                    help='how many batches to wait before writing recovery checkpoint')\nparser.add_argument('-j', '--workers', type=int, default=8, metavar='N',\n                    help='how many training processes to use (default: 1)')\nparser.add_argument('--num-gpu', type=int, default=len(devices),\n                    help='Number of GPUS to use')\nparser.add_argument('--save-images', action='store_true', default=False,\n                    help='save images of inputs bathes every log interval for debugging')\nparser.add_argument('--amp', action='store_true', default=False,\n                    help='use NVIDIA Apex AMP or Native AMP for mixed precision training')\nparser.add_argument('--apex-amp', action='store_true', default=False,\n                    help='Use NVIDIA Apex AMP mixed precision')\nparser.add_argument('--native-amp', action='store_true', default=False,\n                    help='Use Native Torch AMP mixed precision')\nparser.add_argument('--channels-last', action='store_true', default=False,\n                    help='Use channels_last memory layout')\nparser.add_argument('--pin-mem', action='store_true', default=False,\n                    help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')\nparser.add_argument('--no-prefetcher', action='store_true', default=False,\n                    help='disable fast prefetcher')\n\nparser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',\n                    help='Best metric (default: \"top1\"')\nparser.add_argument('--tta', type=int, default=0, metavar='N',\n                    help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')\nparser.add_argument('--local_rank', default=0, type=int)\nparser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,\n                    help='use the multi-epochs-loader to save time at the beginning of every epoch')\nparser.add_argument('--eval', action='store_true', help='Perform evaluation only')\nparser.add_argument('--device', type=int, default=devices[0])\n\n# Spike parameters\nparser.add_argument('--step', type=int, default=10, help='Simulation time step (default: 10)')\nparser.add_argument('--encode', type=str, default='direct', help='Input encode method (default: direct)')\nparser.add_argument('--temporal-flatten', action='store_true',\n                    help='Temporal flatten to channels. ONLY FOR EVENT DATA TRAINING BY ANN')\nparser.add_argument('--adaptive-node', action='store_true')\nparser.add_argument('--critical-loss', action='store_true')\n\n# neuron type\nparser.add_argument('--node-type', type=str, default='LIFNode', help='Node type in network (default: PLIF)')\nparser.add_argument('--act-fun', type=str, default='QGateGrad',\n                    help='Surogate Function in node. Only for Surrogate nodes (default: AtanGrad)')\nparser.add_argument('--threshold', type=float, default=.5, help='Firing threshold (default: 0.5)')\nparser.add_argument('--tau', type=float, default=2., help='Attenuation coefficient (default: 2.)')\nparser.add_argument('--requires-thres-grad', action='store_true')\nparser.add_argument('--sigmoid-thres', action='store_true')\n\nparser.add_argument('--loss-fn', type=str, default='ce', help='loss function (default: ce)')\nparser.add_argument('--noisy-grad', type=float, default=0.,\n                    help='Add noise to backward, sometime will make higher accuracy (default: 0.)')\nparser.add_argument('--spike-output', action='store_true', default=False,\n                    help='Using mem output or spike output (default: False)')\nparser.add_argument('--n_groups', type=int, default=1)\n\n# EventData Augmentation\nparser.add_argument('--mix-up', action='store_true', help='Mix-up for event data (default: False)')\nparser.add_argument('--cut-mix', action='store_true', help='CutMix for event data (default: False)')\nparser.add_argument('--event-mix', action='store_true', help='EventMix for event data (default: False)')\nparser.add_argument('--cutmix_beta', type=float, default=1.0, help='cutmix_beta (default: 1.)')\nparser.add_argument('--cutmix_prob', type=float, default=0.5, help='cutmix_prib for event data (default: .5)')\nparser.add_argument('--cutmix_num', type=int, default=1, help='cutmix_num for event data (default: 1)')\nparser.add_argument('--cutmix_noise', type=float, default=0.,\n                    help='Add Pepper noise after mix, sometimes work (default: 0.)')\nparser.add_argument('--rand-aug', action='store_true',\n                    help='Rand Augment for Event data (default: False)')\nparser.add_argument('--randaug_n', type=int, default=3,\n                    help='Rand Augment times n (default: 3)')\nparser.add_argument('--randaug_m', type=int, default=15,\n                    help='Rand Augment times n (default: 15) (0-30)')\nparser.add_argument('--train-portion', type=float, default=0.9,\n                    help='Dataset portion, only for datasets which do not have validation set (default: 0.9)')\nparser.add_argument('--event-size', default=48, type=int,\n                    help='Event size. Resize event data before process (default: 48)')\nparser.add_argument('--layer-by-layer', action='store_true',\n                    help='forward step-by-step or layer-by-layer. '\n                         'Larger Model with layer-by-layer will be faster (default: False)')\nparser.add_argument('--node-resume', type=str, default='',\n                    help='resume weights in node for adaptive node. (default: False)')\nparser.add_argument('--node-trainable', action='store_true')\n\n# visualize\nparser.add_argument('--visualize', action='store_true',\n                    help='Visualize spiking map for each layer, only for validate (default: False)')\n\nparser.add_argument('--tsne', action='store_true')\nparser.add_argument('--conf-mat', action='store_true')\n\n# DARTS parameters\n\nparser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')\n# parser.add_argument('--arch', default='dvsc10_new_skip19', type=str)\n# parser.add_argument('--motif', default='m1', type=str)\n\nparser.add_argument('--parse_method', default='darts', type=str)\nparser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability')\n\n# parser.add_argument('--back-connection', action='store_true',default=True)\n\nparser.add_argument('--suffix', type=str, default='',\n                    help='Add an additional suffix to the save path (default: \\'\\')')\n\ntry:\n    from apex import amp\n    from apex.parallel import DistributedDataParallel as ApexDDP\n    from apex.parallel import convert_syncbn_model\n\n    has_apex = True\nexcept ImportError:\n    has_apex = False\n\nhas_native_amp = False\ntry:\n    if getattr(torch.cuda.amp, 'autocast') is not None:\n        has_native_amp = True\nexcept AttributeError:\n    pass\n\ndef check_mem(cuda_device):\n    devices_info = os.popen('\"/usr/bin/nvidia-smi\" --query-gpu=memory.total,memory.used --format=csv,nounits,noheader').read().strip().split(\"\\n\")\n    total, used = devices_info[int(cuda_device)].split(',')\n    return total,used\n\ndef occumpy_mem(cuda_device):\n    total, used = check_mem(cuda_device)\n    total = int(total)\n    used = int(used)\n    max_mem = int(total * 1)\n    block_mem = int((max_mem - used)*0.3)\n    x = torch.cuda.FloatTensor(256,1024,block_mem)\n    del x\n    \n\nclass NAS(Problem):\n    # first define the NAS problem (inherit from pymop)\n    def __init__(self, args,n_var=20, n_obj=1, n_constr=0, lb=None, ub=None,\n                 init_channels=24, layers=8):\n        super().__init__(n_var=n_var, n_obj=n_obj, n_constr=n_constr, type_var=np.int64)\n        self.xl = lb\n        self.xu = ub\n        self._lr =args.lr\n\n        self._n_evaluated = 0  # keep track of how many architectures are sampled\n        self.args=args\n    def _evaluate(self, x, out, *args, **kwargs):\n        objs = np.full((x.shape[0], self.n_obj), np.nan)\n        train_data, loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % self.args.dataset)(\n            batch_size=self.args.batch_size,\n            step=self.args.step,\n            args=self.args,\n            _logge=_logger,\n            size=self.args.event_size,\n            mix_up=self.args.mix_up,\n            cut_mix=self.args.cut_mix,\n            event_mix=self.args.event_mix,\n            beta=self.args.cutmix_beta,\n            prob=self.args.cutmix_prob,\n            num=self.args.cutmix_num,\n            noise=self.args.cutmix_noise,\n            num_classes=self.args.num_classes,\n            rand_aug=self.args.rand_aug,\n            randaug_n=self.args.randaug_n,\n            randaug_m=self.args.randaug_m,\n            temporal_flatten=self.args.temporal_flatten,\n            portion=self.args.train_portion,\n            _logger=_logger)\n\n        for i in range(x.shape[0]):\n            arch_id = self._n_evaluated + 1\n            print('\\n')\n            _logger.info('Network= {}'.format(arch_id))\n            genome = x[i, :]\n            arch_dir=os.path.join(self.args.output_dir)\n            if os.path.exists(arch_dir) is False:\n                os.makedirs(arch_dir,exist_ok = True)\n            self.args.lr=self._lr\n            performance,acc=train_motifs(args=self.args,gen=0,arch_dir=arch_dir,genome=genome,_logger=_logger,args_text=args_text,devices=devices,bits=bits)\n            objs[i, 0] = 1000 - performance\n\n            # sJ,J,Cm,Ccosine,Cpe,K = obj.LSP(self.args,genome,train_data)\n            # objs[i, 0] = 1000 - Cm\n\n\n            _logger.info('performance= {}'.format(objs[i, 0]))\n            self._n_evaluated += 1\n\n\n        out[\"F\"] = objs\n        # if your NAS problem has constraints, use the following line to set constraints\n        # out[\"G\"] = np.column_stack([g1, g2, g3, g4, g5, g6]) in case 6 constraints\n\ndef do_every_generations(algorithm):\n    # this function will be call every generation\n    # it has access to the whole algorithm class\n    gen = algorithm.n_gen\n    pop_var = algorithm.pop.get(\"X\")\n    pop_obj = algorithm.pop.get(\"F\")\n\n    # report generation info to files\n    _logger.info(\"generation = {}\".format(gen))\n    _logger.info(\"population error: best = {}, mean = {}, \"\n                 \"median = {}, worst = {}\".format(np.min(pop_obj[:, 0]), np.mean(pop_obj[:, 0]),\n                                                  np.median(pop_obj[:, 0]), np.max(pop_obj[:, 0])))\n    _logger.info('Best Genome= {}'.format(pop_var[np.argmin(pop_obj[:, 0])]))\n \n\ndef _parse_args():\n    args_config, remaining = config_parser.parse_known_args()\n    args = parser.parse_args(remaining)\n    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)\n    return args, args_text\n\n\nif __name__ == '__main__':\n\n\n    args, args_text = _parse_args()\n    args.no_spike_output = True\n    output_dir = ''\n    if args.bns:\n        from cellmodel import NetworkCIFAR\n    else:\n        from cell123model import NetworkCIFAR\n    if args.local_rank == 0:\n        output_base = args.output if args.output else './output'\n        exp_name = '-'.join([\n            datetime.now().strftime(\"%Y%m%d-%H%M%S\"),\n            # args.model,\n            # args.dataset,\n            str(args.layers)+'layers',\n            str(args.init_channels)+'channels',\n            'motif'+str(args.mid),\n            str(args.step)+'steps',\n            # args.suffix\n            # str(args.img_size)\n        ])\n        output_dir = get_outdir(output_base,str(args.dataset),exp_name)\n        args.output_dir = output_dir\n        setup_default_logging(log_path=os.path.join(output_dir, 'log.txt'))\n\n    else:\n        setup_default_logging()\n\n    args.prefetcher = not args.no_prefetcher\n    args.distributed = False\n    if 'WORLD_SIZE' in os.environ:\n        args.distributed = int(os.environ['WORLD_SIZE']) > 1\n        if args.distributed and args.num_gpu > 1:\n            _logger.warning(\n                'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')\n            args.num_gpu = 1\n    # args.device = 'cuda:0'\n    args.world_size = 1\n    args.rank = 0  # global rank\n    if args.distributed:\n        args.num_gpu = 1\n        args.device = 'cuda:%d' % args.local_rank\n        torch.cuda.set_device(args.local_rank)\n        torch.distributed.init_process_group(backend='nccl', init_method='env://')\n        args.world_size = torch.distributed.get_world_size()\n        args.rank = torch.distributed.get_rank()\n    else:\n        torch.cuda.set_device('cuda:%d' % args.device)\n    assert args.rank >= 0\n\n    if args.distributed:\n        _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'\n                        % (args.rank, args.world_size))\n    else:\n        _logger.info('Training with a single process on %d GPUs.' % args.num_gpu)\n\n    # torch.manual_seed(args.seed + args.rank)\n    setup_seed(args.seed + args.rank)\n    defalut_lr = args.lr\n\n    occumpy_mem(str(args.device))\n\n\n    train_data, loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % args.dataset)(\n        batch_size=args.batch_size,\n        step=args.step,\n        args=args,\n        _logge=_logger,\n        size=args.event_size,\n        mix_up=args.mix_up,\n        cut_mix=args.cut_mix,\n        event_mix=args.event_mix,\n        beta=args.cutmix_beta,\n        prob=args.cutmix_prob,\n        num=args.cutmix_num,\n        noise=args.cutmix_noise,\n        num_classes=args.num_classes,\n        rand_aug=args.rand_aug,\n        randaug_n=args.randaug_n,\n        randaug_m=args.randaug_m,\n        temporal_flatten=args.temporal_flatten,\n        portion=args.train_portion,\n        _logger=_logger,\n\n    )\n    len_motifs=args.layers*bits+1\n    low = np.zeros(len_motifs)\n    low[-1] = 2\n    up=[]\n    for i in range(0,args.layers*bits,bits):\n        t=[args.mid]\n        t=t+[(ops-1) for j in range(bits-1)]\n        t[-1]=2*(ops-1)\n        up.extend(t)\n    up.append(3)\n    up=np.array(up).reshape(-1,)\n\n    kkk = NAS(args,n_var=len_motifs, \n                  n_obj=2, n_constr=0, lb=low, ub=up,\n                  init_channels=args.init_channels, layers=args.layers)\n    method = engine.nsganet(pop_size=args.pop_size,\n                            n_offsprings=args.n_offspring,\n                            eliminate_duplicates=True)\n    kres=minimize(kkk,\n                   method,\n                   callback=do_every_generations,\n                   termination=('n_gen', args.n_gens))"
  },
  {
    "path": "examples/Structure_Evolution/MSE-NAS/loss_f.py",
    "content": "import torch\nimport torch.nn.functional as f\n\n\ndef psp(inputs, n_steps,tau_s):\n    shape = inputs.shape\n    n_steps = n_steps\n    tau_s = tau_s\n\n    syn = torch.zeros(shape[0], shape[1], shape[2], shape[3]).cuda()\n    syns = torch.zeros(shape[0], shape[1], shape[2], shape[3], n_steps).cuda()\n\n    for t in range(n_steps):\n        syn = syn - syn / tau_s + inputs[..., t]\n        syns[..., t] = syn / tau_s\n\n    return syns\n\n\nclass SpikeLoss(torch.nn.Module):\n    \"\"\"\n    This class defines different spike based loss modules that can be used to optimize the SNN.\n    \"\"\"\n    def __init__(self, desired_count,undesired_count):\n        super(SpikeLoss, self).__init__()\n        self.desired_count = desired_count\n        self.desired_count = undesired_count\n        self.criterion = torch.nn.CrossEntropyLoss()\n\n    def spike_count(self, outputs, target, desired_count,undesired_count):\n        delta = loss_count.apply(outputs, target, desired_count,undesired_count)\n        return 1 / 2 * torch.sum(delta ** 2)\n\n    def spike_kernel(self, outputs, target, desired_count,undesired_count):\n        delta = loss_kernel.apply(outputs, target, desired_count,undesired_count)\n        return 1 / 2 * torch.sum(delta ** 2)\n\n    def spike_soft_max(self, outputs, target):\n        delta = f.log_softmax(outputs.sum(dim=4).squeeze(-1).squeeze(-1), dim = 1)\n        return self.criterion(delta, target)\n\n\nclass loss_count(torch.autograd.Function):  # a and u is the incremnet of each time steps\n    @staticmethod\n    def forward(ctx, outputs, target, desired_count,undesired_count):\n        desired_count = desired_count\n        undesired_count = undesired_count\n        shape = outputs.shape\n        n_steps = shape[4]\n        out_count = torch.sum(outputs, dim=4)\n\n        delta = (out_count - target) / n_steps\n        mask = torch.ones_like(out_count)\n        mask[target == undesired_count] = 0\n        mask[delta < 0] = 0\n        delta[mask == 1] = 0\n        mask = torch.ones_like(out_count)\n        mask[target == desired_count] = 0\n        mask[delta > 0] = 0\n        delta[mask == 1] = 0\n        delta = delta.unsqueeze_(-1).repeat(1, 1, 1, 1, n_steps)\n        return delta\n\n    @staticmethod\n    def backward(ctx, grad):\n        return grad, None, None, None\n\n\nclass loss_kernel(torch.autograd.Function):  # a and u is the incremnet of each time steps\n    @staticmethod\n    def forward(ctx, outputs, target, n_steps,tau_s):\n        # out_psp = psp(outputs, network_config)\n        target_psp = psp(target, n_steps,tau_s)\n        delta = outputs - target_psp\n        return delta\n\n    @staticmethod\n    def backward(ctx, grad):\n        return grad, None, None\n"
  },
  {
    "path": "examples/Structure_Evolution/MSE-NAS/micro_encoding.py",
    "content": "# NASNet Search Space https://arxiv.org/pdf/1707.07012.pdf\n# code modified from DARTS https://github.com/quark0/darts\nimport numpy as np\nfrom collections import namedtuple\n\nimport torch\n# from models.micro_models import NetworkCIFAR as Network\n\nimport motifs\n\n# Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')\n# Genotype_norm = namedtuple('Genotype', 'normal normal_concat')\n# Genotype_redu = namedtuple('Genotype', 'reduce reduce_concat')\nGenotype = namedtuple('Genotype', 'normal normal_concat')\n\n# what you want to search should be defined here and in micro_operations\n\nPRIMITIVES = [\n    'max_pool_3x3',\n    'avg_pool_3x3',\n    'skip_connect',\n    'sep_conv_3x3',\n    'sep_conv_5x5',\n    'dil_conv_3x3',\n    'dil_conv_5x5',\n    'sep_conv_7x7',\n    'conv_7x1_1x7',\n]\nOPERATIONS_back = [\n    # 'max_pool_3x3_p_back',\n    # 'avg_pool_3x3_p_back',\n    'conv_3x3_p_back',\n    'conv_5x5_p_back',\n    # 'avg_pool_3x3_n_back',\n    'conv_3x3_n_back',\n    'conv_5x5_n_back',\n    # 'sep_conv_3x3_p_back',\n    # 'sep_conv_5x5_p_back',\n    # 'dil_conv_3x3_p_back',\n    # 'dil_conv_5x5_p_back',\n    # 'def_conv_3x3_p_back',\n    # 'def_conv_5x5_p_back',\n]\nOPERATIONS_p = [\n    # 'max_pool_3x3_p',\n    # 'avg_pool_3x3_p',\n    'conv_3x3_p',\n    'conv_5x5_p',\n    # 'sep_conv_3x3_p',\n    # 'sep_conv_5x5_p',\n    # 'dil_conv_3x3_p',\n    # 'dil_conv_5x5_p',\n    # 'def_conv_3x3_p',\n    # 'def_conv_5x5_p',\n]\nops=len(OPERATIONS_p)\n\nOPERATIONS_n = [\n    # 'max_pool_3x3_n',\n    # 'avg_pool_3x3_n',\n    'conv_3x3_n',\n    'conv_5x5_n',\n    # 'sep_conv_3x3_n',\n    # 'sep_conv_5x5_n',\n    # 'dil_conv_3x3_n',\n    # 'dil_conv_5x5_n',\n    # 'def_conv_3x3_n',\n    # 'def_conv_5x5_n',\n\n    # 'transformer',\n]\n\ndef convert_cell(cell_bit_string):\n    # convert cell bit-string to genome\n    tmp = [cell_bit_string[i:i + 2] for i in range(0, len(cell_bit_string), 2)]\n    return [tmp[i:i + 2] for i in range(0, len(tmp), 2)]\n\n\ndef convert(bit_string):\n    # convert network bit-string (norm_cell + redu_cell) to genome\n    norm_gene = convert_cell(bit_string[:len(bit_string)//2])\n    redu_gene = convert_cell(bit_string[len(bit_string)//2:])\n    return [norm_gene, redu_gene]\n\n\n# def decode_cell(genome, norm=True):\n\n#     cell, cell_concat = [], list(range(2, len(genome)+2))\n#     for block in genome:\n#         for unit in block:\n#             cell.append((PRIMITIVES[unit[0]], unit[1]))\n#             if unit[1] in cell_concat:\n#                 cell_concat.remove(unit[1])\n\n#     if norm:\n#         return Genotype_norm(normal=cell, normal_concat=cell_concat)\n#     else:\n#         return Genotype_redu(reduce=cell, reduce_concat=cell_concat)\n\n\ndef decode(genome):\n    # decodes genome to architecture\n    normal_cell = genome[0]\n    reduce_cell = genome[1]\n\n    normal, normal_concat = [], list(range(2, len(normal_cell)+2))\n    reduce, reduce_concat = [], list(range(2, len(reduce_cell)+2))\n\n    for block in normal_cell:\n        for unit in block:\n            normal.append((PRIMITIVES[unit[0]], unit[1]))\n            if unit[1] in normal_concat:\n                normal_concat.remove(unit[1])\n\n    for block in reduce_cell:\n        for unit in block:\n            reduce.append((PRIMITIVES[unit[0]], unit[1]))\n            if unit[1] in reduce_concat:\n                reduce_concat.remove(unit[1])\n\n    return Genotype(\n        normal=normal, normal_concat=normal_concat,\n        reduce=reduce, reduce_concat=reduce_concat\n    )\n\ndef decode_motif(layers,bits,genome):\n    # decodes genome to architecture\n    motif_list=[]\n    motif_ids=[]\n    for b in range(0,layers*bits,bits):\n        if genome[-1]==2:\n            motif_id='mm'+str(genome[b])\n        elif genome[-1]==3:\n            motif_id='m'+str(genome[b])\n        else:\n            motif_id='t'+str(genome[b])\n\n        motif_ids.append(genome[b])\n\n        normalcell=eval('motifs.%s' % motif_id)\n    \n        newnormal=[]\n        for i in range(0,len(normalcell.normal)):\n            op=normalcell.normal[i]\n            if 'skip' in op[0]:\n                newnormal.append(op)\n                continue\n            elif 'back' in op[0]:\n                newnormal.append((OPERATIONS_back[genome[b+1+len(normalcell.normal)-1]],op[1]))\n                continue            \n            elif '_n' in op[0]:\n                newnormal.append((OPERATIONS_n[genome[b+1+i]],op[1]))\n                continue\n            elif '_p' in op[0]:\n                newnormal.append((OPERATIONS_p[genome[b+1+i]],op[1]))\n                continue\n        m=Genotype(normal=newnormal, normal_concat=normalcell.normal_concat,)\n        motif_list.append(m)\n\n\n    return motif_list,motif_ids\n\ndef compare_cell(cell_string1, cell_string2):\n    cell_genome1 = convert_cell(cell_string1)\n    cell_genome2 = convert_cell(cell_string2)\n    cell1, cell2 = cell_genome1[:], cell_genome2[:]\n\n    for block1 in cell1:\n        for block2 in cell2:\n            if block1 == block2 or block1 == block2[::-1]:\n                cell2.remove(block2)\n                break\n    if len(cell2) > 0:\n        return False\n    else:\n        return True\n\n\ndef compare(string1, string2):\n\n    if compare_cell(string1[:len(string1)//2],\n                    string2[:len(string2)//2]):\n        if compare_cell(string1[len(string1)//2:],\n                        string2[len(string2)//2:]):\n            return True\n\n    return False\n\n\n# def debug():\n#     # design to debug the encoding scheme\n#     seed = 0\n#     np.random.seed(seed)\n#     budget = 2000\n#     B, n_ops, n_cell = 5, 7, 2\n#     networks = []\n#     design_id = 1\n#     while len(networks) < budget:\n#         bit_string = []\n#         for c in range(n_cell):\n#             for b in range(B):\n#                 bit_string += [np.random.randint(n_ops),\n#                                np.random.randint(b + 2),\n#                                np.random.randint(n_ops),\n#                                np.random.randint(b + 2)\n#                                ]\n\n#         genome = convert(bit_string)\n#         # check against evaluated networks in case of duplicates\n#         doTrain = True\n#         for network in networks:\n#             if compare(genome, network):\n#                 doTrain = False\n#                 break\n\n#         if doTrain:\n#             genotype = decode(genome)\n#             model = Network(16, 10, 8, False, genotype)\n#             model.drop_path_prob = 0.0\n#             data = torch.randn(1, 3, 32, 32)\n#             output, output_aux = model(torch.autograd.Variable(data))\n#             networks.append(genome)\n#             design_id += 1\n#             print(design_id)\n\n\nif __name__ == \"__main__\":\n    # debug()\n    # genome1 = [[[[3, 0], [3, 1]], [[3, 0], [3, 1]],\n    #             [[3, 1], [2, 0]], [[2, 0], [5, 2]]],\n    #            [[[0, 0], [0, 1]], [[2, 2], [0, 1]],\n    #             [[0, 0], [2, 2]], [[2, 2], [0, 1]]]]\n    # genome2 = [[[[3, 1], [3, 0]], [[3, 1], [3, 0]],\n    #             [[3, 1], [2, 0]], [[2, 0], [5, 2]]],\n    #            [[[0, 1], [0, 0]], [[2, 2], [0, 1]],\n    #             [[0, 0], [2, 2]], [[2, 2], [0, 0]]]]\n    #\n    # print(compare(genome1, genome2))\n    # print(genome1)\n    # print(genome2)\n    # bit_string1 = [3,1,3,0,3,1,3,0,3,1,2,0,2,0,5,2,0,0,0,1,2,2,0,1,0,0,2,2,2,2,0,1]\n    # bit_string2 = [3, 0, 3, 1, 3, 0, 3, 1, 3, 1, 2, 0, 2, 0, 5, 2,\n    #                0, 0, 0, 1, 2, 2, 0, 1, 0, 0, 2, 2, 2, 2, 0, 1]\n    # # print(convert(bit_string1))\n    # print(compare(bit_string1, bit_string2))\n    # print(decode(convert(bit_string)))\n\n    cell_bit_string = [3, 0, 3, 1, 3, 0, 3, 1, 3, 1, 2, 0, 2, 0, 5, 2]\n    # print(decode_cell(convert_cell(cell_bit_string), norm=False))\n"
  },
  {
    "path": "examples/Structure_Evolution/MSE-NAS/motifs.py",
    "content": "from collections import namedtuple\n\nimport torch\n\nGenotype = namedtuple('Genotype', 'normal normal_concat')\n\nm0=Genotype(\n    normal=[\n        ('skip', 0), ('skip', 1),('skip', 2),\n    ],\n    normal_concat=range(3, 4)\n)\nmm0=Genotype(\n    normal=[\n        ('skip', 0), ('skip', 1),('skip', 2),\n    ],\n    normal_concat=range(2, 3)\n)\nmm1=Genotype(\n    normal=[\n        ('conv_3x3_p', 0), ('conv_5x5_p', 1),\n        ('skip_connect', 0), ('conv_5x5_p', 2),\n    ],\n    normal_concat=range(2, 4)\n)\n\n\nmm2=Genotype(\n    normal=[\n        ('conv_5x5_p', 0), ('conv_5x5_p', 1),\n        ('skip_connect', 0), ('conv_5x5_n', 2),\n        ('conv_5x5_p', 2), ('conv_3x3_n', 3),\n\n    ],\n    normal_concat=range(2, 5)\n)\n\nmm4=Genotype(\n    normal=[\n        ('conv_3x3_p', 0), ('conv_5x5_p', 1),#2\n        ('conv_3x3_p', 0), ('conv_3x3_p', 1),#3\n        ('conv_5x5_p', 2), ('conv_5x5_p', 3),#4\n        ('skip_connect', 0), ('conv_3x3_p', 4),#5\n        ('skip_connect', 0), ('conv_3x3_p', 4),#6\n        ],\n    normal_concat=range(2, 7)\n)\n\n\n\nmm3=Genotype(\n    normal=[\n        ('conv_3x3_p', 0), ('conv_5x5_p', 1),#2\n        ('skip_connect', 0), ('conv_5x5_n', 2),#3\n        ('skip_connect', 0), ('conv_5x5_p', 3),#4\n\n        ('skip_connect_back', 2),#3\n        ('conv_3x3_p_back', 3),#4\n\n    ],\n    normal_concat=range(2, 5)\n)\n\n\nmm5=Genotype(\n    normal=[\n        ('conv_3x3_p', 0), ('conv_5x5_p', 1),#2\n        ('skip_connect', 0), ('conv_5x5_p', 2),#3\n\n        ('skip_connect_back', 2),#3\n    ],\n    normal_concat=range(2, 4)\n)\n\n\n\nm1=Genotype(\n    normal=[\n        ('conv_3x3_p', 0), ('conv_5x5_p', 1), ('conv_5x5_p', 2), #B3\n        ('skip', 0), ('conv_5x5_p', 3), ('skip', 1), #C4\n    ],\n    normal_concat=range(3, 5)\n)\n\n\nm2=Genotype(\n    normal=[\n        ('conv_5x5_p', 0), ('conv_5x5_p', 1),('conv_5x5_p', 2), #B3\n        ('skip', 0), ('conv_5x5_n', 3), ('skip', 1),#C4\n        ('conv_5x5_p', 3), ('conv_3x3_n', 4), ('skip', 1), #D5\n\n    ],\n    normal_concat=range(3, 6)\n)\n\nm4=Genotype(\n    normal=[\n        ('conv_3x3_p', 0), ('conv_5x5_p', 1),('conv_5x5_p', 2), #3\n        ('conv_3x3_p', 0), ('conv_3x3_p', 1),('conv_5x5_p', 2), #4\n        ('skip', 0), ('conv_5x5_p', 3), ('conv_5x5_p', 4), #5\n        ('skip', 0), ('conv_3x3_p', 3),('conv_3x3_n', 5),#6\n        ('skip', 0), ('conv_3x3_p', 4),('conv_3x3_n', 5),#7\n        ],\n    normal_concat=range(3, 8)\n)\n\n\n\nm3=Genotype(\n    normal=[\n        ('conv_3x3_p', 0), ('conv_5x5_p', 1), ('conv_3x3_p', 2), #3\n        ('skip', 0), ('conv_5x5_p', 3),('skip', 1), #4\n        ('skip', 0), ('conv_5x5_p', 3), ('skip', 1), #5\n\n        ('conv_3x3_n_back', 3),#4\n        ('skip_back', 2),#5\n\n    ],\n    normal_concat=range(3, 6)\n)\n\n\nm5=Genotype(\n    normal=[\n        ('conv_3x3_p', 0), ('conv_5x5_p', 1), ('conv_5x5_p', 2),#3\n        ('skip', 0),('skip', 1), ('conv_5x5_n', 3), #4\n\n        ('skip_connect_back', 3),#4\n    ],\n    normal_concat=range(3, 5)\n)\n\n\nt1=Genotype(\n    normal=[\n        ('conv_3x3_p', 0), ('conv_5x5_p', 1), ('conv_5x5_p', 2),  ('conv_5x5_p', 3), #4\n        ('skip', 0), ('conv_5x5_p', 4), ('skip', 1), ('skip', 2), #5\n        ('skip', 0), ('conv_5x5_p', 5), ('skip', 1), ('skip', 2), #6\n        ('skip', 0), ('conv_5x5_p', 5), ('skip', 1), ('skip', 2), #7\n    ],\n    normal_concat=range(4, 8)\n)\n\n\nt2=Genotype(\n    normal=[\n        ('conv_5x5_p', 0), ('conv_5x5_p', 1),('conv_5x5_p', 2), ('conv_5x5_p', 3), #4\n        ('skip', 0), ('conv_5x5_n', 4), ('skip', 1),('skip', 2),#5\n        ('conv_5x5_p', 4), ('conv_3x3_n', 5), ('skip', 1),('skip', 2), #6\n\n    ],\n    normal_concat=range(4, 7)\n)\n\nt4=Genotype(\n    normal=[\n        ('conv_3x3_p', 0), ('conv_5x5_p', 1),('conv_5x5_p', 2), ('conv_5x5_p', 3), #4\n        ('conv_5x5_p', 0), ('skip', 1),('conv_5x5_n', 4), ('skip', 3), #5\n        ('skip', 0), ('conv_5x5_p', 3), ('conv_5x5_n', 4), ('skip', 2),#6\n\n        ],\n    normal_concat=range(4, 7)\n)\n\n\n\nt3=Genotype(\n    normal=[\n        ('conv_3x3_p', 0), ('conv_5x5_p', 1), ('conv_3x3_p', 2), ('conv_3x3_p', 3),#4\n        ('skip', 0), ('skip', 2),('skip', 1), ('skip', 3),('conv_3x3_p', 4),#5\n        ('skip', 0), ('conv_5x5_p', 4), ('skip', 1),('skip', 2), #6\n\n        ('conv_3x3_n_back', 4),#5\n        ('skip_back', 4),#6\n\n    ],\n    normal_concat=range(4, 7)\n)\n\n\nt5=Genotype(\n    normal=[\n        ('conv_3x3_p', 0), ('conv_5x5_p', 1), ('conv_5x5_p', 2),('conv_5x5_p', 3),#4\n        ('skip', 0),('skip', 1), ('skip', 2), ('conv_5x5_n', 4), #5\n        ('skip', 0),('skip', 1), ('conv_5x5_n', 4),('conv_5x5_n', 5), #6\n        ('skip', 0),('skip', 1), ('conv_5x5_n', 4),('conv_5x5_n', 5), #7\n\n        ('conv_3x3_n_back', 4),#5\n        ('skip_back', 4),#6\n        ('skip_back', 5),#7\n\n    ],\n    normal_concat=range(4, 8)\n)\n"
  },
  {
    "path": "examples/Structure_Evolution/MSE-NAS/nsganet.py",
    "content": "import numpy as np\n\nfrom pymoo.algorithms.genetic_algorithm import GeneticAlgorithm\nfrom pymoo.docs import parse_doc_string\nfrom pymoo.model.individual import Individual\nfrom pymoo.model.survival import Survival\nfrom pymoo.operators.crossover.point_crossover import PointCrossover\nfrom pymoo.operators.mutation.polynomial_mutation import PolynomialMutation\nfrom pymoo.operators.sampling.random_sampling import RandomSampling\nfrom pymoo.operators.selection.tournament_selection import compare, TournamentSelection\nfrom pymoo.util.display import disp_multi_objective\nfrom pymoo.util.dominator import Dominator\nfrom pymoo.util.non_dominated_sorting import NonDominatedSorting\nfrom pymoo.util.randomized_argsort import randomized_argsort\n\n\n# =========================================================================================================\n# Implementation\n# based on nsga2 from https://github.com/msu-coinlab/pymoo\n# =========================================================================================================\n\n\nclass NSGANet(GeneticAlgorithm):\n\n    def __init__(self, **kwargs):\n        kwargs['individual'] = Individual(rank=np.inf, crowding=-1)\n        super().__init__(**kwargs)\n\n        self.tournament_type = 'comp_by_dom_and_crowding'\n        self.func_display_attrs = disp_multi_objective\n\n\n# ---------------------------------------------------------------------------------------------------------\n# Binary Tournament Selection Function\n# ---------------------------------------------------------------------------------------------------------\n\n\ndef binary_tournament(pop, P, algorithm, **kwargs):\n    if P.shape[1] != 2:\n        raise ValueError(\"Only implemented for binary tournament!\")\n\n    tournament_type = algorithm.tournament_type\n    S = np.full(P.shape[0], np.nan)\n\n    for i in range(P.shape[0]):\n\n        a, b = P[i, 0], P[i, 1]\n\n        # if at least one solution is infeasible\n        if pop[a].CV > 0.0 or pop[b].CV > 0.0:\n            S[i] = compare(a, pop[a].CV, b, pop[b].CV, method='smaller_is_better', return_random_if_equal=True)\n\n        # both solutions are feasible\n        else:\n\n            if tournament_type == 'comp_by_dom_and_crowding':\n                rel = Dominator.get_relation(pop[a].F, pop[b].F)\n                if rel == 1:\n                    S[i] = a\n                elif rel == -1:\n                    S[i] = b\n\n            elif tournament_type == 'comp_by_rank_and_crowding':\n                S[i] = compare(a, pop[a].rank, b, pop[b].rank,\n                               method='smaller_is_better')\n\n            else:\n                raise Exception(\"Unknown tournament type.\")\n\n            # if rank or domination relation didn't make a decision compare by crowding\n            if np.isnan(S[i]):\n                S[i] = compare(a, pop[a].get(\"crowding\"), b, pop[b].get(\"crowding\"),\n                               method='larger_is_better', return_random_if_equal=True)\n\n    return S[:, None].astype(np.int)\n\n\n# ---------------------------------------------------------------------------------------------------------\n# Survival Selection\n# ---------------------------------------------------------------------------------------------------------\n\n\nclass RankAndCrowdingSurvival(Survival):\n\n    def __init__(self) -> None:\n        super().__init__(True)\n\n    def _do(self, pop, n_survive, D=None, **kwargs):\n\n        # get the objective space values and objects\n        F = pop.get(\"F\")\n\n        # the final indices of surviving individuals\n        survivors = []\n\n        # do the non-dominated sorting until splitting front\n        fronts = NonDominatedSorting().do(F, n_stop_if_ranked=n_survive)\n\n        for k, front in enumerate(fronts):\n\n            # calculate the crowding distance of the front\n            crowding_of_front = calc_crowding_distance(F[front, :])\n\n            # save rank and crowding in the individual class\n            for j, i in enumerate(front):\n                pop[i].set(\"rank\", k)\n                pop[i].set(\"crowding\", crowding_of_front[j])\n\n            # current front sorted by crowding distance if splitting\n            if len(survivors) + len(front) > n_survive:\n                I = randomized_argsort(crowding_of_front, order='descending', method='numpy')\n                I = I[:(n_survive - len(survivors))]\n\n            # otherwise take the whole front unsorted\n            else:\n                I = np.arange(len(front))\n\n            # extend the survivors by all or selected individuals\n            survivors.extend(front[I])\n\n        return pop[survivors]\n\n\ndef calc_crowding_distance(F):\n    infinity = 1e+14\n\n    n_points = F.shape[0]\n    n_obj = F.shape[1]\n\n    if n_points <= 2:\n        return np.full(n_points, infinity)\n    else:\n\n        # sort each column and get index\n        I = np.argsort(F, axis=0, kind='mergesort')\n\n        # now really sort the whole array\n        F = F[I, np.arange(n_obj)]\n\n        # get the distance to the last element in sorted list and replace zeros with actual values\n        dist = np.concatenate([F, np.full((1, n_obj), np.inf)]) \\\n               - np.concatenate([np.full((1, n_obj), -np.inf), F])\n\n        index_dist_is_zero = np.where(dist == 0)\n\n        dist_to_last = np.copy(dist)\n        for i, j in zip(*index_dist_is_zero):\n            dist_to_last[i, j] = dist_to_last[i - 1, j]\n\n        dist_to_next = np.copy(dist)\n        for i, j in reversed(list(zip(*index_dist_is_zero))):\n            dist_to_next[i, j] = dist_to_next[i + 1, j]\n\n        # normalize all the distances\n        norm = np.max(F, axis=0) - np.min(F, axis=0)\n        norm[norm == 0] = np.nan\n        dist_to_last, dist_to_next = dist_to_last[:-1] / norm, dist_to_next[1:] / norm\n\n        # if we divided by zero because all values in one columns are equal replace by none\n        dist_to_last[np.isnan(dist_to_last)] = 0.0\n        dist_to_next[np.isnan(dist_to_next)] = 0.0\n\n        # sum up the distance to next and last and norm by objectives - also reorder from sorted list\n        J = np.argsort(I, axis=0)\n        crowding = np.sum(dist_to_last[J, np.arange(n_obj)] + dist_to_next[J, np.arange(n_obj)], axis=1) / n_obj\n\n    # replace infinity with a large number\n    crowding[np.isinf(crowding)] = infinity\n\n    return crowding\n\n\n# =========================================================================================================\n# Interface\n# =========================================================================================================\n\n\ndef nsganet(\n        pop_size=100,\n        sampling=RandomSampling(var_type=np.int),\n        selection=TournamentSelection(func_comp=binary_tournament),\n        crossover=PointCrossover(n_points=2),\n        mutation=PolynomialMutation(eta=3, var_type=np.int),\n        eliminate_duplicates=True,\n        n_offsprings=None,\n        **kwargs):\n    \"\"\"\n\n    Parameters\n    ----------\n    pop_size : {pop_size}\n    sampling : {sampling}\n    selection : {selection}\n    crossover : {crossover}\n    mutation : {mutation}\n    eliminate_duplicates : {eliminate_duplicates}\n    n_offsprings : {n_offsprings}\n\n    Returns\n    -------\n    nsganet : :class:`~pymoo.model.algorithm.Algorithm`\n        Returns an NSGANet algorithm object.\n\n\n    \"\"\"\n\n    return NSGANet(pop_size=pop_size,\n                   sampling=sampling,\n                   selection=selection,\n                   crossover=crossover,\n                   mutation=mutation,\n                   survival=RankAndCrowdingSurvival(),\n                   eliminate_duplicates=eliminate_duplicates,\n                   n_offsprings=n_offsprings,\n                   **kwargs)\n\n\nparse_doc_string(nsganet)\n"
  },
  {
    "path": "examples/Structure_Evolution/MSE-NAS/obj.py",
    "content": "import sys\nimport os\nimport numpy as np\nimport torch\nimport logging\nimport argparse\nimport torch.nn as nn\nimport torch.utils\n# import torchvision.datasets as dset\nimport torch.backends.cudnn as cudnn\nimport torchvision.transforms as transforms\nfrom timm.models import create_model\nfrom cell123model import NetworkCIFAR\nfrom braincog.base.node.node import *\nfrom braincog.utils import *\nfrom braincog.base.utils.criterions import *\nfrom braincog.datasets.datasets import *\nfrom braincog.model_zoo.resnet import *\nfrom braincog.model_zoo.convnet import *\nfrom braincog.model_zoo.reactnet import *\nfrom braincog.model_zoo.convxnet import *\nfrom scipy.stats import kendalltau\nfrom misc import utils\nimport micro_encoding\nfrom misc.flops_counter import add_flops_counting_methods\nfrom utils import data_transforms\nfrom datetime import datetime\nbits=20\n\ndef logdet(K):\n    s, ld = torch.linalg.slogdet(K)\n    return ld\n\n\ndef LSP(args,genome,train_data):\n\n    with torch.no_grad():\n        test_motifs,ids = micro_encoding.decode_motif(layers=args.layers,bits=bits,genome=genome)\n        pmodel = create_model(\n            args.model,\n            pretrained=args.pretrained,\n            num_classes=args.num_classes,\n            dataset=args.dataset,\n            step=args.step,\n            encode_type=args.encode,\n            node_type=eval(args.node_type),\n            threshold=args.threshold,\n            tau=args.tau,\n            sigmoid_thres=args.sigmoid_thres,\n            requires_thres_grad=args.requires_thres_grad,\n            spike_output=not args.no_spike_output,\n            C=args.init_channels,\n            layers=args.layers,\n            auxiliary=args.auxiliary,\n            motif=test_motifs,\n            parse_method=args.parse_method,\n            act_fun=args.act_fun,\n            temporal_flatten=args.temporal_flatten,\n            layer_by_layer=args.layer_by_layer,\n            n_groups=args.n_groups,\n            cell_type=genome[-1]\n        )\n        pmodel.to(args.device)\n\n        pmodel.K = torch.zeros(args.batch_size, args.batch_size,device=args.device)\n        pmodel.J = torch.zeros(args.batch_size, args.batch_size,device=args.device)\n\n        # pmodel.Cou = torch.zeros(args.batch_size, args.batch_size,device=args.device)\n        pmodel.Ccosine = torch.zeros(args.batch_size, args.batch_size,device=args.device)\n        pmodel.Cm = torch.zeros(args.batch_size, args.batch_size,device=args.device)\n        pmodel.Cpe = torch.zeros(args.batch_size, args.batch_size,device=args.device)\n\n        # pmodel.Cou = torch.zeros(args.batch_size,device=args.device)\n        # pmodel.Ccosine = torch.zeros(args.batch_size, device=args.device)\n        # pmodel.Cm = torch.zeros(args.batch_size,device=args.device)\n        pmodel.num_actfun_C = 0    \n        pmodel.num_actfun_K = 0    \n\n        def computing_LSP(module, inp, out):\n            if isinstance(out, tuple):\n                out = out[0]\n            # \n            out = out.view(out.size(0), -1)\n            batch_num , neuron_num = out.size()\n            x = (out > 0).float()\n            full_matrix = torch.ones((args.batch_size, args.batch_size)).cuda() * neuron_num\n            sparsity = (x.sum(1)/neuron_num).unsqueeze(1)\n            norm_K = ((sparsity @ (1-sparsity.t())) + ((1-sparsity) @ sparsity.t())) * neuron_num\n            rescale_factor = torch.div(0.5* torch.ones((args.batch_size, args.batch_size)).cuda(), norm_K+1e-3)\n            K1_0 = (x @ (1 - x.t()))\n            K0_1 = ((1-x) @ x.t())\n            K0_0 = (1-x) @ (1-x).t()\n            K1_1 = (1-x) @ (1-x).t()\n\n            K_total = (full_matrix - rescale_factor * (K0_1 + K1_0))\n            J_total = (K1_1+K0_0)/(K0_1+K1_0+K1_1)\n            pmodel.K = pmodel.K + K_total\n            pmodel.J = pmodel.J + J_total\n            pmodel.num_actfun_K += 1\n            # x = x / torch.norm(x, dim=-1, keepdim=True)\n            # similarity = torch.mm(x, x.T)  \n\n\n\n            # dis_ou=torch.zeros_like(pmodel.Cou)\n            dis_man=torch.zeros_like(pmodel.Cm)\n            dis_cosine=torch.zeros_like(pmodel.Ccosine)\n\n\n            ou_dist = nn.PairwiseDistance(p=2)\n            m_dist = nn.PairwiseDistance(p=1)\n            cos = nn.CosineSimilarity(dim=1, eps=1e-6)\n            # cos = nn.CosineSimilarity(dim=0, eps=1e-6)\n            # for i in range(args.batch_size):\n            #     for j in range(i,args.batch_size):\n            #         input1 = x[i]\n            #         input2 = x[j]\n            #         dis_ou[i][j] = ou_dist(input1,input2)\n            #         dis_man[i][j] = m_dist(input1,input2)\n            #         dis_cosine[i][j] = cos(input1,input2)\n            \n            # pmodel.Cou = pmodel.Cou + dis_ou\n            for i in range(args.batch_size):\n                temp = x[i].repeat(args.batch_size,1)\n                dis_cosine[i] = cos(x,temp)\n                dis_man[i] = m_dist(x,temp)\n                 \n\n                \n            # pmodel.Cou = pmodel.Cou + ou_dist(x,x.flip(dims=[0]))\n            # pmodel.Cm = pmodel.Cou + m_dist(x,x.flip(dims=[0]))\n            pmodel.Ccosine = pmodel.Ccosine + dis_cosine\n            pmodel.Cm = pmodel.Cm + dis_man\n            pmodel.Cpe = pmodel.Cpe + torch.corrcoef(x / torch.norm(x, dim=-1, keepdim=True))\n\n            pmodel.num_actfun_C += 1\n            pmodel.num_actfun_K += 1\n\n\n\n        s_ou = []\n        s_m = []\n        s_pe = []\n        s_cos = []\n        s_k = []\n        s_jac=[]\n        s_sum_j=[]\n        repeat=2\n        for name,module in pmodel.named_modules():\n            if args.node_type in str(type(module)):    \n                handle = module.register_forward_hook(computing_LSP)\n\n        for j in range(repeat):\n            pmodel.K = torch.zeros(args.batch_size, args.batch_size,device=args.device)\n            pmodel.J = torch.zeros(args.batch_size, args.batch_size,device=args.device)\n\n            pmodel.Ccosine = torch.zeros(args.batch_size, args.batch_size,device=args.device)\n            pmodel.Cm = torch.zeros(args.batch_size, args.batch_size,device=args.device)\n            pmodel.Cpe = torch.zeros(args.batch_size, args.batch_size,device=args.device)\n            pmodel.num_actfun_C = 0    \n            pmodel.num_actfun_K = 0    \n\n            data_iterator = iter(train_data)\n            inputs, targets = next(data_iterator)\n            inputs, targets = inputs.cuda(), targets.cuda()\n            outputs = pmodel(inputs)\n            tc=pmodel.Ccosine/pmodel.num_actfun_C\n            tp=pmodel.Cpe/pmodel.num_actfun_C\n            tm=pmodel.Cm/pmodel.num_actfun_C\n            tj=pmodel.J/ (pmodel.num_actfun_K)\n            Ccos = torch.where(torch.isnan(tc), torch.full_like(tc, 0), tc)\n            Cpe = torch.where(torch.isnan(tp), torch.full_like(tp, 0), tp)\n            Cm = torch.where(torch.isnan(tm), torch.full_like(tm, 0), tm)\n\n            s_k.append(float(logdet(pmodel.K/ (pmodel.num_actfun_K))))\n            s_jac.append(float(logdet(tj)))\n            s_sum_j.append(float(tj.sum()))\n            s_m.append(float(Cm.sum()))\n            s_cos.append(float(Ccos.sum()))\n            s_pe.append(float(Cpe.sum()))\n    return np.mean(np.array(s_sum_j)),np.mean(np.array(s_jac)), np.mean(np.array(s_m)),np.mean(np.array(s_cos)),np.mean(np.array(s_pe)),np.mean(np.array(s_k))\n\n\n\n"
  },
  {
    "path": "examples/Structure_Evolution/MSE-NAS/operations.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.nn import *\nimport torch.nn.functional as F\nfrom torch import einsum\nfrom einops import rearrange\n# from braincog.model_zoo.base_module import DeformConvPack\nfrom braincog.model_zoo.base_module import BaseLinearModule\n\n\n# from mmcv.ops import ModulatedDeformConv2dPack\n\n\ndef si_relu(x, positive):\n    if positive == 1:\n        return torch.where(x > 0., x, torch.zeros_like(x))\n    elif positive == 0:\n        return x\n    elif positive == -1:\n        return torch.where(x < 0., x, torch.zeros_like(x))\n    else:\n        raise ValueError\n\n\n\nclass SiReLU(nn.Module):\n    def __init__(self, positive=0):\n        super().__init__()\n        self.positive = positive\n\n    def forward(self, x):\n        return si_relu(x, self.positive)\n\n\ndef weight_init(m):\n    if isinstance(m, nn.Conv2d):\n        torch.nn.init.xavier_normal(m.weight.data, gain=0.1)\n        torch.nn.init.constant(m.bias.data, 0.)\n\nOPS_Mlp = {\n    'mlp': lambda C, act_fun:\n        SiMLP(C, C, act_fun=act_fun, positive=0),\n    'mlp_p': lambda C, act_fun:\n        SiMLP(C, C, act_fun=act_fun, positive=1),\n    'mlp_n': lambda C, act_fun:\n        SiMLP(C, C, act_fun=act_fun, positive=-1),\n\n    'skip_connect': lambda C, act_fun:\n        Identity(positive=0),\n    'skip_connect_p': lambda C, act_fun:\n        Identity(positive=1),\n    'skip_connect_n': lambda C, act_fun:\n        Identity(positive=-1),\n}\n\nOPS = {\n    'avg_pool_3x3': lambda C, stride, affine, act_fun: nn.AvgPool2d(3, stride=stride, padding=1,\n                                                                    count_include_pad=False),\n    'conv_3x3': lambda C, stride, affine, act_fun:\n        ReLUConvBN(C_in=C, C_out=C, kernel_size=3, padding=1, stride=stride, affine=affine, act_fun=act_fun, positive=0),\n    'conv_5x5': lambda C, stride, affine, act_fun:\n        ReLUConvBN(C_in=C, C_out=C, kernel_size=5, padding=2, stride=stride, affine=affine, act_fun=act_fun, positive=0),\n    'max_pool_3x3': lambda C, stride, affine, act_fun: nn.MaxPool2d(3, stride=stride, padding=1),\n    'skip_connect': lambda C, stride, affine, act_fun:\n        Identity(positive=0) if stride == 1 else FactorizedReduce(C, C, affine=affine, act_fun=act_fun),\n    'sep_conv_3x3': lambda C, stride, affine, act_fun:\n        SepConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=0),\n    'sep_conv_5x5': lambda C, stride, affine, act_fun:\n        SepConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=0),\n    'sep_conv_7x7': lambda C, stride, affine, act_fun:\n        SepConv(C, C, 7, stride, 3, affine=affine, act_fun=act_fun, positive=0),\n    'dil_conv_3x3': lambda C, stride, affine, act_fun:\n        DilConv(C, C, 3, stride, 2, 2, affine=affine, act_fun=act_fun, positive=0),\n    'dil_conv_5x5': lambda C, stride, affine, act_fun:\n        DilConv(C, C, 5, stride, 4, 2, affine=affine, act_fun=act_fun, positive=0),\n    'def_conv_3x3': lambda C, stride, affine, act_fun:\n        DeformConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=0),\n    'def_conv_5x5': lambda C, stride, affine, act_fun:\n        DeformConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=0),\n\n    'avg_pool_3x3_p': lambda C, stride, affine, act_fun: nn.Sequential(\n        nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),\n        SiReLU(positive=1)\n    ),\n    'max_pool_3x3_p': lambda C, stride, affine, act_fun: nn.Sequential(\n        nn.MaxPool2d(3, stride=stride, padding=1),\n        SiReLU(positive=1)\n    ),\n    'conv_3x3_p': lambda C, stride, affine, act_fun:\n        ReLUConvBN(C_in=C, C_out=C, kernel_size=3, padding=1, stride=stride, affine=affine, act_fun=act_fun, positive=1),\n    'conv_5x5_p': lambda C, stride, affine, act_fun:\n        ReLUConvBN(C_in=C, C_out=C, kernel_size=5, padding=2, stride=stride, affine=affine, act_fun=act_fun, positive=1),\n    'skip_connect_p': lambda C, stride, affine, act_fun:\n        Identity(positive=1) if stride == 1 else FactorizedReduce(C, C, affine=affine, act_fun=act_fun, positive=1),\n    'sep_conv_3x3_p': lambda C, stride, affine, act_fun:\n        SepConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=1),\n    'sep_conv_5x5_p': lambda C, stride, affine, act_fun:\n        SepConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=1),\n    'sep_conv_7x7_p': lambda C, stride, affine, act_fun:\n        SepConv(C, C, 7, stride, 3, affine=affine, act_fun=act_fun, positive=1),\n    'dil_conv_3x3_p': lambda C, stride, affine, act_fun:\n        DilConv(C, C, 3, stride, 2, 2, affine=affine, act_fun=act_fun, positive=1),\n    'dil_conv_5x5_p': lambda C, stride, affine, act_fun:\n        DilConv(C, C, 5, stride, 4, 2, affine=affine, act_fun=act_fun, positive=1),\n    'def_conv_3x3_p': lambda C, stride, affine, act_fun:\n        DeformConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=1),\n    'def_conv_5x5_p': lambda C, stride, affine, act_fun:\n        DeformConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=1),\n\n    'avg_pool_3x3_n': lambda C, stride, affine, act_fun: nn.Sequential(\n        nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),\n        SiReLU(positive=-1)\n    ),\n    'max_pool_3x3_n': lambda C, stride, affine, act_fun: nn.Sequential(\n            nn.MaxPool2d(3, stride=stride, padding=1),\n            SiReLU(positive=-1)\n    ),\n    'conv_3x3_n': lambda C, stride, affine, act_fun:\n        ReLUConvBN(C_in=C, C_out=C, kernel_size=3, padding=1, stride=stride, affine=affine, act_fun=act_fun, positive=-1),\n    'conv_5x5_n': lambda C, stride, affine, act_fun:\n        ReLUConvBN(C_in=C, C_out=C, kernel_size=5, padding=2, stride=stride, affine=affine, act_fun=act_fun, positive=-1),\n    'skip_connect_n': lambda C, stride, affine, act_fun:\n        Identity(positive=-1) if stride == 1 else FactorizedReduce(C, C, affine=affine, act_fun=act_fun, positive=-1),\n    'sep_conv_3x3_n': lambda C, stride, affine, act_fun:\n        SepConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=-1),\n    'sep_conv_5x5_n': lambda C, stride, affine, act_fun:\n        SepConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=-1),\n    'sep_conv_7x7_n': lambda C, stride, affine, act_fun:\n        SepConv(C, C, 7, stride, 3, affine=affine, act_fun=act_fun, positive=-1),\n    'dil_conv_3x3_n': lambda C, stride, affine, act_fun:\n        DilConv(C, C, 3, stride, 2, 2, affine=affine, act_fun=act_fun, positive=-1),\n    'dil_conv_5x5_n': lambda C, stride, affine, act_fun:\n        DilConv(C, C, 5, stride, 4, 2, affine=affine, act_fun=act_fun, positive=-1),\n    'def_conv_3x3_n': lambda C, stride, affine, act_fun:\n        DeformConv(C, C, 3, stride, 1, affine=affine, act_fun=act_fun, positive=-1),\n    'def_conv_5x5_n': lambda C, stride, affine, act_fun:\n        DeformConv(C, C, 5, stride, 2, affine=affine, act_fun=act_fun, positive=-1),\n\n    'conv_7x1_1x7': lambda C, stride, affine, act_fun: nn.Sequential(\n        # nn.ReLU(inplace=False),\n        act_fun(),\n        nn.Conv2d(C, C, (1, 7), stride=(1, stride),\n                  padding=(0, 3), bias=False),\n        nn.Conv2d(C, C, (7, 1), stride=(stride, 1),\n                  padding=(3, 0), bias=False),\n        nn.BatchNorm2d(C, affine=affine)\n    ),\n    'skip': lambda C, stride, affine, act_fun:\n        Zero(stride) if stride == 1 else FactorizedReduce(C, C, affine=affine, act_fun=act_fun, positive=1),\n    'transformer': lambda C, stride, affine, act_fun:\n        FactorizedReduce(\n            C, C, affine=affine, act_fun=act_fun) if stride != 1 else TransformerEncoderLayer(C),\n}\n\n\nclass SiMLP(nn.Module):\n    def __init__(self, c_in, c_out, act_fun=nn.ReLU, positive=0, *args, **kwargs):\n        super(SiMLP, self).__init__()\n        self.op = nn.Sequential(\n            nn.Linear(c_in, c_out, bias=True),\n            act_fun()\n        )\n        self.positive = positive\n\n    def forward(self, x):\n        out = self.op(si_relu(x, self.positive))\n        return out\n\n\n\n\nclass DilConv(nn.Module):\n    \"\"\"\n    Dilation Convolution ： ReLU -> DilConv -> Conv2d -> BatchNorm2d\n    \"\"\"\n\n    def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True, act_fun=nn.ReLU, positive=0):\n        super(DilConv, self).__init__()\n        self.op = nn.Sequential(\n            # nn.ReLU(inplace=False),\n            act_fun(),\n            nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,\n                      groups=C_in, bias=False),\n            nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),\n            nn.BatchNorm2d(C_out, affine=affine),\n        )\n        self.positive = positive\n        # if positive == -1:\n        #     weight_init(self.op)\n\n    def forward(self, x):\n        out = self.op(x)\n        return si_relu(out, self.positive)\n\n\nclass SepConv(nn.Module):\n\n    def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True, act_fun=nn.ReLU, positive=0):\n        super(SepConv, self).__init__()\n        self.op = nn.Sequential(\n            # nn.ReLU(inplace=False),\n            act_fun(),\n            nn.Conv2d(C_in, C_in, kernel_size=kernel_size,\n                      stride=stride, padding=padding, groups=C_in, bias=False),\n            nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),\n            nn.BatchNorm2d(C_in, affine=affine),\n            nn.ReLU(inplace=False),\n            nn.Conv2d(C_in, C_in, kernel_size=kernel_size,\n                      stride=1, padding=padding, groups=C_in, bias=False),\n            nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),\n            nn.BatchNorm2d(C_out, affine=affine),\n        )\n        self.positive = positive\n        # if positive == -1:\n        #     weight_init(self.op)\n\n    def forward(self, x):\n        out = self.op(x)\n        return si_relu(out, self.positive)\n\n\nclass Identity(nn.Module):\n\n    def __init__(self, positive=0):\n        super(Identity, self).__init__()\n        self.positive = positive\n\n    def forward(self, x):\n        return si_relu(x, self.positive)\n\n\nclass Zero(nn.Module):\n\n    def __init__(self, stride):\n        super(Zero, self).__init__()\n        self.stride = stride\n\n    def forward(self, x):\n        if self.stride == 1:\n            return x.mul(0.)\n        return x[:, :, ::self.stride, ::self.stride].mul(0.)  # N * C * W * H\n\n\nclass FactorizedReduce(nn.Module):\n\n    def __init__(self, C_in, C_out, affine=True, act_fun=nn.ReLU, positive=0):\n        super(FactorizedReduce, self).__init__()\n        assert C_out % 2 == 0\n        # self.relu = nn.ReLU(inplace=False)\n        self.activation = act_fun()\n        self.conv_1 = nn.Conv2d(C_in, C_out // 2, 3,stride=2, padding=1, bias=False)\n        self.conv_2 = nn.Conv2d(C_in, C_out // 2, 3,stride=2, padding=1, bias=False)\n        self.bn = nn.BatchNorm2d(C_out, affine=affine)\n        self.positive = positive\n        # if positive == -1:\n        #     weight_init(self.op)\n\n    def forward(self, x):\n        # x = self.relu(x)\n        x = self.activation(x)\n        out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1)\n        out = self.bn(out)\n        out = si_relu(out, self.positive)\n        return out\n\nclass F0(nn.Module):\n\n    def __init__(self, C_in, C_out, affine=True, act_fun=nn.ReLU, positive=0):\n        super(F0, self).__init__()\n        assert C_out % 2 == 0\n        # self.relu = nn.ReLU(inplace=False)\n        self.activation = act_fun()\n        self.op=nn.Conv2d(C_out, C_out, 3, stride=2, padding=1, bias=False)\n        self.conv_1 = nn.Conv2d(C_in, C_out // 2, 3,stride=2, padding=1, bias=False)\n        self.conv_2 = nn.Conv2d(C_in, C_out // 2, 3,stride=2, padding=1, bias=False)\n        self.bn = nn.BatchNorm2d(C_out, affine=affine)\n        self.positive = positive\n        # if positive == -1:\n        #     weight_init(self.op)\n\n    def forward(self, x):\n        # x = self.relu(x)\n        x = self.activation(x)\n        out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1)\n        out = self.bn(out)\n        out = si_relu(out, self.positive)\n        out=self.op(out)\n        return out\n\n\nclass ReLUConvBN(nn.Module):\n    \"\"\"\n    ReLu -> Conv2d -> BatchNorm2d\n    \"\"\"\n\n    def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True, act_fun=nn.ReLU, positive=0):\n        super(ReLUConvBN, self).__init__()\n        self.op = nn.Sequential(\n            # nn.ReLU(inplace=False),\n            act_fun(),\n            nn.Conv2d(C_in, C_out, kernel_size, stride=stride,\n                      padding=padding, bias=False),\n            nn.BatchNorm2d(C_out, affine=affine)\n        )\n        self.positive = positive\n        # if positive == -1:\n        #     weight_init(self.op)\n\n    def forward(self, x):\n        out = self.op(x)\n        return si_relu(out, self.positive)\n\n\n# class DeformConv(nn.Module):\n#     def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True, act_fun=nn.ReLU, positive=0):\n#         super(DeformConv, self).__init__()\n#         self.op = nn.Sequential(\n#             # nn.ReLU(inplace=False),\n#             act_fun(),\n#             DeformConvPack(C_in, C_out, kernel_size=kernel_size,\n#                            stride=stride, padding=padding, bias=True),\n#             nn.BatchNorm2d(C_out, affine=affine)\n#         )\n#         self.positive = positive\n#         # if positive == -1:\n#         #     weight_init(self.op)\n\n#     def forward(self, x):\n#         out = self.op(x)\n#         return si_relu(out, self.positive)\n\n\nclass Attention(Module):\n    \"\"\"\n    Obtained from: github.com:rwightman/pytorch-image-models\n    \"\"\"\n\n    def __init__(self, dim, num_heads=4, attention_dropout=0.1, projection_dropout=0.1):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // self.num_heads\n        self.scale = head_dim ** -0.5\n\n        self.qkv = Linear(dim, dim * 3, bias=False)\n        self.attn_drop = Dropout(attention_dropout)\n        self.proj = Linear(dim, dim)\n        self.proj_drop = Dropout(projection_dropout)\n\n    def forward(self, x):\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //\n                                  self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass TransformerEncoderLayer(Module):\n    \"\"\"\n    Inspired by torch.nn.TransformerEncoderLayer and\n    rwightman's timm package.\n    \"\"\"\n\n    def __init__(self, d_model, nhead=4, dim_feedforward=256, dropout=0.1,\n                 attention_dropout=0.1, drop_path_rate=0.1):\n        super(TransformerEncoderLayer, self).__init__()\n        self.pre_norm = LayerNorm(d_model)\n        self.self_attn = Attention(dim=d_model, num_heads=nhead,\n                                   attention_dropout=attention_dropout, projection_dropout=dropout)\n        dim_feedforward = d_model\n        self.linear1 = Linear(d_model, dim_feedforward)\n        self.dropout1 = Dropout(dropout)\n        self.norm1 = LayerNorm(d_model)\n        self.linear2 = Linear(dim_feedforward, d_model)\n        self.dropout2 = Dropout(dropout)\n\n        self.drop_path = DropPath(\n            drop_path_rate) if drop_path_rate > 0 else Identity()\n\n        self.activation = F.gelu\n\n    def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor:\n        # print(src.shape)\n        c = src.shape[-1]\n        src = rearrange(src, 'b d r c -> b (r c) d')\n        # print(src.shape)\n        src = src + self.drop_path(self.self_attn(self.pre_norm(src)))\n        src = self.norm1(src)\n        src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))\n        src = src + self.drop_path(self.dropout2(src2))\n        src = rearrange(src, 'b (r c) d -> b d r c', c=c)\n        return src\n\n\ndef drop_path(x, drop_prob: float = 0., training: bool = False):\n    \"\"\"\n    Obtained from: github.com:rwightman/pytorch-image-models\n    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\n    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,\n    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...\n    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for\n    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use\n    'survival rate' as the argument.\n    \"\"\"\n    if drop_prob == 0. or not training:\n        return x\n    keep_prob = 1 - drop_prob\n    # work with diff dim tensors, not just 2D ConvNets\n    shape = (x.shape[0],) + (1,) * (x.ndim - 1)\n    random_tensor = keep_prob + \\\n        torch.rand(shape, dtype=x.dtype, device=x.device)\n    random_tensor.floor_()  # binarize\n    output = x.div(keep_prob) * random_tensor\n    return output\n\n\nclass DropPath(Module):\n    \"\"\"\n    Obtained from: github.com:rwightman/pytorch-image-models\n    Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).\n    \"\"\"\n\n    def __init__(self, drop_prob=None):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n\n    def forward(self, x):\n        return drop_path(x, self.drop_prob, self.training)\n"
  },
  {
    "path": "examples/Structure_Evolution/MSE-NAS/readme.md",
    "content": "\n# Brain-Inspired Multi-scale Evolutionary Architectures for Spiking Neural Networks —— Based on BrainCog #\n\n\n\n## Requirments ##\n* numpy\n* pytorch >= 1.12.0\n* pymoo = 0.4.0\n* BrainCog\n\n## Run ##\n\n```python evolution.py```\n\n## Citation ##\n\nIf you find the code and dataset useful in your research, please consider citing:\n```\n@article{pan2024brain,\n  title={Brain-Inspired Multi-Scale Evolutionary Neural Architecture Search for Deep Spiking Neural Networks},\n  author={Pan, Wenxuan and Zhao, Feifei and Shen, Guobin and Han, Bing and Zeng, Yi},\n  journal={IEEE Transactions on Evolutionary Computation},\n  year={2024},\n  publisher={IEEE}\n}\n\n@article{zeng2023braincog,\n  title={BrainCog: A spiking neural network based, brain-inspired cognitive intelligence engine for brain-inspired AI and brain simulation},\n  author={Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and others},\n  journal={Patterns},\n  volume={4},\n  number={8},\n  year={2023},\n  publisher={Elsevier}\n}\n```\n"
  },
  {
    "path": "examples/Structure_Evolution/MSE-NAS/tm.py",
    "content": "import sys\nsys.path.insert(0, '/home/panwenxuan/back')\nimport numpy as np\nimport argparse\nimport time\nimport obj\nimport timm.models\nimport yaml\nimport os\nimport logging\nfrom random import choice\nfrom collections import OrderedDict\nfrom contextlib import suppress\nfrom datetime import datetime\nfrom micro_encoding import ops\nfrom braincog.base.node.node import *\nfrom braincog.utils import *\nfrom braincog.base.utils.criterions import *\nfrom braincog.datasets.datasets import *\nfrom braincog.model_zoo.resnet import *\nfrom braincog.model_zoo.convnet import *\nfrom braincog.utils import save_feature_map, setup_seed\nfrom braincog.base.utils.visualization import plot_tsne_3d, plot_tsne, plot_confusion_matrix\nimport micro_encoding\nfrom pymop.problem import Problem\nimport torch\nimport torch.nn as nn\nimport torchvision.utils\nfrom torch.nn.parallel import DistributedDataParallel as NativeDDP\nfrom pymoo.optimize import minimize\nfrom utils import data_transforms\nfrom timm.data import create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset\nfrom timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model\nfrom timm.utils import *\nfrom timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy\nfrom timm.optim import create_optimizer\nfrom timm.scheduler import create_scheduler\nfrom timm.utils import ApexScaler, NativeScaler\n# from tn import TinyImageNet\nimport copy\nfrom sklearn.metrics import confusion_matrix,roc_auc_score\nfrom sklearn.preprocessing import label_binarize\n\ndef train_motifs(args,gen,arch_dir,genome,_logger,args_text,devices,bits):\n    if args.bns:\n        from cellmodel import NetworkCIFAR\n    else:\n        from cell123model import NetworkCIFAR\n    test_motifs,ids = micro_encoding.decode_motif(args.layers,bits,genome)\n\n\n    if gen==-1:\n        args.epochs=args.trainning_epochs\n    else:\n        args.epochs=args.eval_epochs\n\n    all_best=[]\n    try:\n        model = create_model(\n            args.model,\n            pretrained=args.pretrained,\n            num_classes=args.num_classes,\n            dataset=args.dataset,\n            step=args.step,\n            encode_type=args.encode,\n            node_type=eval(args.node_type),\n            threshold=args.threshold,\n            tau=args.tau,\n            sigmoid_thres=args.sigmoid_thres,\n            requires_thres_grad=args.requires_thres_grad,\n            spike_output=not args.no_spike_output,\n            C=args.init_channels,\n            layers=args.layers,\n            auxiliary=args.auxiliary,\n            motif=test_motifs,\n            parse_method=args.parse_method,\n            act_fun=args.act_fun,\n            temporal_flatten=args.temporal_flatten,\n            layer_by_layer=args.layer_by_layer,\n            n_groups=args.n_groups,\n            cell_type=genome[-1],\n        )\n\n        if 'dvs' in args.dataset:\n            args.channels = 2\n        # elif 'mnist' in args.dataset:\n        #     args.channels = 1\n        else:\n            args.channels = 3\n        # flops, params = profile(model, inputs=(torch.randn(1, args.channels, args.event_size, args.event_size),), verbose=False)\n        # _logger.info('flops = %fM', flops / 1e6)\n        # _logger.info('param size = %fM', params / 1e6)\n\n\n        # _logger.info(model)\n\n\n        linear_scaled_lr = args.lr * args.batch_size * args.world_size / 1024.0\n        args.lr = linear_scaled_lr\n        _logger.info(\"learning rate is %f\" % linear_scaled_lr)\n\n        if args.local_rank == 0:\n            sumpram=sum([m.numel() for m in model.parameters()])\n            _logger.info('Model %s created, param count: %d' %\n                        (args.model, sumpram))\n            # return\n\n            # if sumpram > 15000000:\n            #     return 0,0\n            \n        num_aug_splits = 0\n        if args.aug_splits > 0:\n            assert args.aug_splits > 1, 'A split of 1 makes no sense'\n            num_aug_splits = args.aug_splits\n\n        if args.split_bn:\n            assert num_aug_splits > 1 or args.resplit\n            model = convert_splitbn_model(model, max(num_aug_splits, 2))\n\n        use_amp = None\n        if args.amp:\n            # for backwards compat, `--amp` arg tries apex before native amp\n            if has_apex:\n                args.apex_amp = True\n            elif has_native_amp:\n                args.native_amp = True\n        if args.apex_amp and has_apex:\n            use_amp = 'apex'\n        elif args.native_amp and has_native_amp:\n            use_amp = 'native'\n        elif args.apex_amp or args.native_amp:\n            _logger.warning(\"Neither APEX or native Torch AMP is available, using float32. \"\n                            \"Install NVIDA apex or upgrade to PyTorch 1.6\")\n\n        optimizer = create_optimizer(args, model)\n\n        amp_autocast = suppress  # do nothing\n        loss_scaler = None\n        if use_amp == 'apex':\n            model, optimizer = amp.initialize(model, optimizer, opt_level='O1')\n            loss_scaler = ApexScaler()\n            if args.local_rank == 0:\n                _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')\n        elif use_amp == 'native':\n            amp_autocast = torch.cuda.amp.autocast\n            loss_scaler = NativeScaler()\n            if args.local_rank == 0:\n                _logger.info('Using native Torch AMP. Training in mixed precision.')\n        else:\n            if args.local_rank == 0:\n                _logger.info('AMP not enabled. Training in float32.')\n\n        # optionally resume from a checkpoint\n        resume_epoch = None\n        if args.resume and args.eval_checkpoint == '':\n            args.eval_checkpoint = args.resume\n        if args.resume:\n            args.eval = True\n            checkpoint = torch.load(args.resume, map_location='cpu')\n            model.load_state_dict(checkpoint['state_dict'], False)\n            resume_epoch = resume_checkpoint(\n                model, args.resume,\n                optimizer=None if args.no_resume_opt else optimizer,\n                loss_scaler=None if args.no_resume_opt else loss_scaler,\n                log_info=args.local_rank == 0)\n            # print(model.get_attr('mu'))\n            # print(model.get_attr('sigma'))\n\n        if args.num_gpu > 1:\n            if use_amp == 'apex':\n                _logger.warning(\n                    'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.')\n                use_amp = None\n            model = nn.DataParallel(model, device_ids=devices).cuda()\n            assert not args.channels_last, \"Channels last not supported with DP, use DDP.\"\n        else:\n            model = model.cuda()\n            if args.channels_last:\n                model = model.to(memory_format=torch.channels_last)\n\n        optimizer = create_optimizer(args, model)\n\n        if args.critical_loss or args.spike_rate:\n            if args.num_gpu>1:\n                model.module.set_requires_fp(True)\n            else:\n                model.set_requires_fp(True)\n\n        model_ema = None\n        if args.model_ema:\n            # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper\n            model_ema = ModelEma(\n                model,\n                decay=args.model_ema_decay,\n                device='cpu' if args.model_ema_force_cpu else '',\n                resume=args.resume)\n\n        if args.node_resume:\n            ckpt = torch.load(args.node_resume, map_location='cpu')\n            model.load_node_weight(ckpt, args.node_trainable)\n\n        model_without_ddp = model\n        if args.distributed:\n            if args.sync_bn:\n                assert not args.split_bn\n                try:\n                    if has_apex and use_amp != 'native':\n                        # Apex SyncBN preferred unless native amp is activated\n                        model = convert_syncbn_model(model)\n                    else:\n                        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)\n                    if args.local_rank == 0:\n                        _logger.info(\n                            'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '\n                            'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')\n                except Exception as e:\n                    _logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')\n            if has_apex and use_amp != 'native':\n                # Apex DDP preferred unless native amp is activated\n                if args.local_rank == 0:\n                    _logger.info(\"Using NVIDIA APEX DistributedDataParallel.\")\n                model = ApexDDP(model, delay_allreduce=True)\n            else:\n                if args.local_rank == 0:\n                    _logger.info(\"Using native Torch DistributedDataParallel.\")\n                model = NativeDDP(model, device_ids=[args.local_rank],\n                                find_unused_parameters=True)  # can use device str in Torch >= 1.1\n            model_without_ddp = model.module\n        # NOTE: EMA model does not need to be wrapped by DDP\n\n        lr_scheduler, num_epochs = create_scheduler(args, optimizer)\n        start_epoch = 0\n        if args.start_epoch is not None:\n            # a specified start_epoch will always override the resume epoch\n            start_epoch = args.start_epoch\n        elif resume_epoch is not None:\n            start_epoch = resume_epoch\n        if lr_scheduler is not None and start_epoch > 0:\n            lr_scheduler.step(start_epoch)\n\n        if args.local_rank == 0:\n            _logger.info('Scheduled epochs: {}'.format(num_epochs))\n\n        # now config only for imnet\n        data_config = resolve_data_config(vars(args), model=model, verbose=False)\n        loader_train, loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % args.dataset)(\n            batch_size=args.batch_size,\n            step=args.step,\n            args=args,\n            _logge=_logger,\n            data_config=data_config,\n            num_aug_splits=num_aug_splits,\n            size=args.event_size,\n            mix_up=args.mix_up,\n            cut_mix=args.cut_mix,\n            event_mix=args.event_mix,\n            beta=args.cutmix_beta,\n            prob=args.cutmix_prob,\n            num=args.cutmix_num,\n            noise=args.cutmix_noise,\n            num_classes=args.num_classes,\n            rand_aug=args.rand_aug,\n            randaug_n=args.randaug_n,\n            randaug_m=args.randaug_m,\n            temporal_flatten=args.temporal_flatten,\n            portion=args.train_portion,\n            _logger=_logger,\n\n        )\n\n\n        if args.loss_fn == 'mse':\n            train_loss_fn = UnilateralMse(1.)\n            validate_loss_fn = UnilateralMse(1.)\n\n        else:\n            if args.jsd:\n                assert num_aug_splits > 1  # JSD only valid with aug splits set\n                train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()\n            elif mixup_active:\n                # smoothing is handled with mixup target transform\n                train_loss_fn = SoftTargetCrossEntropy().cuda()\n            elif args.smoothing:\n                train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()\n            else:\n                train_loss_fn = nn.CrossEntropyLoss().cuda()\n\n            validate_loss_fn = nn.CrossEntropyLoss().cuda()\n\n        if args.loss_fn == 'mix':\n            train_loss_fn = MixLoss(train_loss_fn)\n            validate_loss_fn = MixLoss(validate_loss_fn)\n\n        eval_metric = args.eval_metric\n        best_metric = None\n        best_epoch = None\n\n        if args.eval:  # evaluate the model\n            if args.distributed:\n                state_dict = torch.load(args.eval_checkpoint)['state_dict_ema']\n                new_state_dict = OrderedDict()\n                # add module prefix for DDP\n                for k, v in state_dict.items():\n                    k = 'module.' + k\n                    new_state_dict[k] = v\n\n                model.load_state_dict(new_state_dict)\n            # else:\n            #     load_checkpoint(model, args.eval_checkpoint, args.model_ema)\n            for i in range(1):\n                val_metrics = validate(start_epoch, model, loader_eval, validate_loss_fn, args,_logger,arch_dir,\n                                    visualize=args.visualize, spike_rate=args.spike_rate,\n                                    tsne=args.tsne, conf_mat=args.conf_mat)\n                print(f\"Top-1 accuracy of the model is: {val_metrics['top1']:.1f}%\")\n            # return\n\n        saver = None\n        if args.local_rank == 0:\n            decreasing = True if eval_metric == 'loss' else False\n\n            saver = CheckpointSaver(\n                model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,\n                checkpoint_dir=arch_dir, recovery_dir=arch_dir, decreasing=decreasing)\n            with open(os.path.join(arch_dir, 'args.yaml'), 'w') as f:\n                f.write(args_text)\n\n        try:  # train the model\n            if args.reset_drop:\n                model_without_ddp.reset_drop_path(0.0)\n            for epoch in range(start_epoch, args.epochs):\n                if epoch == 0 and args.reset_drop:\n                    model_without_ddp.reset_drop_path(args.drop_path)\n                # if epoch == 3 and best_metric<5:\n                #     return 0,0\n\n                if args.distributed:\n                    loader_train.sampler.set_epoch(epoch)\n\n                train_metrics = train_epoch(\n                    epoch, model, loader_train, optimizer, train_loss_fn, args,_logger=_logger,\n                    lr_scheduler=lr_scheduler, saver=saver, output_dir=arch_dir,\n                    amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)\n\n                if args.distributed and args.dist_bn in ('broadcast', 'reduce'):\n                    if args.local_rank == 0:\n                        _logger.info(\"Distributing BatchNorm running means and vars\")\n                    distribute_bn(model, args.world_size, args.dist_bn == 'reduce')\n\n                eval_metrics = validate(epoch, model, loader_eval, validate_loss_fn, args,_logger, arch_dir,amp_autocast=amp_autocast,\n                                        visualize=args.visualize, spike_rate=args.spike_rate,\n                                        tsne=args.tsne, conf_mat=args.conf_mat)\n\n                if model_ema is not None and not args.model_ema_force_cpu:\n                    if args.distributed and args.dist_bn in ('broadcast', 'reduce'):\n                        distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')\n                    ema_eval_metrics = validate(\n                        epoch, model_ema.ema, loader_eval, validate_loss_fn, args, _logger,arch_dir,amp_autocast=amp_autocast, log_suffix=' (EMA)',\n                        visualize=args.visualize, spike_rate=args.spike_rate,\n                        tsne=args.tsne, conf_mat=args.conf_mat)\n                    eval_metrics = ema_eval_metrics\n\n                if lr_scheduler is not None:\n                    # step LR for next epoch\n                    lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])\n\n                update_summary(\n                    epoch, train_metrics, eval_metrics, os.path.join(arch_dir, 'summary.csv'),\n                    write_header=best_metric is None)\n\n\n                best_metric, best_epoch = eval_metrics[eval_metric],epoch\n                _logger.info('Test: {0} '.format(best_metric))\n                all_best.append(best_metric)\n\n            f=open(os.path.join(arch_dir, 'direct.txt'), 'a')\n            f.write(str(best_metric))\n            f.write('\\n')\n            f.close()\n\n            f=open(os.path.join(arch_dir, 'direct_genome.txt'), 'a')\n            f.write(\",\".join(str(k) for k in genome))\n            f.write('\\n')\n            f.close()\n\n        except KeyboardInterrupt:\n            pass\n    except MemoryError:\n        return -10000, all_best\n    except RuntimeError:\n        return -10000, all_best\n\n    return best_metric,all_best\n\ndef train_epoch(\n        epoch, model, loader, optimizer, loss_fn, args,_logger,\n        lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,\n        loss_scaler=None, model_ema=None, mixup_fn=None):\n    if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:\n        if args.prefetcher and loader.mixup_enabled:\n            loader.mixup_enabled = False\n        elif mixup_fn is not None:\n            mixup_fn.mixup_enabled = False\n\n    model.drop_path_prob = args.drop_path_prob * epoch / args.epochs\n\n    second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order\n    batch_time_m = AverageMeter()\n    data_time_m = AverageMeter()\n    losses_m = AverageMeter()\n    closses_m = AverageMeter()\n    top1_m = AverageMeter()\n    top5_m = AverageMeter()\n\n    model.train()\n\n\n    end = time.time()\n    last_idx = len(loader) - 1\n    num_updates = epoch * len(loader)\n    for batch_idx, (inputs, target) in enumerate(loader):\n        last_batch = batch_idx == last_idx\n        data_time_m.update(time.time() - end)\n        if not args.prefetcher or args.dataset != 'imnet':\n            inputs, target = inputs.type(torch.FloatTensor).cuda(), target.cuda()\n\n            \n            if mixup_fn is not None:\n                inputs, target = mixup_fn(inputs, target)\n        if args.channels_last:\n            inputs = inputs.contiguous(memory_format=torch.channels_last)\n        with amp_autocast():\n            output = model(inputs)\n            loss = loss_fn(output, target)\n        if not (args.cut_mix | args.mix_up | args.event_mix) and args.dataset != 'imnet':\n            # print(output.shape, target.shape)\n            acc1, acc5 = accuracy(output, target, topk=(1, 5))\n            # acc1, = accuracy(output, target)\n        else:\n            acc1, acc5 = torch.tensor([0.]), torch.tensor([0.])\n\n        closs = torch.tensor([0.], device=loss.device)\n\n        if args.critical_loss:\n            closs = calc_critical_loss(model)\n\n        loss = loss + .1 * closs\n\n        spike_rate_avg_layer_str = ''\n        threshold_str = ''\n        if not args.distributed:\n            losses_m.update(loss.item(), inputs.size(0))\n            top1_m.update(acc1.item(), inputs.size(0))\n            top5_m.update(acc5.item(), inputs.size(0))\n            closses_m.update(closs.item(), inputs.size(0))\n            if args.num_gpu>1:\n                spike_rate_avg_layer = model.module.get_fire_rate().tolist()\n                spike_rate_avg_layer_str = ['{:.3f}'.format(i) for i in spike_rate_avg_layer]\n                threshold = model.module.get_threshold()\n            \n            else:\n                spike_rate_avg_layer = model.get_fire_rate().tolist()\n                spike_rate_avg_layer_str = ['{:.3f}'.format(i) for i in spike_rate_avg_layer]\n                threshold = model.get_threshold()                \n            \n            threshold_str = ['{:.3f}'.format(i) for i in threshold]\n\n        optimizer.zero_grad()\n        if loss_scaler is not None:\n            loss_scaler(\n                loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order)\n        else:\n            loss.backward(create_graph=second_order)\n\n            if args.noisy_grad != 0.:\n                random_gradient(model, args.noisy_grad)\n            if args.clip_grad is not None:\n                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)\n            if args.opt == 'lamb':\n                optimizer.step(epoch=epoch)\n            else:\n                optimizer.step()\n\n        torch.cuda.synchronize()\n        if model_ema is not None:\n            model_ema.update(model)\n        num_updates += 1\n\n        batch_time_m.update(time.time() - end)\n        if last_batch or batch_idx % args.log_interval == 0:\n            lrl = [param_group['lr'] for param_group in optimizer.param_groups]\n            lr = sum(lrl) / len(lrl)\n\n            mu_str = ''\n            sigma_str = ''\n            if not args.distributed:\n                if 'Noise' in args.node_type:\n                    mu, sigma = model.get_noise_param()\n                    mu_str = ['{:.3f}'.format(i.detach()) for i in mu]\n                    sigma_str = ['{:.3f}'.format(i.detach()) for i in sigma]\n\n            if args.distributed:\n                reduced_loss = reduce_tensor(loss.data, args.world_size)\n                losses_m.update(reduced_loss.item(), inputs.size(0))\n                closses_m.update(reduced_loss.item(), inputs.size(0))\n\n            if args.local_rank == 0:\n                if args.distributed:\n                    _logger.info(\n                        'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '\n                        'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f})  '\n                        'cLoss: {closs.val:>9.6f} ({closs.avg:>6.4f})  '\n                        'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '\n                        'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})  '\n                        'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '\n                        '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '\n                        'LR: {lr:.3e}  '\n                        'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(\n                            epoch,\n                            batch_idx, len(loader),\n                            100. * batch_idx / last_idx,\n                            loss=losses_m,\n                            closs=closses_m,\n                            top1=top1_m,\n                            top5=top5_m,\n                            batch_time=batch_time_m,\n                            rate=inputs.size(0) * args.world_size / batch_time_m.val,\n                            rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,\n                            lr=lr,\n                            data_time=data_time_m\n                        ))\n                else:\n                    _logger.info(\n                        'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '\n                        'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f})  '\n                        'cLoss: {closs.val:>9.6f} ({closs.avg:>6.4f})  '\n                        'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '\n                        'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})  '\n                        'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '\n                        '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '\n                        'LR: {lr:.3e}  '\n                        'Data: {data_time.val:.3f} ({data_time.avg:.3f})\\n'\n                        # 'Fire_rate: {spike_rate}\\n'\n                        # 'Thres: {threshold}\\n'\n                        # 'Mu: {mu_str}\\n'\n                        # 'Sigma: {sigma_str}\\n'\n                        .format(\n                            epoch,\n                            batch_idx, len(loader),\n                            100. * batch_idx / last_idx,\n                            loss=losses_m,\n                            closs=closses_m,\n                            top1=top1_m,\n                            top5=top5_m,\n                            batch_time=batch_time_m,\n                            rate=inputs.size(0) * args.world_size / batch_time_m.val,\n                            rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,\n                            lr=lr,\n                            data_time=data_time_m,\n                            # spike_rate=spike_rate_avg_layer_str,\n                            # threshold=threshold_str,\n                            # mu_str=mu_str,\n                            # sigma_str=sigma_str\n                        ))\n\n                if args.save_images and output_dir:\n                    torchvision.utils.save_image(\n                        inputs,\n                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),\n                        padding=0,\n                        normalize=True)\n\n        if saver is not None and args.recovery_interval and (\n                last_batch or (batch_idx + 1) % args.recovery_interval == 0):\n            saver.save_recovery(epoch, batch_idx=batch_idx)\n\n        if lr_scheduler is not None:\n            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)\n\n        end = time.time()\n\n    if hasattr(optimizer, 'sync_lookahead'):\n        optimizer.sync_lookahead()\n\n    return OrderedDict([('loss', losses_m.avg)])\n\n\n\ndef validate(epoch, model, loader, loss_fn, args,_logger, arch_dir,amp_autocast=suppress,\n             log_suffix='', visualize=False, spike_rate=False, tsne=False, conf_mat=False):\n    batch_time_m = AverageMeter()\n    losses_m = AverageMeter()\n    closses_m = AverageMeter()\n    top1_m = AverageMeter()\n    top5_m = AverageMeter()\n    \n    model.eval()\n\n    feature_vec = []\n    feature_cls = []\n    logits_vec = []\n    labels_vec = []\n\n    end = time.time()\n    last_idx = len(loader) - 1\n\n    all_probs = np.array([]).reshape(0, args.num_classes)\n    all_targets = np.array([])\n\n    attack=False\n    with torch.no_grad():\n        for batch_idx, (inputs, target) in enumerate(loader):\n            # inputs = inputs.type(torch.float64)\n            last_batch = batch_idx == last_idx\n            if not args.prefetcher or args.dataset != 'imnet':\n                inputs = inputs.type(torch.FloatTensor).cuda()\n                target = target.cuda()\n\n            if attack:\n                data2 = copy.deepcopy(inputs)\n                inputs = pgd_attack(model, data2, target, target.device, nn.CrossEntropyLoss())\n\n            if args.channels_last:\n                inputs = inputs.contiguous(memory_format=torch.channels_last)\n\n            if not args.distributed:\n                if (visualize or spike_rate or tsne or conf_mat) and not args.critical_loss:\n                    if args.num_gpu>1:\n                        model.module.set_requires_fp(True)\n                    else:\n                        model.set_requires_fp(True)\n\n                    # if not args.critical_loss:\n                    #     model.set_requires_fp(False)\n\n            with amp_autocast():\n                output = model(inputs)\n            if isinstance(output, (tuple, list)):\n                output = output[0]\n\n            if not args.distributed:\n                if visualize:\n                    x = model.get_fp()\n                    feature_path = os.path.join(arch_dir, 'feature_map')\n                    if os.path.exists(feature_path) is False:\n                        os.mkdir(feature_path)\n                    save_feature_map(x, feature_path)\n                    # if not args.critical_loss:\n                    #     model_config.set_requires_fp(False)\n\n                if tsne:\n                    x = model.get_fp(temporal_info=False)[-1]\n                    x = torch.nn.AdaptiveAvgPool2d((1, 1))(x)\n                    x = x.reshape(x.shape[0], -1)\n                    feature_vec.append(x)\n                    feature_cls.append(target)\n\n                if conf_mat:\n                    logits_vec.append(output)\n                    labels_vec.append(target)\n\n                if spike_rate:\n                    if args.num_gpu>1:\n                        avg, var, spike, avg_per_step = model.module.get_spike_info()\n\n                    else:\n                        avg, var, spike, avg_per_step = model.get_spike_info()\n                    save_spike_info(\n                        os.path.join(arch_dir, 'spike_info.csv'),\n                        epoch, batch_idx,\n                        args.step, avg, var,\n                        spike, avg_per_step)\n\n            # augmentation reduction\n            reduce_factor = args.tta\n            if reduce_factor > 1:\n                output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)\n                target = target[0:target.size(0):reduce_factor]\n\n            loss = loss_fn(output, target)\n            \n            probs = output.softmax(dim=1) #\n\n            acc1, acc5 = accuracy(output, target, topk=(1, 5))\n            # acc1, = accuracy(output, target)\n\n            all_probs = np.vstack([all_probs, probs.detach().cpu().numpy()])\n            all_targets = np.concatenate([all_targets, target.detach().cpu().numpy()])\n\n            closs = torch.tensor([0.], device=loss.device)\n\n            if not args.distributed:\n                if args.num_gpu>1:\n                    spike_rate_avg_layer = model.module.get_fire_rate().tolist()\n                    threshold = model.module.get_threshold()\n                    threshold_str = ['{:.3f}'.format(i) for i in threshold]\n                    spike_rate_avg_layer_str = ['{:.3f}'.format(i) for i in spike_rate_avg_layer]\n                    tot_spike = model.module.get_tot_spike()\n                else:\n                    spike_rate_avg_layer = model.get_fire_rate().tolist()\n                    threshold = model.get_threshold()\n                    threshold_str = ['{:.3f}'.format(i) for i in threshold]\n                    spike_rate_avg_layer_str = ['{:.3f}'.format(i) for i in spike_rate_avg_layer]\n                    tot_spike = model.get_tot_spike()                    \n\n            if args.critical_loss:\n                closs = calc_critical_loss(model)\n            loss = loss + .1 * closs\n\n            if args.distributed:\n                reduced_loss = reduce_tensor(loss.data, args.world_size)\n                acc1 = reduce_tensor(acc1, args.world_size)\n                acc5 = reduce_tensor(acc5, args.world_size)\n            else:\n                reduced_loss = loss.data\n\n            torch.cuda.synchronize()\n\n            losses_m.update(reduced_loss.item(), inputs.size(0))\n            top1_m.update(acc1.item(), output.size(0))\n            top5_m.update(acc5.item(), output.size(0))\n            closses_m.update(closs.item(), inputs.size(0))\n\n            batch_time_m.update(time.time() - end)\n            end = time.time()\n            if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):\n                log_name = 'Test' + log_suffix\n\n                mu_str = ''\n                sigma_str = ''\n                if not args.distributed:\n                    if 'Noise' in args.node_type:\n                        mu, sigma = model.get_noise_param()\n                        mu_str = ['{:.3f}'.format(i.detach()) for i in mu]\n                        sigma_str = ['{:.3f}'.format(i.detach()) for i in sigma]\n\n                if args.distributed:\n                    _logger.info(\n                        '{0}: [{1:>4d}/{2}]  '\n                        'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '\n                        'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '\n                        'cLoss: {closs.val:>7.4f} ({closs.avg:>6.4f})  '\n                        'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})'\n                        'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(\n                            log_name,\n                            batch_idx,\n                            last_idx,\n                            batch_time=batch_time_m,\n                            loss=losses_m,\n                            closs=closses_m,\n                            top1=top1_m,\n                            top5=top5_m,\n                            ))\n                else:\n                    _logger.info(\n                        '{0}: [{1:>4d}/{2}]  '\n                        'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '\n                        'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '\n                        'cLoss: {closs.val:>7.4f} ({closs.avg:>6.4f})  '\n                        'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})'\n                        'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})\\n'\n                        'Fire_rate: {spike_rate}\\n'\n                        'Tot_spike: {tot_spike}\\n'\n                        'Thres: {threshold}\\n'\n                        'Mu: {mu_str}\\n'\n                        'Sigma: {sigma_str}\\n'.format(\n                            log_name,\n                            batch_idx,\n                            last_idx,\n                            batch_time=batch_time_m,\n                            loss=losses_m,\n                            closs=closses_m,\n                            top1=top1_m,\n                            top5=top5_m,\n                            spike_rate=spike_rate_avg_layer_str,\n                            tot_spike=tot_spike,\n                            threshold=threshold_str,\n                            mu_str=mu_str,\n                            sigma_str=sigma_str\n                        ))\n\n    # metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])\n\n    if not args.distributed:\n        if tsne:\n            feature_vec = torch.cat(feature_vec)\n            feature_cls = torch.cat(feature_cls)\n            plot_tsne(feature_vec, feature_cls, os.path.join(arch_dir, 't-sne-2d.eps'))\n            plot_tsne_3d(feature_vec, feature_cls, os.path.join(arch_dir, 't-sne-3d.eps'))\n        if conf_mat:\n            logits_vec = torch.cat(logits_vec)\n            labels_vec = torch.cat(labels_vec)\n            plot_confusion_matrix(logits_vec, labels_vec, os.path.join(arch_dir, 'confusion_matrix.eps'))\n    # 将真实标签二值化，为每个类别创建一个二进制标签\n    all_targets_binarized = label_binarize(all_targets, classes=range(args.num_classes))\n\n    # 使用roc_auc_score的multi_class和average参数来计算平均AUC\n    auc = roc_auc_score(all_targets_binarized, all_probs, multi_class='ovr', average='macro')\n    # print(\"Mean AUC: {:.2f}\".format(auc))\n\n    return OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('auc', auc)])\n\n\n"
  },
  {
    "path": "examples/Structure_Evolution/MSE-NAS/utils.py",
    "content": "import json\nimport matplotlib.pyplot as plt\nimport os\nimport numpy as np\nimport torch\nimport shutil\nimport torchvision.transforms as transforms\nfrom torch.autograd import Variable\nfrom auto_augment import CIFAR10Policy\nfrom braincog.model_zoo.darts.genotypes import PRIMITIVES\n\nforward_edge_num = sum(1 for i in range(3) for n in range(2 + i))\nbackward_edge_num = sum(1 for i in range(3) for n in range(i))\nnum_ops = len(PRIMITIVES)\ntype_num = len(PRIMITIVES) // 2\n# edge_num = [2, 3, 4]\nedge_num = [2, 3, 4, 1, 2]\n\ndef drop_path(x, drop_prob):\n    if drop_prob > 0.:\n        keep_prob = 1. - drop_prob\n        mask = Variable(torch.cuda.FloatTensor(\n            x.size(0), 1, 1, 1).bernoulli_(keep_prob))\n        x.div_(keep_prob)\n        x.mul_(mask)\n    return x\n\nclass AvgrageMeter(object):\n\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.avg = 0\n        self.sum = 0\n        self.cnt = 0\n\n    def update(self, val, n=1):\n        self.sum += val * n\n        self.cnt += n\n        self.avg = self.sum / self.cnt\n\n\ndef accuracy(output, target, topk=(1,)):\n    \"\"\"Compute the top1 and top5 accuracy\n\n\"\"\"\n    maxk = max(topk)\n    batch_size = target.size(0)\n\n    # Return the k largest elements of the given input tensor\n    # along a given dimension -> N * k\n    _, pred = output.topk(maxk, 1, True, True)\n    pred = pred.t()\n    correct = pred.eq(target.view(1, -1).expand_as(pred))\n\n    res = []\n    for k in topk:\n        correct_k = correct[:k].reshape(-1).float().sum(0)\n        res.append(correct_k.mul_(100.0 / batch_size))\n    return res\n\n\nclass Cutout(object):\n    def __init__(self, length):\n        self.length = length\n\n    def __call__(self, img):\n        h, w = img.size(1), img.size(2)\n        mask = np.ones((h, w), np.float32)\n        y = np.random.randint(h)\n        x = np.random.randint(w)\n\n        y1 = np.clip(y - self.length // 2, 0, h)\n        y2 = np.clip(y + self.length // 2, 0, h)\n        x1 = np.clip(x - self.length // 2, 0, w)\n        x2 = np.clip(x + self.length // 2, 0, w)\n\n        mask[y1: y2, x1: x2] = 0.\n        mask = torch.from_numpy(mask)\n        mask = mask.expand_as(img)\n        img *= mask\n        return img\n\n\ndef _data_transforms_cifar(args):\n    CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124] if args.dataset == 'cifar10' else [0.50707519, 0.48654887,\n                                                                                         0.44091785]\n    CIFAR_STD = [0.24703233, 0.24348505, 0.26158768] if args.dataset == 'cifar10' else [0.26733428, 0.25643846,\n                                                                                        0.27615049]\n\n    normalize_transform = [\n        transforms.ToTensor(),\n        transforms.Normalize(CIFAR_MEAN, CIFAR_STD)]\n\n    random_transform = [\n        transforms.RandomCrop(32, padding=4),\n        transforms.RandomHorizontalFlip()]\n\n    if args.auto_aug:\n        random_transform += [CIFAR10Policy()]\n\n    if args.cutout:\n        cutout_transform = [Cutout(args.cutout_length)]\n    else:\n        cutout_transform = []\n\n    train_transform = transforms.Compose(\n        random_transform + normalize_transform + cutout_transform\n    )\n\n    valid_transform = transforms.Compose([\n        transforms.ToTensor(),\n        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),\n    ])\n    return train_transform, valid_transform\n\n\ndef count_parameters_in_MB(model):\n    return np.sum(np.prod(v.size()) for v in model.parameters()) / 1e6\n\n\ndef save_checkpoint(state, is_best, save):\n    filename = os.path.join(save, 'checkpoint.pth.tar')\n    torch.save(state, filename)\n    if is_best:\n        best_filename = os.path.join(save, 'model_best.pth.tar')\n        shutil.copyfile(filename, best_filename)\n\n\ndef save(model, model_path):\n    torch.save(model.state_dict(), model_path)\n\n\ndef load(model, model_path):\n    model.load_state_dict(torch.load(model_path))\n\n\ndef drop_path(x, drop_prob):\n    if drop_prob > 0.:\n        keep_prob = 1. - drop_prob\n        mask = Variable(torch.cuda.FloatTensor(\n            x.size(0), 1, 1, 1).bernoulli_(keep_prob))\n        x.div_(keep_prob)\n        x.mul_(mask)\n    return x\n\n\ndef create_exp_dir(path, scripts_to_save=None):\n    if not os.path.exists(path):\n        os.makedirs(path)\n    print('Experiment dir : {}'.format(path))\n\n    if scripts_to_save is not None:\n        os.makedirs(os.path.join(path, 'scripts'))\n        for script in scripts_to_save:\n            dst_file = os.path.join(path, 'scripts', os.path.basename(script))\n            shutil.copyfile(script, dst_file)\n\n\ndef calc_time(seconds):\n    m, s = divmod(seconds, 60)\n    h, m = divmod(m, 60)\n    t, h = divmod(h, 24)\n    return {'day': t, 'hour': h, 'minute': m, 'second': int(s)}\n\n\ndef save_file(recoder, path='./', back_connection=False):\n    size = (forward_edge_num +\n            backward_edge_num if back_connection else forward_edge_num, num_ops)\n    fig, axs = plt.subplots(*size, figsize=(36, 98))\n    row = 0\n    col = 0\n    for (k, v) in recoder.items():\n        axs[row, col].set_title(k)\n        axs[row, col].plot(v, 'r+')\n        if col == num_ops - 1:\n            col = 0\n            row += 1\n        else:\n            col += 1\n    if not os.path.exists(path):\n        os.makedirs(path)\n    fig.savefig(os.path.join(path, 'output.png'), bbox_inches='tight')\n    plt.tight_layout()\n    print('save history weight in {}'.format(os.path.join(path, 'output.png')))\n    with open(os.path.join(path, 'history_weight.json'), 'w') as outf:\n        json.dump(recoder, outf)\n        print('save history weight in {}'.format(\n            os.path.join(path, 'history_weight.json')))\n\ndef data_transforms(args):\n    if args.dataset == 'cifar10':\n        MEAN = [0.4913, 0.4821, 0.4465]\n        STD = [0.2470, 0.2434, 0.2615]\n    elif args.dataset == 'cifar100':\n        MEAN = [0.5071, 0.4867, 0.4408]\n        STD = [0.2673, 0.2564, 0.2762]\n    elif args.dataset == 'tinyimagenet':\n        MEAN = [0.485, 0.456, 0.406]\n        STD = [0.229, 0.224, 0.225]\n\n    if (args.dataset== 'tinyimagenet'):\n        train_transform = transforms.Compose([\n            transforms.RandomCrop(64, padding=8),\n            transforms.RandomHorizontalFlip(),\n            transforms.ToTensor(),\n            transforms.Normalize(MEAN, STD)\n        ])\n        valid_transform = transforms.Compose([\n            transforms.ToTensor(),\n            transforms.Normalize(MEAN, STD)\n        ])\n    else:  # cifar10 or cifar100\n        train_transform = transforms.Compose([\n            transforms.RandomCrop(32, padding=4),\n            transforms.RandomHorizontalFlip(),\n            transforms.ToTensor(),\n            transforms.Normalize(MEAN, STD)\n        ])\n        valid_transform = transforms.Compose([\n            transforms.ToTensor(),\n            transforms.Normalize(MEAN, STD)\n        ])\n    return train_transform, valid_transform\n"
  },
  {
    "path": "examples/TIM/README.md",
    "content": "# TIM: An Efficient Temporal Interaction Module for Spiking Transformer, [IJCAI2024](https://arxiv.org/abs/2401.11687)\n\n![Alt text](img/TIM.png)\n\n## Reference\n```\n@misc{shen2024tim,\n      title={TIM: An Efficient Temporal Interaction Module for Spiking Transformer}, \n      author={Sicheng Shen and Dongcheng Zhao and Guobin Shen and Yi Zeng},\n      year={2024},\n      eprint={2401.11687},\n      archivePrefix={arXiv},\n      primaryClass={cs.NE}\n}\n```\n\nHere is the official implemented code of TIM. The code is based on Pytorch and [Braincog](https://github.com/BrainCog-X/Brain-Cog)\n\n## Requirements\n### Create Braincog Virtual Environment\n```\nconda create -n braincog python=3.8\nconda activate braincog\npip install braincog\n```\n### Dataset Preparation\n**Datasets Needed**: CIFAR10-DVS, N-CALTECH101, UCF101DVS, NCARS, HMDB51DVS, SHD\nPlease unzip data to ```/data/datasets``` so that dataset.py may directly load corresponding dataset for training\n\n## Model Training\nFor most DVS data, we prefer using the event-frame size of 64 but not 128 here.\nPlease adjust your hyper parameters here. 10 is set as the default value of time step numbers.\n```\n@register_model\ndef spikformer_dvs(pretrained=False, **kwargs):\n    model = Spikformer(TIM_alpha=0.5,step=10,if_UCF=False,num_classes=10,\n        # img_size_h=64, img_size_w=64,\n        # patch_size=16, embed_dims=256, num_heads=16, mlp_ratios=4,\n        # in_channels=2, qkv_bias=False,\n        # depths=2, sr_ratios=1,\n        **kwargs\n    )\n    model.default_cfg = _cfg()\n    return model\n```\n### Training on CIFAR10-DVS\n```\npython main.py --model spikformer_dvs --dataset dvsc10 --epoch 500 --batch-size 16 --event-size 64 \n```\n### Training on N-CALTECH101\n```num_classes``` should be set to 101\n```\npython main.py --model spikformer_dvs --dataset NCALTECH101 --epoch 500 --batch-size 16 --event-size 64  --num_classes 101\n```\n\n### Training on NCARS\n```num_classes``` should be set to 2\n```\npython main.py --model spikformer_dvs --dataset NCARS --epoch 500 --batch-size 16 --event-size 64  --num_classes 2\n```\n### Training on UCF101DVS\n```num_classes``` should be set to 101,```if_UCF``` should be set to ```True``` \n```\npython main.py --model spikformer_dvs --dataset UCF101DVS --epoch 500 --batch-size 16 --event-size 64  --num_classes 101\n```\n### Training on HMDB51DVS\n```\npython main.py --model spikformer_dvs --dataset HMDBDVS --epoch 500 --batch-size 16 --event-size 64  --num_classes 51\n```\n\n### Training on SHD\n```num_classes``` should be set to 20\n```\npython main.py --model spikformer_shd --dataset SHD --epoch 500 --batch-size 16 --num_classes 20\n```\n"
  },
  {
    "path": "examples/TIM/main.py",
    "content": "import argparse\nimport time\n\nimport timm.models\nimport yaml\nimport os\nimport random as buildin_random\nimport logging\nfrom collections import OrderedDict\nfrom contextlib import suppress\nfrom datetime import datetime\n\nfrom braincog.base.node.node import *\nfrom braincog.utils import *\nfrom braincog.base.utils.criterions import *\n# from braincog.datasets.datasets import *\nfrom utils.datasets import *\nfrom braincog.model_zoo.resnet import *\nfrom braincog.model_zoo.convnet import *\nfrom braincog.model_zoo.vgg_snn import VGG_SNN, SNN5\nfrom braincog.model_zoo.resnet19_snn import resnet19\nfrom braincog.utils import save_feature_map, setup_seed\nfrom braincog.base.utils.visualization import plot_tsne_3d, plot_tsne, plot_confusion_matrix, plot_mem_distribution\n\nimport torch\nimport torch.nn as nn\nimport torchvision.utils\nfrom torch.nn.parallel import DistributedDataParallel as NativeDDP\n\nfrom timm.data import ImageDataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset\nfrom timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model, register_model\nfrom timm.utils import *\nfrom timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy\nfrom timm.optim import create_optimizer\nfrom timm.scheduler import create_scheduler\nfrom timm.utils import ApexScaler, NativeScaler\n\nfrom torch.utils.tensorboard import SummaryWriter\nfrom models.spikformer_braincog_DVS import spikformer_dvs\nfrom models.spikformer_braincog_DVS import spikformer_shd\n\nos.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n\n\ntorch.backends.cudnn.benchmark = True\n_logger = logging.getLogger('train')\n\n# The first arg parser parses out only the --config argument, this argument is used to\n# load a yaml file containing key-values that override the defaults for the main parser below\nconfig_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)\nparser.add_argument('-c', '--config', default='', type=str, metavar='FILE',\n                    help='YAML config file specifying default arguments')\n\nparser = argparse.ArgumentParser(description='SNN Training and Evaluating')\n\n# Model parameters\nparser.add_argument('--dataset', default='dvsc10', type=str)\nparser.add_argument('--model', default='spikformer', type=str, metavar='MODEL',\n                    help='Name of model to train (default: \"countception\"')\nparser.add_argument('--pretrained', action='store_true', default=False,\n                    help='Start with pretrained version of specified network (if avail)')\nparser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',\n                    help='Initialize model from this checkpoint (default: none)')\nparser.add_argument('--resume', default='', type=str, metavar='PATH',\n                    help='Resume full model and optimizer state from checkpoint (default: none)')\nparser.add_argument('--eval_checkpoint', default='', type=str, metavar='PATH',\n                    help='path to eval checkpoint (default: none)')\nparser.add_argument('--no-resume-opt', action='store_true', default=False,\n                    help='prevent resume of optimizer state when resuming model')\nparser.add_argument('--num-classes', type=int, default=10, metavar='N',\n                    help='number of label classes (default: 1000)')\nparser.add_argument('--gp', default=None, type=str, metavar='POOL',\n                    help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')\n\n# Dataset parameters for static datasets\nparser.add_argument('--img-size', type=int, default=224, metavar='N',\n                    help='Image patch size (default: None => model default)')\nparser.add_argument('--crop-pct', default=None, type=float,\n                    metavar='N', help='inputs image center crop percent (for validation only)')\nparser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',\n                    help='Override mean pixel value of dataset')\nparser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',\n                    help='Override std deviation of of dataset')\nparser.add_argument('--interpolation', default='', type=str, metavar='NAME',\n                    help='Image resize interpolation type (overrides model)')\n\n# Dataloader parameters\nparser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',\n                    help='inputs batch size for training (default: 128)')\nparser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',\n                    help='ratio of validation batch size to training batch size (default: 1)')\n\n# Optimizer parameters\nparser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',\n                    help='Optimizer (default: \"adamw\"')\nparser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',\n                    help='Optimizer Epsilon (default: None, use opt default)')\nparser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',\n                    help='Optimizer Betas (default: None, use opt default)')\nparser.add_argument('--momentum', type=float, default=0.9, metavar='M',\n                    help='Optimizer momentum (default: 0.9)')\nparser.add_argument('--weight-decay', type=float, default=1e-4,\n                    help='weight decay (default: 0.01 for adamw)')\nparser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',\n                    help='Clip gradient norm (default: None, no clipping)')\nparser.add_argument('--adam-epoch', type=int, default=1000, help='lamb switch to adamw')\n\n# Learning rate schedule parameters\nparser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',\n                    help='LR scheduler (default: \"cosine\"')\nparser.add_argument('--lr', type=float, default=5e-3, metavar='LR',\n                    help='learning rate (default: 0.01)')\nparser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',\n                    help='learning rate noise on/off epoch percentages')\nparser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',\n                    help='learning rate noise limit percent (default: 0.67)')\nparser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',\n                    help='learning rate noise std-dev (default: 1.0)')\nparser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',\n                    help='learning rate cycle len multiplier (default: 1.0)')\nparser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',\n                    help='learning rate cycle limit')\nparser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',\n                    help='warmup learning rate (default: 0.0001)')\nparser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',\n                    help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')\nparser.add_argument('--epochs', type=int, default=400, metavar='N',\n                    help='number of epochs to train (default: 2)')\nparser.add_argument('--start-epoch', default=None, type=int, metavar='N',\n                    help='manual epoch number (useful on restarts)')\nparser.add_argument('--decay-epochs', type=float, default=30, metavar='N',\n                    help='epoch interval to decay LR')\nparser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',\n                    help='epochss to warmup LR, if scheduler supports')\nparser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',\n                    help='epochs to cooldown LR at min_lr, after cyclic schedule ends')\nparser.add_argument('--patience-epochs', type=int, default=10, metavar='N',\n                    help='patience epochs for Plateau LR scheduler (default: 10')\nparser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',\n                    help='LR decay rate (default: 0.1)')\nparser.add_argument('--power', type=int, default=1, help='power')\n\n# Augmentation & regularization parameters ONLY FOR IMAGE NET\nparser.add_argument('--no-aug', action='store_true', default=False,\n                    help='Disable all training augmentation, override other train aug args')\nparser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',\n                    help='Random resize scale (default: 0.08 1.0)')\nparser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                    help='Random resize aspect ratio (default: 0.75 1.33)')\nparser.add_argument('--hflip', type=float, default=0.5,\n                    help='Horizontal flip training aug probability')\nparser.add_argument('--vflip', type=float, default=0.,\n                    help='Vertical flip training aug probability')\nparser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',\n                    help='Color jitter factor (default: 0.4)')\nparser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',\n                    help='Use AutoAugment policy. \"v0\" or \"original\". (default: None)'),\nparser.add_argument('--aug-splits', type=int, default=0,\n                    help='Number of augmentation splits (default: 0, valid: 0 or >=2)')\nparser.add_argument('--jsd', action='store_true', default=False,\n                    help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')\nparser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',\n                    help='Random erase prob (default: 0.25)')\nparser.add_argument('--remode', type=str, default='pixel',\n                    help='Random erase mode (default: \"const\")')\nparser.add_argument('--recount', type=int, default=1,\n                    help='Random erase count (default: 1)')\nparser.add_argument('--resplit', action='store_true', default=False,\n                    help='Do not random erase first (clean) augmentation split')\nparser.add_argument('--mixup', type=float, default=0.,\n                    help='mixup alpha, mixup enabled if > 0. (default: 0.)')\nparser.add_argument('--cutmix', type=float, default=0.,\n                    help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')\nparser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,\n                    help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')\nparser.add_argument('--mixup-prob', type=float, default=0.,\n                    help='Probability of performing mixup or cutmix when either/both is enabled')\nparser.add_argument('--mixup-switch-prob', type=float, default=0.5,\n                    help='Probability of switching to cutmix when both mixup and cutmix enabled')\nparser.add_argument('--mixup-mode', type=str, default='batch',\n                    help='How to apply mixup/cutmix params. Per \"batch\", \"pair\", or \"elem\"')\nparser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',\n                    help='Turn off mixup after this epoch, disabled if 0 (default: 0)')\nparser.add_argument('--smoothing', type=float, default=0.1,\n                    help='Label smoothing (default: 0.1)')\nparser.add_argument('--train-interpolation', type=str, default='random',\n                    help='Training interpolation (random, bilinear, bicubic default: \"random\")')\nparser.add_argument('--drop', type=float, default=0.0, metavar='PCT',\n                    help='Dropout rate (default: 0.0)')\nparser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',\n                    help='Drop connect rate, DEPRECATED, use drop-path (default: None)')\nparser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',\n                    help='Drop path rate (default: None)')\nparser.add_argument('--drop-block', type=float, default=None, metavar='PCT',\n                    help='Drop block rate (default: None)')\nparser.add_argument('--newton-maxiter', default=20, type=int,\n                    help='max iterration in newton method')\nparser.add_argument('--reset-drop', action='store_true', default=False,\n                    help='whether to reset drop')\nparser.add_argument('--kernel-method', type=str, default='cuda', choices=['torch', 'cuda'],\n                    help='The implementation way of gaussian kernel method, choose from \"cuda\" and \"torch\"')\n\n# Batch norm parameters (only works with gen_efficientnet based models currently)\nparser.add_argument('--bn-tf', action='store_true', default=False,\n                    help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')\nparser.add_argument('--bn-momentum', type=float, default=None,\n                    help='BatchNorm momentum override (if not None)')\nparser.add_argument('--bn-eps', type=float, default=None,\n                    help='BatchNorm epsilon override (if not None)')\nparser.add_argument('--sync-bn', action='store_true',\n                    help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')\nparser.add_argument('--dist-bn', type=str, default='',\n                    help='Distribute BatchNorm stats between node after each epoch (\"broadcast\", \"reduce\", or \"\")')\nparser.add_argument('--split-bn', action='store_true',\n                    help='Enable separate BN layers per augmentation split.')\n\n# Model Exponential Moving Average\nparser.add_argument('--model-ema', action='store_true', default=False,\n                    help='Enable tracking moving average of model weights')\nparser.add_argument('--model-ema-force-cpu', action='store_true', default=False,\n                    help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')\nparser.add_argument('--model-ema-decay', type=float, default=0.99996,\n                    help='decay factor for model weights moving average (default: 0.9998)')\n\n# Misc\nparser.add_argument('--seed', type=int, default=42, metavar='S',\n                    help='random seed (default: 42)')\nparser.add_argument('--log-interval', type=int, default=50, metavar='N',\n                    help='how many batches to wait before logging training status')\nparser.add_argument('--recovery-interval', type=int, default=0, metavar='N',\n                    help='how many batches to wait before writing recovery checkpoint')\nparser.add_argument('-j', '--workers', type=int, default=8, metavar='N',\n                    help='how many training processes to use (default: 1)')\nparser.add_argument('--num-gpu', type=int, default=1,\n                    help='Number of GPUS to use')\nparser.add_argument('--save-images', action='store_true', default=False,\n                    help='save images of inputs bathes every log interval for debugging')\nparser.add_argument('--amp', action='store_true', default=False,\n                    help='use NVIDIA Apex AMP or Native AMP for mixed precision training')\nparser.add_argument('--apex-amp', action='store_true', default=False,\n                    help='Use NVIDIA Apex AMP mixed precision')\nparser.add_argument('--native-amp', action='store_true', default=False,\n                    help='Use Native Torch AMP mixed precision')\nparser.add_argument('--channels-last', action='store_true', default=False,\n                    help='Use channels_last memory layout')\nparser.add_argument('--pin-mem', action='store_true', default=False,\n                    help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')\nparser.add_argument('--no-prefetcher', action='store_true', default=False,\n                    help='disable fast prefetcher')\nparser.add_argument('--output', default='/home/shensicheng/code/TIM/logs', type=str, metavar='PATH',\n                    help='path to output folder (default: none, current dir)')\nparser.add_argument('--tensorboard-dir', default='./runs', type=str)\nparser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',\n                    help='Best metric (default: \"top1\"')\nparser.add_argument('--tta', type=int, default=0, metavar='N',\n                    help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')\nparser.add_argument('--local_rank', default=0, type=int)\nparser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,\n                    help='use the multi-epochs-loader to save time at the beginning of every epoch')\nparser.add_argument('--eval', action='store_true', help='Perform evaluation only')\nparser.add_argument('--device', type=int, default=0)\n\n# Spike parameters\nparser.add_argument('--step', type=int, default=10, help='Simulation time step (default: 10)')\nparser.add_argument('--encode', type=str, default='direct', help='Input encode method (default: direct)')\nparser.add_argument('--temporal-flatten', action='store_true',\n                    help='Temporal flatten to channels. ONLY FOR EVENT DATA TRAINING BY ANN')\nparser.add_argument('--adaptive-node', action='store_true')\nparser.add_argument('--critical-loss', action='store_true')\nparser.add_argument('--conv-type', type=str, default='normal')\nparser.add_argument('--sew-cnf', type=str, default='ADD')\nparser.add_argument('--rand-step', action='store_true')\n\n# neuron type\nparser.add_argument('--node-type', type=str, default='LIFNode', help='Node type in network (default: PLIF)')\nparser.add_argument('--act-fun', type=str, default='QGateGrad',\n                    help='Surogate Function in node. Only for Surrogate nodes (default: AtanGrad)')\nparser.add_argument('--threshold', type=float, default=.5, help='Firing threshold (default: 0.5)')\nparser.add_argument('--tau', type=float, default=2., help='Attenuation coefficient (default: 2.)')\nparser.add_argument('--requires-thres-grad', action='store_true')\nparser.add_argument('--sigmoid-thres', action='store_true')\n\nparser.add_argument('--loss-fn', type=str, default='ce', help='loss function (default: ce)')\nparser.add_argument('--noisy-grad', type=float, default=0.,\n                    help='Add noise to backward, sometime will make higher accuracy (default: 0.)')\nparser.add_argument('--spike-output', action='store_true', default=False,\n                    help='Using mem output or spike output (default: False)')\nparser.add_argument('--n_groups', type=int, default=1)\nparser.add_argument('--n-encode-type', type=str, default='linear')\nparser.add_argument('--n-preact', action='store_true')\nparser.add_argument('--layer-by-layer', action='store_true',\n                    help='forward step-by-step or layer-by-layer. '\n                         'Larger Model with layer-by-layer will be faster (default: False)')\nparser.add_argument('--tet-loss', action='store_true')\n\n# EventData Augmentation\nparser.add_argument('--mix-up', action='store_true', help='Mix-up for event data (default: False)')\nparser.add_argument('--cut-mix', action='store_true', help='CutMix for event data (default: False)')\nparser.add_argument('--event-mix', action='store_true', help='EventMix for event data (default: False)')\nparser.add_argument('--cutmix_beta', type=float, default=2.0, help='cutmix_beta (default: 1.)')\nparser.add_argument('--cutmix_prob', type=float, default=0.5, help='cutmix_prib for event data (default: .5)')\nparser.add_argument('--cutmix_num', type=int, default=1, help='cutmix_num for event data (default: 1)')\nparser.add_argument('--cutmix_noise', type=float, default=0.,\n                    help='Add Pepper noise after mix, sometimes work (default: 0.)')\nparser.add_argument('--gaussian-n', type=int, default=3)\nparser.add_argument('--rand-aug', action='store_true',\n                    help='Rand Augment for Event data (default: False)')\nparser.add_argument('--randaug_n', type=int, default=3,\n                    help='Rand Augment times n (default: 3)')\nparser.add_argument('--randaug_m', type=int, default=15,\n                    help='Rand Augment times n (default: 15) (0-30)')\nparser.add_argument('--train-portion', type=float, default=0.9,\n                    help='Dataset portion, only for datasets which do not have validation set (default: 0.9)')\nparser.add_argument('--event-size', default=48, type=int,\n                    help='Event size. Resize event data before process (default: 48)')\nparser.add_argument('--node-resume', type=str, default='',\n                    help='resume weights in node for adaptive node. (default: False)')\n\n# visualize\nparser.add_argument('--visualize', action='store_true',\n                    help='Visualize spiking map for each layer, only for validate (default: False)')\nparser.add_argument('--spike-rate', action='store_true',\n                    help='Print spiking rate for each layer, only for validate(default: False)')\nparser.add_argument('--tsne', action='store_true')\nparser.add_argument('--conf-mat', action='store_true')\nparser.add_argument('--mem-dist', action='store_true')\nparser.add_argument('--adaptation-info', action='store_true')\n\nparser.add_argument('--suffix', type=str, default='',\n                    help='Add an additional suffix to the save path (default: \\'\\')')\n\ntry:\n    from apex import amp\n    from apex.parallel import DistributedDataParallel as ApexDDP\n    from apex.parallel import convert_syncbn_model\n\n    has_apex = True\nexcept ImportError:\n    has_apex = False\n\nhas_native_amp = False\ntry:\n    if getattr(torch.cuda.amp, 'autocast') is not None:\n        has_native_amp = True\nexcept AttributeError:\n    pass\n\n\ndef _parse_args():\n    # Do we have a config file to parse?\n    args_config, remaining = config_parser.parse_known_args()\n    if args_config.config:\n        with open(args_config.config, 'r') as f:\n            cfg = yaml.safe_load(f)\n            parser.set_defaults(**cfg)\n\n    # The main arg parser parses the rest of the args, the usual\n    # defaults will have been overridden if config file specified.\n    args = parser.parse_args(remaining)\n\n    # Cache the args as a text string to save them in the output dir later\n    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)\n    return args, args_text\n\n\n\ndef main():\n    args, args_text = _parse_args()\n    # args.no_spike_output = args.no_spike_output | args.cut_mix\n    args.no_spike_output = True\n    output_dir = ''\n    if args.local_rank == 0:\n        output_base = args.output if args.output else './output'\n        exp_name = '-'.join([\n            args.model,\n            args.dataset,\n            args.node_type,\n            str(args.step),\n            args.suffix,\n            datetime.now().strftime(\"%Y%m%d-%H%M%S\"),\n            # str(args.img_size)\n        ])\n        output_dir = get_outdir(output_base, 'train', exp_name)\n        args.output_dir = output_dir\n        setup_default_logging(log_path=os.path.join(output_dir, 'log.txt'))\n        summary_writer = SummaryWriter(log_dir=os.path.join(args.tensorboard_dir, exp_name))\n        args.tensorboard_prefix = os.path.join(args.dataset, args.model)\n    else:\n        summary_writer = None\n        setup_default_logging()\n\n    args.prefetcher = not args.no_prefetcher\n    args.distributed = False\n    if 'WORLD_SIZE' in os.environ:\n        args.distributed = int(os.environ['WORLD_SIZE']) > 1\n        if args.distributed and args.num_gpu > 1:\n            _logger.warning(\n                'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')\n            args.num_gpu = 1\n\n    # args.device = 'cuda:0'\n    args.world_size = 1\n    args.rank = 0  # global rank\n    if args.distributed:\n        args.num_gpu = 1\n        args.device = 'cuda:%d' % args.local_rank\n        torch.cuda.set_device(args.local_rank)\n        torch.distributed.init_process_group(backend='nccl', init_method='env://')\n        args.world_size = torch.distributed.get_world_size()\n        args.rank = torch.distributed.get_rank()\n    else:\n        torch.cuda.set_device('cuda:%d' % args.device)\n    assert args.rank >= 0\n\n    if args.distributed:\n        _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'\n                     % (args.rank, args.world_size))\n    else:\n        _logger.info('Training with a single process on %d GPUs.' % args.num_gpu)\n\n    # torch.manual_seed(args.seed + args.rank)\n    setup_seed(args.seed + args.rank)\n\n    model = create_model(\n        args.model,\n        # pretrained=args.pretrained,\n        # num_classes=args.num_classes,\n        # dataset=args.dataset,\n        # step=args.step,\n        # encode_type=args.encode,\n        # node_type=eval(args.node_type),\n        # threshold=args.threshold,\n        # tau=args.tau,\n        # sigmoid_thres=args.sigmoid_thres,\n        # requires_thres_grad=args.requires_thres_grad,\n        # spike_output=not args.no_spike_output,\n        # act_fun=args.act_fun,\n        # temporal_flatten=args.temporal_flatten,\n        # layer_by_layer=args.layer_by_layer,\n        # n_groups=args.n_groups,\n        # n_encode_type=args.n_encode_type,\n        # n_preact=args.n_preact,\n        # tet_loss=args.tet_loss,\n        # sew_cnf=args.sew_cnf,\n        # conv_type=args.conv_type,\n    )\n\n    _logger.info('[MODEL ARCH]\\n{}'.format(model))\n\n    if 'dvs' in args.dataset:\n        args.channels = 2\n    elif 'mnist' in args.dataset:\n        args.channels = 1\n    else:\n        args.channels = 3\n    # flops, params = profile(model, inputs=(torch.randn(1, args.channels, args.event_size, args.event_size),), verbose=False)\n    # _logger.info('flops = %fM', flops / 1e6)\n    # _logger.info('param size = %fM', params / 1e6)\n\n    linear_scaled_lr = args.lr * args.batch_size * args.world_size / 1024.0\n    args.lr = linear_scaled_lr\n    _logger.info(\"learning rate is %f\" % linear_scaled_lr)\n\n    if args.local_rank == 0:\n        _logger.info('Model %s created, param count: %d' %\n                     (args.model, sum([m.numel() for m in model.parameters()])))\n\n    num_aug_splits = 0\n    if args.aug_splits > 0:\n        assert args.aug_splits > 1, 'A split of 1 makes no sense'\n        num_aug_splits = args.aug_splits\n\n    if args.split_bn:\n        assert num_aug_splits > 1 or args.resplit\n        model = convert_splitbn_model(model, max(num_aug_splits, 2))\n\n    use_amp = None\n    if args.amp:\n        # for backwards compat, `--amp` arg tries apex before native amp\n        if has_apex:\n            args.apex_amp = True\n        elif has_native_amp:\n            args.native_amp = True\n    if args.apex_amp and has_apex:\n        use_amp = 'apex'\n    elif args.native_amp and has_native_amp:\n        use_amp = 'native'\n    elif args.apex_amp or args.native_amp:\n        _logger.warning(\"Neither APEX or native Torch AMP is available, using float32. \"\n                        \"Install NVIDA apex or upgrade to PyTorch 1.6\")\n\n    if args.num_gpu > 1:\n        if use_amp == 'apex':\n            _logger.warning(\n                'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.')\n            use_amp = None\n        model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()\n        assert not args.channels_last, \"Channels last not supported with DP, use DDP.\"\n    else:\n        model = model.cuda()\n        if args.channels_last:\n            model = model.to(memory_format=torch.channels_last)\n\n    optimizer = create_optimizer(args, model)\n\n    _logger.info('[OPTIMIZER]\\n{}'.format(optimizer))\n\n    amp_autocast = suppress  # do nothing\n    loss_scaler = None\n    if use_amp == 'apex':\n        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')\n        loss_scaler = ApexScaler()\n        if args.local_rank == 0:\n            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')\n    elif use_amp == 'native':\n        amp_autocast = torch.cuda.amp.autocast\n        loss_scaler = NativeScaler()\n        if args.local_rank == 0:\n            _logger.info('Using native Torch AMP. Training in mixed precision.')\n    else:\n        if args.local_rank == 0:\n            _logger.info('AMP not enabled. Training in float32.')\n\n    # optionally resume from a checkpoint\n    resume_epoch = None\n    if args.resume and args.eval_checkpoint == '':\n        args.eval_checkpoint = args.resume\n    if args.resume:\n        args.eval = True\n        # checkpoint = torch.load(args.resume, map_location='cpu')\n        # model.load_state_dict(checkpoint['state_dict'], False)\n        resume_epoch = resume_checkpoint(\n            model, args.resume,\n            optimizer=None if args.no_resume_opt else optimizer,\n            loss_scaler=None if args.no_resume_opt else loss_scaler,\n            log_info=args.local_rank == 0)\n        # print(model.get_attr('mu'))\n        # print(model.get_attr('sigma'))\n        if hasattr(model, 'set_threshold'):\n            model.set_threshold(args.threshold)\n\n    if args.critical_loss or args.spike_rate:\n        model.set_requires_fp(True)\n\n    model_ema = None\n    if args.model_ema:\n        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper\n        model_ema = ModelEma(\n            model,\n            decay=args.model_ema_decay,\n            device='cpu' if args.model_ema_force_cpu else '',\n            resume=args.resume)\n\n    if args.node_resume:\n        ckpt = torch.load(args.node_resume, map_location='cpu')\n        model.load_node_weight(ckpt, args.node_trainable)\n\n    model_without_ddp = model\n    if args.distributed:\n        if args.sync_bn:\n            assert not args.split_bn\n            try:\n                if has_apex and use_amp != 'native':\n                    # Apex SyncBN preferred unless native amp is activated\n                    model = convert_syncbn_model(model)\n                else:\n                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)\n                if args.local_rank == 0:\n                    _logger.info(\n                        'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '\n                        'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')\n            except Exception as e:\n                _logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')\n        if has_apex and use_amp != 'native':\n            # Apex DDP preferred unless native amp is activated\n            if args.local_rank == 0:\n                _logger.info(\"Using NVIDIA APEX DistributedDataParallel.\")\n            model = ApexDDP(model, delay_allreduce=True)\n        else:\n            if args.local_rank == 0:\n                _logger.info(\"Using native Torch DistributedDataParallel.\")\n            model = NativeDDP(model.cuda(), device_ids=[args.local_rank],\n                              find_unused_parameters=True)  # can use device str in Torch >= 1.1\n        model_without_ddp = model.module\n    # NOTE: EMA model does not need to be wrapped by DDP\n\n    lr_scheduler, num_epochs = create_scheduler(args, optimizer)\n    start_epoch = 0\n    if args.start_epoch is not None:\n        # a specified start_epoch will always override the resume epoch\n        start_epoch = args.start_epoch\n    elif resume_epoch is not None:\n        start_epoch = resume_epoch\n    if lr_scheduler is not None and start_epoch > 0:\n        lr_scheduler.step(start_epoch)\n\n    if args.local_rank == 0:\n        _logger.info('Scheduled epochs: {}'.format(num_epochs))\n\n    # now config only for imnet\n    data_config = resolve_data_config(vars(args), model=model, verbose=False)\n    loader_train, loader_eval, mixup_active, mixup_fn = eval('get_%s_data' % args.dataset)(\n        batch_size=args.batch_size,\n        step=args.step,\n        args=args,\n        _logge=_logger,\n        data_config=data_config,\n        num_aug_splits=num_aug_splits,\n        size=args.event_size,\n        mix_up=args.mix_up,\n        cut_mix=args.cut_mix,\n        event_mix=args.event_mix,\n        beta=args.cutmix_beta,\n        prob=args.cutmix_prob,\n        gaussian_n=args.gaussian_n,\n        num=args.cutmix_num,\n        noise=args.cutmix_noise,\n        num_classes=args.num_classes,\n        rand_aug=args.rand_aug,\n        randaug_n=args.randaug_n,\n        randaug_m=args.randaug_m,\n        portion=args.train_portion,\n        _logger=_logger,\n    )\n    # _logger.info('train_loader:\\n{}\\nval_loader:\\n{}'.format(loader_train, loader_eval))\n    if args.loss_fn == 'mse':\n        train_loss_fn = UnilateralMse(1.)\n        validate_loss_fn = UnilateralMse(1.)\n    elif args.loss_fn == 'onehot-mse':\n        train_loss_fn = OnehotMse(args.num_classes)\n        validate_loss_fn = OnehotMse(args.num_classes)\n    else:\n        if args.jsd:\n            assert num_aug_splits > 1  # JSD only valid with aug splits set\n            train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()\n        elif mixup_active:\n            # smoothing is handled with mixup target transform\n            train_loss_fn = SoftTargetCrossEntropy().cuda()\n        elif args.smoothing:\n            train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()\n        else:\n            train_loss_fn = nn.CrossEntropyLoss().cuda()\n\n        validate_loss_fn = nn.CrossEntropyLoss().cuda()\n\n    if args.loss_fn == 'mix':\n        train_loss_fn = MixLoss(train_loss_fn)\n        validate_loss_fn = MixLoss(validate_loss_fn)\n\n    if args.tet_loss:\n        train_loss_fn = TetLoss(train_loss_fn)\n        validate_loss_fn = TetLoss(validate_loss_fn)\n\n    eval_metric = args.eval_metric\n    best_metric = None\n    best_epoch = None\n\n    if args.eval:  # evaluate the model\n        # if args.distributed:\n        #     raise NotImplementedError('eval not has not been verified for distributed')\n        # else:\n        #     load_checkpoint(model, args.eval_checkpoint, args.model_ema)\n        model.eval()\n        for t in range(1, args.step * 3):\n        # for t in range(args.step, args.step + 1):\n            model.set_attr('step', t)\n            val_metrics = validate(start_epoch, model, loader_eval, validate_loss_fn, args,\n                                   visualize=args.visualize, spike_rate=args.spike_rate,\n                                   tsne=args.tsne, conf_mat=args.conf_mat, summary_writer=summary_writer)\n            print(f\"[STEP:{t}], Top-1 accuracy of the model is: {val_metrics['top1']:.1f}%\")\n        return\n\n    saver = None\n    if args.local_rank == 0:\n        decreasing = True if eval_metric == 'loss' else False\n        saver = CheckpointSaver(\n            model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,\n            checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=3)\n        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:\n            f.write(args_text)\n\n    try:  # train the model\n        if args.reset_drop:\n            model_without_ddp.reset_drop_path(0.0)\n        for epoch in range(start_epoch, args.epochs):\n            if epoch == 0 and args.reset_drop:\n                model_without_ddp.reset_drop_path(args.drop_path)\n\n            if args.distributed:\n                loader_train.sampler.set_epoch(epoch)\n\n            train_metrics = train_epoch(\n                epoch, model, loader_train, optimizer, train_loss_fn, args,\n                lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,\n                amp_autocast=amp_autocast, loss_scaler=loss_scaler,\n                model_ema=model_ema, mixup_fn=mixup_fn, summary_writer=summary_writer\n            )\n\n            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):\n                if args.local_rank == 0:\n                    _logger.info(\"Distributing BatchNorm running means and vars\")\n                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')\n\n            eval_metrics = validate(epoch, model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast,\n                                    visualize=args.visualize, spike_rate=args.spike_rate,\n                                    tsne=args.tsne, conf_mat=args.conf_mat, summary_writer=summary_writer)\n\n            if model_ema is not None and not args.model_ema_force_cpu:\n                if args.distributed and args.dist_bn in ('broadcast', 'reduce'):\n                    distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')\n                ema_eval_metrics = validate(\n                    epoch, model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)',\n                    visualize=args.visualize, spike_rate=args.spike_rate,\n                    tsne=args.tsne, conf_mat=args.conf_mat, summary_writer=summary_writer\n                )\n                eval_metrics = ema_eval_metrics\n\n            if lr_scheduler is not None:\n                # step LR for next epoch\n                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])\n\n            update_summary(\n                epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),\n                write_header=best_metric is None)\n\n            # if saver is not None and epoch >= args.n_warm_up:\n            if saver is not None:\n                # save proper checkpoint with eval metric\n                save_metric = eval_metrics[eval_metric]\n                best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)\n\n    except KeyboardInterrupt:\n        pass\n    if best_metric is not None:\n        _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))\n\n\ndef train_epoch(\n        epoch, model, loader, optimizer, loss_fn, args,\n        lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,\n        loss_scaler=None, model_ema=None, mixup_fn=None, summary_writer=None):\n    if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:\n        if args.prefetcher and loader.mixup_enabled:\n            loader.mixup_enabled = False\n        elif mixup_fn is not None:\n            mixup_fn.mixup_enabled = False\n\n    second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order\n    batch_time_m = AverageMeter()\n    data_time_m = AverageMeter()\n    losses_m = AverageMeter()\n    # closses_m = AverageMeter()\n    top1_m = AverageMeter()\n    top5_m = AverageMeter()\n\n    model.train()\n\n    # t, k = adjust_surrogate_coeff(100, args.epochs)\n    # model.set_attr('t', t)\n    # model.set_attr('k', k)\n\n    end = time.time()\n    last_idx = len(loader) - 1\n    num_updates = epoch * len(loader)\n    iters_per_epoch = len(loader)\n    for batch_idx, (inputs, target) in enumerate(loader):\n        last_batch = batch_idx == last_idx\n        if args.rand_step:\n            step = buildin_random.randint(1, args.step + 2)\n            model.set_attr('step', step)\n\n        data_time_m.update(time.time() - end)\n        if not args.prefetcher or args.dataset != 'imnet':\n            inputs, target = inputs.type(torch.FloatTensor).cuda(), target.cuda()\n            if mixup_fn is not None:\n                inputs, target = mixup_fn(inputs, target)\n        if args.channels_last:\n            inputs = inputs.contiguous(memory_format=torch.channels_last)\n        with amp_autocast():\n            output = model(inputs)\n            loss = loss_fn(output, target)\n        if args.tet_loss:\n            output = output.mean(0)\n\n        if not (args.cut_mix | args.mix_up | args.event_mix | (args.cutmix != 0.) | (args.mixup != 0.)):\n            # print(output.shape, target.shape)\n            acc1, acc5 = accuracy(output, target, topk=(1, 5))\n            # acc1, = accuracy(output, target)\n        else:\n            acc1, acc5 = torch.tensor([0.]), torch.tensor([0.])\n\n        optimizer.zero_grad()\n        if loss_scaler is not None:\n            loss_scaler(\n                loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order)\n        else:\n            loss.backward(create_graph=second_order)\n            if args.noisy_grad != 0.:\n                random_gradient(model, args.noisy_grad)\n            if args.clip_grad is not None:\n                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)\n            # if args.opt == 'lamb':\n            #     optimizer.step(epoch=epoch)\n            # else:\n            optimizer.step()\n\n        torch.cuda.synchronize()\n        if model_ema is not None:\n            model_ema.update(model)\n        num_updates += 1\n\n        batch_time_m.update(time.time() - end)\n\n        if args.local_rank == 0:\n            summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/train/top1'), acc1.item(), epoch * iters_per_epoch + batch_idx)\n            summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/train/top5'), acc5.item(), epoch * iters_per_epoch + batch_idx)\n            summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/train/loss'), loss.item(), epoch * iters_per_epoch + batch_idx)\n\n        if last_batch or batch_idx % args.log_interval == 0:\n            lrl = [param_group['lr'] for param_group in optimizer.param_groups]\n            lr = sum(lrl) / len(lrl)\n\n            if args.distributed:\n                loss = reduce_tensor(loss.data, args.world_size)\n                acc1 = reduce_tensor(acc1, args.world_size)\n                acc5 = reduce_tensor(acc5, args.world_size)\n\n            losses_m.update(loss.item(), inputs.size(0))\n            top1_m.update(acc1.item(), output.size(0))\n            top5_m.update(acc5.item(), output.size(0))\n                # closses_m.update(reduced_loss.item(), inputs.size(0))\n\n            if args.local_rank == 0:\n                # if args.distributed:\n                _logger.info(\n                    'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '\n                    'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f})  '\n                    'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '\n                    'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})  '\n                    'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '\n                    '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '\n                    'LR: {lr:.3e}  '\n                    'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(\n                        epoch,\n                        batch_idx, len(loader),\n                        100. * batch_idx / last_idx,\n                        loss=losses_m,\n                        top1=top1_m,\n                        top5=top5_m,\n                        batch_time=batch_time_m,\n                        rate=inputs.size(0) * args.world_size / batch_time_m.val,\n                        rate_avg=inputs.size(0) * args.world_size / batch_time_m.avg,\n                        lr=lr,\n                        data_time=data_time_m\n                    ))\n\n                if args.save_images and output_dir:\n                    torchvision.utils.save_image(\n                        inputs,\n                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),\n                        padding=0,\n                        normalize=True)\n\n        if saver is not None and args.recovery_interval and (\n                last_batch or (batch_idx + 1) % args.recovery_interval == 0):\n            saver.save_recovery(epoch, batch_idx=batch_idx)\n\n        if lr_scheduler is not None:\n            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)\n\n        end = time.time()\n    # end for\n\n    if hasattr(optimizer, 'sync_lookahead'):\n        optimizer.sync_lookahead()\n\n    if args.local_rank == 0:\n        summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/train/top1'), top1_m.avg, epoch)\n        summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/train/top5'), top5_m.avg, epoch)\n        summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/train/loss'), losses_m.avg, epoch)\n\n    if args.rand_step:\n        model.set_attr('step', args.step)\n\n    return OrderedDict([('loss', losses_m.avg)])\n\n\ndef validate(epoch, model, loader, loss_fn, args, amp_autocast=suppress,\n             log_suffix='', visualize=False, spike_rate=False, tsne=False, conf_mat=False, summary_writer=None):\n    batch_time_m = AverageMeter()\n    losses_m = AverageMeter()\n    # closses_m = AverageMeter()\n    top1_m = AverageMeter()\n    top5_m = AverageMeter()\n    spike_m = AverageMeter()\n\n    model.eval()\n\n    feature_vec = []\n    feature_cls = []\n    logits_vec = []\n    labels_vec = []\n    mem_vec = []\n\n    end = time.time()\n    last_idx = len(loader) - 1\n    iters_per_epoch = len(loader)\n    with torch.no_grad():\n\n        for batch_idx, (inputs, target) in enumerate(loader):\n            # inputs = inputs.type(torch.float64)\n            last_batch = batch_idx == last_idx\n            if not args.prefetcher or args.dataset != 'imnet':\n                inputs = inputs.type(torch.FloatTensor).cuda()\n                target = target.cuda()\n            if args.channels_last:\n                inputs = inputs.contiguous(memory_format=torch.channels_last)\n\n            if not args.distributed:\n                if (visualize or spike_rate or tsne or conf_mat or args.mem_dist) and not args.critical_loss:\n                    model.set_requires_fp(True)\n\n            with amp_autocast():\n                output = model(inputs)\n\n            if isinstance(output, (tuple, list)):\n                output = output[0]\n\n            # augmentation reduction\n            reduce_factor = args.tta\n            if reduce_factor > 1:\n                output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)\n                target = target[0:target.size(0):reduce_factor]\n\n            # print(args.rank, output.shape, target.shape, max(target))\n            loss = loss_fn(output, target)\n            if args.tet_loss:\n                output = output.mean(0)\n            acc1, acc5 = accuracy(output, target, topk=(1, 5))\n\n            if args.distributed:\n                reduced_loss = reduce_tensor(loss.data, args.world_size)\n                acc1 = reduce_tensor(acc1, args.world_size)\n                acc5 = reduce_tensor(acc5, args.world_size)\n            else:\n                reduced_loss = loss.data\n\n            torch.cuda.synchronize()\n\n            losses_m.update(reduced_loss.item(), inputs.size(0))\n            top1_m.update(acc1.item(), output.size(0))\n            top5_m.update(acc5.item(), output.size(0))\n            # closses_m.update(closs, inputs.size(0))\n\n            batch_time_m.update(time.time() - end)\n            end = time.time()\n\n            if args.local_rank == 0:\n                summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/val/top1'), acc1.item(), epoch * iters_per_epoch + batch_idx)\n                summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/val/top5'), acc5.item(), epoch * iters_per_epoch + batch_idx)\n                summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'batch/val/loss'), loss.item(), epoch * iters_per_epoch + batch_idx)\n\n            if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):\n                log_name = 'Test' + log_suffix\n\n            if not args.distributed and spike_rate:\n                spike_m.update(model.get_tot_spike() / output.size(0), output.size(0))\n\n                if not args.distributed and spike_rate:\n                    _logger.info(\n                        '[Spike Info]: {spike.val} ({spike.avg})'.format(\n                            spike=spike_m\n                        )\n                    )\n            if last_batch or batch_idx % args.log_interval == 0:\n                _logger.info(\n                    'Eval : {} '\n                    'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '\n                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '\n                    'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})'\n                    'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(\n                        epoch,\n                        batch_idx,\n                        last_idx,\n                        batch_time=batch_time_m,\n                        loss=losses_m,\n                        top1=top1_m,\n                        top5=top5_m,\n                        ))\n\n    # metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])\n    metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg)])\n\n    if args.local_rank == 0:\n        summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/val/top1'), top1_m.avg, epoch)\n        summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/val/top5'), top5_m.avg, epoch)\n        summary_writer.add_scalar(os.path.join(args.tensorboard_prefix, 'epoch/val/loss'), losses_m.avg, epoch)\n    return metrics\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "examples/TIM/models/TIM.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom braincog.model_zoo.base_module import BaseModule\nfrom braincog.base.node.node import *\nfrom braincog.base.connection.layer import *\nfrom braincog.base.strategy.surrogate import *\nfrom utils.MyNode import *\n\n\nclass TIM(BaseModule):\n    def __init__(self,dim=256,encode_type='direct',in_channels=16,TIM_alpha=0.5):\n        super().__init__(step=1,encode_type=encode_type)\n        \n\n        #  channels may depends on the shape of input\n        self.interactor = nn.Conv1d(in_channels=in_channels,out_channels=in_channels,kernel_size=5, stride=1, padding=2, bias=True)\n\n        self.in_lif = MyNode(tau=2.0,v_threshold=0.3,layer_by_layer=False,step=1)  #spike-driven\n        self.out_lif = MyNode(tau=2.0,v_threshold=0.5,layer_by_layer=False,step=1)   #spike-driven\n\n        self.tim_alpha = TIM_alpha\n\n    # input [T, B, H, N, C/H]\n    def forward(self, x):\n        self.reset()\n\n        T, B, H, N, CoH = x.shape   \n\n        output = [] \n        x_tim = torch.empty_like(x[0]) \n\n        #temporal interaction \n\n        for i in range(T):\n            #1st step\n            if i == 0 :\n                x_tim = x[i]\n                output.append(x_tim)\n            \n            #other steps\n            else:\n                x_tim = self.interactor(x_tim.flatten(0,1)).reshape(B,H,N,CoH).contiguous()\n                x_tim = self.in_lif(x_tim) * self.tim_alpha + x[i] * (1-self.tim_alpha)\n                x_tim = self.out_lif(x_tim)\n              \n                \n                output.append(x_tim)\n            \n        output = torch.stack(output) # T B H, N, C/H\n\n        return output # T B H, N, C/H"
  },
  {
    "path": "examples/TIM/models/spikformer_braincog_DVS.py",
    "content": "import torch\nimport torch.nn as nn\nfrom timm.models.layers import to_2tuple, trunc_normal_, DropPath\nfrom timm.models.registry import register_model\nfrom timm.models.vision_transformer import _cfg\nimport torch.nn.functional as F\nfrom braincog.model_zoo.base_module import BaseModule\nfrom braincog.base.node.node import *\nfrom braincog.base.connection.layer import *\nfrom braincog.base.strategy.surrogate import *\nfrom functools import partial\nfrom torchvision import transforms\nfrom utils.MyNode import *\nfrom models.TIM import *\n__all__ = ['spikformer']\n\nclass MLP(BaseModule):\n    def __init__(self,in_features,step=10,encode_type='direct',hidden_features=None, out_features=None, drop=0.):\n        super().__init__(step=10,encode_type='direct')\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1_conv = nn.Conv1d(in_features, hidden_features, kernel_size=1, stride=1)\n        self.fc1_bn = nn.BatchNorm1d(hidden_features)\n        self.fc1_lif = MyNode(step=step,tau=2.0)\n\n        self.fc2_conv = nn.Conv1d(hidden_features, out_features, kernel_size=1, stride=1)\n        self.fc2_bn = nn.BatchNorm1d(out_features)\n        self.fc2_lif = MyNode(step=step,tau=2.0)\n\n        self.c_hidden = hidden_features\n        self.c_output = out_features\n\n    def forward(self, x):\n        self.reset()\n\n        T,B,C,N = x.shape\n\n        x = self.fc1_conv(x.flatten(0,1))\n        x = self.fc1_bn(x).reshape(T, B, self.c_hidden, N ).contiguous() # T B C N\n        x = self.fc1_lif(x.flatten(0,1)).reshape(T, B, self.c_hidden, N).contiguous() \n\n        x = self.fc2_conv(x.flatten(0,1))\n        x = self.fc2_bn(x).reshape(T, B, C, N).contiguous()\n        x = self.fc2_lif(x.flatten(0,1)).reshape(T, B, C, N ).contiguous() \n        return x\n\n\nclass SSA(BaseModule):\n    def __init__(self,dim,step=10,encode_type='direct',num_heads=16,TIM_alpha=0.5,qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):\n        super().__init__(step=10,encode_type='direct')\n        assert dim % num_heads == 0, f\"dim {dim} should be divided by num_heads {num_heads}.\"\n        self.dim = dim\n\n        self.num_heads = num_heads\n\n        self.in_channels = dim // num_heads\n\n        self.scale = 0.25\n        \n    \n        self.q_conv = nn.Conv1d(dim, dim,kernel_size=1, stride=1,bias=False)\n        self.q_bn = nn.BatchNorm1d(dim)\n        self.q_lif = MyNode(step=step,tau=2.0)\n\n        self.k_conv = nn.Conv1d(dim, dim,kernel_size=1, stride=1,bias=False)\n        self.k_bn = nn.BatchNorm1d(dim)\n        self.k_lif = MyNode(step=step,tau=2.0)\n\n        self.v_conv = nn.Conv1d(dim, dim,kernel_size=1, stride=1,bias=False)\n        self.v_bn = nn.BatchNorm1d(dim)\n        self.v_lif = MyNode(step=step,tau=2.0)\n    \n        self.attn_drop = nn.Dropout(0.2)\n        self.res_lif = MyNode(step=step, tau=2.0)\n        self.attn_lif = MyNode(step=step, tau=2.0, v_threshold=0.5,)\n\n        self.proj_conv =  nn.Conv1d(dim, dim,kernel_size=1, stride=1,bias=False)\n        self.proj_bn = nn.BatchNorm1d(dim)\n        self.proj_lif = MyNode(step=step, tau=2.0,)\n\n        self.TIM = TIM(TIM_alpha=TIM_alpha,in_channels=self.in_channels)\n        \n    def forward(self, x):\n\n        self.reset()\n\n        T,B,C,N = x.shape\n\n        x_for_qkv = x.flatten(0, 1)  \n\n        q_conv_out = self.q_conv(x_for_qkv)  \n        q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, N).contiguous()\n        q_conv_out = self.q_lif(q_conv_out.flatten(0,1)).reshape(T, B, C ,N).transpose(-2,-1)\n        q = q_conv_out.reshape(T, B, N, self.num_heads, C//self.num_heads).permute(0, 1, 3, 2, 4).contiguous()\n\n        k_conv_out = self.k_conv(x_for_qkv)\n        k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, N).contiguous()\n        k_conv_out= self.k_lif(k_conv_out.flatten(0,1)).reshape(T, B, C ,N).transpose(-2,-1)\n        k = k_conv_out.reshape(T, B, N, self.num_heads, C//self.num_heads).permute(0, 1, 3, 2, 4).contiguous()\n\n        v_conv_out = self.v_conv(x_for_qkv)\n        v_conv_out = self.v_bn(v_conv_out).reshape(T, B, C, N).contiguous()\n        v_conv_out = self.v_lif(v_conv_out.flatten(0,1)).reshape(T, B, C ,N).transpose(-2,-1)\n        v = v_conv_out.reshape(T, B, N, self.num_heads, C//self.num_heads).permute(0, 1, 3, 2, 4).contiguous()\n\n        #TIM on Q\n        q = self.TIM(q) \n\n        #SSA \n        attn = (q @ k.transpose(-2, -1)) \n        x = (attn @ v) * self.scale \n        \n        x = x.transpose(3,4).reshape(T, B, C, N).contiguous() \n        x = self.attn_lif(x.flatten(0,1)) \n        x = self.proj_lif(self.proj_bn(self.proj_conv(x))).reshape(T, B, C, N) \n        \n        return x\n\n\n\nclass Block(nn.Module):\n    def __init__(self, dim, num_heads, step=10,TIM_alpha=0.5, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n\n\n\n        self.attn = SSA(dim, step=step,TIM_alpha=TIM_alpha,num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                              attn_drop=attn_drop, sr_ratio=sr_ratio)\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = MLP(in_features=dim,step=step, hidden_features=mlp_hidden_dim, drop=drop)\n\n    def forward(self, x):\n        x = x + self.attn(x)\n        x = x + self.mlp(x)\n        return x\n\nclass SPS(BaseModule):\n    def __init__(self, step=10, encode_type='direct', img_size_h=64, img_size_w=64, patch_size=4, in_channels=2,\n                 embed_dims=256,if_UCF=False):\n        super().__init__(step=10, encode_type='direct')\n        self.image_size = [img_size_h, img_size_w]\n\n        patch_size = to_2tuple(patch_size)  \n        self.patch_size = patch_size  \n        self.C = in_channels  \n        self.H, self.W = self.image_size[0] // patch_size[0], self.image_size[1] // patch_size[1] \n        self.num_patches = self.H * self.W  \n\n        self.if_UCF = if_UCF\n\n         \n        self.proj_conv = nn.Conv2d(in_channels, embed_dims // 8, kernel_size=3, stride=1, padding=1, bias=False)\n        self.proj_bn = nn.BatchNorm2d(embed_dims // 8)\n        self.proj_lif = MyNode(step=step, tau=2.0)\n        self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n        \n        self.proj_conv1 = nn.Conv2d(embed_dims // 8, embed_dims // 4, kernel_size=3, stride=1, padding=1, bias=False) \n        self.proj_bn1 = nn.BatchNorm2d(embed_dims // 4)   \n        self.proj_lif1 = MyNode(step=step, tau=2.0)\n        self.maxpool1 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n\n        self.proj_conv2 = nn.Conv2d(embed_dims // 4, embed_dims // 2, kernel_size=3, stride=1, padding=1, bias=False)\n        self.proj_bn2 = nn.BatchNorm2d(embed_dims // 2)\n        self.proj_lif2 = MyNode(step=step, tau=2.0)\n        self.maxpool2 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n\n        self.proj_conv3 = nn.Conv2d(embed_dims // 2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)\n        self.proj_bn3 = nn.BatchNorm2d(embed_dims)\n        self.proj_lif3 = MyNode(step=step, tau=2.0)\n        self.maxpool3 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n\n        self.rpe_conv = nn.Conv2d(embed_dims, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)\n        self.rpe_bn = nn.BatchNorm2d(embed_dims)\n        self.rpe_lif = MyNode(step=step, tau=2.0)\n\n    def forward(self, x):\n        self.reset()\n\n        T, B, C, H, W = x.shape\n        \n        # UCF101DVS\n        if self.if_UCF:\n            x = F.adaptive_avg_pool2d(x.flatten(0,1), output_size=(64, 64)).reshape(T, B, C,64,64)\n            T, B, C, H, W = x.shape\n\n\n        x = self.proj_conv(x.flatten(0, 1))  # have some fire value\n        x = self.proj_bn(x).reshape(T, B, -1, H, W).contiguous()\n        x = self.proj_lif(x.flatten(0, 1)).contiguous()\n        x = self.maxpool(x)\n\n        x = self.proj_conv1(x)\n        x = self.proj_bn1(x).reshape(T, B, -1, H // 2, W // 2).contiguous()\n        x = self.proj_lif1(x.flatten(0, 1)).contiguous()\n        x = self.maxpool1(x)\n        \n        x = self.proj_conv2(x)\n        x = self.proj_bn2(x).reshape(T, B, -1, H // 4, W // 4).contiguous()\n        x = self.proj_lif2(x.flatten(0, 1)).contiguous()\n        x = self.maxpool2(x)\n\n        x = self.proj_conv3(x)\n        x = self.proj_bn3(x).reshape(T, B, -1, H // 8, W // 8).contiguous()\n        x = self.proj_lif3(x.flatten(0, 1)).contiguous()\n        x = self.maxpool3(x)\n\n        x_rpe = self.rpe_bn(self.rpe_conv(x)).reshape(T, B, -1 , H // 16,W // 16).contiguous()\n        x_rpe = self.rpe_lif(x_rpe.flatten(0,1)).contiguous()\n        x = x + x_rpe\n        x = x.reshape(T, B, -1, (H//16)*(W//16)).contiguous()\n        \n       \n        return x # T B C N\n\n\nclass Spikformer(nn.Module):\n    def __init__(self, step=10,TIM_alpha=0.5,if_UCF=False,\n                 img_size_h=64, img_size_w=64, patch_size=16, in_channels=2, num_classes=10,\n                 embed_dims=256, num_heads=16, mlp_ratios=4, qkv_bias=False, qk_scale=None,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,\n                 depths=2, sr_ratios=4, \n                 ):\n        super().__init__()\n        self.T = step  # time step\n        self.num_classes = num_classes\n        self.depths = depths\n\n\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths)]  # stochastic depth decay rule\n\n        patch_embed = SPS(       step=step, \n                                 if_UCF=if_UCF,\n                                 img_size_h=img_size_h,\n                                 img_size_w=img_size_w,\n                                 patch_size=patch_size,\n                                 in_channels=in_channels,\n                                 embed_dims=embed_dims)\n\n        block = nn.ModuleList([Block(step=step, TIM_alpha=TIM_alpha,\n            dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratios, qkv_bias=qkv_bias,\n            qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j],\n            norm_layer=norm_layer, sr_ratio=sr_ratios)\n\n            for j in range(depths)])\n\n        setattr(self, f\"patch_embed\", patch_embed)\n        setattr(self, f\"block\", block)\n\n        # classification head\n        self.head = nn.Linear(embed_dims, num_classes) if num_classes > 0 else nn.Identity()\n        self.apply(self._init_weights)\n\n    @torch.jit.ignore\n    def _get_pos_embed(self, pos_embed, patch_embed, H, W):\n        if H * W == self.patch_embed1.num_patches:\n            return pos_embed\n        else:\n            return F.interpolate(\n                pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),\n                size=(H, W), mode=\"bilinear\").reshape(1, -1, H * W).permute(0, 2, 1)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def forward_features(self, x):\n\n        block = getattr(self, f\"block\")\n        patch_embed = getattr(self, f\"patch_embed\")\n\n\n        x = patch_embed(x)\n        for blk in block:\n            x = blk(x)\n        return x.mean(3)\n\n    def forward(self, x):\n        x = x.permute(1, 0, 2, 3, 4)  \n        x = self.forward_features(x)\n        x = self.head(x.mean(0))\n        return x\n\n# Hyperparams could be adjust here\n\n@register_model\ndef spikformer_dvs(pretrained=False, **kwargs):\n    model = Spikformer(TIM_alpha=0.5,step=10,if_UCF=False,num_classes=10,\n        # img_size_h=64, img_size_w=64,\n        # patch_size=16, embed_dims=256, num_heads=16, mlp_ratios=4,\n        # in_channels=2, qkv_bias=False,\n        # depths=2, sr_ratios=1,\n        **kwargs\n    )\n    model.default_cfg = _cfg()\n    return model\n\n\n"
  },
  {
    "path": "examples/TIM/models/spikformer_braincog_SHD.py",
    "content": "import torch\nimport torch.nn as nn\nfrom timm.models.layers import to_2tuple, trunc_normal_, DropPath\nfrom timm.models.registry import register_model\nfrom timm.models.vision_transformer import _cfg\nimport torch.nn.functional as F\nfrom braincog.model_zoo.base_module import BaseModule\nfrom braincog.base.node.node import *\nfrom braincog.base.connection.layer import *\nfrom braincog.base.strategy.surrogate import *\nfrom functools import partial\nfrom torchvision import transforms\nfrom utils.MyNode import *\nfrom models.TIM import *\n__all__ = ['spikformer']\n\nclass MLP(BaseModule):\n    def __init__(self,in_features,step=10,encode_type='direct',hidden_features=None, out_features=None, drop=0.):\n        super().__init__(step=10,encode_type='direct')\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1_conv = nn.Conv1d(in_features, hidden_features, kernel_size=1, stride=1)\n        self.fc1_bn = nn.BatchNorm1d(hidden_features)\n        self.fc1_lif = MyNode(step=step,tau=2.0)\n\n        self.fc2_conv = nn.Conv1d(hidden_features, out_features, kernel_size=1, stride=1)\n        self.fc2_bn = nn.BatchNorm1d(out_features)\n        self.fc2_lif = MyNode(step=step,tau=2.0)\n\n        self.c_hidden = hidden_features\n        self.c_output = out_features\n\n    def forward(self, x):\n        self.reset()\n\n        T,B,C,N = x.shape\n\n        x = self.fc1_conv(x.flatten(0,1))\n        x = self.fc1_bn(x).reshape(T, B, self.c_hidden, N ).contiguous() # T B C N\n        x = self.fc1_lif(x.flatten(0,1)).reshape(T, B, self.c_hidden, N).contiguous() \n\n        x = self.fc2_conv(x.flatten(0,1))\n        x = self.fc2_bn(x).reshape(T, B, C, N).contiguous()\n        x = self.fc2_lif(x.flatten(0,1)).reshape(T, B, C, N ).contiguous() \n        return x\n\n\nclass SSA(BaseModule):\n    def __init__(self,dim,step=10,encode_type='direct',num_heads=16,TIM_alpha=0.5,qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):\n        super().__init__(step=10,encode_type='direct')\n        assert dim % num_heads == 0, f\"dim {dim} should be divided by num_heads {num_heads}.\"\n        self.dim = dim\n\n        self.num_heads = num_heads\n\n        self.in_channels = dim // num_heads\n\n        self.scale = 0.25\n        \n    \n        self.q_conv = nn.Conv1d(dim, dim,kernel_size=1, stride=1,bias=False)\n        self.q_bn = nn.BatchNorm1d(dim)\n        self.q_lif = MyNode(step=step,tau=2.0)\n\n        self.k_conv = nn.Conv1d(dim, dim,kernel_size=1, stride=1,bias=False)\n        self.k_bn = nn.BatchNorm1d(dim)\n        self.k_lif = MyNode(step=step,tau=2.0)\n\n        self.v_conv = nn.Conv1d(dim, dim,kernel_size=1, stride=1,bias=False)\n        self.v_bn = nn.BatchNorm1d(dim)\n        self.v_lif = MyNode(step=step,tau=2.0)\n    \n        self.attn_drop = nn.Dropout(0.2)\n        self.res_lif = MyNode(step=step, tau=2.0)\n        self.attn_lif = MyNode(step=step, tau=2.0, v_threshold=0.5,)\n\n        self.proj_conv =  nn.Conv1d(dim, dim,kernel_size=1, stride=1,bias=False)\n        self.proj_bn = nn.BatchNorm1d(dim)\n        self.proj_lif = MyNode(step=step, tau=2.0,)\n\n        self.TIM = TIM(TIM_alpha=TIM_alpha,in_channels=self.in_channels)\n        \n    def forward(self, x):\n\n        self.reset()\n\n        T,B,C,N = x.shape\n\n        x_for_qkv = x.flatten(0, 1)  \n\n        q_conv_out = self.q_conv(x_for_qkv)  \n        q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, N).contiguous()\n        q_conv_out = self.q_lif(q_conv_out.flatten(0,1)).reshape(T, B, C ,N).transpose(-2,-1) \n        q = q_conv_out.reshape(T, B, N, self.num_heads, C//self.num_heads).permute(0, 1, 3, 2, 4).contiguous()\n\n        k_conv_out = self.k_conv(x_for_qkv)\n        k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, N).contiguous()\n        k_conv_out= self.k_lif(k_conv_out.flatten(0,1)).reshape(T, B, C ,N).transpose(-2,-1)\n        k = k_conv_out.reshape(T, B, N, self.num_heads, C//self.num_heads).permute(0, 1, 3, 2, 4).contiguous()\n\n        v_conv_out = self.v_conv(x_for_qkv)\n        v_conv_out = self.v_bn(v_conv_out).reshape(T, B, C, N).contiguous()\n        v_conv_out = self.v_lif(v_conv_out.flatten(0,1)).reshape(T, B, C ,N).transpose(-2,-1)\n        v = v_conv_out.reshape(T, B, N, self.num_heads, C//self.num_heads).permute(0, 1, 3, 2, 4).contiguous()\n\n        #TIM on Q\n        q = self.TIM(q) \n\n        #SSA \n        attn = (q @ k.transpose(-2, -1)) \n        x = (attn @ v) * self.scale \n        \n        x = x.transpose(3,4).reshape(T, B, C, N).contiguous() \n        x = self.attn_lif(x.flatten(0,1)) \n        x = self.proj_lif(self.proj_bn(self.proj_conv(x))).reshape(T, B, C, N) \n        \n        return x\n\n\n\nclass Block(nn.Module):\n    def __init__(self, dim, num_heads, step=10,TIM_alpha=0.5, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n\n\n\n        self.attn = SSA(dim, step=step,TIM_alpha=TIM_alpha,num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                              attn_drop=attn_drop, sr_ratio=sr_ratio)\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = MLP(in_features=dim,step=step, hidden_features=mlp_hidden_dim, drop=drop)\n\n    def forward(self, x):\n        x = x + self.attn(x)\n        x = x + self.mlp(x)\n        return x\n\nclass SPS(BaseModule):\n    def __init__(self, step=10, encode_type='direct', img_size_h=64, img_size_w=64, patch_size=4, in_channels=2,\n                 embed_dims=256,if_UCF=False):\n        super().__init__(step=10, encode_type='direct')\n        self.image_size = [img_size_h, img_size_w]\n\n        patch_size = to_2tuple(patch_size)  \n        self.patch_size = patch_size  \n        self.C = in_channels  \n        self.H, self.W = self.image_size[0] // patch_size[0], self.image_size[1] // patch_size[1] \n        self.num_patches = self.H * self.W  \n\n        self.if_UCF = if_UCF\n\n         \n        self.proj_conv = nn.Conv2d(in_channels, embed_dims // 8, kernel_size=3, stride=1, padding=1, bias=False)\n        self.proj_bn = nn.BatchNorm2d(embed_dims // 8)\n        self.proj_lif = MyNode(step=step, tau=2.0)\n        self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n        \n        self.proj_conv1 = nn.Conv2d(embed_dims // 8, embed_dims // 4, kernel_size=3, stride=1, padding=1, bias=False) \n        self.proj_bn1 = nn.BatchNorm2d(embed_dims // 4)   \n        self.proj_lif1 = MyNode(step=step, tau=2.0)\n        self.maxpool1 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n\n        self.proj_conv2 = nn.Conv2d(embed_dims // 4, embed_dims // 2, kernel_size=3, stride=1, padding=1, bias=False)\n        self.proj_bn2 = nn.BatchNorm2d(embed_dims // 2)\n        self.proj_lif2 = MyNode(step=step, tau=2.0)\n        self.maxpool2 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n\n        self.proj_conv3 = nn.Conv2d(embed_dims // 2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)\n        self.proj_bn3 = nn.BatchNorm2d(embed_dims)\n        self.proj_lif3 = MyNode(step=step, tau=2.0)\n        self.maxpool3 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n\n        self.rpe_conv = nn.Conv2d(embed_dims, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)\n        self.rpe_bn = nn.BatchNorm2d(embed_dims)\n        self.rpe_lif = MyNode(step=step, tau=2.0)\n\n    def forward(self, x):\n        self.reset()\n\n        # SHD\n        T, B, _ = x.shape\n        x = x.reshape(T,B,2,-1) # T B 2 350\n        \n        x = F.interpolate(x.flatten(0,1), size=256, mode='nearest').reshape(T,B,2,16,16)\n\n\n        T, B, C, H, W = x.shape\n\n        x = self.proj_conv(x.flatten(0, 1))  # have some fire value\n        x = self.proj_bn(x).reshape(T, B, -1, H, W).contiguous()\n        x = self.proj_lif(x.flatten(0, 1)).contiguous()\n        # x = self.maxpool(x)\n\n        x = self.proj_conv1(x)\n        x = self.proj_bn1(x).reshape(T, B, -1, H, W).contiguous()\n        x = self.proj_lif1(x.flatten(0, 1)).contiguous()\n        # x = self.maxpool1(x)\n        \n\n        x = self.proj_conv2(x)\n        x = self.proj_bn2(x).reshape(T, B, -1, H, W).contiguous()\n        x = self.proj_lif2(x.flatten(0, 1)).contiguous()\n        x = self.maxpool2(x)\n\n        x = self.proj_conv3(x)\n        x = self.proj_bn3(x).reshape(T, B, -1, H // 2, W // 2).contiguous()\n        x = self.proj_lif3(x.flatten(0, 1)).contiguous()\n        x = self.maxpool3(x)\n\n        x_rpe = self.rpe_bn(self.rpe_conv(x)).reshape(T, B, -1 , H // 4,W // 4).contiguous()\n        x_rpe = self.rpe_lif(x_rpe.flatten(0,1)).contiguous()\n        x = x + x_rpe\n        x = x.reshape(T, B, -1, (H//4)*(W//4)).contiguous()\n        \n       \n        return x # T B C N\n\n\nclass Spikformer(nn.Module):\n    def __init__(self, step=10,TIM_alpha=0.5,if_UCF=False,\n                 img_size_h=64, img_size_w=64, patch_size=16, in_channels=2, num_classes=10,\n                 embed_dims=256, num_heads=16, mlp_ratios=4, qkv_bias=False, qk_scale=None,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,\n                 depths=2, sr_ratios=4, \n                 ):\n        super().__init__()\n        self.T = step  # time step\n        self.num_classes = num_classes\n        self.depths = depths\n\n\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths)]  # stochastic depth decay rule\n\n        patch_embed = SPS(       step=step, \n                                 if_UCF=if_UCF,\n                                 img_size_h=img_size_h,\n                                 img_size_w=img_size_w,\n                                 patch_size=patch_size,\n                                 in_channels=in_channels,\n                                 embed_dims=embed_dims)\n\n        block = nn.ModuleList([Block(step=step, TIM_alpha=TIM_alpha,\n            dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratios, qkv_bias=qkv_bias,\n            qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j],\n            norm_layer=norm_layer, sr_ratio=sr_ratios)\n\n            for j in range(depths)])\n\n        setattr(self, f\"patch_embed\", patch_embed)\n        setattr(self, f\"block\", block)\n\n        # classification head\n        self.head = nn.Linear(embed_dims, num_classes) if num_classes > 0 else nn.Identity()\n        self.apply(self._init_weights)\n\n    @torch.jit.ignore\n    def _get_pos_embed(self, pos_embed, patch_embed, H, W):\n        if H * W == self.patch_embed1.num_patches:\n            return pos_embed\n        else:\n            return F.interpolate(\n                pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),\n                size=(H, W), mode=\"bilinear\").reshape(1, -1, H * W).permute(0, 2, 1)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def forward_features(self, x):\n\n        block = getattr(self, f\"block\")\n        patch_embed = getattr(self, f\"patch_embed\")\n\n\n        x = patch_embed(x)\n        for blk in block:\n            x = blk(x)\n        return x.mean(3)\n\n    def forward(self, x):\n        x = x.permute(1, 0, 2)  \n        x = self.forward_features(x)\n        x = self.head(x.mean(0))\n        return x\n\n# Hyperparams could be adjust here\n\n@register_model\ndef spikformer_shd(pretrained=False, **kwargs):\n    model = Spikformer(TIM_alpha=0.5,step=10,if_UCF=False,num_classes=20,\n        # img_size_h=64, img_size_w=64,\n        # patch_size=16, embed_dims=256, num_heads=16, mlp_ratios=4,\n        # in_channels=2, qkv_bias=False,\n        # depths=2, sr_ratios=1,\n        **kwargs\n    )\n    model.default_cfg = _cfg()\n    return model\n\n\n"
  },
  {
    "path": "examples/TIM/utils/MyGrad.py",
    "content": "from braincog.base.node.node import *\nfrom braincog.base.connection.layer import *\nfrom braincog.base.strategy.surrogate import *\n\n\nclass MyGrad(SurrogateFunctionBase):\n    def __init__(self, alpha=4., requires_grad=False):\n        super().__init__(alpha, requires_grad)\n\n    @staticmethod\n    def act_fun(x, alpha):\n        return sigmoid.apply(x, alpha)\n    "
  },
  {
    "path": "examples/TIM/utils/MyNode.py",
    "content": "from braincog.base.node.node import *\nfrom braincog.base.connection.layer import *\nfrom braincog.base.strategy.surrogate import *\nfrom utils.MyGrad import MyGrad\n\n\nclass MyBaseNode(BaseNode):\n    def __init__(self,threshold=0.5,step=10,layer_by_layer=False,mem_detach=False):\n        super().__init__(threshold=threshold,step=step,layer_by_layer=layer_by_layer,mem_detach=mem_detach)\n    def rearrange2node(self, inputs):\n        if self.groups != 1:\n            if len(inputs.shape) == 4:\n                outputs = rearrange(inputs, 'b (c t) w h -> t b c w h', t=self.step)\n            elif len(inputs.shape) == 2:\n                outputs = rearrange(inputs, 'b (c t) -> t b c', t=self.step)\n            else:\n                raise NotImplementedError\n\n        elif self.layer_by_layer:\n            if len(inputs.shape) == 4:\n                outputs = rearrange(inputs, '(t b) c w h -> t b c w h', t=self.step)\n\n            #加入适配Transformer T B N C的rearange2node分支\n            elif len(inputs.shape) == 3:\n                outputs = rearrange(inputs, '(t b) n c -> t b n c', t=self.step)\n            elif len(inputs.shape) == 2:\n                outputs = rearrange(inputs, '(t b) c -> t b c', t=self.step)\n            else:\n                raise NotImplementedError\n\n\n        else:\n            outputs = inputs\n\n        return outputs\n\n    def rearrange2op(self, inputs):\n        if self.groups != 1:\n            if len(inputs.shape) == 5:\n                outputs = rearrange(inputs, 't b c w h -> b (c t) w h')\n            elif len(inputs.shape) == 3:\n                outputs = rearrange(inputs, ' t b c -> b (c t)')\n            else:\n                raise NotImplementedError\n        elif self.layer_by_layer:\n            if len(inputs.shape) == 5:\n                outputs = rearrange(inputs, 't b c w h -> (t b) c w h')\n\n            # 加入适配Transformer T B N C的rearange2op分支\n            elif len(inputs.shape) == 4:\n                outputs = rearrange(inputs, ' t b n c -> (t b) n c')\n            elif len(inputs.shape) == 3:\n                outputs = rearrange(inputs, ' t b c -> (t b) c')\n            else:\n                raise NotImplementedError\n\n        else:\n            outputs = inputs\n\n        return outputs\n    \nclass MyNode(MyBaseNode):\n    def __init__(self, threshold=1.,step=10,layer_by_layer=True,tau=2., act_fun=MyGrad, mem_detach=True,*args, **kwargs):\n        super().__init__(threshold=threshold,step=step, layer_by_layer=layer_by_layer,mem_detach=mem_detach)\n        self.tau = tau\n        if isinstance(act_fun, str):\n            act_fun = eval(act_fun)\n        self.act_fun = act_fun(alpha=4., requires_grad=False)\n    def integral(self, inputs):\n        self.mem = self.mem + (inputs - self.mem) / self.tau\n    def calc_spike(self):\n        self.spike = self.act_fun(self.mem - self.threshold)\n        self.mem = self.mem * (1 - self.spike.detach())"
  },
  {
    "path": "examples/TIM/utils/datasets.py",
    "content": "import os\nimport warnings\nimport random\nimport torchvision.datasets\n\nimport braincog.datasets.ucf101_dvs\n\ntry:\n    import tonic\n    from tonic import DiskCachedDataset\nexcept:\n    warnings.warn(\"tonic should be installed, 'pip install git+https://github.com/FloyedShen/tonic.git'\")\n\nimport torch\nimport torch.nn.functional as F\nimport torch.utils\nimport torchvision.datasets as datasets\nfrom timm.data import ImageDataset, create_loader, Mixup, FastCollateMixup, AugMixDataset\nfrom timm.data import create_transform, distributed_sampler\nfrom timm.data.loader import PrefetchLoader\nfrom tonic import DiskCachedDataset\nfrom torchvision import transforms\nfrom typing import Any, Dict, Optional, Sequence, Tuple, Union\nfrom braincog.datasets.NOmniglot.nomniglot_full import NOmniglotfull\nfrom braincog.datasets.NOmniglot.nomniglot_nw_ks import NOmniglotNWayKShot\nfrom braincog.datasets.NOmniglot.nomniglot_pair import NOmniglotTrainSet, NOmniglotTestSet\n# from braincog.base.conversion.conversion import CIFAR10Policy, Cutout\n# from .cut_mix import CutMix, EventMix, MixUp\n# from .rand_aug import *\n# from .event_drop import event_drop\n# from .utils import dvs_channel_check_expend, rescale\n\nDVSCIFAR10_MEAN_16 = [0.3290, 0.4507]\nDVSCIFAR10_STD_16 = [1.8398, 1.6549]\n\nDATA_DIR = '/data/datasets'\n\nDEFAULT_CROP_PCT = 0.875\nIMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)\nIMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)\nIMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)\nIMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)\nIMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)\nIMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)\n\nCIFAR10_DEFAULT_MEAN = (0.4914, 0.4822, 0.4465)\nCIFAR10_DEFAULT_STD = (0.2023, 0.1994, 0.2010)\nCIFAR100_DEFAULT_MEAN = (0.5071, 0.4867, 0.4408)\nCIFAR100_DEFAULT_STD = (0.2675, 0.2565, 0.2761)\n\n\ndef unpack_mix_param(args):\n    mix_up = args['mix_up'] if 'mix_up' in args else False\n    cut_mix = args['cut_mix'] if 'cut_mix' in args else False\n    event_mix = args['event_mix'] if 'event_mix' in args else False\n    beta = args['beta'] if 'beta' in args else 1.\n    prob = args['prob'] if 'prob' in args else .5\n    num = args['num'] if 'num' in args else 1\n    num_classes = args['num_classes'] if 'num_classes' in args else 10\n    noise = args['noise'] if 'noise' in args else 0.\n    gaussian_n = args['gaussian_n'] if 'gaussian_n' in args else None\n    return mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n\n\n\ndef build_transform(is_train, img_size):\n    \"\"\"\n    构建数据增强, 适用于static data\n    :param is_train: 是否训练集\n    :param img_size: 输出的图像尺寸\n    :return: 数据增强策略\n    \"\"\"\n    resize_im = img_size > 32\n    if is_train:\n        # this should always dispatch to transforms_imagenet_train\n        transform = create_transform(\n            input_size=img_size,\n            is_training=True,\n            color_jitter=0.4,\n            # auto_augment='rand-m9-mstd0.5-inc1',\n            interpolation='bicubic',\n            # re_prob=0.25,\n            # re_mode='pixel',\n            # re_count=1,\n        )\n        if not resize_im:\n            # replace RandomResizedCropAndInterpolation with\n            # RandomCrop\n            transform.transforms[0] = transforms.RandomCrop(\n                img_size, padding=4)\n        return transform\n\n    t = []\n    if resize_im:\n        size = int((256 / 224) * img_size)\n        t.append(\n            # to maintain same ratio w.r.t. 224 images\n            transforms.Resize(size, interpolation=3),\n        )\n        t.append(transforms.CenterCrop(img_size))\n\n    t.append(transforms.ToTensor())\n    if img_size > 32:\n        t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))\n    else:\n        t.append(transforms.Normalize(CIFAR10_DEFAULT_MEAN, CIFAR10_DEFAULT_STD))\n    return transforms.Compose(t)\n\n\ndef build_dataset(is_train, img_size, dataset, path, same_da=False):\n    \"\"\"\n    构建带有增强策略的数据集\n    :param is_train: 是否训练集\n    :param img_size: 输出图像尺寸\n    :param dataset: 数据集名称\n    :param path: 数据集路径\n    :param same_da: 为训练集使用测试集的增广方法\n    :return: 增强后的数据集\n    \"\"\"\n    transform = build_transform(False, img_size) if same_da else build_transform(is_train, img_size)\n\n    if dataset == 'CIFAR10':\n        dataset = datasets.CIFAR10(\n            path, train=is_train, transform=transform, download=True)\n        nb_classes = 10\n    elif dataset == 'CIFAR100':\n        dataset = datasets.CIFAR100(\n            path, train=is_train, transform=transform, download=True)\n        nb_classes = 100\n    else:\n        raise NotImplementedError\n\n    return dataset, nb_classes\n\n\nclass MNISTData(object):\n    \"\"\"\n    Load MNIST datesets.\n    \"\"\"\n\n    def __init__(self,\n                 data_path: str,\n                 batch_size: int,\n                 train_trans: Sequence[torch.nn.Module] = None,\n                 test_trans: Sequence[torch.nn.Module] = None,\n                 pin_memory: bool = True,\n                 drop_last: bool = True,\n                 shuffle: bool = True,\n                 ) -> None:\n        self._data_path = data_path\n        self._batch_size = batch_size\n        self._pin_memory = pin_memory\n        self._drop_last = drop_last\n        self._shuffle = shuffle\n        self._train_transform = transforms.Compose(train_trans) if train_trans else None\n        self._test_transform = transforms.Compose(test_trans) if test_trans else None\n\n    def get_data_loaders(self):\n        print('Batch size: ', self._batch_size)\n        train_datasets = datasets.MNIST(root=self._data_path, train=True, transform=self._train_transform, download=True)\n        test_datasets = datasets.MNIST(root=self._data_path, train=False, transform=self._test_transform, download=True)\n        train_loader = torch.utils.data.DataLoader(\n            train_datasets, batch_size=self._batch_size,\n            pin_memory=self._pin_memory, drop_last=self._drop_last, shuffle=self._shuffle\n        )\n        test_loader = torch.utils.data.DataLoader(\n            test_datasets, batch_size=self._batch_size,\n            pin_memory=self._pin_memory, drop_last=False\n        )\n        return train_loader, test_loader\n\n    def get_standard_data(self):\n        MNIST_MEAN = 0.1307\n        MNIST_STD = 0.3081\n        self._train_transform = transforms.Compose([transforms.RandomCrop(28, padding=4),\n                                                    transforms.ToTensor(),\n                                                    transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])\n        self._test_transform = transforms.Compose([transforms.ToTensor(),\n                                                   transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])\n        return self.get_data_loaders()\n\n\ndef get_mnist_data(batch_size, num_workers=8, same_da=False, **kwargs):\n    \"\"\"\n    获取MNIST数据\n    http://data.pymvpa.org/datasets/mnist/\n    :param batch_size: batch size\n    :param same_da: 为训练集使用测试集的增广方法\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    MNIST_MEAN = 0.1307\n    MNIST_STD = 0.3081\n    if 'skip_norm' in kwargs and kwargs['skip_norm'] is True:\n        train_transform = transforms.Compose([\n            transforms.ToTensor(),\n            transforms.Lambda(rescale)\n        ])\n        test_transform = transforms.Compose([\n            transforms.ToTensor(),\n            transforms.Lambda(rescale)\n        ])\n    else:\n        train_transform = transforms.Compose([transforms.RandomCrop(28, padding=4),\n                                              # transforms.RandomRotation(10),\n                                              transforms.ToTensor(),\n                                              transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])\n        test_transform = transforms.Compose([transforms.ToTensor(),\n                                             transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))])\n\n    train_datasets = datasets.MNIST(\n        root=DATA_DIR, train=True, transform=test_transform if same_da else train_transform, download=True)\n    test_datasets = datasets.MNIST(\n        root=DATA_DIR, train=False, transform=test_transform, download=True)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=num_workers\n    )\n\n    return train_loader, test_loader, False, None\n\n\ndef get_fashion_data(batch_size, num_workers=8, same_da=False, **kwargs):\n    \"\"\"\n    获取fashion MNIST数据\n    http://arxiv.org/abs/1708.07747\n    :param batch_size: batch size\n    :param same_da: 为训练集使用测试集的增广方法\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    train_transform = transforms.Compose([transforms.RandomCrop(28, padding=4),\n                                          transforms.RandomHorizontalFlip(),\n                                          transforms.RandomRotation(10),\n                                          transforms.ToTensor()])\n    test_transform = transforms.Compose([transforms.ToTensor()])\n\n    train_datasets = datasets.FashionMNIST(\n        root=DATA_DIR, train=True, transform=test_transform if same_da else train_transform, download=True)\n    test_datasets = datasets.FashionMNIST(\n        root=DATA_DIR, train=False, transform=test_transform, download=True)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_datasets, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=num_workers\n    )\n\n    return train_loader, test_loader, False, None\n\n\ndef get_cifar10_data(batch_size, num_workers=8, same_da=False, **kwargs):\n    # \"\"\"\n    # 获取CIFAR10数据\n    #  https://www.cs.toronto.edu/~kriz/cifar.html\n    # :param batch_size: batch size\n    # :param kwargs:\n    # :return: (train loader, test loader, mixup_active, mixup_fn)\n    # \"\"\"\n    # train_datasets, _ = build_dataset(True, 32, 'CIFAR10', DATA_DIR, same_da)\n    # test_datasets, _ = build_dataset(False, 32, 'CIFAR10', DATA_DIR, same_da)\n    #\n    # train_loader = torch.utils.data.DataLoader(\n    #     train_datasets, batch_size=batch_size,\n    #     pin_memory=True, drop_last=True, shuffle=True,\n    #     num_workers=num_workers\n    # )\n    #\n    # test_loader = torch.utils.data.DataLoader(\n    #     test_datasets, batch_size=batch_size,\n    #     pin_memory=True, drop_last=False,\n    #     num_workers=num_workers\n    # )\n    normalize = transforms.Normalize(CIFAR10_DEFAULT_MEAN, CIFAR10_DEFAULT_STD)\n    transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),\n                                          CIFAR10Policy(),\n                                          transforms.ToTensor(),\n                                          Cutout(n_holes=1, length=16),\n                                          normalize])\n    transform_test = transforms.Compose([transforms.ToTensor(), normalize])\n    train_dataset = datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=transform_train)\n    test_dataset = datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=transform_test)\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset,  batch_size=batch_size,\n        shuffle=True, num_workers=num_workers,\n        pin_memory=True\n    )\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        shuffle=False, num_workers=num_workers,\n        pin_memory=True\n    )\n    return train_loader, test_loader, None, None\n\ndef get_shd_data(batch_size, step, **kwargs):\n    \"\"\"\n    获取SHD数据\n    https://ieeexplore.ieee.org/abstract/document/9311226\n    :param batch_size: batch size\n    :param step: 仿真步长\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    :format: (b,t,c,len) 不同于vision, audio中c为1, 并且没有h,w; 只有len=700\n    \"\"\"\n    sensor_size = tonic.datasets.SHD.sensor_size\n    train_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        # tonic.transforms.DropEvent(p=0.1),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),\n    ])\n    test_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),\n    ])\n\n    train_dataset = tonic.datasets.SHD(os.path.join(DATA_DIR, 'DVS/SHD'),\n                                              transform=train_transform, train=True)\n\n    test_dataset = tonic.datasets.SHD(os.path.join(DATA_DIR, 'DVS/SHD'),\n                                             transform=test_transform, train=False)\n\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size,\n        pin_memory=True, drop_last=True, num_workers=8,\n        shuffle=True,\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=2,\n        shuffle=False,\n    )\n\n    return train_loader, test_loader, None, None\n\ndef get_cifar100_data(batch_size, num_workers=8, same_data=False, *args, **kwargs):\n    # \"\"\"\n    # 获取CIFAR100数据\n    # https://www.cs.toronto.edu/~kriz/cifar.html\n    # :param batch_size: batch size\n    # :param kwargs:\n    # :return: (train loader, test loader, mixup_active, mixup_fn)\n    # \"\"\"\n    # train_datasets, _ = build_dataset(True, 32, 'CIFAR100', DATA_DIR, same_data)\n    # test_datasets, _ = build_dataset(False, 32, 'CIFAR100', DATA_DIR, same_data)\n    #\n    # train_loader = torch.utils.data.DataLoader(\n    #     train_datasets, batch_size=batch_size,\n    #     pin_memory=True, drop_last=True, shuffle=True, num_workers=num_workers\n    # )\n    #\n    # test_loader = torch.utils.data.DataLoader(\n    #     test_datasets, batch_size=batch_size,\n    #     pin_memory=True, drop_last=False, num_workers=num_workers\n    # )\n    # return train_loader, test_loader, False, None\n    normalize = transforms.Normalize(CIFAR100_DEFAULT_MEAN, CIFAR100_DEFAULT_STD)\n    transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),\n                                          CIFAR10Policy(),\n                                          transforms.ToTensor(),\n                                          Cutout(n_holes=1, length=16),\n                                          normalize])\n    transform_test = transforms.Compose([transforms.ToTensor(), normalize])\n    train_dataset = datasets.CIFAR100(root=DATA_DIR, train=True, download=True, transform=transform_train)\n    test_dataset = datasets.CIFAR100(root=DATA_DIR, train=False, download=True, transform=transform_test)\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset,  batch_size=batch_size,\n        shuffle=True, num_workers=num_workers,\n        pin_memory=True\n    )\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        shuffle=False, num_workers=num_workers,\n        pin_memory=True\n    )\n    return train_loader, test_loader, None, None\n\n\ndef get_imnet_data(args, _logger, data_config, num_aug_splits, **kwargs):\n    \"\"\"\n    获取ImageNet数据集\n    http://arxiv.org/abs/1409.0575\n    :param args: 其他的参数\n    :param _logger: 日志路径\n    :param data_config: 增强策略\n    :param num_aug_splits: 不同增强策略的数量\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    train_dir = os.path.join(DATA_DIR, 'ILSVRC2012/train')\n    if not os.path.exists(train_dir):\n        _logger.error(\n            'Training folder does not exist at: {}'.format(train_dir))\n        exit(1)\n    dataset_train = ImageDataset(train_dir)\n    # collate_fn = None\n    # mixup_fn = None\n    # mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None\n    # if mixup_active:\n    #     mixup_args = dict(\n    #         mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,\n    #         prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,\n    #         label_smoothing=args.smoothing, num_classes=args.num_classes)\n    #     if args.prefetcher:\n    #         # collate conflict (need to support deinterleaving in collate mixup)\n    #         assert not num_aug_splits\n    #         collate_fn = FastCollateMixup(**mixup_args)\n    #     else:\n    #         mixup_fn = Mixup(**mixup_args)\n\n    # if num_aug_splits > 1:\n    #     dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)\n\n    train_interpolation = args.train_interpolation\n    if args.no_aug or not train_interpolation:\n        train_interpolation = data_config['interpolation']\n    loader_train = create_loader(\n        dataset_train,\n        input_size=data_config['input_size'],\n        batch_size=args.batch_size,\n        is_training=True,\n        use_prefetcher=args.prefetcher,\n        no_aug=args.no_aug,\n        # re_prob=args.reprob,\n        # re_mode=args.remode,\n        # re_count=args.recount,\n        # re_split=args.resplit,\n        scale=args.scale,\n        ratio=args.ratio,\n        hflip=args.hflip,\n        # vflip=args.vflip,\n        # color_jitter=args.color_jitter,\n        # auto_augment=args.aa,\n        # num_aug_splits=num_aug_splits,\n        interpolation=train_interpolation,\n        mean=data_config['mean'],\n        std=data_config['std'],\n        num_workers=args.workers,\n        distributed=args.distributed,\n        # collate_fn=collate_fn,\n        pin_memory=args.pin_mem,\n        # use_multi_epochs_loader=args.use_multi_epochs_loader\n    )\n\n    eval_dir = os.path.join(DATA_DIR, 'ILSVRC2012/val')\n    if not os.path.isdir(eval_dir):\n        eval_dir = os.path.join(DATA_DIR, 'ILSVRC2012/validation')\n        if not os.path.isdir(eval_dir):\n            _logger.error(\n                'Validation folder does not exist at: {}'.format(eval_dir))\n            exit(1)\n    dataset_eval = ImageDataset(eval_dir)\n\n    loader_eval = create_loader(\n        dataset_eval,\n        input_size=data_config['input_size'],\n        batch_size=args.validation_batch_size_multiplier * args.batch_size,\n        is_training=False,\n        use_prefetcher=args.prefetcher,\n        interpolation=data_config['interpolation'],\n        mean=data_config['mean'],\n        std=data_config['std'],\n        num_workers=args.workers,\n        distributed=args.distributed,\n        crop_pct=data_config['crop_pct'],\n        pin_memory=args.pin_mem,\n    )\n    return loader_train, loader_eval, None, None\n\n\ndef get_dvsg_data(batch_size, step, **kwargs):\n    \"\"\"\n    获取DVS Gesture数据\n    DOI: 10.1109/CVPR.2017.781\n    :param batch_size: batch size\n    :param step: 仿真步长\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    sensor_size = tonic.datasets.DVSGesture.sensor_size\n    size = kwargs['size'] if 'size' in kwargs else 48\n\n    train_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        # tonic.transforms.DropEvent(p=0.1),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),\n    ])\n    test_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step),\n    ])\n\n    train_dataset = tonic.datasets.DVSGesture(os.path.join(DATA_DIR, 'DVS/DVSGesture'),\n                                              transform=train_transform, train=True)\n    test_dataset = tonic.datasets.DVSGesture(os.path.join(DATA_DIR, 'DVS/DVSGesture'),\n                                             transform=test_transform, train=False)\n\n    train_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        lambda x: dvs_channel_check_expend(x),\n        transforms.RandomCrop(size, padding=size // 12),\n        # lambda x: event_drop(x),\n        # transforms.RandomHorizontalFlip(),\n        # transforms.RandomRotation(15)\n    ])\n    test_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        lambda x: dvs_channel_check_expend(x),\n    ])\n    if 'rand_aug' in kwargs.keys():\n        if kwargs['rand_aug'] is True:\n            n = kwargs['randaug_n']\n            m = kwargs['randaug_m']\n            train_transform.transforms.insert(2, RandAugment(m=m, n=n))\n\n    # if 'temporal_flatten' in kwargs.keys():\n    #     if kwargs['temporal_flatten'] is True:\n    #         train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n    #         test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n\n    train_dataset = DiskCachedDataset(train_dataset,\n                                      cache_path=os.path.join(DATA_DIR, 'DVS/DVSGesture/train_cache_{}'.format(step)),\n                                      transform=train_transform, num_copies=3)\n    test_dataset = DiskCachedDataset(test_dataset,\n                                     cache_path=os.path.join(DATA_DIR, 'DVS/DVSGesture/test_cache_{}'.format(step)),\n                                     transform=test_transform, num_copies=3)\n\n    mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n    mixup_active = cut_mix | event_mix | mix_up\n\n    if cut_mix:\n        train_dataset = CutMix(train_dataset,\n                               beta=beta,\n                               prob=prob,\n                               num_mix=num,\n                               num_class=num_classes,\n                               noise=noise)\n\n    if event_mix:\n        train_dataset = EventMix(train_dataset,\n                                 beta=beta,\n                                 prob=prob,\n                                 num_mix=num,\n                                 num_class=num_classes,\n                                 noise=noise,\n                                 gaussian_n=gaussian_n)\n    if mix_up:\n        train_dataset = MixUp(train_dataset,\n                              beta=beta,\n                              prob=prob,\n                              num_mix=num,\n                              num_class=num_classes,\n                              noise=noise)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size,\n        pin_memory=True, drop_last=True, num_workers=8,\n        shuffle=True,\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        pin_memory=True, drop_last=False, num_workers=2,\n        shuffle=False,\n    )\n\n    return train_loader, test_loader, mixup_active, None\n\n\ndef get_dvsc10_data(batch_size, step, **kwargs):\n    \"\"\"\n    获取DVS CIFAR10数据\n    http://journal.frontiersin.org/article/10.3389/fnins.2017.00309/full\n    :param batch_size: batch size\n    :param step: 仿真步长\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    size = kwargs['size'] if 'size' in kwargs else 48\n    sensor_size = tonic.datasets.CIFAR10DVS.sensor_size\n    train_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        # tonic.transforms.DropEvent(p=0.1),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    test_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    train_dataset = tonic.datasets.CIFAR10DVS(os.path.join(DATA_DIR, 'DVS/DVS_Cifar10'), transform=train_transform)\n    test_dataset = tonic.datasets.CIFAR10DVS(os.path.join(DATA_DIR, 'DVS/DVS_Cifar10'), transform=test_transform)\n\n    train_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        # lambda x: TemporalShift(x, .01),\n        # lambda x: drop(x, 0.15),\n        # lambda x: ShearX(x, 15),\n        # lambda x: ShearY(x, 15),\n        # lambda x: TranslateX(x, 0.225),\n        # lambda x: TranslateY(x, 0.225),\n        # lambda x: Rotate(x, 15),\n        # lambda x: CutoutAbs(x, 0.25),\n        # lambda x: CutoutTemporal(x, 0.25),\n        # lambda x: GaussianBlur(x, 0.5),\n        # lambda x: SaltAndPepperNoise(x, 0.1),\n        # transforms.Normalize(DVSCIFAR10_MEAN_16, DVSCIFAR10_STD_16),\n        transforms.RandomCrop(size, padding=size // 12),\n        transforms.RandomHorizontalFlip(),\n        # lambda x: event_drop(x),\n        # transforms.RandomRotation(15)\n    ])\n    test_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n    ])\n\n    if 'rand_aug' in kwargs.keys():\n        if kwargs['rand_aug'] is True:\n            n = kwargs['randaug_n']\n            m = kwargs['randaug_m']\n            # print('randaug', m, n)\n            train_transform.transforms.insert(2, RandAugment(m=m, n=n))\n\n    # if 'temporal_flatten' in kwargs.keys():\n    #     if kwargs['temporal_flatten'] is True:\n    #         train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n    #         test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n\n    train_dataset = DiskCachedDataset(train_dataset,\n                                      cache_path=os.path.join(DATA_DIR, 'DVS/DVS_Cifar10/train_cache_{}'.format(step)),\n                                      transform=train_transform)\n    test_dataset = DiskCachedDataset(test_dataset,\n                                     cache_path=os.path.join(DATA_DIR, 'DVS/DVS_Cifar10/test_cache_{}'.format(step)),\n                                     transform=test_transform)\n\n    num_train = len(train_dataset)\n    num_per_cls = num_train // 10\n    indices_train, indices_test = [], []\n    portion = kwargs['portion'] if 'portion' in kwargs else .9\n    for i in range(10):\n        indices_train.extend(\n            list(range(i * num_per_cls, round(i * num_per_cls + num_per_cls * portion))))\n        indices_test.extend(\n            list(range(round(i * num_per_cls + num_per_cls * portion), (i + 1) * num_per_cls)))\n\n    mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n    mixup_active = cut_mix | event_mix | mix_up\n\n    if cut_mix:\n        # print('cut_mix', beta, prob, num, num_classes)\n        train_dataset = CutMix(train_dataset,\n                               beta=beta,\n                               prob=prob,\n                               num_mix=num,\n                               num_class=num_classes,\n                               indices=indices_train,\n                               noise=noise)\n\n    if event_mix:\n        train_dataset = EventMix(train_dataset,\n                                 beta=beta,\n                                 prob=prob,\n                                 num_mix=num,\n                                 num_class=num_classes,\n                                 indices=indices_train,\n                                 noise=noise,\n                                 gaussian_n=gaussian_n)\n\n    if mix_up:\n        train_dataset = MixUp(train_dataset,\n                              beta=beta,\n                              prob=prob,\n                              num_mix=num,\n                              num_class=num_classes,\n                              indices=indices_train,\n                              noise=noise)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size,\n        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_train),\n        pin_memory=True, drop_last=True, num_workers=8\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_test),\n        pin_memory=True, drop_last=False, num_workers=2\n    )\n\n    return train_loader, test_loader, mixup_active, None\n\n\ndef get_nmnist_data(batch_size, step, **kwargs):\n    \"\"\"\n    获取DVS CIFAR10数据\n    http://journal.frontiersin.org/article/10.3389/fnins.2017.00309/full\n    :param batch_size: batch size\n    :param step: 仿真步长\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    size = kwargs['size'] if 'size' in kwargs else 48\n    sensor_size = tonic.datasets.NMNIST.sensor_size\n    train_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        # tonic.transforms.DropEvent(p=0.1),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    test_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    train_dataset = tonic.datasets.NMNIST(os.path.join(DATA_DIR, 'DVS/NMNIST'), transform=train_transform)\n    test_dataset = tonic.datasets.NMNIST(os.path.join(DATA_DIR, 'DVS/NMNIST'), transform=test_transform)\n\n    train_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        # lambda x: TemporalShift(x, .01),\n        # lambda x: drop(x, 0.15),\n        # lambda x: ShearX(x, 15),\n        # lambda x: ShearY(x, 15),\n        # lambda x: TranslateX(x, 0.225),\n        # lambda x: TranslateY(x, 0.225),\n        # lambda x: Rotate(x, 15),\n        # lambda x: CutoutAbs(x, 0.25),\n        # lambda x: CutoutTemporal(x, 0.25),\n        # lambda x: GaussianBlur(x, 0.5),\n        # lambda x: SaltAndPepperNoise(x, 0.1),\n        # transforms.Normalize(DVSCIFAR10_MEAN_16, DVSCIFAR10_STD_16),\n        transforms.RandomCrop(size, padding=size // 12),\n        transforms.RandomHorizontalFlip(),\n        # lambda x: event_drop(x),\n        # transforms.RandomRotation(15)\n    ])\n    test_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n    ])\n\n    if 'rand_aug' in kwargs.keys():\n        if kwargs['rand_aug'] is True:\n            n = kwargs['randaug_n']\n            m = kwargs['randaug_m']\n            # print('randaug', m, n)\n            train_transform.transforms.insert(2, RandAugment(m=m, n=n))\n\n    # if 'temporal_flatten' in kwargs.keys():\n    #     if kwargs['temporal_flatten'] is True:\n    #         train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n    #         test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n\n    train_dataset = DiskCachedDataset(train_dataset,\n                                      cache_path=os.path.join(DATA_DIR, 'DVS/NMNIST/train_cache_{}'.format(step)),\n                                      transform=train_transform)\n    test_dataset = DiskCachedDataset(test_dataset,\n                                     cache_path=os.path.join(DATA_DIR, 'DVS/NMNIST/test_cache_{}'.format(step)),\n                                     transform=test_transform)\n\n    num_train = len(train_dataset)\n    num_per_cls = num_train // 10\n    indices_train, indices_test = [], []\n    portion = kwargs['portion'] if 'portion' in kwargs else .9\n    for i in range(10):\n        indices_train.extend(\n            list(range(i * num_per_cls, round(i * num_per_cls + num_per_cls * portion))))\n        indices_test.extend(\n            list(range(round(i * num_per_cls + num_per_cls * portion), (i + 1) * num_per_cls)))\n\n    mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n    mixup_active = cut_mix | event_mix | mix_up\n\n    if cut_mix:\n        # print('cut_mix', beta, prob, num, num_classes)\n        train_dataset = CutMix(train_dataset,\n                               beta=beta,\n                               prob=prob,\n                               num_mix=num,\n                               num_class=num_classes,\n                               indices=indices_train,\n                               noise=noise)\n\n    if event_mix:\n        train_dataset = EventMix(train_dataset,\n                                 beta=beta,\n                                 prob=prob,\n                                 num_mix=num,\n                                 num_class=num_classes,\n                                 indices=indices_train,\n                                 noise=noise,\n                                 gaussian_n=gaussian_n)\n\n    if mix_up:\n        train_dataset = MixUp(train_dataset,\n                              beta=beta,\n                              prob=prob,\n                              num_mix=num,\n                              num_class=num_classes,\n                              indices=indices_train,\n                              noise=noise)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size,\n        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_train),\n        pin_memory=True, drop_last=True, num_workers=8\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_test),\n        pin_memory=True, drop_last=False, num_workers=2\n    )\n\n    return train_loader, test_loader, mixup_active, None\n\n\ndef get_NCALTECH101_data(batch_size, step, **kwargs):\n    \"\"\"\n    获取NCaltech101数据\n    http://journal.frontiersin.org/Article/10.3389/fnins.2015.00437/abstract\n    :param batch_size: batch size\n    :param step: 仿真步长\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    sensor_size = braincog.datasets.ncaltech101.NCALTECH101.sensor_size\n    cls_count = braincog.datasets.ncaltech101.NCALTECH101.cls_count\n    dataset_length = braincog.datasets.ncaltech101.NCALTECH101.length\n    portion = kwargs['portion'] if 'portion' in kwargs else .9\n    size = kwargs['size'] if 'size' in kwargs else 48\n    # print('portion', portion)\n    train_sample_weight = []\n    train_sample_index = []\n    train_count = 0\n    test_sample_index = []\n    idx_begin = 0\n    for count in cls_count:\n        sample_weight = dataset_length / count\n        train_sample = round(portion * count)\n        test_sample = count - train_sample\n        train_count += train_sample\n        train_sample_weight.extend(\n            [sample_weight] * train_sample\n        )\n        train_sample_weight.extend(\n            [0.] * test_sample\n        )\n        train_sample_index.extend(\n            list((range(idx_begin, idx_begin + train_sample)))\n        )\n        test_sample_index.extend(\n            list(range(idx_begin + train_sample, idx_begin + train_sample + test_sample))\n        )\n        idx_begin += count\n\n    train_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_sample_weight, train_count)\n    test_sampler = torch.utils.data.sampler.SubsetRandomSampler(test_sample_index)\n\n    train_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        # tonic.transforms.DropEvent(p=0.1),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    test_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n\n    train_dataset = braincog.datasets.ncaltech101.NCALTECH101(os.path.join(DATA_DIR, 'DVS/NCALTECH101'), transform=train_transform)\n    test_dataset = braincog.datasets.ncaltech101.NCALTECH101(os.path.join(DATA_DIR, 'DVS/NCALTECH101'), transform=test_transform)\n\n    train_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        # lambda x: print(x.shape),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        transforms.RandomCrop(size, padding=size // 12),\n        # transforms.RandomHorizontalFlip(),\n        # transforms.RandomRotation(15)\n    ])\n    test_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        # lambda x: temporal_flatten(x),\n    ])\n    if 'rand_aug' in kwargs.keys():\n        if kwargs['rand_aug'] is True:\n            n = kwargs['randaug_n']\n            m = kwargs['randaug_m']\n            train_transform.transforms.insert(2, RandAugment(m=m, n=n))\n\n    # if 'temporal_flatten' in kwargs.keys():\n    #     if kwargs['temporal_flatten'] is True:\n    #         train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n    #         test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n\n    train_dataset = DiskCachedDataset(train_dataset,\n                                      cache_path=os.path.join(DATA_DIR, 'DVS/NCALTECH101/train_cache_{}'.format(step)),\n                                      transform=train_transform, num_copies=3)\n    test_dataset = DiskCachedDataset(test_dataset,\n                                     cache_path=os.path.join(DATA_DIR, 'DVS/NCALTECH101/test_cache_{}'.format(step)),\n                                     transform=test_transform, num_copies=3)\n\n    mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n    mixup_active = cut_mix | event_mix | mix_up\n\n    if cut_mix:\n        train_dataset = CutMix(train_dataset,\n                               beta=beta,\n                               prob=prob,\n                               num_mix=num,\n                               num_class=num_classes,\n                               indices=train_sample_index,\n                               noise=noise)\n\n    if event_mix:\n        train_dataset = EventMix(train_dataset,\n                                 beta=beta,\n                                 prob=prob,\n                                 num_mix=num,\n                                 num_class=num_classes,\n                                 indices=train_sample_index,\n                                 noise=noise,\n                                 gaussian_n=gaussian_n)\n    if mix_up:\n        train_dataset = MixUp(train_dataset,\n                              beta=beta,\n                              prob=prob,\n                              num_mix=num,\n                              num_class=num_classes,\n                              indices=train_sample_index,\n                              noise=noise)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size,\n        sampler=train_sampler,\n        pin_memory=True, drop_last=True, num_workers=8\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        sampler=test_sampler,\n        pin_memory=True, drop_last=False, num_workers=2\n    )\n\n    return train_loader, test_loader, mixup_active, None\n\n\ndef get_UCF101DVS_data(batch_size, step, **kwargs):\n    \"\"\"\n    获取DVS CIFAR10数据\n    http://journal.frontiersin.org/article/10.3389/fnins.2017.00309/full\n    :param batch_size: batch size\n    :param step: 仿真步长\n    :param kwargs:\n    :return: (train loader, test loader, mixup_active, mixup_fn)\n    \"\"\"\n    size = kwargs['size'] if 'size' in kwargs else 48\n    sensor_size = braincog.datasets.ucf101_dvs.UCF101DVS.sensor_size\n    train_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        # tonic.transforms.DropEvent(p=0.1),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    test_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    train_dataset = braincog.datasets.ucf101_dvs.UCF101DVS(os.path.join(DATA_DIR, 'DVS/UCF101DVS'), train=True, transform=train_transform)\n    test_dataset = braincog.datasets.ucf101_dvs.UCF101DVS(os.path.join(DATA_DIR, 'DVS/UCF101DVS'), train=False, transform=test_transform)\n\n    train_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        # lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        # lambda x: TemporalShift(x, .01),\n        # lambda x: drop(x, 0.15),\n        # lambda x: ShearX(x, 15),\n        # lambda x: ShearY(x, 15),\n        # lambda x: TranslateX(x, 0.225),\n        # lambda x: TranslateY(x, 0.225),\n        # lambda x: Rotate(x, 15),\n        # lambda x: CutoutAbs(x, 0.25),\n        # lambda x: CutoutTemporal(x, 0.25),\n        # lambda x: GaussianBlur(x, 0.5),\n        # lambda x: SaltAndPepperNoise(x, 0.1),\n        # transforms.Normalize(DVSCIFAR10_MEAN_16, DVSCIFAR10_STD_16),\n        # transforms.RandomCrop(size, padding=size // 12),\n        transforms.RandomHorizontalFlip(),\n        # transforms.RandomRotation(15)\n    ])\n    test_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        # lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n    ])\n\n    if 'rand_aug' in kwargs.keys():\n        if kwargs['rand_aug'] is True:\n            n = kwargs['randaug_n']\n            m = kwargs['randaug_m']\n            # print('randaug', m, n)\n            train_transform.transforms.insert(2, RandAugment(m=m, n=n))\n\n    # if 'temporal_flatten' in kwargs.keys():\n    #     if kwargs['temporal_flatten'] is True:\n    #         train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n    #         test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n\n    train_dataset = DiskCachedDataset(train_dataset,\n                                      cache_path=os.path.join(DATA_DIR, 'UCF101DVS/train_cache_{}'.format(step)),\n                                      transform=train_transform)\n    test_dataset = DiskCachedDataset(test_dataset,\n                                     cache_path=os.path.join(DATA_DIR, 'UCF101DVS/test_cache_{}'.format(step)),\n                                     transform=test_transform)\n\n    mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n    mixup_active = cut_mix | event_mix | mix_up\n\n    if cut_mix:\n        # print('cut_mix', beta, prob, num, num_classes)\n        train_dataset = CutMix(train_dataset,\n                               beta=beta,\n                               prob=prob,\n                               num_mix=num,\n                               num_class=num_classes,\n                               noise=noise)\n\n    if event_mix:\n        train_dataset = EventMix(train_dataset,\n                                 beta=beta,\n                                 prob=prob,\n                                 num_mix=num,\n                                 num_class=num_classes,\n                                 noise=noise,\n                                 gaussian_n=gaussian_n)\n\n    if mix_up:\n        train_dataset = MixUp(train_dataset,\n                              beta=beta,\n                              prob=prob,\n                              num_mix=num,\n                              num_class=num_classes,\n                              noise=noise)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size, shuffle=True,\n        pin_memory=True, drop_last=True, num_workers=8\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size, shuffle=False,\n        pin_memory=True, drop_last=False, num_workers=2\n    )\n\n    return train_loader, test_loader, mixup_active, None\n\n\ndef get_HMDBDVS_data(batch_size, step, **kwargs):\n    sensor_size = braincog.datasets.hmdb_dvs.HMDBDVS.sensor_size\n\n    train_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        # tonic.transforms.DropEvent(p=0.1),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n    test_transform = transforms.Compose([\n        # tonic.transforms.Denoise(filter_time=10000),\n        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])\n\n    train_dataset = braincog.datasets.hmdb_dvs.HMDBDVS(os.path.join(DATA_DIR, 'HMDBDVS'), transform=train_transform)\n    test_dataset = braincog.datasets.hmdb_dvs.HMDBDVS(os.path.join(DATA_DIR, 'HMDBDVS'), transform=test_transform)\n\n    cls_count = train_dataset.cls_count\n    dataset_length = train_dataset.length\n\n    portion = .5\n    # portion = kwargs['portion'] if 'portion' in kwargs else .9\n    size = kwargs['size'] if 'size' in kwargs else 48\n    # print('portion', portion)\n    train_sample_weight = []\n    train_sample_index = []\n    train_count = 0\n    test_sample_index = []\n    idx_begin = 0\n    for count in cls_count:\n        sample_weight = dataset_length / count\n        train_sample = round(portion * count)\n        test_sample = count - train_sample\n        train_count += train_sample\n        train_sample_weight.extend(\n            [sample_weight] * train_sample\n        )\n        train_sample_weight.extend(\n            [0.] * test_sample\n        )\n        lst = list(range(idx_begin, idx_begin + train_sample + test_sample))\n        random.seed(0)\n        random.shuffle(lst)\n        train_sample_index.extend(\n            lst[:train_sample]\n            # list((range(idx_begin, idx_begin + train_sample)))\n        )\n        test_sample_index.extend(\n            lst[train_sample:train_sample + test_sample]\n            # list(range(idx_begin + train_sample, idx_begin + train_sample + test_sample))\n        )\n        idx_begin += count\n\n    train_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_sample_weight, train_count)\n    test_sampler = torch.utils.data.sampler.SubsetRandomSampler(test_sample_index)\n\n    train_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        # lambda x: print(x.shape),\n        # lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        # transforms.RandomCrop(size, padding=size // 12),\n        # transforms.RandomHorizontalFlip(),\n        # transforms.RandomRotation(15)\n    ])\n    test_transform = transforms.Compose([\n        lambda x: torch.tensor(x, dtype=torch.float),\n        # lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n        # lambda x: temporal_flatten(x),\n    ])\n    if 'rand_aug' in kwargs.keys():\n        if kwargs['rand_aug'] is True:\n            n = kwargs['randaug_n']\n            m = kwargs['randaug_m']\n            train_transform.transforms.insert(2, RandAugment(m=m, n=n))\n\n    # if 'temporal_flatten' in kwargs.keys():\n    #     if kwargs['temporal_flatten'] is True:\n    #         train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n    #         test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n\n    train_dataset = DiskCachedDataset(train_dataset,\n                                      cache_path=os.path.join(DATA_DIR, 'HMDBDVS/train_cache_{}'.format(step)),\n                                      transform=train_transform, num_copies=3)\n    test_dataset = DiskCachedDataset(test_dataset,\n                                     cache_path=os.path.join(DATA_DIR, 'HMDBDVS/test_cache_{}'.format(step)),\n                                     transform=test_transform, num_copies=3)\n\n    mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n    mixup_active = cut_mix | event_mix | mix_up\n\n    if cut_mix:\n        train_dataset = CutMix(train_dataset,\n                               beta=beta,\n                               prob=prob,\n                               num_mix=num,\n                               num_class=num_classes,\n                               indices=train_sample_index,\n                               noise=noise)\n\n    if event_mix:\n        train_dataset = EventMix(train_dataset,\n                                 beta=beta,\n                                 prob=prob,\n                                 num_mix=num,\n                                 num_class=num_classes,\n                                 indices=train_sample_index,\n                                 noise=noise,\n                                 gaussian_n=gaussian_n)\n    if mix_up:\n        train_dataset = MixUp(train_dataset,\n                              beta=beta,\n                              prob=prob,\n                              num_mix=num,\n                              num_class=num_classes,\n                              indices=train_sample_index,\n                              noise=noise)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=batch_size,\n        sampler=train_sampler,\n        pin_memory=True, drop_last=True, num_workers=8\n    )\n\n    test_loader = torch.utils.data.DataLoader(\n        test_dataset, batch_size=batch_size,\n        sampler=test_sampler,\n        pin_memory=True, drop_last=False, num_workers=2\n    )\n\n    return train_loader, test_loader, mixup_active, None\n\n\n# def get_NCARS_data(batch_size, step, **kwargs):\n#     \"\"\"\n#     获取N-Cars数据\n#     https://ieeexplore.ieee.org/document/8578284/\n#     :param batch_size: batch size\n#     :param step: 仿真步长\n#     :param kwargs:\n#     :return: (train loader, test loader, mixup_active, mixup_fn)\n#     \"\"\"\n#     sensor_size = tonic.datasets.NCARS.sensor_size\n#     size = kwargs['size'] if 'size' in kwargs else 48\n#\n#     train_transform = transforms.Compose([\n#         # tonic.transforms.Denoise(filter_time=10000),\n#         # tonic.transforms.DropEvent(p=0.1),\n#         tonic.transforms.ToFrame(sensor_size=None, n_time_bins=step),\n#     ])\n#     test_transform = transforms.Compose([\n#         # tonic.transforms.Denoise(filter_time=10000),\n#         tonic.transforms.ToFrame(sensor_size=None, n_time_bins=step),\n#     ])\n#\n#     train_dataset = tonic.datasets.NCARS(os.path.join(DATA_DIR, 'DVS/NCARS'), transform=train_transform, train=True)\n#     test_dataset = tonic.datasets.NCARS(os.path.join(DATA_DIR, 'DVS/NCARS'), transform=test_transform, train=False)\n#\n#     train_transform = transforms.Compose([\n#         lambda x: torch.tensor(x, dtype=torch.float),\n#         lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n#         lambda x: dvs_channel_check_expend(x),\n#         transforms.RandomCrop(size, padding=size // 12),\n#         transforms.RandomHorizontalFlip(),\n#         transforms.RandomRotation(15)\n#     ])\n#     test_transform = transforms.Compose([\n#         lambda x: torch.tensor(x, dtype=torch.float),\n#         lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),\n#         lambda x: dvs_channel_check_expend(x),\n#     ])\n#     if 'rand_aug' in kwargs.keys():\n#         if kwargs['rand_aug'] is True:\n#             n = kwargs['randaug_n']\n#             m = kwargs['randaug_m']\n#             train_transform.transforms.insert(2, RandAugment(m=m, n=n))\n#\n#     # if 'temporal_flatten' in kwargs.keys():\n#     #     if kwargs['temporal_flatten'] is True:\n#     #         train_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n#     #         test_transform.transforms.insert(-1, lambda x: temporal_flatten(x))\n#\n#     train_dataset = DiskCachedDataset(train_dataset,\n#                                       cache_path=os.path.join(DATA_DIR, 'DVS/NCARS/train_cache_{}'.format(step)),\n#                                       transform=train_transform, num_copies=3)\n#     test_dataset = DiskCachedDataset(test_dataset,\n#                                      cache_path=os.path.join(DATA_DIR, 'DVS/NCARS/test_cache_{}'.format(step)),\n#                                      transform=test_transform, num_copies=3)\n#\n#     mix_up, cut_mix, event_mix, beta, prob, num, num_classes, noise, gaussian_n = unpack_mix_param(kwargs)\n#     mixup_active = cut_mix | event_mix | mix_up\n#\n#     if cut_mix:\n#         train_dataset = CutMix(train_dataset,\n#                                beta=beta,\n#                                prob=prob,\n#                                num_mix=num,\n#                                num_class=num_classes,\n#                                noise=noise)\n#\n#     if event_mix:\n#         train_dataset = EventMix(train_dataset,\n#                                  beta=beta,\n#                                  prob=prob,\n#                                  num_mix=num,\n#                                  num_class=num_classes,\n#                                  noise=noise,\n#                                  gaussian_n=gaussian_n)\n#     if mix_up:\n#         train_dataset = MixUp(train_dataset,\n#                               beta=beta,\n#                               prob=prob,\n#                               num_mix=num,\n#                               num_class=num_classes,\n#                               noise=noise)\n#\n#     train_loader = torch.utils.data.DataLoader(\n#         train_dataset, batch_size=batch_size,\n#         pin_memory=True, drop_last=True, num_workers=8,\n#         shuffle=True,\n#     )\n#\n#     test_loader = torch.utils.data.DataLoader(\n#         test_dataset, batch_size=batch_size,\n#         pin_memory=True, drop_last=False, num_workers=2,\n#         shuffle=False,\n#     )\n#\n#     return train_loader, test_loader, mixup_active, None\n\n\ndef get_nomni_data(batch_size, train_portion=1., **kwargs):\n    \"\"\"\n    获取N-Omniglot数据\n    :param batch_size:batch的大小\n    :param data_mode:一共full nkks pair三种模式\n    :param frames_num:一个样本帧的个数\n    :param data_type:event frequency两种模式\n    \"\"\"\n    data_mode = kwargs[\"data_mode\"] if \"data_mode\" in kwargs else \"full\"\n    frames_num = kwargs[\"frames_num\"] if \"frames_num\" in kwargs else 10\n    data_type = kwargs[\"data_type\"] if \"data_type\" in kwargs else \"event\"\n\n    train_transform = transforms.Compose([\n        transforms.Resize((64, 64))])\n    test_transform = transforms.Compose([\n        transforms.Resize((64, 64))])\n    if data_mode == \"full\":\n        train_datasets = NOmniglotfull(root=os.path.join(DATA_DIR, 'DVS/NOmniglot'), train=True, frames_num=frames_num,\n                                       data_type=data_type,\n                                       transform=train_transform)\n        test_datasets = NOmniglotfull(root=os.path.join(DATA_DIR, 'DVS/NOmniglot'), train=False, frames_num=frames_num,\n                                      data_type=data_type,\n                                      transform=test_transform)\n\n    elif data_mode == \"nkks\":\n        train_datasets = NOmniglotNWayKShot(os.path.join(DATA_DIR, 'DVS/NOmniglot'),\n                                            n_way=kwargs[\"n_way\"],\n                                            k_shot=kwargs[\"k_shot\"],\n                                            k_query=kwargs[\"k_query\"],\n                                            train=True,\n                                            frames_num=frames_num,\n                                            data_type=data_type,\n                                            transform=train_transform)\n        test_datasets = NOmniglotNWayKShot(os.path.join(DATA_DIR, 'DVS/NOmniglot'),\n                                           n_way=kwargs[\"n_way\"],\n                                           k_shot=kwargs[\"k_shot\"],\n                                           k_query=kwargs[\"k_query\"],\n                                           train=False,\n                                           frames_num=frames_num,\n                                           data_type=data_type,\n                                           transform=test_transform)\n    elif data_mode == \"pair\":\n        train_datasets = NOmniglotTrainSet(root=os.path.join(DATA_DIR, 'DVS/NOmniglot'), use_frame=True,\n                                           frames_num=frames_num, data_type=data_type,\n                                           use_npz=False, resize=105)\n        test_datasets = NOmniglotTestSet(root=os.path.join(DATA_DIR, 'DVS/NOmniglot'), time=2000, way=kwargs[\"n_way\"],\n                                         shot=kwargs[\"k_shot\"], use_frame=True,\n                                         frames_num=frames_num, data_type=data_type, use_npz=False, resize=105)\n\n    else:\n        pass\n\n    train_loader = torch.utils.data.DataLoader(\n        train_datasets, batch_size=batch_size, num_workers=12,\n        pin_memory=True, drop_last=True, shuffle=True\n    )\n    test_loader = torch.utils.data.DataLoader(\n        test_datasets, batch_size=batch_size, num_workers=12,\n        pin_memory=True, drop_last=False\n    )\n    return train_loader, test_loader, None, None\n"
  },
  {
    "path": "examples/decision_making/BDM-SNN/BDM-SNN-UAV.py",
    "content": "import numpy as np\r\nimport torch,os,sys\r\nfrom torch import nn\r\nfrom torch.nn import Parameter \r\n\r\nimport abc\r\nimport math\r\nfrom abc import ABC\r\n\r\nimport torch.nn.functional as F\r\nimport matplotlib.pyplot as plt\r\n#from braincog.base.strategy.surrogate import *\r\nfrom braincog.base.node.node import IFNode\r\nfrom braincog.base.learningrule.STDP import STDP,MutliInputSTDP\r\nfrom braincog.base.connection.CustomLinear import CustomLinear\r\nfrom braincog.base.brainarea.basalganglia import basalganglia\r\nfrom braincog.model_zoo.bdmsnn import BDMSNN\r\n\r\nfrom robomaster import robot\r\nimport time\r\n\r\ndef chooseAct(Net,input,weight_trace_d1,weight_trace_d2):\r\n    \"\"\"\r\n    根据输入选择行为\r\n    :param Net: 输入BDM-SNN网络\r\n    :param input: 输入电流 编码状态的脉冲\r\n    :param weight_trace_d1: 不断累积保存资格迹\r\n    :param weight_trace_d2: 不断累积保存资格迹\r\n    :return: 返回选择的行为、资格迹和网络\r\n    \"\"\"\r\n    for i_train in range(500):\r\n        out, dw = Net(input)\r\n        # rstdp\r\n        weight_trace_d1 *= trace_decay\r\n        weight_trace_d1 += dw[0][0]\r\n        weight_trace_d2 *= trace_decay\r\n        weight_trace_d2 += dw[1][0]\r\n        if torch.max(out) > 0:\r\n            return torch.argmax(out),weight_trace_d1,weight_trace_d2,Net\r\n\r\ndef updateNet(Net,reward, action, state,weight_trace_d1,weight_trace_d2):\r\n    \"\"\"\r\n    更新网络\r\n    :param Net: BDM-SNN网络\r\n    :param reward: 获得的奖励\r\n    :param action: 执行的行为\r\n    :param state: 执行行为前的状态\r\n    :param weight_trace_d1: 直接通路累积的资格迹\r\n    :param weight_trace_d2: 间接通路累积的资格迹\r\n    :return: 更新后的网络\r\n    \"\"\"\r\n    r = torch.ones((num_state, num_state * num_action), dtype=torch.float)\r\n    r[state, state * num_action + action] = reward\r\n    dw_d1 = r * weight_trace_d1\r\n    dw_d2 = -1 * r * weight_trace_d2\r\n    Net.UpdateWeight(0, state,num_action,dw_d1)\r\n    Net.UpdateWeight(1, state,num_action,dw_d2)\r\n    return Net\r\n\r\nif __name__==\"__main__\":\r\n    \"\"\"\r\n    定义无人机 大疆Tello Talent \r\n    定义BDM-SNN网络\r\n    用户自定义状态空间、奖励函数，调用行为选择及网络更新\r\n    \"\"\"\r\n    #define UAV\r\n    tl_drone = robot.Drone()\r\n    tl_drone.initialize()\r\n    tl_flight = tl_drone.flight\r\n    tl_flight.takeoff().wait_for_completed()\r\n\r\n    #define Net\r\n    num_state=9\r\n    num_action=2\r\n    weight_exc=1\r\n    weight_inh=-0.5\r\n    trace_decay = 0.8\r\n    DM=BDMSNN(num_state,num_action,weight_exc,weight_inh,\"lif\")\r\n    con_matrix1 = torch.zeros((num_state, num_state * num_action), dtype=torch.float)\r\n    for i in range(num_state):\r\n        for j in range(num_action):\r\n            con_matrix1[i, i * num_action + j] = weight_exc\r\n    weight_trace_d1 = torch.zeros(con_matrix1.shape, dtype=torch.float)\r\n    weight_trace_d2 = torch.zeros(con_matrix1.shape, dtype=torch.float)\r\n    iteration=0\r\n    while iteration < 200:\r\n        input = torch.zeros((num_state), dtype=torch.float)\r\n        #users define the judgestate function\r\n        state=1\r\n        input[state]=2\r\n        action,weight_trace_d1,weight_trace_d2,DM = chooseAct(DM,input,weight_trace_d1,weight_trace_d2)\r\n        #uav do action\r\n        if action==0:\r\n            tl_flight.forward(distance=20).wait_for_completed()\r\n        if action == 1:\r\n            # flying left\r\n            tl_flight.rc(a=20, b=0, c=0, d=0)\r\n            time.sleep(4)\r\n        if action == 2:\r\n            # flying right\r\n            tl_flight.rc(a=-20, b=0, c=0, d=0)\r\n            time.sleep(3)\r\n        if action == 3:\r\n            tl_flight.backward(distance=20).wait_for_completed()\r\n       #users define the reward function\r\n        reward =1\r\n\r\n        DM=updateNet(DM,reward, action, state,weight_trace_d1,weight_trace_d2)\r\n        weight_trace_d1 = torch.zeros(con_matrix1.shape, dtype=torch.float)\r\n        weight_trace_d2 = torch.zeros(con_matrix1.shape, dtype=torch.float)\r\n        DM.reset()\r\n\r\n        iteration += 1"
  },
  {
    "path": "examples/decision_making/BDM-SNN/BDM-SNN-hh.py",
    "content": "import torch,os\r\n\r\nfrom random import randint\r\nimport torch\r\nfrom torch import nn\r\nfrom braincog.base.strategy.surrogate import *\r\nfrom braincog.model_zoo.bdmsnn import BDMSNN\r\nimport pygame\r\nfrom pygame.locals import *\r\nfrom collections import deque\r\nfrom random import randint\r\nimport numpy as np\r\nimport matplotlib.pyplot as plt\r\n\r\nimport random\r\n#os.environ[\"SDL_VIDEODRIVER\"] = \"dummy\"\r\n\r\ndef load_images():\r\n    \"\"\"\r\n    Flappy Bird中load图像\r\n    :return:load的图像\r\n    \"\"\"\r\n    def load_image(img_file_name):\r\n        file_name = os.path.join('.', 'birdimages', img_file_name)\r\n        img = pygame.image.load(file_name)\r\n        # converting all images before use speeds up blitting\r\n        img.convert()\r\n        return img\r\n\r\n    return {'background': load_image('background.png'),\r\n            'pipe-end': load_image('pipe_end.png'),\r\n            'pipe-body': load_image('pipe_body.png'),\r\n            # images for animating the flapping bird -- animated GIFs are\r\n            # not supported in pygame\r\n            'bird-wingup': load_image('bird_wing_up.png'),\r\n            'bird-wingdown': load_image('bird_wing_down.png'),}\r\n\r\nclass Bird(pygame.sprite.Sprite):\r\n    \"\"\"\r\n    Flappy Bird类\r\n    \"\"\"\r\n    WIDTH = HEIGHT = 32\r\n    SINK_SPEED = 0.2\r\n    Fail_SINk_SPEED = 0.6\r\n    CLIMB_SPEED = 0.25\r\n    CLIMB_DURATION = 333.3\r\n    REGION = CLIMB_DURATION / 3\r\n    NEAR_COLLIDE = 30\r\n    NEAR_PIPE = 0\r\n\r\n    def __init__(self, x, y, msec_to_climb, images):\r\n        super(Bird, self).__init__()\r\n        self.x, self.y = x, y\r\n        self.msec_to_climb = msec_to_climb\r\n        self._img_wingup, self._img_wingdown = images\r\n        self._mask_wingup = pygame.mask.from_surface(self._img_wingup)\r\n        self._mask_wingdown = pygame.mask.from_surface(self._img_wingdown)\r\n\r\n    def update(self, action,state,delta_frames=1):\r\n        \"\"\"\r\n        更新小鸟的位置\r\n        :param action: 输入行为\r\n        :param state:输入状态\r\n        :param delta_frames:Fault\r\n        :return:None\r\n        \"\"\"\r\n        if self.msec_to_climb > 0 and action == 1:\r\n            if state==4 or state==5 or state == 2 or state == 3:\r\n                self.y -= (2*Bird.CLIMB_SPEED * (1000.0 * delta_frames / 60))\r\n            else:\r\n                self.y -= (Bird.CLIMB_SPEED * (1000.0 * delta_frames / 60))\r\n        else:\r\n            if state == 4 or state == 5 or state == 2 or state == 3:\r\n                self.y += 2*Bird.SINK_SPEED * (1000.0 * delta_frames / 60)\r\n            else:\r\n                self.y += Bird.SINK_SPEED * (1000.0 * delta_frames / 60)\r\n\r\n    def sink(self, delta_frames=1):\r\n        self.y += Bird.Fail_SINk_SPEED * (1000.0 * delta_frames / 60)\r\n\r\n    @property\r\n    def image(self):\r\n        if pygame.time.get_ticks() % 500 >= 250:\r\n            return self._img_wingup\r\n        else:\r\n            return self._img_wingdown\r\n\r\n    @property\r\n    def mask(self):\r\n        if pygame.time.get_ticks() % 500 >= 250:\r\n            return self._mask_wingup\r\n        else:\r\n            return self._mask_wingdown\r\n\r\n    @property\r\n    def rect(self):\r\n        return Rect(self.x, self.y, Bird.WIDTH, Bird.HEIGHT)\r\n\r\nclass PipePair(pygame.sprite.Sprite):\r\n    \"\"\"\r\n    Flappy Bird 中的管子类\r\n    \"\"\"\r\n    WIDTH = 80\r\n    PIECE_HEIGHT = 32\r\n    ADD_INTERVAL = 2000\r\n    ADD_EVENT = pygame.USEREVENT + 1\r\n    ROOM_HIGHT = 2 * Bird.HEIGHT + 2 * PIECE_HEIGHT\r\n\r\n    def __init__(self, pipe_end_img, pipe_body_img):\r\n        self.x = float(WIN_WIDTH - 1)\r\n        self.score_counted = False\r\n        self.isNewPipe = True\r\n\r\n        self.image = pygame.Surface((PipePair.WIDTH, WIN_HEIGHT), SRCALPHA)\r\n        self.image.convert()   # speeds up blitting\r\n        self.image.fill((0, 0, 0, 0))\r\n        total_pipe_body_pieces = int(\r\n            (WIN_HEIGHT -  # fill window from top to bottom\r\n             3 * Bird.HEIGHT -  # make room for bird to fit through\r\n             3 * PipePair.PIECE_HEIGHT) /  # 2 end pieces + 1 body piece\r\n            PipePair.PIECE_HEIGHT  # to get number of pipe pieces\r\n        )\r\n        self.bottom_pieces = randint(1, total_pipe_body_pieces)\r\n        self.top_pieces = total_pipe_body_pieces - self.bottom_pieces\r\n\r\n        # bottom pipe\r\n        for i in range(1, self.bottom_pieces + 1):\r\n            piece_pos = (0, WIN_HEIGHT - i * PipePair.PIECE_HEIGHT)\r\n            self.image.blit(pipe_body_img, piece_pos)\r\n        bottom_pipe_end_y = WIN_HEIGHT - self.bottom_height_px\r\n        bottom_end_piece_pos = (0, bottom_pipe_end_y - PipePair.PIECE_HEIGHT)\r\n        self.image.blit(pipe_end_img, bottom_end_piece_pos)\r\n\r\n        # top pipe\r\n        for i in range(self.top_pieces):\r\n            self.image.blit(pipe_body_img, (0, i * PipePair.PIECE_HEIGHT))\r\n        top_pipe_end_y = self.top_height_px\r\n        self.image.blit(pipe_end_img, (0, top_pipe_end_y))\r\n\r\n        self.center = (top_pipe_end_y + bottom_pipe_end_y) / 2\r\n\r\n        # compensate for added end pieces\r\n        self.top_pieces += 1\r\n        self.bottom_pieces += 1\r\n\r\n        # for collision detection\r\n        self.mask = pygame.mask.from_surface(self.image)\r\n        self.top_y = top_pipe_end_y\r\n        self.bottom_y = bottom_pipe_end_y\r\n\r\n    @property\r\n    def top_height_px(self):\r\n        return self.top_pieces * PipePair.PIECE_HEIGHT\r\n\r\n    @property\r\n    def bottom_height_px(self):\r\n        return self.bottom_pieces * PipePair.PIECE_HEIGHT\r\n\r\n    @property\r\n    def visible(self):\r\n        return -PipePair.WIDTH < self.x < WIN_WIDTH\r\n\r\n    @property\r\n    def rect(self):\r\n        return Rect(self.x, 0, PipePair.WIDTH, PipePair.PIECE_HEIGHT)\r\n\r\n    def update(self, delta_frames=1):\r\n        self.x -= 0.18 * 1000.0 * delta_frames /60\r\n\r\n    def collides_with(self, bird):\r\n        return pygame.sprite.collide_mask(self, bird)\r\n\r\ndef chooseAct(Net,input,weight_trace_d1,weight_trace_d2):\r\n    \"\"\"\r\n    根据输入选择行为\r\n    :param Net: 输入BDM-SNN网络\r\n    :param input: 输入电流 编码状态的脉冲\r\n    :param weight_trace_d1: 不断累积保存资格迹\r\n    :param weight_trace_d2: 不断累积保存资格迹\r\n    :return: 返回选择的行为、资格迹和网络\r\n    \"\"\"\r\n    for i_train in range(500):\r\n        out, dw = Net(input)\r\n        # rstdp\r\n        weight_trace_d1 *= trace_decay\r\n        weight_trace_d1 += dw[0][0]\r\n        weight_trace_d2 *= trace_decay\r\n        weight_trace_d2 += dw[1][0]\r\n        if torch.max(out) > 0:\r\n            return torch.argmax(out),weight_trace_d1,weight_trace_d2,Net\r\n\r\ndef judgeState(bird, pipes, collide):\r\n    \"\"\"\r\n    根据小鸟和管子之间的位置关系判断当前状态\r\n    :param bird:传入小鸟的各项属性\r\n    :param pipes:传入管子的各项属性\r\n    :param collide:是否发生碰撞\r\n    :return:状态，距离，是否是新的管子\r\n    \"\"\"\r\n    # bird's x and y coordinate in the left top of the image\r\n    dist = bird.y + Bird.HEIGHT / 2 - WIN_HEIGHT / 2\r\n    isNew = False\r\n    index = -1\r\n    state = -1\r\n    if collide:\r\n        state = 8\r\n        return state\r\n    for p in pipes:\r\n        if p.x + PipePair.WIDTH - Bird.HEIGHT / 4 < bird.x and not p.score_counted:\r\n            continue\r\n        if p.x - Bird.NEAR_PIPE <= bird.x + Bird.HEIGHT and \\\r\n                p.x + PipePair.WIDTH - Bird.HEIGHT / 4 >= bird.x:\r\n\r\n            p_top_y = p.top_y + PipePair.PIECE_HEIGHT\r\n            p_bottom_y = p.bottom_y - PipePair.PIECE_HEIGHT\r\n            if p.center - bird.y - Bird.HEIGHT / 2 >= 0 and bird.y >= p_top_y + Bird.NEAR_COLLIDE / 2:\r\n                state = 0\r\n            elif bird.y - p.center + Bird.HEIGHT / 2 > 0 and bird.y + Bird.HEIGHT <= p_bottom_y - Bird.NEAR_COLLIDE / 2:\r\n                state = 1\r\n            elif bird.y < p_top_y + Bird.NEAR_COLLIDE / 2 and bird.y > p_top_y - 10:\r\n                state = 6\r\n            elif bird.y + Bird.HEIGHT > p_bottom_y - Bird.NEAR_COLLIDE / 2 and bird.y + Bird.HEIGHT < p_bottom_y + 10:\r\n                state = 7\r\n            if state > -0.5:\r\n                index = 1\r\n        elif p.x > bird.x + Bird.HEIGHT + Bird.NEAR_PIPE:\r\n            state = blankState(bird, p.center)\r\n            if p.isNewPipe:\r\n                isNew = True\r\n            p.isNewPipe = False\r\n            index = 1\r\n        if index > 0:  # only judge the nearest and not passed pipe\r\n            dist = bird.y + Bird.HEIGHT / 2 - p.center\r\n            break\r\n    if index < -0.5:  # no pipe left, key the bird in the middle\r\n        pos = WIN_HEIGHT / 2\r\n        dist = bird.y + Bird.HEIGHT / 2 - pos\r\n        state = blankState(bird, pos)\r\n\r\n    return state, dist, isNew\r\n\r\ndef blankState(bird, center):\r\n    \"\"\"\r\n    judgeState中调用的判断状态的函数 根据鸟的位置和管子中心的距离来判断\r\n    :param bird: 传入小鸟的各项属性\r\n    :param center:中心\r\n    :return:状态\r\n    \"\"\"\r\n    realHeight = (PipePair.ROOM_HIGHT - Bird.HEIGHT) / 2\r\n    if center - bird.y - Bird.HEIGHT / 2 >= 0 and \\\r\n            center - bird.y - Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2:\r\n        state = 0\r\n    elif bird.y - center + Bird.HEIGHT / 2 >= 0 and \\\r\n            bird.y - center + Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2:\r\n        state = 1\r\n    elif center - bird.y - Bird.HEIGHT / 2 >= realHeight - Bird.NEAR_COLLIDE / 2 and \\\r\n            center - bird.y - Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION:\r\n        state = 2\r\n    elif bird.y - center + Bird.HEIGHT / 2 >= realHeight - Bird.NEAR_COLLIDE / 2 and \\\r\n            bird.y - center + Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION:\r\n        state = 3\r\n    elif bird.y + Bird.HEIGHT / 2 <= center - (realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION):\r\n        state = 4\r\n    elif bird.y + Bird.HEIGHT / 2 >= center + realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION:\r\n        state = 5\r\n    return state\r\n\r\ndef getReward(state,lastState,smallerError,isNewPipe):\r\n    \"\"\"\r\n    根据状态和距离的变化获得奖励\r\n    :param state: 执行行为后的当前状态\r\n    :param lastState:执行行为之前的上一状态\r\n    :param smallerError:距离是否变小\r\n    :param isNewPipe:是否是新的管子\r\n    :return:奖励\r\n    \"\"\"\r\n    if state == 0 or state == 1:\r\n        reward = 6\r\n    elif state == 2 or state == 3:\r\n        if lastState == state and not isNewPipe:\r\n            if smallerError:\r\n                reward = 3\r\n            else:\r\n                reward = -5\r\n        else:\r\n            reward = -3\r\n    elif state == 4 or state == 5:\r\n        if lastState == state and not isNewPipe:\r\n            if smallerError:\r\n                reward = 3\r\n            else:\r\n                reward = -8\r\n        else:\r\n            reward = -5\r\n    elif state == 6 or state == 7:\r\n        if lastState == state and not isNewPipe:\r\n            if smallerError:\r\n                reward = 3\r\n            else:\r\n                reward = -3\r\n        else:\r\n            reward = -3\r\n    elif state == 8:   #  collide\r\n        reward = -100\r\n    return reward\r\n\r\ndef updateNet(Net,reward, action, state,weight_trace_d1,weight_trace_d2):\r\n    \"\"\"\r\n    更新网络\r\n    :param Net: BDM-SNN网络\r\n    :param reward: 获得的奖励\r\n    :param action: 执行的行为\r\n    :param state: 执行行为前的状态\r\n    :param weight_trace_d1: 直接通路累积的资格迹\r\n    :param weight_trace_d2: 间接通路累积的资格迹\r\n    :return: 更新后的网络\r\n    \"\"\"\r\n    r = torch.ones((num_state, num_state * num_action), dtype=torch.float)\r\n    r[state, state * num_action + action] = reward\r\n    dw_d1 = r * weight_trace_d1\r\n    dw_d2 = -1 * r * weight_trace_d2\r\n    Net.UpdateWeight(0, state,num_action,dw_d1)\r\n    Net.UpdateWeight(1, state,num_action,dw_d2)\r\n    return Net\r\n\r\nif __name__==\"__main__\":\r\n    \"\"\"\r\n    执行网络，运行Flappy Bird游戏\r\n    \"\"\"\r\n    num_state=9\r\n    num_action=2\r\n    weight_exc=50\r\n    weight_inh=-60\r\n    trace_decay = 0.8\r\n    DM=BDMSNN(num_state,num_action,weight_exc,weight_inh,\"hh\")\r\n    con_matrix1 = torch.zeros((num_state, num_state * num_action), dtype=torch.float)\r\n    for i in range(num_state):\r\n        for j in range(num_action):\r\n            con_matrix1[i, i * num_action + j] = weight_exc\r\n    weight_trace_d1 = torch.zeros(con_matrix1.shape, dtype=torch.float)\r\n    weight_trace_d2 = torch.zeros(con_matrix1.shape, dtype=torch.float)\r\n\r\n    pygame.init()\r\n    WIN_HEIGHT = 512\r\n    WIN_WIDTH = 284 * 2\r\n    heighest = 0\r\n    display_frame=0\r\n    display_surface = pygame.display.set_mode((WIN_WIDTH, WIN_HEIGHT))\r\n    pygame.display.set_caption('Flappy Bird')\r\n    images = load_images()\r\n    bird = Bird(250, int(WIN_HEIGHT / 2 - Bird.HEIGHT / 2), 2,\r\n                (images['bird-wingup'], images['bird-wingdown']))\r\n\r\n    clock = pygame.time.Clock()\r\n    score_font = pygame.font.SysFont(None, 25, bold=True)\r\n    info_font = pygame.font.SysFont(None, 50, bold=True)\r\n    collide = paused = False\r\n    frame_clock = 0\r\n    pipes = deque()\r\n    score = 0\r\n    lastDist = 0\r\n    lastState = 0 #init\r\n    state = lastState\r\n    num=0\r\n    num_reward=[]\r\n    num_score=[]\r\n    while not collide:\r\n        num=num+1\r\n        if num>30000:\r\n            break\r\n        input = torch.zeros((num_state), dtype=torch.float)\r\n        clock.tick(60)\r\n        if frame_clock %2==0 or frame_clock==1:\r\n            state, dist, isNewPipe = judgeState(bird, pipes, collide)\r\n            lastState = state\r\n            lastDist = dist\r\n            input[state]=2\r\n            print(input)\r\n            action,weight_trace_d1,weight_trace_d2,DM = chooseAct(DM,input,weight_trace_d1,weight_trace_d2)\r\n            print(\"state, dist:\", state, dist)\r\n            print(\"state, action:\",state,action)\r\n        if not (paused or frame_clock % (60 * PipePair.ADD_INTERVAL / 1000.0)):\r\n            pygame.event.post(pygame.event.Event(PipePair.ADD_EVENT))\r\n\r\n        for e in pygame.event.get():\r\n            if e.type == QUIT or (e.type == KEYUP and e.key == K_ESCAPE):\r\n                collide = True\r\n            elif e.type == KEYUP and e.key in (K_PAUSE, K_p):\r\n                paused = not paused\r\n            elif e.type == PipePair.ADD_EVENT:\r\n                pp = PipePair(images['pipe-end'], images['pipe-body'])\r\n                pipes.append(pp)\r\n        if paused:\r\n            continue  # don't draw anything\r\n        pipe_collision = any(p.collides_with(bird) for p in pipes)\r\n        if pipe_collision or 0 >= bird.y or bird.y >= WIN_HEIGHT - Bird.HEIGHT:\r\n            collide = True\r\n        for x in (0, WIN_WIDTH / 2):\r\n            display_surface.blit(images['background'], (x, 0))\r\n        while pipes and not pipes[0].visible:\r\n            pipes.popleft()\r\n        for p in pipes:\r\n            p.update()\r\n            display_surface.blit(p.image, p.rect)\r\n        bird.update(action,state)\r\n        display_surface.blit(bird.image, bird.rect)\r\n        if frame_clock %2==0 or frame_clock==1 or collide:\r\n            dist = 0\r\n            if collide:\r\n                nextState = 8\r\n                isNewPipe = False\r\n            else:\r\n                nextState, dist, isNewPipe = judgeState(bird, pipes, collide)  # judge the bird's state\r\n                print(\"next state:\", nextState)\r\n            print(\"lastdist, dist:\", lastDist,dist)\r\n            isSmallerError = False\r\n            if state == nextState:\r\n                isSmallerError = False\r\n                if lastDist <= 0:\r\n                    if lastDist < dist:\r\n                        isSmallerError = True\r\n                else:\r\n                    if lastDist > dist:\r\n                        isSmallerError = True\r\n            if frame_clock>0 and not collide:\r\n                reward = getReward(nextState, state, isSmallerError, isNewPipe)\r\n                print(\"reward:\", reward)\r\n                num_reward.append(reward)\r\n                DM=updateNet(DM,reward, action, state,weight_trace_d1,weight_trace_d2)\r\n            state = nextState  #going on the next state\r\n            weight_trace_d1 = torch.zeros(con_matrix1.shape, dtype=torch.float)\r\n            weight_trace_d2 = torch.zeros(con_matrix1.shape, dtype=torch.float)\r\n            DM.reset()\r\n            display_frame += 1\r\n        for p in pipes:\r\n            if p.x + PipePair.WIDTH < bird.x and not p.score_counted:\r\n                score += 1\r\n                p.score_counted = True\r\n\r\n        num_score.append(score)\r\n        score_surface = score_font.render('Current score: ' + str(score), True, (0, 0, 0))  # current score\r\n        score_x = WIN_WIDTH / 2 - 3 * score_surface.get_width() / 4\r\n        display_surface.blit(score_surface, (score_x, PipePair.PIECE_HEIGHT))\r\n        if heighest < score:\r\n            heighest = score\r\n        score_surface_h = score_font.render('Highest score: ' + str(heighest), True,\r\n                                            (0, 0, 0))  # heighest score\r\n        score_x_h = 4 * WIN_WIDTH / 5 - 1.2 * score_surface.get_width() / 3\r\n        display_surface.blit(score_surface_h, (score_x_h, PipePair.PIECE_HEIGHT))\r\n        score_surface_i = score_font.render('Attempts: 0', True, (0, 0, 0))  # heighest score\r\n        score_x_i = 10\r\n        display_surface.blit(score_surface_i, (score_x_i, PipePair.PIECE_HEIGHT))\r\n        frame_clock += 1\r\n        pygame.display.flip()\r\n\r\n    #  if collide, display the fail information, for 2 frames\r\n    cct = 0\r\n    while (bird.y < WIN_HEIGHT - Bird.HEIGHT - 3):\r\n        clock.tick(60)\r\n        for x in (0, WIN_WIDTH / 2):\r\n            display_surface.blit(images['background'], (x, 0))\r\n        while pipes and not pipes[0].visible:\r\n            pipes.popleft()\r\n        for p in pipes:\r\n            display_surface.blit(p.image, p.rect)\r\n        if cct >= 6:\r\n            bird.sink()\r\n        display_surface.blit(bird.image, bird.rect)\r\n        fail_infor = info_font.render('Game over !', True, (255, 60, 30))  # current score\r\n        pos_x = WIN_WIDTH / 2 - fail_infor.get_width() / 2\r\n        pos_y = WIN_HEIGHT / 2 - 100\r\n        display_surface.blit(fail_infor, (pos_x, pos_y))\r\n        #  display the score\r\n        score_surface = score_font.render('Current score: ' + str(score), True, (0, 0, 0))  # current score\r\n        score_x = WIN_WIDTH / 2 - 3 * score_surface.get_width() / 4\r\n        display_surface.blit(score_surface, (score_x, PipePair.PIECE_HEIGHT))\r\n        if heighest < score:\r\n            heighest = score\r\n        score_surface_h = score_font.render('Highest score: ' + str(heighest), True,\r\n                                            (0, 0, 0))  # heighest score\r\n        score_x_h = 4 * WIN_WIDTH / 5 - 1.2 * score_surface.get_width() / 3\r\n        display_surface.blit(score_surface_h, (score_x_h, PipePair.PIECE_HEIGHT))\r\n        score_surface_i = score_font.render('Attempts: 0' , True, (0, 0, 0))  # heighest score\r\n        score_x_i = 10\r\n        display_surface.blit(score_surface_i, (score_x_i, PipePair.PIECE_HEIGHT))\r\n        pygame.display.flip()\r\n        cct += 1\r\n    if heighest < score:\r\n        heighest = score\r\n\r\n    num_reward_np=np.array(num_reward)\r\n    num_score_np=np.array(num_score)\r\n    print(num_reward_np,num_score_np)\r\n    np.save('hh_reward_l.npy', num_reward_np)\r\n    np.save('hh_score_l.npy', num_score_np)\r\n    print(score)\r\n"
  },
  {
    "path": "examples/decision_making/BDM-SNN/BDM-SNN.py",
    "content": "import torch\r\nimport os\r\nfrom braincog.model_zoo.bdmsnn import BDMSNN\r\nimport pygame\r\nfrom pygame.locals import *\r\nfrom collections import deque\r\nfrom random import randint\r\nimport numpy as np\r\n\r\ntry:\r\n    pygame.display.init()\r\n\r\nexcept:\r\n    os.environ[\"SDL_VIDEODRIVER\"] = \"dummy\"\r\n\r\n\r\ndef load_images():\r\n    \"\"\"\r\n    Flappy Bird中load图像\r\n    :return:load的图像\r\n    \"\"\"\r\n\r\n    def load_image(img_file_name):\r\n        file_name = os.path.join('.', 'birdimages', img_file_name)\r\n        img = pygame.image.load(file_name)\r\n        # converting all images before use speeds up blitting\r\n        img.convert()\r\n        return img\r\n\r\n    return {'background': load_image('background.png'),\r\n            'pipe-end': load_image('pipe_end.png'),\r\n            'pipe-body': load_image('pipe_body.png'),\r\n            # images for animating the flapping bird -- animated GIFs are\r\n            # not supported in pygame\r\n            'bird-wingup': load_image('bird_wing_up.png'),\r\n            'bird-wingdown': load_image('bird_wing_down.png'), }\r\n\r\n\r\nclass Bird(pygame.sprite.Sprite):\r\n    \"\"\"\r\n    Flappy Bird类\r\n    \"\"\"\r\n    WIDTH = HEIGHT = 32\r\n    SINK_SPEED = 0.2\r\n    Fail_SINk_SPEED = 0.6\r\n    CLIMB_SPEED = 0.25\r\n    CLIMB_DURATION = 333.3\r\n    REGION = CLIMB_DURATION / 3\r\n    NEAR_COLLIDE = 30\r\n    NEAR_PIPE = 0\r\n\r\n    def __init__(self, x, y, msec_to_climb, images):\r\n        super(Bird, self).__init__()\r\n        self.x, self.y = x, y\r\n        self.msec_to_climb = msec_to_climb\r\n        self._img_wingup, self._img_wingdown = images\r\n        self._mask_wingup = pygame.mask.from_surface(self._img_wingup)\r\n        self._mask_wingdown = pygame.mask.from_surface(self._img_wingdown)\r\n\r\n    def update(self, action, state, delta_frames=1):\r\n        \"\"\"\r\n        更新小鸟的位置\r\n        :param action: 输入行为\r\n        :param state:输入状态\r\n        :param delta_frames:Fault\r\n        :return:None\r\n        \"\"\"\r\n        if self.msec_to_climb > 0 and action == 1:\r\n            if state == 4 or state == 5 or state == 2 or state == 3:\r\n                self.y -= (2 * Bird.CLIMB_SPEED * (1000.0 * delta_frames / 60))\r\n            else:\r\n                self.y -= (Bird.CLIMB_SPEED * (1000.0 * delta_frames / 60))\r\n        else:\r\n            if state == 4 or state == 5 or state == 2 or state == 3:\r\n                self.y += 2 * Bird.SINK_SPEED * (1000.0 * delta_frames / 60)\r\n            else:\r\n                self.y += Bird.SINK_SPEED * (1000.0 * delta_frames / 60)\r\n\r\n    def sink(self, delta_frames=1):\r\n        self.y += Bird.Fail_SINk_SPEED * (1000.0 * delta_frames / 60)\r\n\r\n    @property\r\n    def image(self):\r\n        if pygame.time.get_ticks() % 500 >= 250:\r\n            return self._img_wingup\r\n        else:\r\n            return self._img_wingdown\r\n\r\n    @property\r\n    def mask(self):\r\n        if pygame.time.get_ticks() % 500 >= 250:\r\n            return self._mask_wingup\r\n        else:\r\n            return self._mask_wingdown\r\n\r\n    @property\r\n    def rect(self):\r\n        return Rect(self.x, self.y, Bird.WIDTH, Bird.HEIGHT)\r\n\r\n\r\nclass PipePair(pygame.sprite.Sprite):\r\n    \"\"\"\r\n    Flappy Bird 中的管子类\r\n    \"\"\"\r\n    WIDTH = 80\r\n    PIECE_HEIGHT = 32\r\n    ADD_INTERVAL = 2000\r\n    ADD_EVENT = pygame.USEREVENT + 1\r\n    ROOM_HIGHT = 2 * Bird.HEIGHT + 2 * PIECE_HEIGHT\r\n\r\n    def __init__(self, pipe_end_img, pipe_body_img):\r\n        self.x = float(WIN_WIDTH - 1)\r\n        self.score_counted = False\r\n        self.isNewPipe = True\r\n\r\n        self.image = pygame.Surface((PipePair.WIDTH, WIN_HEIGHT), SRCALPHA)\r\n        self.image.convert()  # speeds up blitting\r\n        self.image.fill((0, 0, 0, 0))\r\n        total_pipe_body_pieces = int(\r\n            (WIN_HEIGHT -  # fill window from top to bottom\r\n             3 * Bird.HEIGHT -  # make room for bird to fit through\r\n             3 * PipePair.PIECE_HEIGHT) /  # 2 end pieces + 1 body piece\r\n            PipePair.PIECE_HEIGHT  # to get number of pipe pieces\r\n        )\r\n        self.bottom_pieces = randint(1, total_pipe_body_pieces)\r\n        self.top_pieces = total_pipe_body_pieces - self.bottom_pieces\r\n\r\n        # bottom pipe\r\n        for i in range(1, self.bottom_pieces + 1):\r\n            piece_pos = (0, WIN_HEIGHT - i * PipePair.PIECE_HEIGHT)\r\n            self.image.blit(pipe_body_img, piece_pos)\r\n        bottom_pipe_end_y = WIN_HEIGHT - self.bottom_height_px\r\n        bottom_end_piece_pos = (0, bottom_pipe_end_y - PipePair.PIECE_HEIGHT)\r\n        self.image.blit(pipe_end_img, bottom_end_piece_pos)\r\n\r\n        # top pipe\r\n        for i in range(self.top_pieces):\r\n            self.image.blit(pipe_body_img, (0, i * PipePair.PIECE_HEIGHT))\r\n        top_pipe_end_y = self.top_height_px\r\n        self.image.blit(pipe_end_img, (0, top_pipe_end_y))\r\n\r\n        self.center = (top_pipe_end_y + bottom_pipe_end_y) / 2\r\n\r\n        # compensate for added end pieces\r\n        self.top_pieces += 1\r\n        self.bottom_pieces += 1\r\n\r\n        # for collision detection\r\n        self.mask = pygame.mask.from_surface(self.image)\r\n        self.top_y = top_pipe_end_y\r\n        self.bottom_y = bottom_pipe_end_y\r\n\r\n    @property\r\n    def top_height_px(self):\r\n        return self.top_pieces * PipePair.PIECE_HEIGHT\r\n\r\n    @property\r\n    def bottom_height_px(self):\r\n        return self.bottom_pieces * PipePair.PIECE_HEIGHT\r\n\r\n    @property\r\n    def visible(self):\r\n        return -PipePair.WIDTH < self.x < WIN_WIDTH\r\n\r\n    @property\r\n    def rect(self):\r\n        return Rect(self.x, 0, PipePair.WIDTH, PipePair.PIECE_HEIGHT)\r\n\r\n    def update(self, delta_frames=1):\r\n        self.x -= 0.18 * 1000.0 * delta_frames / 60\r\n\r\n    def collides_with(self, bird):\r\n        return pygame.sprite.collide_mask(self, bird)\r\n\r\n\r\ndef chooseAct(Net, input, weight_trace_d1, weight_trace_d2):\r\n    \"\"\"\r\n    根据输入选择行为\r\n    :param Net: 输入BDM-SNN网络\r\n    :param input: 输入电流 编码状态的脉冲\r\n    :param weight_trace_d1: 不断累积保存资格迹\r\n    :param weight_trace_d2: 不断累积保存资格迹\r\n    :return: 返回选择的行为、资格迹和网络\r\n    \"\"\"\r\n    for i_train in range(500):\r\n        out, dw = Net(input)\r\n        # rstdp\r\n        weight_trace_d1 *= trace_decay\r\n        weight_trace_d1 += dw[0][0]\r\n        weight_trace_d2 *= trace_decay\r\n        weight_trace_d2 += dw[1][0]\r\n        if torch.max(out) > 0:\r\n            return torch.argmax(out), weight_trace_d1, weight_trace_d2, Net\r\n\r\n\r\ndef judgeState(bird, pipes, collide):\r\n    \"\"\"\r\n    根据小鸟和管子之间的位置关系判断当前状态\r\n    :param bird:传入小鸟的各项属性\r\n    :param pipes:传入管子的各项属性\r\n    :param collide:是否发生碰撞\r\n    :return:状态，距离，是否是新的管子\r\n    \"\"\"\r\n    # bird's x and y coordinate in the left top of the image\r\n    dist = bird.y + Bird.HEIGHT / 2 - WIN_HEIGHT / 2\r\n    isNew = False\r\n    index = -1\r\n    state = -1\r\n    if collide:\r\n        state = 8\r\n        return state\r\n    for p in pipes:\r\n        if p.x + PipePair.WIDTH - Bird.HEIGHT / 4 < bird.x and not p.score_counted:\r\n            continue\r\n        if p.x - Bird.NEAR_PIPE <= bird.x + Bird.HEIGHT and \\\r\n                p.x + PipePair.WIDTH - Bird.HEIGHT / 4 >= bird.x:\r\n\r\n            p_top_y = p.top_y + PipePair.PIECE_HEIGHT\r\n            p_bottom_y = p.bottom_y - PipePair.PIECE_HEIGHT\r\n            if p.center - bird.y - Bird.HEIGHT / 2 >= 0 and bird.y >= p_top_y + Bird.NEAR_COLLIDE / 2:\r\n                state = 0\r\n            elif bird.y - p.center + Bird.HEIGHT / 2 > 0 and bird.y + Bird.HEIGHT <= p_bottom_y - Bird.NEAR_COLLIDE / 2:\r\n                state = 1\r\n            elif bird.y < p_top_y + Bird.NEAR_COLLIDE / 2 and bird.y > p_top_y - 10:\r\n                state = 6\r\n            elif bird.y + Bird.HEIGHT > p_bottom_y - Bird.NEAR_COLLIDE / 2 and bird.y + Bird.HEIGHT < p_bottom_y + 10:\r\n                state = 7\r\n            if state > -0.5:\r\n                index = 1\r\n        elif p.x > bird.x + Bird.HEIGHT + Bird.NEAR_PIPE:\r\n            state = blankState(bird, p.center)\r\n            if p.isNewPipe:\r\n                isNew = True\r\n            p.isNewPipe = False\r\n            index = 1\r\n        if index > 0:  # only judge the nearest and not passed pipe\r\n            dist = bird.y + Bird.HEIGHT / 2 - p.center\r\n            break\r\n    if index < -0.5:  # no pipe left, key the bird in the middle\r\n        pos = WIN_HEIGHT / 2\r\n        dist = bird.y + Bird.HEIGHT / 2 - pos\r\n        state = blankState(bird, pos)\r\n\r\n    return state, dist, isNew\r\n\r\n\r\ndef blankState(bird, center):\r\n    \"\"\"\r\n    judgeState中调用的判断状态的函数 根据鸟的位置和管子中心的距离来判断\r\n    :param bird: 传入小鸟的各项属性\r\n    :param center:中心\r\n    :return:状态\r\n    \"\"\"\r\n    realHeight = (PipePair.ROOM_HIGHT - Bird.HEIGHT) / 2\r\n    if center - bird.y - Bird.HEIGHT / 2 >= 0 and \\\r\n            center - bird.y - Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2:\r\n        state = 0\r\n    elif bird.y - center + Bird.HEIGHT / 2 >= 0 and \\\r\n            bird.y - center + Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2:\r\n        state = 1\r\n    elif center - bird.y - Bird.HEIGHT / 2 >= realHeight - Bird.NEAR_COLLIDE / 2 and \\\r\n            center - bird.y - Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION:\r\n        state = 2\r\n    elif bird.y - center + Bird.HEIGHT / 2 >= realHeight - Bird.NEAR_COLLIDE / 2 and \\\r\n            bird.y - center + Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION:\r\n        state = 3\r\n    elif bird.y + Bird.HEIGHT / 2 <= center - (realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION):\r\n        state = 4\r\n    elif bird.y + Bird.HEIGHT / 2 >= center + realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION:\r\n        state = 5\r\n    return state\r\n\r\n\r\ndef getReward(state, lastState, smallerError, isNewPipe):\r\n    \"\"\"\r\n    根据状态和距离的变化获得奖励\r\n    :param state: 执行行为后的当前状态\r\n    :param lastState:执行行为之前的上一状态\r\n    :param smallerError:距离是否变小\r\n    :param isNewPipe:是否是新的管子\r\n    :return:奖励\r\n    \"\"\"\r\n    if state == 0 or state == 1:\r\n        reward = 6\r\n    elif state == 2 or state == 3:\r\n        if lastState == state and not isNewPipe:\r\n            if smallerError:\r\n                reward = 3\r\n            else:\r\n                reward = -5\r\n        else:\r\n            reward = -3\r\n    elif state == 4 or state == 5:\r\n        if lastState == state and not isNewPipe:\r\n            if smallerError:\r\n                reward = 3\r\n            else:\r\n                reward = -8\r\n        else:\r\n            reward = -5\r\n    elif state == 6 or state == 7:\r\n        if lastState == state and not isNewPipe:\r\n            if smallerError:\r\n                reward = 3\r\n            else:\r\n                reward = -3\r\n        else:\r\n            reward = -3\r\n    elif state == 8:  # collide\r\n        reward = -100\r\n    return reward\r\n\r\n\r\ndef updateNet(Net, reward, action, state, weight_trace_d1, weight_trace_d2):\r\n    \"\"\"\r\n    更新网络\r\n    :param Net: BDM-SNN网络\r\n    :param reward: 获得的奖励\r\n    :param action: 执行的行为\r\n    :param state: 执行行为前的状态\r\n    :param weight_trace_d1: 直接通路累积的资格迹\r\n    :param weight_trace_d2: 间接通路累积的资格迹\r\n    :return: 更新后的网络\r\n    \"\"\"\r\n    r = torch.ones((num_state, num_state * num_action), dtype=torch.float)\r\n    r[state, state * num_action + action] = reward\r\n    dw_d1 = r * weight_trace_d1\r\n    dw_d2 = -1 * r * weight_trace_d2\r\n    Net.UpdateWeight(0, state, num_action, dw_d1)\r\n    Net.UpdateWeight(1, state, num_action, dw_d2)\r\n    return Net\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    \"\"\"\r\n    执行网络，运行Flappy Bird游戏\r\n    \"\"\"\r\n    num_state = 9\r\n    num_action = 2\r\n    weight_exc = 1\r\n    weight_inh = -0.5\r\n    trace_decay = 0.8\r\n    DM = BDMSNN(num_state, num_action, weight_exc, weight_inh, \"lif\")\r\n    con_matrix1 = torch.zeros((num_state, num_state * num_action), dtype=torch.float)\r\n    for i in range(num_state):\r\n        for j in range(num_action):\r\n            con_matrix1[i, i * num_action + j] = weight_exc\r\n    weight_trace_d1 = torch.zeros(con_matrix1.shape, dtype=torch.float)\r\n    weight_trace_d2 = torch.zeros(con_matrix1.shape, dtype=torch.float)\r\n\r\n    pygame.init()\r\n    WIN_HEIGHT = 512\r\n    WIN_WIDTH = 284 * 2\r\n    heighest = 0\r\n    contTime = 0\r\n    display_frame = 0\r\n    display_surface = pygame.display.set_mode((WIN_WIDTH, WIN_HEIGHT))\r\n    pygame.display.set_caption('Flappy Bird')\r\n    images = load_images()\r\n    bird = Bird(250, int(WIN_HEIGHT / 2 - Bird.HEIGHT / 2), 2,\r\n                (images['bird-wingup'], images['bird-wingdown']))\r\n\r\n    clock = pygame.time.Clock()\r\n    score_font = pygame.font.SysFont(None, 25, bold=True)\r\n    info_font = pygame.font.SysFont(None, 50, bold=True)\r\n    collide = paused = False\r\n    frame_clock = 0\r\n    pipes = deque()\r\n    score = 0\r\n    lastDist = 0\r\n    lastState = 0  # init\r\n    state = lastState\r\n    num = 0\r\n    num_reward = []\r\n    num_score = []\r\n    while not collide:\r\n        num = num + 1\r\n        if num > 30000:\r\n            break\r\n        input = torch.zeros((num_state), dtype=torch.float)\r\n        clock.tick(60)\r\n        if frame_clock % 2 == 0 or frame_clock == 1:\r\n            state, dist, isNewPipe = judgeState(bird, pipes, collide)\r\n            lastState = state\r\n            lastDist = dist\r\n            input[state] = 2\r\n            action, weight_trace_d1, weight_trace_d2, DM = chooseAct(DM, input, weight_trace_d1, weight_trace_d2)\r\n            print(\"state, dist:\", state, dist)\r\n            print(\"state, action:\", state, action)\r\n        if not (paused or frame_clock % (60 * PipePair.ADD_INTERVAL / 1000.0)):\r\n            pygame.event.post(pygame.event.Event(PipePair.ADD_EVENT))\r\n\r\n        for e in pygame.event.get():\r\n            if e.type == QUIT or (e.type == KEYUP and e.key == K_ESCAPE):\r\n                collide = True\r\n            elif e.type == KEYUP and e.key in (K_PAUSE, K_p):\r\n                paused = not paused\r\n            elif e.type == PipePair.ADD_EVENT:\r\n                pp = PipePair(images['pipe-end'], images['pipe-body'])\r\n                pipes.append(pp)\r\n        if paused:\r\n            continue  # don't draw anything\r\n        pipe_collision = any(p.collides_with(bird) for p in pipes)\r\n        if pipe_collision or 0 >= bird.y or bird.y >= WIN_HEIGHT - Bird.HEIGHT:\r\n            collide = True\r\n        for x in (0, WIN_WIDTH / 2):\r\n            display_surface.blit(images['background'], (x, 0))\r\n        while pipes and not pipes[0].visible:\r\n            pipes.popleft()\r\n        for p in pipes:\r\n            p.update()\r\n            display_surface.blit(p.image, p.rect)\r\n        bird.update(action, state)\r\n        display_surface.blit(bird.image, bird.rect)\r\n        if frame_clock % 2 == 0 or frame_clock == 1 or collide:\r\n            dist = 0\r\n            if collide:\r\n                nextState = 8\r\n                isNewPipe = False\r\n            else:\r\n                nextState, dist, isNewPipe = judgeState(bird, pipes, collide)  # judge the bird's state\r\n                print(\"next state:\", nextState)\r\n            print(\"lastdist, dist:\", lastDist, dist)\r\n            isSmallerError = False\r\n            if state == nextState:\r\n                isSmallerError = False\r\n                if lastDist <= 0:\r\n                    if lastDist < dist:\r\n                        isSmallerError = True\r\n                else:\r\n                    if lastDist > dist:\r\n                        isSmallerError = True\r\n            if frame_clock > 0 and not collide:\r\n                reward = getReward(nextState, state, isSmallerError, isNewPipe)\r\n                print(\"reward:\", reward)\r\n                num_reward.append(reward)\r\n                DM = updateNet(DM, reward, action, state, weight_trace_d1, weight_trace_d2)\r\n            state = nextState  # going on the next state\r\n            weight_trace_d1 = torch.zeros(con_matrix1.shape, dtype=torch.float)\r\n            weight_trace_d2 = torch.zeros(con_matrix1.shape, dtype=torch.float)\r\n            DM.reset()\r\n            display_frame += 1\r\n        for p in pipes:\r\n            if p.x + PipePair.WIDTH < bird.x and not p.score_counted:\r\n                score += 1\r\n                p.score_counted = True\r\n        num_score.append(score)\r\n        score_surface = score_font.render('Current score: ' + str(score), True, (0, 0, 0))  # current score\r\n        score_x = WIN_WIDTH / 2 - 3 * score_surface.get_width() / 4\r\n        display_surface.blit(score_surface, (score_x, PipePair.PIECE_HEIGHT))\r\n        if heighest < score:\r\n            heighest = score\r\n        score_surface_h = score_font.render('Highest score: ' + str(heighest), True,\r\n                                            (0, 0, 0))  # heighest score\r\n        score_x_h = 4 * WIN_WIDTH / 5 - 1.2 * score_surface.get_width() / 3\r\n        display_surface.blit(score_surface_h, (score_x_h, PipePair.PIECE_HEIGHT))\r\n        score_surface_i = score_font.render('Attempts: 0', True, (0, 0, 0))  # heighest score\r\n        score_x_i = 10\r\n        display_surface.blit(score_surface_i, (score_x_i, PipePair.PIECE_HEIGHT))\r\n        frame_clock += 1\r\n        pygame.display.flip()\r\n\r\n    #  if collide, display the fail information, for 2 frames\r\n    cct = 0\r\n    while (bird.y < WIN_HEIGHT - Bird.HEIGHT - 3):\r\n        clock.tick(60)\r\n        for x in (0, WIN_WIDTH / 2):\r\n            display_surface.blit(images['background'], (x, 0))\r\n        while pipes and not pipes[0].visible:\r\n            pipes.popleft()\r\n        for p in pipes:\r\n            display_surface.blit(p.image, p.rect)\r\n        if cct >= 6:\r\n            bird.sink()\r\n        display_surface.blit(bird.image, bird.rect)\r\n        fail_infor = info_font.render('Game over !', True, (255, 60, 30))  # current score\r\n        pos_x = WIN_WIDTH / 2 - fail_infor.get_width() / 2\r\n        pos_y = WIN_HEIGHT / 2 - 100\r\n        display_surface.blit(fail_infor, (pos_x, pos_y))\r\n        #  display the score\r\n        score_surface = score_font.render('Current score: ' + str(score), True, (0, 0, 0))  # current score\r\n        score_x = WIN_WIDTH / 2 - 3 * score_surface.get_width() / 4\r\n        display_surface.blit(score_surface, (score_x, PipePair.PIECE_HEIGHT))\r\n        if heighest < score:\r\n            heighest = score\r\n        score_surface_h = score_font.render('Highest score: ' + str(heighest), True,\r\n                                            (0, 0, 0))  # heighest score\r\n        score_x_h = 4 * WIN_WIDTH / 5 - 1.2 * score_surface.get_width() / 3\r\n        display_surface.blit(score_surface_h, (score_x_h, PipePair.PIECE_HEIGHT))\r\n        score_surface_i = score_font.render('Attempts: 0', True, (0, 0, 0))  # heighest score\r\n        score_x_i = 10\r\n        display_surface.blit(score_surface_i, (score_x_i, PipePair.PIECE_HEIGHT))\r\n        pygame.display.flip()\r\n        cct += 1\r\n    if heighest < score:\r\n        heighest = score\r\n    contTime += 1\r\n\r\n    num_reward_np = np.array(num_reward)\r\n    num_score_np = np.array(num_score)\r\n    print(num_reward_np, num_score_np)\r\n    np.save('lif_reward_l.npy', num_reward_np)\r\n    np.save('lif_score_l.npy', num_score_np)\r\n    print(score)\r\n"
  },
  {
    "path": "examples/decision_making/BDM-SNN/README.md",
    "content": "\n# Brain-inspired Decision-Making SNN\n\n## Requirements\n\n\"decisionmaking.py\", \"BDM-SNN.py\"，\"BDM-SNN-hh.py\"：pygame\n\n\"BDM-SNN-UAV.py\"：robomaster\n\n\n## Run\n The decisionmaking.py and BDM-SNN.py implements the core code of the brain-inspired decision-making spiking neural network in paper entitled \"A brain-inspired decision-making spiking neural network and its application in unmanned aerial vehicle\".\n\n \"decisionmaking.py, BDM-SNN.py\"  includes the multi-brain regions coordinated decision-making spiking neural network with LIF neurons.\n\n \"BDM-SNN-hh.py\"  includes the BDM-SNN with simplified HH neurons.\n\n \"BDM-SNN-UAV.py\"  includes the BDM-SNN applied to the UAV (DJI Tello talent), users need to define the reinforcement learning task.\n\n```shell\npython decisionmaking.py\n\npython BDM-SNN.py\n\npython BDM-SNN-hh.py\n\npython BDM-SNN-UAV.py\n```\n\n## Results\n\"decisionmaking.py\", \"BDM-SNN.py\"  and  \"BDM-SNN-hh.py\"  have been verified on Flappy Bird game. BDM-SNN could stably pass the pipeline on the first try.\n\n![description](./bdm.png)\n\nDifferences from the original article: an improved reward-modulated STDP learning rule.\n\n## Citation\n\nIf you find this package helpful, please consider citing the following papers:\n\n```BibTex\n@article{zhao2018brain,\n  title={A brain-inspired decision-making spiking neural network and its application in unmanned aerial vehicle},\n  author={Zhao, Feifei and Zeng, Yi and Xu, Bo},\n  journal={Frontiers in neurorobotics},\n  volume={12},\n  pages={56},\n  year={2018},\n  publisher={Frontiers Media SA}\n}\n\n@misc{https://doi.org/10.48550/arxiv.2207.08533,\n  doi = {10.48550/ARXIV.2207.08533},\n  url = {https://arxiv.org/abs/2207.08533},\n  author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},\n  title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},\n  publisher = {arXiv},\n  year = {2022},\n}\n\n```\n"
  },
  {
    "path": "examples/decision_making/BDM-SNN/decisionmaking.py",
    "content": "import numpy as np\r\nimport torch,os,sys\r\nfrom torch import nn\r\nfrom torch.nn import Parameter\r\n\r\nimport abc\r\nimport math\r\nfrom abc import ABC\r\n\r\nimport torch.nn.functional as F\r\nimport matplotlib.pyplot as plt\r\n#from BrainCog.base.strategy.surrogate import *\r\nfrom braincog.base.node.node import IFNode, SimHHNode\r\nfrom braincog.base.learningrule.STDP import STDP, MutliInputSTDP\r\nfrom braincog.base.connection.CustomLinear import CustomLinear\r\nfrom braincog.base.brainarea.basalganglia import basalganglia\r\n#from braincog.model_zoo.bdmsnn import BDMSNN\r\n\r\nimport pygame\r\nfrom pygame.locals import *\r\nfrom collections import deque\r\nfrom random import randint\r\n#os.environ[\"SDL_VIDEODRIVER\"] = \"dummy\"\r\nclass BDMSNN(nn.Module):\r\n    def __init__(self, num_state, num_action, weight_exc, weight_inh, node_type):\r\n        \"\"\"\r\n        定义BDM-SNN网络\r\n        :param num_state: 状态个数\r\n        :param num_action: 动作个数\r\n        :param weight_exc: 兴奋性连接权重\r\n        :param weight_inh: 抑制性连接权重\r\n        \"\"\"\r\n        super().__init__()\r\n        # parameters\r\n        BG = basalganglia(num_state, num_action, weight_exc, weight_inh, node_type)\r\n        dm_connection = BG.getweight()\r\n        dm_mask = BG.getmask()\r\n        # input-dlpfc\r\n        con_matrix9 = torch.eye((num_state), dtype=torch.float)\r\n        dm_connection.append(CustomLinear(weight_exc * con_matrix9, con_matrix9))\r\n        dm_mask.append(con_matrix9)\r\n        # gpi-th\r\n        con_matrix10 = torch.eye((num_action), dtype=torch.float)\r\n        dm_mask.append(con_matrix10)\r\n        dm_connection.append(CustomLinear(weight_inh * con_matrix10, con_matrix10))\r\n        # th-pm\r\n        dm_mask.append(con_matrix10)\r\n        dm_connection.append(CustomLinear(weight_exc * con_matrix10, con_matrix10))\r\n        # dlpfc-th\r\n        con_matrix11 = torch.ones((num_state, num_action), dtype=torch.float)\r\n        dm_mask.append(con_matrix11)\r\n        dm_connection.append(CustomLinear(0.2 * weight_exc * con_matrix11, con_matrix11))\r\n        # pm-pm\r\n        con_matrix3 = torch.ones((num_action, num_action), dtype=torch.float)\r\n        con_matrix4 = torch.eye((num_action), dtype=torch.float)\r\n        con_matrix5 = con_matrix3 - con_matrix4\r\n        con_matrix5 = con_matrix5\r\n        dm_mask.append(con_matrix5)\r\n        dm_connection.append(CustomLinear(5 * weight_inh * con_matrix5, con_matrix5))\r\n        # dlpfc thalamus pm +bg\r\n        self.weight_exc = weight_exc\r\n        self.num_subDM = 8\r\n        self.connection = dm_connection\r\n        self.mask = dm_mask\r\n        self.node = BG.node\r\n        self.node_type = node_type\r\n        if self.node_type == \"hh\":\r\n            self.node.extend([SimHHNode() for i in range(self.num_subDM - BG.num_subBG)])\r\n            self.node[6].g_Na = torch.tensor(12)\r\n            self.node[6].g_K = torch.tensor(3.6)\r\n            self.node[6].g_L = torch.tensor(0.03)\r\n        if self.node_type == \"lif\":\r\n            self.node.extend([IFNode() for i in range(self.num_subDM - BG.num_subBG)])\r\n        self.learning_rule = BG.learning_rule\r\n        self.learning_rule.append(MutliInputSTDP(self.node[5], [self.connection[10], self.connection[12]]))  # gpi-丘脑\r\n        self.learning_rule.append(MutliInputSTDP(self.node[6], [self.connection[11], self.connection[13]]))  # pm\r\n        self.learning_rule.append(STDP(self.node[7], self.connection[9]))\r\n\r\n        out_shape=[self.connection[0].weight.shape[1],self.connection[1].weight.shape[1],self.connection[2].weight.shape[1],self.connection[4].weight.shape[1],self.connection[3].weight.shape[1],self.connection[10].weight.shape[1],self.connection[11].weight.shape[1],self.connection[9].weight.shape[1]]\r\n        self.out = []\r\n        self.dw = []\r\n        for i in range(self.num_subDM):\r\n            self.out.append(torch.zeros((out_shape[i]), dtype=torch.float))\r\n            self.dw.append(torch.zeros((out_shape[i]), dtype=torch.float))\r\n\r\n    def forward(self, input):\r\n        \"\"\"\r\n        根据输入得到网络的输出\r\n        :param input: 输入\r\n        :return: 网络的输出\r\n        \"\"\"\r\n        self.out[7] = self.node[7](self.connection[9](input))\r\n        self.out[0], self.dw[0] = self.learning_rule[0](self.out[7])\r\n        self.out[1], self.dw[1] = self.learning_rule[1](self.out[7])\r\n        self.out[2], self.dw[2] = self.learning_rule[2](self.out[7], self.out[3])\r\n        self.out[3], self.dw[3] = self.learning_rule[3](self.out[1], self.out[2])\r\n        self.out[4], self.dw[4] = self.learning_rule[4](self.out[0], self.out[3], self.out[2])\r\n        self.out[5], self.dw[5] = self.learning_rule[5](self.out[4], self.out[7])\r\n        self.out[6], self.dw[6] = self.learning_rule[6](self.out[5], self.out[6])\r\n        br = [\"StrD1\", \"StrD2\", \"STN\", \"Gpe\", \"Gpi\", \"thalamus\", \"PM\", \"DLPFC\"]\r\n        for i in range(self.num_subDM):\r\n            if torch.max(self.out[i]) > 0 and self.node_type == \"hh\":\r\n                self.node[i].n_reset()\r\n            print(\"every areas:\", br[i], self.out[i])\r\n        return self.out[6], self.dw\r\n\r\n    def UpdateWeight(self, i, s, num_action, dw):\r\n        \"\"\"\r\n        更新网络中第i组连接的权重\r\n        :param i:要更新的连接组索引\r\n        :param s:传入状态\r\n        :param dw:更新权重的量\r\n        :return:\r\n        \"\"\"\r\n        if self.node_type == \"hh\":\r\n            self.connection[i].update(0.2 * self.weight_exc * dw)\r\n            self.connection[i].weight.data[s, [s * num_action, s * num_action + 1]] /= (self.connection[i].weight.data[s, [s * num_action, s * num_action + 1]].float().max() + 1e-12)\r\n            self.connection[i].weight.data[s, :] = self.connection[i].weight.data[s, :] * self.weight_exc\r\n        if self.node_type == \"lif\":\r\n            dw_mean = dw[s, [s * num_action, s * num_action + 1]].mean()\r\n            dw_std = dw[s, [s * num_action, s * num_action + 1]].std()\r\n            dw[s, [s * num_action, s * num_action + 1]] = (dw[s, [s * num_action,s * num_action + 1]] - dw_mean) / dw_std\r\n            dw[s, :] = dw[s, :] * self.mask[i][s, :]\r\n            self.connection[i].update(dw)\r\n            self.connection[i].weight.data[s, [s * num_action, s * num_action + 1]] /= (self.connection[i].weight.data[s, [s * num_action, s * num_action + 1]].float().max() + 1e-12)\r\n        if i in [0, 1, 2, 6, 7, 11, 12]:\r\n            self.connection[i].weight.data = torch.clamp(self.connection[i].weight.data, 0, None)\r\n        if i in [3, 4, 5, 8, 10]:\r\n            self.connection[i].weight.data = torch.clamp(self.connection[i].weight.data, None, 0)\r\n\r\n    def reset(self):\r\n        \"\"\"\r\n        reset神经元或学习法则的中间量\r\n        :return: None\r\n        \"\"\"\r\n        for i in range(self.num_subDM):\r\n            self.node[i].n_reset()\r\n        for i in range(len(self.learning_rule)):\r\n            self.learning_rule[i].reset()\r\n\r\n    def getweight(self):\r\n        \"\"\"\r\n        获取网络的连接(包括权值等)\r\n        :return: 网络的连接\r\n        \"\"\"\r\n        return self.connection\r\n\r\ndef load_images():\r\n    \"\"\"Load all images required by the game and return a dict of them.\r\n\r\n    The returned dict has the following keys:\r\n    background: The game's background image.\r\n    bird-wingup: An image of the bird with its wing pointing upward.\r\n        Use this and bird-wingdown to create a flapping bird.\r\n    bird-wingdown: An image of the bird with its wing pointing downward.\r\n        Use this and bird-wingup to create a flapping bird.\r\n    pipe-end: An image of a pipe's end piece (the slightly wider bit).\r\n        Use this and pipe-body to make pipes.\r\n    pipe-body: An image of a slice of a pipe's body.  Use this and\r\n        pipe-body to make pipes.\r\n    \"\"\"\r\n\r\n    def load_image(img_file_name):\r\n        \"\"\"Return the loaded pygame image with the specified file name.\r\n\r\n        This function looks for images in the game's images folder\r\n        (./images/).  All images are converted before being returned to\r\n        speed up blitting.\r\n\r\n        Arguments:\r\n        img_file_name: The file name (including its extension, e.g.\r\n            '.png') of the required image, without a file path.\r\n        \"\"\"\r\n        file_name = os.path.join('.', 'birdimages', img_file_name)\r\n        img = pygame.image.load(file_name)\r\n        # converting all images before use speeds up blitting\r\n        img.convert()\r\n        return img\r\n\r\n    return {'background': load_image('background.png'),\r\n            'pipe-end': load_image('pipe_end.png'),\r\n            'pipe-body': load_image('pipe_body.png'),\r\n            # images for animating the flapping bird -- animated GIFs are\r\n            # not supported in pygame\r\n            'bird-wingup': load_image('bird_wing_up.png'),\r\n            'bird-wingdown': load_image('bird_wing_down.png'),}\r\n\r\nclass Bird(pygame.sprite.Sprite):\r\n    WIDTH = HEIGHT = 32\r\n    SINK_SPEED = 0.2\r\n    Fail_SINk_SPEED = 0.6\r\n    CLIMB_SPEED = 0.25\r\n    CLIMB_DURATION = 333.3\r\n    REGION = CLIMB_DURATION / 3  # when far from the pipe, the bird can fluctuate in a certain region,wgx\r\n    NEAR_COLLIDE = 30  # when inside the pipe, near collide distance, this define another state\r\n    NEAR_PIPE = 0  # at what distance does the bird near the pipe\r\n\r\n    def __init__(self, x, y, msec_to_climb, images):\r\n        super(Bird, self).__init__()\r\n        self.x, self.y = x, y\r\n        self.msec_to_climb = msec_to_climb\r\n        self._img_wingup, self._img_wingdown = images\r\n        self._mask_wingup = pygame.mask.from_surface(self._img_wingup)\r\n        self._mask_wingdown = pygame.mask.from_surface(self._img_wingdown)\r\n\r\n    def update(self, action,state,delta_frames=1):\r\n        if self.msec_to_climb > 0 and action == 1:\r\n            if state==4 or state==5 or state == 2 or state == 3:\r\n                self.y -= (2*Bird.CLIMB_SPEED * (1000.0 * delta_frames / 60))\r\n            else:\r\n                self.y -= (Bird.CLIMB_SPEED * (1000.0 * delta_frames / 60))\r\n        else:\r\n            if state == 4 or state == 5 or state == 2 or state == 3:\r\n                self.y += 2*Bird.SINK_SPEED * (1000.0 * delta_frames / 60)\r\n            else:\r\n                self.y += Bird.SINK_SPEED * (1000.0 * delta_frames / 60)\r\n\r\n    #  if the bird fails, sink the bird till it hit the bottom\r\n    def sink(self, delta_frames=1):\r\n        self.y += Bird.Fail_SINk_SPEED * (1000.0 * delta_frames / 60)\r\n\r\n    @property\r\n    def image(self):\r\n        if pygame.time.get_ticks() % 500 >= 250:\r\n            return self._img_wingup\r\n        else:\r\n            return self._img_wingdown\r\n\r\n    @property\r\n    def mask(self):\r\n        if pygame.time.get_ticks() % 500 >= 250:\r\n            return self._mask_wingup\r\n        else:\r\n            return self._mask_wingdown\r\n\r\n    @property\r\n    def rect(self):\r\n        return Rect(self.x, self.y, Bird.WIDTH, Bird.HEIGHT)\r\n\r\nclass PipePair(pygame.sprite.Sprite):\r\n    WIDTH = 80\r\n    PIECE_HEIGHT = 32\r\n    ADD_INTERVAL = 2000\r\n    ADD_EVENT = pygame.USEREVENT + 1\r\n    ROOM_HIGHT = 2 * Bird.HEIGHT + 2 * PIECE_HEIGHT\r\n\r\n    def __init__(self, pipe_end_img, pipe_body_img):\r\n        \"\"\"Initialises a new random PipePair.\r\n\r\n        The new PipePair will automatically be assigned an x attribute of\r\n        float(WIN_WIDTH - 1).\r\n\r\n        Arguments:\r\n        pipe_end_img: The image to use to represent a pipe's end piece.\r\n        pipe_body_img: The image to use to represent one horizontal slice\r\n            of a pipe's body.\r\n        \"\"\"\r\n        self.x = float(WIN_WIDTH - 1)\r\n        self.score_counted = False\r\n        self.isNewPipe = True\r\n\r\n        self.image = pygame.Surface((PipePair.WIDTH, WIN_HEIGHT), SRCALPHA)\r\n        self.image.convert()   # speeds up blitting\r\n        self.image.fill((0, 0, 0, 0))\r\n        total_pipe_body_pieces = int(\r\n            (WIN_HEIGHT -  # fill window from top to bottom\r\n             3 * Bird.HEIGHT -  # make room for bird to fit through\r\n             3 * PipePair.PIECE_HEIGHT) /  # 2 end pieces + 1 body piece\r\n            PipePair.PIECE_HEIGHT  # to get number of pipe pieces\r\n        )\r\n        self.bottom_pieces = randint(1, total_pipe_body_pieces)\r\n        self.top_pieces = total_pipe_body_pieces - self.bottom_pieces\r\n\r\n        # bottom pipe\r\n        for i in range(1, self.bottom_pieces + 1):\r\n            piece_pos = (0, WIN_HEIGHT - i * PipePair.PIECE_HEIGHT)\r\n            self.image.blit(pipe_body_img, piece_pos)\r\n        bottom_pipe_end_y = WIN_HEIGHT - self.bottom_height_px\r\n        bottom_end_piece_pos = (0, bottom_pipe_end_y - PipePair.PIECE_HEIGHT)\r\n        self.image.blit(pipe_end_img, bottom_end_piece_pos)\r\n\r\n        # top pipe\r\n        for i in range(self.top_pieces):\r\n            self.image.blit(pipe_body_img, (0, i * PipePair.PIECE_HEIGHT))\r\n        top_pipe_end_y = self.top_height_px\r\n        self.image.blit(pipe_end_img, (0, top_pipe_end_y))\r\n\r\n        self.center = (top_pipe_end_y + bottom_pipe_end_y) / 2  # center of pipe-room,wgx\r\n\r\n        # compensate for added end pieces\r\n        self.top_pieces += 1\r\n        self.bottom_pieces += 1\r\n\r\n        # for collision detection\r\n        self.mask = pygame.mask.from_surface(self.image)\r\n        self.top_y = top_pipe_end_y\r\n        self.bottom_y = bottom_pipe_end_y\r\n\r\n    @property\r\n    def top_height_px(self):\r\n        \"\"\"Get the top pipe's height, in pixels.\"\"\"\r\n        return self.top_pieces * PipePair.PIECE_HEIGHT\r\n\r\n    @property\r\n    def bottom_height_px(self):\r\n        \"\"\"Get the bottom pipe's height, in pixels.\"\"\"\r\n        return self.bottom_pieces * PipePair.PIECE_HEIGHT\r\n\r\n    @property\r\n    def visible(self):\r\n        \"\"\"Get whether this PipePair on screen, visible to the player.\"\"\"\r\n        return -PipePair.WIDTH < self.x < WIN_WIDTH\r\n\r\n    @property\r\n    def rect(self):\r\n        \"\"\"Get the Rect which contains this PipePair.\"\"\"\r\n        return Rect(self.x, 0, PipePair.WIDTH, PipePair.PIECE_HEIGHT)\r\n\r\n    def update(self, delta_frames=1):\r\n        \"\"\"Update the PipePair's position.\r\n\r\n        Attributes:\r\n        delta_frames: The number of frames elapsed since this method was\r\n            last called.\r\n        \"\"\"\r\n        self.x -= 0.18 * 1000.0 * delta_frames /60\r\n\r\n    def collides_with(self, bird):\r\n        \"\"\"Get whether the bird collides with a pipe in this PipePair.\r\n\r\n        Arguments:\r\n        bird: The Bird which should be tested for collision with this\r\n            PipePair.\r\n        \"\"\"\r\n        return pygame.sprite.collide_mask(self, bird)\r\n\r\ndef chooseAct(Net,s,input,weight_trace_d1,weight_trace_d2):\r\n    for i_train in range(500):\r\n        out, dw = Net(input)\r\n        # 更新权重\r\n        # Net.UpdateWeight(10, dw[5][0])\r\n        # Net.UpdateWeight(12, dw[5][1])\r\n        # Net.UpdateWeight(11, dw[6][0])\r\n        # rstdp\r\n        weight_trace_d1 *= trace_decay\r\n        weight_trace_d1 += dw[0][0]\r\n        weight_trace_d2 *= trace_decay\r\n        weight_trace_d2 += dw[1][0]\r\n        if torch.max(out) > 0:\r\n            return torch.argmax(out),weight_trace_d1,weight_trace_d2,Net\r\n\r\ndef judgeState(bird, pipes, collide):\r\n    # bird's x and y coordinate in the left top of the image\r\n    dist = bird.y + Bird.HEIGHT / 2 - WIN_HEIGHT / 2\r\n    isNew = False\r\n    index = -1\r\n    state = -1\r\n    if collide:\r\n        state = 8\r\n        return state\r\n    for p in pipes:\r\n        if p.x + PipePair.WIDTH - Bird.HEIGHT / 4 < bird.x and not p.score_counted:\r\n            continue\r\n        if p.x - Bird.NEAR_PIPE <= bird.x + Bird.HEIGHT and \\\r\n                p.x + PipePair.WIDTH - Bird.HEIGHT / 4 >= bird.x:\r\n\r\n            p_top_y = p.top_y + PipePair.PIECE_HEIGHT\r\n            p_bottom_y = p.bottom_y - PipePair.PIECE_HEIGHT\r\n            if p.center - bird.y - Bird.HEIGHT / 2 >= 0 and bird.y >= p_top_y + Bird.NEAR_COLLIDE / 2:\r\n                state = 0\r\n            elif bird.y - p.center + Bird.HEIGHT / 2 > 0 and bird.y + Bird.HEIGHT <= p_bottom_y - Bird.NEAR_COLLIDE / 2:\r\n                state = 1\r\n            elif bird.y < p_top_y + Bird.NEAR_COLLIDE / 2 and bird.y > p_top_y - 10:\r\n                state = 6\r\n            elif bird.y + Bird.HEIGHT > p_bottom_y - Bird.NEAR_COLLIDE / 2 and bird.y + Bird.HEIGHT < p_bottom_y + 10:\r\n                state = 7\r\n            if state > -0.5:\r\n                index = 1\r\n        elif p.x > bird.x + Bird.HEIGHT + Bird.NEAR_PIPE:\r\n            state = blankState(bird, p.center)\r\n            if p.isNewPipe:\r\n                isNew = True\r\n            p.isNewPipe = False\r\n            index = 1\r\n        if index > 0:  # only judge the nearest and not passed pipe\r\n            dist = bird.y + Bird.HEIGHT / 2 - p.center\r\n            break\r\n    if index < -0.5:  # no pipe left, key the bird in the middle\r\n        pos = WIN_HEIGHT / 2\r\n        dist = bird.y + Bird.HEIGHT / 2 - pos\r\n        state = blankState(bird, pos)\r\n\r\n    return state, dist, isNew\r\n\r\ndef blankState(bird, center):  # judge the state before passing the pipe\r\n    realHeight = (PipePair.ROOM_HIGHT - Bird.HEIGHT) / 2\r\n    if center - bird.y - Bird.HEIGHT / 2 >= 0 and \\\r\n            center - bird.y - Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2:\r\n        state = 0\r\n    elif bird.y - center + Bird.HEIGHT / 2 >= 0 and \\\r\n            bird.y - center + Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2:\r\n        state = 1\r\n    elif center - bird.y - Bird.HEIGHT / 2 >= realHeight - Bird.NEAR_COLLIDE / 2 and \\\r\n            center - bird.y - Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION:\r\n        state = 2\r\n    elif bird.y - center + Bird.HEIGHT / 2 >= realHeight - Bird.NEAR_COLLIDE / 2 and \\\r\n            bird.y - center + Bird.HEIGHT / 2 < realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION:\r\n        state = 3\r\n    elif bird.y + Bird.HEIGHT / 2 <= center - (realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION):\r\n        state = 4\r\n    elif bird.y + Bird.HEIGHT / 2 >= center + realHeight - Bird.NEAR_COLLIDE / 2 + Bird.REGION:\r\n        state = 5\r\n    return state\r\n\r\ndef getReward(state,lastState,smallerError,isNewPipe):\r\n    if state == 0 or state == 1:\r\n        reward = 6\r\n    elif state == 2 or state == 3:\r\n        if lastState == state and not isNewPipe:\r\n            if smallerError:\r\n                reward = 3\r\n            else:\r\n                reward = -5\r\n        else:\r\n            reward = -3\r\n    elif state == 4 or state == 5:\r\n        if lastState == state and not isNewPipe:\r\n            if smallerError:\r\n                reward = 3\r\n            else:\r\n                reward = -8\r\n        else:\r\n            reward = -5\r\n    elif state == 6 or state == 7:\r\n        if lastState == state and not isNewPipe:\r\n            if smallerError:\r\n                reward = 3\r\n            else:\r\n                reward = -3\r\n        else:\r\n            reward = -3\r\n    elif state == 8:   #  collide\r\n        reward = -100\r\n    return reward\r\n\r\ndef updateNet(Net,reward, action, state,weight_trace_d1,weight_trace_d2):\r\n    r = torch.ones((num_state, num_state * num_action), dtype=torch.float)\r\n    r[state, state * num_action + action] = reward\r\n    dw_d1 = r * weight_trace_d1\r\n    dw_d2 = -1 * r * weight_trace_d2\r\n    Net.UpdateWeight(0, state, num_action, dw_d1)\r\n    Net.UpdateWeight(1, state, num_action, dw_d2)\r\n    return Net\r\n\r\nif __name__==\"__main__\":\r\n    #定义网络\r\n    num_state=9\r\n    num_action=2\r\n    weight_exc=1\r\n    weight_inh=-0.5\r\n    trace_decay = 0.8\r\n    DM = BDMSNN(num_state, num_action, weight_exc, weight_inh, \"lif\")\r\n    con_matrix1 = torch.zeros((num_state, num_state * num_action), dtype=torch.float)\r\n    for i in range(num_state):\r\n        for j in range(num_action):\r\n            con_matrix1[i, i * num_action + j] = weight_exc\r\n    weight_trace_d1 = torch.zeros(con_matrix1.shape, dtype=torch.float)\r\n    weight_trace_d2 = torch.zeros(con_matrix1.shape, dtype=torch.float)\r\n\r\n    #定义游戏场景\r\n    pygame.init()\r\n    WIN_HEIGHT = 512\r\n    WIN_WIDTH = 284 * 2  # image size: 284x512 px; tiled twice\r\n    heighest = 0\r\n    iteration=0\r\n    contTime = 0  # number of times to restart\r\n    display_frame=0\r\n    while iteration < 20:       #  restart the game for reinforcement learning, wgx\r\n        display_surface = pygame.display.set_mode((WIN_WIDTH, WIN_HEIGHT))\r\n        pygame.display.set_caption('Flappy Bird')\r\n        images = load_images()\r\n        bird = Bird(250, int(WIN_HEIGHT / 2 - Bird.HEIGHT / 2), 2,\r\n                    (images['bird-wingup'], images['bird-wingdown']))\r\n\r\n        clock = pygame.time.Clock()\r\n        score_font = pygame.font.SysFont(None, 25, bold=True)  # default font\r\n        info_font = pygame.font.SysFont(None, 50, bold=True)\r\n        collide = paused = False\r\n        frame_clock = 0\r\n        pipes = deque()\r\n        score = 0\r\n        lastDist = 0\r\n        lastState = 0 #init\r\n        state = lastState\r\n        while not collide:\r\n            # 输入\r\n            input = torch.zeros((num_state), dtype=torch.float)\r\n            clock.tick(60)\r\n            if frame_clock %2==0 or frame_clock==1:\r\n                state, dist, isNewPipe = judgeState(bird, pipes, collide)  # judge the bird's state\r\n                lastState = state\r\n                lastDist = dist\r\n                input[state]=2\r\n                action,weight_trace_d1,weight_trace_d2,DM = chooseAct(DM,state,input,weight_trace_d1,weight_trace_d2)\r\n                print(\"state, dist:\", state, dist)\r\n                print(\"state, action:\",state,action)\r\n            if not (paused or frame_clock % (60 * PipePair.ADD_INTERVAL / 1000.0)):\r\n                pygame.event.post(pygame.event.Event(PipePair.ADD_EVENT))\r\n\r\n            for e in pygame.event.get():\r\n                if e.type == QUIT or (e.type == KEYUP and e.key == K_ESCAPE):\r\n                    collide = True\r\n                elif e.type == KEYUP and e.key in (K_PAUSE, K_p):\r\n                    paused = not paused\r\n                elif e.type == PipePair.ADD_EVENT:\r\n                    pp = PipePair(images['pipe-end'], images['pipe-body'])\r\n                    pipes.append(pp)\r\n            if paused:\r\n                continue  # don't draw anything\r\n            # check for collisions\r\n            pipe_collision = any(p.collides_with(bird) for p in pipes)\r\n            if pipe_collision or 0 >= bird.y or bird.y >= WIN_HEIGHT - Bird.HEIGHT:\r\n                collide = True\r\n            for x in (0, WIN_WIDTH / 2):\r\n                display_surface.blit(images['background'], (x, 0))\r\n            while pipes and not pipes[0].visible:\r\n                pipes.popleft()\r\n            for p in pipes:\r\n                p.update()\r\n                display_surface.blit(p.image, p.rect)\r\n            bird.update(action,state)\r\n            display_surface.blit(bird.image, bird.rect)\r\n            if frame_clock %2==0 or frame_clock==1 or collide:\r\n                # judge the state and update the value function\r\n                dist = 0\r\n                if collide:\r\n                    nextState = 8\r\n                    isNewPipe = False\r\n                else:\r\n                    nextState, dist, isNewPipe = judgeState(bird, pipes, collide)  # judge the bird's state\r\n                    print(\"next state:\", nextState)\r\n                print(\"lastdist, dist:\", lastDist,dist)\r\n                isSmallerError = False\r\n                if state == nextState:\r\n                    isSmallerError = False\r\n                    if lastDist <= 0:\r\n                        if lastDist < dist:\r\n                            isSmallerError = True\r\n                    else:\r\n                        if lastDist > dist:\r\n                            isSmallerError = True\r\n                if frame_clock>0 and not collide:\r\n                    reward = getReward(nextState, state, isSmallerError, isNewPipe)\r\n                    print(\"reward:\", reward)\r\n                    DM=updateNet(DM,reward, action, state,weight_trace_d1,weight_trace_d2)\r\n                state = nextState  #going on the next state\r\n                weight_trace_d1 = torch.zeros(con_matrix1.shape, dtype=torch.float)\r\n                weight_trace_d2 = torch.zeros(con_matrix1.shape, dtype=torch.float)\r\n                DM.reset()\r\n                display_frame += 1\r\n                # update and display score\r\n            for p in pipes:\r\n                if p.x + PipePair.WIDTH < bird.x and not p.score_counted:\r\n                    score += 1\r\n                    p.score_counted = True\r\n\r\n            score_surface = score_font.render('Current score: ' + str(score), True, (0, 0, 0))  # current score\r\n            score_x = WIN_WIDTH / 2 - 3 * score_surface.get_width() / 4\r\n            display_surface.blit(score_surface, (score_x, PipePair.PIECE_HEIGHT))\r\n            if heighest < score:\r\n                heighest = score\r\n            score_surface_h = score_font.render('Highest score: ' + str(heighest), True,\r\n                                                (0, 0, 0))  # heighest score\r\n            score_x_h = 4 * WIN_WIDTH / 5 - 1.2 * score_surface.get_width() / 3\r\n            display_surface.blit(score_surface_h, (score_x_h, PipePair.PIECE_HEIGHT))\r\n            score_surface_i = score_font.render('Attempts: ' + str(iteration), True, (0, 0, 0))  # heighest score\r\n            score_x_i = 10\r\n            display_surface.blit(score_surface_i, (score_x_i, PipePair.PIECE_HEIGHT))\r\n            frame_clock += 1\r\n            pygame.display.flip()\r\n\r\n        #  if collide, display the fail information, for 2 frames\r\n        cct = 0\r\n        while (bird.y < WIN_HEIGHT - Bird.HEIGHT - 3):\r\n            clock.tick(60)\r\n            for x in (0, WIN_WIDTH / 2):\r\n                display_surface.blit(images['background'], (x, 0))\r\n            while pipes and not pipes[0].visible:\r\n                pipes.popleft()\r\n            for p in pipes:\r\n                display_surface.blit(p.image, p.rect)\r\n            if cct >= 6:\r\n                bird.sink()\r\n            display_surface.blit(bird.image, bird.rect)\r\n            fail_infor = info_font.render('Game over !', True, (255, 60, 30))  # current score\r\n            pos_x = WIN_WIDTH / 2 - fail_infor.get_width() / 2\r\n            pos_y = WIN_HEIGHT / 2 - 100\r\n            display_surface.blit(fail_infor, (pos_x, pos_y))\r\n            #  display the score\r\n            score_surface = score_font.render('Current score: ' + str(score), True, (0, 0, 0))  # current score\r\n            score_x = WIN_WIDTH / 2 - 3 * score_surface.get_width() / 4\r\n            display_surface.blit(score_surface, (score_x, PipePair.PIECE_HEIGHT))\r\n            if heighest < score:\r\n                heighest = score\r\n            score_surface_h = score_font.render('Highest score: ' + str(heighest), True,\r\n                                                (0, 0, 0))  # heighest score\r\n            score_x_h = 4 * WIN_WIDTH / 5 - 1.2 * score_surface.get_width() / 3\r\n            display_surface.blit(score_surface_h, (score_x_h, PipePair.PIECE_HEIGHT))\r\n            score_surface_i = score_font.render('Attempts: ' + str(iteration), True, (0, 0, 0))  # heighest score\r\n            score_x_i = 10\r\n            display_surface.blit(score_surface_i, (score_x_i, PipePair.PIECE_HEIGHT))\r\n            pygame.display.flip()\r\n            cct += 1\r\n        if heighest < score:\r\n            heighest = score\r\n        contTime += 1\r\n        iteration += 1"
  },
  {
    "path": "examples/decision_making/RL/README.md",
    "content": "# PL-SDQN\n\nThis repository contains code from our paper [**Solving the Spike Feature Information Vanishing Problem in Spiking Deep Q Network with Potential Based Normalization**]. If you use our code or refer to this project, please cite this paper.\nTo run the PL-SDQN model, please install 'tianshou' framework first https://github.com/thu-ml/tianshou\n\n## Requirments\n\n* numpy\n* scipy\n* pytorch >= 1.7.0\n* torchvision\n* gym\n* atari-py\n* opencv-python\n* tianshou\n\n## Train\n\n```shell  \npython ./sdqn/main.py\n```\n\n\n\n## Citation\n\nIf you find this package helpful, please consider citing the following papers:\n\n```BibTex\n@ARTICLE{sun2022,\nAUTHOR={Sun, Yinqian and Zeng, Yi and Li, Yang},    \nTITLE={Solving the spike feature information vanishing problem in spiking deep Q network with potential based normalization},      \nJOURNAL={Frontiers in Neuroscience},      \nVOLUME={16},           \nYEAR={2022},      \t  \nURL={https://www.frontiersin.org/articles/10.3389/fnins.2022.953368},       \nDOI={10.3389/fnins.2022.953368},      \nISSN={1662-453X},   \n}\n\n@misc{https://doi.org/10.48550/arxiv.2207.08533,\n  doi = {10.48550/ARXIV.2207.08533},\n  url = {https://arxiv.org/abs/2207.08533},\n  author = {Zeng, Yi and Zhao, Dongcheng and Zhao, Feifei and Shen, Guobin and Dong, Yiting and Lu, Enmeng and Zhang, Qian and Sun, Yinqian and Liang, Qian and Zhao, Yuxuan and Zhao, Zhuoya and Fang, Hongjian and Wang, Yuwei and Li, Yang and Liu, Xin and Du, Chengcheng and Kong, Qingqun and Ruan, Zizhe and Bi, Weida},\n  title = {BrainCog: A Spiking Neural Network based Brain-inspired Cognitive Intelligence Engine for Brain-inspired AI and Brain Simulation},\n  publisher = {arXiv},\n  year = {2022},\n}\n```\n"
  },
  {
    "path": "examples/decision_making/RL/atari/__init__.py",
    "content": ""
  },
  {
    "path": "examples/decision_making/RL/atari/atari_wrapper.py",
    "content": "# Borrow a lot from openai baselines:\n# https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py\n\nfrom collections import deque\n\nimport cv2\nimport gym\nimport numpy as np\n\n\nclass NoopResetEnv(gym.Wrapper):\n    \"\"\"Sample initial states by taking random number of no-ops on reset.\n    No-op is assumed to be action 0.\n\n    :param gym.Env env: the environment to wrap.\n    :param int noop_max: the maximum value of no-ops to run.\n    \"\"\"\n\n    def __init__(self, env, noop_max=30):\n        super().__init__(env)\n        self.noop_max = noop_max\n        self.noop_action = 0\n        assert env.unwrapped.get_action_meanings()[0] == 'NOOP'\n\n    def reset(self):\n        self.env.reset()\n        noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)\n        for _ in range(noops):\n            obs, _, done, _ = self.env.step(self.noop_action)\n            if done:\n                obs = self.env.reset()\n        return obs\n\n\nclass MaxAndSkipEnv(gym.Wrapper):\n    \"\"\"Return only every `skip`-th frame (frameskipping) using most recent raw\n    observations (for max pooling across time steps)\n\n    :param gym.Env env: the environment to wrap.\n    :param int skip: number of `skip`-th frame.\n    \"\"\"\n\n    def __init__(self, env, skip=4):\n        super().__init__(env)\n        self._skip = skip\n\n    def step(self, action):\n        \"\"\"Step the environment with the given action. Repeat action, sum\n        reward, and max over last observations.\n        \"\"\"\n        obs_list, total_reward, done = [], 0., False\n        for _ in range(self._skip):\n            obs, reward, done, info = self.env.step(action)\n            obs_list.append(obs)\n            total_reward += reward\n            if done:\n                break\n        max_frame = np.max(obs_list[-2:], axis=0)\n        return max_frame, total_reward, done, info\n\n\nclass EpisodicLifeEnv(gym.Wrapper):\n    \"\"\"Make end-of-life == end-of-episode, but only reset on true game over. It\n    helps the value estimation.\n\n    :param gym.Env env: the environment to wrap.\n    \"\"\"\n\n    def __init__(self, env):\n        super().__init__(env)\n        self.lives = 0\n        self.was_real_done = True\n\n    def step(self, action):\n        obs, reward, done, info = self.env.step(action)\n        self.was_real_done = done\n        # check current lives, make loss of life terminal, then update lives to\n        # handle bonus lives\n        lives = self.env.unwrapped.ale.lives()\n        if 0 < lives < self.lives:\n            # for Qbert sometimes we stay in lives == 0 condition for a few\n            # frames, so its important to keep lives > 0, so that we only reset\n            # once the environment is actually done.\n            done = True\n        self.lives = lives\n        return obs, reward, done, info\n\n    def reset(self):\n        \"\"\"Calls the Gym environment reset, only when lives are exhausted. This\n        way all states are still reachable even though lives are episodic, and\n        the learner need not know about any of this behind-the-scenes.\n        \"\"\"\n        if self.was_real_done:\n            obs = self.env.reset()\n        else:\n            # no-op step to advance from terminal/lost life state\n            obs = self.env.step(0)[0]\n        self.lives = self.env.unwrapped.ale.lives()\n        return obs\n\n\nclass FireResetEnv(gym.Wrapper):\n    \"\"\"Take action on reset for environments that are fixed until firing.\n    Related discussion: https://github.com/openai/baselines/issues/240\n\n    :param gym.Env env: the environment to wrap.\n    \"\"\"\n\n    def __init__(self, env):\n        super().__init__(env)\n        assert env.unwrapped.get_action_meanings()[1] == 'FIRE'\n        assert len(env.unwrapped.get_action_meanings()) >= 3\n\n    def reset(self):\n        self.env.reset()\n        return self.env.step(1)[0]\n\n\nclass WarpFrame(gym.ObservationWrapper):\n    \"\"\"Warp frames to 84x84 as done in the Nature paper and later work.\n\n    :param gym.Env env: the environment to wrap.\n    \"\"\"\n\n    def __init__(self, env):\n        super().__init__(env)\n        self.size = 84\n        self.observation_space = gym.spaces.Box(\n            low=np.min(env.observation_space.low),\n            high=np.max(env.observation_space.high),\n            shape=(self.size, self.size),\n            dtype=env.observation_space.dtype\n        )\n\n    def observation(self, frame):\n        \"\"\"returns the current observation from a frame\"\"\"\n        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)\n        return cv2.resize(frame, (self.size, self.size), interpolation=cv2.INTER_AREA)\n\n\nclass ScaledFloatFrame(gym.ObservationWrapper):\n    \"\"\"Normalize observations to 0~1.\n\n    :param gym.Env env: the environment to wrap.\n    \"\"\"\n\n    def __init__(self, env):\n        super().__init__(env)\n        low = np.min(env.observation_space.low)\n        high = np.max(env.observation_space.high)\n        self.bias = low\n        self.scale = high - low\n        self.observation_space = gym.spaces.Box(\n            low=0., high=1., shape=env.observation_space.shape, dtype=np.float32\n        )\n\n    def observation(self, observation):\n        return (observation - self.bias) / self.scale\n\n\nclass ClipRewardEnv(gym.RewardWrapper):\n    \"\"\"clips the reward to {+1, 0, -1} by its sign.\n\n    :param gym.Env env: the environment to wrap.\n    \"\"\"\n\n    def __init__(self, env):\n        super().__init__(env)\n        self.reward_range = (-1, 1)\n\n    def reward(self, reward):\n        \"\"\"Bin reward to {+1, 0, -1} by its sign. Note: np.sign(0) == 0.\"\"\"\n        return np.sign(reward)\n\n\nclass FrameStack(gym.Wrapper):\n    \"\"\"Stack n_frames last frames.\n\n    :param gym.Env env: the environment to wrap.\n    :param int n_frames: the number of frames to stack.\n    \"\"\"\n\n    def __init__(self, env, n_frames):\n        super().__init__(env)\n        self.n_frames = n_frames\n        self.frames = deque([], maxlen=n_frames)\n        shape = (n_frames, ) + env.observation_space.shape\n        self.observation_space = gym.spaces.Box(\n            low=np.min(env.observation_space.low),\n            high=np.max(env.observation_space.high),\n            shape=shape,\n            dtype=env.observation_space.dtype\n        )\n\n    def reset(self):\n        obs = self.env.reset()\n        for _ in range(self.n_frames):\n            self.frames.append(obs)\n        return self._get_ob()\n\n    def step(self, action):\n        obs, reward, done, info = self.env.step(action)\n        self.frames.append(obs)\n        return self._get_ob(), reward, done, info\n\n    def _get_ob(self):\n        # the original wrapper use `LazyFrames` but since we use np buffer,\n        # it has no effect\n        return np.stack(self.frames, axis=0)\n\n\ndef wrap_deepmind(\n    env_id,\n    episode_life=True,\n    clip_rewards=True,\n    frame_stack=4,\n    scale=False,\n    warp_frame=True\n):\n    \"\"\"Configure environment for DeepMind-style Atari. The observation is\n    channel-first: (c, h, w) instead of (h, w, c).\n\n    :param str env_id: the atari environment id.\n    :param bool episode_life: wrap the episode life wrapper.\n    :param bool clip_rewards: wrap the reward clipping wrapper.\n    :param int frame_stack: wrap the frame stacking wrapper.\n    :param bool scale: wrap the scaling observation wrapper.\n    :param bool warp_frame: wrap the grayscale + resize observation wrapper.\n    :return: the wrapped atari environment.\n    \"\"\"\n    assert 'NoFrameskip' in env_id\n    env = gym.make(env_id)\n    env = NoopResetEnv(env, noop_max=30)\n    env = MaxAndSkipEnv(env, skip=4)\n    if episode_life:\n        env = EpisodicLifeEnv(env)\n    if 'FIRE' in env.unwrapped.get_action_meanings():\n        env = FireResetEnv(env)\n    if warp_frame:\n        env = WarpFrame(env)\n    if scale:\n        env = ScaledFloatFrame(env)\n    if clip_rewards:\n        env = ClipRewardEnv(env)\n    if frame_stack:\n        env = FrameStack(env, frame_stack)\n    return env\n"
  },
  {
    "path": "examples/decision_making/RL/mcs-fqf/discrete.py",
    "content": "from audioop import bias\nfrom time import time\nfrom typing import Any, Optional, Sequence, Tuple, Union\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom tianshou.data import Batch\nfrom braincog.base.node.node import LIFNode, ThreeCompNode\n\nclass SpikePopEncodingNetwork(nn.Module):\n    \"\"\"Cosine embedding network for IQN. Convert a scalar in [0, 1] to a list \\\n    of n-dim vectors.\n\n    :param num_cosines: the number of cosines used for the embedding.\n    :param embedding_dim: the dimension of the embedding/output.\n\n    .. note::\n\n        From https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master\n        /fqf_iqn_qrdqn/network.py .\n    \"\"\"\n\n    def __init__(self, num_cosines: int, embedding_dim: int, device, time_window: int=8) -> None:\n        super().__init__()\n        self._threshold = 0.5\n        # self._decay = 0.2\n        self._decay = 0.5\n        self.r_max = 0.5\n        # self.mus = torch.from_numpy(np.arange(num_cosines) / num_cosines)\n        \n        self.sigma = 0.05\n        # self.sigma = 0.01\n        self._node = LIFNode\n        self.net = nn.Sequential(\n            nn.Linear(num_cosines, embedding_dim), \n            self._node()\n            # self._node(threshold=self._threshold, decay=self._decay)\n            )\n        self.num_cosines = num_cosines\n        self.embedding_dim = embedding_dim\n        self.mus = torch.arange(0, num_cosines, device=device).view(1, 1, self.num_cosines) / num_cosines\n\n    def reset(self):\n        for mod in self.modules():\n            if hasattr(mod, 'n_reset'):\n                mod.n_reset()\n\n    def forward(self, taus: torch.Tensor, time_window: int) -> torch.Tensor:\n        batch_size = taus.shape[0]\n        N = taus.shape[1]\n        self.reset()\n\n\n        taus_lam = self.r_max * torch.exp(-(taus.unsqueeze(-1) - self.mus)**2/2/self.sigma**2).view(batch_size*N, self.num_cosines)\n        taus_repeat = taus_lam.unsqueeze(0).repeat(time_window, 1, 1)\n        taus_emb = torch.poisson(taus_repeat)\n      \n        tau_embeddings = []\n        \n        for i in range(time_window):\n            t_e = self.net(taus_emb[i])\n            tau_embeddings.append(t_e)\n        return tau_embeddings   \n\nclass SpikeFractionProposalNetwork(nn.Module):\n    \"\"\"Fraction proposal network for FQF.\n\n    :param num_fractions: the number of factions to propose.\n    :param embedding_dim: the dimension of the embedding/input.\n\n    .. note::\n\n        Adapted from https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master\n        /fqf_iqn_qrdqn/network.py .\n    \"\"\"\n\n    def __init__(self, num_fractions: int, embedding_dim: int) -> None:\n        super().__init__()\n        self.net = nn.Linear(embedding_dim, num_fractions)\n        torch.nn.init.xavier_uniform_(self.net.weight, gain=0.01)\n        torch.nn.init.constant_(self.net.bias, 0)\n        self.num_fractions = num_fractions\n        self.embedding_dim = embedding_dim\n\n    def forward(\n        self, state_embeddings: list\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        state_embeddings = torch.stack(state_embeddings).detach() \n        \n        time_window = state_embeddings.shape[0]\n        batch_size = state_embeddings.shape[1]\n        logits = self.net(state_embeddings.view(time_window*batch_size, -1))\n        logits = logits.view(time_window, batch_size, -1)\n        m = torch.distributions.Categorical(logits=logits.mean(0))\n        taus_1_N = torch.cumsum(m.probs, dim=1)\n        taus = F.pad(taus_1_N, (1, 0))\n        tau_hats = (taus[:, :-1] + taus[:, 1:]).detach() / 2.0\n        entropies = m.entropy()\n        return taus, tau_hats, entropies \n\n\nclass MCQuantiles(nn.Module):\n    def __init__(self, \n            state_embedings_shape: int, \n            tau_embeddings_shape: int,\n            hidden_size: int, \n            last_size: int,\n            fusion_size : int=512,\n            tau_s : int = 2.0):\n        super().__init__()\n       \n        self.basal_w = nn.Linear(state_embedings_shape, fusion_size, bias=False)\n        self.apical_w = nn.Linear(tau_embeddings_shape, fusion_size, bias=False)\n        self._node = LIFNode\n        self.mc_node = ThreeCompNode()\n        \n        self._last = nn.Sequential(\n            nn.Linear(fusion_size, hidden_size),           \n            self._node(),\n            nn.Linear(hidden_size, last_size),\n        )\n\n    def reset(self):\n        for mod in self.modules():\n            if hasattr(mod, 'n_reset'):\n                mod.n_reset()\n    \n    def forward(self, state_embedding, tau_embedding):\n        \"\"\"\n        state_embedding: list\n        tau_embedding: torch.Tensor\n        \"\"\"\n        self.reset()\n        \n        assert isinstance(state_embedding, type(tau_embedding))\n        if isinstance(state_embedding, list):\n            time_window = len(state_embedding) \n            \n        elif isinstance(state_embedding, torch.Tensor):\n            time_window = state_embedding.shape[0]\n        else:\n            raise TypeError('Not support data type.')\n        batch_size = state_embedding[0].shape[0]\n        sample_size = tau_embedding[0].shape[0] // batch_size \n\n\n        quantiles = []\n      \n        for step in range(time_window):\n           \n            basal_psp = self.basal_w(state_embedding[step]).unsqueeze(1)     \n            apical_psp = self.apical_w(tau_embedding[step]).view(batch_size, sample_size, -1)  \n            embeddings = self.mc_node({'basal_inputs': basal_psp, 'apical_inputs': apical_psp}).view(batch_size*sample_size, -1) \n            out = self._last(embeddings)    \n            quantiles.append(out)\n        \n        quantiles = sum(quantiles) / time_window\n        return quantiles.view(batch_size, sample_size, -1).transpose(1, 2)\n\n   \n    \nclass SpikeFullQuantileFunction(nn.Module):\n    \"\"\"Full(y parameterized) Quantile Function.\n\n    :param preprocess_net: a self-defined preprocess_net which output a\n        flattened hidden state.\n    :param int action_dim: the dimension of action space.\n    :param hidden_sizes: a sequence of int for constructing the MLP after\n        preprocess_net. Default to empty sequence (where the MLP now contains\n        only a single linear layer).\n    :param int num_cosines: the number of cosines to use for cosine embedding.\n        Default to 64.\n    :param int preprocess_net_output_dim: the output dimension of\n        preprocess_net.\n\n    .. note::\n\n        The first return value is a tuple of (quantiles, fractions, quantiles_tau),\n        where fractions is a Batch(taus, tau_hats, entropies).\n    \"\"\"\n\n    def __init__(\n        self,\n        preprocess_net: nn.Module,\n        action_shape: Sequence[int],\n        hidden_sizes: Sequence[int] = (),\n        num_cosines: int = 64,\n        preprocess_net_output_dim: Optional[int] = None,\n        device: Union[str, int, torch.device] = \"cpu\",\n    ) -> None:\n        super().__init__()\n        self.device = device\n        self.last_size = np.prod(action_shape)\n        self.preprocess = preprocess_net\n        self.input_dim = getattr(\n            self.preprocess, \"output_dim\", preprocess_net_output_dim\n        )\n        self.embed_model = SpikePopEncodingNetwork(num_cosines,\n                                                  self.input_dim, device=device).to(device)\n        self.mcquantiles = MCQuantiles(self.input_dim, self.input_dim, hidden_size=np.prod(hidden_sizes),\n                                     last_size=action_shape).to(device)\n    def forward(  # type: ignore\n        self, s: Union[np.ndarray, torch.Tensor],\n        propose_model: SpikeFractionProposalNetwork,\n        fractions: Optional[Batch] = None,\n        **kwargs: Any\n    ) -> Tuple[Any, torch.Tensor]:\n        r\"\"\"Mapping: s -> Q(s, \\*).\"\"\"\n        logits, h = self.preprocess(s, state=kwargs.get(\"state\", None))  \n        # Propose fractions\n        if fractions is None:\n            taus, tau_hats, entropies = propose_model(logits)\n            fractions = Batch(taus=taus, tau_hats=tau_hats, entropies=entropies)\n        else:\n            taus, tau_hats = fractions.taus, fractions.tau_hats\n\n        time_window = len(logits)\n        tau_hats_emb = self.embed_model(tau_hats, time_window)\n        \n        quantiles = self.mcquantiles(logits, tau_hats_emb)\n        \n        quantiles_tau = None\n        if self.training:\n            with torch.no_grad():\n                tau_emb = self.embed_model(taus[:, 1:-1], time_window)        \n                quantiles_tau = self.mcquantiles(logits, tau_emb)\n        return (quantiles, fractions, quantiles_tau), h\n\n\n\n"
  },
  {
    "path": "examples/decision_making/RL/mcs-fqf/main.py",
    "content": "import argparse\nimport os\nimport pprint\nimport numpy as np\nimport torch\nfrom network import SpikingDQN\nfrom ..atari.atari_wrapper import wrap_deepmind\nfrom torch.utils.tensorboard import SummaryWriter\n\nfrom tianshou.data import Collector, VectorReplayBuffer\nfrom tianshou.env import ShmemVectorEnv\nfrom tianshou.trainer import offpolicy_trainer\nfrom tianshou.utils import TensorboardLogger, SequenceLogger\nfrom discrete import SpikeFractionProposalNetwork, SpikeFullQuantileFunction\nfrom policy import FQFPolicy\n\ndef get_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--task', type=str, default='MsPacmanNoFrameskip-v4')\n    parser.add_argument('--seed', type=int, default=3128)\n    parser.add_argument('--eps-test', type=float, default=0.005)\n    parser.add_argument('--eps-train', type=float, default=1.)\n    parser.add_argument('--eps-train-final', type=float, default=0.05)\n    parser.add_argument('--buffer-size', type=int, default=100000)\n    parser.add_argument('--lr', type=float, default=1e-4)\n    parser.add_argument('--fraction-lr', type=float, default=2.5e-9)\n    parser.add_argument('--gamma', type=float, default=0.99)\n    parser.add_argument('--num-fractions', type=int, default=32)\n    parser.add_argument('--num-cosines', type=int, default=64)\n    parser.add_argument('--ent-coef', type=float, default=10.)\n    parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512])\n    parser.add_argument('--n-step', type=int, default=3)\n    parser.add_argument('--target-update-freq', type=int, default=500)\n    parser.add_argument('--epoch', type=int, default=200)\n    parser.add_argument('--step-per-epoch', type=int, default=100000)\n    parser.add_argument('--step-per-collect', type=int, default=10)\n    # parser.add_argument('--update-per-step', type=float, default=0.1)\n    parser.add_argument('--update-per-step', type=float, default=0.1)\n    parser.add_argument('--batch-size', type=int, default=32)\n    parser.add_argument('--training-num', type=int, default=10)\n    parser.add_argument('--test-num', type=int, default=10)\n    parser.add_argument('--logdir', type=str, default='log')\n    parser.add_argument('--render', type=float, default=0.)\n    parser.add_argument(\n        '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'\n    )\n    parser.add_argument('--frames-stack', type=int, default=4)\n    parser.add_argument('--resume-path', type=str, default=None)\n    parser.add_argument('--resume-id', type=str, default=None)\n    parser.add_argument(\n        '--watch',\n        default=False,\n        action='store_true',\n        help='watch the play of pre-trained policy only'\n    )\n    parser.add_argument('--save-buffer-name', type=str, default=None)\n\n    parser.add_argument('--time-window', type=int, default=8)\n    parser.add_argument('--prefix', type=str, default='')\n\n    parser.add_argument('--save-interval', type=int, default=10)\n\n    \n    return parser.parse_args()\n\n\ndef make_atari_env(args):\n    return wrap_deepmind(args.task, frame_stack=args.frames_stack)\n\n\ndef make_atari_env_watch(args):\n    return wrap_deepmind(\n        args.task,\n        frame_stack=args.frames_stack,\n        episode_life=False,\n        clip_rewards=False\n    )\n\n\ndef main(args=get_args()):\n    print('Setting: ', args)\n    env = make_atari_env(args)\n    args.state_shape = env.observation_space.shape or env.observation_space.n\n    args.action_shape = env.action_space.shape or env.action_space.n\n    # should be N_FRAMES x H x W\n    print(\"Observations shape:\", args.state_shape)\n    print(\"Actions shape:\", args.action_shape)\n    print('update_per_step: ', args.update_per_step)\n    print('lr: ', args.lr)\n    # make environments\n    train_envs = ShmemVectorEnv(\n        [lambda: make_atari_env(args) for _ in range(args.training_num)]\n    )\n    test_envs = ShmemVectorEnv(\n        [lambda: make_atari_env_watch(args) for _ in range(args.test_num)]\n    )\n    # define model\n    feature_net = SpikingDQN(\n        *args.state_shape, args.action_shape, args.device, time_window=args.time_window, features_only=True\n    )\n    net = SpikeFullQuantileFunction(\n        feature_net,\n        args.action_shape,\n        args.hidden_sizes,\n        args.num_cosines,\n        device=args.device,\n    ).to(args.device)\n    optim = torch.optim.Adam(net.parameters(), lr=args.lr)\n    fraction_net = SpikeFractionProposalNetwork(args.num_fractions, net.input_dim)\n    fraction_optim = torch.optim.RMSprop(\n        fraction_net.parameters(), lr=args.fraction_lr\n    )\n    # define policy\n    policy = FQFPolicy(\n        net,\n        optim,\n        fraction_net,\n        fraction_optim,\n        args.gamma,\n        args.num_fractions,\n        args.ent_coef,\n        args.n_step,\n        target_update_freq=args.target_update_freq\n    ).to(args.device)\n    # load a previous policy\n    if args.resume_path:\n        policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))\n        print(\"Loaded agent from: \", args.resume_path)\n    # replay buffer: `save_last_obs` and `stack_num` can be removed together\n    # when you have enough RAM\n    buffer = VectorReplayBuffer(\n        args.buffer_size,\n        buffer_num=len(train_envs),\n        ignore_obs_next=True,\n        save_only_last_obs=True,\n        stack_num=args.frames_stack\n    )\n    # collector\n    train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)\n    test_collector = Collector(policy, test_envs, exploration_noise=True)\n    # log\n    log_path = os.path.join(args.logdir, args.task, 'spike_fqf', args.prefix)\n    model_log_path = os.path.join(log_path, 'models')\n    if not os.path.exists(model_log_path):\n        os.makedirs(model_log_path)\n    print('log_path: ', log_path)\n   \n    writer = SummaryWriter(log_path)\n    writer.add_text(\"args\", str(args))\n    logger = TensorboardLogger(writer, save_interval=args.save_interval)\n    result_logger = SequenceLogger(log_path)\n\n    def save_checkpoint_fn(epoch, env_step, gradient_step, epoch_round=True):\n        # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html\n        if epoch_round:\n            ckpt_path = os.path.join(model_log_path, 'checkpoint_epoch{}.pth'.format(epoch))\n        else:\n            ckpt_path = os.path.join(model_log_path, 'checkpoint.pth')\n        ckpt = {\n            'epoch': epoch,\n            'env_step': env_step,\n            'gradient_step': gradient_step,\n            'model': policy.state_dict()\n        }\n        torch.save(ckpt, ckpt_path)\n        return ckpt_path\n\n    setting_path = os.path.join(log_path, 'settings.txt')\n    argsDict = args.__dict__\n    with open(setting_path, 'w') as f:\n        f.writelines('------------------ start ------------------' + '\\n')\n        for eachArg, value in argsDict.items():\n            f.writelines(eachArg + ' : ' + str(value) + '\\n')\n        f.writelines('------------------- end -------------------')\n        \n\n    def save_fn(policy, is_best=False):\n        if is_best:\n            torch.save(policy.state_dict(), os.path.join(log_path, 'best_policy.pth'))\n        else:\n            torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))\n\n    def stop_fn(mean_rewards):\n        if env.spec.reward_threshold:\n            return mean_rewards >= env.spec.reward_threshold\n        elif 'Pong' in args.task:\n            return mean_rewards >= 20\n        else:\n            return False\n\n    def train_fn(epoch, env_step):\n        # nature DQN setting, linear decay in the first 1M steps\n        if env_step <= 1e6:\n            eps = args.eps_train - env_step / 1e6 * \\\n                (args.eps_train - args.eps_train_final)\n        else:\n            eps = args.eps_train_final\n        policy.set_eps(eps)\n        if env_step % 1000 == 0:\n            logger.write(\"train/env_step\", env_step, {\"train/eps\": eps})\n\n    def test_fn(epoch, env_step):\n        policy.set_eps(args.eps_test)\n\n    # watch agent's performance\n    def watch():\n        print(\"Setup test envs ...\")\n        policy.eval()\n        policy.set_eps(args.eps_test)\n        test_envs.seed(args.seed)\n        if args.save_buffer_name:\n            print(f\"Generate buffer with size {args.buffer_size}\")\n            buffer = VectorReplayBuffer(\n                args.buffer_size,\n                buffer_num=len(test_envs),\n                ignore_obs_next=True,\n                save_only_last_obs=True,\n                stack_num=args.frames_stack\n            )\n            collector = Collector(policy, test_envs, buffer, exploration_noise=True)\n            result = collector.collect(n_step=args.buffer_size)\n            print(f\"Save buffer into {args.save_buffer_name}\")\n            buffer.save_hdf5(args.save_buffer_name)\n        else:\n            print(\"Testing agent ...\")\n            test_collector.reset()\n            result = test_collector.collect(\n                n_episode=args.test_num, render=args.render\n            )\n        rew = result[\"rews\"].mean()\n        print(f'Mean reward (over {result[\"n/ep\"]} episodes): {rew}')\n\n    if args.watch:\n        watch()\n        exit(0)\n\n    # test train_collector and start filling replay buffer\n    train_collector.collect(n_step=args.batch_size * args.training_num)\n    # trainer\n    result = offpolicy_trainer(\n        policy,\n        train_collector,\n        test_collector,\n        args.epoch,\n        args.step_per_epoch,\n        args.step_per_collect,\n        args.test_num,\n        args.batch_size,\n        train_fn=train_fn,\n        test_fn=test_fn,\n        stop_fn=stop_fn,\n        save_fn=save_fn,\n        logger=logger,\n        update_per_step=args.update_per_step,\n        test_in_train=False,\n        resume_from_log=args.resume_id is not None,\n        save_checkpoint_fn=save_checkpoint_fn,\n        result_logger=result_logger\n    )\n\n    pprint.pprint(result)\n    watch()\n\n\nif __name__ == '__main__':\n    main(get_args())\n"
  },
  {
    "path": "examples/decision_making/RL/mcs-fqf/network.py",
    "content": "from typing import Any, Dict, Optional, Sequence, Tuple, Union\n\nimport numpy as np\nimport torch\nfrom torch import nn\n\nfrom braincog.base.node.node import LIFNode\nfrom ..utils.normalization import PopNorm\n\n\nclass SpikingDQN(nn.Module):\n    \"\"\"Reference: Human-level control through deep reinforcement learning.\n\n    For advanced usage (how to customize the network), please refer to\n    :ref:`build_the_network`.\n    \"\"\"\n\n    def __init__(\n        self,\n        c: int,\n        h: int,\n        w: int,\n        action_shape: Sequence[int],\n        device: Union[str, int, torch.device] = \"cpu\",\n        time_window: int = 16,\n        features_only: bool = False,\n    ) -> None:\n        super().__init__()\n        self._node = LIFNode\n        # self._node = ReLUNode\n        self.features_only = features_only\n        self.device = device\n        # self._threshold = 0.5\n        self._threshold = 1.0\n        self.v_reset = 0.0\n        # self._decay = 0.2\n        self._decay = 0.5\n        self._time_window = time_window\n\n        init_layer = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), gain=1)\n        self.p_count = 0\n\n        self.net = nn.Sequential(\n           # nn.utils.weight_norm(nn.Conv2d(c, 32, kernel_size=8, stride=4)),\n            nn.Conv2d(c, 32, kernel_size=8, stride=4),\n            # nn.BatchNorm2d(32),\n            # self._node(threshold=self._threshold, decay=self._decay),\n            PopNorm([32, 20, 20], threshold=self._threshold, v_reset=self.v_reset),\n\n            self._node(threshold=self._threshold, v_reset=self.v_reset),\n            # nn.utils.weight_norm(nn.Conv2d(32, 64, kernel_size=4, stride=2)),\n            nn.Conv2d(32, 64, kernel_size=4, stride=2),\n            # nn.BatchNorm2d(64),\n            # self._node(threshold=self._threshold, decay=self._decay),\n            PopNorm([64, 9, 9], threshold=self._threshold, v_reset=self.v_reset),\n            self._node(threshold=self._threshold, v_reset=self.v_reset),\n            # nn.utils.weight_norm(nn.Conv2d(64, 64, kernel_size=3, stride=1)),\n            nn.Conv2d(64, 64, kernel_size=3, stride=1),\n            # nn.BatchNorm2d(64),\n            PopNorm([64, 7, 7], threshold=self._threshold, v_reset=self.v_reset),\n            # self._node(threshold=self._threshold, decay=self._decay),\n            self._node(threshold=self._threshold, v_reset=self.v_reset),\n            nn.Flatten()\n        )\n        with torch.no_grad():\n            self.output_dim = np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:])\n        if not features_only:\n            self.net = nn.Sequential(\n                self.net, nn.Linear(self.output_dim, 512),\n                # self.net, nn.Linear(self.output_dim, 512),\n                # self._node(threshold=self._threshold, decay=self._decay),\n                self._node(threshold=self._threshold, v_reset=self.v_reset),\n                # nn.Linear(512, np.prod(action_shape))\n                nn.Linear(512, np.prod(action_shape), bias=False)\n            )\n            self.output_dim = np.prod(action_shape)\n\n       \n       \n    def reset(self):\n        for mod in self.modules():\n            if hasattr(mod, 'n_reset'):\n                mod.n_reset()\n        \n    def forward(\n        self,\n        x: Union[np.ndarray, torch.Tensor],\n        state: Optional[Any] = None,\n        info: Dict[str, Any] = {},\n    ) -> Tuple[torch.Tensor, Any]:\n        r\"\"\"Mapping: x -> Q(x, \\*).\"\"\"\n        self.reset()\n        # obs = torch.as_tensor(x, device=self.device, dtype=torch.float32) \n        x = torch.as_tensor(x, device=self.device, dtype=torch.float32) / 255.0\n\n        qs = []\n\n        for i in range(self._time_window):\n            value = self.net(x)\n            qs.append(value)\n        if self.features_only:\n            return qs, state\n        else:\n            q_values = sum(qs) / self._time_window    \n            return q_values, state\n            "
  },
  {
    "path": "examples/decision_making/RL/mcs-fqf/policy.py",
    "content": "from typing import Any, Dict, Optional, Union\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\n\nfrom tianshou.data import Batch, ReplayBuffer, to_numpy\nfrom tianshou.policy import DQNPolicy, QRDQNPolicy\nfrom discrete import SpikeFractionProposalNetwork, SpikeFullQuantileFunction\n\n\nclass FQFPolicy(QRDQNPolicy):\n    \"\"\"Implementation of Fully-parameterized Quantile Function. arXiv:1911.02140.\n\n    :param torch.nn.Module model: a model following the rules in\n        :class:`~tianshou.policy.BasePolicy`. (s -> logits)\n    :param torch.optim.Optimizer optim: a torch.optim for optimizing the model.\n    :param FractionProposalNetwork fraction_model: a FractionProposalNetwork for\n        proposing fractions/quantiles given state.\n    :param torch.optim.Optimizer fraction_optim: a torch.optim for optimizing\n        the fraction model above.\n    :param float discount_factor: in [0, 1].\n    :param int num_fractions: the number of fractions to use. Default to 32.\n    :param float ent_coef: the coefficient for entropy loss. Default to 0.\n    :param int estimation_step: the number of steps to look ahead. Default to 1.\n    :param int target_update_freq: the target network update frequency (0 if\n        you do not use the target network).\n    :param bool reward_normalization: normalize the reward to Normal(0, 1).\n        Default to False.\n\n    .. seealso::\n\n        Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed\n        explanation.\n    \"\"\"\n\n    def __init__(\n        self,\n        model: SpikeFullQuantileFunction,\n        optim: torch.optim.Optimizer,\n        fraction_model: SpikeFractionProposalNetwork,\n        fraction_optim: torch.optim.Optimizer,\n        discount_factor: float = 0.99,\n        num_fractions: int = 32,\n        ent_coef: float = 0.0,\n        estimation_step: int = 1,\n        target_update_freq: int = 0,\n        reward_normalization: bool = False,\n        **kwargs: Any,\n    ) -> None:\n        super().__init__(\n            model, optim, discount_factor, num_fractions, estimation_step,\n            target_update_freq, reward_normalization, **kwargs\n        )\n        self.propose_model = fraction_model\n        self._ent_coef = ent_coef\n        self._fraction_optim = fraction_optim\n\n    def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:\n        batch = buffer[indices]  # batch.obs_next: s_{t+n}\n        if self._target:\n            result = self(batch, input=\"obs_next\")\n            a, fractions = result.act, result.fractions\n            next_dist = self(\n                batch, model=\"model_old\", input=\"obs_next\", fractions=fractions\n            ).logits\n        else:\n            next_b = self(batch, input=\"obs_next\")\n            a = next_b.act\n            next_dist = next_b.logits\n        next_dist = next_dist[np.arange(len(a)), a, :]\n        return next_dist  # shape: [bsz, num_quantiles]\n\n    def forward(\n        self,\n        batch: Batch,\n        state: Optional[Union[dict, Batch, np.ndarray]] = None,\n        model: str = \"model\",\n        input: str = \"obs\",\n        fractions: Optional[Batch] = None,\n        **kwargs: Any,\n    ) -> Batch:\n        model = getattr(self, model)\n        obs = batch[input]\n        obs_ = obs.obs if hasattr(obs, \"obs\") else obs\n        # print('fractions: ', fractions)\n        if fractions is None:\n            (logits, fractions, quantiles_tau), h = model(\n                obs_, propose_model=self.propose_model, state=state, info=batch.info\n            )\n        else:\n            (logits, _, quantiles_tau), h = model(\n                obs_,\n                propose_model=self.propose_model,\n                fractions=fractions,\n                state=state,\n                info=batch.info\n            )\n        # print('fractions.taus shape : ', fractions.taus.shape)\n        # print('logits shape: ', logits.shape)\n        weighted_logits = (fractions.taus[:, 1:] -\n                           fractions.taus[:, :-1]).unsqueeze(1) * logits\n        # print('weighted_logits shape: ', weighted_logits.shape)\n        q = DQNPolicy.compute_q_value(\n            self, weighted_logits.sum(2), getattr(obs, \"mask\", None)\n        )\n        if not hasattr(self, \"max_action_num\"):\n            self.max_action_num = q.shape[1]\n        act = to_numpy(q.max(dim=1)[1])\n        return Batch(\n            logits=logits,\n            act=act,\n            state=h,\n            fractions=fractions,\n            quantiles_tau=quantiles_tau\n        )\n\n    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:\n        if self._target and self._iter % self._freq == 0:\n            self.sync_weight()\n        weight = batch.pop(\"weight\", 1.0)\n        out = self(batch)\n        curr_dist_orig = out.logits\n        taus, tau_hats = out.fractions.taus, out.fractions.tau_hats\n        act = batch.act\n        curr_dist = curr_dist_orig[np.arange(len(act)), act, :].unsqueeze(2)\n        target_dist = batch.returns.unsqueeze(1)\n        # calculate each element's difference between curr_dist and target_dist\n        u = F.smooth_l1_loss(target_dist, curr_dist, reduction=\"none\")\n        huber_loss = (\n            u * (\n                tau_hats.unsqueeze(2) -\n                (target_dist - curr_dist).detach().le(0.).float()\n            ).abs()\n        ).sum(-1).mean(1)\n        quantile_loss = (huber_loss * weight).mean()\n        # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/\n        # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130\n        batch.weight = u.detach().abs().sum(-1).mean(1)  # prio-buffer\n        # calculate fraction loss\n        with torch.no_grad():\n            sa_quantile_hats = curr_dist_orig[np.arange(len(act)), act, :]\n            sa_quantiles = out.quantiles_tau[np.arange(len(act)), act, :]\n            # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/\n            # blob/master/fqf_iqn_qrdqn/agent/fqf_agent.py L169\n            values_1 = sa_quantiles - sa_quantile_hats[:, :-1]\n            signs_1 = sa_quantiles > torch.cat(\n                [sa_quantile_hats[:, :1], sa_quantiles[:, :-1]], dim=1\n            )\n\n            values_2 = sa_quantiles - sa_quantile_hats[:, 1:]\n            signs_2 = sa_quantiles < torch.cat(\n                [sa_quantiles[:, 1:], sa_quantile_hats[:, -1:]], dim=1\n            )\n\n            gradient_of_taus = (\n                torch.where(signs_1, values_1, -values_1) +\n                torch.where(signs_2, values_2, -values_2)\n            )\n        fraction_loss = (gradient_of_taus * taus[:, 1:-1]).sum(1).mean()\n        # calculate entropy loss\n        entropy_loss = out.fractions.entropies.mean()\n        fraction_entropy_loss = fraction_loss - self._ent_coef * entropy_loss\n        self._fraction_optim.zero_grad()\n        fraction_entropy_loss.backward(retain_graph=True)\n        self._fraction_optim.step()\n        self.optim.zero_grad()\n        quantile_loss.backward()\n        self.optim.step()\n        self._iter += 1\n        return {\n            \"loss\": quantile_loss.item() + fraction_entropy_loss.item(),\n            \"loss/quantile\": quantile_loss.item(),\n            \"loss/fraction\": fraction_loss.item(),\n            \"loss/entropy\": entropy_loss.item()\n        }\n"
  },
  {
    "path": "examples/decision_making/RL/requirements.txt",
    "content": "gym\natari-py\nopencv-python\ntianshou\n"
  },
  {
    "path": "examples/decision_making/RL/sdqn/main.py",
    "content": "import argparse\nimport os\nimport pprint\n\nimport numpy as np\nimport torch\n\ntry:\n    import tianshou\nexcept:\n    raise ImportError('Need install \"tianshou\" lib at  https://github.com/thu-ml/tianshou !')\nfrom tianshou.data import Collector, VectorReplayBuffer\nfrom tianshou.env import ShmemVectorEnv\nfrom tianshou.policy import DQNPolicy\nfrom tianshou.trainer import offpolicy_trainer\nfrom tianshou.utils import TensorboardLogger, WandbLogger\nimport random \n\nfrom network import SpikingDQN\nfrom ..atari.atari_wrapper import wrap_deepmind\nfrom torch.utils.tensorboard import SummaryWriter\n\n\n\n\n\ndef get_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')\n    parser.add_argument('--seed', type=int, default=40)\n    parser.add_argument('--eps-test', type=float, default=0.005)\n    parser.add_argument('--eps-train', type=float, default=1.)\n    parser.add_argument('--eps-train-final', type=float, default=0.05)\n    parser.add_argument('--buffer-size', type=int, default=100000)\n    parser.add_argument('--lr', type=float, default=0.0001)\n    parser.add_argument('--gamma', type=float, default=0.99)\n    parser.add_argument('--n-step', type=int, default=3)\n    parser.add_argument('--target-update-freq', type=int, default=500)\n    parser.add_argument('--epoch', type=int, default=100)\n    parser.add_argument('--step-per-epoch', type=int, default=100000)\n    parser.add_argument('--step-per-collect', type=int, default=10)\n    parser.add_argument('--update-per-step', type=float, default=0.1)\n    parser.add_argument('--batch-size', type=int, default=32)\n    parser.add_argument('--training-num', type=int, default=10)\n    parser.add_argument('--test-num', type=int, default=10)\n    parser.add_argument('--logdir', type=str, default='log')\n    parser.add_argument('--render', type=float, default=0.)\n    parser.add_argument(\n        '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'\n    )\n    parser.add_argument('--frames-stack', type=int, default=4)\n    parser.add_argument('--resume-path', type=str, default=None)\n    parser.add_argument('--resume-id', type=str, default=None)\n    parser.add_argument('--time-window', type=int, default=16)\n    parser.add_argument(\n        '--spike',\n        default=False,\n        action='store_true',\n        help='execute spike dqn'\n    )\n    parser.add_argument(\n        '--logger',\n        type=str,\n        default=\"tensorboard\",\n        choices=[\"tensorboard\", \"wandb\"],\n    )\n    parser.add_argument(\n        '--watch',\n        default=False,\n        action='store_true',\n        help='watch the play of pre-trained policy only'\n    )\n    parser.add_argument('--save-buffer-name', type=str, default=None)\n    return parser.parse_args()\n\n\ndef make_atari_env(args):\n    return wrap_deepmind(args.task, frame_stack=args.frames_stack)\n\n\ndef make_atari_env_watch(args):\n    return wrap_deepmind(\n        args.task,\n        frame_stack=args.frames_stack,\n        episode_life=False,\n        clip_rewards=False\n    )\n\n\ndef main(args=get_args()):\n    print('n_step: ', args.n_step)\n    env = make_atari_env(args)\n    args.state_shape = env.observation_space.shape or env.observation_space.n\n    args.action_shape = env.action_space.shape or env.action_space.n\n    # should be N_FRAMES x H x W\n    print(\"Observations shape:\", args.state_shape)\n    print(\"Actions shape:\", args.action_shape)\n    print('logdir', args.logdir)\n    print('Spiking', args.spike)\n    # make environments\n    train_envs = ShmemVectorEnv(\n        [lambda: make_atari_env(args) for _ in range(args.training_num)]\n    )\n    test_envs = ShmemVectorEnv(\n        [lambda: make_atari_env_watch(args) for _ in range(args.test_num)]\n    )\n    # seed\n    os.environ['PYTHONHASHSEED'] = str(args.seed)\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n    torch.manual_seed(args.seed)\n    torch.cuda.manual_seed_all(args.seed)\n    train_envs.seed(args.seed)\n    test_envs.seed(args.seed)\n    torch.backends.cudnn.deterministic = True\n    torch.backends.cudnn.benchmark = False\n   \n    net = SpikingDQN(*args.state_shape, args.action_shape, args.device, args.time_window).to(args.device)\n\n    # optim = torch.optim.Adam(net.parameters(), lr=args.lr)\n    optim = torch.optim.AdamW(net.parameters(), lr=args.lr)\n    # define policy\n    policy = DQNPolicy(\n        net,\n        optim,\n        args.gamma,\n        args.n_step,\n        target_update_freq=args.target_update_freq\n    )\n    # load a previous policy\n    if args.resume_path:\n        policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))\n        print(\"Loaded agent from: \", args.resume_path)\n    # replay buffer: `save_last_obs` and `stack_num` can be removed together\n    # when you have enough RAM\n    buffer = VectorReplayBuffer(\n        args.buffer_size,\n        buffer_num=len(train_envs),\n        ignore_obs_next=True,\n        save_only_last_obs=True,\n        stack_num=args.frames_stack\n    )\n    # collector\n    train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)\n    test_collector = Collector(policy, test_envs, exploration_noise=True)\n    # log\n    \n    log_path = os.path.join(args.logdir, args.task, 'csdqn')\n    if args.logger == \"tensorboard\":\n        writer = SummaryWriter(log_path)\n        writer.add_text(\"args\", str(args))\n        logger = TensorboardLogger(writer)\n    else:\n        logger = WandbLogger(\n            save_interval=1,\n            project=args.task,\n            name='dqn',\n            run_id=args.resume_id,\n            config=args,\n        )\n\n    def save_fn(policy):\n        torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))\n\n    def stop_fn(mean_rewards):\n        if env.spec.reward_threshold:\n            return mean_rewards >= env.spec.reward_threshold\n        elif 'Pong' in args.task:\n            return mean_rewards >= 20\n        else:\n            return False\n\n    def train_fn(epoch, env_step):\n        # nature DQN setting, linear decay in the first 1M steps\n        if env_step <= 1e6:\n            eps = args.eps_train - env_step / 1e6 * \\\n                (args.eps_train - args.eps_train_final)\n        else:\n            eps = args.eps_train_final\n        policy.set_eps(eps)\n        if env_step % 1000 == 0:\n            logger.write(\"train/env_step\", env_step, {\"train/eps\": eps})\n\n    def test_fn(epoch, env_step):\n        policy.set_eps(args.eps_test)\n\n    def save_checkpoint_fn(epoch, env_step, gradient_step):\n        # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html\n        ckpt_path = os.path.join(log_path, 'checkpoint.pth')\n        torch.save({'model': policy.state_dict()}, ckpt_path)\n        return ckpt_path\n\n    # watch agent's performance\n    def watch():\n        print(\"Setup test envs ...\")\n        policy.eval()\n        policy.set_eps(args.eps_test)\n        test_envs.seed(args.seed)\n        if args.save_buffer_name:\n            print(f\"Generate buffer with size {args.buffer_size}\")\n            buffer = VectorReplayBuffer(\n                args.buffer_size,\n                buffer_num=len(test_envs),\n                ignore_obs_next=True,\n                save_only_last_obs=True,\n                stack_num=args.frames_stack\n            )\n            collector = Collector(policy, test_envs, buffer, exploration_noise=True)\n            result = collector.collect(n_step=args.buffer_size)\n            print(f\"Save buffer into {args.save_buffer_name}\")\n            # Unfortunately, pickle will cause oom with 1M buffer size\n            buffer.save_hdf5(args.save_buffer_name)\n        else:\n            print(\"Testing agent ...\")\n            test_collector.reset()\n            result = test_collector.collect(\n                n_episode=args.test_num, render=args.render\n            )\n        rew = result[\"rews\"].mean()\n        print(f'Mean reward (over {result[\"n/ep\"]} episodes): {rew}')\n\n    if args.watch:\n        watch()\n        exit(0)\n\n    # test train_collector and start filling replay buffer\n    train_collector.collect(n_step=args.batch_size * args.training_num)\n    # trainer\n    result = offpolicy_trainer(\n        policy,\n        train_collector,\n        test_collector,\n        args.epoch,\n        args.step_per_epoch,\n        args.step_per_collect,\n        args.test_num,\n        args.batch_size,\n        train_fn=train_fn,\n        test_fn=test_fn,\n        stop_fn=stop_fn,\n        save_fn=save_fn,\n        logger=logger,\n        update_per_step=args.update_per_step,\n        test_in_train=False,\n        resume_from_log=args.resume_id is not None,\n        save_checkpoint_fn=save_checkpoint_fn,\n    )\n\n    pprint.pprint(result)\n    watch()\n\n\nif __name__ == '__main__':\n    main(get_args())\n"
  },
  {
    "path": "examples/decision_making/RL/sdqn/network.py",
    "content": "from typing import Any, Dict, Optional, Sequence, Tuple, Union\n\nimport numpy as np\nimport torch\nfrom torch import nn\n\nfrom braincog.base.node.node import LIFNode\nfrom ..utils.normalization import PopNorm\n\n\nclass SpikingDQN(nn.Module):\n    \"\"\"Reference: Human-level control through deep reinforcement learning.\n\n    For advanced usage (how to customize the network), please refer to\n    :ref:`build_the_network`.\n    \"\"\"\n\n    def __init__(\n        self,\n        c: int,\n        h: int,\n        w: int,\n        action_shape: Sequence[int],\n        device: Union[str, int, torch.device] = \"cpu\",\n        time_window: int = 16,\n        features_only: bool = False,\n    ) -> None:\n        super().__init__()\n        self._node = LIFNode\n        self.features_only = features_only\n        self.device = device\n        self._threshold = 1.0\n        self.v_reset = 0.0\n        self._decay = 0.5\n        self._time_window = time_window\n        self.p_count = 0\n\n        self.net = nn.Sequential(\n            nn.Conv2d(c, 32, kernel_size=8, stride=4),\n            PopNorm([32, 20, 20], threshold=self._threshold, v_reset=self.v_reset),\n            self._node(threshold=self._threshold, v_reset=self.v_reset),\n            nn.Conv2d(32, 64, kernel_size=4, stride=2),\n            PopNorm([64, 9, 9], threshold=self._threshold, v_reset=self.v_reset),\n            self._node(threshold=self._threshold, v_reset=self.v_reset),\n            nn.Conv2d(64, 64, kernel_size=3, stride=1),\n            PopNorm([64, 7, 7], threshold=self._threshold, v_reset=self.v_reset),\n            self._node(threshold=self._threshold, v_reset=self.v_reset),\n            nn.Flatten()\n        )\n        with torch.no_grad():\n            self.output_dim = np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:])\n        if not features_only:\n            self.net = nn.Sequential(\n                self.net, nn.Linear(self.output_dim, 512),\n                self._node(threshold=self._threshold, v_reset=self.v_reset),\n                nn.Linear(512, np.prod(action_shape), bias=False)\n            )\n            self.output_dim = np.prod(action_shape)\n\n    def reset(self):\n        for mod in self.modules():\n            if hasattr(mod, 'n_reset'):\n                mod.n_reset()\n        \n    def forward(\n        self,\n        x: Union[np.ndarray, torch.Tensor],\n        state: Optional[Any] = None,\n        info: Dict[str, Any] = {},\n    ) -> Tuple[torch.Tensor, Any]:\n        r\"\"\"Mapping: x -> Q(x, \\*).\"\"\"\n        self.reset()\n        x = torch.as_tensor(x, device=self.device, dtype=torch.float32) / 255.0\n        qs = []\n        for i in range(self._time_window):\n            value = self.net(x)\n            qs.append(value)\n        if self.features_only:\n            return qs, state\n        else:\n            q_values = sum(qs) / self._time_window    \n            return q_values, state\n\n\n\n"
  },
  {
    "path": "examples/decision_making/RL/utils/__init__.py",
    "content": "__all__ = ['normalization']\n\nfrom . import (\n    normalization,\n)\n"
  },
  {
    "path": "examples/decision_making/RL/utils/normalization.py",
    "content": "from typing import Optional, Any\n\nimport torch\nimport torch.nn as nn\nfrom torch import Tensor\nfrom torch.nn.parameter import Parameter, UninitializedParameter, UninitializedBuffer\nimport torch.nn.functional as F\nfrom torch import Tensor, Size\nfrom typing import Union, List\nimport numbers\nfrom torch.nn import Module\n\n_shape_t = Union[int, List[int], Size]\n\nclass PopNorm(Module):\n    r\"\"\"Applies Layer Normalization over a mini-batch of inputs as described in\n    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__\n\n    .. math::\n        y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\epsilon}} * \\gamma + \\beta\n\n    The mean and standard-deviation are calculated separately over the last\n    certain number dimensions which have to be of the shape specified by\n    :attr:`normalized_shape`.\n    :math:`\\gamma` and :math:`\\beta` are learnable affine transform parameters of\n    :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.\n    The standard-deviation is calculated via the biased estimator, equivalent to\n    `torch.var(input, unbiased=False)`.\n\n    .. note::\n        Unlike Batch Normalization and Instance Normalization, which applies\n        scalar scale and bias for each entire channel/plane with the\n        :attr:`affine` option, Layer Normalization applies per-element scale and\n        bias with :attr:`elementwise_affine`.\n\n    This layer uses statistics computed from input data in both training and\n    evaluation modes.\n\n    Args:\n        normalized_shape (int or list or torch.Size): input shape from an expected input\n            of size\n\n            .. math::\n                [* \\times \\text{normalized\\_shape}[0] \\times \\text{normalized\\_shape}[1]\n                    \\times \\ldots \\times \\text{normalized\\_shape}[-1]]\n\n            If a single integer is used, it is treated as a singleton list, and this module will\n            normalize over the last dimension which is expected to be of that specific size.\n        eps: a value added to the denominator for numerical stability. Default: 1e-5\n        elementwise_affine: a boolean value that when set to ``True``, this module\n            has learnable per-element affine parameters initialized to ones (for weights)\n            and zeros (for biases). Default: ``True``.\n\n    Shape:\n        - Input: :math:`(N, *)`\n        - Output: :math:`(N, *)` (same shape as input)\n\n    Examples::\n\n        >>> input = torch.randn(20, 5, 10, 10)\n        >>> # With Learnable Parameters\n        >>> m = nn.LayerNorm(input.size()[1:])\n        >>> # Without Learnable Parameters\n        >>> m = nn.LayerNorm(input.size()[1:], elementwise_affine=False)\n        >>> # Normalize over last two dimensions\n        >>> m = nn.LayerNorm([10, 10])\n        >>> # Normalize over last dimension of size 10\n        >>> m = nn.LayerNorm(10)\n        >>> # Activating the module\n        >>> output = m(input)\n    \"\"\"\n    __constants__ = ['normalized_shape', 'eps', 'elementwise_affine']\n    normalized_shape: _shape_t\n    eps: float\n    elementwise_affine: bool\n\n    def __init__(self, normalized_shape: _shape_t, threshold: float, v_reset: float, eps: float = 1e-5, affine: bool = True) -> None:\n        super().__init__()\n        if isinstance(normalized_shape, numbers.Integral):\n            normalized_shape = (normalized_shape,)\n        self.normalized_shape = tuple(normalized_shape)\n        self.threshold = threshold\n        self.v_reset = v_reset\n        self.eps = eps\n        self.affine = affine\n        if self.affine:\n            # self.weight = Parameter(torch.Tensor(*normalized_shape))\n            self.weight = Parameter(torch.Tensor(*normalized_shape))\n            self.bias = Parameter(torch.Tensor(*normalized_shape))\n        else:\n            self.register_parameter('weight', None)\n            self.register_parameter('bias', None)\n        self.reset_parameters()\n\n    def reset_parameters(self) -> None:\n        if self.affine:\n            # nn.init.ones_(self.weight)\n            # nn.init.zeros_(self.bias)\n            nn.init.constant_(self.weight, self.threshold-self.v_reset)\n            nn.init.constant_(self.bias, self.v_reset)\n    def forward(self, input: Tensor) -> Tensor:\n        out = F.layer_norm(\n            input, self.normalized_shape, self.weight, self.bias, self.eps)\n        # out = F.layer_norm(\n        #     input, self.normalized_shape, None, None, self.eps)\n        # if self.affine:\n        #     out = self.weight * out + self.bias\n        return out\n    def extra_repr(self) -> Tensor:\n        return '{normalized_shape}, eps={eps}, ' \\\n            'elementwise_affine={elementwise_affine}'.format(**self.__dict__)\n\n\nclass _NormBase(nn.Module):\n    \"\"\"Common base of _InstanceNorm and _BatchNorm\"\"\"\n\n    _version = 2\n    __constants__ = [\"track_running_stats\", \"momentum\", \"eps\", \"num_features\", \"affine\"]\n    num_features: int\n    eps: float\n    momentum: float\n    affine: bool\n    track_running_stats: bool\n    # WARNING: weight and bias purposely not defined here.\n    # See https://github.com/pytorch/pytorch/issues/39670\n\n    def __init__(\n        self,\n        num_features: int,\n        eps: float = 1e-5,\n        momentum: float = 0.1,\n        affine: bool = True,\n        track_running_stats: bool = True,\n        mean: float = 0.2,\n        device=None,\n        dtype=None\n    ) -> None:\n        factory_kwargs = {'device': device, 'dtype': dtype}\n        super(_NormBase, self).__init__()\n        self.num_features = num_features\n        self.eps = eps\n        self.momentum = momentum\n        self.affine = affine\n        self.track_running_stats = track_running_stats\n        self.mean = mean\n        if self.affine:\n            self.weight = Parameter(torch.empty(num_features, **factory_kwargs))\n            self.bias = Parameter(torch.empty(num_features, **factory_kwargs))\n            # self.bias = Parameter(torch.empty(num_features, **factory_kwargs), requires_grad=False)\n        else:\n            self.register_parameter(\"weight\", None)\n            self.register_parameter(\"bias\", None)\n        if self.track_running_stats:\n            self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs))\n            self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs))\n            self.running_mean: Optional[Tensor]\n            self.running_var: Optional[Tensor]\n            self.register_buffer('num_batches_tracked',\n                                 torch.tensor(0, dtype=torch.long,\n                                              **{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))\n        else:\n            self.register_buffer(\"running_mean\", None)\n            self.register_buffer(\"running_var\", None)\n            self.register_buffer(\"num_batches_tracked\", None)\n        self.reset_parameters()\n\n    def reset_running_stats(self) -> None:\n        if self.track_running_stats:\n            # running_mean/running_var/num_batches... are registered at runtime depending\n            # if self.track_running_stats is on\n            self.running_mean.zero_()  # type: ignore[union-attr]\n            self.running_var.fill_(1)  # type: ignore[union-attr]\n            self.num_batches_tracked.zero_()  # type: ignore[union-attr,operator]\n\n    def reset_parameters(self) -> None:\n        self.reset_running_stats()\n        if self.affine:\n            nn.init.ones_(self.weight)\n            # nn.init.zeros_(self.bias)\n            nn.init.constant_(self.bias, self.mean)\n\n    def _check_input_dim(self, input):\n        raise NotImplementedError\n\n    def extra_repr(self):\n        return (\n            \"{num_features}, eps={eps}, momentum={momentum}, affine={affine}, \"\n            \"track_running_stats={track_running_stats}\".format(**self.__dict__)\n        )\n\n    def _load_from_state_dict(\n        self,\n        state_dict,\n        prefix,\n        local_metadata,\n        strict,\n        missing_keys,\n        unexpected_keys,\n        error_msgs,\n    ):\n        version = local_metadata.get(\"version\", None)\n\n        if (version is None or version < 2) and self.track_running_stats:\n            # at version 2: added num_batches_tracked buffer\n            #               this should have a default value of 0\n            num_batches_tracked_key = prefix + \"num_batches_tracked\"\n            if num_batches_tracked_key not in state_dict:\n                state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long)\n\n        super(_NormBase, self)._load_from_state_dict(\n            state_dict,\n            prefix,\n            local_metadata,\n            strict,\n            missing_keys,\n            unexpected_keys,\n            error_msgs,\n        )\n\nclass _BatchNorm(_NormBase):\n    def __init__(\n        self,\n        num_features,\n        eps=1e-5,\n        momentum=0.1,\n        affine=True,\n        track_running_stats=True,\n        mean=0.2,\n        device=None,\n        dtype=None\n    ):\n        factory_kwargs = {'device': device, 'dtype': dtype}\n        super().__init__(\n            num_features, eps, momentum, affine, track_running_stats, mean,  **factory_kwargs\n        )\n\n    def forward(self, input: Tensor) -> Tensor:\n        self._check_input_dim(input)\n\n        # exponential_average_factor is set to self.momentum\n        # (when it is available) only so that it gets updated\n        # in ONNX graph when this node is exported to ONNX.\n        if self.momentum is None:\n            exponential_average_factor = 0.0\n        else:\n            exponential_average_factor = self.momentum\n\n        if self.training and self.track_running_stats:\n            # TODO: if statement only here to tell the jit to skip emitting this when it is None\n            if self.num_batches_tracked is not None:  # type: ignore[has-type]\n                self.num_batches_tracked = self.num_batches_tracked + 1  # type: ignore[has-type]\n                if self.momentum is None:  # use cumulative moving average\n                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)\n                else:  # use exponential moving average\n                    exponential_average_factor = self.momentum\n\n        r\"\"\"\n        Decide whether the mini-batch stats should be used for normalization rather than the buffers.\n        Mini-batch stats are used in training mode, and in eval mode when buffers are None.\n        \"\"\"\n        if self.training:\n            bn_training = True\n        else:\n            bn_training = (self.running_mean is None) and (self.running_var is None)\n\n        r\"\"\"\n        Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be\n        passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are\n        used for normalization (i.e. in eval mode when buffers are not None).\n        \"\"\"\n        return F.batch_norm(\n            input,\n            # If buffers are not to be tracked, ensure that they won't be updated\n            self.running_mean\n            if not self.training or self.track_running_stats\n            else None,\n            self.running_var if not self.training or self.track_running_stats else None,\n            self.weight,\n            self.bias,\n            bn_training,\n            exponential_average_factor,\n            self.eps,\n        )\n\n\nclass PDBatchNorm2d(_BatchNorm):\n    r\"\"\"Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs\n    with additional channel dimension) as described in the paper\n    `Batch Normalization: Accelerating Deep Network Training by Reducing\n    Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .\n\n    .. math::\n\n        y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\epsilon}} * \\gamma + \\beta\n\n    The mean and standard-deviation are calculated per-dimension over\n    the mini-batches and :math:`\\gamma` and :math:`\\beta` are learnable parameter vectors\n    of size `C` (where `C` is the input size). By default, the elements of :math:`\\gamma` are set\n    to 1 and the elements of :math:`\\beta` are set to 0. The standard-deviation is calculated\n    via the biased estimator, equivalent to `torch.var(input, unbiased=False)`.\n\n    Also by default, during training this layer keeps running estimates of its\n    computed mean and variance, which are then used for normalization during\n    evaluation. The running estimates are kept with a default :attr:`momentum`\n    of 0.1.\n\n    If :attr:`track_running_stats` is set to ``False``, this layer then does not\n    keep running estimates, and batch statistics are instead used during\n    evaluation time as well.\n\n    .. note::\n        This :attr:`momentum` argument is different from one used in optimizer\n        classes and the conventional notion of momentum. Mathematically, the\n        update rule for running statistics here is\n        :math:`\\hat{x}_\\text{new} = (1 - \\text{momentum}) \\times \\hat{x} + \\text{momentum} \\times x_t`,\n        where :math:`\\hat{x}` is the estimated statistic and :math:`x_t` is the\n        new observed value.\n\n    Because the Batch Normalization is done over the `C` dimension, computing statistics\n    on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.\n\n    Args:\n        num_features: :math:`C` from an expected input of size\n            :math:`(N, C, H, W)`\n        eps: a value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum: the value used for the running_mean and running_var\n            computation. Can be set to ``None`` for cumulative moving average\n            (i.e. simple average). Default: 0.1\n        affine: a boolean value that when set to ``True``, this module has\n            learnable affine parameters. Default: ``True``\n        track_running_stats: a boolean value that when set to ``True``, this\n            module tracks the running mean and variance, and when set to ``False``,\n            this module does not track such statistics, and initializes statistics\n            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.\n            When these buffers are ``None``, this module always uses batch statistics.\n            in both training and eval modes. Default: ``True``\n\n    Shape:\n        - Input: :math:`(N, C, H, W)`\n        - Output: :math:`(N, C, H, W)` (same shape as input)\n\n    Examples::\n\n        >>> # With Learnable Parameters\n        >>> m = nn.BatchNorm2d(100)\n        >>> # Without Learnable Parameters\n        >>> m = nn.BatchNorm2d(100, affine=False)\n        >>> input = torch.randn(20, 100, 35, 45)\n        >>> output = m(input)\n    \"\"\"\n\n    def _check_input_dim(self, input):\n        if input.dim() != 4:\n            raise ValueError(\"expected 4D input (got {}D input)\".format(input.dim()))"
  },
  {
    "path": "examples/decision_making/swarm/Collision-Avoidance.py",
    "content": "import torch,os\r\nfrom braincog.model_zoo.rsnn import RSNN\r\nfrom random import randint\r\nimport math\r\nimport random\r\nimport matplotlib\r\n# matplotlib.use(\"TkAgg\")\r\nimport numpy as np\r\nimport  random\r\nimport matplotlib.pyplot as plt\r\nimport matplotlib.animation as animation\r\n#os.environ[\"SDL_VIDEODRIVER\"] = \"dummy\"\r\n\r\n#parameters\r\nN =10\r\nWORLD_WIDTH = 500\r\nCOLLISION_THRE =25 #60 65 70\r\nWALL_COLLISION_LIMIT=10\r\nVISIBLE_THRE=75  #3=75/COLLISION_THRE   #3*COLLISION_THRE\r\n\r\n#eight velocity\r\nvel_space=[[0,1],[1,0],[0,-1],[-1,0],[1,1],[1,-1],[-1,-1],[-1,1]]\r\nvel_x_small=[[0,1],[1,0],[0,-1],[1,1],[1,-1]]\r\nvel_x_large=[[0,1],[0,-1],[-1,0],[-1,-1],[-1,1]]\r\nvel_y_small=[[0,1],[1,0],[-1,0],[1,1],[-1,1]]\r\nvel_y_large=[[1,0],[0,-1],[-1,0],[1,-1],[-1,-1]]\r\n\r\nN_action=len(vel_space)\r\ncol_robot=[i for i in range(N)]\r\n# parameters for rl+snn\r\nC = 50\r\nruntime = 100  # Runtime in ms for choosing action\r\n\r\n# parameters for snn\r\ntau = 10  # time constant of STDP\r\nstdpwin = 10  # STDP windows in ms\r\nApos = 0.925\r\nAneg = 0.1\r\nvr = 0  # Reset Potential\r\nvt = 0.1  # Judge if the neurons fire or not\r\n\r\ntau_m = 20\r\nRm = 0.5\r\ntau_e = 5\r\n# inhibition weight between output population\r\ns_in = np.random.rand(N_action * C, N_action * C)\r\nfor i in range(N_action):\r\n    for j in range(C):\r\n        for k in range(C):\r\n            s_in[i * C + j][i * C + k] = 0\r\n\r\n#init boids with no collision\r\nglobal boids\r\nboids = np.zeros(N, dtype=[('pos', int, 2), ('vel', int, 2),('nn',RSNN)])\r\nlist_rand=[i for i in range(16)]\r\nrand_int=random.sample(list_rand,N)\r\nfor i in range(len(rand_int)):\r\n    boids['pos'][i,0]=np.random.uniform(int(rand_int[i]%4)*125,(int(rand_int[i]%4)+1)*125+1,1)\r\n    boids['pos'][i,1] = np.random.uniform(int(rand_int[i]/4) * 125, (int(rand_int[i]/4) + 1) * 125 + 1, 1)\r\nboids['vel'] = np.random.uniform(-1, 2, (N, 2))\r\nfor i_vel in range(len(boids['vel'])):\r\n    boids['nn'][i_vel] = RSNN(N_action*2,N_action*C).cuda()\r\n    while(boids['vel'][i_vel][0]==0 and boids['vel'][i_vel][1]==0):\r\n        boids['vel'][i_vel] = np.random.uniform(-1, 2, (1, 2))\r\n\r\n#update boids parameters\r\ndo_update=np.zeros(N)\r\ndistance_pre=np.zeros((N,N))\r\ntmp_min_robot=[i for i in range(N)]\r\ntmp_input=[i for i in range(N)]\r\nsum_deta_tmp=np.zeros(N)\r\nsum_deta_new=np.zeros(N)\r\n\r\ntrace_decay = 0.8\r\ndef chooseAct(Net,input,explore):\r\n    count_group = np.zeros(N_action)\r\n    count_output = np.zeros(N_action * C)\r\n    if explore==-1:\r\n        pass\r\n    else:\r\n        pass\r\n    for i_train in range(runtime):\r\n        out, dw = Net(input[:,i_train])\r\n        # rstdp\r\n        Net.weight_trace *= trace_decay\r\n        Net.weight_trace += dw[0][0]\r\n\r\n        count_output=count_output+np.array(out)\r\n        for i in range(N_action):\r\n            count_group[i]=count_output[i*C:(i+1)*C].sum()\r\n        if count_group.max()>C/2:\r\n            action=count_group.argmax()\r\n    return action,Net\r\n        # if t==runtime-2 and len(np.where(self.count_group==0)[0])!=len(self.count_group):\r\n        #     self.action=self.count_group.argmax()\r\n\r\n\r\ndef update_boids(xs, ys, xvs, yvs,frame):\r\n    global distance_pre,col_c\r\n    # Matrix off position difference and distance\r\n    xdiff = np.add.outer(xs, -xs)\r\n    ydiff = np.add.outer(ys, -ys)\r\n    distance = np.sqrt(xdiff ** 2 + ydiff ** 2)\r\n    # Calculate the boids that are visible to every other boid   -pi/2 to pi/2\r\n    visible = np.zeros((N, N))\r\n    dir = np.zeros((N, N))\r\n    col_c = WORLD_WIDTH * np.ones((N, 4))\r\n    dir_c = np.zeros((N, 4))\r\n    angle_towards = np.arctan2(-ydiff, -xdiff)\r\n    angle_vel = np.arctan2(yvs, xvs)\r\n    for i in range(N):\r\n        for j in range(N):\r\n            if (xvs[i] == 1 and yvs[i] == 0) or (xvs[i] == 1 and yvs[i] == 1) or (xvs[i] == 0 and yvs[i] == 1) or (\r\n                    xvs[i] == 0 and yvs[i] == -1) or (xvs[i] == 1 and yvs[i] == -1):\r\n                if angle_towards[i][j] < angle_vel[i] + np.pi / 2 and angle_towards[i][j] > angle_vel[i] - np.pi / 2:\r\n                    visible[i][j] = True\r\n                if angle_towards[i][j] > angle_vel[i] - np.pi / 2 and angle_towards[i][j] < angle_vel[i]:\r\n                    dir[i][j]=1#right\r\n                if angle_towards[i][j] < angle_vel[i] + np.pi / 2 and angle_towards[i][j] >= angle_vel[i]:\r\n                    dir[i][j] = 2#left\r\n            if xvs[i] == -1 and yvs[i] == 1:\r\n                if (angle_towards[i][j] > angle_vel[i] - np.pi / 2 and angle_towards[i][j] < np.pi) or (\r\n                        angle_towards[i][j] > -np.pi and angle_towards[i][j] < angle_vel[i] - 1.5 * np.pi):\r\n                    visible[i][j] = True\r\n                    if angle_towards[i][j] > angle_vel[i] - np.pi / 2 and angle_towards[i][j] < angle_vel[i]:\r\n                        dir[i][j] = 1\r\n                    if (angle_towards[i][j] < np.pi and angle_towards[i][j] >= angle_vel[i]) or (\r\n                        angle_towards[i][j] > -np.pi and angle_towards[i][j] < angle_vel[i] - 1.5 * np.pi):\r\n                        dir[i][j] = 2\r\n            if xvs[i] == -1 and yvs[i] == 0:\r\n                if (angle_towards[i][j] > np.pi / 2 and angle_towards[i][j] < np.pi) or (\r\n                        angle_towards[i][j] > -np.pi and angle_towards[i][j] < -np.pi / 2):\r\n                    visible[i][j] = True\r\n                if angle_towards[i][j] > np.pi / 2 and angle_towards[i][j] < np.pi:\r\n                    dir[i][j] = 1\r\n                if angle_towards[i][j] >= -np.pi and angle_towards[i][j] < -np.pi / 2:\r\n                    dir[i][j] = 2\r\n            if xvs[i] == -1 and yvs[i] == -1:\r\n                if (angle_towards[i][j] > -np.pi and angle_towards[i][j] < -np.pi / 4) or (\r\n                        angle_towards[i][j] > 0.75 * np.pi and angle_towards[i][j] < np.pi):\r\n                    visible[i][j] = True\r\n                if (angle_towards[i][j] > 0.75 * np.pi and angle_towards[i][j] < np.pi) or (\r\n                        angle_towards[i][j] > -np.pi and angle_towards[i][j] < angle_vel[i]):\r\n                    dir[i][j] = 1\r\n                if angle_towards[i][j] >= angle_vel[i] and angle_towards[i][j] < -np.pi / 4:\r\n                    dir[i][j] = 2\r\n    v_tmp = np.diag(np.diag(visible))\r\n    visible = visible - v_tmp\r\n    # the danger of collision, considering dis=6*collision\r\n    collision = np.clip(VISIBLE_THRE/COLLISION_THRE - distance / COLLISION_THRE, 0,VISIBLE_THRE/COLLISION_THRE) * visible  # visible and in some distance 3*collision_thre\r\n    c_tmp = np.diag(np.diag(collision))\r\n    collision = collision - c_tmp\r\n\r\n    if len(np.where(yvs[np.where(ys < (VISIBLE_THRE/COLLISION_THRE)*WALL_COLLISION_LIMIT)] == -1)[0])>0:\r\n        wall_tmp=np.where(ys < (VISIBLE_THRE/COLLISION_THRE)*WALL_COLLISION_LIMIT)[0]\r\n        for i_wall in range(len(wall_tmp)):\r\n            if yvs[wall_tmp[i_wall]] == -1:\r\n                col_c[wall_tmp[i_wall], 0] = ys[wall_tmp[i_wall]]\r\n                if xvs[wall_tmp[i_wall]] >= 0:\r\n                    dir_c[wall_tmp[i_wall], 0] = 1\r\n                else:\r\n                    dir_c[wall_tmp[i_wall], 1] = 2\r\n    if len(np.where(xvs[np.where(xs < (VISIBLE_THRE/COLLISION_THRE)*WALL_COLLISION_LIMIT)]==-1)[0])>0:\r\n        wall_tmp = np.where(xs < (VISIBLE_THRE/COLLISION_THRE)*WALL_COLLISION_LIMIT)[0]\r\n        for i_wall in range(len(wall_tmp)):\r\n            if xvs[wall_tmp[i_wall]] == -1:\r\n                col_c[wall_tmp[i_wall], 1] = xs[wall_tmp[i_wall]]\r\n                if yvs[wall_tmp[i_wall]] >= 0:\r\n                    dir_c[wall_tmp[i_wall], 1] = 2\r\n                else:\r\n                    dir_c[wall_tmp[i_wall], 1] = 1\r\n    if len(np.where(yvs[np.where((WORLD_WIDTH - ys) < (VISIBLE_THRE/COLLISION_THRE) * WALL_COLLISION_LIMIT)] == 1)[0]) > 0:\r\n        wall_tmp = np.where((WORLD_WIDTH - ys) < (VISIBLE_THRE/COLLISION_THRE) * WALL_COLLISION_LIMIT)[0]\r\n        for i_wall in range(len(wall_tmp)):\r\n            if yvs[wall_tmp[i_wall]]==1:\r\n                col_c[wall_tmp[i_wall],2] =WORLD_WIDTH - ys[wall_tmp[i_wall]]\r\n                if xvs[wall_tmp[i_wall]]>=0:\r\n                    dir_c[wall_tmp[i_wall],2]=2\r\n                else:\r\n                    dir_c[wall_tmp[i_wall], 2] = 1\r\n    if len(np.where(xvs[np.where((WORLD_WIDTH - xs) < (VISIBLE_THRE/COLLISION_THRE)*WALL_COLLISION_LIMIT)] ==1)[0])>0:\r\n        wall_tmp=np.where((WORLD_WIDTH - xs) < (VISIBLE_THRE/COLLISION_THRE) * WALL_COLLISION_LIMIT)[0]\r\n        for i_wall in range(len(wall_tmp)):\r\n            if xvs[wall_tmp[i_wall]]==1:\r\n                col_c[wall_tmp[i_wall],3] =WORLD_WIDTH - xs[wall_tmp[i_wall]]\r\n                if yvs[wall_tmp[i_wall]]>=0:\r\n                    dir_c[wall_tmp[i_wall],3]=1\r\n                else:\r\n                    dir_c[wall_tmp[i_wall], 3] = 2\r\n    # print(col_c)\r\n    col_c_tmp = np.clip(VISIBLE_THRE/COLLISION_THRE - col_c / WALL_COLLISION_LIMIT, 0, VISIBLE_THRE/COLLISION_THRE)\r\n    deta_dis_tmp = distance - distance_pre\r\n    deta_dis = deta_dis_tmp * collision  # <0 and small is the obstacle\r\n    collision=np.c_[collision, col_c_tmp]\r\n    deta_dis=np.c_[deta_dis, -col_c_tmp]\r\n    dir=np.c_[dir,dir_c]\r\n    # print(collision,deta_dis)\r\n    #for every agent, choose the approaching agent as input\r\n    for i in range(N):\r\n        if frame>1 and do_update[i]>0:\r\n            sum_deta_new[i] = (tmp_input[i] * collision[i][tmp_min_robot[i]]).sum()\r\n            # print(sum_deta_new[i] ,sum_deta_tmp[i] )\r\n            if sum_deta_new[i]  < sum_deta_tmp[i] :\r\n                r=10*(sum_deta_tmp[i]-sum_deta_new[i])\r\n            else:\r\n                r=-10*(sum_deta_new[i]-sum_deta_tmp[i])\r\n            boids['nn'][i].UpdateWeight(r)\r\n        if frame > 0:\r\n            do_update[i] =0\r\n            if len(np.where(deta_dis[i] < 0)[0]) > 0:\r\n                do_update[i] += 1\r\n                # then get the velocity direction of objects and the distance between them as the network input\r\n                appro_index = np.where(deta_dis[i] < 0)[0]  # the input is the approching directions and distances\r\n                # print(appro_index)\r\n                input = []\r\n                for j in range(len(appro_index)):\r\n                    if appro_index[j]<=N-1:\r\n                        xvs_input = xvs[appro_index[j]]\r\n                        yvs_input = yvs[appro_index[j]]\r\n                        input.append(vel_space.index([xvs_input, yvs_input]))\r\n                    else:\r\n                        vel_tmp=int(appro_index[j]%N)\r\n                        input.append(vel_tmp)\r\n                dis_tmp=np.c_[distance,col_c]\r\n                weight = -1 * dis_tmp[i][np.where(deta_dis[i] < 0)]\r\n                # input=input[np.argmin(weight)]\r\n                if weight.max() - weight.min() == 0:\r\n                    weight = np.random.randint(1, 5, weight.shape)\r\n                    weight[0] = 4\r\n                else:\r\n                    k = (4 - 1) / (weight.max() - weight.min())\r\n                    weight = 1 + k * (weight - weight.min())\r\n                # print(input,weight)\r\n                I = np.zeros((N_action*2, runtime))\r\n                for j in range(len(input)):\r\n                    # print(appro_index,input,appro_index[j],dir[i][appro_index[j]],input[j]*dir[i][appro_index[j]])\r\n                    I[int(input[j]+N_action*(dir[i][appro_index[j]]-1))][0:runtime] = max(I[int(input[j]+N_action*(dir[i][appro_index[j]]-1))][0], weight[j])\r\n                if random.random()<0.7:\r\n                    action_index,boids['nn'][i] = chooseAct(boids['nn'][i],I,-1)#exploitation\r\n                else:\r\n                    action_index,boids['nn'][i] = chooseAct(boids['nn'][i],I, 1)  #exploration\r\n                xvs[i] = vel_space[action_index][0]\r\n                yvs[i] = vel_space[action_index][1]\r\n                tmp_min_robot[i] = np.where(deta_dis[i] < 0)[0]\r\n                tmp_input[i] = weight\r\n                sum_deta_tmp[i] = (tmp_input[i] * collision[i][tmp_min_robot[i]]).sum()\r\n    xs+=xvs\r\n    ys+=yvs\r\n    xs=np.clip(xs,0,WORLD_WIDTH)\r\n    ys = np.clip(ys, 0, WORLD_WIDTH)\r\n    distance_pre = distance\r\n    if frame>=10000:\r\n        for i in range(N):\r\n            for j in range(N_action*2):\r\n                I = np.zeros((N_action * 2, runtime))\r\n                I[j][0:runtime]=4\r\n                a=chooseAct(boids['nn'][i],I,-1)\r\n                # print(a)\r\n                aaa=1\r\n\r\n\r\ndef animate(frame):\r\n    update_boids(boids['pos'][:, 0], boids['pos'][:, 1], boids['vel'][:, 0], boids['vel'][:, 1],frame)\r\n    scatter.set_offsets(boids['pos'])\r\n    scatter1.set_offsets(boids['pos'])\r\n\r\n#build background\r\nfig = plt.figure(figsize=(8, 8))\r\nax1 = fig.add_subplot(111)\r\nax1.set_title('Scatter Plot')\r\nplt.xlim(-20,520)\r\nplt.ylim(-20,520)\r\nplt.grid(ls='--',c='gray')\r\nplt.xlabel('X')\r\nplt.ylabel('Y')\r\n# Use a scatter plot to visualize the boids\r\ncolor_list=['r','b','g','y','m','c','deeppink','tomato','gold','crimson','cornsilk','darkred','greenyellow','lightcoral','mintcream',\r\n'rosybrown']\r\ncolors=color_list[0:N]\r\n#colors=random.sample(color_list,N)\r\nlines=np.zeros(N)+5\r\nscatter = ax1.scatter(boids['pos'][:, 0], boids['pos'][:, 1],s=500,alpha=0.5,linewidths=lines)\r\nscatter1 = ax1.scatter(boids['pos'][:, 0], boids['pos'][:, 1],s=2500,c=colors,alpha=0.5)\r\nboids_newp=boids['pos']+boids['vel']*10\r\nfor i in range(N):\r\n    boids_linex=np.hstack((boids['pos'][i, 0],boids_newp[i,0]))\r\n    boids_liney=np.hstack((boids['pos'][i, 1],boids_newp[i,1]))\r\n    #line,=plt.plot(boids_linex,boids_liney,linewidth=5)\r\n#lines = [ax1.plot(np.hstack((boids['pos'][i, 0],boids_newp[i,0])), np.hstack((boids['pos'][i, 1],boids_newp[i,1])),linewidth=5) for i in range(N)]\r\nanimation = animation.FuncAnimation(fig, animate,interval=0.001)\r\nplt.show()"
  },
  {
    "path": "examples/decision_making/swarm/README.md",
    "content": "# Reward-modulated Spiking Neural Network for Self-organizing Collision Avoidance of Drone Swarm\n\nThis repository contains code from our paper [*Nature-inspired Self-organizing Collision Avoidance for Drones Swarm Based on Reward-modulated Spiking Neural Network*] published in Cell Patterns. \nhttps://www.cell.com/patterns/fulltext/S2666-3899(22)00236-7\n\nWe also provide the BrainCog-based version: https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/decision_making/swarm\n\nIf you use our code or refer to this project, please cite this paper:\nFeifei Zhao,Yi Zeng, Bing Han, Hongjian Fang, and Zhuoya Zhao. Nature-inspired Self-organizing Collision Avoidance for Drones Swarm Based on Reward-modulated Spiking Neural Network. Patterns, DOI：https://doi.org/10.1016/j.patter.2022.100611\n\n\n## Paper Introduction \nThe collaborative interaction mechanisms of biological swarms in nature are of great importance to inspire the study of swarm intelligence. This paper proposed a self-organizing obstacle avoidance model by drawing on the decentralized, self-organizing properties of intelligent behavior of biological swarms. Each individual independently adopts brain-inspired reward-modulated spiking neural network (RSNN) to achieve online learning and makes decentralized decisions based on local observations. The following picture shows the decision-making process of our model.\n\n<center><img src=\"./process.jpg\" width=\"50%\"></center>\n\n\nWe validated the proposed model on swarm collision avoidance tasks (a swarm of unmanned aerial vehicles without central control) in a bounded space, carrying out simulation and real-world experiments. The drone swarm emerges with safe flight behavior, as shown in the following videos. Compared with artificial neural network-based online learning methods, our proposed method exhibits superior performance and better stability. \n\n<img src=\"./simulation.gif\" alt=\"mt\" width=\"55%\" /><img src=\"./figures/joy.gif\" alt=\"mt\" width=\"55%\" />\n<img src=\"./collision_avoidance.gif\" alt=\"mt\" width=\"55%\" /><img src=\"./figures/joy.gif\" alt=\"mt\" width=\"55%\" />\n\n## Run\n * \"reward-modulated snn on swarm collision avoidance.py\"  includes the self-organized collision avoidance implemented by RSNN for simulation scenarios.\n * \"flytestfive.py\"  includes five UAVs swarm collision avoidance validation in real bounded scenario.\n \n ## Requirments\n* \"reward-modulated snn on swarm collision avoidance.py\": python==3.7, numpy>=1.21.6\n* \"flytestfive.py\": multi_robomaster\n\n\n\n"
  },
  {
    "path": "requirements.txt",
    "content": "numpy\nscipy\nh5py\ntorch\ntorchvision\ntorchaudio\ntimm == 0.6.13\nscikit-learn\neinops\nthop\npyyaml\nmatplotlib\nseaborn\npygame\ndv\ntensorboard\ntonic\n"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import find_packages\nfrom setuptools import setup\n\nwith open(\"./requirements.txt\", \"r\", encoding=\"utf-8\") as fh:\n    install_requires = fh.read()\nwith open(\"README.md\", \"r\", encoding=\"utf-8\") as fh:\n    long_description = fh.read()\nsetup(\n    install_requires=install_requires,\n    packages=find_packages(),\n    description=\"BrainCog is an open source spiking neural network based brain-inspired cognitive intelligence engine for Brain-inspired Artificial Intelligence and brain simulation. More information on braincog can be found on its homepage http://www.brain-cog.network/\",\n    long_description=long_description,  \n    long_description_content_type=\"text/markdown\",  \n    classifiers=[\n        \"Programming Language :: Python :: 3 :: Only\",\n        \"Programming Language :: Python :: 3.6\",\n        \"Programming Language :: Python :: 3.7\",\n        \"Programming Language :: Python :: 3.8\",\n        \"Programming Language :: Python :: 3.9\",\n        \"License :: Other/Proprietary License\",\n        \"Operating System :: OS Independent\",\n    ],\n    name='braincog',\n    version='0.2.7.19',\n    author='braincog',\n    python_requires='>=3.6'\n)\n"
  }
]