[
  {
    "path": ".gitignore",
    "content": "*.json\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\nout/\n# C extensions\n*.so\n*.pkl\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndistf/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nwandb/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n.vscode/*\n.vscode/settings.json\n\n\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n# target/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nvenv/\nenv.bak/\nvenv.bak/\n \n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n # Pyre type checker\n.pyre/\ncheckpoints/\ndata/*\noutput/\nlog/\nruns/\n\n*.png\n*.jpg\n*.mp4\n*.gif\n*.pkl\n*.pt"
  },
  {
    "path": "GenMM.py",
    "content": "import os\nimport os.path as osp\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\n\nfrom utils.base import logger\n\nclass GenMM:\n    def __init__(self, mode = 'random_synthesis', noise_sigma = 1.0, coarse_ratio = 0.2, coarse_ratio_factor = 6, pyr_factor = 0.75, num_stages_limit = -1, device = 'cuda:0', silent = False):\n        '''\n        GenMM main constructor\n        Args:\n            device : str = 'cuda:0', default device.\n            silent : bool = False, whether to mute the output.\n        '''\n        self.device = torch.device(device)\n        self.silent = silent\n\n    def _get_pyramid_lengths(self, final_len, coarse_ratio, pyr_factor):\n        '''\n        Get a list of pyramid lengths using given target length and ratio\n        '''\n        lengths = [int(np.round(final_len * coarse_ratio))]\n        while lengths[-1] < final_len:\n            lengths.append(int(np.round(lengths[-1] / pyr_factor)))\n            if lengths[-1] == lengths[-2]:\n                lengths[-1] += 1\n        lengths[-1] = final_len\n\n        return lengths\n\n    def _get_target_pyramid(self, target, coarse_ratio, pyr_factor, num_stages_limit=-1):\n        '''\n        Reads a target motion(s) and create a pyraimd out of it. Ordered in increatorch.sing size\n        '''\n        self.num_target = len(target)\n        lengths = []\n        min_len = 10000\n        for i in range(len(target)):\n            new_length = self._get_pyramid_lengths(len(target[i].motion_data), coarse_ratio, pyr_factor)\n            min_len = min(min_len, len(new_length))\n            if num_stages_limit != -1:\n                new_length = new_length[:num_stages_limit]\n            lengths.append(new_length)\n        for i in range(len(target)):\n            lengths[i] = lengths[i][-min_len:]\n        self.pyraimd_lengths = lengths\n\n        target_pyramid = [[] for _ in range(len(lengths[0]))]\n        for step in range(len(lengths[0])):\n            for i in range(len(target)):\n                length = lengths[i][step]\n                target_pyramid[step].append(target[i].sample(size=length).to(self.device))\n\n        if not self.silent:\n            print('Levels:', lengths)\n            for i in range(len(target_pyramid)):\n                print(f'Number of clips in target pyramid {i} is {len(target_pyramid[i])}, ranging {[[tgt.min(), tgt.max()] for tgt in target_pyramid[i]]}')\n\n        return target_pyramid\n\n    def _get_initial_motion(self, init_length, noise_sigma):\n        '''\n        Prepare the initial motion for optimization\n        '''\n        initial_motion = F.interpolate(torch.cat([self.target_pyramid[0][i] for i in range(self.num_target)], dim=-1),\n                                       size=init_length, mode='linear', align_corners=True)\n        if noise_sigma > 0:\n            initial_motion_w_noise = initial_motion + torch.randn_like(initial_motion) * noise_sigma\n            initial_motion_w_noise = torch.fmod(initial_motion_w_noise, 1.0)\n        else:\n            initial_motion_w_noise = initial_motion\n\n        if not self.silent:\n            print('Initial motion:', initial_motion.min(), initial_motion.max())\n            print('Initial motion with noise:', initial_motion_w_noise.min(), initial_motion_w_noise.max())\n\n        return initial_motion_w_noise\n\n    def run(self, target, criteria, num_frames, num_steps, noise_sigma, patch_size, coarse_ratio, pyr_factor, ext=None, debug_dir=None):\n        '''\n        generation function\n        Args:\n            mode             : - string = 'x?', generate x times longer frames results\n                             : - int, specifying the number of times to generate\n            noise_sigma      : float = 1.0, random noise.\n            coarse_ratio     : float = 0.2, ratio at the coarse level.\n            pyr_factor       : float = 0.75, pyramid factor.\n            num_stages_limit : int = -1, no limit.\n        '''\n        if debug_dir is not None:\n            from tensorboardX import SummaryWriter\n            writer = SummaryWriter(log_dir=debug_dir)\n\n        # build target pyramid\n        if 'patchsize' in coarse_ratio:\n            coarse_ratio = patch_size * float(coarse_ratio.split('x_')[0]) / max([len(t.motion_data) for t in target])\n        elif 'nframes' in coarse_ratio:\n            coarse_ratio = float(coarse_ratio.split('x_')[0])\n        else:\n            raise ValueError('Unsupported coarse ratio specified')\n        self.target_pyramid = self._get_target_pyramid(target, coarse_ratio, pyr_factor)\n\n        # get the initial motion data\n        if 'nframes' in num_frames:\n            syn_length = int(sum([i[-1] for i in self.pyraimd_lengths]) * float(num_frames.split('x_')[0]))\n        elif num_frames.isdigit():\n            syn_length = int(num_frames)\n        else:\n            raise ValueError(f'Unsupported mode {self.mode}')\n        self.synthesized_lengths = self._get_pyramid_lengths(syn_length, coarse_ratio, pyr_factor)\n        if not self.silent:\n            print('Synthesized lengths:', self.synthesized_lengths)\n        self.synthesized = self._get_initial_motion(self.synthesized_lengths[0], noise_sigma)\n\n        # perform the optimization\n        self.synthesized.requires_grad_(False)\n        self.pbar = logger(num_steps, len(self.target_pyramid))\n        for lvl, lvl_target in enumerate(self.target_pyramid):\n            self.pbar.new_lvl()\n            if lvl > 0:\n                with torch.no_grad():\n                    self.synthesized = F.interpolate(self.synthesized.detach(), size=self.synthesized_lengths[lvl], mode='linear')\n\n            self.synthesized, losses = GenMM.match_and_blend(self.synthesized, lvl_target, criteria, num_steps, self.pbar, ext=ext)\n\n            criteria.clean_cache()\n            if debug_dir is not None:\n                for itr in range(len(losses)):\n                    writer.add_scalar(f'optimize/losses_lvl{lvl}', losses[itr], itr)\n        self.pbar.pbar.close()\n\n        return self.synthesized.detach()\n\n\n    @staticmethod\n    @torch.no_grad()\n    def match_and_blend(synthesized, targets, criteria, n_steps, pbar, ext=None):\n        '''\n        Minimizes criteria between synthesized and target\n        Args:\n            synthesized    : torch.Tensor, optimized motion data\n            targets        : torch.Tensor, target motion data\n            criteria       : optimize target function\n            n_steps        : int, number of steps to optimize\n            pbar           : logger\n            ext            : extra configurations or constraints (optional)\n        '''\n        losses = []\n        keyframe_motion = targets[0] if isinstance(targets, list) else targets\n        syn_length = synthesized.shape[-1]\n        km_length = keyframe_motion.shape[-1]\n\n        print(\"Synthesized shape:\", synthesized.shape)\n        print(\"Keyframe_motion shape:\", keyframe_motion.shape)\n\n        # Use the class-level KEYFRAME_INDICES\n        keyframe_indices = GenMM.KEYFRAME_INDICES\n\n        for _i in range(n_steps):\n            synthesized, loss = criteria(synthesized, targets, ext=ext, return_blended_results=True)\n\n            # Manually set the keyframes in synthesized motion to be the ones from the input motion\n            if syn_length >= keyframe_indices.stop and km_length >= keyframe_indices.stop:\n                synthesized[..., keyframe_indices] = keyframe_motion[..., keyframe_indices]\n\n            # Update status\n            losses.append(loss.item())\n            pbar.step()\n            pbar.print()\n\n        return synthesized, losses\n\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# Example-based Motion Synthesis via Generative Motion Matching, ACM Transactions on Graphics (Proceedings of SIGGRAPH 2023)\n\n#####  <p align=\"center\"> [Weiyu Li*](https://wyysf-98.github.io/), [Xuelin Chen*†](https://xuelin-chen.github.io/), [Peizhuo Li](https://peizhuoli.github.io/), [Olga Sorkine-Hornung](https://igl.ethz.ch/people/sorkine/), [Baoquan Chen](https://cfcs.pku.edu.cn/baoquan/)</p>\n \n#### <p align=\"center\">[Project Page](https://wyysf-98.github.io/GenMM) | [ArXiv](https://arxiv.org/abs/2306.00378) | [Paper](https://wyysf-98.github.io/GenMM/paper/Paper_high_res.pdf) | [Video](https://youtu.be/lehnxcade4I)</p>\n\n<p align=\"center\">\n  <img src=\"https://wyysf-98.github.io/GenMM/assets/images/teaser.png\"/>\n</p>\n\n<p align=\"center\"> All Code and demo will be released in this week(still ongoing...) 🏗️ 🚧 🔨</p>\n\n- [x] Release main code\n- [x] Release blender addon\n- [x] Detailed README and installation guide\n- [ ] Release skeleton-aware component, WIP as we need to split the joints into groups manually.\n- [ ] Release codes for evaluation\n\n## Prerequisite\n\n<details> <summary>Setup environment</summary>\n\n:smiley: We also provide a Dockerfile for easy installation, see [Setup using Docker](./docker/README.md).\n\n - Python 3.8\n - PyTorch 1.12.1\n - [unfoldNd](https://github.com/f-dangel/unfoldNd)\n\nClone this repository.\n\n```sh\ngit clone git@github.com:wyysf-98/GenMM.git\n```\n\nInstall the required packages.\n\n```sh\nconda create -n GenMM python=3.8\nconda activate GenMM\nconda install -c pytorch pytorch=1.12.1 torchvision=0.13.1 cudatoolkit=11.3 && \\\npip install -r docker/requirements.txt\npip install torch-scatter==2.1.1\n```\n\n</details>\n\n## Quick inference demo\nFor local quick inference demo using .bvh file, you can use\n\n```sh\npython run_random_generation.py -i './data/Malcolm/Gangnam-Style.bvh'\n```\nMore configuration can be found in the `run_random_generation.py`.\nWe use an Apple M1 and NVIDIA Tesla V100 with 32 GB RAM to generate each motion, which takes about ~0.2s and ~0.05s as mentioned in our paper.\n\n## Blender add-on\nYou can install and use the blender add-on with easy installation as our method is efficient and you do not need to install CUDA Toolkit.\nWe test our code using blender 3.22.0, and will support 2.8.0 in the future.\n\nStep 1: Find yout blender python path. Common paths are as follows\n```sh\n(Windows) 'C:\\Program Files\\Blender Foundation\\Blender 3.2\\3.2\\python\\bin'\n(Linux) '/path/to/blender/blender-path/3.2/python/bin'\n(Windows) '/Applications/Blender.app/Contents/Resources/3.2/python/bin'\n```\n\nStep 2: Install required packages. Open your shell(Linux) or powershell(Windows), \n```sh\ncd {your python path} && pip3 install -r docker/requirements.txt && pip3 install torch-scatter==2.1.0 -f https://data.pyg.org/whl/torch-1.12.0+${CUDA}.html\n```\n, where ${CUDA} should be replaced by either cpu, cu117, or cu118 depending on your PyTorch installation.\nOn my MacOS with M1 cpu,\n\n```sh\ncd /Applications/Blender.app/Contents/Resources/3.2/python/bin && pip3 install -r docker/requirements_blender.txt && pip3 install torch-scatter==2.1.0 -f https://data.pyg.org/whl/torch-1.12.0+cpu.html\n```\n\nStep 3: Install add-on in blender. [Blender Add-ons Official Tutorial](https://docs.blender.org/manual/en/latest/editors/preferences/addons.html). `edit -> Preferences -> Add-ons -> Install -> Select the downloaded .zip file`\n\nStep 4: Have fun! Click the armature and you will find a `GenMM` tag.\n\n(GPU support) If you have GPU and CUDA Toolskits installed, we automatically dectect the running device.\n\nFeel free to submit an issue if you run into any issues during the installation :)\n\n## Acknowledgement\n\nWe thank [@stefanonuvoli](https://github.com/stefanonuvoli/skinmixer) for the help for the discussion of implementation about `Motion Reassembly` part (we eventually manually merged the meshes of different characters). And [@Radamés Ajna](https://github.com/radames) for the help of a better huggingface demo. \n\n\n## Citation\n\nIf you find our work useful for your research, please consider citing using the following BibTeX entry.\n\n```BibTeX\n@article{10.1145/weiyu23GenMM,\n    author     = {Li, Weiyu and Chen, Xuelin and Li, Peizhuo and Sorkine-Hornung, Olga and Chen, Baoquan},\n    title      = {Example-Based Motion Synthesis via Generative Motion Matching},\n    journal    = {ACM Transactions on Graphics (TOG)},\n    volume     = {42},\n    number     = {4},\n    year       = {2023},\n    articleno  = {94},\n    doi = {10.1145/3592395},\n    publisher  = {Association for Computing Machinery},\n}\n```\n"
  },
  {
    "path": "__init__.py",
    "content": "# This program is free software; you can redistribute it and/or modify\n# it under the terms of the GNU General Public License as published by\n# the Free Software Foundation; either version 3 of the License, or\n# (at your option) any later version.\n#\n# This program is distributed in the hope that it will be useful, but\n# WITHOUT ANY WARRANTY; without even the implied warranty of\n# MERCHANTIBILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU\n# General Public License for more details.\n#\n# You should have received a copy of the GNU General Public License\n# along with this program. If not, see <http://www.gnu.org/licenses/>.\nimport os\nimport sys\nimport bpy\nimport torch\nimport mathutils\nimport numpy as np\nfrom math import degrees, radians, ceil\nfrom mathutils import Vector, Matrix, Euler\nfrom typing import List, Iterable, Tuple, Any, Dict\n\nabs_path = os.path.abspath(__file__)\nsys.path.append(os.path.dirname(abs_path))\nfrom GenMM import GenMM\nfrom nearest_neighbor.losses import PatchCoherentLoss\nfrom dataset.blender_motion import BlenderMotion\n\nbl_info = {\n    \"name\" : \"GenMM\",\n    \"author\" : \"Weiyu Li\",\n    \"description\" : \"Blender addon for SIGGRAPH paper 'Example-Based Motion Synthesis via Generative Motion Matching'\",\n    \"blender\" : (3, 2, 0),\n    \"version\" : (0, 0, 1),\n    \"location\": \"3D View\",\n    \"description\": \"Synthesis novel motions form a few exemplars.\",\n    \"location\" : \"\",\n    \"support\": \"TESTING\",\n    \"warning\" : \"\",\n    \"category\" : \"Generic\"\n}\n\ndef capture_rest_pose(armature_obj):\n    \"\"\"Capture the rest pose bone data (head, tail, roll) from an armature.\"\"\"\n    rest_pose_data = {}\n    bpy.ops.object.mode_set(mode='EDIT')\n    arm_data = armature_obj.data\n    for bone in arm_data.edit_bones:\n        rest_pose_data[bone.name] = {\n            'head': bone.head.copy(),\n            'tail': bone.tail.copy(),\n            'roll': bone.roll,\n            'matrix_local': bone.matrix.copy()\n        }\n    bpy.ops.object.mode_set(mode='OBJECT')\n    return rest_pose_data\n\n# This function is modified from\n# https://github.com/bwrsandman/blender-addons/blob/master/io_anim_bvh\ndef get_bvh_data(context,\n                 frame_end,\n                 frame_start,\n                 global_scale=1.0,\n                 rotate_mode='NATIVE',\n                 root_transform_only=False,\n                 ):\n\n    def ensure_rot_order(rot_order_str):\n        if set(rot_order_str) != {'X', 'Y', 'Z'}:\n            rot_order_str = \"XYZ\"\n        return rot_order_str\n    \n    file_str = []\n\n    obj = context.object\n    arm = obj.data\n\n    # Build a dictionary of children.\n    # None for parentless\n    children = {None: []}\n\n    # initialize with blank lists\n    for bone in arm.bones:\n        children[bone.name] = []\n\n    # keep bone order from armature, no sorting, not esspential but means\n    # we can maintain order from import -> export which secondlife incorrectly expects.\n    for bone in arm.bones:\n        children[getattr(bone.parent, \"name\", None)].append(bone.name)\n\n    # bone name list in the order that the bones are written\n    serialized_names = []\n\n    node_locations = {}\n\n    file_str.append(\"HIERARCHY\\n\")\n\n    def write_recursive_nodes(bone_name, indent):\n        my_children = children[bone_name]\n\n        indent_str = \"\\t\" * indent\n\n        bone = arm.bones[bone_name]\n        pose_bone = obj.pose.bones[bone_name]\n        loc = bone.head_local\n        node_locations[bone_name] = loc\n\n        if rotate_mode == \"NATIVE\":\n            rot_order_str = ensure_rot_order(pose_bone.rotation_mode)\n        else:\n            rot_order_str = rotate_mode\n\n        # make relative if we can\n        if bone.parent:\n            loc = loc - node_locations[bone.parent.name]\n\n        if indent:\n            file_str.append(\"%sJOINT %s\\n\" % (indent_str, bone_name))\n        else:\n            file_str.append(\"%sROOT %s\\n\" % (indent_str, bone_name))\n\n        file_str.append(\"%s{\\n\" % indent_str)\n        file_str.append(\"%s\\tOFFSET %.6f %.6f %.6f\\n\" % (indent_str, loc.x * global_scale, loc.y * global_scale, loc.z * global_scale))\n        if (bone.use_connect or root_transform_only) and bone.parent:\n            file_str.append(\"%s\\tCHANNELS 3 %srotation %srotation %srotation\\n\" % (indent_str, rot_order_str[0], rot_order_str[1], rot_order_str[2]))\n        else:\n            file_str.append(\"%s\\tCHANNELS 6 Xposition Yposition Zposition %srotation %srotation %srotation\\n\" % (indent_str, rot_order_str[0], rot_order_str[1], rot_order_str[2]))\n\n        if my_children:\n            # store the location for the children\n            # to get their relative offset\n\n            # Write children\n            for child_bone in my_children:\n                serialized_names.append(child_bone)\n                write_recursive_nodes(child_bone, indent + 1)\n\n        else:\n            # Write the bone end.\n            file_str.append(\"%s\\tEnd Site\\n\" % indent_str)\n            file_str.append(\"%s\\t{\\n\" % indent_str)\n            loc = bone.tail_local - node_locations[bone_name]\n            file_str.append(\"%s\\t\\tOFFSET %.6f %.6f %.6f\\n\" % (indent_str, loc.x * global_scale, loc.y * global_scale, loc.z * global_scale))\n            file_str.append(\"%s\\t}\\n\" % indent_str)\n\n        file_str.append(\"%s}\\n\" % indent_str)\n\n    if len(children[None]) == 1:\n        key = children[None][0]\n        serialized_names.append(key)\n        indent = 0\n\n        write_recursive_nodes(key, indent)\n\n    else:\n        # Write a dummy parent node, with a dummy key name\n        # Just be sure it's not used by another bone!\n        i = 0\n        key = \"__%d\" % i\n        while key in children:\n            i += 1\n            key = \"__%d\" % i\n        file_str.append(\"ROOT %s\\n\" % key)\n        file_str.append(\"{\\n\")\n        file_str.append(\"\\tOFFSET 0.0 0.0 0.0\\n\")\n        file_str.append(\"\\tCHANNELS 0\\n\")  # Xposition Yposition Zposition Xrotation Yrotation Zrotation\n        indent = 1\n\n        # Write children\n        for child_bone in children[None]:\n            serialized_names.append(child_bone)\n            write_recursive_nodes(child_bone, indent)\n\n        file_str.append(\"}\\n\")\n    file_str = ''.join(file_str)\n    # redefine bones as sorted by serialized_names\n    # so we can write motion\n\n    class DecoratedBone:\n        __slots__ = (\n            # Bone name, used as key in many places.\n            \"name\",\n            \"parent\",  # decorated bone parent, set in a later loop\n            # Blender armature bone.\n            \"rest_bone\",\n            # Blender pose bone.\n            \"pose_bone\",\n            # Blender pose matrix.\n            \"pose_mat\",\n            # Blender rest matrix (armature space).\n            \"rest_arm_mat\",\n            # Blender rest matrix (local space).\n            \"rest_local_mat\",\n            # Pose_mat inverted.\n            \"pose_imat\",\n            # Rest_arm_mat inverted.\n            \"rest_arm_imat\",\n            # Rest_local_mat inverted.\n            \"rest_local_imat\",\n            # Last used euler to preserve euler compatibility in between keyframes.\n            \"prev_euler\",\n            # Is the bone disconnected to the parent bone?\n            \"skip_position\",\n            \"rot_order\",\n            \"rot_order_str\",\n            # Needed for the euler order when converting from a matrix.\n            \"rot_order_str_reverse\",\n        )\n\n        _eul_order_lookup = {\n            'XYZ': (0, 1, 2),\n            'XZY': (0, 2, 1),\n            'YXZ': (1, 0, 2),\n            'YZX': (1, 2, 0),\n            'ZXY': (2, 0, 1),\n            'ZYX': (2, 1, 0),\n        }\n\n        def __init__(self, bone_name):\n            self.name = bone_name\n            self.rest_bone = arm.bones[bone_name]\n            self.pose_bone = obj.pose.bones[bone_name]\n\n            if rotate_mode == \"NATIVE\":\n                self.rot_order_str = ensure_rot_order(self.pose_bone.rotation_mode)\n            else:\n                self.rot_order_str = rotate_mode\n            self.rot_order_str_reverse = self.rot_order_str[::-1]\n\n            self.rot_order = DecoratedBone._eul_order_lookup[self.rot_order_str]\n\n            self.pose_mat = self.pose_bone.matrix\n\n            # mat = self.rest_bone.matrix  # UNUSED\n            self.rest_arm_mat = self.rest_bone.matrix_local\n            self.rest_local_mat = self.rest_bone.matrix\n\n            # inverted mats\n            self.pose_imat = self.pose_mat.inverted()\n            self.rest_arm_imat = self.rest_arm_mat.inverted()\n            self.rest_local_imat = self.rest_local_mat.inverted()\n\n            self.parent = None\n            self.prev_euler = Euler((0.0, 0.0, 0.0), self.rot_order_str_reverse)\n            self.skip_position = ((self.rest_bone.use_connect or root_transform_only) and self.rest_bone.parent)\n\n        def update_posedata(self):\n            self.pose_mat = self.pose_bone.matrix\n            self.pose_imat = self.pose_mat.inverted()\n\n        def __repr__(self):\n            if self.parent:\n                return \"[\\\"%s\\\" child on \\\"%s\\\"]\\n\" % (self.name, self.parent.name)\n            else:\n                return \"[\\\"%s\\\" root bone]\\n\" % (self.name)\n\n    bones_decorated = [DecoratedBone(bone_name) for bone_name in serialized_names]\n\n\n    # Assign parents\n    bones_decorated_dict = {dbone.name: dbone for dbone in bones_decorated}\n    for dbone in bones_decorated:\n        parent = dbone.rest_bone.parent\n        if parent:\n            dbone.parent = bones_decorated_dict[parent.name]\n    del bones_decorated_dict\n    # finish assigning parents\n\n    scene = context.scene\n    frame_current = scene.frame_current\n\n    file_str += \"MOTION\\n\"\n    file_str += \"Frames: %d\\n\" % (frame_end - frame_start + 1)\n    file_str += \"Frame Time: %.6f\\n\" % (1.0 / (scene.render.fps / scene.render.fps_base))\n\n    for frame in range(frame_start, frame_end + 1):\n        scene.frame_set(frame)\n\n        for dbone in bones_decorated:\n            dbone.update_posedata()\n\n        for dbone in bones_decorated:\n            trans = Matrix.Translation(dbone.rest_bone.head_local)\n            itrans = Matrix.Translation(-dbone.rest_bone.head_local)\n\n            if dbone.parent:\n                mat_final = dbone.parent.rest_arm_mat @ dbone.parent.pose_imat @ dbone.pose_mat @ dbone.rest_arm_imat\n                mat_final = itrans @ mat_final @ trans\n                loc = mat_final.to_translation() + (dbone.rest_bone.head_local - dbone.parent.rest_bone.head_local)\n            else:\n                mat_final = dbone.pose_mat @ dbone.rest_arm_imat\n                mat_final = itrans @ mat_final @ trans\n                loc = mat_final.to_translation() + dbone.rest_bone.head\n\n            # keep eulers compatible, no jumping on interpolation.\n            rot = mat_final.to_euler(dbone.rot_order_str_reverse, dbone.prev_euler)\n\n            if not dbone.skip_position:\n                file_str += \"%.6f %.6f %.6f \" % (loc * global_scale)[:]\n\n            file_str += \"%.6f %.6f %.6f \" % (degrees(rot[dbone.rot_order[0]]), degrees(rot[dbone.rot_order[1]]), degrees(rot[dbone.rot_order[2]]))\n\n            dbone.prev_euler = rot\n\n        file_str += \"\\n\"\n\n    scene.frame_set(frame_current)\n\n    return file_str\n\n\nclass BVH_Node:\n    __slots__ = (\n        # Bvh joint name.\n        'name',\n        # BVH_Node type or None for no parent.\n        'parent',\n        # A list of children of this type..\n        'children',\n        # Worldspace rest location for the head of this node.\n        'rest_head_world',\n        # Localspace rest location for the head of this node.\n        'rest_head_local',\n        # Worldspace rest location for the tail of this node.\n        'rest_tail_world',\n        # Worldspace rest location for the tail of this node.\n        'rest_tail_local',\n        # List of 6 ints, -1 for an unused channel,\n        # otherwise an index for the BVH motion data lines,\n        # loc triple then rot triple.\n        'channels',\n        # A triple of indices as to the order rotation is applied.\n        # [0,1,2] is x/y/z - [None, None, None] if no rotation..\n        'rot_order',\n        # Same as above but a string 'XYZ' format..\n        'rot_order_str',\n        # A list one tuple's one for each frame: (locx, locy, locz, rotx, roty, rotz),\n        # euler rotation ALWAYS stored xyz order, even when native used.\n        'anim_data',\n        # Convenience function, bool, same as: (channels[0] != -1 or channels[1] != -1 or channels[2] != -1).\n        'has_loc',\n        # Convenience function, bool, same as: (channels[3] != -1 or channels[4] != -1 or channels[5] != -1).\n        'has_rot',\n        # Index from the file, not strictly needed but nice to maintain order.\n        'index',\n        # Use this for whatever you want.\n        'temp',\n    )\n\n    _eul_order_lookup = {\n        (None, None, None): 'XYZ',  # XXX Dummy one, no rotation anyway!\n        (0, 1, 2): 'XYZ',\n        (0, 2, 1): 'XZY',\n        (1, 0, 2): 'YXZ',\n        (1, 2, 0): 'YZX',\n        (2, 0, 1): 'ZXY',\n        (2, 1, 0): 'ZYX',\n    }\n\n    def __init__(self, name, rest_head_world, rest_head_local, parent, channels, rot_order, index):\n        self.name = name\n        self.rest_head_world = rest_head_world\n        self.rest_head_local = rest_head_local\n        self.rest_tail_world = None\n        self.rest_tail_local = None\n        self.parent = parent\n        self.channels = channels\n        self.rot_order = tuple(rot_order)\n        self.rot_order_str = BVH_Node._eul_order_lookup[self.rot_order]\n        self.index = index\n\n        # convenience functions\n        self.has_loc = channels[0] != -1 or channels[1] != -1 or channels[2] != -1\n        self.has_rot = channels[3] != -1 or channels[4] != -1 or channels[5] != -1\n\n        self.children = []\n\n        # List of 6 length tuples: (lx, ly, lz, rx, ry, rz)\n        # even if the channels aren't used they will just be zero.\n        self.anim_data = [(0, 0, 0, 0, 0, 0)]\n\n    def __repr__(self):\n        return (\n            \"BVH name: '%s', rest_loc:(%.3f,%.3f,%.3f), rest_tail:(%.3f,%.3f,%.3f)\" % (\n                self.name,\n                *self.rest_head_world,\n                *self.rest_head_world,\n            )\n        )\n\n\ndef sorted_nodes(bvh_nodes):\n    bvh_nodes_list = list(bvh_nodes.values())\n    bvh_nodes_list.sort(key=lambda bvh_node: bvh_node.index)\n    return bvh_nodes_list\n\n\ndef read_bvh(context, bvh_str, rotate_mode='XYZ', global_scale=1.0):\n    # Separate into a list of lists, each line a list of words.\n    file_lines = bvh_str\n    # Non standard carriage returns?\n    if len(file_lines) == 1:\n        file_lines = file_lines[0].split('\\r')\n\n    # Split by whitespace.\n    file_lines = [ll for ll in [l.split() for l in file_lines] if ll]\n\n    # Create hierarchy as empties\n    if file_lines[0][0].lower() == 'hierarchy':\n        # print 'Importing the BVH Hierarchy for:', file_path\n        pass\n    else:\n        raise Exception(\"This is not a BVH file\")\n\n    bvh_nodes = {None: None}\n    bvh_nodes_serial = [None]\n    bvh_frame_count = None\n    bvh_frame_time = None\n\n    channelIndex = -1\n\n    lineIdx = 0  # An index for the file.\n    while lineIdx < len(file_lines) - 1:\n        if file_lines[lineIdx][0].lower() in {'root', 'joint'}:\n\n            # Join spaces into 1 word with underscores joining it.\n            if len(file_lines[lineIdx]) > 2:\n                file_lines[lineIdx][1] = '_'.join(file_lines[lineIdx][1:])\n                file_lines[lineIdx] = file_lines[lineIdx][:2]\n\n            # MAY NEED TO SUPPORT MULTIPLE ROOTS HERE! Still unsure weather multiple roots are possible?\n\n            # Make sure the names are unique - Object names will match joint names exactly and both will be unique.\n            name = file_lines[lineIdx][1]\n\n            # print '%snode: %s, parent: %s' % (len(bvh_nodes_serial) * '  ', name,  bvh_nodes_serial[-1])\n\n            lineIdx += 2  # Increment to the next line (Offset)\n            rest_head_local = global_scale * Vector((\n                float(file_lines[lineIdx][1]),\n                float(file_lines[lineIdx][2]),\n                float(file_lines[lineIdx][3]),\n            ))\n            lineIdx += 1  # Increment to the next line (Channels)\n\n            # newChannel[Xposition, Yposition, Zposition, Xrotation, Yrotation, Zrotation]\n            # newChannel references indices to the motiondata,\n            # if not assigned then -1 refers to the last value that will be added on loading at a value of zero, this is appended\n            # We'll add a zero value onto the end of the MotionDATA so this always refers to a value.\n            my_channel = [-1, -1, -1, -1, -1, -1]\n            my_rot_order = [None, None, None]\n            rot_count = 0\n            for channel in file_lines[lineIdx][2:]:\n                channel = channel.lower()\n                channelIndex += 1  # So the index points to the right channel\n                if channel == 'xposition':\n                    my_channel[0] = channelIndex\n                elif channel == 'yposition':\n                    my_channel[1] = channelIndex\n                elif channel == 'zposition':\n                    my_channel[2] = channelIndex\n\n                elif channel == 'xrotation':\n                    my_channel[3] = channelIndex\n                    my_rot_order[rot_count] = 0\n                    rot_count += 1\n                elif channel == 'yrotation':\n                    my_channel[4] = channelIndex\n                    my_rot_order[rot_count] = 1\n                    rot_count += 1\n                elif channel == 'zrotation':\n                    my_channel[5] = channelIndex\n                    my_rot_order[rot_count] = 2\n                    rot_count += 1\n\n            channels = file_lines[lineIdx][2:]\n\n            my_parent = bvh_nodes_serial[-1]  # account for none\n\n            # Apply the parents offset accumulatively\n            if my_parent is None:\n                rest_head_world = Vector(rest_head_local)\n            else:\n                rest_head_world = my_parent.rest_head_world + rest_head_local\n\n            bvh_node = bvh_nodes[name] = BVH_Node(\n                name,\n                rest_head_world,\n                rest_head_local,\n                my_parent,\n                my_channel,\n                my_rot_order,\n                len(bvh_nodes) - 1,\n            )\n\n            # If we have another child then we can call ourselves a parent, else\n            bvh_nodes_serial.append(bvh_node)\n\n        # Account for an end node.\n        # There is sometimes a name after 'End Site' but we will ignore it.\n        if file_lines[lineIdx][0].lower() == 'end' and file_lines[lineIdx][1].lower() == 'site':\n            # Increment to the next line (Offset)\n            lineIdx += 2\n            rest_tail = global_scale * Vector((\n                float(file_lines[lineIdx][1]),\n                float(file_lines[lineIdx][2]),\n                float(file_lines[lineIdx][3]),\n            ))\n\n            bvh_nodes_serial[-1].rest_tail_world = bvh_nodes_serial[-1].rest_head_world + rest_tail\n            bvh_nodes_serial[-1].rest_tail_local = bvh_nodes_serial[-1].rest_head_local + rest_tail\n\n            # Just so we can remove the parents in a uniform way,\n            # the end has kids so this is a placeholder.\n            bvh_nodes_serial.append(None)\n\n        if len(file_lines[lineIdx]) == 1 and file_lines[lineIdx][0] == '}':  # == ['}']\n            bvh_nodes_serial.pop()  # Remove the last item\n\n        # End of the hierarchy. Begin the animation section of the file with\n        # the following header.\n        #  MOTION\n        #  Frames: n\n        #  Frame Time: dt\n        if len(file_lines[lineIdx]) == 1 and file_lines[lineIdx][0].lower() == 'motion':\n            lineIdx += 1  # Read frame count.\n            if (\n                    len(file_lines[lineIdx]) == 2 and\n                    file_lines[lineIdx][0].lower() == 'frames:'\n            ):\n                bvh_frame_count = int(file_lines[lineIdx][1])\n\n            lineIdx += 1  # Read frame rate.\n            if (\n                    len(file_lines[lineIdx]) == 3 and\n                    file_lines[lineIdx][0].lower() == 'frame' and\n                    file_lines[lineIdx][1].lower() == 'time:'\n            ):\n                bvh_frame_time = float(file_lines[lineIdx][2])\n\n            lineIdx += 1  # Set the cursor to the first frame\n\n            break\n\n        lineIdx += 1\n\n    # Remove the None value used for easy parent reference\n    del bvh_nodes[None]\n    # Don't use anymore\n    del bvh_nodes_serial\n\n    # importing world with any order but nicer to maintain order\n    # second life expects it, which isn't to spec.\n    bvh_nodes_list = sorted_nodes(bvh_nodes)\n\n    while lineIdx < len(file_lines):\n        line = file_lines[lineIdx]\n        for bvh_node in bvh_nodes_list:\n            # for bvh_node in bvh_nodes_serial:\n            lx = ly = lz = rx = ry = rz = 0.0\n            channels = bvh_node.channels\n            anim_data = bvh_node.anim_data\n            if channels[0] != -1:\n                lx = global_scale * float(line[channels[0]])\n\n            if channels[1] != -1:\n                ly = global_scale * float(line[channels[1]])\n\n            if channels[2] != -1:\n                lz = global_scale * float(line[channels[2]])\n\n            if channels[3] != -1 or channels[4] != -1 or channels[5] != -1:\n\n                rx = radians(float(line[channels[3]]))\n                ry = radians(float(line[channels[4]]))\n                rz = radians(float(line[channels[5]]))\n\n            # Done importing motion data #\n            anim_data.append((lx, ly, lz, rx, ry, rz))\n        lineIdx += 1\n\n    # Assign children\n    for bvh_node in bvh_nodes_list:\n        bvh_node_parent = bvh_node.parent\n        if bvh_node_parent:\n            bvh_node_parent.children.append(bvh_node)\n\n    # Now set the tip of each bvh_node\n    for bvh_node in bvh_nodes_list:\n\n        if not bvh_node.rest_tail_world:\n            if len(bvh_node.children) == 0:\n                # could just fail here, but rare BVH files have childless nodes\n                bvh_node.rest_tail_world = Vector(bvh_node.rest_head_world)\n                bvh_node.rest_tail_local = Vector(bvh_node.rest_head_local)\n            elif len(bvh_node.children) == 1:\n                bvh_node.rest_tail_world = Vector(bvh_node.children[0].rest_head_world)\n                bvh_node.rest_tail_local = bvh_node.rest_head_local + bvh_node.children[0].rest_head_local\n            else:\n                # allow this, see above\n                # if not bvh_node.children:\n                #     raise Exception(\"bvh node has no end and no children. bad file\")\n\n                # Removed temp for now\n                rest_tail_world = Vector((0.0, 0.0, 0.0))\n                rest_tail_local = Vector((0.0, 0.0, 0.0))\n                for bvh_node_child in bvh_node.children:\n                    rest_tail_world += bvh_node_child.rest_head_world\n                    rest_tail_local += bvh_node_child.rest_head_local\n\n                bvh_node.rest_tail_world = rest_tail_world * (1.0 / len(bvh_node.children))\n                bvh_node.rest_tail_local = rest_tail_local * (1.0 / len(bvh_node.children))\n\n        # Make sure tail isn't the same location as the head.\n        if (bvh_node.rest_tail_local - bvh_node.rest_head_local).length <= 0.001 * global_scale:\n            print(\"\\tzero length node found:\", bvh_node.name)\n            bvh_node.rest_tail_local.y = bvh_node.rest_tail_local.y + global_scale / 10\n            bvh_node.rest_tail_world.y = bvh_node.rest_tail_world.y + global_scale / 10\n\n    return bvh_nodes, bvh_frame_time, bvh_frame_count\n\n\ndef bvh_node_dict2objects(context, bvh_name, bvh_nodes, rotate_mode='NATIVE', frame_start=1, IMPORT_LOOP=False):\n\n    if frame_start < 1:\n        frame_start = 1\n\n    scene = context.scene\n    for obj in scene.objects:\n        obj.select_set(False)\n\n    objects = []\n\n    def add_ob(name):\n        obj = bpy.data.objects.new(name, None)\n        context.collection.objects.link(obj)\n        objects.append(obj)\n        obj.select_set(True)\n\n        # nicer drawing.\n        obj.empty_display_type = 'CUBE'\n        obj.empty_display_size = 0.1\n\n        return obj\n\n    # Add objects\n    for name, bvh_node in bvh_nodes.items():\n        bvh_node.temp = add_ob(name)\n        bvh_node.temp.rotation_mode = bvh_node.rot_order_str[::-1]\n\n    # Parent the objects\n    for bvh_node in bvh_nodes.values():\n        for bvh_node_child in bvh_node.children:\n            bvh_node_child.temp.parent = bvh_node.temp\n\n    # Offset\n    for bvh_node in bvh_nodes.values():\n        # Make relative to parents offset\n        bvh_node.temp.location = bvh_node.rest_head_local\n\n    # Add tail objects\n    for name, bvh_node in bvh_nodes.items():\n        if not bvh_node.children:\n            ob_end = add_ob(name + '_end')\n            ob_end.parent = bvh_node.temp\n            ob_end.location = bvh_node.rest_tail_world - bvh_node.rest_head_world\n\n    for name, bvh_node in bvh_nodes.items():\n        obj = bvh_node.temp\n\n        for frame_current in range(len(bvh_node.anim_data)):\n\n            lx, ly, lz, rx, ry, rz = bvh_node.anim_data[frame_current]\n\n            if bvh_node.has_loc:\n                obj.delta_location = Vector((lx, ly, lz)) - bvh_node.rest_head_world\n                obj.keyframe_insert(\"delta_location\", index=-1, frame=frame_start + frame_current)\n\n            if bvh_node.has_rot:\n                obj.delta_rotation_euler = rx, ry, rz\n                obj.keyframe_insert(\"delta_rotation_euler\", index=-1, frame=frame_start + frame_current)\n\n    return objects\n\n\ndef bvh_node_dict2armature(\n        context,\n        bvh_name,\n        bvh_nodes,\n        bvh_frame_time,\n        rotate_mode='XYZ',\n        frame_start=1,\n        IMPORT_LOOP=False,\n        global_matrix=None,\n        use_fps_scale=False,\n        original_rest_pose=None  # New parameter for the original rest pose\n):\n    if frame_start < 1:\n        frame_start = 1\n\n    scene = context.scene\n    for obj in scene.objects:\n        obj.select_set(False)\n\n    arm_data = bpy.data.armatures.new(bvh_name)\n    arm_ob = bpy.data.objects.new(bvh_name, arm_data)\n\n    context.collection.objects.link(arm_ob)\n\n    arm_ob.select_set(True)\n    context.view_layer.objects.active = arm_ob\n\n    bpy.ops.object.mode_set(mode='EDIT', toggle=False)\n\n    bvh_nodes_list = sorted_nodes(bvh_nodes)\n\n    # Get the average bone length for zero length bones\n    average_bone_length = 0.0\n    nonzero_count = 0\n    for bvh_node in bvh_nodes_list:\n        l = (bvh_node.rest_head_local - bvh_node.rest_tail_local).length\n        if l:\n            average_bone_length += l\n            nonzero_count += 1\n\n    if not average_bone_length:\n        average_bone_length = 0.1\n    else:\n        average_bone_length = average_bone_length / nonzero_count\n\n    while arm_data.edit_bones:\n        arm_ob.edit_bones.remove(arm_data.edit_bones[-1])\n\n    ZERO_AREA_BONES = []\n    # First pass: Create all bones and assign to temp\n    for bvh_node in bvh_nodes_list:\n        bone = arm_data.edit_bones.new(bvh_node.name)\n\n        # Use the original rest pose if provided, otherwise fall back to BVH data\n        if original_rest_pose and bvh_node.name in original_rest_pose:\n            bone.head = original_rest_pose[bvh_node.name]['head']\n            bone.tail = original_rest_pose[bvh_node.name]['tail']\n            bone.roll = original_rest_pose[bvh_node.name]['roll']\n        else:\n            bone.head = bvh_node.rest_head_world\n            bone.tail = bvh_node.rest_tail_world\n\n            # Handle zero-length bones\n            if (bone.head - bone.tail).length < 0.001:\n                print(\"\\tzero length bone found:\", bone.name)\n                if bvh_node.parent:\n                    ofs = bvh_node.parent.rest_head_local - bvh_node.parent.rest_tail_local\n                    if ofs.length:\n                        bone.tail = bone.tail - ofs\n                    else:\n                        bone.tail.y = bone.tail.y + average_bone_length\n                else:\n                    bone.tail.y = bone.tail.y + average_bone_length\n\n                ZERO_AREA_BONES.append(bvh_node.name)\n\n        # Assign the edit bone to the temp attribute\n        bvh_node.temp = bone\n\n    # Second pass: Set parenting and connection\n    for bvh_node in bvh_nodes_list:\n        if bvh_node.parent:\n            # Now bvh_node.temp and bvh_node.parent.temp should both be valid\n            bvh_node.temp.parent = bvh_node.parent.temp\n\n            if (\n                (not bvh_node.has_loc) and\n                (bvh_node.parent.temp.name not in ZERO_AREA_BONES) and\n                (bvh_node.parent.rest_tail_local == bvh_node.rest_head_local)\n            ):\n                bvh_node.temp.use_connect = True\n\n    # Replace temp with bone name for later use\n    for bvh_node in bvh_nodes_list:\n        bvh_node.temp = bvh_node.temp.name\n\n    bpy.ops.object.mode_set(mode='OBJECT', toggle=False)\n\n    pose = arm_ob.pose\n    pose_bones = pose.bones\n\n    if rotate_mode == 'NATIVE':\n        for bvh_node in bvh_nodes_list:\n            bone_name = bvh_node.temp\n            pose_bone = pose_bones[bone_name]\n            pose_bone.rotation_mode = bvh_node.rot_order_str\n    elif rotate_mode != 'QUATERNION':\n        for pose_bone in pose_bones:\n            pose_bone.rotation_mode = rotate_mode\n\n    context.view_layer.update()\n\n    arm_ob.animation_data_create()\n    action = bpy.data.actions.new(name=bvh_name)\n    arm_ob.animation_data.action = action\n\n    num_frame = 0\n    for bvh_node in bvh_nodes_list:\n        bone_name = bvh_node.temp\n        pose_bone = pose_bones[bone_name]\n        rest_bone = arm_data.bones[bone_name]\n        bone_rest_matrix = rest_bone.matrix_local.to_3x3()\n\n        bone_rest_matrix_inv = Matrix(bone_rest_matrix)\n        bone_rest_matrix_inv.invert()\n\n        bone_rest_matrix_inv.resize_4x4()\n        bone_rest_matrix.resize_4x4()\n        bvh_node.temp = (pose_bone, rest_bone, bone_rest_matrix, bone_rest_matrix_inv)\n\n        if 0 == num_frame:\n            num_frame = len(bvh_node.anim_data)\n\n    skip_frame = 1\n    if num_frame > skip_frame:\n        num_frame = num_frame - skip_frame\n\n    time = [float(frame_start)] * num_frame\n    if use_fps_scale:\n        dt = scene.render.fps * bvh_frame_time\n        for frame_i in range(1, num_frame):\n            time[frame_i] += float(frame_i) * dt\n    else:\n        for frame_i in range(1, num_frame):\n            time[frame_i] += float(frame_i)\n\n    for i, bvh_node in enumerate(bvh_nodes_list):\n        pose_bone, bone, bone_rest_matrix, bone_rest_matrix_inv = bvh_node.temp\n\n        if bvh_node.has_loc:\n            data_path = f'pose.bones[\"{pose_bone.name}\"].location'\n            location = [(0.0, 0.0, 0.0)] * num_frame\n            for frame_i in range(num_frame):\n                bvh_loc = bvh_node.anim_data[frame_i + skip_frame][:3]\n                bone_translate_matrix = Matrix.Translation(\n                    Vector(bvh_loc) - bvh_node.rest_head_local)\n                location[frame_i] = (bone_rest_matrix_inv @\n                                     bone_translate_matrix).to_translation()\n\n            for axis_i in range(3):\n                curve = action.fcurves.new(data_path=data_path, index=axis_i, action_group=bvh_node.name)\n                keyframe_points = curve.keyframe_points\n                keyframe_points.add(num_frame)\n                for frame_i in range(num_frame):\n                    keyframe_points[frame_i].co = (\n                        time[frame_i],\n                        location[frame_i][axis_i],\n                    )\n\n        if bvh_node.has_rot:\n            data_path = None\n            rotate = None\n            if 'QUATERNION' == rotate_mode:\n                rotate = [(1.0, 0.0, 0.0, 0.0)] * num_frame\n                data_path = f'pose.bones[\"{pose_bone.name}\"].rotation_quaternion'\n            else:\n                rotate = [(0.0, 0.0, 0.0)] * num_frame\n                data_path = f'pose.bones[\"{pose_bone.name}\"].rotation_euler'\n\n            prev_euler = Euler((0.0, 0.0, 0.0))\n            for frame_i in range(num_frame):\n                bvh_rot = bvh_node.anim_data[frame_i + skip_frame][3:]\n                euler = Euler(bvh_rot, bvh_node.rot_order_str[::-1])\n                bone_rotation_matrix = euler.to_matrix().to_4x4()\n                bone_rotation_matrix = (\n                    bone_rest_matrix_inv @\n                    bone_rotation_matrix @\n                    bone_rest_matrix\n                )\n\n                if len(rotate[frame_i]) == 4:\n                    rotate[frame_i] = bone_rotation_matrix.to_quaternion()\n                else:\n                    rotate[frame_i] = bone_rotation_matrix.to_euler(\n                        pose_bone.rotation_mode, prev_euler)\n                    prev_euler = rotate[frame_i]\n\n            for axis_i in range(len(rotate[0])):\n                curve = action.fcurves.new(data_path=data_path, index=axis_i, action_group=bvh_node.name)\n                keyframe_points = curve.keyframe_points\n                keyframe_points.add(num_frame)\n                for frame_i in range(num_frame):\n                    keyframe_points[frame_i].co = (\n                        time[frame_i],\n                        rotate[frame_i][axis_i],\n                    )\n\n    for cu in action.fcurves:\n        if IMPORT_LOOP:\n            pass\n        for bez in cu.keyframe_points:\n            bez.interpolation = 'LINEAR'\n\n    try:\n        arm_ob.matrix_world = global_matrix\n    except:\n        pass\n    bpy.ops.object.transform_apply(location=False, rotation=True, scale=False)\n\n    return arm_ob\n\n\ndef load(\n        context,\n        bvh_str,\n        *,\n        target='ARMATURE',\n        rotate_mode='NATIVE',\n        global_scale=1.0,\n        use_cyclic=False,\n        frame_start=1,\n        global_matrix=None,\n        use_fps_scale=False,\n        update_scene_fps=False,\n        update_scene_duration=False,\n        original_rest_pose=None,\n        bvh_name='synsized',  # Added parameter\n        report=print,\n):\n    import time\n    t1 = time.time()\n\n    bvh_nodes, bvh_frame_time, bvh_frame_count = read_bvh(\n        context, bvh_str,\n        rotate_mode=rotate_mode,\n        global_scale=global_scale,\n    )\n\n    print(\"%.4f\" % (time.time() - t1))\n\n    scene = context.scene\n    frame_orig = scene.frame_current\n\n    if bvh_frame_time is None:\n        report(\n            {'WARNING'},\n            \"The BVH file does not contain frame duration in its MOTION \"\n            \"section, assuming the BVH and Blender scene have the same \"\n            \"frame rate\"\n        )\n        bvh_frame_time = scene.render.fps_base / scene.render.fps\n        use_fps_scale = False\n\n    if update_scene_fps:\n        _update_scene_fps(context, report, bvh_frame_time)\n        use_fps_scale = False\n\n    if update_scene_duration:\n        _update_scene_duration(context, report, bvh_frame_count, bvh_frame_time, frame_start, use_fps_scale)\n\n    t1 = time.time()\n    print(\"\\timporting to blender...\", end=\"\")\n\n    if target == 'ARMATURE':\n        bvh_node_dict2armature(\n            context, bvh_name, bvh_nodes, bvh_frame_time,\n            rotate_mode=rotate_mode,\n            frame_start=frame_start,\n            IMPORT_LOOP=use_cyclic,\n            global_matrix=global_matrix,\n            use_fps_scale=use_fps_scale,\n            original_rest_pose=original_rest_pose\n        )\n    elif target == 'OBJECT':\n        bvh_node_dict2objects(\n            context, bvh_name, bvh_nodes,\n            rotate_mode=rotate_mode,\n            frame_start=frame_start,\n            IMPORT_LOOP=use_cyclic,\n        )\n    else:\n        report({'ERROR'}, tip_(\"Invalid target %r (must be 'ARMATURE' or 'OBJECT')\") % target)\n        return {'CANCELLED'}\n\n    print('Done in %.4f\\n' % (time.time() - t1))\n    context.scene.frame_set(frame_orig)\n    return {'FINISHED'}\n\n\ndef _update_scene_fps(context, report, bvh_frame_time):\n    \"\"\"Update the scene's FPS settings from the BVH, but only if the BVH contains enough info.\"\"\"\n\n    # Broken BVH handling: prevent division by zero.\n    if bvh_frame_time == 0.0:\n        report(\n            {'WARNING'},\n            \"Unable to update scene frame rate, as the BVH file \"\n            \"contains a zero frame duration in its MOTION section\",\n        )\n        return\n\n    scene = context.scene\n    scene_fps = scene.render.fps / scene.render.fps_base\n    new_fps = 1.0 / bvh_frame_time\n\n    if scene.render.fps != new_fps or scene.render.fps_base != 1.0:\n        print(\"\\tupdating scene FPS (was %f) to BVH FPS (%f)\" % (scene_fps, new_fps))\n    scene.render.fps = int(round(new_fps))\n    scene.render.fps_base = scene.render.fps / new_fps\n\n\ndef _update_scene_duration(\n        context, report, bvh_frame_count, bvh_frame_time, frame_start,\n        use_fps_scale):\n    \"\"\"Extend the scene's duration so that the BVH file fits in its entirety.\"\"\"\n\n    if bvh_frame_count is None:\n        report(\n            {'WARNING'},\n            \"Unable to extend the scene duration, as the BVH file does not \"\n            \"contain the number of frames in its MOTION section\",\n        )\n        return\n\n    # Not likely, but it can happen when a BVH is just used to store an armature.\n    if bvh_frame_count == 0:\n        return\n\n    if use_fps_scale:\n        scene_fps = context.scene.render.fps / context.scene.render.fps_base\n        scaled_frame_count = int(ceil(bvh_frame_count * bvh_frame_time * scene_fps))\n        bvh_last_frame = frame_start + scaled_frame_count\n    else:\n        bvh_last_frame = frame_start + bvh_frame_count\n\n    # Only extend the scene, never shorten it.\n    if context.scene.frame_end < bvh_last_frame:\n        context.scene.frame_end = bvh_last_frame\n\n\n# This function is from\n# https://github.com/yuki-koyama/blender-cli-rendering\ndef set_smooth_shading(mesh: bpy.types.Mesh) -> None:\n    for polygon in mesh.polygons:\n        polygon.use_smooth = True\n\n\n# This function is from\n# https://github.com/yuki-koyama/blender-cli-rendering\ndef create_mesh_from_pydata(scene: bpy.types.Scene,\n                            vertices: Iterable[Iterable[float]],\n                            faces: Iterable[Iterable[int]],\n                            mesh_name: str,\n                            object_name: str,\n                            use_smooth: bool = True) -> bpy.types.Object:\n    # Add a new mesh and set vertices and faces\n    # Note: In this case, it does not require to set edges.\n    # Note: After manipulating mesh data, update() needs to be called.\n    new_mesh: bpy.types.Mesh = bpy.data.meshes.new(mesh_name)\n    new_mesh.from_pydata(vertices, [], faces)\n    new_mesh.update()\n    if use_smooth:\n        set_smooth_shading(new_mesh)\n\n    new_object: bpy.types.Object = bpy.data.objects.new(object_name, new_mesh)\n    scene.collection.objects.link(new_object)\n\n    return new_object\n\n\n# This function is from\n# https://github.com/yuki-koyama/blender-cli-rendering\ndef add_subdivision_surface_modifier(mesh_object: bpy.types.Object, level: int, is_simple: bool = False) -> None:\n    '''\n    https://docs.blender.org/api/current/bpy.types.SubsurfModifier.html\n    '''\n\n    modifier: bpy.types.SubsurfModifier = mesh_object.modifiers.new(name=\"Subsurf\", type='SUBSURF')\n\n    modifier.levels = level\n    modifier.render_levels = level\n    modifier.subdivision_type = 'SIMPLE' if is_simple else 'CATMULL_CLARK'\n\n\n# This function is from\n# https://github.com/yuki-koyama/blender-cli-rendering\ndef create_armature_mesh(scene: bpy.types.Scene, armature_object: bpy.types.Object, mesh_name: str) -> bpy.types.Object:\n    assert armature_object.type == 'ARMATURE', 'Error'\n    assert len(armature_object.data.bones) != 0, 'Error'\n\n    def add_rigid_vertex_group(target_object: bpy.types.Object, name: str, vertex_indices: Iterable[int]) -> None:\n        new_vertex_group = target_object.vertex_groups.new(name=name)\n        for vertex_index in vertex_indices:\n            new_vertex_group.add([vertex_index], 1.0, 'REPLACE')\n\n    def generate_bone_mesh_pydata(radius: float, length: float) -> Tuple[List[mathutils.Vector], List[List[int]]]:\n        base_radius = radius\n        top_radius = 0.5 * radius\n\n        vertices = [\n            # Cross section of the base part\n            mathutils.Vector((-base_radius, 0.0, +base_radius)),\n            mathutils.Vector((+base_radius, 0.0, +base_radius)),\n            mathutils.Vector((+base_radius, 0.0, -base_radius)),\n            mathutils.Vector((-base_radius, 0.0, -base_radius)),\n\n            # Cross section of the top part\n            mathutils.Vector((-top_radius, length, +top_radius)),\n            mathutils.Vector((+top_radius, length, +top_radius)),\n            mathutils.Vector((+top_radius, length, -top_radius)),\n            mathutils.Vector((-top_radius, length, -top_radius)),\n\n            # End points\n            mathutils.Vector((0.0, -base_radius, 0.0)),\n            mathutils.Vector((0.0, length + top_radius, 0.0))\n        ]\n\n        faces = [\n            # End point for the base part\n            [8, 1, 0],\n            [8, 2, 1],\n            [8, 3, 2],\n            [8, 0, 3],\n\n            # End point for the top part\n            [9, 4, 5],\n            [9, 5, 6],\n            [9, 6, 7],\n            [9, 7, 4],\n\n            # Side faces\n            [0, 1, 5, 4],\n            [1, 2, 6, 5],\n            [2, 3, 7, 6],\n            [3, 0, 4, 7],\n        ]\n\n        return vertices, faces\n\n    armature_data: bpy.types.Armature = armature_object.data\n\n    vertices: List[mathutils.Vector] = []\n    faces: List[List[int]] = []\n    vertex_groups: List[Dict[str, Any]] = []\n\n    for bone in armature_data.bones:\n        radius = 0.10 * (0.10 + bone.length)\n        temp_vertices, temp_faces = generate_bone_mesh_pydata(radius, bone.length)\n\n        vertex_index_offset = len(vertices)\n\n        temp_vertex_group = {'name': bone.name, 'vertex_indices': []}\n        for local_index, vertex in enumerate(temp_vertices):\n            vertices.append(bone.matrix_local @ vertex)\n            temp_vertex_group['vertex_indices'].append(local_index + vertex_index_offset)\n        vertex_groups.append(temp_vertex_group)\n\n        for face in temp_faces:\n            if len(face) == 3:\n                faces.append([\n                    face[0] + vertex_index_offset,\n                    face[1] + vertex_index_offset,\n                    face[2] + vertex_index_offset,\n                ])\n            else:\n                faces.append([\n                    face[0] + vertex_index_offset,\n                    face[1] + vertex_index_offset,\n                    face[2] + vertex_index_offset,\n                    face[3] + vertex_index_offset,\n                ])\n\n    new_object = create_mesh_from_pydata(scene, vertices, faces, mesh_name, mesh_name)\n    new_object.matrix_world = armature_object.matrix_world\n\n    for vertex_group in vertex_groups:\n        add_rigid_vertex_group(new_object, vertex_group['name'], vertex_group['vertex_indices'])\n\n    armature_modifier = new_object.modifiers.new('Armature', 'ARMATURE')\n    armature_modifier.object = armature_object\n    armature_modifier.use_vertex_groups = True\n\n    add_subdivision_surface_modifier(new_object, 1, is_simple=True)\n    add_subdivision_surface_modifier(new_object, 2, is_simple=False)\n\n    # Set the armature as the parent of the new object\n    bpy.ops.object.select_all(action='DESELECT')\n    new_object.select_set(True)\n    armature_object.select_set(True)\n    bpy.context.view_layer.objects.active = armature_object\n    bpy.ops.object.parent_set(type='OBJECT')\n\n    return new_object\n\n\nclass OP_AddMesh(bpy.types.Operator):\n    bl_idname = \"genmm.add_mesh\"\n    bl_label = \"Add mesh\"\n    bl_description = \"\"\n    bl_options = {\"REGISTER\", \"UNDO\"}\n\n    def __init__(self) -> None:\n        super().__init__()\n\n    def execute(self, context: bpy.types.Context):\n        name = bpy.context.object.name + \"_proxy\"\n        create_armature_mesh(bpy.context.scene, bpy.context.object, name)\n        return {'FINISHED'}\n\nclass OP_RunSynthesis(bpy.types.Operator):\n    bl_idname = \"genmm.run_synthesis\"\n    bl_label = \"Run synthesis\"\n    bl_description = \"\"\n    bl_options = {\"REGISTER\", \"UNDO\"}\n\n    def execute(self, context: bpy.types.Context):\n        setting = context.scene.setting\n        original_armature = context.object\n        rest_pose_data = capture_rest_pose(original_armature)\n\n        anim = original_armature.animation_data.action\n        start_frame, end_frame = map(int, anim.frame_range)\n        start_frame = start_frame if setting.start_frame == -1 else setting.start_frame\n        end_frame = end_frame if setting.end_frame == -1 else setting.end_frame\n\n        bvh_str = get_bvh_data(context,\n                               frame_start=start_frame,\n                               frame_end=end_frame)\n        frames_str, frame_time_str = bvh_str.split('MOTION\\n')[1].split('\\n')[:2]\n        motion_data_str = bvh_str.split('MOTION\\n')[1].split('\\n')[2:-1]\n        motion_data = np.array([item.strip().split(' ') for item in motion_data_str], dtype=np.float32)\n\n        model = GenMM(device='cuda' if torch.cuda.is_available() else 'cpu', silent=True)\n        criteria = PatchCoherentLoss(patch_size=setting.patch_size, \n                                     alpha=setting.alpha, \n                                     loop=setting.loop, cache=True)\n\n        for i in range(setting.num_output):\n            print(f\"Generating motion {i+1} of {setting.num_output}\")\n            # Create a new BlenderMotion instance for each iteration\n            motion = [BlenderMotion(motion_data.copy(), repr='repr6d', use_velo=True, \n                                    keep_up_pos=True, up_axis=setting.up_axis, padding_last=False)]\n            syn = model.run(motion, criteria,\n                            num_frames=str(setting.num_syn_frames),\n                            num_steps=setting.num_steps,\n                            noise_sigma=setting.noise,\n                            patch_size=setting.patch_size, \n                            coarse_ratio=f'{setting.coarse_ratio}x_nframes',\n                            pyr_factor=setting.pyr_factor)\n            motion_data_str = [' '.join(str(x) for x in item) for item in motion[0].parse(syn)]\n            bvh_name = f\"synsized_{i+1}\"\n            load(context,\n                 bvh_str.split('MOTION\\n')[0].split('\\n') + ['MOTION'] + [frames_str] + [frame_time_str] + motion_data_str,\n                 rotate_mode='QUATERNION',\n                 global_matrix=original_armature.matrix_world,\n                 original_rest_pose=rest_pose_data,\n                 target='ARMATURE',\n                 use_fps_scale=False,\n                 bvh_name=bvh_name)\n\n        return {'FINISHED'}\n\nclass GENMM_PT_ControlPanel(bpy.types.Panel):\n    bl_label = \"GenMM\"\n    bl_space_type = 'VIEW_3D'\n    bl_region_type = 'UI'\n    bl_category = \"GenMM\"\n\n    @classmethod\n    def poll(cls, context: bpy.types.Context):\n        return True\n\n    def draw_header(self, context: bpy.types.Context):\n        layout = self.layout\n        layout.label(text=\"\", icon='PLUGIN')\n\n    def draw(self, context: bpy.types.Context):\n        layout = self.layout\n        scene = bpy.context.scene\n\n        ops: List[bpy.types.Operator] = [\n            OP_AddMesh,\n        ]\n        for op in ops:\n            layout.operator(op.bl_idname, text=op.bl_label)\n        \n        box = layout.box()\n        box.label(text=\"Exemplar config:\")\n        exemplar_row = box.row()\n        exemplar_row.prop(scene.setting, \"start_frame\")\n        exemplar_row.prop(scene.setting, \"end_frame\")\n        exemplar_row = box.row()\n        exemplar_row.prop(scene.setting, \"up_axis\")\n\n        box = layout.box()\n        box.label(text=\"Synthesis config:\")\n        box.prop(scene.setting, \"loop\")\n        box.prop(scene.setting, \"noise\")\n        box.prop(scene.setting, \"num_syn_frames\")\n        box.prop(scene.setting, \"patch_size\")\n        box.prop(scene.setting, \"coarse_ratio\")\n        box.prop(scene.setting, \"pyr_factor\")\n        box.prop(scene.setting, \"alpha\")\n        box.prop(scene.setting, \"num_steps\")\n        box.prop(scene.setting, \"num_output\")  # New parameter\n\n        ops: List[bpy.types.Operator] = [\n            OP_RunSynthesis,\n        ]\n        for op in ops:\n            layout.operator(op.bl_idname, text=op.bl_label)\n\nclass PropertyGroup(bpy.types.PropertyGroup):\n    '''Property container for options and paths of GenMM'''\n    start_frame: bpy.props.IntProperty(\n        name=\"Start Frame\",\n        description=\"Start Frame of the Exemplar Motion.\",\n        default=1)\n    end_frame: bpy.props.IntProperty(\n        name=\"End Frame\",\n        description=\"End Frame of the Exemplar Motion.\",\n        default=-1)\n    up_axis: bpy.props.EnumProperty(\n        name=\"Up Axis\", \n        default='Z_UP',\n        description=\"Up axis of the Exemplar Motion\",\n        items=[('Z_UP', \"Z-Up\", 'Z Up'),\n               ('Y_UP', \"Y-Up\", 'Y Up'),\n               ('X_UP', \"X-Up\", 'X Up'),\n               ]\n    )\n    noise: bpy.props.FloatProperty(\n        name=\"Noise Intensity\",\n        description=\"Intensity of Noise Added to the Synthesized Motion.\",\n        default=10)\n    num_syn_frames: bpy.props.IntProperty(\n        name=\"Num. of Frames\",\n        description=\"Number of the Synthesized Motion.\",\n        default=600)\n    patch_size: bpy.props.IntProperty(\n        name=\"Patch Size\",\n        description=\"Size for Patch Extraction.\",\n        min=7,\n        default=15)\n    coarse_ratio: bpy.props.FloatProperty(\n        name=\"Coarse Ratio\",\n        description=\"Ratio of the Coarest Pyramid.\",\n        min=0.0,\n        default=0.2)\n    pyr_factor: bpy.props.FloatProperty(\n        name=\"Pyramid Factor\",\n        description=\"Pyramid Downsample Factor.\",\n        min=0.1,\n        default=0.75)\n    alpha: bpy.props.FloatProperty(\n        name=\"Completeness Alpha\",\n        description=\"Alpha Value for Completeness/Diversity Trade-off.\",\n        default=0.05)\n    loop: bpy.props.BoolProperty(\n        name=\"Endless Loop\",\n        description=\"Whether to Use Loop Constrain.\",\n        default=False)\n    num_steps: bpy.props.IntProperty(\n        name=\"Num of Steps\",\n        description=\"Number of Optimized Steps.\",\n        default=5)\n    num_output: bpy.props.IntProperty(\n        name=\"Num. of Output\",\n        description=\"Number of different motions to generate.\",\n        min=1,\n        default=1)\n\nclasses = [\n    OP_AddMesh,\n    OP_RunSynthesis,\n    GENMM_PT_ControlPanel,\n]\n\ndef register():\n    bpy.utils.register_class(PropertyGroup)\n    bpy.types.Scene.setting = bpy.props.PointerProperty(type=PropertyGroup)\n    for cls in classes:\n        bpy.utils.register_class(cls)\n\ndef unregister():\n    bpy.utils.unregister_class(PropertyGroup)\n    for cls in classes:\n        bpy.utils.unregister_class(cls)\n\nif __name__ == \"__main__\":\n    register()"
  },
  {
    "path": "configs/default.yaml",
    "content": "# motion data config\nrepr: 'repr6d'\nskeleton_name: null\nuse_velo: true\nkeep_up_pos: true\nup_axis: 'Y_UP'\npadding_last: false\nrequires_contact: false\njoint_reduction: false\nskeleton_aware: false\njoints_group: null\n\n# generate parameters\nnum_frames: '2x_nframes'\nalpha: 0.01\nnum_steps: 3\nnoise_sigma: 10.0\ncoarse_ratio: '5x_patchsize'\n# coarse_ratio: '0.2x_nframes'\npyr_factor: 0.75\nnum_stages_limit: -1\npatch_size: 11\nloop: false"
  },
  {
    "path": "configs/ganimator.yaml",
    "content": "################################################################\n# This configuration uses the same input format of GANimmator for generation\n################################################################\noutout_dir: './output/ganimator_format'\n\n# for GANimator BVH data\nrepr: 'repr6d'\nskeleton_name: 'mixamo'\nuse_velo: true\nkeep_up_pos: true\nup_axis: 'Y_UP'\npadding_last: true\nrequires_contact: true\njoint_reduction: true\nskeleton_aware: false\njoints_group: null\n\n# generate parameters\nnum_frames: '2x_nframes'\nalpha: 0.01\nnum_steps: 3\nnoise_sigma: 10.0\ncoarse_ratio: '3x_patchsize'\n# coarse_ratio: '0.1x_nframes'\npyr_factor: 0.75\nnum_stages_limit: -1\npatch_size: 11\nloop: false"
  },
  {
    "path": "dataset/blender_motion.py",
    "content": "import os\nimport os.path as osp\nimport torch\nimport numpy as np\nimport torch.nn.functional as F\nfrom .motion import MotionData\nfrom utils.transforms import quat2repr6d, euler2mat, mat2quat, repr6d2quat, quat2euler\n\nclass BlenderMotion:\n    def __init__(self, motion_data, repr='quat', use_velo=True, keep_up_pos=True, up_axis=None, padding_last=False):\n        '''\n        BVHMotion constructor\n        Args:\n            motion_data      : np.array, bvh format data to load from\n            repr             : string, rotation representation, support ['quat', 'repr6d', 'euler'] \n            use_velo         : book, whether to transform the joints positions to velocities\n            keep_up_pos      : bool, whether to keep y position when converting to velocity\n            up_axis          : string, up axis of the motion data\n            padding_last     : bool, whether to pad the last position\n            requires_contact : bool, whether to concatenate contact information\n        '''\n        self.motion_data = motion_data\n\n        def to_tensor(motion_data, repr='euler', rot_only=False):\n            if repr not in ['euler', 'quat', 'quaternion', 'repr6d']:\n                raise Exception('Unknown rotation representation')\n            if repr == 'quaternion' or repr == 'quat' or repr == 'repr6d': # default is euler for blender data\n                rotations = torch.tensor(motion_data[:, 3:], dtype=torch.float).view(motion_data.shape[0], -1, 3)\n            if repr == 'quat':\n                rotations = euler2mat(rotations)\n                rotations = mat2quat(rotations)\n            if repr == 'repr6d':\n                rotations = euler2mat(rotations)\n                rotations = mat2quat(rotations)\n                rotations = quat2repr6d(rotations)\n\n            positions = torch.tensor(motion_data[:, :3], dtype=torch.float32)\n\n            if rot_only:\n                return rotations.reshape(rotations.shape[0], -1)\n\n            rotations = rotations.reshape(rotations.shape[0], -1)\n            return torch.cat((rotations, positions), dim=-1)\n        \n        self.motion_data = MotionData(to_tensor(motion_data, repr=repr).permute(1, 0).unsqueeze(0), repr=repr, use_velo=use_velo, \n                                      keep_up_pos=keep_up_pos, up_axis=up_axis, padding_last=padding_last, contact_id=None)\n    @property\n    def repr(self):\n        return self.motion_data.repr\n\n    @property\n    def use_velo(self):\n        return self.motion_data.use_velo\n\n    @property\n    def keep_up_pos(self):\n        return self.motion_data.keep_up_pos\n    \n    @property\n    def padding_last(self):\n        return self.motion_data.padding_last\n    \n    @property\n    def concat_id(self):\n        return self.motion_data.contact_id\n    \n    @property\n    def n_pad(self):\n        return self.motion_data.n_pad\n    \n    @property\n    def n_contact(self):\n        return self.motion_data.n_contact\n\n    @property\n    def n_rot(self):\n        return self.motion_data.n_rot\n\n    def sample(self, size=None, slerp=False):\n        '''\n        Sample motion data, support slerp\n        '''\n        return self.motion_data.sample(size, slerp)\n\n    def parse(self, motion, keep_velo=False,):\n        \"\"\"\n        No batch support here!!!\n        :returns tracks_json\n        \"\"\"\n        motion = motion.clone()\n\n        if self.use_velo and not keep_velo:\n            motion = self.motion_data.to_position(motion)\n        if self.n_pad:\n            motion = motion[:, :-self.n_pad]\n\n        motion = motion.squeeze().permute(1, 0)\n        pos = motion[..., -3:]\n        rot = motion[..., :-3].reshape(motion.shape[0], -1, self.n_rot)\n        if self.repr == 'quat':\n            rot = quat2euler(rot)\n        elif self.repr == 'repr6d':\n            rot = repr6d2quat(rot)\n            rot = quat2euler(rot)\n\n        return torch.cat([pos, rot.view(motion.shape[0], -1)], dim=-1).cpu().numpy()\n"
  },
  {
    "path": "dataset/bvh/Quaternions.py",
    "content": "\"\"\"\nThis code is modified from:\nhttp://theorangeduck.com/page/deep-learning-framework-character-motion-synthesis-and-editing\n\nby Daniel Holden et al\n\"\"\"\n\n\nimport numpy as np\n\nclass Quaternions:\n    \"\"\"\n    Quaternions is a wrapper around a numpy ndarray\n    that allows it to act as if it were an narray of\n    a quater data type.\n    \n    Therefore addition, subtraction, multiplication,\n    division, negation, absolute, are all defined\n    in terms of quater operations such as quater\n    multiplication.\n    \n    This allows for much neater code and many routines\n    which conceptually do the same thing to be written\n    in the same way for point data and for rotation data.\n    \n    The Quaternions class has been desgined such that it\n    should support broadcasting and slicing in all of the\n    usual ways.\n    \"\"\"\n    \n    def __init__(self, qs):\n        if isinstance(qs, np.ndarray):\n            if len(qs.shape) == 1: qs = np.array([qs])\n            self.qs = qs\n            return\n\n        if isinstance(qs, Quaternions):\n            self.qs = qs\n            return\n\n        raise TypeError('Quaternions must be constructed from iterable, numpy array, or Quaternions, not %s' % type(qs))\n    \n    def __str__(self): return \"Quaternions(\"+ str(self.qs) + \")\"\n    def __repr__(self): return \"Quaternions(\"+ repr(self.qs) + \")\"\n    \n    \"\"\" Helper Methods for Broadcasting and Data extraction \"\"\"\n    \n    @classmethod\n    def _broadcast(cls, sqs, oqs, scalar=False):\n        if isinstance(oqs, float): return sqs, oqs * np.ones(sqs.shape[:-1])\n        \n        ss = np.array(sqs.shape) if not scalar else np.array(sqs.shape[:-1])\n        os = np.array(oqs.shape)\n\n        if len(ss) != len(os):\n            raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape))\n            \n        if np.all(ss == os): return sqs, oqs\n        \n        if not np.all((ss == os) | (os == np.ones(len(os))) | (ss == np.ones(len(ss)))):\n            raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape))\n\n        sqsn, oqsn = sqs.copy(), oqs.copy()\n\n        for a in np.where(ss == 1)[0]: sqsn = sqsn.repeat(os[a], axis=a)\n        for a in np.where(os == 1)[0]: oqsn = oqsn.repeat(ss[a], axis=a)\n        \n        return sqsn, oqsn\n        \n    \"\"\" Adding Quaterions is just Defined as Multiplication \"\"\"\n    \n    def __add__(self, other): return self * other\n    def __sub__(self, other): return self / other\n    \n    \"\"\" Quaterion Multiplication \"\"\"\n    \n    def __mul__(self, other):\n        \"\"\"\n        Quaternion multiplication has three main methods.\n        \n        When multiplying a Quaternions array by Quaternions\n        normal quater multiplication is performed.\n        \n        When multiplying a Quaternions array by a vector\n        array of the same shape, where the last axis is 3,\n        it is assumed to be a Quaternion by 3D-Vector \n        multiplication and the 3D-Vectors are rotated\n        in space by the Quaternions.\n        \n        When multipplying a Quaternions array by a scalar\n        or vector of different shape it is assumed to be\n        a Quaternions by Scalars multiplication and the\n        Quaternions are scaled using Slerp and the identity\n        quaternions.\n        \"\"\"\n        \n        \"\"\" If Quaternions type do Quaternions * Quaternions \"\"\"\n        if isinstance(other, Quaternions):\n            sqs, oqs = Quaternions._broadcast(self.qs, other.qs)\n\n            q0 = sqs[...,0]; q1 = sqs[...,1]; \n            q2 = sqs[...,2]; q3 = sqs[...,3]; \n            r0 = oqs[...,0]; r1 = oqs[...,1]; \n            r2 = oqs[...,2]; r3 = oqs[...,3]; \n            \n            qs = np.empty(sqs.shape)\n            qs[...,0] = r0 * q0 - r1 * q1 - r2 * q2 - r3 * q3\n            qs[...,1] = r0 * q1 + r1 * q0 - r2 * q3 + r3 * q2\n            qs[...,2] = r0 * q2 + r1 * q3 + r2 * q0 - r3 * q1\n            qs[...,3] = r0 * q3 - r1 * q2 + r2 * q1 + r3 * q0\n            \n            return Quaternions(qs)\n        \n        \"\"\" If array type do Quaternions * Vectors \"\"\"\n        if isinstance(other, np.ndarray) and other.shape[-1] == 3:\n            vs = Quaternions(np.concatenate([np.zeros(other.shape[:-1] + (1,)), other], axis=-1))\n\n            return (self * (vs * -self)).imaginaries\n\n        \"\"\" If float do Quaternions * Scalars \"\"\"\n        if isinstance(other, np.ndarray) or isinstance(other, float):\n            return Quaternions.slerp(Quaternions.id_like(self), self, other)\n        \n        raise TypeError('Cannot multiply/add Quaternions with type %s' % str(type(other)))\n        \n    def __div__(self, other):\n        \"\"\"\n        When a Quaternion type is supplied, division is defined\n        as multiplication by the inverse of that Quaternion.\n        \n        When a scalar or vector is supplied it is defined\n        as multiplicaion of one over the supplied value.\n        Essentially a scaling.\n        \"\"\"\n        \n        if isinstance(other, Quaternions): return self * (-other)\n        if isinstance(other, np.ndarray): return self * (1.0 / other)\n        if isinstance(other, float): return self * (1.0 / other)\n        raise TypeError('Cannot divide/subtract Quaternions with type %s' + str(type(other)))\n        \n    def __eq__(self, other): return self.qs == other.qs\n    def __ne__(self, other): return self.qs != other.qs\n    \n    def __neg__(self):\n        \"\"\" Invert Quaternions \"\"\"\n        return Quaternions(self.qs * np.array([[1, -1, -1, -1]]))\n    \n    def __abs__(self):\n        \"\"\" Unify Quaternions To Single Pole \"\"\"\n        qabs = self.normalized().copy()\n        top = np.sum(( qabs.qs) * np.array([1,0,0,0]), axis=-1)\n        bot = np.sum((-qabs.qs) * np.array([1,0,0,0]), axis=-1)\n        qabs.qs[top < bot] = -qabs.qs[top <  bot]\n        return qabs\n    \n    def __iter__(self): return iter(self.qs)\n    def __len__(self): return len(self.qs)\n    \n    def __getitem__(self, k):    return Quaternions(self.qs[k]) \n    def __setitem__(self, k, v): self.qs[k] = v.qs\n        \n    @property\n    def lengths(self):\n        return np.sum(self.qs**2.0, axis=-1)**0.5\n    \n    @property\n    def reals(self):\n        return self.qs[...,0]\n        \n    @property\n    def imaginaries(self):\n        return self.qs[...,1:4]\n    \n    @property\n    def shape(self): return self.qs.shape[:-1]\n    \n    def repeat(self, n, **kwargs):\n        return Quaternions(self.qs.repeat(n, **kwargs))\n    \n    def normalized(self):\n        return Quaternions(self.qs / self.lengths[...,np.newaxis])\n    \n    def log(self):\n        norm = abs(self.normalized())\n        imgs = norm.imaginaries\n        lens = np.sqrt(np.sum(imgs**2, axis=-1))\n        lens = np.arctan2(lens, norm.reals) / (lens + 1e-10)\n        return imgs * lens[...,np.newaxis]\n    \n    def constrained(self, axis):\n        \n        rl = self.reals\n        im = np.sum(axis * self.imaginaries, axis=-1)\n        \n        t1 = -2 * np.arctan2(rl, im) + np.pi\n        t2 = -2 * np.arctan2(rl, im) - np.pi\n        \n        top = Quaternions.exp(axis[np.newaxis] * (t1[:,np.newaxis] / 2.0))\n        bot = Quaternions.exp(axis[np.newaxis] * (t2[:,np.newaxis] / 2.0))\n        img = self.dot(top) > self.dot(bot)\n        \n        ret = top.copy()\n        ret[ img] = top[ img]\n        ret[~img] = bot[~img]\n        return ret\n    \n    def constrained_x(self): return self.constrained(np.array([1,0,0]))\n    def constrained_y(self): return self.constrained(np.array([0,1,0]))\n    def constrained_z(self): return self.constrained(np.array([0,0,1]))\n    \n    def dot(self, q): return np.sum(self.qs * q.qs, axis=-1)\n    \n    def copy(self): return Quaternions(np.copy(self.qs))\n    \n    def reshape(self, s):\n        self.qs.reshape(s)\n        return self\n    \n    def interpolate(self, ws):\n        return Quaternions.exp(np.average(abs(self).log, axis=0, weights=ws))\n    \n    def euler(self, order='xyz'):\n        \n        q = self.normalized().qs\n        q0 = q[...,0]\n        q1 = q[...,1]\n        q2 = q[...,2]\n        q3 = q[...,3]\n        es = np.zeros(self.shape + (3,))\n\n        # These version is wrong on converting\n        '''\n        if   order == 'xyz':\n            es[...,0] = np.arctan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))\n            es[...,1] = np.arcsin((2 * (q0 * q2 - q3 * q1)).clip(-1,1))\n            es[...,2] = np.arctan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))\n        elif order == 'yzx':\n            es[...,0] = np.arctan2(2 * (q1 * q0 - q2 * q3), -q1 * q1 + q2 * q2 - q3 * q3 + q0 * q0)\n            es[...,1] = np.arctan2(2 * (q2 * q0 - q1 * q3),  q1 * q1 - q2 * q2 - q3 * q3 + q0 * q0)\n            es[...,2] = np.arcsin((2 * (q1 * q2 + q3 * q0)).clip(-1,1))\n        else:\n            raise NotImplementedError('Cannot convert from ordering %s' % order)\n        \n        '''\n        \n        if   order == 'xyz':\n            es[..., 2] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)\n            es[..., 1] = np.arcsin((2 * (q1 * q3 + q0 * q2)).clip(-1,1))\n            es[..., 0] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)\n        else:\n            raise NotImplementedError('Cannot convert from ordering %s' % order)\n\n        # These conversion don't appear to work correctly for Maya.\n        # http://bediyap.com/programming/convert-quaternion-to-euler-rotations/\n        '''\n        if   order == 'xyz':\n            es[..., 0] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)\n            es[..., 1] = np.arcsin((2 * (q1 * q3 + q0 * q2)).clip(-1,1))\n            es[..., 2] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)\n        elif order == 'yzx':\n            es[fa + (0,)] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3)\n            es[fa + (1,)] = np.arcsin((2 * (q1 * q2 + q0 * q3)).clip(-1,1))\n            es[fa + (2,)] = np.arctan2(2 * (q0 * q2 - q1 * q3), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)\n        elif order == 'zxy':\n            es[fa + (0,)] = np.arctan2(2 * (q0 * q2 - q1 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)\n            es[fa + (1,)] = np.arcsin((2 * (q0 * q1 + q2 * q3)).clip(-1,1))\n            es[fa + (2,)] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3) \n        elif order == 'xzy':\n            es[fa + (0,)] = np.arctan2(2 * (q0 * q2 + q1 * q3), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)\n            es[fa + (1,)] = np.arcsin((2 * (q0 * q3 - q1 * q2)).clip(-1,1))\n            es[fa + (2,)] = np.arctan2(2 * (q0 * q1 + q2 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3)\n        elif order == 'yxz':\n            es[fa + (0,)] = np.arctan2(2 * (q1 * q2 + q0 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3)\n            es[fa + (1,)] = np.arcsin((2 * (q0 * q1 - q2 * q3)).clip(-1,1))\n            es[fa + (2,)] = np.arctan2(2 * (q1 * q3 + q0 * q2), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)\n        elif order == 'zyx':\n            es[fa + (0,)] = np.arctan2(2 * (q0 * q1 + q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)\n            es[fa + (1,)] = np.arcsin((2 * (q0 * q2 - q1 * q3)).clip(-1,1))\n            es[fa + (2,)] = np.arctan2(2 * (q0 * q3 + q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)\n        \n        else:\n            raise KeyError('Unknown ordering %s' % order)\n        '''\n\n        \n        # https://github.com/ehsan/ogre/blob/master/OgreMain/src/OgreMatrix3.cpp\n        # Use this class and convert from matrix\n        \n        return es\n        \n    \n    def average(self):\n        \n        if len(self.shape) == 1:\n            \n            import numpy.core.umath_tests as ut\n            system = ut.matrix_multiply(self.qs[:,:,np.newaxis], self.qs[:,np.newaxis,:]).sum(axis=0)\n            w, v = np.linalg.eigh(system)\n            qiT_dot_qref = (self.qs[:,:,np.newaxis] * v[np.newaxis,:,:]).sum(axis=1)\n            return Quaternions(v[:,np.argmin((1.-qiT_dot_qref**2).sum(axis=0))])            \n        \n        else:\n            \n            raise NotImplementedError('Cannot average multi-dimensionsal Quaternions')\n\n    def angle_axis(self):\n        \n        norm = self.normalized()        \n        s = np.sqrt(1 - (norm.reals**2.0))\n        s[s == 0] = 0.001\n        \n        angles = 2.0 * np.arccos(norm.reals)\n        axis = norm.imaginaries / s[...,np.newaxis]\n        \n        return angles, axis\n        \n    \n    def transforms(self):\n        \n        qw = self.qs[...,0]\n        qx = self.qs[...,1]\n        qy = self.qs[...,2]\n        qz = self.qs[...,3]\n        \n        x2 = qx + qx; y2 = qy + qy; z2 = qz + qz;\n        xx = qx * x2; yy = qy * y2; wx = qw * x2;\n        xy = qx * y2; yz = qy * z2; wy = qw * y2;\n        xz = qx * z2; zz = qz * z2; wz = qw * z2;\n\n        m = np.empty(self.shape + (3,3))\n        m[...,0,0] = 1.0 - (yy + zz)\n        m[...,0,1] = xy - wz\n        m[...,0,2] = xz + wy\n        m[...,1,0] = xy + wz\n        m[...,1,1] = 1.0 - (xx + zz)\n        m[...,1,2] = yz - wx\n        m[...,2,0] = xz - wy\n        m[...,2,1] = yz + wx\n        m[...,2,2] = 1.0 - (xx + yy)\n        \n        return m\n    \n    def ravel(self):\n        return self.qs.ravel()\n    \n    @classmethod\n    def id(cls, n):\n        \n        if isinstance(n, tuple):\n            qs = np.zeros(n + (4,))\n            qs[...,0] = 1.0\n            return Quaternions(qs)\n        \n        if isinstance(n, int) or isinstance(n, long):\n            qs = np.zeros((n,4))\n            qs[:,0] = 1.0\n            return Quaternions(qs)\n        \n        raise TypeError('Cannot Construct Quaternion from %s type' % str(type(n)))\n\n    @classmethod\n    def id_like(cls, a):\n        qs = np.zeros(a.shape + (4,))\n        qs[...,0] = 1.0\n        return Quaternions(qs)\n        \n    @classmethod\n    def exp(cls, ws):\n    \n        ts = np.sum(ws**2.0, axis=-1)**0.5\n        ts[ts == 0] = 0.001\n        ls = np.sin(ts) / ts\n        \n        qs = np.empty(ws.shape[:-1] + (4,))\n        qs[...,0] = np.cos(ts)\n        qs[...,1] = ws[...,0] * ls\n        qs[...,2] = ws[...,1] * ls\n        qs[...,3] = ws[...,2] * ls\n        \n        return Quaternions(qs).normalized()\n        \n    @classmethod\n    def slerp(cls, q0s, q1s, a):\n        \n        fst, snd = cls._broadcast(q0s.qs, q1s.qs)\n        fst, a = cls._broadcast(fst, a, scalar=True)\n        snd, a = cls._broadcast(snd, a, scalar=True)\n        \n        len = np.sum(fst * snd, axis=-1)\n        \n        neg = len < 0.0\n        len[neg] = -len[neg]\n        snd[neg] = -snd[neg]\n        \n        amount0 = np.zeros(a.shape)\n        amount1 = np.zeros(a.shape)\n\n        linear = (1.0 - len) < 0.01\n        omegas = np.arccos(len[~linear])\n        sinoms = np.sin(omegas)\n        \n        amount0[ linear] = 1.0 - a[linear]\n        amount1[ linear] =       a[linear]\n        amount0[~linear] = np.sin((1.0 - a[~linear]) * omegas) / sinoms\n        amount1[~linear] = np.sin(       a[~linear]  * omegas) / sinoms\n        \n        return Quaternions(\n            amount0[...,np.newaxis] * fst + \n            amount1[...,np.newaxis] * snd)\n    \n    @classmethod\n    def between(cls, v0s, v1s):\n        a = np.cross(v0s, v1s)\n        w = np.sqrt((v0s**2).sum(axis=-1) * (v1s**2).sum(axis=-1)) + (v0s * v1s).sum(axis=-1)\n        return Quaternions(np.concatenate([w[...,np.newaxis], a], axis=-1)).normalized()\n    \n    @classmethod\n    def from_angle_axis(cls, angles, axis):\n        axis    = axis / (np.sqrt(np.sum(axis**2, axis=-1)) + 1e-10)[...,np.newaxis]\n        sines   = np.sin(angles / 2.0)[...,np.newaxis]\n        cosines = np.cos(angles / 2.0)[...,np.newaxis]\n        return Quaternions(np.concatenate([cosines, axis * sines], axis=-1))\n    \n    @classmethod\n    def from_euler(cls, es, order='xyz', world=False):\n    \n        axis = {\n            'x' : np.array([1,0,0]),\n            'y' : np.array([0,1,0]),\n            'z' : np.array([0,0,1]),\n        }\n        \n        q0s = Quaternions.from_angle_axis(es[...,0], axis[order[0]])\n        q1s = Quaternions.from_angle_axis(es[...,1], axis[order[1]])\n        q2s = Quaternions.from_angle_axis(es[...,2], axis[order[2]])\n        \n        return (q2s * (q1s * q0s)) if world else (q0s * (q1s * q2s))\n    \n    @classmethod\n    def from_transforms(cls, ts):\n        \n        d0, d1, d2 = ts[...,0,0], ts[...,1,1], ts[...,2,2]\n        \n        q0 = ( d0 + d1 + d2 + 1.0) / 4.0\n        q1 = ( d0 - d1 - d2 + 1.0) / 4.0\n        q2 = (-d0 + d1 - d2 + 1.0) / 4.0\n        q3 = (-d0 - d1 + d2 + 1.0) / 4.0\n        \n        q0 = np.sqrt(q0.clip(0,None))\n        q1 = np.sqrt(q1.clip(0,None))\n        q2 = np.sqrt(q2.clip(0,None))\n        q3 = np.sqrt(q3.clip(0,None))\n        \n        c0 = (q0 >= q1) & (q0 >= q2) & (q0 >= q3)\n        c1 = (q1 >= q0) & (q1 >= q2) & (q1 >= q3)\n        c2 = (q2 >= q0) & (q2 >= q1) & (q2 >= q3)\n        c3 = (q3 >= q0) & (q3 >= q1) & (q3 >= q2)\n        \n        q1[c0] *= np.sign(ts[c0,2,1] - ts[c0,1,2])\n        q2[c0] *= np.sign(ts[c0,0,2] - ts[c0,2,0])\n        q3[c0] *= np.sign(ts[c0,1,0] - ts[c0,0,1])\n        \n        q0[c1] *= np.sign(ts[c1,2,1] - ts[c1,1,2])\n        q2[c1] *= np.sign(ts[c1,1,0] + ts[c1,0,1])\n        q3[c1] *= np.sign(ts[c1,0,2] + ts[c1,2,0])  \n        \n        q0[c2] *= np.sign(ts[c2,0,2] - ts[c2,2,0])\n        q1[c2] *= np.sign(ts[c2,1,0] + ts[c2,0,1])\n        q3[c2] *= np.sign(ts[c2,2,1] + ts[c2,1,2])  \n        \n        q0[c3] *= np.sign(ts[c3,1,0] - ts[c3,0,1])\n        q1[c3] *= np.sign(ts[c3,2,0] + ts[c3,0,2])\n        q2[c3] *= np.sign(ts[c3,2,1] + ts[c3,1,2])  \n        \n        qs = np.empty(ts.shape[:-2] + (4,))\n        qs[...,0] = q0\n        qs[...,1] = q1\n        qs[...,2] = q2\n        qs[...,3] = q3\n        \n        return cls(qs)\n"
  },
  {
    "path": "dataset/bvh/bvh_io.py",
    "content": "\"\"\"\nThis code is modified from:\nhttp://theorangeduck.com/page/deep-learning-framework-character-motion-synthesis-and-editing\n\nby Daniel Holden et al\n\"\"\"\n\n\nimport re\nimport numpy as np\nfrom dataset.bvh.Quaternions import Quaternions\n\nchannelmap = {\n    'Xrotation' : 'x',\n    'Yrotation' : 'y',\n    'Zrotation' : 'z'   \n}\n\nchannelmap_inv = {\n    'x': 'Xrotation',\n    'y': 'Yrotation',\n    'z': 'Zrotation',\n}\n\nordermap = {\n    'x': 0,\n    'y': 1,\n    'z': 2,\n}\n\n\nclass Animation:\n    def __init__(self, rotations, positions, orients, offsets, parents, names, frametime):\n        self.rotations = rotations\n        self.positions = positions\n        self.orients   = orients\n        self.offsets   = offsets\n        self.parent    = parents\n        self.names     = names\n        self.frametime = frametime\n\n    @property\n    def shape(self):\n        return self.rotations.shape\n\n\ndef load(filename, start=None, end=None, order=None, world=False, need_quater=False) -> Animation:\n    \"\"\"\n    Reads a BVH file and constructs an animation\n\n    Parameters\n    ----------\n    filename: str\n        File to be opened\n\n    start : int\n        Optional Starting Frame\n\n    end : int\n        Optional Ending Frame\n\n    order : str\n        Optional Specifier for joint order.\n        Given as string E.G 'xyz', 'zxy'\n\n    world : bool\n        If set to true euler angles are applied\n        together in world space rather than local\n        space\n    Returns\n    -------\n\n    (animation, joint_names, frametime)\n        Tuple of loaded animation and joint names\n    \"\"\"\n\n    f = open(filename, \"r\")\n\n    i = 0\n    active = -1\n    end_site = False\n\n    names = []\n    orients = Quaternions.id(0)\n    offsets = np.array([]).reshape((0, 3))\n    parents = np.array([], dtype=int)\n    orders = []\n\n    for line in f:\n\n        if \"HIERARCHY\" in line: continue\n        if \"MOTION\" in line: continue\n\n        \"\"\" Modified line read to handle mixamo data \"\"\"\n        #        rmatch = re.match(r\"ROOT (\\w+)\", line)\n        rmatch = re.match(r\"ROOT (\\w+:?\\w+)\", line)\n        if rmatch:\n            names.append(rmatch.group(1))\n            offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0)\n            orients.qs = np.append(orients.qs, np.array([[1, 0, 0, 0]]), axis=0)\n            parents = np.append(parents, active)\n            active = (len(parents) - 1)\n            continue\n\n        if \"{\" in line: continue\n\n        if \"}\" in line:\n            if end_site:\n                end_site = False\n            else:\n                active = parents[active]\n            continue\n\n        offmatch = re.match(r\"\\s*OFFSET\\s+([\\-\\d\\.e]+)\\s+([\\-\\d\\.e]+)\\s+([\\-\\d\\.e]+)\", line)\n        if offmatch:\n            if not end_site:\n                offsets[active] = np.array([list(map(float, offmatch.groups()))])\n            continue\n\n        chanmatch = re.match(r\"\\s*CHANNELS\\s+(\\d+)\", line)\n        if chanmatch:\n            channels = int(chanmatch.group(1))\n\n            channelis = 0 if channels == 3 else 3\n            channelie = 3 if channels == 3 else 6\n            parts = line.split()[2 + channelis:2 + channelie]\n            if any([p not in channelmap for p in parts]):\n                continue\n            order = \"\".join([channelmap[p] for p in parts])\n            orders.append(order)\n            continue\n\n        \"\"\" Modified line read to handle mixamo data \"\"\"\n        #        jmatch = re.match(\"\\s*JOINT\\s+(\\w+)\", line)\n        jmatch = re.match(\"\\s*JOINT\\s+(\\w+:?\\w+)\", line)\n        if jmatch:\n            names.append(jmatch.group(1))\n            offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0)\n            orients.qs = np.append(orients.qs, np.array([[1, 0, 0, 0]]), axis=0)\n            parents = np.append(parents, active)\n            active = (len(parents) - 1)\n            continue\n\n        if \"End Site\" in line:\n            end_site = True\n            continue\n\n        fmatch = re.match(\"\\s*Frames:\\s+(\\d+)\", line)\n        if fmatch:\n            if start and end:\n                fnum = (end - start) - 1\n            else:\n                fnum = int(fmatch.group(1))\n            jnum = len(parents)\n            positions = offsets[np.newaxis].repeat(fnum, axis=0)\n            rotations = np.zeros((fnum, len(orients), 3))\n            continue\n\n        fmatch = re.match(\"\\s*Frame Time:\\s+([\\d\\.]+)\", line)\n        if fmatch:\n            frametime = float(fmatch.group(1))\n            continue\n\n        if (start and end) and (i < start or i >= end - 1):\n            i += 1\n            continue\n\n        # dmatch = line.strip().split(' ')\n        dmatch = line.strip().split()\n        if dmatch:\n            data_block = np.array(list(map(float, dmatch)))\n            N = len(parents)\n            fi = i - start if start else i\n            if channels == 3:\n                positions[fi, 0:1] = data_block[0:3]\n                rotations[fi, :] = data_block[3:].reshape(N, 3)\n            elif channels == 6:\n                data_block = data_block.reshape(N, 6)\n                positions[fi, :] = data_block[:, 0:3]\n                rotations[fi, :] = data_block[:, 3:6]\n            elif channels == 9:\n                positions[fi, 0] = data_block[0:3]\n                data_block = data_block[3:].reshape(N - 1, 9)\n                rotations[fi, 1:] = data_block[:, 3:6]\n                positions[fi, 1:] += data_block[:, 0:3] * data_block[:, 6:9]\n            else:\n                raise Exception(\"Too many channels! %i\" % channels)\n\n            i += 1\n\n    f.close()\n\n    all_rotations = []\n    canonical_order = 'xyz'\n    for i, order in enumerate(orders):\n        rot = rotations[:, i:i + 1]\n        if need_quater:\n            quat = Quaternions.from_euler(np.radians(rot), order=order, world=world)\n            all_rotations.append(quat)\n            continue\n        elif order != canonical_order:\n            quat = Quaternions.from_euler(np.radians(rot), order=order, world=world)\n            rot = np.degrees(quat.euler(order=canonical_order))\n        all_rotations.append(rot)\n    rotations = np.concatenate(all_rotations, axis=1)\n\n    return Animation(rotations, positions, orients, offsets, parents, names, frametime)\n\n    \ndef save(filename, anim, names=None, frametime=1.0/24.0, order='zyx', positions=False, orients=True):\n    \"\"\"\n    Saves an Animation to file as BVH\n    \n    Parameters\n    ----------\n    filename: str\n        File to be saved to\n        \n    anim : Animation\n        Animation to save\n        \n    names : [str]\n        List of joint names\n    \n    order : str\n        Optional Specifier for joint order.\n        Given as string E.G 'xyz', 'zxy'\n    \n    frametime : float\n        Optional Animation Frame time\n        \n    positions : bool\n        Optional specfier to save bone\n        positions for each frame\n        \n    orients : bool\n        Multiply joint orients to the rotations\n        before saving.\n        \n    \"\"\"\n    \n    if names is None:\n        names = [\"joint_\" + str(i) for i in range(len(anim.parents))]\n    \n    with open(filename, 'w') as f:\n\n        t = \"\"\n        f.write(\"%sHIERARCHY\\n\" % t)\n        f.write(\"%sROOT %s\\n\" % (t, names[0]))\n        f.write(\"%s{\\n\" % t)\n        t += '\\t'\n\n        f.write(\"%sOFFSET %f %f %f\\n\" % (t, anim.offsets[0,0], anim.offsets[0,1], anim.offsets[0,2]) )\n        f.write(\"%sCHANNELS 6 Xposition Yposition Zposition %s %s %s \\n\" % \n            (t, channelmap_inv[order[0]], channelmap_inv[order[1]], channelmap_inv[order[2]]))\n\n        for i in range(anim.shape[1]):\n            if anim.parents[i] == 0:\n                t = save_joint(f, anim, names, t, i, order=order, positions=positions)\n\n        t = t[:-1]\n        f.write(\"%s}\\n\" % t)\n\n        f.write(\"MOTION\\n\")\n        f.write(\"Frames: %i\\n\" % anim.shape[0]);\n        f.write(\"Frame Time: %f\\n\" % frametime);\n            \n        #if orients:        \n        #    rots = np.degrees((-anim.orients[np.newaxis] * anim.rotations).euler(order=order[::-1]))\n        #else:\n        #    rots = np.degrees(anim.rotations.euler(order=order[::-1]))\n        rots = np.degrees(anim.rotations.euler(order=order[::-1]))\n        poss = anim.positions\n        \n        for i in range(anim.shape[0]):\n            for j in range(anim.shape[1]):\n                \n                if positions or j == 0:\n                \n                    f.write(\"%f %f %f %f %f %f \" % (\n                        poss[i,j,0],                  poss[i,j,1],                  poss[i,j,2], \n                        rots[i,j,ordermap[order[0]]], rots[i,j,ordermap[order[1]]], rots[i,j,ordermap[order[2]]]))\n                \n                else:\n                    \n                    f.write(\"%f %f %f \" % (\n                        rots[i,j,ordermap[order[0]]], rots[i,j,ordermap[order[1]]], rots[i,j,ordermap[order[2]]]))\n\n            f.write(\"\\n\")\n    \n    \ndef save_joint(f, anim, names, t, i, order='zyx', positions=False):\n    \n    f.write(\"%sJOINT %s\\n\" % (t, names[i]))\n    f.write(\"%s{\\n\" % t)\n    t += '\\t'\n  \n    f.write(\"%sOFFSET %f %f %f\\n\" % (t, anim.offsets[i,0], anim.offsets[i,1], anim.offsets[i,2]))\n    \n    if positions:\n        f.write(\"%sCHANNELS 6 Xposition Yposition Zposition %s %s %s \\n\" % (t, \n            channelmap_inv[order[0]], channelmap_inv[order[1]], channelmap_inv[order[2]]))\n    else:\n        f.write(\"%sCHANNELS 3 %s %s %s\\n\" % (t, \n            channelmap_inv[order[0]], channelmap_inv[order[1]], channelmap_inv[order[2]]))\n    \n    end_site = True\n    \n    for j in range(anim.shape[1]):\n        if anim.parents[j] == i:\n            t = save_joint(f, anim, names, t, j, order=order, positions=positions)\n            end_site = False\n    \n    if end_site:\n        f.write(\"%sEnd Site\\n\" % t)\n        f.write(\"%s{\\n\" % t)\n        t += '\\t'\n        f.write(\"%sOFFSET %f %f %f\\n\" % (t, 0.0, 0.0, 0.0))\n        t = t[:-1]\n        f.write(\"%s}\\n\" % t)\n  \n    t = t[:-1]\n    f.write(\"%s}\\n\" % t)\n    \n    return t\n"
  },
  {
    "path": "dataset/bvh/bvh_parser.py",
    "content": "import torch\nimport numpy as np\nimport dataset.bvh.bvh_io as bvh_io\nfrom utils.kinematics import ForwardKinematicsJoint\nfrom utils.transforms import quat2repr6d\nfrom utils.contact import foot_contact\nfrom dataset.bvh.Quaternions import Quaternions\nfrom dataset.bvh.bvh_writer import WriterWrapper\n\n\nclass Skeleton:\n    def __init__(self, names, parent, offsets, joint_reduction=True, skeleton_conf=None):\n        self._names = names\n        self.original_parent = parent\n        self._offsets = offsets\n        self._parent = None\n        self._ee_id = None\n        self.contact_names = []\n\n        for i, name in enumerate(self._names):\n            if ':' in name:\n                self._names[i] = name[name.find(':')+1:]\n\n        if joint_reduction or skeleton_conf is not None:\n            assert skeleton_conf is not None, 'skeleton_conf can not be None if you use joint reduction'\n            corps_names = skeleton_conf['corps_names']\n            self.contact_names = skeleton_conf['corps_names']\n            self.contact_threshold = skeleton_conf['contact_threshold']\n\n            self.contact_id = []\n            for i in self.contact_names:\n                self.contact_id.append(corps_names.index(i))\n        else:\n            self.skeleton_type = -1\n            corps_names = self._names\n\n        self.details = []    # joints that does not belong to the corps (we are not interested in them)\n        for i, name in enumerate(self._names):\n            if name not in corps_names: self.details.append(i)\n\n        self.corps = []\n        self.simplified_name = []\n        self.simplify_map = {}\n        self.inverse_simplify_map = {}\n\n        # Repermute the skeleton id according to the databse\n        for name in corps_names:\n            for j in range(len(self._names)):\n                if name in self._names[j]:\n                    self.corps.append(j)\n                    break\n        if len(self.corps) != len(corps_names):\n            for i in self.corps:\n                print(self._names[i], end=' ')\n            print(self.corps, self.skeleton_type, len(self.corps), sep='\\n')\n            raise Exception('Problem in this skeleton')\n\n        self.joint_num_simplify = len(self.corps)\n        for i, j in enumerate(self.corps):\n            self.simplify_map[j] = i\n            self.inverse_simplify_map[i] = j\n            self.simplified_name.append(self._names[j])\n        self.inverse_simplify_map[0] = -1\n        for i in range(len(self._names)):\n            if i in self.details:\n                self.simplify_map[i] = -1\n\n    @property\n    def parent(self):\n        if self._parent is None:\n            self._parent = self.original_parent[self.corps].copy()\n            for i in range(self._parent.shape[0]):\n                if i >= 1: self._parent[i] = self.simplify_map[self._parent[i]]\n            self._parent = tuple(self._parent)\n        return self._parent\n\n    @property\n    def offsets(self):\n        return torch.tensor(self._offsets[self.corps], dtype=torch.float)\n\n    @property\n    def names(self):\n        return self.simplified_name\n\n    @property\n    def ee_id(self):\n        raise Exception('Abaddoned')\n        # if self._ee_id is None:\n        #     self._ee_id = []\n        #     for i in SkeletonDatabase.ee_names[self.skeleton_type]:\n        #         self.ee_id._ee_id(corps_names[self.skeleton_type].index(i))\n\n\nclass BVH_file:\n    def __init__(self, file_path, skeleton_conf=None, requires_contact=False, joint_reduction=True, auto_scale=True):\n        self.anim = bvh_io.load(file_path)\n        self._names = self.anim.names\n        self.frametime = self.anim.frametime\n        if requires_contact or joint_reduction:\n            assert skeleton_conf is not None, 'Please provide a skeleton configuration for contact or joint reduction'\n        self.skeleton = Skeleton(self.anim.names, self.anim.parent, self.anim.offsets, joint_reduction, skeleton_conf)\n\n        # Downsample to 30 fps for our application\n        if self.frametime < 0.0084:\n            self.frametime *= 2\n            self.anim.positions = self.anim.positions[::2]\n            self.anim.rotations = self.anim.rotations[::2]\n        if self.frametime < 0.017:\n            self.frametime *= 2\n            self.anim.positions = self.anim.positions[::2]\n            self.anim.rotations = self.anim.rotations[::2]\n\n        self.requires_contact = requires_contact\n\n        if requires_contact:\n            self.contact_names = self.skeleton.contact_names\n        else:\n            self.contact_names = []\n\n        self.fk = ForwardKinematicsJoint(self.skeleton.parent, self.skeleton.offsets)\n        self.writer = WriterWrapper(self.skeleton.parent, self.skeleton.offsets)\n\n        self.auto_scale = auto_scale\n        if auto_scale:\n            self.scale = 1. / np.ceil(self.skeleton.offsets.max().cpu().numpy())\n            print(f'rescale the skeleton with scale: {self.scale}')\n            self.rescale(self.scale)\n        else:\n            self.scale = 1.0\n\n        if self.requires_contact:\n            gl_pos = self.joint_position()\n            self.contact_label = foot_contact(gl_pos[:, self.skeleton.contact_id],\n                                              threshold=self.skeleton.contact_threshold)\n            self.gl_pos = gl_pos\n\n    def local_pos(self):\n        gl_pos = self.joint_position()\n        local_pos = gl_pos - gl_pos[:, 0:1, :]\n        return local_pos[:, 1:]\n\n    def rescale(self, ratio):\n        self.anim.offsets *= ratio\n        self.anim.positions *= ratio\n\n    def to_tensor(self, repr='euler', rot_only=False):\n        if repr not in ['euler', 'quat', 'quaternion', 'repr6d']:\n            raise Exception('Unknown rotation representation')\n        positions = self.get_position()\n        rotations = self.get_rotation(repr=repr)\n\n        if rot_only:\n            return rotations.reshape(rotations.shape[0], -1)\n\n        if self.requires_contact:\n            virtual_contact = torch.zeros_like(rotations[:, :len(self.skeleton.contact_id)])\n            virtual_contact[..., 0] = self.contact_label\n            rotations = torch.cat([rotations, virtual_contact], dim=1)\n\n        rotations = rotations.reshape(rotations.shape[0], -1)\n        return torch.cat((rotations, positions), dim=-1)\n\n    def joint_position(self):\n        positions = torch.tensor(self.anim.positions[:, 0, :], dtype=torch.float)\n        rotations = self.anim.rotations[:, self.skeleton.corps, :]\n        rotations = Quaternions.from_euler(np.radians(rotations)).qs\n        rotations = torch.tensor(rotations, dtype=torch.float)\n        j_loc = self.fk.forward(rotations, positions)\n        return j_loc\n\n    def get_rotation(self, repr='quat'):\n        rotations = self.anim.rotations[:, self.skeleton.corps, :]\n        if repr == 'quaternion' or repr == 'quat' or repr == 'repr6d':\n            rotations = Quaternions.from_euler(np.radians(rotations)).qs\n            rotations = torch.tensor(rotations, dtype=torch.float)\n        if repr == 'repr6d':\n            rotations = quat2repr6d(rotations)\n        if repr == 'euler':\n            rotations = torch.tensor(rotations, dtype=torch.float)\n        return rotations\n\n    def get_position(self):\n        return torch.tensor(self.anim.positions[:, 0, :], dtype=torch.float)\n\n    def dfs(self, x, vis, dist):\n        fa = self.skeleton.parent\n        vis[x] = 1\n        for y in range(len(fa)):\n            if (fa[y] == x or fa[x] == y) and vis[y] == 0:\n                dist[y] = dist[x] + 1\n                self.dfs(y, vis, dist)\n\n    def get_neighbor(self, threshold, enforce_contact=False):\n        fa = self.skeleton.parent\n        neighbor_list = []\n        for x in range(0, len(fa)):\n            vis = [0 for _ in range(len(fa))]\n            dist = [0 for _ in range(len(fa))]\n            self.dfs(x, vis, dist)\n            neighbor = []\n            for j in range(0, len(fa)):\n                if dist[j] <= threshold:\n                    neighbor.append(j)\n            neighbor_list.append(neighbor)\n\n        contact_list = []\n        if self.requires_contact:\n            for i, p_id in enumerate(self.skeleton.contact_id):\n                v_id = len(neighbor_list)\n                neighbor_list[p_id].append(v_id)\n                neighbor_list.append(neighbor_list[p_id])\n                contact_list.append(v_id)\n\n        root_neighbor = neighbor_list[0]\n        id_root = len(neighbor_list)\n\n        if enforce_contact:\n            root_neighbor = root_neighbor + contact_list\n            for j in contact_list:\n                neighbor_list[j] = list(set(neighbor_list[j]))\n\n        root_neighbor = list(set(root_neighbor))\n        for j in root_neighbor:\n            neighbor_list[j].append(id_root)\n        root_neighbor.append(id_root)\n        neighbor_list.append(root_neighbor)  # Neighbor for root position\n        return neighbor_list"
  },
  {
    "path": "dataset/bvh/bvh_writer.py",
    "content": "import torch\nfrom utils.transforms import quat2euler, repr6d2quat\n\n\n# rotation with shape frame * J * 3\ndef write_bvh(parent, offset, rotation, position, names, frametime, order, path, endsite=None):\n    file = open(path, 'w')\n    frame = rotation.shape[0]\n    joint_num = rotation.shape[1]\n    order = order.upper()\n\n    file_string = 'HIERARCHY\\n'\n\n    seq = []\n\n    def write_static(idx, prefix):\n        nonlocal parent, offset, rotation, names, order, endsite, file_string, seq\n        seq.append(idx)\n        if idx == 0:\n            name_label = 'ROOT ' + names[idx]\n            channel_label = 'CHANNELS 6 Xposition Yposition Zposition {}rotation {}rotation {}rotation'.format(*order)\n        else:\n            name_label = 'JOINT ' + names[idx]\n            channel_label = 'CHANNELS 3 {}rotation {}rotation {}rotation'.format(*order)\n        offset_label = 'OFFSET %.6f %.6f %.6f' % (offset[idx][0], offset[idx][1], offset[idx][2])\n\n        file_string += prefix + name_label + '\\n'\n        file_string += prefix + '{\\n'\n        file_string += prefix + '\\t' + offset_label + '\\n'\n        file_string += prefix + '\\t' + channel_label + '\\n'\n\n        has_child = False\n        for y in range(idx+1, rotation.shape[1]):\n            if parent[y] == idx:\n                has_child = True\n                write_static(y, prefix + '\\t')\n        if not has_child:\n            file_string += prefix + '\\t' + 'End Site\\n'\n            file_string += prefix + '\\t' + '{\\n'\n            file_string += prefix + '\\t\\t' + 'OFFSET 0 0 0\\n'\n            file_string += prefix + '\\t' + '}\\n'\n\n        file_string += prefix + '}\\n'\n\n    write_static(0, '')\n\n    file_string += 'MOTION\\n' + 'Frames: {}\\n'.format(frame) + 'Frame Time: %.8f\\n' % frametime\n    for i in range(frame):\n        file_string += '%.6f %.6f %.6f ' % (position[i][0], position[i][1], position[i][2])\n        for j in range(joint_num):\n            idx = seq[j]\n            file_string += '%.6f %.6f %.6f ' % (rotation[i][idx][0], rotation[i][idx][1], rotation[i][idx][2])\n        file_string += '\\n'\n\n    file.write(file_string)\n    return file_string\n\n\nclass WriterWrapper:\n    def __init__(self, parents, offset=None):\n        self.parents = parents\n        self.offset = offset\n\n    def write(self, filename, rot, pos, offset=None, names=None, repr='quat'):\n        \"\"\"\n        Write animation to bvh file\n        :param filename:\n        :param rot: Quaternion as (w, x, y, z)\n        :param pos:\n        :param offset:\n        :return:\n        \"\"\"\n        if repr not in ['euler', 'quat', 'quaternion', 'repr6d']:\n            raise Exception('Unknown rotation representation')\n        if offset is None:\n            offset = self.offset\n        if not isinstance(offset, torch.Tensor):\n            offset = torch.tensor(offset)\n        n_bone = offset.shape[0]\n\n        if repr == 'repr6d':\n            rot = rot.reshape(rot.shape[0], -1, 6)\n            rot = repr6d2quat(rot)\n        if repr == 'repr6d' or repr == 'quat' or repr == 'quaternion':\n            rot = rot.reshape(rot.shape[0], -1, 4)\n            rot /= rot.norm(dim=-1, keepdim=True) ** 0.5\n            euler = quat2euler(rot, order='xyz')\n            rot = euler\n\n        if names is None:\n            names = ['%02d' % i for i in range(n_bone)]\n        write_bvh(self.parents, offset, rot, pos, names, 1, 'xyz', filename)\n"
  },
  {
    "path": "dataset/bvh_motion.py",
    "content": "import os\nimport os.path as osp\nimport torch\nimport numpy as np\nimport torch.nn.functional as F\nfrom .motion import MotionData\nfrom .bvh.bvh_parser import BVH_file\n\n\n## Some skeleton configurations\ncrab_dance_corps_names = ['ORG_Hips', 'ORG_BN_Bip01_Pelvis', 'DEF_BN_Eye_L_01', 'DEF_BN_Eye_L_02', 'DEF_BN_Eye_L_03', 'DEF_BN_Eye_L_03_end', 'DEF_BN_Eye_R_01', 'DEF_BN_Eye_R_02', 'DEF_BN_Eye_R_03', 'DEF_BN_Eye_R_03_end', 'DEF_BN_Leg_L_11', 'DEF_BN_Leg_L_12', 'DEF_BN_Leg_L_13', 'DEF_BN_Leg_L_14', 'DEF_BN_Leg_L_15', 'DEF_BN_Leg_L_15_end', 'DEF_BN_Leg_R_11', 'DEF_BN_Leg_R_12', 'DEF_BN_Leg_R_13', 'DEF_BN_Leg_R_14', 'DEF_BN_Leg_R_15', 'DEF_BN_Leg_R_15_end', 'DEF_BN_leg_L_01', 'DEF_BN_leg_L_02', 'DEF_BN_leg_L_03', 'DEF_BN_leg_L_04', 'DEF_BN_leg_L_05', 'DEF_BN_leg_L_05_end',\n                         'DEF_BN_leg_L_06', 'DEF_BN_Leg_L_07', 'DEF_BN_Leg_L_08', 'DEF_BN_Leg_L_09', 'DEF_BN_Leg_L_10', 'DEF_BN_Leg_L_10_end', 'DEF_BN_leg_R_01', 'DEF_BN_leg_R_02', 'DEF_BN_leg_R_03', 'DEF_BN_leg_R_04', 'DEF_BN_leg_R_05', 'DEF_BN_leg_R_05_end', 'DEF_BN_leg_R_06', 'DEF_BN_Leg_R_07', 'DEF_BN_Leg_R_08', 'DEF_BN_Leg_R_09', 'DEF_BN_Leg_R_10', 'DEF_BN_Leg_R_10_end', 'DEF_BN_Bip01_Pelvis', 'DEF_BN_Bip01_Pelvis_end', 'DEF_BN_Arm_L_01', 'DEF_BN_Arm_L_02', 'DEF_BN_Arm_L_03', 'DEF_BN_Arm_L_03_end', 'DEF_BN_Arm_R_01', 'DEF_BN_Arm_R_02', 'DEF_BN_Arm_R_03', 'DEF_BN_Arm_R_03_end']\nskeleton_confs = {\n    'mixamo': {\n        'corps_names': ['Hips', 'LeftUpLeg', 'LeftLeg', 'LeftFoot', 'LeftToeBase', 'LeftToe_End', 'RightUpLeg', 'RightLeg', 'RightFoot', 'RightToeBase', 'RightToe_End', 'Spine', 'Spine1', 'Spine2', 'Neck', 'Head', 'LeftShoulder', 'LeftArm', 'LeftForeArm', 'LeftHand', 'RightShoulder', 'RightArm', 'RightForeArm', 'RightHand'],\n        'contact_names': ['LeftToe_End', 'RightToe_End', 'LeftToeBase', 'RightToeBase'],\n        'contact_threshold': 0.018\n    },\n    'crab_dance': {\n        'corps_names': crab_dance_corps_names,\n        'contact_names': [name for name in crab_dance_corps_names if 'end' in name and ('05' in name or '10' in name or '15' in name)],\n        'contact_threshold': 0.006\n    },\n    'xia': {\n        'corps_names': ['Hips', 'LHipJoint', 'LeftUpLeg', 'LeftLeg', 'LeftFoot', 'LeftToeBase', 'RHipJoint', 'RightUpLeg', 'RightLeg', 'RightFoot', 'RightToeBase', 'LowerBack', 'Spine', 'Spine1', 'Neck', 'Neck1', 'Head', 'LeftShoulder', 'LeftArm', 'LeftForeArm', 'LeftHand', 'LeftFingerBase', 'LeftHandIndex1', 'LThumb', 'RightShoulder', 'RightArm', 'RightForeArm', 'RightHand', 'RightFingerBase', 'RightHandIndex1', 'RThumb'],\n        'contact_names': ['LeftToeBase', 'RightToeBase'],\n        'contact_threshold': 0.006\n    }\n}\n\nclass BVHMotion:\n    def __init__(self, bvh_file, skeleton_name=None, repr='quat', use_velo=True, keep_up_pos=False, up_axis='Y_UP', padding_last=False, requires_contact=False, joint_reduction=False):\n        '''\n        BVHMotion constructor\n        Args:\n            bvh_file         : string, bvh_file path to load from\n            skelton_name     : string, name of predefined skeleton, used when joint_reduction==True or contact==True\n            repr             : string, rotation representation, support ['quat', 'repr6d', 'euler'] \n            use_velo         : book, whether to transform the joints positions to velocities\n            keep_up_pos      : bool, whether to keep y position when converting to velocity\n            up_axis          : string, string, up axis of the motion data\n            padding_last     : bool, whether to pad the last position\n            requires_contact : bool, whether to concatenate contact information\n            joint_reduction  : bool, whether to reduce the joint number\n        '''\n        self.bvh_file = bvh_file\n        self.skeleton_name = skeleton_name\n        if skeleton_name is not None:\n            assert skeleton_name in skeleton_confs, f'{skeleton_name} not found, please add a skeleton configuration.'\n        self.requires_contact = requires_contact\n        self.joint_reduction = joint_reduction\n\n        self.raw_data = BVH_file(bvh_file, skeleton_confs[skeleton_name] if skeleton_name is not None else None, requires_contact, joint_reduction, auto_scale=True)\n        self.motion_data = MotionData(self.raw_data.to_tensor(repr=repr).permute(1, 0).unsqueeze(0), repr=repr, use_velo=use_velo, keep_up_pos=keep_up_pos, up_axis=up_axis, \n                                      padding_last=padding_last, contact_id=self.raw_data.skeleton.contact_id if requires_contact else None)\n    @property\n    def repr(self):\n        return self.motion_data.repr\n\n    @property\n    def use_velo(self):\n        return self.motion_data.use_velo\n\n    @property\n    def keep_up_pos(self):\n        return self.motion_data.keep_up_pos\n    \n    @property\n    def padding_last(self):\n        return self.motion_data.padding_last\n    \n    @property\n    def concat_id(self):\n        return self.motion_data.contact_id\n    \n    @property\n    def n_pad(self):\n        return self.motion_data.n_pad\n    \n    @property\n    def n_contact(self):\n        return self.motion_data.n_contact\n\n    @property\n    def n_rot(self):\n        return self.motion_data.n_rot\n\n    def sample(self, size=None, slerp=False):\n        '''\n        Sample motion data, support slerp\n        '''\n        return self.motion_data.sample(size, slerp)\n\n\n    def write(self, filename, data):\n        '''\n        Parse motion data into position, velocity and contact(if exists)\n        data should be []\n        No batch support here!!!\n        '''\n        assert len(data.shape) == 3, 'The data format should be [batch_size x n_channels x n_frames]' \n\n        if self.n_pad:\n            data = data.clone()[:, :-self.n_pad]\n        if self.use_velo:\n            data = self.motion_data.to_position(data)\n        data = data.squeeze().permute(1, 0)\n        pos = data[..., -3:]\n        rot = data[..., :-3].reshape(data.shape[0], -1, self.n_rot)\n        if self.requires_contact:\n            contact = rot[..., -self.n_contact:, 0]\n            rot = rot[..., :-self.n_contact, :]\n        else:\n            contact = None\n\n        if contact is not None:\n            np.save(filename + '.contact', contact.detach().cpu().numpy())\n\n        # rescale the output\n        self.raw_data.rescale(1. / self.raw_data.scale)\n        pos *= 1. / self.raw_data.scale\n        self.raw_data.writer.write(filename, rot, pos, names=self.raw_data.skeleton.names, repr=self.repr)\n\n\ndef load_multiple_dataset(name_list, **kargs):\n        with open(name_list, 'r') as f:\n            names = [line.strip() for line in f.readlines()]\n        datasets = []\n        for f in names:\n            kargs['bvh_file'] = osp.join(osp.dirname(name_list), f)\n            datasets.append(BVHMotion(**kargs))\n        return datasets"
  },
  {
    "path": "dataset/motion.py",
    "content": "import torch\nimport torch.nn.functional as F\n\n\nclass MotionData:\n    def __init__(self, data, repr='quat', use_velo=True, keep_up_pos=True, up_axis='Y', padding_last=False, contact_id=None):\n        '''\n        BaseMotionData constructor\n        Args:\n            data         : torch.Tensor, [batch_size x n_channels x n_frames] input motion data, \n                           the channels dim shoud be [n_joints x n_dim_of_rotation + 3(global position)]\n            repr         : string, rotation representation, support ['quat', 'repr6d', 'euler'] \n            use_velo     : book, whether to transform the joints positions to velocities\n            keep_up_pos  : bool, whether to keep up position when converting to velocity\n            up_axis      : string, string, up axis of the motion data\n            padding_last : bool, whether to pad the last position\n            contact_id   : list, contact joints id\n        '''\n        self.data = data \n        self.repr = repr\n        self.use_velo = use_velo\n        self.keep_up_pos = keep_up_pos\n        self.up_axis = up_axis\n        self.padding_last = padding_last\n        self.contact_id = contact_id\n        self.begin_pos = None\n\n        # assert the rotation representation\n        if self.repr == 'quat':\n            self.n_rot = 4\n            assert (self.data.shape[1] - 3) % 4 == 0, 'rotation is not \"quaternion\" representation'\n        elif self.repr == 'repr6d':\n            self.n_rot = 6\n            assert (self.data.shape[1] - 3) % 6 == 0, 'rotation is not \"repr6d\" representation'\n        elif self.repr == 'eluer':\n            self.n_rot = 3\n            assert (self.data.shape[1] - 3) % 3 == 0, 'rotation is not \"euler\" representation'\n\n        # whether to pad the position data with zero\n        if self.padding_last:\n            self.n_pad = self.data.shape[1] - 3  # pad position channels to match the n_channels of rotation\n            paddings = torch.zeros_like(self.data[:, :self.n_pad])\n            self.data = torch.cat((self.data, paddings), dim=1)\n        else:\n            self.n_pad = 0\n\n        # get the contact information\n        if self.contact_id is not None:\n            self.n_contact = len(contact_id)\n        else:\n            self.n_contact = 0\n\n        # whether to keep y position when converting to velocity\n        if self.keep_up_pos:\n            if self.up_axis == 'X_UP':\n                self.velo_mask = [-2, -1]\n            elif self.up_axis == 'Y_UP':\n                self.velo_mask = [-3, -1]\n            elif self.up_axis == 'Z_UP':\n                self.velo_mask = [-3, -2]\n        else:\n            self.velo_mask = [-3, -2, -1]\n\n        # whether to convert global position to velocity\n        if self.use_velo:\n            self.data =  self.to_velocity(self.data)\n\n\n    def __len__(self):\n        '''\n        return the number of motion frames\n        '''\n        return self.data.shape[-1]\n\n\n    def sample(self, size=None, slerp=False):\n        '''\n        sample the motion data using given size\n        '''\n        if size is None:\n            return self.data\n        else:\n            if slerp:\n                motion = self.slerp(self.data, size=size)\n            else:\n                motion = F.interpolate(self.data, size=size, mode='linear', align_corners=False)\n            return motion\n\n\n    def to_velocity(self, pos):\n        '''\n        convert motion data to velocity\n        '''\n        assert self.begin_pos is None, 'the motion data had been converted to velocity'\n        msk = [i - self.n_pad for i in self.velo_mask]\n        velo = pos.detach().clone().to(pos.device)\n        velo[:, msk, 1:] = pos[:, msk, 1:] - pos[:, msk, :-1]\n        self.begin_pos = pos[:, msk, 0].clone()\n        velo[:, msk, 0] = pos[:, msk, 1]\n        return velo\n\n    def to_position(self, velo):\n        '''\n        convert motion data to position\n        '''\n        assert self.begin_pos is not None, 'the motion data is already position'\n        msk = [i - self.n_pad for i in self.velo_mask]\n        pos = velo.detach().clone().to(velo.device)\n        pos[:, msk, 0] = self.begin_pos.to(velo.device)\n        pos[:, msk] = torch.cumsum(pos[:, msk], dim=-1)\n        self.begin_pos = None\n        return pos"
  },
  {
    "path": "dataset/tracks_motion.py",
    "content": "import os\nfrom os.path import join as pjoin\nimport numpy as np\nimport copy\nimport torch\nfrom .motion import MotionData\nfrom ..utils.transforms import quat2repr6d, quat2euler, repr6d2quat\n\nclass TracksParser():\n    def __init__(self, tracks_json, scale):\n        self.tracks_json = tracks_json\n        self.scale = scale\n        \n        self.skeleton_names = []\n        self.rotations = []\n        for i, track in enumerate(self.tracks_json):\n            self.skeleton_names.append(track['name'])\n            if i == 0:\n                assert track['type'] == 'vector'\n                self.position = np.array(track['values']).reshape(-1, 3) * self.scale\n                self.num_frames = self.position.shape[0]\n            else:\n                assert track['type'] == 'quaternion' # DEAFULT: quaternion\n                rotation = np.array(track['values']).reshape(-1, 4)\n                if rotation.shape[0] == 0:\n                    rotation = np.zeros((self.num_frames, 4))\n                elif rotation.shape[0] < self.num_frames:\n                    rotation = np.repeat(rotation, self.num_frames // rotation.shape[0], axis=0)\n                elif rotation.shape[0] > self.num_frames:\n                    rotation = rotation[:self.num_frames]\n                self.rotations += [rotation]\n        self.rotations = np.array(self.rotations, dtype=np.float32)\n\n    def to_tensor(self, repr='euler', rot_only=False):\n        if repr not in ['euler', 'quat', 'quaternion', 'repr6d']:\n            raise Exception('Unknown rotation representation')\n        rotations = self.get_rotation(repr=repr)\n        positions = self.get_position()\n\n        if rot_only:\n            return rotations.reshape(rotations.shape[0], -1)\n\n        rotations = rotations.reshape(rotations.shape[0], -1)\n        return torch.cat((rotations, positions), dim=-1)\n\n    def get_rotation(self, repr='quat'):\n        if repr == 'quaternion' or repr == 'quat' or repr == 'repr6d':\n            rotations = torch.tensor(self.rotations, dtype=torch.float).transpose(0, 1)\n        if repr == 'repr6d':\n            rotations = quat2repr6d(rotations)\n        if repr == 'euler':\n            rotations = quat2euler(rotations)\n        return rotations\n\n    def get_position(self):\n        return torch.tensor(self.position, dtype=torch.float32)\n\nclass TracksMotion:\n    def __init__(self, tracks_json, scale=1.0, repr='quat', use_velo=True, keep_up_pos=True, up_axis='Y_UP', padding_last=False):\n        '''\n        TracksMotion constructor\n        Args:\n            tracks_json      : dict, json format tracks data to load from\n            scale            : float, scale of the tracks motion data\n            repr             : string, rotation representation, support ['quat', 'repr6d', 'euler'] \n            use_velo         : book, whether to transform the joints positions to velocities\n            keep_up_pos      : bool, whether to keep y position when converting to velocity\n            up_axis          : string, string, up axis of the motion data\n            padding_last     : bool, whether to pad the last position\n        '''\n        self.tracks_json = tracks_json\n\n        self.raw_data = TracksParser(tracks_json, scale)\n        self.motion_data = MotionData(self.raw_data.to_tensor(repr=repr).permute(1, 0).unsqueeze(0), repr=repr, use_velo=use_velo, keep_up_pos=keep_up_pos, up_axis=up_axis, \n                                      padding_last=padding_last, contact_id=None)\n    @property\n    def repr(self):\n        return self.motion_data.repr\n\n    @property\n    def use_velo(self):\n        return self.motion_data.use_velo\n\n    @property\n    def keep_up_pos(self):\n        return self.motion_data.keep_up_pos\n    \n    @property\n    def padding_last(self):\n        return self.motion_data.padding_last\n\n    @property\n    def n_pad(self):\n        return self.motion_data.n_pad\n\n    @property\n    def n_rot(self):\n        return self.motion_data.n_rot\n\n    def sample(self, size=None, slerp=False):\n        '''\n        Sample motion data, support slerp\n        '''\n        return self.motion_data.sample(size, slerp)\n\n\n    def parse(self, motion, keep_velo=False,):\n        \"\"\"\n        No batch support here!!!\n        :returns tracks_json\n        \"\"\"\n        motion = motion.clone()\n\n        if self.use_velo and not keep_velo:\n            motion = self.motion_data.to_position(motion)\n        if self.n_pad:\n            motion = motion[:, :-self.n_pad]\n\n        motion = motion.squeeze().permute(1, 0)\n        pos = motion[..., -3:] / self.raw_data.scale\n        rot = motion[..., :-3].reshape(motion.shape[0], -1, self.n_rot)\n        if self.repr == 'repr6d':\n            rot = repr6d2quat(rot)\n        elif self.repr == 'euler':\n            raise NotImplementedError('parse \"euler is not implemented yet!!!')\n\n        times = []\n        out_tracks_json = copy.deepcopy(self.tracks_json)\n        for i, _track in enumerate(out_tracks_json):\n            if i == 0:\n                times = [ j * out_tracks_json[i]['times'][1] for j in range(motion.shape[0])]\n                out_tracks_json[i]['values'] = pos.flatten().detach().cpu().numpy().tolist() \n            else:\n                out_tracks_json[i]['values'] = rot[:, i-1, :].flatten().detach().cpu().numpy().tolist()\n            out_tracks_json[i]['times'] = times\n\n        return out_tracks_json\n"
  },
  {
    "path": "docker/Dockerfile",
    "content": "FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel\n\n# For the convenience for users in China mainland\nCOPY apt-sources.list /etc/apt/sources.list\n# Install some basic utilities\nRUN rm /etc/apt/sources.list.d/cuda.list\nRUN rm /etc/apt/sources.list.d/nvidia-ml.list\nRUN apt-get update && apt-get install -y \\\n    curl \\\n    ca-certificates \\\n    sudo \\\n    git \\\n    bzip2 \\\n    libx11-6 \\\n    gcc \\\n    g++ \\\n    libusb-1.0-0 \\\n    libgl1-mesa-glx \\\n    libglib2.0-dev \\\n    openssh-server \\\n    openssh-client \\\n    iputils-ping \\\n    unzip \\\n    cmake \\\n    libssl-dev \\\n    libosmesa6-dev \\\n    freeglut3-dev \\\n    ffmpeg \\\n    iputils-ping \\\n && rm -rf /var/lib/apt/lists/*\n\n# For the convenience for users in China mainland\nRUN pip config set global.index-url https://pypi.mirrors.ustc.edu.cn/simple \\\n  && export PATH=\"/usr/local/bin:$PATH\" \\\n  && /bin/bash -c \"source ~/.bashrc\"\nRUN conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ \\\n && conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ \\\n && conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/ \\\n && conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ \\\n && conda config --set show_channel_urls yes\n\n# Install dependencies\nCOPY requirements.txt requirements.txt \nRUN pip install -r requirements.txt --user \n\nCMD [\"python3\"]"
  },
  {
    "path": "docker/README.md",
    "content": "## Build Docker Environment and use with GPU Support\n\nBefore you can use this Docker environment, you need to have the following:\n\n- Docker installed on your system\n- NVIDIA drivers installed on your system\n- NVIDIA Container Toolkit installed on your system\n\n\n### Build and Run\n1. Build docker image:\n   ```sh\n   docker build -t GenMM:latest .\n   ```\n2. Start the docker container:\n   ```sh\n   docker run --gpus all -it GenMM:latest /bin/bash\n   ```\n3. Clone the repository:\n   ```sh\n   git clone git@github.com:wyysf-98/GenMM.git\n   ```\n\n## Troubleshooting\n\nIf you encounter any issues with the Docker environment with GPU support, please check the following:\n\n- Make sure that you have installed the NVIDIA drivers and NVIDIA Container Toolkit on your system.\n- Make sure that you have specified the --gpus all option when starting the Docker container.\n- Make sure that your deep learning application is configured to use the GPU."
  },
  {
    "path": "docker/apt-sources.list",
    "content": "deb https://mirrors.ustc.edu.cn/ubuntu/ bionic main restricted universe multiverse\ndeb-src https://mirrors.ustc.edu.cn/ubuntu/ bionic main restricted universe multiverse\ndeb https://mirrors.ustc.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse\ndeb-src https://mirrors.ustc.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse\ndeb https://mirrors.ustc.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse\ndeb-src https://mirrors.ustc.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse\ndeb https://mirrors.ustc.edu.cn/ubuntu/ bionic-security main restricted universe multiverse\ndeb-src https://mirrors.ustc.edu.cn/ubuntu/ bionic-security main restricted universe multiverse\ndeb https://mirrors.ustc.edu.cn/ubuntu/ bionic-proposed main restricted universe multiverse\ndeb-src https://mirrors.ustc.edu.cn/ubuntu/ bionic-proposed main restricted universe multiverse"
  },
  {
    "path": "docker/requirements.txt",
    "content": "torch==1.12.1\ntorchvision==0.13.1\ntensorboardX==2.5\ntqdm==4.62.3\nunfoldNd==0.2.0\npyyaml>=5.3.1\ngradio==3.34.0\nmatplotlib==3.3.2"
  },
  {
    "path": "docker/requirements_blender.txt",
    "content": "torch==2.2.0\ntorchvision==0.17.0\ntqdm==4.62.3\nunfoldNd==0.2.0\npyyaml>=5.3.1\n"
  },
  {
    "path": "fix_contact.py",
    "content": "from dataset.bvh.bvh_parser import BVH_file\nfrom os.path import join as pjoin\nimport numpy as np\nimport torch\nfrom utils.contact import constrain_from_contact\nfrom utils.kinematics import InverseKinematicsJoint2\nfrom utils.transforms import repr6d2quat\nfrom tqdm import tqdm\nimport argparse\nimport matplotlib.pyplot as plt\nfrom dataset.bvh_motion import skeleton_confs\n\ndef continuous_filter(contact, length=2):\n    contact = contact.copy()\n    for j in range(contact.shape[1]):\n        c = contact[:, j]\n        t_len = 0\n        prev = c[0]\n        for i in range(contact.shape[0]):\n            if prev == c[i]:\n                t_len += 1\n            else:\n                if t_len <= length:\n                    c[i - t_len:i] = c[i]\n                t_len = 1\n                prev = c[i]\n    return contact\n\n\ndef fix_negative_height(contact, constrain, cid):\n    floor = -1\n    constrain = constrain.clone()\n    for i in range(constrain.shape[0]):\n        for j in range(constrain.shape[1]):\n            if constrain[i, j, 1] < floor:\n                constrain[i, j, 1] = floor\n    return constrain\n\n\ndef fix_contact(bvh_file, contact):\n    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n    cid = bvh_file.skeleton.contact_id\n    glb = bvh_file.joint_position()\n    rotation = bvh_file.get_rotation(repr='repr6d').to(device)\n    position = bvh_file.get_position().to(device)\n    contact = contact > 0.5\n    # contact = continuous_filter(contact)\n    constrain = constrain_from_contact(contact, glb, cid)\n    constrain = fix_negative_height(contact, constrain, cid).to(device)\n    cid = list(range(glb.shape[1]))\n    ik_solver = InverseKinematicsJoint2(rotation, position, bvh_file.skeleton.offsets.to(device), bvh_file.skeleton.parent,\n                                        constrain[:, cid], cid, 0.1, 0.01, use_velo=True)\n\n    loop = tqdm(range(500))\n    losses = []\n    for i in loop:\n        loss = ik_solver.step()\n        loop.set_description(f'loss = {loss:.07f}')\n        losses += [loss]\n        plt.plot(losses)\n    \n\n    return repr6d2quat(ik_solver.rotations.detach()), ik_solver.get_position()\n\n\ndef fix_contact_on_file(prefix, name):\n    try:\n        contact = np.load(pjoin(prefix, name + '.bvh.contact.npy'))\n    except:\n        print(f'{name} not found')\n        return\n    bvh_file = BVH_file(pjoin(prefix, name + '.bvh'), no_scale=True, requires_contact=True)\n    print('Fixing foot contact with IK...')\n    res = fix_contact(bvh_file, contact)\n    bvh_file.writer.write(pjoin(prefix, name + '_fixed.bvh'), res[0], res[1], names=bvh_file.skeleton.names, repr='quat')\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--prefix', type=str, required=True)\n    parser.add_argument('--name', type=str, required=True)\n    parser.add_argument('--skeleton_name', type=str, required=True)\n    args = parser.parse_args()\n    if args.prefix[0] == '/':\n        prefix = args.prefix\n    else:\n        prefix = f'./results/{args.prefix}'\n    name = args.name\n    contact = np.load(pjoin(prefix, name + '.bvh.contact.npy'))\n    bvh_file = BVH_file(pjoin(prefix, name + '.bvh'), skeleton_confs[args.skeleton_name], auto_scale=False, requires_contact=True)\n\n    res = fix_contact(bvh_file, contact)\n    plt.savefig(f'{prefix}/losses.png')\n\n    bvh_file.writer.write(pjoin(prefix, name + '_fixed.bvh'), res[0], res[1], names=bvh_file.skeleton.names, repr='quat')"
  },
  {
    "path": "nearest_neighbor/losses.py",
    "content": "import torch    \nimport torch.nn as nn\n\nfrom .utils import extract_patches, combine_patches, efficient_cdist, get_NNs_Dists\n\nclass PatchCoherentLoss(torch.nn.Module):\n    def __init__(self, patch_size=7, stride=1, alpha=None, loop=False, cache=False):\n        super(PatchCoherentLoss, self).__init__()\n        self.patch_size = patch_size\n        assert self.patch_size % 2 == 1, \"Only support odd patch size\"\n        self.stride = stride\n        assert self.stride == 1, \"Only support stride of 1\"\n        self.alpha = alpha\n        self.loop = loop\n        self.cache = cache\n        if cache:\n            self.cached_data = None\n\n    def forward(self, X, Ys, dist_wrapper=None, ext=None, return_blended_results=False):\n        \"\"\"For each patch in input X find its NN in target Y and sum the their distances\"\"\"\n        assert X.shape[0] == 1, \"Only support batch size of 1 for X\"\n        dist_fn = lambda X, Y: dist_wrapper(efficient_cdist, X, Y) if dist_wrapper is not None else efficient_cdist(X, Y)\n\n        x_patches = extract_patches(X, self.patch_size, self.stride, loop=self.loop)\n\n        if not self.cache or self.cached_data is None:\n            y_patches = []\n            for y in Ys:\n                y_patches += [extract_patches(y, self.patch_size, self.stride, loop=False)]\n            y_patches = torch.cat(y_patches, dim=1)\n            self.cached_data = y_patches\n        else:\n            y_patches = self.cached_data\n        \n        nnf, dist = get_NNs_Dists(dist_fn, x_patches.squeeze(0), y_patches.squeeze(0), self.alpha)\n\n        if return_blended_results:\n            return combine_patches(X.shape, y_patches[:, nnf, :], self.patch_size, self.stride, loop=self.loop), dist.mean()\n        else:\n            return dist.mean()\n    \n    def clean_cache(self):\n        self.cached_data = None"
  },
  {
    "path": "nearest_neighbor/utils.py",
    "content": "\"\"\"\nthis file borrows some codes from https://github.com/ariel415el/Efficient-GPNN/blob/main/utils/NN.py.\n\"\"\"\nimport torch\nimport torch.nn.functional as F\nimport unfoldNd\n\ndef extract_patches(x, patch_size, stride, loop=False):\n    \"\"\"Extract patches from a motion sequence\"\"\"\n    b, c, _t = x.shape\n\n    # manually padding to loop\n    if loop:\n        half = patch_size // 2\n        front, tail = x[:,:,:half], x[:,:,-half:]\n        x = torch.concat([tail, x, front], dim=-1)\n\n    x_patches = unfoldNd.unfoldNd(x, kernel_size=patch_size, stride=stride).transpose(1, 2).reshape(b, -1, c, patch_size)\n    \n    return x_patches.view(b, -1, c * patch_size)\n\ndef combine_patches(x_shape, ys, patch_size, stride, loop=False):\n    \"\"\"Combine motion patches\"\"\"\n    \n    # manually handle the loop situation\n    out_shape = [*x_shape]\n    if loop:\n        padding = patch_size // 2\n        out_shape[-1] = out_shape[-1] + padding * 2\n\n    combined = unfoldNd.foldNd(ys.permute(0, 2, 1), output_size=tuple(out_shape[-1:]), kernel_size=patch_size, stride=stride)\n\n    # normal fold matrix\n    input_ones = torch.ones(tuple(out_shape), dtype=ys.dtype, device=ys.device)\n    divisor = unfoldNd.unfoldNd(input_ones, kernel_size=patch_size, stride=stride)\n    divisor = unfoldNd.foldNd(divisor, output_size=tuple(out_shape[-1:]), kernel_size=patch_size, stride=stride)\n    combined = (combined / divisor).squeeze(dim=0).unsqueeze(0)\n    \n    if loop:\n        half = patch_size // 2\n        front, tail = combined[:,:,:half], combined[:,:,-half:]\n        combined[:, :, half:2 * half] = (combined[:, :, half:2 * half] + tail) / 2\n        combined[:, :, - 2 * half:-half] = (front + combined[:, :, - 2 * half:-half]) / 2\n        combined = combined[:, :, half:-half]\n\n    return combined\n\n\ndef efficient_cdist(X, Y):\n    \"\"\"\n    borrowed from https://github.com/ariel415el/Efficient-GPNN/blob/main/utils/NN.py\n    Pytorch efficient way of computing distances between all vectors in X and Y, i.e (X[:, None] - Y[None, :])**2\n    Get the nearest neighbor index from Y for each X\n    :param X:  (n1, d) tensor\n    :param Y:  (n2, d) tensor\n    Returns a n2 n1 of indices\n    \"\"\"\n    dist = (X * X).sum(1)[:, None] + (Y * Y).sum(1)[None, :] - 2.0 * torch.mm(X, torch.transpose(Y, 0, 1))\n    d = X.shape[1]\n    dist /= d # normalize by size of vector to make dists independent of the size of d ( use same alpha for all patche-sizes)\n    return dist # DO NOT use torch.sqrt\n\n\ndef get_col_mins_efficient(dist_fn, X, Y, b=1024):\n    \"\"\"\n    borrowed from https://github.com/ariel415el/Efficient-GPNN/blob/main/utils/NN.py\n    Computes the l2 distance to the closest x or each y.\n    :param X:  (n1, d) tensor\n    :param Y:  (n2, d) tensor\n    Returns n1 long array of L2 distances\n    \"\"\"\n    n_batches = len(Y) // b\n    mins = torch.zeros(Y.shape[0], dtype=X.dtype, device=X.device)\n    for i in range(n_batches):\n        mins[i * b:(i + 1) * b] = dist_fn(X, Y[i * b:(i + 1) * b]).min(0)[0]\n    if len(Y) % b != 0:\n        mins[n_batches * b:] = dist_fn(X, Y[n_batches * b:]).min(0)[0]\n\n    return mins\n\n\ndef get_NNs_Dists(dist_fn, X, Y, alpha=None, b=1024):\n    \"\"\"\n    borrowed from https://github.com/ariel415el/Efficient-GPNN/blob/main/utils/NN.py\n    Get the nearest neighbor index from Y for each X.\n    Avoids holding a (n1 * n2) amtrix in order to reducing memory footprint to (b * max(n1,n2)).\n    :param X:  (n1, d) tensor\n    :param Y:  (n2, d) tensor\n    Returns a n2 n1 of indices amd distances\n    \"\"\"\n    if alpha is not None:\n        normalizing_row = get_col_mins_efficient(dist_fn, X, Y, b=b)\n        normalizing_row = alpha + normalizing_row[None, :]\n    else:\n        normalizing_row = 1\n\n    NNs = torch.zeros(X.shape[0], dtype=torch.long, device=X.device)\n    Dists = torch.zeros(X.shape[0], dtype=torch.float, device=X.device)\n\n    n_batches = len(X) // b\n    for i in range(n_batches):\n        dists = dist_fn(X[i * b:(i + 1) * b], Y) / normalizing_row\n        NNs[i * b:(i + 1) * b] = dists.min(1)[1]\n        Dists[i * b:(i + 1) * b] = dists.min(1)[0]\n    if len(X) % b != 0:\n        dists = dist_fn(X[n_batches * b:], Y) / normalizing_row\n        NNs[n_batches * b:] = dists.min(1)[1]\n        Dists[n_batches * b: ] = dists.min(1)[0]\n\n    return NNs, Dists\n"
  },
  {
    "path": "run_random_generation.py",
    "content": "import os\nimport os.path as osp\nimport argparse\nfrom GenMM import GenMM\nfrom nearest_neighbor.losses import PatchCoherentLoss\nfrom dataset.bvh_motion import BVHMotion, load_multiple_dataset\nfrom utils.base import ConfigParser, set_seed\n\nargs = argparse.ArgumentParser(\n    description='Random shuffle the input motion sequence')\nargs.add_argument('-m', '--mode', default='run',\n                  choices=['run', 'eval', 'debug'], type=str, help='current run mode.')\nargs.add_argument('-i', '--input', required=True,\n                  type=str, help='exemplar motion path.')\nargs.add_argument('-o', '--output_dir', default='./output',\n                  type=str, help='output folder path for saving results.')\nargs.add_argument('-c', '--config', default='./configs/default.yaml',\n                  type=str, help='config file path.')\nargs.add_argument('-s', '--seed', default=None,\n                  type=int, help='random seed used.')\nargs.add_argument('-d', '--device', default=\"cuda:0\",\n                  type=str, help='device to use.')\nargs.add_argument('--post_precess', action='store_true',\n                  help='whether to use IK post-process to fix foot contact.')\n\n# Use argsparser to overwrite the configuration\n# for dataset\nargs.add_argument('--skeleton_name', type=str,\n                  help='(used when joint_reduction==True or contact==True) skeleton name to load pre-defined joints configuration.')\nargs.add_argument('--use_velo', type=int,\n                  help='whether to use velocity rather than global position of each joint.')\nargs.add_argument('--repr', choices=['repr6d', 'quat', 'euler'], type=str,\n                  help='rotation representation, support [epr6d, quat, reuler].')\nargs.add_argument('--requires_contact', type=int,\n                  help='whether to use contact label.')\nargs.add_argument('--keep_up_pos', type=int,\n                  help='whether to do not use velocity and keep the y(up) position.')\nargs.add_argument('--up_axis', type=str, choices=['X_UP', 'Y_UP', 'Z_UP'],\n                  help='up axis of the motion.')\nargs.add_argument('--padding_last', type=int,\n                  help='whether to pad the last position channel to match the rotation dimension.')\nargs.add_argument('--joint_reduction', type=int,\n                  help='whether to simplify the skeleton using provided skeleton config.')\nargs.add_argument('--skeleton_aware', type=int,\n                  help='whether to enable skeleton-aware component.')\nargs.add_argument('--joints_group', type=str,\n                  help='joints spliting group for using skeleton-aware component.')\n# for synthesis\nargs.add_argument('--num_frames', type=str, \n                  help='number of synthesized frames, supported Nx(N times) and int input.')\nargs.add_argument('--alpha', type=float,\n                  help='completeness/diversity trade-off alpha value.')\nargs.add_argument('--num_steps', type=int,\n                  help='number of optimization steps at each pyramid level.')\nargs.add_argument('--noise_sigma', type=float,\n                  help='standard deviation of the zero mean normal noise added to the initialization.')\nargs.add_argument('--coarse_ratio', type=float,\n                  help='downscale ratio of the coarse level.')\nargs.add_argument('--coarse_ratio_factor', type=float,\n                  help='downscale ratio of the coarse level.')\nargs.add_argument('--pyr_factor', type=float,\n                  help='upsample ratio of each pyramid level.')\nargs.add_argument('--num_stages_limit', type=int,\n                  help='limit of the number of stages.')\nargs.add_argument('--patch_size', type=int, help='patch size for generation.')\nargs.add_argument('--loop', type=int, help='whether to loop the sequence.')\ncfg = ConfigParser(args)\n\n\ndef generate(cfg):\n    # seet seed for reproducible\n    set_seed(cfg.seed)\n\n    # set save path and prepare data for generation\n    if cfg.input.endswith('.bvh'):\n        base_dir = osp.join(\n            cfg.output_dir, cfg.input.split('/')[-1].split('.')[0])\n        motion_data = [BVHMotion(cfg.input, skeleton_name=cfg.skeleton_name, repr=cfg.repr,\n                                 use_velo=cfg.use_velo, keep_up_pos=cfg.keep_up_pos, up_axis=cfg.up_axis, padding_last=cfg.padding_last,\n                                 requires_contact=cfg.requires_contact, joint_reduction=cfg.joint_reduction)]\n    elif cfg.input.endswith('.txt'):\n        base_dir = osp.join(cfg.output_dir, cfg.input.split(\n            '/')[-2], cfg.input.split('/')[-1].split('.')[0])\n        motion_data = load_multiple_dataset(name_list=cfg.input, skeleton_name=cfg.skeleton_name, repr=cfg.repr,\n                                            use_velo=cfg.use_velo, keep_up_pos=cfg.keep_up_pos, up_axis=cfg.up_axis, padding_last=cfg.padding_last,\n                                            requires_contact=cfg.requires_contact, joint_reduction=cfg.joint_reduction)\n    else:\n        raise ValueError('exemplar must be a bvh file or a txt file')\n    prefix = f\"s{cfg.seed}+{cfg.num_frames}+{cfg.repr}+use_velo_{cfg.use_velo}+kypose_{cfg.keep_up_pos}+padding_{cfg.padding_last}\" \\\n             f\"+contact_{cfg.requires_contact}+jredu_{cfg.joint_reduction}+n{cfg.noise_sigma}+pyr{cfg.pyr_factor}\" \\\n             f\"+r{cfg.coarse_ratio}_{cfg.coarse_ratio_factor}+itr{cfg.num_steps}+ps_{cfg.patch_size}+alpha_{cfg.alpha}\" \\\n             f\"+loop_{cfg.loop}\"\n\n    # perform the generation\n    model = GenMM(device=cfg.device, silent=True if cfg.mode == 'eval' else False)\n    criteria = PatchCoherentLoss(patch_size=cfg.patch_size, alpha=cfg.alpha, loop=cfg.loop, cache=True)\n    syn = model.run(motion_data, criteria,\n                    num_frames=cfg.num_frames,\n                    num_steps=cfg.num_steps,\n                    noise_sigma=cfg.noise_sigma,\n                    patch_size=cfg.patch_size, \n                    coarse_ratio=cfg.coarse_ratio,\n                    pyr_factor=cfg.pyr_factor,\n                    debug_dir=save_dir if cfg.mode == 'debug' else None)\n    \n    # save the generated results\n    save_dir = osp.join(base_dir, prefix)\n    os.makedirs(save_dir, exist_ok=True)\n    motion_data[0].write(f\"{save_dir}/syn.bvh\", syn)\n\n    if cfg.post_precess:\n        cmd = f\"python fix_contact.py --prefix {osp.abspath(save_dir)} --name syn --skeleton_name={cfg.skeleton_name}\"\n        os.system(cmd)\n\nif __name__ == '__main__':\n    generate(cfg)\n"
  },
  {
    "path": "run_web_server.py",
    "content": "import json\nimport time\nimport torch\nimport argparse\nimport gradio as gr\n\nfrom GenMM import GenMM\nfrom nearest_neighbor.losses import PatchCoherentLoss\nfrom dataset.tracks_motion import TracksMotion\n\nargs = argparse.ArgumentParser(description='Web server for GenMM')\nargs.add_argument('-d', '--device', default=\"cuda:0\", type=str, help='device to use.')\nargs.add_argument('--ip', default=\"0.0.0.0\", type=str, help='interface url to host.')\nargs.add_argument('--port', default=8000, type=int, help='interface port to serve.')\nargs.add_argument('--debug', action='store_true', help='debug mode.')\nargs = args.parse_args()\n\ndef generate(data):\n    data = json.loads(data)\n\n    # create track object\n    motion_data = [TracksMotion(data['tracks'], repr='repr6d', use_velo=True, keep_y_pos=True, padding_last=False)]\n    model = GenMM(device=args.device, silent=True)\n    criteria = PatchCoherentLoss(patch_size=data['setting']['patch_size'], \n                                alpha=data['setting']['alpha'] if data['setting']['completeness'] else None, \n                                loop=data['setting']['loop'], cache=True)\n\n    # start generation\n    start = time.time()\n    syn = model.run(motion_data, criteria,\n                    num_frames=str(data['setting']['frames']),\n                    num_steps=data['setting']['num_steps'],\n                    noise_sigma=data['setting']['noise_sigma'],\n                    patch_size=data['setting']['patch_size'], \n                    coarse_ratio=f'{data[\"setting\"][\"coarse_ratio\"]}x_nframes',\n                    # coarse_ratio=f'3x_patchsize',\n                    pyr_factor=data['setting']['pyr_factor'])\n    end = time.time()\n\n    data['time'] = end - start\n    data['tracks'] = motion_data[0].parse(syn)\n\n    return data\n\nif __name__ == '__main__':\n    demo = gr.Interface(fn=generate, inputs=\"json\", outputs=\"json\")\n    demo.launch(debug=args.debug, server_name=args.ip, server_port=args.port)"
  },
  {
    "path": "utils/base.py",
    "content": "import os\nimport os.path as osp\nimport sys\nimport time\nimport yaml\nimport imageio\nimport random\nimport shutil\nimport random\nimport numpy as np\nimport torch\nfrom tqdm import tqdm\n\n# configuration\nclass ConfigParser():\n    def __init__(self, args):\n        \"\"\"\n        class to parse configuration.\n        \"\"\"\n        args = args.parse_args()\n        self.cfg = self.merge_config_file(args)\n\n        # set random seed\n        self.set_seed()\n\n    def __str__(self):\n        return str(self.cfg.__dict__)\n\n    def __getattr__(self, name):\n        \"\"\"\n        Access items use dot.notation.\n        \"\"\"\n        return self.cfg.__dict__[name]\n\n    def __getitem__(self, name):\n        \"\"\"\n        Access items like ordinary dict.\n        \"\"\"\n        return self.cfg.__dict__[name]\n\n    def merge_config_file(self, args, allow_invalid=True):\n        \"\"\"\n        Load json config file and merge the arguments\n        \"\"\"\n        assert args.config is not None\n        with open(args.config, 'r') as f:\n            cfg = yaml.safe_load(f)\n            if 'config' in cfg.keys():\n                del cfg['config']\n        f.close()\n        invalid_args = list(set(cfg.keys()) - set(dir(args)))\n        if invalid_args and not allow_invalid:\n            raise ValueError(f\"Invalid args {invalid_args} in {args.config}.\")\n        \n        for k in list(cfg.keys()):\n            if k in args.__dict__.keys() and args.__dict__[k] is not None:\n                print('=========>  overwrite config: {} = {}'.format(k, args.__dict__[k]))\n                del cfg[k]\n\n        args.__dict__.update(cfg)\n\n        return args\n\n    def set_seed(self):\n        ''' set random seed for random, numpy and torch. '''\n        if 'seed' not in self.cfg.__dict__.keys():\n            return\n        if self.cfg.seed is None:\n            self.cfg.seed = int(time.time()) % 1000000\n        print('=========>  set random seed: {}'.format(self.cfg.seed))\n        # fix random seeds for reproducibility\n        random.seed(self.cfg.seed)\n        np.random.seed(self.cfg.seed)\n        torch.manual_seed(self.cfg.seed)\n        torch.cuda.manual_seed(self.cfg.seed)\n\n    def save_codes_and_config(self, save_path):\n        \"\"\"\n        save codes and config to $save_path.\n        \"\"\"\n        cur_codes_path = osp.dirname(osp.dirname(os.path.abspath(__file__)))\n        if os.path.exists(save_path):\n            shutil.rmtree(save_path)\n        shutil.copytree(cur_codes_path, osp.join(save_path, 'codes'), \\\n            ignore=shutil.ignore_patterns('*debug*', '*data*', '*output*', '*exps*', '*.txt', '*.json', '*.mp4', '*.png', '*.jpg', '*.bvh', '*.csv', '*.pth', '*.tar', '*.npz'))\n\n        with open(osp.join(save_path, 'config.yaml'), 'w') as f:\n            f.write(yaml.dump(self.cfg.__dict__))\n        f.close()\n\n\n# logger util\nclass logger:\n    \"\"\"\n    Keeps track of the levels and steps of optimization. Logs it via TQDM\n    \"\"\"\n    def __init__(self, n_steps, n_lvls):\n        self.n_steps = n_steps\n        self.n_lvls = n_lvls\n        self.lvl = -1\n        self.lvl_step = 0\n        self.steps = 0\n        self.pbar = tqdm(total=self.n_lvls * self.n_steps, desc='Starting')\n\n    def step(self):\n        self.pbar.update(1)\n        self.steps += 1\n        self.lvl_step += 1\n\n    def new_lvl(self):\n        self.lvl += 1\n        self.lvl_step = 0\n\n    def print(self):\n        self.pbar.set_description(f'Lvl {self.lvl}/{self.n_lvls-1}, step {self.lvl_step}/{self.n_steps}')\n\n\n# other utils\ndef set_seed(seed=None):\n    \"\"\"\n    Set all the seed for the reproducible\n    \"\"\"\n    if seed is not None:\n        random.seed(seed)\n        np.random.seed(seed)\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed(seed)"
  },
  {
    "path": "utils/contact.py",
    "content": "import torch\n\n\ndef foot_contact_by_height(pos):\n    eps = 0.25\n    return (-eps < pos[..., 1]) * (pos[..., 1] < eps)\n\n\ndef velocity(pos, padding=False):\n    velo = pos[1:, ...] - pos[:-1, ...]\n    velo_norm = torch.norm(velo, dim=-1)\n    if padding:\n        pad = torch.zeros_like(velo_norm[:1, :])\n        velo_norm = torch.cat([pad, velo_norm], dim=0)\n    return velo_norm\n\n\ndef foot_contact(pos, ref_height=1., threshold=0.018):\n    velo_norm = velocity(pos)\n    contact = velo_norm < threshold\n    contact = contact.int()\n    padding = torch.zeros_like(contact)\n    contact = torch.cat([padding[:1, :], contact], dim=0)\n    return contact\n\n\ndef alpha(t):\n    return 2.0 * t * t * t - 3.0 * t * t + 1\n\n\ndef lerp(a, l, r):\n    return (1 - a) * l + a * r\n\n\ndef constrain_from_contact(contact, glb, fid='TBD', L=5):\n    \"\"\"\n    :param contact: contact label\n    :param glb: original global position\n    :param fid: joint id to fix, corresponding to the order in contact\n    :param L: frame to look forward/backward\n    :return:\n    \"\"\"\n    T = glb.shape[0]\n\n    for i, fidx in enumerate(fid):  # fidx: index of the foot joint\n        fixed = contact[:, i]  # [T]\n        s = 0\n        while s < T:\n            while s < T and fixed[s] == 0:\n                s += 1\n            if s >= T:\n                break\n            t = s\n            avg = glb[t, fidx].clone()\n            while t + 1 < T and fixed[t + 1] == 1:\n                t += 1\n                avg += glb[t, fidx].clone()\n            avg /= (t - s + 1)\n\n            for j in range(s, t + 1):\n                glb[j, fidx] = avg.clone()\n            s = t + 1\n\n        for s in range(T):\n            if fixed[s] == 1:\n                continue\n            l, r = None, None\n            consl, consr = False, False\n            for k in range(L):\n                if s - k - 1 < 0:\n                    break\n                if fixed[s - k - 1]:\n                    l = s - k - 1\n                    consl = True\n                    break\n            for k in range(L):\n                if s + k + 1 >= T:\n                    break\n                if fixed[s + k + 1]:\n                    r = s + k + 1\n                    consr = True\n                    break\n            if not consl and not consr:\n                continue\n            if consl and consr:\n                litp = lerp(alpha(1.0 * (s - l + 1) / (L + 1)),\n                            glb[s, fidx], glb[l, fidx])\n                ritp = lerp(alpha(1.0 * (r - s + 1) / (L + 1)),\n                            glb[s, fidx], glb[r, fidx])\n                itp = lerp(alpha(1.0 * (s - l + 1) / (r - l + 1)),\n                           ritp, litp)\n                glb[s, fidx] = itp.clone()\n                continue\n            if consl:\n                litp = lerp(alpha(1.0 * (s - l + 1) / (L + 1)),\n                            glb[s, fidx], glb[l, fidx])\n                glb[s, fidx] = litp.clone()\n                continue\n            if consr:\n                ritp = lerp(alpha(1.0 * (r - s + 1) / (L + 1)),\n                            glb[s, fidx], glb[r, fidx])\n                glb[s, fidx] = ritp.clone()\n    return glb\n"
  },
  {
    "path": "utils/kinematics.py",
    "content": "import torch\nfrom utils.transforms import quat2mat, repr6d2mat, euler2mat\n\n\nclass ForwardKinematics:\n    def __init__(self, parents, offsets=None):\n        self.parents = parents\n        if offsets is not None and len(offsets.shape) == 2:\n            offsets = offsets.unsqueeze(0)\n        self.offsets = offsets\n\n    def forward(self, rots, offsets=None, global_pos=None):\n        \"\"\"\n        Forward Kinematics: returns a per-bone transformation\n        @param rots: local joint rotations (batch_size, bone_num, 3, 3)\n        @param offsets: (batch_size, bone_num, 3) or None\n        @param global_pos: global_position: (batch_size, 3) or keep it as in offsets (default)\n        @return: (batch_szie, bone_num, 3, 4)\n        \"\"\"\n        rots = rots.clone()\n        if offsets is None:\n            offsets = self.offsets.to(rots.device)\n        if global_pos is None:\n            global_pos = offsets[:, 0]\n\n        pos = torch.zeros((rots.shape[0], rots.shape[1], 3), device=rots.device)\n        rest_pos = torch.zeros_like(pos)\n        res = torch.zeros((rots.shape[0], rots.shape[1], 3, 4), device=rots.device)\n\n        pos[:, 0] = global_pos\n        rest_pos[:, 0] = offsets[:, 0]\n\n        for i, p in enumerate(self.parents):\n            if i != 0:\n                rots[:, i] = torch.matmul(rots[:, p], rots[:, i])\n                pos[:, i] = torch.matmul(rots[:, p], offsets[:, i].unsqueeze(-1)).squeeze(-1) + pos[:, p]\n                rest_pos[:, i] = rest_pos[:, p] + offsets[:, i]\n\n            res[:, i, :3, :3] = rots[:, i]\n            res[:, i, :, 3] = torch.matmul(rots[:, i], -rest_pos[:, i].unsqueeze(-1)).squeeze(-1) + pos[:, i]\n\n        return res\n\n    def accumulate(self, local_rots):\n        \"\"\"\n        Get global joint rotation from local rotations\n        @param local_rots: (batch_size, n_bone, 3, 3)\n        @return: global_rotations\n        \"\"\"\n        res = torch.empty_like(local_rots)\n        for i, p in enumerate(self.parents):\n            if i == 0:\n                res[:, i] = local_rots[:, i]\n            else:\n                res[:, i] = torch.matmul(res[:, p], local_rots[:, i])\n        return res\n\n    def unaccumulate(self, global_rots):\n        \"\"\"\n        Get local joint rotation from global rotations\n        @param global_rots: (batch_size, n_bone, 3, 3)\n        @return: local_rotations\n        \"\"\"\n        res = torch.empty_like(global_rots)\n        inv = torch.empty_like(global_rots)\n\n        for i, p in enumerate(self.parents):\n            if i == 0:\n                inv[:, i] = global_rots[:, i].transpose(-2, -1)\n                res[:, i] = global_rots[:, i]\n                continue\n            res[:, i] = torch.matmul(inv[:, p], global_rots[:, i])\n            inv[:, i] = torch.matmul(res[:, i].transpose(-2, -1), inv[:, p])\n\n        return res\n\n\nclass ForwardKinematicsJoint:\n    def __init__(self, parents, offset):\n        self.parents = parents\n        self.offset = offset\n\n    '''\n        rotation should have shape batch_size * Joint_num * (3/4) * Time\n        position should have shape batch_size * 3 * Time\n        offset should have shape batch_size * Joint_num * 3\n        output have shape batch_size * Time * Joint_num * 3\n    '''\n\n    def forward(self, rotation: torch.Tensor, position: torch.Tensor, offset=None,\n                world=True):\n        '''\n        if not quater and rotation.shape[-2] != 3: raise Exception('Unexpected shape of rotation')\n        if quater and rotation.shape[-2] != 4: raise Exception('Unexpected shape of rotation')\n        rotation = rotation.permute(0, 3, 1, 2)\n        position = position.permute(0, 2, 1)\n        '''\n        if rotation.shape[-1] == 6:\n            transform = repr6d2mat(rotation)\n        elif rotation.shape[-1] == 4:\n            norm = torch.norm(rotation, dim=-1, keepdim=True)\n            rotation = rotation / norm\n            transform = quat2mat(rotation)\n        elif rotation.shape[-1] == 3:\n            transform = euler2mat(rotation)\n        else:\n            raise Exception('Only accept quaternion rotation input')\n        result = torch.empty(transform.shape[:-2] + (3,), device=position.device)\n\n        if offset is None:\n            offset = self.offset\n        offset = offset.reshape((-1, 1, offset.shape[-2], offset.shape[-1], 1))\n\n        result[..., 0, :] = position\n        for i, pi in enumerate(self.parents):\n            if pi == -1:\n                assert i == 0\n                continue\n\n            result[..., i, :] = torch.matmul(transform[..., pi, :, :], offset[..., i, :, :]).squeeze()\n            transform[..., i, :, :] = torch.matmul(transform[..., pi, :, :].clone(), transform[..., i, :, :].clone())\n            if world: result[..., i, :] += result[..., pi, :]\n        return result\n\n\nclass InverseKinematicsJoint:\n    def __init__(self, rotations: torch.Tensor, positions: torch.Tensor, offset, parents, constrains):\n        self.rotations = rotations.detach().clone()\n        self.rotations.requires_grad_(True)\n        self.position = positions.detach().clone()\n        self.position.requires_grad_(True)\n\n        self.parents = parents\n        self.offset = offset\n        self.constrains = constrains\n\n        self.optimizer = torch.optim.Adam([self.position, self.rotations], lr=1e-3, betas=(0.9, 0.999))\n        self.criteria = torch.nn.MSELoss()\n\n        self.fk = ForwardKinematicsJoint(parents, offset)\n\n        self.glb = None\n\n    def step(self):\n        self.optimizer.zero_grad()\n        glb = self.fk.forward(self.rotations, self.position)\n        loss = self.criteria(glb, self.constrains)\n        loss.backward()\n        self.optimizer.step()\n        self.glb = glb\n        return loss.item()\n\n\nclass InverseKinematicsJoint2:\n    def __init__(self, rotations: torch.Tensor, positions: torch.Tensor, offset, parents, constrains, cid,\n                 lambda_rec_rot=1., lambda_rec_pos=1., use_velo=False):\n        self.use_velo = use_velo\n        self.rotations_ori = rotations.detach().clone()\n        self.rotations = rotations.detach().clone()\n        self.rotations.requires_grad_(True)\n        self.position_ori = positions.detach().clone()\n        self.position = positions.detach().clone()\n        if self.use_velo:\n            self.position[1:] = self.position[1:] - self.position[:-1]\n        self.position.requires_grad_(True)\n\n        self.parents = parents\n        self.offset = offset\n        self.constrains = constrains.detach().clone()\n        self.cid = cid\n\n        self.lambda_rec_rot = lambda_rec_rot\n        self.lambda_rec_pos = lambda_rec_pos\n\n        self.optimizer = torch.optim.Adam([self.position, self.rotations], lr=1e-3, betas=(0.9, 0.999))\n        self.criteria = torch.nn.MSELoss()\n\n        self.fk = ForwardKinematicsJoint(parents, offset)\n\n        self.glb = None\n\n    def step(self):\n        self.optimizer.zero_grad()\n        if self.use_velo:\n            position = torch.cumsum(self.position, dim=0)\n        else:\n            position = self.position\n        glb = self.fk.forward(self.rotations, position)\n        self.constrain_loss = self.criteria(glb[:, self.cid], self.constrains)\n        self.rec_loss_rot = self.criteria(self.rotations, self.rotations_ori)\n        self.rec_loss_pos = self.criteria(self.position, self.position_ori)\n        loss = self.constrain_loss + self.rec_loss_rot * self.lambda_rec_rot + self.rec_loss_pos * self.lambda_rec_pos\n        loss.backward()\n        self.optimizer.step()\n        self.glb = glb\n        return loss.item()\n\n    def get_position(self):\n        if self.use_velo:\n            position = torch.cumsum(self.position.detach(), dim=0)\n        else:\n            position = self.position.detach()\n        return position\n"
  },
  {
    "path": "utils/rename_mixamo_rig.py",
    "content": "# rename_mixamo_prefix.py\nimport bpy, re\nrx = re.compile(r\"mixamorig\\d+:\")          # any number before the colon\n\nfor obj in bpy.data.objects:\n    if obj.type == 'ARMATURE':\n        for b in obj.data.bones:\n            b.name = rx.sub(\"mixamorig:\", b.name)"
  },
  {
    "path": "utils/skeleton.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\nimport numpy as np\n\n\nclass SkeletonConv(nn.Module):\n    def __init__(self, neighbour_list, in_channels, out_channels, kernel_size, joint_num, stride=1, padding=0,\n                 bias=True, padding_mode='zeros', add_offset=False, in_offset_channel=0):\n        super(SkeletonConv, self).__init__()\n\n        if in_channels % joint_num != 0 or out_channels % joint_num != 0:\n            raise Exception('in/out channels should be divided by joint_num')\n        self.in_channels_per_joint = in_channels // joint_num\n        self.out_channels_per_joint = out_channels // joint_num\n\n        if padding_mode == 'zeros': padding_mode = 'constant'\n\n        self.expanded_neighbour_list = []\n        self.expanded_neighbour_list_offset = []\n        self.neighbour_list = neighbour_list\n        self.add_offset = add_offset\n        self.joint_num = joint_num\n\n        self.stride = stride\n        self.dilation = 1\n        self.groups = 1\n        self.padding = padding\n        self.padding_mode = padding_mode\n        self._padding_repeated_twice = (padding, padding)\n\n        for neighbour in neighbour_list:\n            expanded = []\n            for k in neighbour:\n                for i in range(self.in_channels_per_joint):\n                    expanded.append(k * self.in_channels_per_joint + i)\n            self.expanded_neighbour_list.append(expanded)\n\n        if self.add_offset:\n            self.offset_enc = SkeletonLinear(neighbour_list, in_offset_channel * len(neighbour_list), out_channels)\n\n            for neighbour in neighbour_list:\n                expanded = []\n                for k in neighbour:\n                    for i in range(add_offset):\n                        expanded.append(k * in_offset_channel + i)\n                self.expanded_neighbour_list_offset.append(expanded)\n\n        self.weight = torch.zeros(out_channels, in_channels, kernel_size)\n        if bias:\n            self.bias = torch.zeros(out_channels)\n        else:\n            self.register_parameter('bias', None)\n\n        self.mask = torch.zeros_like(self.weight)\n        for i, neighbour in enumerate(self.expanded_neighbour_list):\n            self.mask[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...] = 1\n        self.mask = nn.Parameter(self.mask, requires_grad=False)\n\n        self.description = 'SkeletonConv(in_channels_per_armature={}, out_channels_per_armature={}, kernel_size={}, ' \\\n                           'joint_num={}, stride={}, padding={}, bias={})'.format(\n            in_channels // joint_num, out_channels // joint_num, kernel_size, joint_num, stride, padding, bias\n        )\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        for i, neighbour in enumerate(self.expanded_neighbour_list):\n            \"\"\" Use temporary variable to avoid assign to copy of slice, which might lead to un expected result \"\"\"\n            tmp = torch.zeros_like(self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1),\n                                   neighbour, ...])\n            nn.init.kaiming_uniform_(tmp, a=math.sqrt(5))\n            self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1),\n                        neighbour, ...] = tmp\n            if self.bias is not None:\n                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(\n                    self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...])\n                bound = 1 / math.sqrt(fan_in)\n                tmp = torch.zeros_like(\n                    self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)])\n                nn.init.uniform_(tmp, -bound, bound)\n                self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)] = tmp\n\n        self.weight = nn.Parameter(self.weight)\n        if self.bias is not None:\n            self.bias = nn.Parameter(self.bias)\n\n    def set_offset(self, offset):\n        if not self.add_offset: raise Exception('Wrong Combination of Parameters')\n        self.offset = offset.reshape(offset.shape[0], -1)\n\n    def forward(self, input):\n        weight_masked = self.weight * self.mask\n        res = F.conv1d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),\n                       weight_masked, self.bias, self.stride,\n                       0, self.dilation, self.groups)\n\n        if self.add_offset:\n            offset_res = self.offset_enc(self.offset)\n            offset_res = offset_res.reshape(offset_res.shape + (1, ))\n            res += offset_res / 100\n        return res\n\n    def __repr__(self):\n        return self.description\n\n\nclass SkeletonLinear(nn.Module):\n    def __init__(self, neighbour_list, in_channels, out_channels, extra_dim1=False):\n        super(SkeletonLinear, self).__init__()\n        self.neighbour_list = neighbour_list\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.in_channels_per_joint = in_channels // len(neighbour_list)\n        self.out_channels_per_joint = out_channels // len(neighbour_list)\n        self.extra_dim1 = extra_dim1\n        self.expanded_neighbour_list = []\n\n        for neighbour in neighbour_list:\n            expanded = []\n            for k in neighbour:\n                for i in range(self.in_channels_per_joint):\n                    expanded.append(k * self.in_channels_per_joint + i)\n            self.expanded_neighbour_list.append(expanded)\n\n        self.weight = torch.zeros(out_channels, in_channels)\n        self.mask = torch.zeros(out_channels, in_channels)\n        self.bias = nn.Parameter(torch.Tensor(out_channels))\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        for i, neighbour in enumerate(self.expanded_neighbour_list):\n            tmp = torch.zeros_like(\n                self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour]\n            )\n            self.mask[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = 1\n            nn.init.kaiming_uniform_(tmp, a=math.sqrt(5))\n            self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = tmp\n\n        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)\n        bound = 1 / math.sqrt(fan_in)\n        nn.init.uniform_(self.bias, -bound, bound)\n\n        self.weight = nn.Parameter(self.weight)\n        self.mask = nn.Parameter(self.mask, requires_grad=False)\n\n    def forward(self, input):\n        input = input.reshape(input.shape[0], -1)\n        weight_masked = self.weight * self.mask\n        res = F.linear(input, weight_masked, self.bias)\n        if self.extra_dim1: res = res.reshape(res.shape + (1,))\n        return res\n\n\nclass SkeletonPoolJoint(nn.Module):\n    def __init__(self, topology, pooling_mode, channels_per_joint, last_pool=False):\n        super(SkeletonPoolJoint, self).__init__()\n\n        if pooling_mode != 'mean':\n            raise Exception('Unimplemented pooling mode in matrix_implementation')\n\n        self.joint_num = len(topology)\n        self.parent = topology\n        self.pooling_list = []\n        self.pooling_mode = pooling_mode\n\n        self.pooling_map = [-1 for _ in range(len(self.parent))]\n        self.child = [-1 for _ in range(len(self.parent))]\n        children_cnt = [0 for _ in range(len(self.parent))]\n        for x, pa in enumerate(self.parent):\n            if pa < 0: continue\n            children_cnt[pa] += 1\n            self.child[pa] = x\n        self.pooling_map[0] = 0\n        for x in range(len(self.parent)):\n            if children_cnt[x] == 0 or (children_cnt[x] == 1 and children_cnt[self.child[x]] > 1):\n                while children_cnt[x] <= 1:\n                    pa = self.parent[x]\n                    if last_pool:\n                        seq = [x]\n                        while pa != -1 and children_cnt[pa] == 1:\n                            seq = [pa] + seq\n                            x = pa\n                            pa = self.parent[x]\n                        self.pooling_list.append(seq)\n                        break\n                    else:\n                        if pa != -1 and children_cnt[pa] == 1:\n                            self.pooling_list.append([pa, x])\n                            x = self.parent[pa]\n                        else:\n                            self.pooling_list.append([x, ])\n                            break\n            elif children_cnt[x] > 1:\n                self.pooling_list.append([x, ])\n\n        self.description = 'SkeletonPool(in_joint_num={}, out_joint_num={})'.format(\n            len(topology), len(self.pooling_list),\n        )\n\n        self.pooling_list.sort(key=lambda x:x[0])\n        for i, a in enumerate(self.pooling_list):\n            for j in a:\n                self.pooling_map[j] = i\n\n        self.output_joint_num = len(self.pooling_list)\n        self.new_topology = [-1 for _ in range(len(self.pooling_list))]\n        for i, x in enumerate(self.pooling_list):\n            if i < 1: continue\n            self.new_topology[i] = self.pooling_map[self.parent[x[0]]]\n\n        self.weight = torch.zeros(len(self.pooling_list) * channels_per_joint, self.joint_num * channels_per_joint)\n\n        for i, pair in enumerate(self.pooling_list):\n            for j in pair:\n                for c in range(channels_per_joint):\n                    self.weight[i * channels_per_joint + c, j * channels_per_joint + c] = 1.0 / len(pair)\n\n        self.weight = nn.Parameter(self.weight, requires_grad=False)\n\n    def forward(self, input: torch.Tensor):\n        return torch.matmul(self.weight, input.unsqueeze(-1)).squeeze(-1)\n\n\nclass SkeletonPool(nn.Module):\n    def __init__(self, edges, pooling_mode, channels_per_edge, last_pool=False):\n        super(SkeletonPool, self).__init__()\n\n        if pooling_mode != 'mean':\n            raise Exception('Unimplemented pooling mode in matrix_implementation')\n\n        self.channels_per_edge = channels_per_edge\n        self.pooling_mode = pooling_mode\n        self.edge_num = len(edges) + 1\n        self.seq_list = []\n        self.pooling_list = []\n        self.new_edges = []\n        degree = [0] * 100\n\n        for edge in edges:\n            degree[edge[0]] += 1\n            degree[edge[1]] += 1\n\n        def find_seq(j, seq):\n            nonlocal self, degree, edges\n\n            if degree[j] > 2 and j != 0:\n                self.seq_list.append(seq)\n                seq = []\n\n            if degree[j] == 1:\n                self.seq_list.append(seq)\n                return\n\n            for idx, edge in enumerate(edges):\n                if edge[0] == j:\n                    find_seq(edge[1], seq + [idx])\n\n        find_seq(0, [])\n        for seq in self.seq_list:\n            if last_pool:\n                self.pooling_list.append(seq)\n                continue\n            if len(seq) % 2 == 1:\n                self.pooling_list.append([seq[0]])\n                self.new_edges.append(edges[seq[0]])\n                seq = seq[1:]\n            for i in range(0, len(seq), 2):\n                self.pooling_list.append([seq[i], seq[i + 1]])\n                self.new_edges.append([edges[seq[i]][0], edges[seq[i + 1]][1]])\n\n        # add global position\n        self.pooling_list.append([self.edge_num - 1])\n\n        self.description = 'SkeletonPool(in_edge_num={}, out_edge_num={})'.format(\n            len(edges), len(self.pooling_list)\n        )\n\n        self.weight = torch.zeros(len(self.pooling_list) * channels_per_edge, self.edge_num * channels_per_edge)\n\n        for i, pair in enumerate(self.pooling_list):\n            for j in pair:\n                for c in range(channels_per_edge):\n                    self.weight[i * channels_per_edge + c, j * channels_per_edge + c] = 1.0 / len(pair)\n\n        self.weight = nn.Parameter(self.weight, requires_grad=False)\n\n    def forward(self, input: torch.Tensor):\n        return torch.matmul(self.weight, input)\n\n\nclass SkeletonUnpool(nn.Module):\n    def __init__(self, pooling_list, channels_per_edge):\n        super(SkeletonUnpool, self).__init__()\n        self.pooling_list = pooling_list\n        self.input_joint_num = len(pooling_list)\n        self.output_joint_num = 0\n        self.channels_per_edge = channels_per_edge\n        for t in self.pooling_list:\n            self.output_joint_num += len(t)\n\n        self.description = 'SkeletonUnpool(in_joint_num={}, out_joint_num={})'.format(\n            self.input_joint_num, self.output_joint_num,\n        )\n\n        self.weight = torch.zeros(self.output_joint_num * channels_per_edge, self.input_joint_num * channels_per_edge)\n\n        for i, pair in enumerate(self.pooling_list):\n            for j in pair:\n                for c in range(channels_per_edge):\n                    self.weight[j * channels_per_edge + c, i * channels_per_edge + c] = 1\n\n        self.weight = nn.Parameter(self.weight)\n        self.weight.requires_grad_(False)\n\n    def forward(self, input: torch.Tensor):\n        return torch.matmul(self.weight, input.unsqueeze(-1)).squeeze(-1)\n\n\ndef find_neighbor_joint(parents, threshold):\n    n_joint = len(parents)\n    dist_mat = np.empty((n_joint, n_joint), dtype=np.int)\n    dist_mat[:, :] = 100000\n    for i, p in enumerate(parents):\n        dist_mat[i, i] = 0\n        if i != 0:\n            dist_mat[i, p] = dist_mat[p, i] = 1\n\n    \"\"\"\n    Floyd's algorithm\n    \"\"\"\n    for k in range(n_joint):\n        for i in range(n_joint):\n            for j in range(n_joint):\n                dist_mat[i, j] = min(dist_mat[i, j], dist_mat[i, k] + dist_mat[k, j])\n\n    neighbor_list = []\n    for i in range(n_joint):\n        neighbor = []\n        for j in range(n_joint):\n            if dist_mat[i, j] <= threshold:\n                neighbor.append(j)\n        neighbor_list.append(neighbor)\n\n    return neighbor_list\n"
  },
  {
    "path": "utils/transforms.py",
    "content": "import numpy as np\nimport torch\n\n\ndef batch_mm(matrix, matrix_batch):\n    \"\"\"\n    https://github.com/pytorch/pytorch/issues/14489#issuecomment-607730242\n    :param matrix: Sparse or dense matrix, size (m, n).\n    :param matrix_batch: Batched dense matrices, size (b, n, k).\n    :return: The batched matrix-matrix product, size (m, n) x (b, n, k) = (b, m, k).\n    \"\"\"\n    batch_size = matrix_batch.shape[0]\n    # Stack the vector batch into columns. (b, n, k) -> (n, b, k) -> (n, b*k)\n    vectors = matrix_batch.transpose(0, 1).reshape(matrix.shape[1], -1)\n\n    # A matrix-matrix product is a batched matrix-vector product of the columns.\n    # And then reverse the reshaping. (m, n) x (n, b*k) = (m, b*k) -> (m, b, k) -> (b, m, k)\n    return matrix.mm(vectors).reshape(matrix.shape[0], batch_size, -1).transpose(1, 0)\n\n\ndef aa2quat(rots, form='wxyz', unified_orient=True):\n    \"\"\"\n    Convert angle-axis representation to wxyz quaternion and to the half plan (w >= 0)\n    @param rots: angle-axis rotations, (*, 3)\n    @param form: quaternion format, either 'wxyz' or 'xyzw'\n    @param unified_orient: Use unified orientation for quaternion (quaternion is dual cover of SO3)\n    :return:\n    \"\"\"\n    angles = rots.norm(dim=-1, keepdim=True)\n    norm = angles.clone()\n    norm[norm < 1e-8] = 1\n    axis = rots / norm\n    quats = torch.empty(rots.shape[:-1] + (4,), device=rots.device, dtype=rots.dtype)\n    angles = angles * 0.5\n    if form == 'wxyz':\n        quats[..., 0] = torch.cos(angles.squeeze(-1))\n        quats[..., 1:] = torch.sin(angles) * axis\n    elif form == 'xyzw':\n        quats[..., :3] = torch.sin(angles) * axis\n        quats[..., 3] = torch.cos(angles.squeeze(-1))\n\n    if unified_orient:\n        idx = quats[..., 0] < 0\n        quats[idx, :] *= -1\n\n    return quats\n\n\ndef quat2aa(quats):\n    \"\"\"\n    Convert wxyz quaternions to angle-axis representation\n    :param quats:\n    :return:\n    \"\"\"\n    _cos = quats[..., 0]\n    xyz = quats[..., 1:]\n    _sin = xyz.norm(dim=-1)\n    norm = _sin.clone()\n    norm[norm < 1e-7] = 1\n    axis = xyz / norm.unsqueeze(-1)\n    angle = torch.atan2(_sin, _cos) * 2\n    return axis * angle.unsqueeze(-1)\n\n\ndef quat2mat(quats: torch.Tensor):\n    \"\"\"\n    Convert (w, x, y, z) quaternions to 3x3 rotation matrix\n    :param quats: quaternions of shape (..., 4)\n    :return:  rotation matrices of shape (..., 3, 3)\n    \"\"\"\n    qw = quats[..., 0]\n    qx = quats[..., 1]\n    qy = quats[..., 2]\n    qz = quats[..., 3]\n\n    x2 = qx + qx\n    y2 = qy + qy\n    z2 = qz + qz\n    xx = qx * x2\n    yy = qy * y2\n    wx = qw * x2\n    xy = qx * y2\n    yz = qy * z2\n    wy = qw * y2\n    xz = qx * z2\n    zz = qz * z2\n    wz = qw * z2\n\n    m = torch.empty(quats.shape[:-1] + (3, 3), device=quats.device, dtype=quats.dtype)\n    m[..., 0, 0] = 1.0 - (yy + zz)\n    m[..., 0, 1] = xy - wz\n    m[..., 0, 2] = xz + wy\n    m[..., 1, 0] = xy + wz\n    m[..., 1, 1] = 1.0 - (xx + zz)\n    m[..., 1, 2] = yz - wx\n    m[..., 2, 0] = xz - wy\n    m[..., 2, 1] = yz + wx\n    m[..., 2, 2] = 1.0 - (xx + yy)\n\n    return m\n\n\ndef quat2euler(q, order='xyz', degrees=True):\n    \"\"\"\n    Convert (w, x, y, z) quaternions to xyz euler angles. This is  used for bvh output.\n    \"\"\"\n    q0 = q[..., 0]\n    q1 = q[..., 1]\n    q2 = q[..., 2]\n    q3 = q[..., 3]\n    es = torch.empty(q0.shape + (3,), device=q.device, dtype=q.dtype)\n\n    if order == 'xyz':\n        es[..., 2] = torch.atan2(2 * (q0 * q3 - q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)\n        es[..., 1] = torch.asin((2 * (q1 * q3 + q0 * q2)).clip(-1, 1))\n        es[..., 0] = torch.atan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)\n    else:\n        raise NotImplementedError('Cannot convert to ordering %s' % order)\n\n    if degrees:\n        es = es * 180 / np.pi\n\n    return es\n\n\ndef euler2mat(rots, order='xyz'):\n    axis = {'x': torch.tensor((1, 0, 0), device=rots.device),\n            'y': torch.tensor((0, 1, 0), device=rots.device),\n            'z': torch.tensor((0, 0, 1), device=rots.device)}\n\n    rots = rots / 180 * np.pi\n    mats = []\n    for i in range(3):\n        aa = axis[order[i]] * rots[..., i].unsqueeze(-1)\n        mats.append(aa2mat(aa))\n    return mats[0] @ (mats[1] @ mats[2])\n\n\ndef aa2mat(rots):\n    \"\"\"\n    Convert angle-axis representation to rotation matrix\n    :param rots: angle-axis representation\n    :return:\n    \"\"\"\n    quat = aa2quat(rots)\n    mat = quat2mat(quat)\n    return mat\n\n\ndef mat2quat(R) -> torch.Tensor:\n    '''\n    https://github.com/duolu/pyrotation/blob/master/pyrotation/pyrotation.py\n    Convert a rotation matrix to a unit quaternion.\n\n    This uses the Shepperd’s method for numerical stability.\n    '''\n\n    # The rotation matrix must be orthonormal\n\n    w2 = (1 + R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2])\n    x2 = (1 + R[..., 0, 0] - R[..., 1, 1] - R[..., 2, 2])\n    y2 = (1 - R[..., 0, 0] + R[..., 1, 1] - R[..., 2, 2])\n    z2 = (1 - R[..., 0, 0] - R[..., 1, 1] + R[..., 2, 2])\n\n    yz = (R[..., 1, 2] + R[..., 2, 1])\n    xz = (R[..., 2, 0] + R[..., 0, 2])\n    xy = (R[..., 0, 1] + R[..., 1, 0])\n\n    wx = (R[..., 2, 1] - R[..., 1, 2])\n    wy = (R[..., 0, 2] - R[..., 2, 0])\n    wz = (R[..., 1, 0] - R[..., 0, 1])\n\n    w = torch.empty_like(x2)\n    x = torch.empty_like(x2)\n    y = torch.empty_like(x2)\n    z = torch.empty_like(x2)\n\n    flagA = (R[..., 2, 2] < 0) * (R[..., 0, 0] > R[..., 1, 1])\n    flagB = (R[..., 2, 2] < 0) * (R[..., 0, 0] <= R[..., 1, 1])\n    flagC = (R[..., 2, 2] >= 0) * (R[..., 0, 0] < -R[..., 1, 1])\n    flagD = (R[..., 2, 2] >= 0) * (R[..., 0, 0] >= -R[..., 1, 1])\n\n    x[flagA] = torch.sqrt(x2[flagA])\n    w[flagA] = wx[flagA] / x[flagA]\n    y[flagA] = xy[flagA] / x[flagA]\n    z[flagA] = xz[flagA] / x[flagA]\n\n    y[flagB] = torch.sqrt(y2[flagB])\n    w[flagB] = wy[flagB] / y[flagB]\n    x[flagB] = xy[flagB] / y[flagB]\n    z[flagB] = yz[flagB] / y[flagB]\n\n    z[flagC] = torch.sqrt(z2[flagC])\n    w[flagC] = wz[flagC] / z[flagC]\n    x[flagC] = xz[flagC] / z[flagC]\n    y[flagC] = yz[flagC] / z[flagC]\n\n    w[flagD] = torch.sqrt(w2[flagD])\n    x[flagD] = wx[flagD] / w[flagD]\n    y[flagD] = wy[flagD] / w[flagD]\n    z[flagD] = wz[flagD] / w[flagD]\n\n    # if R[..., 2, 2] < 0:\n    #\n    #     if R[..., 0, 0] > R[..., 1, 1]:\n    #\n    #         x = torch.sqrt(x2)\n    #         w = wx / x\n    #         y = xy / x\n    #         z = xz / x\n    #\n    #     else:\n    #\n    #         y = torch.sqrt(y2)\n    #         w = wy / y\n    #         x = xy / y\n    #         z = yz / y\n    #\n    # else:\n    #\n    #     if R[..., 0, 0] < -R[..., 1, 1]:\n    #\n    #         z = torch.sqrt(z2)\n    #         w = wz / z\n    #         x = xz / z\n    #         y = yz / z\n    #\n    #     else:\n    #\n    #         w = torch.sqrt(w2)\n    #         x = wx / w\n    #         y = wy / w\n    #         z = wz / w\n\n    res = [w, x, y, z]\n    res = [z.unsqueeze(-1) for z in res]\n\n    return torch.cat(res, dim=-1) / 2\n\n\ndef quat2repr6d(quat):\n    mat = quat2mat(quat)\n    res = mat[..., :2, :]\n    res = res.reshape(res.shape[:-2] + (6, ))\n    return res\n\n\ndef repr6d2mat(repr):\n    x = repr[..., :3]\n    y = repr[..., 3:]\n    x = x / x.norm(dim=-1, keepdim=True)\n    z = torch.cross(x, y)\n    z = z / z.norm(dim=-1, keepdim=True)\n    y = torch.cross(z, x)\n    res = [x, y, z]\n    res = [v.unsqueeze(-2) for v in res]\n    mat = torch.cat(res, dim=-2)\n    return mat\n\n\ndef repr6d2quat(repr) -> torch.Tensor:\n    x = repr[..., :3]\n    y = repr[..., 3:]\n    x = x / x.norm(dim=-1, keepdim=True)\n    z = torch.cross(x, y)\n    z = z / z.norm(dim=-1, keepdim=True)\n    y = torch.cross(z, x)\n    res = [x, y, z]\n    res = [v.unsqueeze(-2) for v in res]\n    mat = torch.cat(res, dim=-2)\n    return mat2quat(mat)\n\n\ndef inv_affine(mat):\n    \"\"\"\n    Calculate the inverse of any affine transformation\n    \"\"\"\n    affine = torch.zeros((mat.shape[:2] + (1, 4)))\n    affine[..., 3] = 1\n    vert_mat = torch.cat((mat, affine), dim=2)\n    vert_mat_inv = torch.inverse(vert_mat)\n    return vert_mat_inv[..., :3, :]\n\n\ndef inv_rigid_affine(mat):\n    \"\"\"\n    Calculate the inverse of a rigid affine transformation\n    \"\"\"\n    res = mat.clone()\n    res[..., :3] = mat[..., :3].transpose(-2, -1)\n    res[..., 3] = -torch.matmul(res[..., :3], mat[..., 3].unsqueeze(-1)).squeeze(-1)\n    return res\n\n\ndef generate_pose(batch_size, device, uniform=False, factor=1, root_rot=False, n_bone=None, ee=None):\n    if n_bone is None: n_bone = 24\n    if ee is not None:\n        if root_rot:\n            ee.append(0)\n        n_bone_ = n_bone\n        n_bone = len(ee)\n    axis = torch.randn((batch_size, n_bone, 3), device=device)\n    axis /= axis.norm(dim=-1, keepdim=True)\n    if uniform:\n        angle = torch.rand((batch_size, n_bone, 1), device=device) * np.pi\n    else:\n        angle = torch.randn((batch_size, n_bone, 1), device=device) * np.pi / 6 * factor\n        angle.clamp(-np.pi, np.pi)\n    poses = axis * angle\n    if ee is not None:\n        res = torch.zeros((batch_size, n_bone_, 3), device=device)\n        for i, id in enumerate(ee):\n            res[:, id] = poses[:, i]\n        poses = res\n    poses = poses.reshape(batch_size, -1)\n    if not root_rot:\n        poses[..., :3] = 0\n    return poses\n\n\ndef slerp(l, r, t, unit=True):\n    \"\"\"\n    :param l: shape = (*, n)\n    :param r: shape = (*, n)\n    :param t: shape = (*)\n    :param unit: If l and h are unit vectors\n    :return:\n    \"\"\"\n    eps = 1e-8\n    if not unit:\n        l_n = l / torch.norm(l, dim=-1, keepdim=True)\n        r_n = r / torch.norm(r, dim=-1, keepdim=True)\n    else:\n        l_n = l\n        r_n = r\n    omega = torch.acos((l_n * r_n).sum(dim=-1).clamp(-1, 1))\n    dom = torch.sin(omega)\n\n    flag = dom < eps\n\n    res = torch.empty_like(l_n)\n    t_t = t[flag].unsqueeze(-1)\n    res[flag] = (1 - t_t) * l_n[flag] + t_t * r_n[flag]\n\n    flag = ~ flag\n\n    t_t = t[flag]\n    d_t = dom[flag]\n    va = torch.sin((1 - t_t) * omega[flag]) / d_t\n    vb = torch.sin(t_t * omega[flag]) / d_t\n    res[flag] = (va.unsqueeze(-1) * l_n[flag] + vb.unsqueeze(-1) * r_n[flag])\n    return res\n\n\ndef slerp_quat(l, r, t):\n    \"\"\"\n    slerp for unit quaternions\n    :param l: (*, 4) unit quaternion\n    :param r: (*, 4) unit quaternion\n    :param t: (*) scalar between 0 and 1\n    \"\"\"\n    t = t.expand(l.shape[:-1])\n    flag = (l * r).sum(dim=-1) >= 0\n    res = torch.empty_like(l)\n    res[flag] = slerp(l[flag], r[flag], t[flag])\n    flag = ~ flag\n    res[flag] = slerp(-l[flag], r[flag], t[flag])\n    return res\n\n\n# def slerp_6d(l, r, t):\n#     l_q = repr6d2quat(l)\n#     r_q = repr6d2quat(r)\n#     res_q = slerp_quat(l_q, r_q, t)\n#     return quat2repr6d(res_q)\n\n\ndef interpolate_6d(input, size):\n    \"\"\"\n    :param input: (batch_size, n_channels, length)\n    :param size: required output size for temporal axis\n    :return:\n    \"\"\"\n    batch = input.shape[0]\n    length = input.shape[-1]\n    input = input.reshape((batch, -1, 6, length))\n    input = input.permute(0, 1, 3, 2)     # (batch_size, n_joint, length, 6)\n    input_q = repr6d2quat(input)\n    idx = torch.tensor(list(range(size)), device=input_q.device, dtype=torch.float) / size * (length - 1)\n    idx_l = torch.floor(idx)\n    t = idx - idx_l\n    idx_l = idx_l.long()\n    idx_r = idx_l + 1\n    t = t.reshape((1, 1, -1))\n    res_q = slerp_quat(input_q[..., idx_l, :], input_q[..., idx_r, :], t)\n    res = quat2repr6d(res_q)  # shape = (batch_size, n_joint, t, 6)\n    res = res.permute(0, 1, 3, 2)\n    res = res.reshape((batch, -1, size))\n    return res\n"
  }
]