Full Code of wyysf-98/GenMM for AI

main aee9bec5e1b5 cached
32 files
193.8 KB
52.0k tokens
239 symbols
1 requests
Download .txt
Showing preview only (204K chars total). Download the full file or copy to clipboard to get everything.
Repository: wyysf-98/GenMM
Branch: main
Commit: aee9bec5e1b5
Files: 32
Total size: 193.8 KB

Directory structure:
gitextract_jk8oqutc/

├── .gitignore
├── GenMM.py
├── LICENSE
├── README.md
├── __init__.py
├── configs/
│   ├── default.yaml
│   └── ganimator.yaml
├── dataset/
│   ├── blender_motion.py
│   ├── bvh/
│   │   ├── Quaternions.py
│   │   ├── bvh_io.py
│   │   ├── bvh_parser.py
│   │   └── bvh_writer.py
│   ├── bvh_motion.py
│   ├── motion.py
│   └── tracks_motion.py
├── demo.blend
├── docker/
│   ├── Dockerfile
│   ├── README.md
│   ├── apt-sources.list
│   ├── requirements.txt
│   └── requirements_blender.txt
├── fix_contact.py
├── nearest_neighbor/
│   ├── losses.py
│   └── utils.py
├── run_random_generation.py
├── run_web_server.py
└── utils/
    ├── base.py
    ├── contact.py
    ├── kinematics.py
    ├── rename_mixamo_rig.py
    ├── skeleton.py
    └── transforms.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
*.json

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
out/
# C extensions
*.so
*.pkl

# Distribution / packaging
.Python
build/
develop-eggs/
distf/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
wandb/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

.vscode/*
.vscode/settings.json



# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/


# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
# target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
venv/
env.bak/
venv.bak/
 

# Rope project settings
.ropeproject

# mkdocs documentation
/site

 # Pyre type checker
.pyre/
checkpoints/
data/*
output/
log/
runs/

*.png
*.jpg
*.mp4
*.gif
*.pkl
*.pt

================================================
FILE: GenMM.py
================================================
import os
import os.path as osp
import numpy as np
import torch
import torch.nn.functional as F

from utils.base import logger

class GenMM:
    def __init__(self, mode = 'random_synthesis', noise_sigma = 1.0, coarse_ratio = 0.2, coarse_ratio_factor = 6, pyr_factor = 0.75, num_stages_limit = -1, device = 'cuda:0', silent = False):
        '''
        GenMM main constructor
        Args:
            device : str = 'cuda:0', default device.
            silent : bool = False, whether to mute the output.
        '''
        self.device = torch.device(device)
        self.silent = silent

    def _get_pyramid_lengths(self, final_len, coarse_ratio, pyr_factor):
        '''
        Get a list of pyramid lengths using given target length and ratio
        '''
        lengths = [int(np.round(final_len * coarse_ratio))]
        while lengths[-1] < final_len:
            lengths.append(int(np.round(lengths[-1] / pyr_factor)))
            if lengths[-1] == lengths[-2]:
                lengths[-1] += 1
        lengths[-1] = final_len

        return lengths

    def _get_target_pyramid(self, target, coarse_ratio, pyr_factor, num_stages_limit=-1):
        '''
        Reads a target motion(s) and create a pyraimd out of it. Ordered in increatorch.sing size
        '''
        self.num_target = len(target)
        lengths = []
        min_len = 10000
        for i in range(len(target)):
            new_length = self._get_pyramid_lengths(len(target[i].motion_data), coarse_ratio, pyr_factor)
            min_len = min(min_len, len(new_length))
            if num_stages_limit != -1:
                new_length = new_length[:num_stages_limit]
            lengths.append(new_length)
        for i in range(len(target)):
            lengths[i] = lengths[i][-min_len:]
        self.pyraimd_lengths = lengths

        target_pyramid = [[] for _ in range(len(lengths[0]))]
        for step in range(len(lengths[0])):
            for i in range(len(target)):
                length = lengths[i][step]
                target_pyramid[step].append(target[i].sample(size=length).to(self.device))

        if not self.silent:
            print('Levels:', lengths)
            for i in range(len(target_pyramid)):
                print(f'Number of clips in target pyramid {i} is {len(target_pyramid[i])}, ranging {[[tgt.min(), tgt.max()] for tgt in target_pyramid[i]]}')

        return target_pyramid

    def _get_initial_motion(self, init_length, noise_sigma):
        '''
        Prepare the initial motion for optimization
        '''
        initial_motion = F.interpolate(torch.cat([self.target_pyramid[0][i] for i in range(self.num_target)], dim=-1),
                                       size=init_length, mode='linear', align_corners=True)
        if noise_sigma > 0:
            initial_motion_w_noise = initial_motion + torch.randn_like(initial_motion) * noise_sigma
            initial_motion_w_noise = torch.fmod(initial_motion_w_noise, 1.0)
        else:
            initial_motion_w_noise = initial_motion

        if not self.silent:
            print('Initial motion:', initial_motion.min(), initial_motion.max())
            print('Initial motion with noise:', initial_motion_w_noise.min(), initial_motion_w_noise.max())

        return initial_motion_w_noise

    def run(self, target, criteria, num_frames, num_steps, noise_sigma, patch_size, coarse_ratio, pyr_factor, ext=None, debug_dir=None):
        '''
        generation function
        Args:
            mode             : - string = 'x?', generate x times longer frames results
                             : - int, specifying the number of times to generate
            noise_sigma      : float = 1.0, random noise.
            coarse_ratio     : float = 0.2, ratio at the coarse level.
            pyr_factor       : float = 0.75, pyramid factor.
            num_stages_limit : int = -1, no limit.
        '''
        if debug_dir is not None:
            from tensorboardX import SummaryWriter
            writer = SummaryWriter(log_dir=debug_dir)

        # build target pyramid
        if 'patchsize' in coarse_ratio:
            coarse_ratio = patch_size * float(coarse_ratio.split('x_')[0]) / max([len(t.motion_data) for t in target])
        elif 'nframes' in coarse_ratio:
            coarse_ratio = float(coarse_ratio.split('x_')[0])
        else:
            raise ValueError('Unsupported coarse ratio specified')
        self.target_pyramid = self._get_target_pyramid(target, coarse_ratio, pyr_factor)

        # get the initial motion data
        if 'nframes' in num_frames:
            syn_length = int(sum([i[-1] for i in self.pyraimd_lengths]) * float(num_frames.split('x_')[0]))
        elif num_frames.isdigit():
            syn_length = int(num_frames)
        else:
            raise ValueError(f'Unsupported mode {self.mode}')
        self.synthesized_lengths = self._get_pyramid_lengths(syn_length, coarse_ratio, pyr_factor)
        if not self.silent:
            print('Synthesized lengths:', self.synthesized_lengths)
        self.synthesized = self._get_initial_motion(self.synthesized_lengths[0], noise_sigma)

        # perform the optimization
        self.synthesized.requires_grad_(False)
        self.pbar = logger(num_steps, len(self.target_pyramid))
        for lvl, lvl_target in enumerate(self.target_pyramid):
            self.pbar.new_lvl()
            if lvl > 0:
                with torch.no_grad():
                    self.synthesized = F.interpolate(self.synthesized.detach(), size=self.synthesized_lengths[lvl], mode='linear')

            self.synthesized, losses = GenMM.match_and_blend(self.synthesized, lvl_target, criteria, num_steps, self.pbar, ext=ext)

            criteria.clean_cache()
            if debug_dir is not None:
                for itr in range(len(losses)):
                    writer.add_scalar(f'optimize/losses_lvl{lvl}', losses[itr], itr)
        self.pbar.pbar.close()

        return self.synthesized.detach()


    @staticmethod
    @torch.no_grad()
    def match_and_blend(synthesized, targets, criteria, n_steps, pbar, ext=None):
        '''
        Minimizes criteria between synthesized and target
        Args:
            synthesized    : torch.Tensor, optimized motion data
            targets        : torch.Tensor, target motion data
            criteria       : optimize target function
            n_steps        : int, number of steps to optimize
            pbar           : logger
            ext            : extra configurations or constraints (optional)
        '''
        losses = []
        keyframe_motion = targets[0] if isinstance(targets, list) else targets
        syn_length = synthesized.shape[-1]
        km_length = keyframe_motion.shape[-1]

        print("Synthesized shape:", synthesized.shape)
        print("Keyframe_motion shape:", keyframe_motion.shape)

        # Use the class-level KEYFRAME_INDICES
        keyframe_indices = GenMM.KEYFRAME_INDICES

        for _i in range(n_steps):
            synthesized, loss = criteria(synthesized, targets, ext=ext, return_blended_results=True)

            # Manually set the keyframes in synthesized motion to be the ones from the input motion
            if syn_length >= keyframe_indices.stop and km_length >= keyframe_indices.stop:
                synthesized[..., keyframe_indices] = keyframe_motion[..., keyframe_indices]

            # Update status
            losses.append(loss.item())
            pbar.step()
            pbar.print()

        return synthesized, losses



================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
# Example-based Motion Synthesis via Generative Motion Matching, ACM Transactions on Graphics (Proceedings of SIGGRAPH 2023)

#####  <p align="center"> [Weiyu Li*](https://wyysf-98.github.io/), [Xuelin Chen*†](https://xuelin-chen.github.io/), [Peizhuo Li](https://peizhuoli.github.io/), [Olga Sorkine-Hornung](https://igl.ethz.ch/people/sorkine/), [Baoquan Chen](https://cfcs.pku.edu.cn/baoquan/)</p>
 
#### <p align="center">[Project Page](https://wyysf-98.github.io/GenMM) | [ArXiv](https://arxiv.org/abs/2306.00378) | [Paper](https://wyysf-98.github.io/GenMM/paper/Paper_high_res.pdf) | [Video](https://youtu.be/lehnxcade4I)</p>

<p align="center">
  <img src="https://wyysf-98.github.io/GenMM/assets/images/teaser.png"/>
</p>

<p align="center"> All Code and demo will be released in this week(still ongoing...) 🏗️ 🚧 🔨</p>

- [x] Release main code
- [x] Release blender addon
- [x] Detailed README and installation guide
- [ ] Release skeleton-aware component, WIP as we need to split the joints into groups manually.
- [ ] Release codes for evaluation

## Prerequisite

<details> <summary>Setup environment</summary>

:smiley: We also provide a Dockerfile for easy installation, see [Setup using Docker](./docker/README.md).

 - Python 3.8
 - PyTorch 1.12.1
 - [unfoldNd](https://github.com/f-dangel/unfoldNd)

Clone this repository.

```sh
git clone git@github.com:wyysf-98/GenMM.git
```

Install the required packages.

```sh
conda create -n GenMM python=3.8
conda activate GenMM
conda install -c pytorch pytorch=1.12.1 torchvision=0.13.1 cudatoolkit=11.3 && \
pip install -r docker/requirements.txt
pip install torch-scatter==2.1.1
```

</details>

## Quick inference demo
For local quick inference demo using .bvh file, you can use

```sh
python run_random_generation.py -i './data/Malcolm/Gangnam-Style.bvh'
```
More configuration can be found in the `run_random_generation.py`.
We use an Apple M1 and NVIDIA Tesla V100 with 32 GB RAM to generate each motion, which takes about ~0.2s and ~0.05s as mentioned in our paper.

## Blender add-on
You can install and use the blender add-on with easy installation as our method is efficient and you do not need to install CUDA Toolkit.
We test our code using blender 3.22.0, and will support 2.8.0 in the future.

Step 1: Find yout blender python path. Common paths are as follows
```sh
(Windows) 'C:\Program Files\Blender Foundation\Blender 3.2\3.2\python\bin'
(Linux) '/path/to/blender/blender-path/3.2/python/bin'
(Windows) '/Applications/Blender.app/Contents/Resources/3.2/python/bin'
```

Step 2: Install required packages. Open your shell(Linux) or powershell(Windows), 
```sh
cd {your python path} && pip3 install -r docker/requirements.txt && pip3 install torch-scatter==2.1.0 -f https://data.pyg.org/whl/torch-1.12.0+${CUDA}.html
```
, where ${CUDA} should be replaced by either cpu, cu117, or cu118 depending on your PyTorch installation.
On my MacOS with M1 cpu,

```sh
cd /Applications/Blender.app/Contents/Resources/3.2/python/bin && pip3 install -r docker/requirements_blender.txt && pip3 install torch-scatter==2.1.0 -f https://data.pyg.org/whl/torch-1.12.0+cpu.html
```

Step 3: Install add-on in blender. [Blender Add-ons Official Tutorial](https://docs.blender.org/manual/en/latest/editors/preferences/addons.html). `edit -> Preferences -> Add-ons -> Install -> Select the downloaded .zip file`

Step 4: Have fun! Click the armature and you will find a `GenMM` tag.

(GPU support) If you have GPU and CUDA Toolskits installed, we automatically dectect the running device.

Feel free to submit an issue if you run into any issues during the installation :)

## Acknowledgement

We thank [@stefanonuvoli](https://github.com/stefanonuvoli/skinmixer) for the help for the discussion of implementation about `Motion Reassembly` part (we eventually manually merged the meshes of different characters). And [@Radamés Ajna](https://github.com/radames) for the help of a better huggingface demo. 


## Citation

If you find our work useful for your research, please consider citing using the following BibTeX entry.

```BibTeX
@article{10.1145/weiyu23GenMM,
    author     = {Li, Weiyu and Chen, Xuelin and Li, Peizhuo and Sorkine-Hornung, Olga and Chen, Baoquan},
    title      = {Example-Based Motion Synthesis via Generative Motion Matching},
    journal    = {ACM Transactions on Graphics (TOG)},
    volume     = {42},
    number     = {4},
    year       = {2023},
    articleno  = {94},
    doi = {10.1145/3592395},
    publisher  = {Association for Computing Machinery},
}
```


================================================
FILE: __init__.py
================================================
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTIBILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import os
import sys
import bpy
import torch
import mathutils
import numpy as np
from math import degrees, radians, ceil
from mathutils import Vector, Matrix, Euler
from typing import List, Iterable, Tuple, Any, Dict

abs_path = os.path.abspath(__file__)
sys.path.append(os.path.dirname(abs_path))
from GenMM import GenMM
from nearest_neighbor.losses import PatchCoherentLoss
from dataset.blender_motion import BlenderMotion

bl_info = {
    "name" : "GenMM",
    "author" : "Weiyu Li",
    "description" : "Blender addon for SIGGRAPH paper 'Example-Based Motion Synthesis via Generative Motion Matching'",
    "blender" : (3, 2, 0),
    "version" : (0, 0, 1),
    "location": "3D View",
    "description": "Synthesis novel motions form a few exemplars.",
    "location" : "",
    "support": "TESTING",
    "warning" : "",
    "category" : "Generic"
}

def capture_rest_pose(armature_obj):
    """Capture the rest pose bone data (head, tail, roll) from an armature."""
    rest_pose_data = {}
    bpy.ops.object.mode_set(mode='EDIT')
    arm_data = armature_obj.data
    for bone in arm_data.edit_bones:
        rest_pose_data[bone.name] = {
            'head': bone.head.copy(),
            'tail': bone.tail.copy(),
            'roll': bone.roll,
            'matrix_local': bone.matrix.copy()
        }
    bpy.ops.object.mode_set(mode='OBJECT')
    return rest_pose_data

# This function is modified from
# https://github.com/bwrsandman/blender-addons/blob/master/io_anim_bvh
def get_bvh_data(context,
                 frame_end,
                 frame_start,
                 global_scale=1.0,
                 rotate_mode='NATIVE',
                 root_transform_only=False,
                 ):

    def ensure_rot_order(rot_order_str):
        if set(rot_order_str) != {'X', 'Y', 'Z'}:
            rot_order_str = "XYZ"
        return rot_order_str
    
    file_str = []

    obj = context.object
    arm = obj.data

    # Build a dictionary of children.
    # None for parentless
    children = {None: []}

    # initialize with blank lists
    for bone in arm.bones:
        children[bone.name] = []

    # keep bone order from armature, no sorting, not esspential but means
    # we can maintain order from import -> export which secondlife incorrectly expects.
    for bone in arm.bones:
        children[getattr(bone.parent, "name", None)].append(bone.name)

    # bone name list in the order that the bones are written
    serialized_names = []

    node_locations = {}

    file_str.append("HIERARCHY\n")

    def write_recursive_nodes(bone_name, indent):
        my_children = children[bone_name]

        indent_str = "\t" * indent

        bone = arm.bones[bone_name]
        pose_bone = obj.pose.bones[bone_name]
        loc = bone.head_local
        node_locations[bone_name] = loc

        if rotate_mode == "NATIVE":
            rot_order_str = ensure_rot_order(pose_bone.rotation_mode)
        else:
            rot_order_str = rotate_mode

        # make relative if we can
        if bone.parent:
            loc = loc - node_locations[bone.parent.name]

        if indent:
            file_str.append("%sJOINT %s\n" % (indent_str, bone_name))
        else:
            file_str.append("%sROOT %s\n" % (indent_str, bone_name))

        file_str.append("%s{\n" % indent_str)
        file_str.append("%s\tOFFSET %.6f %.6f %.6f\n" % (indent_str, loc.x * global_scale, loc.y * global_scale, loc.z * global_scale))
        if (bone.use_connect or root_transform_only) and bone.parent:
            file_str.append("%s\tCHANNELS 3 %srotation %srotation %srotation\n" % (indent_str, rot_order_str[0], rot_order_str[1], rot_order_str[2]))
        else:
            file_str.append("%s\tCHANNELS 6 Xposition Yposition Zposition %srotation %srotation %srotation\n" % (indent_str, rot_order_str[0], rot_order_str[1], rot_order_str[2]))

        if my_children:
            # store the location for the children
            # to get their relative offset

            # Write children
            for child_bone in my_children:
                serialized_names.append(child_bone)
                write_recursive_nodes(child_bone, indent + 1)

        else:
            # Write the bone end.
            file_str.append("%s\tEnd Site\n" % indent_str)
            file_str.append("%s\t{\n" % indent_str)
            loc = bone.tail_local - node_locations[bone_name]
            file_str.append("%s\t\tOFFSET %.6f %.6f %.6f\n" % (indent_str, loc.x * global_scale, loc.y * global_scale, loc.z * global_scale))
            file_str.append("%s\t}\n" % indent_str)

        file_str.append("%s}\n" % indent_str)

    if len(children[None]) == 1:
        key = children[None][0]
        serialized_names.append(key)
        indent = 0

        write_recursive_nodes(key, indent)

    else:
        # Write a dummy parent node, with a dummy key name
        # Just be sure it's not used by another bone!
        i = 0
        key = "__%d" % i
        while key in children:
            i += 1
            key = "__%d" % i
        file_str.append("ROOT %s\n" % key)
        file_str.append("{\n")
        file_str.append("\tOFFSET 0.0 0.0 0.0\n")
        file_str.append("\tCHANNELS 0\n")  # Xposition Yposition Zposition Xrotation Yrotation Zrotation
        indent = 1

        # Write children
        for child_bone in children[None]:
            serialized_names.append(child_bone)
            write_recursive_nodes(child_bone, indent)

        file_str.append("}\n")
    file_str = ''.join(file_str)
    # redefine bones as sorted by serialized_names
    # so we can write motion

    class DecoratedBone:
        __slots__ = (
            # Bone name, used as key in many places.
            "name",
            "parent",  # decorated bone parent, set in a later loop
            # Blender armature bone.
            "rest_bone",
            # Blender pose bone.
            "pose_bone",
            # Blender pose matrix.
            "pose_mat",
            # Blender rest matrix (armature space).
            "rest_arm_mat",
            # Blender rest matrix (local space).
            "rest_local_mat",
            # Pose_mat inverted.
            "pose_imat",
            # Rest_arm_mat inverted.
            "rest_arm_imat",
            # Rest_local_mat inverted.
            "rest_local_imat",
            # Last used euler to preserve euler compatibility in between keyframes.
            "prev_euler",
            # Is the bone disconnected to the parent bone?
            "skip_position",
            "rot_order",
            "rot_order_str",
            # Needed for the euler order when converting from a matrix.
            "rot_order_str_reverse",
        )

        _eul_order_lookup = {
            'XYZ': (0, 1, 2),
            'XZY': (0, 2, 1),
            'YXZ': (1, 0, 2),
            'YZX': (1, 2, 0),
            'ZXY': (2, 0, 1),
            'ZYX': (2, 1, 0),
        }

        def __init__(self, bone_name):
            self.name = bone_name
            self.rest_bone = arm.bones[bone_name]
            self.pose_bone = obj.pose.bones[bone_name]

            if rotate_mode == "NATIVE":
                self.rot_order_str = ensure_rot_order(self.pose_bone.rotation_mode)
            else:
                self.rot_order_str = rotate_mode
            self.rot_order_str_reverse = self.rot_order_str[::-1]

            self.rot_order = DecoratedBone._eul_order_lookup[self.rot_order_str]

            self.pose_mat = self.pose_bone.matrix

            # mat = self.rest_bone.matrix  # UNUSED
            self.rest_arm_mat = self.rest_bone.matrix_local
            self.rest_local_mat = self.rest_bone.matrix

            # inverted mats
            self.pose_imat = self.pose_mat.inverted()
            self.rest_arm_imat = self.rest_arm_mat.inverted()
            self.rest_local_imat = self.rest_local_mat.inverted()

            self.parent = None
            self.prev_euler = Euler((0.0, 0.0, 0.0), self.rot_order_str_reverse)
            self.skip_position = ((self.rest_bone.use_connect or root_transform_only) and self.rest_bone.parent)

        def update_posedata(self):
            self.pose_mat = self.pose_bone.matrix
            self.pose_imat = self.pose_mat.inverted()

        def __repr__(self):
            if self.parent:
                return "[\"%s\" child on \"%s\"]\n" % (self.name, self.parent.name)
            else:
                return "[\"%s\" root bone]\n" % (self.name)

    bones_decorated = [DecoratedBone(bone_name) for bone_name in serialized_names]


    # Assign parents
    bones_decorated_dict = {dbone.name: dbone for dbone in bones_decorated}
    for dbone in bones_decorated:
        parent = dbone.rest_bone.parent
        if parent:
            dbone.parent = bones_decorated_dict[parent.name]
    del bones_decorated_dict
    # finish assigning parents

    scene = context.scene
    frame_current = scene.frame_current

    file_str += "MOTION\n"
    file_str += "Frames: %d\n" % (frame_end - frame_start + 1)
    file_str += "Frame Time: %.6f\n" % (1.0 / (scene.render.fps / scene.render.fps_base))

    for frame in range(frame_start, frame_end + 1):
        scene.frame_set(frame)

        for dbone in bones_decorated:
            dbone.update_posedata()

        for dbone in bones_decorated:
            trans = Matrix.Translation(dbone.rest_bone.head_local)
            itrans = Matrix.Translation(-dbone.rest_bone.head_local)

            if dbone.parent:
                mat_final = dbone.parent.rest_arm_mat @ dbone.parent.pose_imat @ dbone.pose_mat @ dbone.rest_arm_imat
                mat_final = itrans @ mat_final @ trans
                loc = mat_final.to_translation() + (dbone.rest_bone.head_local - dbone.parent.rest_bone.head_local)
            else:
                mat_final = dbone.pose_mat @ dbone.rest_arm_imat
                mat_final = itrans @ mat_final @ trans
                loc = mat_final.to_translation() + dbone.rest_bone.head

            # keep eulers compatible, no jumping on interpolation.
            rot = mat_final.to_euler(dbone.rot_order_str_reverse, dbone.prev_euler)

            if not dbone.skip_position:
                file_str += "%.6f %.6f %.6f " % (loc * global_scale)[:]

            file_str += "%.6f %.6f %.6f " % (degrees(rot[dbone.rot_order[0]]), degrees(rot[dbone.rot_order[1]]), degrees(rot[dbone.rot_order[2]]))

            dbone.prev_euler = rot

        file_str += "\n"

    scene.frame_set(frame_current)

    return file_str


class BVH_Node:
    __slots__ = (
        # Bvh joint name.
        'name',
        # BVH_Node type or None for no parent.
        'parent',
        # A list of children of this type..
        'children',
        # Worldspace rest location for the head of this node.
        'rest_head_world',
        # Localspace rest location for the head of this node.
        'rest_head_local',
        # Worldspace rest location for the tail of this node.
        'rest_tail_world',
        # Worldspace rest location for the tail of this node.
        'rest_tail_local',
        # List of 6 ints, -1 for an unused channel,
        # otherwise an index for the BVH motion data lines,
        # loc triple then rot triple.
        'channels',
        # A triple of indices as to the order rotation is applied.
        # [0,1,2] is x/y/z - [None, None, None] if no rotation..
        'rot_order',
        # Same as above but a string 'XYZ' format..
        'rot_order_str',
        # A list one tuple's one for each frame: (locx, locy, locz, rotx, roty, rotz),
        # euler rotation ALWAYS stored xyz order, even when native used.
        'anim_data',
        # Convenience function, bool, same as: (channels[0] != -1 or channels[1] != -1 or channels[2] != -1).
        'has_loc',
        # Convenience function, bool, same as: (channels[3] != -1 or channels[4] != -1 or channels[5] != -1).
        'has_rot',
        # Index from the file, not strictly needed but nice to maintain order.
        'index',
        # Use this for whatever you want.
        'temp',
    )

    _eul_order_lookup = {
        (None, None, None): 'XYZ',  # XXX Dummy one, no rotation anyway!
        (0, 1, 2): 'XYZ',
        (0, 2, 1): 'XZY',
        (1, 0, 2): 'YXZ',
        (1, 2, 0): 'YZX',
        (2, 0, 1): 'ZXY',
        (2, 1, 0): 'ZYX',
    }

    def __init__(self, name, rest_head_world, rest_head_local, parent, channels, rot_order, index):
        self.name = name
        self.rest_head_world = rest_head_world
        self.rest_head_local = rest_head_local
        self.rest_tail_world = None
        self.rest_tail_local = None
        self.parent = parent
        self.channels = channels
        self.rot_order = tuple(rot_order)
        self.rot_order_str = BVH_Node._eul_order_lookup[self.rot_order]
        self.index = index

        # convenience functions
        self.has_loc = channels[0] != -1 or channels[1] != -1 or channels[2] != -1
        self.has_rot = channels[3] != -1 or channels[4] != -1 or channels[5] != -1

        self.children = []

        # List of 6 length tuples: (lx, ly, lz, rx, ry, rz)
        # even if the channels aren't used they will just be zero.
        self.anim_data = [(0, 0, 0, 0, 0, 0)]

    def __repr__(self):
        return (
            "BVH name: '%s', rest_loc:(%.3f,%.3f,%.3f), rest_tail:(%.3f,%.3f,%.3f)" % (
                self.name,
                *self.rest_head_world,
                *self.rest_head_world,
            )
        )


def sorted_nodes(bvh_nodes):
    bvh_nodes_list = list(bvh_nodes.values())
    bvh_nodes_list.sort(key=lambda bvh_node: bvh_node.index)
    return bvh_nodes_list


def read_bvh(context, bvh_str, rotate_mode='XYZ', global_scale=1.0):
    # Separate into a list of lists, each line a list of words.
    file_lines = bvh_str
    # Non standard carriage returns?
    if len(file_lines) == 1:
        file_lines = file_lines[0].split('\r')

    # Split by whitespace.
    file_lines = [ll for ll in [l.split() for l in file_lines] if ll]

    # Create hierarchy as empties
    if file_lines[0][0].lower() == 'hierarchy':
        # print 'Importing the BVH Hierarchy for:', file_path
        pass
    else:
        raise Exception("This is not a BVH file")

    bvh_nodes = {None: None}
    bvh_nodes_serial = [None]
    bvh_frame_count = None
    bvh_frame_time = None

    channelIndex = -1

    lineIdx = 0  # An index for the file.
    while lineIdx < len(file_lines) - 1:
        if file_lines[lineIdx][0].lower() in {'root', 'joint'}:

            # Join spaces into 1 word with underscores joining it.
            if len(file_lines[lineIdx]) > 2:
                file_lines[lineIdx][1] = '_'.join(file_lines[lineIdx][1:])
                file_lines[lineIdx] = file_lines[lineIdx][:2]

            # MAY NEED TO SUPPORT MULTIPLE ROOTS HERE! Still unsure weather multiple roots are possible?

            # Make sure the names are unique - Object names will match joint names exactly and both will be unique.
            name = file_lines[lineIdx][1]

            # print '%snode: %s, parent: %s' % (len(bvh_nodes_serial) * '  ', name,  bvh_nodes_serial[-1])

            lineIdx += 2  # Increment to the next line (Offset)
            rest_head_local = global_scale * Vector((
                float(file_lines[lineIdx][1]),
                float(file_lines[lineIdx][2]),
                float(file_lines[lineIdx][3]),
            ))
            lineIdx += 1  # Increment to the next line (Channels)

            # newChannel[Xposition, Yposition, Zposition, Xrotation, Yrotation, Zrotation]
            # newChannel references indices to the motiondata,
            # if not assigned then -1 refers to the last value that will be added on loading at a value of zero, this is appended
            # We'll add a zero value onto the end of the MotionDATA so this always refers to a value.
            my_channel = [-1, -1, -1, -1, -1, -1]
            my_rot_order = [None, None, None]
            rot_count = 0
            for channel in file_lines[lineIdx][2:]:
                channel = channel.lower()
                channelIndex += 1  # So the index points to the right channel
                if channel == 'xposition':
                    my_channel[0] = channelIndex
                elif channel == 'yposition':
                    my_channel[1] = channelIndex
                elif channel == 'zposition':
                    my_channel[2] = channelIndex

                elif channel == 'xrotation':
                    my_channel[3] = channelIndex
                    my_rot_order[rot_count] = 0
                    rot_count += 1
                elif channel == 'yrotation':
                    my_channel[4] = channelIndex
                    my_rot_order[rot_count] = 1
                    rot_count += 1
                elif channel == 'zrotation':
                    my_channel[5] = channelIndex
                    my_rot_order[rot_count] = 2
                    rot_count += 1

            channels = file_lines[lineIdx][2:]

            my_parent = bvh_nodes_serial[-1]  # account for none

            # Apply the parents offset accumulatively
            if my_parent is None:
                rest_head_world = Vector(rest_head_local)
            else:
                rest_head_world = my_parent.rest_head_world + rest_head_local

            bvh_node = bvh_nodes[name] = BVH_Node(
                name,
                rest_head_world,
                rest_head_local,
                my_parent,
                my_channel,
                my_rot_order,
                len(bvh_nodes) - 1,
            )

            # If we have another child then we can call ourselves a parent, else
            bvh_nodes_serial.append(bvh_node)

        # Account for an end node.
        # There is sometimes a name after 'End Site' but we will ignore it.
        if file_lines[lineIdx][0].lower() == 'end' and file_lines[lineIdx][1].lower() == 'site':
            # Increment to the next line (Offset)
            lineIdx += 2
            rest_tail = global_scale * Vector((
                float(file_lines[lineIdx][1]),
                float(file_lines[lineIdx][2]),
                float(file_lines[lineIdx][3]),
            ))

            bvh_nodes_serial[-1].rest_tail_world = bvh_nodes_serial[-1].rest_head_world + rest_tail
            bvh_nodes_serial[-1].rest_tail_local = bvh_nodes_serial[-1].rest_head_local + rest_tail

            # Just so we can remove the parents in a uniform way,
            # the end has kids so this is a placeholder.
            bvh_nodes_serial.append(None)

        if len(file_lines[lineIdx]) == 1 and file_lines[lineIdx][0] == '}':  # == ['}']
            bvh_nodes_serial.pop()  # Remove the last item

        # End of the hierarchy. Begin the animation section of the file with
        # the following header.
        #  MOTION
        #  Frames: n
        #  Frame Time: dt
        if len(file_lines[lineIdx]) == 1 and file_lines[lineIdx][0].lower() == 'motion':
            lineIdx += 1  # Read frame count.
            if (
                    len(file_lines[lineIdx]) == 2 and
                    file_lines[lineIdx][0].lower() == 'frames:'
            ):
                bvh_frame_count = int(file_lines[lineIdx][1])

            lineIdx += 1  # Read frame rate.
            if (
                    len(file_lines[lineIdx]) == 3 and
                    file_lines[lineIdx][0].lower() == 'frame' and
                    file_lines[lineIdx][1].lower() == 'time:'
            ):
                bvh_frame_time = float(file_lines[lineIdx][2])

            lineIdx += 1  # Set the cursor to the first frame

            break

        lineIdx += 1

    # Remove the None value used for easy parent reference
    del bvh_nodes[None]
    # Don't use anymore
    del bvh_nodes_serial

    # importing world with any order but nicer to maintain order
    # second life expects it, which isn't to spec.
    bvh_nodes_list = sorted_nodes(bvh_nodes)

    while lineIdx < len(file_lines):
        line = file_lines[lineIdx]
        for bvh_node in bvh_nodes_list:
            # for bvh_node in bvh_nodes_serial:
            lx = ly = lz = rx = ry = rz = 0.0
            channels = bvh_node.channels
            anim_data = bvh_node.anim_data
            if channels[0] != -1:
                lx = global_scale * float(line[channels[0]])

            if channels[1] != -1:
                ly = global_scale * float(line[channels[1]])

            if channels[2] != -1:
                lz = global_scale * float(line[channels[2]])

            if channels[3] != -1 or channels[4] != -1 or channels[5] != -1:

                rx = radians(float(line[channels[3]]))
                ry = radians(float(line[channels[4]]))
                rz = radians(float(line[channels[5]]))

            # Done importing motion data #
            anim_data.append((lx, ly, lz, rx, ry, rz))
        lineIdx += 1

    # Assign children
    for bvh_node in bvh_nodes_list:
        bvh_node_parent = bvh_node.parent
        if bvh_node_parent:
            bvh_node_parent.children.append(bvh_node)

    # Now set the tip of each bvh_node
    for bvh_node in bvh_nodes_list:

        if not bvh_node.rest_tail_world:
            if len(bvh_node.children) == 0:
                # could just fail here, but rare BVH files have childless nodes
                bvh_node.rest_tail_world = Vector(bvh_node.rest_head_world)
                bvh_node.rest_tail_local = Vector(bvh_node.rest_head_local)
            elif len(bvh_node.children) == 1:
                bvh_node.rest_tail_world = Vector(bvh_node.children[0].rest_head_world)
                bvh_node.rest_tail_local = bvh_node.rest_head_local + bvh_node.children[0].rest_head_local
            else:
                # allow this, see above
                # if not bvh_node.children:
                #     raise Exception("bvh node has no end and no children. bad file")

                # Removed temp for now
                rest_tail_world = Vector((0.0, 0.0, 0.0))
                rest_tail_local = Vector((0.0, 0.0, 0.0))
                for bvh_node_child in bvh_node.children:
                    rest_tail_world += bvh_node_child.rest_head_world
                    rest_tail_local += bvh_node_child.rest_head_local

                bvh_node.rest_tail_world = rest_tail_world * (1.0 / len(bvh_node.children))
                bvh_node.rest_tail_local = rest_tail_local * (1.0 / len(bvh_node.children))

        # Make sure tail isn't the same location as the head.
        if (bvh_node.rest_tail_local - bvh_node.rest_head_local).length <= 0.001 * global_scale:
            print("\tzero length node found:", bvh_node.name)
            bvh_node.rest_tail_local.y = bvh_node.rest_tail_local.y + global_scale / 10
            bvh_node.rest_tail_world.y = bvh_node.rest_tail_world.y + global_scale / 10

    return bvh_nodes, bvh_frame_time, bvh_frame_count


def bvh_node_dict2objects(context, bvh_name, bvh_nodes, rotate_mode='NATIVE', frame_start=1, IMPORT_LOOP=False):

    if frame_start < 1:
        frame_start = 1

    scene = context.scene
    for obj in scene.objects:
        obj.select_set(False)

    objects = []

    def add_ob(name):
        obj = bpy.data.objects.new(name, None)
        context.collection.objects.link(obj)
        objects.append(obj)
        obj.select_set(True)

        # nicer drawing.
        obj.empty_display_type = 'CUBE'
        obj.empty_display_size = 0.1

        return obj

    # Add objects
    for name, bvh_node in bvh_nodes.items():
        bvh_node.temp = add_ob(name)
        bvh_node.temp.rotation_mode = bvh_node.rot_order_str[::-1]

    # Parent the objects
    for bvh_node in bvh_nodes.values():
        for bvh_node_child in bvh_node.children:
            bvh_node_child.temp.parent = bvh_node.temp

    # Offset
    for bvh_node in bvh_nodes.values():
        # Make relative to parents offset
        bvh_node.temp.location = bvh_node.rest_head_local

    # Add tail objects
    for name, bvh_node in bvh_nodes.items():
        if not bvh_node.children:
            ob_end = add_ob(name + '_end')
            ob_end.parent = bvh_node.temp
            ob_end.location = bvh_node.rest_tail_world - bvh_node.rest_head_world

    for name, bvh_node in bvh_nodes.items():
        obj = bvh_node.temp

        for frame_current in range(len(bvh_node.anim_data)):

            lx, ly, lz, rx, ry, rz = bvh_node.anim_data[frame_current]

            if bvh_node.has_loc:
                obj.delta_location = Vector((lx, ly, lz)) - bvh_node.rest_head_world
                obj.keyframe_insert("delta_location", index=-1, frame=frame_start + frame_current)

            if bvh_node.has_rot:
                obj.delta_rotation_euler = rx, ry, rz
                obj.keyframe_insert("delta_rotation_euler", index=-1, frame=frame_start + frame_current)

    return objects


def bvh_node_dict2armature(
        context,
        bvh_name,
        bvh_nodes,
        bvh_frame_time,
        rotate_mode='XYZ',
        frame_start=1,
        IMPORT_LOOP=False,
        global_matrix=None,
        use_fps_scale=False,
        original_rest_pose=None  # New parameter for the original rest pose
):
    if frame_start < 1:
        frame_start = 1

    scene = context.scene
    for obj in scene.objects:
        obj.select_set(False)

    arm_data = bpy.data.armatures.new(bvh_name)
    arm_ob = bpy.data.objects.new(bvh_name, arm_data)

    context.collection.objects.link(arm_ob)

    arm_ob.select_set(True)
    context.view_layer.objects.active = arm_ob

    bpy.ops.object.mode_set(mode='EDIT', toggle=False)

    bvh_nodes_list = sorted_nodes(bvh_nodes)

    # Get the average bone length for zero length bones
    average_bone_length = 0.0
    nonzero_count = 0
    for bvh_node in bvh_nodes_list:
        l = (bvh_node.rest_head_local - bvh_node.rest_tail_local).length
        if l:
            average_bone_length += l
            nonzero_count += 1

    if not average_bone_length:
        average_bone_length = 0.1
    else:
        average_bone_length = average_bone_length / nonzero_count

    while arm_data.edit_bones:
        arm_ob.edit_bones.remove(arm_data.edit_bones[-1])

    ZERO_AREA_BONES = []
    # First pass: Create all bones and assign to temp
    for bvh_node in bvh_nodes_list:
        bone = arm_data.edit_bones.new(bvh_node.name)

        # Use the original rest pose if provided, otherwise fall back to BVH data
        if original_rest_pose and bvh_node.name in original_rest_pose:
            bone.head = original_rest_pose[bvh_node.name]['head']
            bone.tail = original_rest_pose[bvh_node.name]['tail']
            bone.roll = original_rest_pose[bvh_node.name]['roll']
        else:
            bone.head = bvh_node.rest_head_world
            bone.tail = bvh_node.rest_tail_world

            # Handle zero-length bones
            if (bone.head - bone.tail).length < 0.001:
                print("\tzero length bone found:", bone.name)
                if bvh_node.parent:
                    ofs = bvh_node.parent.rest_head_local - bvh_node.parent.rest_tail_local
                    if ofs.length:
                        bone.tail = bone.tail - ofs
                    else:
                        bone.tail.y = bone.tail.y + average_bone_length
                else:
                    bone.tail.y = bone.tail.y + average_bone_length

                ZERO_AREA_BONES.append(bvh_node.name)

        # Assign the edit bone to the temp attribute
        bvh_node.temp = bone

    # Second pass: Set parenting and connection
    for bvh_node in bvh_nodes_list:
        if bvh_node.parent:
            # Now bvh_node.temp and bvh_node.parent.temp should both be valid
            bvh_node.temp.parent = bvh_node.parent.temp

            if (
                (not bvh_node.has_loc) and
                (bvh_node.parent.temp.name not in ZERO_AREA_BONES) and
                (bvh_node.parent.rest_tail_local == bvh_node.rest_head_local)
            ):
                bvh_node.temp.use_connect = True

    # Replace temp with bone name for later use
    for bvh_node in bvh_nodes_list:
        bvh_node.temp = bvh_node.temp.name

    bpy.ops.object.mode_set(mode='OBJECT', toggle=False)

    pose = arm_ob.pose
    pose_bones = pose.bones

    if rotate_mode == 'NATIVE':
        for bvh_node in bvh_nodes_list:
            bone_name = bvh_node.temp
            pose_bone = pose_bones[bone_name]
            pose_bone.rotation_mode = bvh_node.rot_order_str
    elif rotate_mode != 'QUATERNION':
        for pose_bone in pose_bones:
            pose_bone.rotation_mode = rotate_mode

    context.view_layer.update()

    arm_ob.animation_data_create()
    action = bpy.data.actions.new(name=bvh_name)
    arm_ob.animation_data.action = action

    num_frame = 0
    for bvh_node in bvh_nodes_list:
        bone_name = bvh_node.temp
        pose_bone = pose_bones[bone_name]
        rest_bone = arm_data.bones[bone_name]
        bone_rest_matrix = rest_bone.matrix_local.to_3x3()

        bone_rest_matrix_inv = Matrix(bone_rest_matrix)
        bone_rest_matrix_inv.invert()

        bone_rest_matrix_inv.resize_4x4()
        bone_rest_matrix.resize_4x4()
        bvh_node.temp = (pose_bone, rest_bone, bone_rest_matrix, bone_rest_matrix_inv)

        if 0 == num_frame:
            num_frame = len(bvh_node.anim_data)

    skip_frame = 1
    if num_frame > skip_frame:
        num_frame = num_frame - skip_frame

    time = [float(frame_start)] * num_frame
    if use_fps_scale:
        dt = scene.render.fps * bvh_frame_time
        for frame_i in range(1, num_frame):
            time[frame_i] += float(frame_i) * dt
    else:
        for frame_i in range(1, num_frame):
            time[frame_i] += float(frame_i)

    for i, bvh_node in enumerate(bvh_nodes_list):
        pose_bone, bone, bone_rest_matrix, bone_rest_matrix_inv = bvh_node.temp

        if bvh_node.has_loc:
            data_path = f'pose.bones["{pose_bone.name}"].location'
            location = [(0.0, 0.0, 0.0)] * num_frame
            for frame_i in range(num_frame):
                bvh_loc = bvh_node.anim_data[frame_i + skip_frame][:3]
                bone_translate_matrix = Matrix.Translation(
                    Vector(bvh_loc) - bvh_node.rest_head_local)
                location[frame_i] = (bone_rest_matrix_inv @
                                     bone_translate_matrix).to_translation()

            for axis_i in range(3):
                curve = action.fcurves.new(data_path=data_path, index=axis_i, action_group=bvh_node.name)
                keyframe_points = curve.keyframe_points
                keyframe_points.add(num_frame)
                for frame_i in range(num_frame):
                    keyframe_points[frame_i].co = (
                        time[frame_i],
                        location[frame_i][axis_i],
                    )

        if bvh_node.has_rot:
            data_path = None
            rotate = None
            if 'QUATERNION' == rotate_mode:
                rotate = [(1.0, 0.0, 0.0, 0.0)] * num_frame
                data_path = f'pose.bones["{pose_bone.name}"].rotation_quaternion'
            else:
                rotate = [(0.0, 0.0, 0.0)] * num_frame
                data_path = f'pose.bones["{pose_bone.name}"].rotation_euler'

            prev_euler = Euler((0.0, 0.0, 0.0))
            for frame_i in range(num_frame):
                bvh_rot = bvh_node.anim_data[frame_i + skip_frame][3:]
                euler = Euler(bvh_rot, bvh_node.rot_order_str[::-1])
                bone_rotation_matrix = euler.to_matrix().to_4x4()
                bone_rotation_matrix = (
                    bone_rest_matrix_inv @
                    bone_rotation_matrix @
                    bone_rest_matrix
                )

                if len(rotate[frame_i]) == 4:
                    rotate[frame_i] = bone_rotation_matrix.to_quaternion()
                else:
                    rotate[frame_i] = bone_rotation_matrix.to_euler(
                        pose_bone.rotation_mode, prev_euler)
                    prev_euler = rotate[frame_i]

            for axis_i in range(len(rotate[0])):
                curve = action.fcurves.new(data_path=data_path, index=axis_i, action_group=bvh_node.name)
                keyframe_points = curve.keyframe_points
                keyframe_points.add(num_frame)
                for frame_i in range(num_frame):
                    keyframe_points[frame_i].co = (
                        time[frame_i],
                        rotate[frame_i][axis_i],
                    )

    for cu in action.fcurves:
        if IMPORT_LOOP:
            pass
        for bez in cu.keyframe_points:
            bez.interpolation = 'LINEAR'

    try:
        arm_ob.matrix_world = global_matrix
    except:
        pass
    bpy.ops.object.transform_apply(location=False, rotation=True, scale=False)

    return arm_ob


def load(
        context,
        bvh_str,
        *,
        target='ARMATURE',
        rotate_mode='NATIVE',
        global_scale=1.0,
        use_cyclic=False,
        frame_start=1,
        global_matrix=None,
        use_fps_scale=False,
        update_scene_fps=False,
        update_scene_duration=False,
        original_rest_pose=None,
        bvh_name='synsized',  # Added parameter
        report=print,
):
    import time
    t1 = time.time()

    bvh_nodes, bvh_frame_time, bvh_frame_count = read_bvh(
        context, bvh_str,
        rotate_mode=rotate_mode,
        global_scale=global_scale,
    )

    print("%.4f" % (time.time() - t1))

    scene = context.scene
    frame_orig = scene.frame_current

    if bvh_frame_time is None:
        report(
            {'WARNING'},
            "The BVH file does not contain frame duration in its MOTION "
            "section, assuming the BVH and Blender scene have the same "
            "frame rate"
        )
        bvh_frame_time = scene.render.fps_base / scene.render.fps
        use_fps_scale = False

    if update_scene_fps:
        _update_scene_fps(context, report, bvh_frame_time)
        use_fps_scale = False

    if update_scene_duration:
        _update_scene_duration(context, report, bvh_frame_count, bvh_frame_time, frame_start, use_fps_scale)

    t1 = time.time()
    print("\timporting to blender...", end="")

    if target == 'ARMATURE':
        bvh_node_dict2armature(
            context, bvh_name, bvh_nodes, bvh_frame_time,
            rotate_mode=rotate_mode,
            frame_start=frame_start,
            IMPORT_LOOP=use_cyclic,
            global_matrix=global_matrix,
            use_fps_scale=use_fps_scale,
            original_rest_pose=original_rest_pose
        )
    elif target == 'OBJECT':
        bvh_node_dict2objects(
            context, bvh_name, bvh_nodes,
            rotate_mode=rotate_mode,
            frame_start=frame_start,
            IMPORT_LOOP=use_cyclic,
        )
    else:
        report({'ERROR'}, tip_("Invalid target %r (must be 'ARMATURE' or 'OBJECT')") % target)
        return {'CANCELLED'}

    print('Done in %.4f\n' % (time.time() - t1))
    context.scene.frame_set(frame_orig)
    return {'FINISHED'}


def _update_scene_fps(context, report, bvh_frame_time):
    """Update the scene's FPS settings from the BVH, but only if the BVH contains enough info."""

    # Broken BVH handling: prevent division by zero.
    if bvh_frame_time == 0.0:
        report(
            {'WARNING'},
            "Unable to update scene frame rate, as the BVH file "
            "contains a zero frame duration in its MOTION section",
        )
        return

    scene = context.scene
    scene_fps = scene.render.fps / scene.render.fps_base
    new_fps = 1.0 / bvh_frame_time

    if scene.render.fps != new_fps or scene.render.fps_base != 1.0:
        print("\tupdating scene FPS (was %f) to BVH FPS (%f)" % (scene_fps, new_fps))
    scene.render.fps = int(round(new_fps))
    scene.render.fps_base = scene.render.fps / new_fps


def _update_scene_duration(
        context, report, bvh_frame_count, bvh_frame_time, frame_start,
        use_fps_scale):
    """Extend the scene's duration so that the BVH file fits in its entirety."""

    if bvh_frame_count is None:
        report(
            {'WARNING'},
            "Unable to extend the scene duration, as the BVH file does not "
            "contain the number of frames in its MOTION section",
        )
        return

    # Not likely, but it can happen when a BVH is just used to store an armature.
    if bvh_frame_count == 0:
        return

    if use_fps_scale:
        scene_fps = context.scene.render.fps / context.scene.render.fps_base
        scaled_frame_count = int(ceil(bvh_frame_count * bvh_frame_time * scene_fps))
        bvh_last_frame = frame_start + scaled_frame_count
    else:
        bvh_last_frame = frame_start + bvh_frame_count

    # Only extend the scene, never shorten it.
    if context.scene.frame_end < bvh_last_frame:
        context.scene.frame_end = bvh_last_frame


# This function is from
# https://github.com/yuki-koyama/blender-cli-rendering
def set_smooth_shading(mesh: bpy.types.Mesh) -> None:
    for polygon in mesh.polygons:
        polygon.use_smooth = True


# This function is from
# https://github.com/yuki-koyama/blender-cli-rendering
def create_mesh_from_pydata(scene: bpy.types.Scene,
                            vertices: Iterable[Iterable[float]],
                            faces: Iterable[Iterable[int]],
                            mesh_name: str,
                            object_name: str,
                            use_smooth: bool = True) -> bpy.types.Object:
    # Add a new mesh and set vertices and faces
    # Note: In this case, it does not require to set edges.
    # Note: After manipulating mesh data, update() needs to be called.
    new_mesh: bpy.types.Mesh = bpy.data.meshes.new(mesh_name)
    new_mesh.from_pydata(vertices, [], faces)
    new_mesh.update()
    if use_smooth:
        set_smooth_shading(new_mesh)

    new_object: bpy.types.Object = bpy.data.objects.new(object_name, new_mesh)
    scene.collection.objects.link(new_object)

    return new_object


# This function is from
# https://github.com/yuki-koyama/blender-cli-rendering
def add_subdivision_surface_modifier(mesh_object: bpy.types.Object, level: int, is_simple: bool = False) -> None:
    '''
    https://docs.blender.org/api/current/bpy.types.SubsurfModifier.html
    '''

    modifier: bpy.types.SubsurfModifier = mesh_object.modifiers.new(name="Subsurf", type='SUBSURF')

    modifier.levels = level
    modifier.render_levels = level
    modifier.subdivision_type = 'SIMPLE' if is_simple else 'CATMULL_CLARK'


# This function is from
# https://github.com/yuki-koyama/blender-cli-rendering
def create_armature_mesh(scene: bpy.types.Scene, armature_object: bpy.types.Object, mesh_name: str) -> bpy.types.Object:
    assert armature_object.type == 'ARMATURE', 'Error'
    assert len(armature_object.data.bones) != 0, 'Error'

    def add_rigid_vertex_group(target_object: bpy.types.Object, name: str, vertex_indices: Iterable[int]) -> None:
        new_vertex_group = target_object.vertex_groups.new(name=name)
        for vertex_index in vertex_indices:
            new_vertex_group.add([vertex_index], 1.0, 'REPLACE')

    def generate_bone_mesh_pydata(radius: float, length: float) -> Tuple[List[mathutils.Vector], List[List[int]]]:
        base_radius = radius
        top_radius = 0.5 * radius

        vertices = [
            # Cross section of the base part
            mathutils.Vector((-base_radius, 0.0, +base_radius)),
            mathutils.Vector((+base_radius, 0.0, +base_radius)),
            mathutils.Vector((+base_radius, 0.0, -base_radius)),
            mathutils.Vector((-base_radius, 0.0, -base_radius)),

            # Cross section of the top part
            mathutils.Vector((-top_radius, length, +top_radius)),
            mathutils.Vector((+top_radius, length, +top_radius)),
            mathutils.Vector((+top_radius, length, -top_radius)),
            mathutils.Vector((-top_radius, length, -top_radius)),

            # End points
            mathutils.Vector((0.0, -base_radius, 0.0)),
            mathutils.Vector((0.0, length + top_radius, 0.0))
        ]

        faces = [
            # End point for the base part
            [8, 1, 0],
            [8, 2, 1],
            [8, 3, 2],
            [8, 0, 3],

            # End point for the top part
            [9, 4, 5],
            [9, 5, 6],
            [9, 6, 7],
            [9, 7, 4],

            # Side faces
            [0, 1, 5, 4],
            [1, 2, 6, 5],
            [2, 3, 7, 6],
            [3, 0, 4, 7],
        ]

        return vertices, faces

    armature_data: bpy.types.Armature = armature_object.data

    vertices: List[mathutils.Vector] = []
    faces: List[List[int]] = []
    vertex_groups: List[Dict[str, Any]] = []

    for bone in armature_data.bones:
        radius = 0.10 * (0.10 + bone.length)
        temp_vertices, temp_faces = generate_bone_mesh_pydata(radius, bone.length)

        vertex_index_offset = len(vertices)

        temp_vertex_group = {'name': bone.name, 'vertex_indices': []}
        for local_index, vertex in enumerate(temp_vertices):
            vertices.append(bone.matrix_local @ vertex)
            temp_vertex_group['vertex_indices'].append(local_index + vertex_index_offset)
        vertex_groups.append(temp_vertex_group)

        for face in temp_faces:
            if len(face) == 3:
                faces.append([
                    face[0] + vertex_index_offset,
                    face[1] + vertex_index_offset,
                    face[2] + vertex_index_offset,
                ])
            else:
                faces.append([
                    face[0] + vertex_index_offset,
                    face[1] + vertex_index_offset,
                    face[2] + vertex_index_offset,
                    face[3] + vertex_index_offset,
                ])

    new_object = create_mesh_from_pydata(scene, vertices, faces, mesh_name, mesh_name)
    new_object.matrix_world = armature_object.matrix_world

    for vertex_group in vertex_groups:
        add_rigid_vertex_group(new_object, vertex_group['name'], vertex_group['vertex_indices'])

    armature_modifier = new_object.modifiers.new('Armature', 'ARMATURE')
    armature_modifier.object = armature_object
    armature_modifier.use_vertex_groups = True

    add_subdivision_surface_modifier(new_object, 1, is_simple=True)
    add_subdivision_surface_modifier(new_object, 2, is_simple=False)

    # Set the armature as the parent of the new object
    bpy.ops.object.select_all(action='DESELECT')
    new_object.select_set(True)
    armature_object.select_set(True)
    bpy.context.view_layer.objects.active = armature_object
    bpy.ops.object.parent_set(type='OBJECT')

    return new_object


class OP_AddMesh(bpy.types.Operator):
    bl_idname = "genmm.add_mesh"
    bl_label = "Add mesh"
    bl_description = ""
    bl_options = {"REGISTER", "UNDO"}

    def __init__(self) -> None:
        super().__init__()

    def execute(self, context: bpy.types.Context):
        name = bpy.context.object.name + "_proxy"
        create_armature_mesh(bpy.context.scene, bpy.context.object, name)
        return {'FINISHED'}

class OP_RunSynthesis(bpy.types.Operator):
    bl_idname = "genmm.run_synthesis"
    bl_label = "Run synthesis"
    bl_description = ""
    bl_options = {"REGISTER", "UNDO"}

    def execute(self, context: bpy.types.Context):
        setting = context.scene.setting
        original_armature = context.object
        rest_pose_data = capture_rest_pose(original_armature)

        anim = original_armature.animation_data.action
        start_frame, end_frame = map(int, anim.frame_range)
        start_frame = start_frame if setting.start_frame == -1 else setting.start_frame
        end_frame = end_frame if setting.end_frame == -1 else setting.end_frame

        bvh_str = get_bvh_data(context,
                               frame_start=start_frame,
                               frame_end=end_frame)
        frames_str, frame_time_str = bvh_str.split('MOTION\n')[1].split('\n')[:2]
        motion_data_str = bvh_str.split('MOTION\n')[1].split('\n')[2:-1]
        motion_data = np.array([item.strip().split(' ') for item in motion_data_str], dtype=np.float32)

        model = GenMM(device='cuda' if torch.cuda.is_available() else 'cpu', silent=True)
        criteria = PatchCoherentLoss(patch_size=setting.patch_size, 
                                     alpha=setting.alpha, 
                                     loop=setting.loop, cache=True)

        for i in range(setting.num_output):
            print(f"Generating motion {i+1} of {setting.num_output}")
            # Create a new BlenderMotion instance for each iteration
            motion = [BlenderMotion(motion_data.copy(), repr='repr6d', use_velo=True, 
                                    keep_up_pos=True, up_axis=setting.up_axis, padding_last=False)]
            syn = model.run(motion, criteria,
                            num_frames=str(setting.num_syn_frames),
                            num_steps=setting.num_steps,
                            noise_sigma=setting.noise,
                            patch_size=setting.patch_size, 
                            coarse_ratio=f'{setting.coarse_ratio}x_nframes',
                            pyr_factor=setting.pyr_factor)
            motion_data_str = [' '.join(str(x) for x in item) for item in motion[0].parse(syn)]
            bvh_name = f"synsized_{i+1}"
            load(context,
                 bvh_str.split('MOTION\n')[0].split('\n') + ['MOTION'] + [frames_str] + [frame_time_str] + motion_data_str,
                 rotate_mode='QUATERNION',
                 global_matrix=original_armature.matrix_world,
                 original_rest_pose=rest_pose_data,
                 target='ARMATURE',
                 use_fps_scale=False,
                 bvh_name=bvh_name)

        return {'FINISHED'}

class GENMM_PT_ControlPanel(bpy.types.Panel):
    bl_label = "GenMM"
    bl_space_type = 'VIEW_3D'
    bl_region_type = 'UI'
    bl_category = "GenMM"

    @classmethod
    def poll(cls, context: bpy.types.Context):
        return True

    def draw_header(self, context: bpy.types.Context):
        layout = self.layout
        layout.label(text="", icon='PLUGIN')

    def draw(self, context: bpy.types.Context):
        layout = self.layout
        scene = bpy.context.scene

        ops: List[bpy.types.Operator] = [
            OP_AddMesh,
        ]
        for op in ops:
            layout.operator(op.bl_idname, text=op.bl_label)
        
        box = layout.box()
        box.label(text="Exemplar config:")
        exemplar_row = box.row()
        exemplar_row.prop(scene.setting, "start_frame")
        exemplar_row.prop(scene.setting, "end_frame")
        exemplar_row = box.row()
        exemplar_row.prop(scene.setting, "up_axis")

        box = layout.box()
        box.label(text="Synthesis config:")
        box.prop(scene.setting, "loop")
        box.prop(scene.setting, "noise")
        box.prop(scene.setting, "num_syn_frames")
        box.prop(scene.setting, "patch_size")
        box.prop(scene.setting, "coarse_ratio")
        box.prop(scene.setting, "pyr_factor")
        box.prop(scene.setting, "alpha")
        box.prop(scene.setting, "num_steps")
        box.prop(scene.setting, "num_output")  # New parameter

        ops: List[bpy.types.Operator] = [
            OP_RunSynthesis,
        ]
        for op in ops:
            layout.operator(op.bl_idname, text=op.bl_label)

class PropertyGroup(bpy.types.PropertyGroup):
    '''Property container for options and paths of GenMM'''
    start_frame: bpy.props.IntProperty(
        name="Start Frame",
        description="Start Frame of the Exemplar Motion.",
        default=1)
    end_frame: bpy.props.IntProperty(
        name="End Frame",
        description="End Frame of the Exemplar Motion.",
        default=-1)
    up_axis: bpy.props.EnumProperty(
        name="Up Axis", 
        default='Z_UP',
        description="Up axis of the Exemplar Motion",
        items=[('Z_UP', "Z-Up", 'Z Up'),
               ('Y_UP', "Y-Up", 'Y Up'),
               ('X_UP', "X-Up", 'X Up'),
               ]
    )
    noise: bpy.props.FloatProperty(
        name="Noise Intensity",
        description="Intensity of Noise Added to the Synthesized Motion.",
        default=10)
    num_syn_frames: bpy.props.IntProperty(
        name="Num. of Frames",
        description="Number of the Synthesized Motion.",
        default=600)
    patch_size: bpy.props.IntProperty(
        name="Patch Size",
        description="Size for Patch Extraction.",
        min=7,
        default=15)
    coarse_ratio: bpy.props.FloatProperty(
        name="Coarse Ratio",
        description="Ratio of the Coarest Pyramid.",
        min=0.0,
        default=0.2)
    pyr_factor: bpy.props.FloatProperty(
        name="Pyramid Factor",
        description="Pyramid Downsample Factor.",
        min=0.1,
        default=0.75)
    alpha: bpy.props.FloatProperty(
        name="Completeness Alpha",
        description="Alpha Value for Completeness/Diversity Trade-off.",
        default=0.05)
    loop: bpy.props.BoolProperty(
        name="Endless Loop",
        description="Whether to Use Loop Constrain.",
        default=False)
    num_steps: bpy.props.IntProperty(
        name="Num of Steps",
        description="Number of Optimized Steps.",
        default=5)
    num_output: bpy.props.IntProperty(
        name="Num. of Output",
        description="Number of different motions to generate.",
        min=1,
        default=1)

classes = [
    OP_AddMesh,
    OP_RunSynthesis,
    GENMM_PT_ControlPanel,
]

def register():
    bpy.utils.register_class(PropertyGroup)
    bpy.types.Scene.setting = bpy.props.PointerProperty(type=PropertyGroup)
    for cls in classes:
        bpy.utils.register_class(cls)

def unregister():
    bpy.utils.unregister_class(PropertyGroup)
    for cls in classes:
        bpy.utils.unregister_class(cls)

if __name__ == "__main__":
    register()

================================================
FILE: configs/default.yaml
================================================
# motion data config
repr: 'repr6d'
skeleton_name: null
use_velo: true
keep_up_pos: true
up_axis: 'Y_UP'
padding_last: false
requires_contact: false
joint_reduction: false
skeleton_aware: false
joints_group: null

# generate parameters
num_frames: '2x_nframes'
alpha: 0.01
num_steps: 3
noise_sigma: 10.0
coarse_ratio: '5x_patchsize'
# coarse_ratio: '0.2x_nframes'
pyr_factor: 0.75
num_stages_limit: -1
patch_size: 11
loop: false

================================================
FILE: configs/ganimator.yaml
================================================
################################################################
# This configuration uses the same input format of GANimmator for generation
################################################################
outout_dir: './output/ganimator_format'

# for GANimator BVH data
repr: 'repr6d'
skeleton_name: 'mixamo'
use_velo: true
keep_up_pos: true
up_axis: 'Y_UP'
padding_last: true
requires_contact: true
joint_reduction: true
skeleton_aware: false
joints_group: null

# generate parameters
num_frames: '2x_nframes'
alpha: 0.01
num_steps: 3
noise_sigma: 10.0
coarse_ratio: '3x_patchsize'
# coarse_ratio: '0.1x_nframes'
pyr_factor: 0.75
num_stages_limit: -1
patch_size: 11
loop: false

================================================
FILE: dataset/blender_motion.py
================================================
import os
import os.path as osp
import torch
import numpy as np
import torch.nn.functional as F
from .motion import MotionData
from utils.transforms import quat2repr6d, euler2mat, mat2quat, repr6d2quat, quat2euler

class BlenderMotion:
    def __init__(self, motion_data, repr='quat', use_velo=True, keep_up_pos=True, up_axis=None, padding_last=False):
        '''
        BVHMotion constructor
        Args:
            motion_data      : np.array, bvh format data to load from
            repr             : string, rotation representation, support ['quat', 'repr6d', 'euler'] 
            use_velo         : book, whether to transform the joints positions to velocities
            keep_up_pos      : bool, whether to keep y position when converting to velocity
            up_axis          : string, up axis of the motion data
            padding_last     : bool, whether to pad the last position
            requires_contact : bool, whether to concatenate contact information
        '''
        self.motion_data = motion_data

        def to_tensor(motion_data, repr='euler', rot_only=False):
            if repr not in ['euler', 'quat', 'quaternion', 'repr6d']:
                raise Exception('Unknown rotation representation')
            if repr == 'quaternion' or repr == 'quat' or repr == 'repr6d': # default is euler for blender data
                rotations = torch.tensor(motion_data[:, 3:], dtype=torch.float).view(motion_data.shape[0], -1, 3)
            if repr == 'quat':
                rotations = euler2mat(rotations)
                rotations = mat2quat(rotations)
            if repr == 'repr6d':
                rotations = euler2mat(rotations)
                rotations = mat2quat(rotations)
                rotations = quat2repr6d(rotations)

            positions = torch.tensor(motion_data[:, :3], dtype=torch.float32)

            if rot_only:
                return rotations.reshape(rotations.shape[0], -1)

            rotations = rotations.reshape(rotations.shape[0], -1)
            return torch.cat((rotations, positions), dim=-1)
        
        self.motion_data = MotionData(to_tensor(motion_data, repr=repr).permute(1, 0).unsqueeze(0), repr=repr, use_velo=use_velo, 
                                      keep_up_pos=keep_up_pos, up_axis=up_axis, padding_last=padding_last, contact_id=None)
    @property
    def repr(self):
        return self.motion_data.repr

    @property
    def use_velo(self):
        return self.motion_data.use_velo

    @property
    def keep_up_pos(self):
        return self.motion_data.keep_up_pos
    
    @property
    def padding_last(self):
        return self.motion_data.padding_last
    
    @property
    def concat_id(self):
        return self.motion_data.contact_id
    
    @property
    def n_pad(self):
        return self.motion_data.n_pad
    
    @property
    def n_contact(self):
        return self.motion_data.n_contact

    @property
    def n_rot(self):
        return self.motion_data.n_rot

    def sample(self, size=None, slerp=False):
        '''
        Sample motion data, support slerp
        '''
        return self.motion_data.sample(size, slerp)

    def parse(self, motion, keep_velo=False,):
        """
        No batch support here!!!
        :returns tracks_json
        """
        motion = motion.clone()

        if self.use_velo and not keep_velo:
            motion = self.motion_data.to_position(motion)
        if self.n_pad:
            motion = motion[:, :-self.n_pad]

        motion = motion.squeeze().permute(1, 0)
        pos = motion[..., -3:]
        rot = motion[..., :-3].reshape(motion.shape[0], -1, self.n_rot)
        if self.repr == 'quat':
            rot = quat2euler(rot)
        elif self.repr == 'repr6d':
            rot = repr6d2quat(rot)
            rot = quat2euler(rot)

        return torch.cat([pos, rot.view(motion.shape[0], -1)], dim=-1).cpu().numpy()


================================================
FILE: dataset/bvh/Quaternions.py
================================================
"""
This code is modified from:
http://theorangeduck.com/page/deep-learning-framework-character-motion-synthesis-and-editing

by Daniel Holden et al
"""


import numpy as np

class Quaternions:
    """
    Quaternions is a wrapper around a numpy ndarray
    that allows it to act as if it were an narray of
    a quater data type.
    
    Therefore addition, subtraction, multiplication,
    division, negation, absolute, are all defined
    in terms of quater operations such as quater
    multiplication.
    
    This allows for much neater code and many routines
    which conceptually do the same thing to be written
    in the same way for point data and for rotation data.
    
    The Quaternions class has been desgined such that it
    should support broadcasting and slicing in all of the
    usual ways.
    """
    
    def __init__(self, qs):
        if isinstance(qs, np.ndarray):
            if len(qs.shape) == 1: qs = np.array([qs])
            self.qs = qs
            return

        if isinstance(qs, Quaternions):
            self.qs = qs
            return

        raise TypeError('Quaternions must be constructed from iterable, numpy array, or Quaternions, not %s' % type(qs))
    
    def __str__(self): return "Quaternions("+ str(self.qs) + ")"
    def __repr__(self): return "Quaternions("+ repr(self.qs) + ")"
    
    """ Helper Methods for Broadcasting and Data extraction """
    
    @classmethod
    def _broadcast(cls, sqs, oqs, scalar=False):
        if isinstance(oqs, float): return sqs, oqs * np.ones(sqs.shape[:-1])
        
        ss = np.array(sqs.shape) if not scalar else np.array(sqs.shape[:-1])
        os = np.array(oqs.shape)

        if len(ss) != len(os):
            raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape))
            
        if np.all(ss == os): return sqs, oqs
        
        if not np.all((ss == os) | (os == np.ones(len(os))) | (ss == np.ones(len(ss)))):
            raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape))

        sqsn, oqsn = sqs.copy(), oqs.copy()

        for a in np.where(ss == 1)[0]: sqsn = sqsn.repeat(os[a], axis=a)
        for a in np.where(os == 1)[0]: oqsn = oqsn.repeat(ss[a], axis=a)
        
        return sqsn, oqsn
        
    """ Adding Quaterions is just Defined as Multiplication """
    
    def __add__(self, other): return self * other
    def __sub__(self, other): return self / other
    
    """ Quaterion Multiplication """
    
    def __mul__(self, other):
        """
        Quaternion multiplication has three main methods.
        
        When multiplying a Quaternions array by Quaternions
        normal quater multiplication is performed.
        
        When multiplying a Quaternions array by a vector
        array of the same shape, where the last axis is 3,
        it is assumed to be a Quaternion by 3D-Vector 
        multiplication and the 3D-Vectors are rotated
        in space by the Quaternions.
        
        When multipplying a Quaternions array by a scalar
        or vector of different shape it is assumed to be
        a Quaternions by Scalars multiplication and the
        Quaternions are scaled using Slerp and the identity
        quaternions.
        """
        
        """ If Quaternions type do Quaternions * Quaternions """
        if isinstance(other, Quaternions):
            sqs, oqs = Quaternions._broadcast(self.qs, other.qs)

            q0 = sqs[...,0]; q1 = sqs[...,1]; 
            q2 = sqs[...,2]; q3 = sqs[...,3]; 
            r0 = oqs[...,0]; r1 = oqs[...,1]; 
            r2 = oqs[...,2]; r3 = oqs[...,3]; 
            
            qs = np.empty(sqs.shape)
            qs[...,0] = r0 * q0 - r1 * q1 - r2 * q2 - r3 * q3
            qs[...,1] = r0 * q1 + r1 * q0 - r2 * q3 + r3 * q2
            qs[...,2] = r0 * q2 + r1 * q3 + r2 * q0 - r3 * q1
            qs[...,3] = r0 * q3 - r1 * q2 + r2 * q1 + r3 * q0
            
            return Quaternions(qs)
        
        """ If array type do Quaternions * Vectors """
        if isinstance(other, np.ndarray) and other.shape[-1] == 3:
            vs = Quaternions(np.concatenate([np.zeros(other.shape[:-1] + (1,)), other], axis=-1))

            return (self * (vs * -self)).imaginaries

        """ If float do Quaternions * Scalars """
        if isinstance(other, np.ndarray) or isinstance(other, float):
            return Quaternions.slerp(Quaternions.id_like(self), self, other)
        
        raise TypeError('Cannot multiply/add Quaternions with type %s' % str(type(other)))
        
    def __div__(self, other):
        """
        When a Quaternion type is supplied, division is defined
        as multiplication by the inverse of that Quaternion.
        
        When a scalar or vector is supplied it is defined
        as multiplicaion of one over the supplied value.
        Essentially a scaling.
        """
        
        if isinstance(other, Quaternions): return self * (-other)
        if isinstance(other, np.ndarray): return self * (1.0 / other)
        if isinstance(other, float): return self * (1.0 / other)
        raise TypeError('Cannot divide/subtract Quaternions with type %s' + str(type(other)))
        
    def __eq__(self, other): return self.qs == other.qs
    def __ne__(self, other): return self.qs != other.qs
    
    def __neg__(self):
        """ Invert Quaternions """
        return Quaternions(self.qs * np.array([[1, -1, -1, -1]]))
    
    def __abs__(self):
        """ Unify Quaternions To Single Pole """
        qabs = self.normalized().copy()
        top = np.sum(( qabs.qs) * np.array([1,0,0,0]), axis=-1)
        bot = np.sum((-qabs.qs) * np.array([1,0,0,0]), axis=-1)
        qabs.qs[top < bot] = -qabs.qs[top <  bot]
        return qabs
    
    def __iter__(self): return iter(self.qs)
    def __len__(self): return len(self.qs)
    
    def __getitem__(self, k):    return Quaternions(self.qs[k]) 
    def __setitem__(self, k, v): self.qs[k] = v.qs
        
    @property
    def lengths(self):
        return np.sum(self.qs**2.0, axis=-1)**0.5
    
    @property
    def reals(self):
        return self.qs[...,0]
        
    @property
    def imaginaries(self):
        return self.qs[...,1:4]
    
    @property
    def shape(self): return self.qs.shape[:-1]
    
    def repeat(self, n, **kwargs):
        return Quaternions(self.qs.repeat(n, **kwargs))
    
    def normalized(self):
        return Quaternions(self.qs / self.lengths[...,np.newaxis])
    
    def log(self):
        norm = abs(self.normalized())
        imgs = norm.imaginaries
        lens = np.sqrt(np.sum(imgs**2, axis=-1))
        lens = np.arctan2(lens, norm.reals) / (lens + 1e-10)
        return imgs * lens[...,np.newaxis]
    
    def constrained(self, axis):
        
        rl = self.reals
        im = np.sum(axis * self.imaginaries, axis=-1)
        
        t1 = -2 * np.arctan2(rl, im) + np.pi
        t2 = -2 * np.arctan2(rl, im) - np.pi
        
        top = Quaternions.exp(axis[np.newaxis] * (t1[:,np.newaxis] / 2.0))
        bot = Quaternions.exp(axis[np.newaxis] * (t2[:,np.newaxis] / 2.0))
        img = self.dot(top) > self.dot(bot)
        
        ret = top.copy()
        ret[ img] = top[ img]
        ret[~img] = bot[~img]
        return ret
    
    def constrained_x(self): return self.constrained(np.array([1,0,0]))
    def constrained_y(self): return self.constrained(np.array([0,1,0]))
    def constrained_z(self): return self.constrained(np.array([0,0,1]))
    
    def dot(self, q): return np.sum(self.qs * q.qs, axis=-1)
    
    def copy(self): return Quaternions(np.copy(self.qs))
    
    def reshape(self, s):
        self.qs.reshape(s)
        return self
    
    def interpolate(self, ws):
        return Quaternions.exp(np.average(abs(self).log, axis=0, weights=ws))
    
    def euler(self, order='xyz'):
        
        q = self.normalized().qs
        q0 = q[...,0]
        q1 = q[...,1]
        q2 = q[...,2]
        q3 = q[...,3]
        es = np.zeros(self.shape + (3,))

        # These version is wrong on converting
        '''
        if   order == 'xyz':
            es[...,0] = np.arctan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
            es[...,1] = np.arcsin((2 * (q0 * q2 - q3 * q1)).clip(-1,1))
            es[...,2] = np.arctan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
        elif order == 'yzx':
            es[...,0] = np.arctan2(2 * (q1 * q0 - q2 * q3), -q1 * q1 + q2 * q2 - q3 * q3 + q0 * q0)
            es[...,1] = np.arctan2(2 * (q2 * q0 - q1 * q3),  q1 * q1 - q2 * q2 - q3 * q3 + q0 * q0)
            es[...,2] = np.arcsin((2 * (q1 * q2 + q3 * q0)).clip(-1,1))
        else:
            raise NotImplementedError('Cannot convert from ordering %s' % order)
        
        '''
        
        if   order == 'xyz':
            es[..., 2] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
            es[..., 1] = np.arcsin((2 * (q1 * q3 + q0 * q2)).clip(-1,1))
            es[..., 0] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
        else:
            raise NotImplementedError('Cannot convert from ordering %s' % order)

        # These conversion don't appear to work correctly for Maya.
        # http://bediyap.com/programming/convert-quaternion-to-euler-rotations/
        '''
        if   order == 'xyz':
            es[..., 0] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
            es[..., 1] = np.arcsin((2 * (q1 * q3 + q0 * q2)).clip(-1,1))
            es[..., 2] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
        elif order == 'yzx':
            es[fa + (0,)] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3)
            es[fa + (1,)] = np.arcsin((2 * (q1 * q2 + q0 * q3)).clip(-1,1))
            es[fa + (2,)] = np.arctan2(2 * (q0 * q2 - q1 * q3), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
        elif order == 'zxy':
            es[fa + (0,)] = np.arctan2(2 * (q0 * q2 - q1 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
            es[fa + (1,)] = np.arcsin((2 * (q0 * q1 + q2 * q3)).clip(-1,1))
            es[fa + (2,)] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3) 
        elif order == 'xzy':
            es[fa + (0,)] = np.arctan2(2 * (q0 * q2 + q1 * q3), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
            es[fa + (1,)] = np.arcsin((2 * (q0 * q3 - q1 * q2)).clip(-1,1))
            es[fa + (2,)] = np.arctan2(2 * (q0 * q1 + q2 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3)
        elif order == 'yxz':
            es[fa + (0,)] = np.arctan2(2 * (q1 * q2 + q0 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3)
            es[fa + (1,)] = np.arcsin((2 * (q0 * q1 - q2 * q3)).clip(-1,1))
            es[fa + (2,)] = np.arctan2(2 * (q1 * q3 + q0 * q2), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
        elif order == 'zyx':
            es[fa + (0,)] = np.arctan2(2 * (q0 * q1 + q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
            es[fa + (1,)] = np.arcsin((2 * (q0 * q2 - q1 * q3)).clip(-1,1))
            es[fa + (2,)] = np.arctan2(2 * (q0 * q3 + q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
        
        else:
            raise KeyError('Unknown ordering %s' % order)
        '''

        
        # https://github.com/ehsan/ogre/blob/master/OgreMain/src/OgreMatrix3.cpp
        # Use this class and convert from matrix
        
        return es
        
    
    def average(self):
        
        if len(self.shape) == 1:
            
            import numpy.core.umath_tests as ut
            system = ut.matrix_multiply(self.qs[:,:,np.newaxis], self.qs[:,np.newaxis,:]).sum(axis=0)
            w, v = np.linalg.eigh(system)
            qiT_dot_qref = (self.qs[:,:,np.newaxis] * v[np.newaxis,:,:]).sum(axis=1)
            return Quaternions(v[:,np.argmin((1.-qiT_dot_qref**2).sum(axis=0))])            
        
        else:
            
            raise NotImplementedError('Cannot average multi-dimensionsal Quaternions')

    def angle_axis(self):
        
        norm = self.normalized()        
        s = np.sqrt(1 - (norm.reals**2.0))
        s[s == 0] = 0.001
        
        angles = 2.0 * np.arccos(norm.reals)
        axis = norm.imaginaries / s[...,np.newaxis]
        
        return angles, axis
        
    
    def transforms(self):
        
        qw = self.qs[...,0]
        qx = self.qs[...,1]
        qy = self.qs[...,2]
        qz = self.qs[...,3]
        
        x2 = qx + qx; y2 = qy + qy; z2 = qz + qz;
        xx = qx * x2; yy = qy * y2; wx = qw * x2;
        xy = qx * y2; yz = qy * z2; wy = qw * y2;
        xz = qx * z2; zz = qz * z2; wz = qw * z2;

        m = np.empty(self.shape + (3,3))
        m[...,0,0] = 1.0 - (yy + zz)
        m[...,0,1] = xy - wz
        m[...,0,2] = xz + wy
        m[...,1,0] = xy + wz
        m[...,1,1] = 1.0 - (xx + zz)
        m[...,1,2] = yz - wx
        m[...,2,0] = xz - wy
        m[...,2,1] = yz + wx
        m[...,2,2] = 1.0 - (xx + yy)
        
        return m
    
    def ravel(self):
        return self.qs.ravel()
    
    @classmethod
    def id(cls, n):
        
        if isinstance(n, tuple):
            qs = np.zeros(n + (4,))
            qs[...,0] = 1.0
            return Quaternions(qs)
        
        if isinstance(n, int) or isinstance(n, long):
            qs = np.zeros((n,4))
            qs[:,0] = 1.0
            return Quaternions(qs)
        
        raise TypeError('Cannot Construct Quaternion from %s type' % str(type(n)))

    @classmethod
    def id_like(cls, a):
        qs = np.zeros(a.shape + (4,))
        qs[...,0] = 1.0
        return Quaternions(qs)
        
    @classmethod
    def exp(cls, ws):
    
        ts = np.sum(ws**2.0, axis=-1)**0.5
        ts[ts == 0] = 0.001
        ls = np.sin(ts) / ts
        
        qs = np.empty(ws.shape[:-1] + (4,))
        qs[...,0] = np.cos(ts)
        qs[...,1] = ws[...,0] * ls
        qs[...,2] = ws[...,1] * ls
        qs[...,3] = ws[...,2] * ls
        
        return Quaternions(qs).normalized()
        
    @classmethod
    def slerp(cls, q0s, q1s, a):
        
        fst, snd = cls._broadcast(q0s.qs, q1s.qs)
        fst, a = cls._broadcast(fst, a, scalar=True)
        snd, a = cls._broadcast(snd, a, scalar=True)
        
        len = np.sum(fst * snd, axis=-1)
        
        neg = len < 0.0
        len[neg] = -len[neg]
        snd[neg] = -snd[neg]
        
        amount0 = np.zeros(a.shape)
        amount1 = np.zeros(a.shape)

        linear = (1.0 - len) < 0.01
        omegas = np.arccos(len[~linear])
        sinoms = np.sin(omegas)
        
        amount0[ linear] = 1.0 - a[linear]
        amount1[ linear] =       a[linear]
        amount0[~linear] = np.sin((1.0 - a[~linear]) * omegas) / sinoms
        amount1[~linear] = np.sin(       a[~linear]  * omegas) / sinoms
        
        return Quaternions(
            amount0[...,np.newaxis] * fst + 
            amount1[...,np.newaxis] * snd)
    
    @classmethod
    def between(cls, v0s, v1s):
        a = np.cross(v0s, v1s)
        w = np.sqrt((v0s**2).sum(axis=-1) * (v1s**2).sum(axis=-1)) + (v0s * v1s).sum(axis=-1)
        return Quaternions(np.concatenate([w[...,np.newaxis], a], axis=-1)).normalized()
    
    @classmethod
    def from_angle_axis(cls, angles, axis):
        axis    = axis / (np.sqrt(np.sum(axis**2, axis=-1)) + 1e-10)[...,np.newaxis]
        sines   = np.sin(angles / 2.0)[...,np.newaxis]
        cosines = np.cos(angles / 2.0)[...,np.newaxis]
        return Quaternions(np.concatenate([cosines, axis * sines], axis=-1))
    
    @classmethod
    def from_euler(cls, es, order='xyz', world=False):
    
        axis = {
            'x' : np.array([1,0,0]),
            'y' : np.array([0,1,0]),
            'z' : np.array([0,0,1]),
        }
        
        q0s = Quaternions.from_angle_axis(es[...,0], axis[order[0]])
        q1s = Quaternions.from_angle_axis(es[...,1], axis[order[1]])
        q2s = Quaternions.from_angle_axis(es[...,2], axis[order[2]])
        
        return (q2s * (q1s * q0s)) if world else (q0s * (q1s * q2s))
    
    @classmethod
    def from_transforms(cls, ts):
        
        d0, d1, d2 = ts[...,0,0], ts[...,1,1], ts[...,2,2]
        
        q0 = ( d0 + d1 + d2 + 1.0) / 4.0
        q1 = ( d0 - d1 - d2 + 1.0) / 4.0
        q2 = (-d0 + d1 - d2 + 1.0) / 4.0
        q3 = (-d0 - d1 + d2 + 1.0) / 4.0
        
        q0 = np.sqrt(q0.clip(0,None))
        q1 = np.sqrt(q1.clip(0,None))
        q2 = np.sqrt(q2.clip(0,None))
        q3 = np.sqrt(q3.clip(0,None))
        
        c0 = (q0 >= q1) & (q0 >= q2) & (q0 >= q3)
        c1 = (q1 >= q0) & (q1 >= q2) & (q1 >= q3)
        c2 = (q2 >= q0) & (q2 >= q1) & (q2 >= q3)
        c3 = (q3 >= q0) & (q3 >= q1) & (q3 >= q2)
        
        q1[c0] *= np.sign(ts[c0,2,1] - ts[c0,1,2])
        q2[c0] *= np.sign(ts[c0,0,2] - ts[c0,2,0])
        q3[c0] *= np.sign(ts[c0,1,0] - ts[c0,0,1])
        
        q0[c1] *= np.sign(ts[c1,2,1] - ts[c1,1,2])
        q2[c1] *= np.sign(ts[c1,1,0] + ts[c1,0,1])
        q3[c1] *= np.sign(ts[c1,0,2] + ts[c1,2,0])  
        
        q0[c2] *= np.sign(ts[c2,0,2] - ts[c2,2,0])
        q1[c2] *= np.sign(ts[c2,1,0] + ts[c2,0,1])
        q3[c2] *= np.sign(ts[c2,2,1] + ts[c2,1,2])  
        
        q0[c3] *= np.sign(ts[c3,1,0] - ts[c3,0,1])
        q1[c3] *= np.sign(ts[c3,2,0] + ts[c3,0,2])
        q2[c3] *= np.sign(ts[c3,2,1] + ts[c3,1,2])  
        
        qs = np.empty(ts.shape[:-2] + (4,))
        qs[...,0] = q0
        qs[...,1] = q1
        qs[...,2] = q2
        qs[...,3] = q3
        
        return cls(qs)


================================================
FILE: dataset/bvh/bvh_io.py
================================================
"""
This code is modified from:
http://theorangeduck.com/page/deep-learning-framework-character-motion-synthesis-and-editing

by Daniel Holden et al
"""


import re
import numpy as np
from dataset.bvh.Quaternions import Quaternions

channelmap = {
    'Xrotation' : 'x',
    'Yrotation' : 'y',
    'Zrotation' : 'z'   
}

channelmap_inv = {
    'x': 'Xrotation',
    'y': 'Yrotation',
    'z': 'Zrotation',
}

ordermap = {
    'x': 0,
    'y': 1,
    'z': 2,
}


class Animation:
    def __init__(self, rotations, positions, orients, offsets, parents, names, frametime):
        self.rotations = rotations
        self.positions = positions
        self.orients   = orients
        self.offsets   = offsets
        self.parent    = parents
        self.names     = names
        self.frametime = frametime

    @property
    def shape(self):
        return self.rotations.shape


def load(filename, start=None, end=None, order=None, world=False, need_quater=False) -> Animation:
    """
    Reads a BVH file and constructs an animation

    Parameters
    ----------
    filename: str
        File to be opened

    start : int
        Optional Starting Frame

    end : int
        Optional Ending Frame

    order : str
        Optional Specifier for joint order.
        Given as string E.G 'xyz', 'zxy'

    world : bool
        If set to true euler angles are applied
        together in world space rather than local
        space
    Returns
    -------

    (animation, joint_names, frametime)
        Tuple of loaded animation and joint names
    """

    f = open(filename, "r")

    i = 0
    active = -1
    end_site = False

    names = []
    orients = Quaternions.id(0)
    offsets = np.array([]).reshape((0, 3))
    parents = np.array([], dtype=int)
    orders = []

    for line in f:

        if "HIERARCHY" in line: continue
        if "MOTION" in line: continue

        """ Modified line read to handle mixamo data """
        #        rmatch = re.match(r"ROOT (\w+)", line)
        rmatch = re.match(r"ROOT (\w+:?\w+)", line)
        if rmatch:
            names.append(rmatch.group(1))
            offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0)
            orients.qs = np.append(orients.qs, np.array([[1, 0, 0, 0]]), axis=0)
            parents = np.append(parents, active)
            active = (len(parents) - 1)
            continue

        if "{" in line: continue

        if "}" in line:
            if end_site:
                end_site = False
            else:
                active = parents[active]
            continue

        offmatch = re.match(r"\s*OFFSET\s+([\-\d\.e]+)\s+([\-\d\.e]+)\s+([\-\d\.e]+)", line)
        if offmatch:
            if not end_site:
                offsets[active] = np.array([list(map(float, offmatch.groups()))])
            continue

        chanmatch = re.match(r"\s*CHANNELS\s+(\d+)", line)
        if chanmatch:
            channels = int(chanmatch.group(1))

            channelis = 0 if channels == 3 else 3
            channelie = 3 if channels == 3 else 6
            parts = line.split()[2 + channelis:2 + channelie]
            if any([p not in channelmap for p in parts]):
                continue
            order = "".join([channelmap[p] for p in parts])
            orders.append(order)
            continue

        """ Modified line read to handle mixamo data """
        #        jmatch = re.match("\s*JOINT\s+(\w+)", line)
        jmatch = re.match("\s*JOINT\s+(\w+:?\w+)", line)
        if jmatch:
            names.append(jmatch.group(1))
            offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0)
            orients.qs = np.append(orients.qs, np.array([[1, 0, 0, 0]]), axis=0)
            parents = np.append(parents, active)
            active = (len(parents) - 1)
            continue

        if "End Site" in line:
            end_site = True
            continue

        fmatch = re.match("\s*Frames:\s+(\d+)", line)
        if fmatch:
            if start and end:
                fnum = (end - start) - 1
            else:
                fnum = int(fmatch.group(1))
            jnum = len(parents)
            positions = offsets[np.newaxis].repeat(fnum, axis=0)
            rotations = np.zeros((fnum, len(orients), 3))
            continue

        fmatch = re.match("\s*Frame Time:\s+([\d\.]+)", line)
        if fmatch:
            frametime = float(fmatch.group(1))
            continue

        if (start and end) and (i < start or i >= end - 1):
            i += 1
            continue

        # dmatch = line.strip().split(' ')
        dmatch = line.strip().split()
        if dmatch:
            data_block = np.array(list(map(float, dmatch)))
            N = len(parents)
            fi = i - start if start else i
            if channels == 3:
                positions[fi, 0:1] = data_block[0:3]
                rotations[fi, :] = data_block[3:].reshape(N, 3)
            elif channels == 6:
                data_block = data_block.reshape(N, 6)
                positions[fi, :] = data_block[:, 0:3]
                rotations[fi, :] = data_block[:, 3:6]
            elif channels == 9:
                positions[fi, 0] = data_block[0:3]
                data_block = data_block[3:].reshape(N - 1, 9)
                rotations[fi, 1:] = data_block[:, 3:6]
                positions[fi, 1:] += data_block[:, 0:3] * data_block[:, 6:9]
            else:
                raise Exception("Too many channels! %i" % channels)

            i += 1

    f.close()

    all_rotations = []
    canonical_order = 'xyz'
    for i, order in enumerate(orders):
        rot = rotations[:, i:i + 1]
        if need_quater:
            quat = Quaternions.from_euler(np.radians(rot), order=order, world=world)
            all_rotations.append(quat)
            continue
        elif order != canonical_order:
            quat = Quaternions.from_euler(np.radians(rot), order=order, world=world)
            rot = np.degrees(quat.euler(order=canonical_order))
        all_rotations.append(rot)
    rotations = np.concatenate(all_rotations, axis=1)

    return Animation(rotations, positions, orients, offsets, parents, names, frametime)

    
def save(filename, anim, names=None, frametime=1.0/24.0, order='zyx', positions=False, orients=True):
    """
    Saves an Animation to file as BVH
    
    Parameters
    ----------
    filename: str
        File to be saved to
        
    anim : Animation
        Animation to save
        
    names : [str]
        List of joint names
    
    order : str
        Optional Specifier for joint order.
        Given as string E.G 'xyz', 'zxy'
    
    frametime : float
        Optional Animation Frame time
        
    positions : bool
        Optional specfier to save bone
        positions for each frame
        
    orients : bool
        Multiply joint orients to the rotations
        before saving.
        
    """
    
    if names is None:
        names = ["joint_" + str(i) for i in range(len(anim.parents))]
    
    with open(filename, 'w') as f:

        t = ""
        f.write("%sHIERARCHY\n" % t)
        f.write("%sROOT %s\n" % (t, names[0]))
        f.write("%s{\n" % t)
        t += '\t'

        f.write("%sOFFSET %f %f %f\n" % (t, anim.offsets[0,0], anim.offsets[0,1], anim.offsets[0,2]) )
        f.write("%sCHANNELS 6 Xposition Yposition Zposition %s %s %s \n" % 
            (t, channelmap_inv[order[0]], channelmap_inv[order[1]], channelmap_inv[order[2]]))

        for i in range(anim.shape[1]):
            if anim.parents[i] == 0:
                t = save_joint(f, anim, names, t, i, order=order, positions=positions)

        t = t[:-1]
        f.write("%s}\n" % t)

        f.write("MOTION\n")
        f.write("Frames: %i\n" % anim.shape[0]);
        f.write("Frame Time: %f\n" % frametime);
            
        #if orients:        
        #    rots = np.degrees((-anim.orients[np.newaxis] * anim.rotations).euler(order=order[::-1]))
        #else:
        #    rots = np.degrees(anim.rotations.euler(order=order[::-1]))
        rots = np.degrees(anim.rotations.euler(order=order[::-1]))
        poss = anim.positions
        
        for i in range(anim.shape[0]):
            for j in range(anim.shape[1]):
                
                if positions or j == 0:
                
                    f.write("%f %f %f %f %f %f " % (
                        poss[i,j,0],                  poss[i,j,1],                  poss[i,j,2], 
                        rots[i,j,ordermap[order[0]]], rots[i,j,ordermap[order[1]]], rots[i,j,ordermap[order[2]]]))
                
                else:
                    
                    f.write("%f %f %f " % (
                        rots[i,j,ordermap[order[0]]], rots[i,j,ordermap[order[1]]], rots[i,j,ordermap[order[2]]]))

            f.write("\n")
    
    
def save_joint(f, anim, names, t, i, order='zyx', positions=False):
    
    f.write("%sJOINT %s\n" % (t, names[i]))
    f.write("%s{\n" % t)
    t += '\t'
  
    f.write("%sOFFSET %f %f %f\n" % (t, anim.offsets[i,0], anim.offsets[i,1], anim.offsets[i,2]))
    
    if positions:
        f.write("%sCHANNELS 6 Xposition Yposition Zposition %s %s %s \n" % (t, 
            channelmap_inv[order[0]], channelmap_inv[order[1]], channelmap_inv[order[2]]))
    else:
        f.write("%sCHANNELS 3 %s %s %s\n" % (t, 
            channelmap_inv[order[0]], channelmap_inv[order[1]], channelmap_inv[order[2]]))
    
    end_site = True
    
    for j in range(anim.shape[1]):
        if anim.parents[j] == i:
            t = save_joint(f, anim, names, t, j, order=order, positions=positions)
            end_site = False
    
    if end_site:
        f.write("%sEnd Site\n" % t)
        f.write("%s{\n" % t)
        t += '\t'
        f.write("%sOFFSET %f %f %f\n" % (t, 0.0, 0.0, 0.0))
        t = t[:-1]
        f.write("%s}\n" % t)
  
    t = t[:-1]
    f.write("%s}\n" % t)
    
    return t


================================================
FILE: dataset/bvh/bvh_parser.py
================================================
import torch
import numpy as np
import dataset.bvh.bvh_io as bvh_io
from utils.kinematics import ForwardKinematicsJoint
from utils.transforms import quat2repr6d
from utils.contact import foot_contact
from dataset.bvh.Quaternions import Quaternions
from dataset.bvh.bvh_writer import WriterWrapper


class Skeleton:
    def __init__(self, names, parent, offsets, joint_reduction=True, skeleton_conf=None):
        self._names = names
        self.original_parent = parent
        self._offsets = offsets
        self._parent = None
        self._ee_id = None
        self.contact_names = []

        for i, name in enumerate(self._names):
            if ':' in name:
                self._names[i] = name[name.find(':')+1:]

        if joint_reduction or skeleton_conf is not None:
            assert skeleton_conf is not None, 'skeleton_conf can not be None if you use joint reduction'
            corps_names = skeleton_conf['corps_names']
            self.contact_names = skeleton_conf['corps_names']
            self.contact_threshold = skeleton_conf['contact_threshold']

            self.contact_id = []
            for i in self.contact_names:
                self.contact_id.append(corps_names.index(i))
        else:
            self.skeleton_type = -1
            corps_names = self._names

        self.details = []    # joints that does not belong to the corps (we are not interested in them)
        for i, name in enumerate(self._names):
            if name not in corps_names: self.details.append(i)

        self.corps = []
        self.simplified_name = []
        self.simplify_map = {}
        self.inverse_simplify_map = {}

        # Repermute the skeleton id according to the databse
        for name in corps_names:
            for j in range(len(self._names)):
                if name in self._names[j]:
                    self.corps.append(j)
                    break
        if len(self.corps) != len(corps_names):
            for i in self.corps:
                print(self._names[i], end=' ')
            print(self.corps, self.skeleton_type, len(self.corps), sep='\n')
            raise Exception('Problem in this skeleton')

        self.joint_num_simplify = len(self.corps)
        for i, j in enumerate(self.corps):
            self.simplify_map[j] = i
            self.inverse_simplify_map[i] = j
            self.simplified_name.append(self._names[j])
        self.inverse_simplify_map[0] = -1
        for i in range(len(self._names)):
            if i in self.details:
                self.simplify_map[i] = -1

    @property
    def parent(self):
        if self._parent is None:
            self._parent = self.original_parent[self.corps].copy()
            for i in range(self._parent.shape[0]):
                if i >= 1: self._parent[i] = self.simplify_map[self._parent[i]]
            self._parent = tuple(self._parent)
        return self._parent

    @property
    def offsets(self):
        return torch.tensor(self._offsets[self.corps], dtype=torch.float)

    @property
    def names(self):
        return self.simplified_name

    @property
    def ee_id(self):
        raise Exception('Abaddoned')
        # if self._ee_id is None:
        #     self._ee_id = []
        #     for i in SkeletonDatabase.ee_names[self.skeleton_type]:
        #         self.ee_id._ee_id(corps_names[self.skeleton_type].index(i))


class BVH_file:
    def __init__(self, file_path, skeleton_conf=None, requires_contact=False, joint_reduction=True, auto_scale=True):
        self.anim = bvh_io.load(file_path)
        self._names = self.anim.names
        self.frametime = self.anim.frametime
        if requires_contact or joint_reduction:
            assert skeleton_conf is not None, 'Please provide a skeleton configuration for contact or joint reduction'
        self.skeleton = Skeleton(self.anim.names, self.anim.parent, self.anim.offsets, joint_reduction, skeleton_conf)

        # Downsample to 30 fps for our application
        if self.frametime < 0.0084:
            self.frametime *= 2
            self.anim.positions = self.anim.positions[::2]
            self.anim.rotations = self.anim.rotations[::2]
        if self.frametime < 0.017:
            self.frametime *= 2
            self.anim.positions = self.anim.positions[::2]
            self.anim.rotations = self.anim.rotations[::2]

        self.requires_contact = requires_contact

        if requires_contact:
            self.contact_names = self.skeleton.contact_names
        else:
            self.contact_names = []

        self.fk = ForwardKinematicsJoint(self.skeleton.parent, self.skeleton.offsets)
        self.writer = WriterWrapper(self.skeleton.parent, self.skeleton.offsets)

        self.auto_scale = auto_scale
        if auto_scale:
            self.scale = 1. / np.ceil(self.skeleton.offsets.max().cpu().numpy())
            print(f'rescale the skeleton with scale: {self.scale}')
            self.rescale(self.scale)
        else:
            self.scale = 1.0

        if self.requires_contact:
            gl_pos = self.joint_position()
            self.contact_label = foot_contact(gl_pos[:, self.skeleton.contact_id],
                                              threshold=self.skeleton.contact_threshold)
            self.gl_pos = gl_pos

    def local_pos(self):
        gl_pos = self.joint_position()
        local_pos = gl_pos - gl_pos[:, 0:1, :]
        return local_pos[:, 1:]

    def rescale(self, ratio):
        self.anim.offsets *= ratio
        self.anim.positions *= ratio

    def to_tensor(self, repr='euler', rot_only=False):
        if repr not in ['euler', 'quat', 'quaternion', 'repr6d']:
            raise Exception('Unknown rotation representation')
        positions = self.get_position()
        rotations = self.get_rotation(repr=repr)

        if rot_only:
            return rotations.reshape(rotations.shape[0], -1)

        if self.requires_contact:
            virtual_contact = torch.zeros_like(rotations[:, :len(self.skeleton.contact_id)])
            virtual_contact[..., 0] = self.contact_label
            rotations = torch.cat([rotations, virtual_contact], dim=1)

        rotations = rotations.reshape(rotations.shape[0], -1)
        return torch.cat((rotations, positions), dim=-1)

    def joint_position(self):
        positions = torch.tensor(self.anim.positions[:, 0, :], dtype=torch.float)
        rotations = self.anim.rotations[:, self.skeleton.corps, :]
        rotations = Quaternions.from_euler(np.radians(rotations)).qs
        rotations = torch.tensor(rotations, dtype=torch.float)
        j_loc = self.fk.forward(rotations, positions)
        return j_loc

    def get_rotation(self, repr='quat'):
        rotations = self.anim.rotations[:, self.skeleton.corps, :]
        if repr == 'quaternion' or repr == 'quat' or repr == 'repr6d':
            rotations = Quaternions.from_euler(np.radians(rotations)).qs
            rotations = torch.tensor(rotations, dtype=torch.float)
        if repr == 'repr6d':
            rotations = quat2repr6d(rotations)
        if repr == 'euler':
            rotations = torch.tensor(rotations, dtype=torch.float)
        return rotations

    def get_position(self):
        return torch.tensor(self.anim.positions[:, 0, :], dtype=torch.float)

    def dfs(self, x, vis, dist):
        fa = self.skeleton.parent
        vis[x] = 1
        for y in range(len(fa)):
            if (fa[y] == x or fa[x] == y) and vis[y] == 0:
                dist[y] = dist[x] + 1
                self.dfs(y, vis, dist)

    def get_neighbor(self, threshold, enforce_contact=False):
        fa = self.skeleton.parent
        neighbor_list = []
        for x in range(0, len(fa)):
            vis = [0 for _ in range(len(fa))]
            dist = [0 for _ in range(len(fa))]
            self.dfs(x, vis, dist)
            neighbor = []
            for j in range(0, len(fa)):
                if dist[j] <= threshold:
                    neighbor.append(j)
            neighbor_list.append(neighbor)

        contact_list = []
        if self.requires_contact:
            for i, p_id in enumerate(self.skeleton.contact_id):
                v_id = len(neighbor_list)
                neighbor_list[p_id].append(v_id)
                neighbor_list.append(neighbor_list[p_id])
                contact_list.append(v_id)

        root_neighbor = neighbor_list[0]
        id_root = len(neighbor_list)

        if enforce_contact:
            root_neighbor = root_neighbor + contact_list
            for j in contact_list:
                neighbor_list[j] = list(set(neighbor_list[j]))

        root_neighbor = list(set(root_neighbor))
        for j in root_neighbor:
            neighbor_list[j].append(id_root)
        root_neighbor.append(id_root)
        neighbor_list.append(root_neighbor)  # Neighbor for root position
        return neighbor_list

================================================
FILE: dataset/bvh/bvh_writer.py
================================================
import torch
from utils.transforms import quat2euler, repr6d2quat


# rotation with shape frame * J * 3
def write_bvh(parent, offset, rotation, position, names, frametime, order, path, endsite=None):
    file = open(path, 'w')
    frame = rotation.shape[0]
    joint_num = rotation.shape[1]
    order = order.upper()

    file_string = 'HIERARCHY\n'

    seq = []

    def write_static(idx, prefix):
        nonlocal parent, offset, rotation, names, order, endsite, file_string, seq
        seq.append(idx)
        if idx == 0:
            name_label = 'ROOT ' + names[idx]
            channel_label = 'CHANNELS 6 Xposition Yposition Zposition {}rotation {}rotation {}rotation'.format(*order)
        else:
            name_label = 'JOINT ' + names[idx]
            channel_label = 'CHANNELS 3 {}rotation {}rotation {}rotation'.format(*order)
        offset_label = 'OFFSET %.6f %.6f %.6f' % (offset[idx][0], offset[idx][1], offset[idx][2])

        file_string += prefix + name_label + '\n'
        file_string += prefix + '{\n'
        file_string += prefix + '\t' + offset_label + '\n'
        file_string += prefix + '\t' + channel_label + '\n'

        has_child = False
        for y in range(idx+1, rotation.shape[1]):
            if parent[y] == idx:
                has_child = True
                write_static(y, prefix + '\t')
        if not has_child:
            file_string += prefix + '\t' + 'End Site\n'
            file_string += prefix + '\t' + '{\n'
            file_string += prefix + '\t\t' + 'OFFSET 0 0 0\n'
            file_string += prefix + '\t' + '}\n'

        file_string += prefix + '}\n'

    write_static(0, '')

    file_string += 'MOTION\n' + 'Frames: {}\n'.format(frame) + 'Frame Time: %.8f\n' % frametime
    for i in range(frame):
        file_string += '%.6f %.6f %.6f ' % (position[i][0], position[i][1], position[i][2])
        for j in range(joint_num):
            idx = seq[j]
            file_string += '%.6f %.6f %.6f ' % (rotation[i][idx][0], rotation[i][idx][1], rotation[i][idx][2])
        file_string += '\n'

    file.write(file_string)
    return file_string


class WriterWrapper:
    def __init__(self, parents, offset=None):
        self.parents = parents
        self.offset = offset

    def write(self, filename, rot, pos, offset=None, names=None, repr='quat'):
        """
        Write animation to bvh file
        :param filename:
        :param rot: Quaternion as (w, x, y, z)
        :param pos:
        :param offset:
        :return:
        """
        if repr not in ['euler', 'quat', 'quaternion', 'repr6d']:
            raise Exception('Unknown rotation representation')
        if offset is None:
            offset = self.offset
        if not isinstance(offset, torch.Tensor):
            offset = torch.tensor(offset)
        n_bone = offset.shape[0]

        if repr == 'repr6d':
            rot = rot.reshape(rot.shape[0], -1, 6)
            rot = repr6d2quat(rot)
        if repr == 'repr6d' or repr == 'quat' or repr == 'quaternion':
            rot = rot.reshape(rot.shape[0], -1, 4)
            rot /= rot.norm(dim=-1, keepdim=True) ** 0.5
            euler = quat2euler(rot, order='xyz')
            rot = euler

        if names is None:
            names = ['%02d' % i for i in range(n_bone)]
        write_bvh(self.parents, offset, rot, pos, names, 1, 'xyz', filename)


================================================
FILE: dataset/bvh_motion.py
================================================
import os
import os.path as osp
import torch
import numpy as np
import torch.nn.functional as F
from .motion import MotionData
from .bvh.bvh_parser import BVH_file


## Some skeleton configurations
crab_dance_corps_names = ['ORG_Hips', 'ORG_BN_Bip01_Pelvis', 'DEF_BN_Eye_L_01', 'DEF_BN_Eye_L_02', 'DEF_BN_Eye_L_03', 'DEF_BN_Eye_L_03_end', 'DEF_BN_Eye_R_01', 'DEF_BN_Eye_R_02', 'DEF_BN_Eye_R_03', 'DEF_BN_Eye_R_03_end', 'DEF_BN_Leg_L_11', 'DEF_BN_Leg_L_12', 'DEF_BN_Leg_L_13', 'DEF_BN_Leg_L_14', 'DEF_BN_Leg_L_15', 'DEF_BN_Leg_L_15_end', 'DEF_BN_Leg_R_11', 'DEF_BN_Leg_R_12', 'DEF_BN_Leg_R_13', 'DEF_BN_Leg_R_14', 'DEF_BN_Leg_R_15', 'DEF_BN_Leg_R_15_end', 'DEF_BN_leg_L_01', 'DEF_BN_leg_L_02', 'DEF_BN_leg_L_03', 'DEF_BN_leg_L_04', 'DEF_BN_leg_L_05', 'DEF_BN_leg_L_05_end',
                         'DEF_BN_leg_L_06', 'DEF_BN_Leg_L_07', 'DEF_BN_Leg_L_08', 'DEF_BN_Leg_L_09', 'DEF_BN_Leg_L_10', 'DEF_BN_Leg_L_10_end', 'DEF_BN_leg_R_01', 'DEF_BN_leg_R_02', 'DEF_BN_leg_R_03', 'DEF_BN_leg_R_04', 'DEF_BN_leg_R_05', 'DEF_BN_leg_R_05_end', 'DEF_BN_leg_R_06', 'DEF_BN_Leg_R_07', 'DEF_BN_Leg_R_08', 'DEF_BN_Leg_R_09', 'DEF_BN_Leg_R_10', 'DEF_BN_Leg_R_10_end', 'DEF_BN_Bip01_Pelvis', 'DEF_BN_Bip01_Pelvis_end', 'DEF_BN_Arm_L_01', 'DEF_BN_Arm_L_02', 'DEF_BN_Arm_L_03', 'DEF_BN_Arm_L_03_end', 'DEF_BN_Arm_R_01', 'DEF_BN_Arm_R_02', 'DEF_BN_Arm_R_03', 'DEF_BN_Arm_R_03_end']
skeleton_confs = {
    'mixamo': {
        'corps_names': ['Hips', 'LeftUpLeg', 'LeftLeg', 'LeftFoot', 'LeftToeBase', 'LeftToe_End', 'RightUpLeg', 'RightLeg', 'RightFoot', 'RightToeBase', 'RightToe_End', 'Spine', 'Spine1', 'Spine2', 'Neck', 'Head', 'LeftShoulder', 'LeftArm', 'LeftForeArm', 'LeftHand', 'RightShoulder', 'RightArm', 'RightForeArm', 'RightHand'],
        'contact_names': ['LeftToe_End', 'RightToe_End', 'LeftToeBase', 'RightToeBase'],
        'contact_threshold': 0.018
    },
    'crab_dance': {
        'corps_names': crab_dance_corps_names,
        'contact_names': [name for name in crab_dance_corps_names if 'end' in name and ('05' in name or '10' in name or '15' in name)],
        'contact_threshold': 0.006
    },
    'xia': {
        'corps_names': ['Hips', 'LHipJoint', 'LeftUpLeg', 'LeftLeg', 'LeftFoot', 'LeftToeBase', 'RHipJoint', 'RightUpLeg', 'RightLeg', 'RightFoot', 'RightToeBase', 'LowerBack', 'Spine', 'Spine1', 'Neck', 'Neck1', 'Head', 'LeftShoulder', 'LeftArm', 'LeftForeArm', 'LeftHand', 'LeftFingerBase', 'LeftHandIndex1', 'LThumb', 'RightShoulder', 'RightArm', 'RightForeArm', 'RightHand', 'RightFingerBase', 'RightHandIndex1', 'RThumb'],
        'contact_names': ['LeftToeBase', 'RightToeBase'],
        'contact_threshold': 0.006
    }
}

class BVHMotion:
    def __init__(self, bvh_file, skeleton_name=None, repr='quat', use_velo=True, keep_up_pos=False, up_axis='Y_UP', padding_last=False, requires_contact=False, joint_reduction=False):
        '''
        BVHMotion constructor
        Args:
            bvh_file         : string, bvh_file path to load from
            skelton_name     : string, name of predefined skeleton, used when joint_reduction==True or contact==True
            repr             : string, rotation representation, support ['quat', 'repr6d', 'euler'] 
            use_velo         : book, whether to transform the joints positions to velocities
            keep_up_pos      : bool, whether to keep y position when converting to velocity
            up_axis          : string, string, up axis of the motion data
            padding_last     : bool, whether to pad the last position
            requires_contact : bool, whether to concatenate contact information
            joint_reduction  : bool, whether to reduce the joint number
        '''
        self.bvh_file = bvh_file
        self.skeleton_name = skeleton_name
        if skeleton_name is not None:
            assert skeleton_name in skeleton_confs, f'{skeleton_name} not found, please add a skeleton configuration.'
        self.requires_contact = requires_contact
        self.joint_reduction = joint_reduction

        self.raw_data = BVH_file(bvh_file, skeleton_confs[skeleton_name] if skeleton_name is not None else None, requires_contact, joint_reduction, auto_scale=True)
        self.motion_data = MotionData(self.raw_data.to_tensor(repr=repr).permute(1, 0).unsqueeze(0), repr=repr, use_velo=use_velo, keep_up_pos=keep_up_pos, up_axis=up_axis, 
                                      padding_last=padding_last, contact_id=self.raw_data.skeleton.contact_id if requires_contact else None)
    @property
    def repr(self):
        return self.motion_data.repr

    @property
    def use_velo(self):
        return self.motion_data.use_velo

    @property
    def keep_up_pos(self):
        return self.motion_data.keep_up_pos
    
    @property
    def padding_last(self):
        return self.motion_data.padding_last
    
    @property
    def concat_id(self):
        return self.motion_data.contact_id
    
    @property
    def n_pad(self):
        return self.motion_data.n_pad
    
    @property
    def n_contact(self):
        return self.motion_data.n_contact

    @property
    def n_rot(self):
        return self.motion_data.n_rot

    def sample(self, size=None, slerp=False):
        '''
        Sample motion data, support slerp
        '''
        return self.motion_data.sample(size, slerp)


    def write(self, filename, data):
        '''
        Parse motion data into position, velocity and contact(if exists)
        data should be []
        No batch support here!!!
        '''
        assert len(data.shape) == 3, 'The data format should be [batch_size x n_channels x n_frames]' 

        if self.n_pad:
            data = data.clone()[:, :-self.n_pad]
        if self.use_velo:
            data = self.motion_data.to_position(data)
        data = data.squeeze().permute(1, 0)
        pos = data[..., -3:]
        rot = data[..., :-3].reshape(data.shape[0], -1, self.n_rot)
        if self.requires_contact:
            contact = rot[..., -self.n_contact:, 0]
            rot = rot[..., :-self.n_contact, :]
        else:
            contact = None

        if contact is not None:
            np.save(filename + '.contact', contact.detach().cpu().numpy())

        # rescale the output
        self.raw_data.rescale(1. / self.raw_data.scale)
        pos *= 1. / self.raw_data.scale
        self.raw_data.writer.write(filename, rot, pos, names=self.raw_data.skeleton.names, repr=self.repr)


def load_multiple_dataset(name_list, **kargs):
        with open(name_list, 'r') as f:
            names = [line.strip() for line in f.readlines()]
        datasets = []
        for f in names:
            kargs['bvh_file'] = osp.join(osp.dirname(name_list), f)
            datasets.append(BVHMotion(**kargs))
        return datasets

================================================
FILE: dataset/motion.py
================================================
import torch
import torch.nn.functional as F


class MotionData:
    def __init__(self, data, repr='quat', use_velo=True, keep_up_pos=True, up_axis='Y', padding_last=False, contact_id=None):
        '''
        BaseMotionData constructor
        Args:
            data         : torch.Tensor, [batch_size x n_channels x n_frames] input motion data, 
                           the channels dim shoud be [n_joints x n_dim_of_rotation + 3(global position)]
            repr         : string, rotation representation, support ['quat', 'repr6d', 'euler'] 
            use_velo     : book, whether to transform the joints positions to velocities
            keep_up_pos  : bool, whether to keep up position when converting to velocity
            up_axis      : string, string, up axis of the motion data
            padding_last : bool, whether to pad the last position
            contact_id   : list, contact joints id
        '''
        self.data = data 
        self.repr = repr
        self.use_velo = use_velo
        self.keep_up_pos = keep_up_pos
        self.up_axis = up_axis
        self.padding_last = padding_last
        self.contact_id = contact_id
        self.begin_pos = None

        # assert the rotation representation
        if self.repr == 'quat':
            self.n_rot = 4
            assert (self.data.shape[1] - 3) % 4 == 0, 'rotation is not "quaternion" representation'
        elif self.repr == 'repr6d':
            self.n_rot = 6
            assert (self.data.shape[1] - 3) % 6 == 0, 'rotation is not "repr6d" representation'
        elif self.repr == 'eluer':
            self.n_rot = 3
            assert (self.data.shape[1] - 3) % 3 == 0, 'rotation is not "euler" representation'

        # whether to pad the position data with zero
        if self.padding_last:
            self.n_pad = self.data.shape[1] - 3  # pad position channels to match the n_channels of rotation
            paddings = torch.zeros_like(self.data[:, :self.n_pad])
            self.data = torch.cat((self.data, paddings), dim=1)
        else:
            self.n_pad = 0

        # get the contact information
        if self.contact_id is not None:
            self.n_contact = len(contact_id)
        else:
            self.n_contact = 0

        # whether to keep y position when converting to velocity
        if self.keep_up_pos:
            if self.up_axis == 'X_UP':
                self.velo_mask = [-2, -1]
            elif self.up_axis == 'Y_UP':
                self.velo_mask = [-3, -1]
            elif self.up_axis == 'Z_UP':
                self.velo_mask = [-3, -2]
        else:
            self.velo_mask = [-3, -2, -1]

        # whether to convert global position to velocity
        if self.use_velo:
            self.data =  self.to_velocity(self.data)


    def __len__(self):
        '''
        return the number of motion frames
        '''
        return self.data.shape[-1]


    def sample(self, size=None, slerp=False):
        '''
        sample the motion data using given size
        '''
        if size is None:
            return self.data
        else:
            if slerp:
                motion = self.slerp(self.data, size=size)
            else:
                motion = F.interpolate(self.data, size=size, mode='linear', align_corners=False)
            return motion


    def to_velocity(self, pos):
        '''
        convert motion data to velocity
        '''
        assert self.begin_pos is None, 'the motion data had been converted to velocity'
        msk = [i - self.n_pad for i in self.velo_mask]
        velo = pos.detach().clone().to(pos.device)
        velo[:, msk, 1:] = pos[:, msk, 1:] - pos[:, msk, :-1]
        self.begin_pos = pos[:, msk, 0].clone()
        velo[:, msk, 0] = pos[:, msk, 1]
        return velo

    def to_position(self, velo):
        '''
        convert motion data to position
        '''
        assert self.begin_pos is not None, 'the motion data is already position'
        msk = [i - self.n_pad for i in self.velo_mask]
        pos = velo.detach().clone().to(velo.device)
        pos[:, msk, 0] = self.begin_pos.to(velo.device)
        pos[:, msk] = torch.cumsum(pos[:, msk], dim=-1)
        self.begin_pos = None
        return pos

================================================
FILE: dataset/tracks_motion.py
================================================
import os
from os.path import join as pjoin
import numpy as np
import copy
import torch
from .motion import MotionData
from ..utils.transforms import quat2repr6d, quat2euler, repr6d2quat

class TracksParser():
    def __init__(self, tracks_json, scale):
        self.tracks_json = tracks_json
        self.scale = scale
        
        self.skeleton_names = []
        self.rotations = []
        for i, track in enumerate(self.tracks_json):
            self.skeleton_names.append(track['name'])
            if i == 0:
                assert track['type'] == 'vector'
                self.position = np.array(track['values']).reshape(-1, 3) * self.scale
                self.num_frames = self.position.shape[0]
            else:
                assert track['type'] == 'quaternion' # DEAFULT: quaternion
                rotation = np.array(track['values']).reshape(-1, 4)
                if rotation.shape[0] == 0:
                    rotation = np.zeros((self.num_frames, 4))
                elif rotation.shape[0] < self.num_frames:
                    rotation = np.repeat(rotation, self.num_frames // rotation.shape[0], axis=0)
                elif rotation.shape[0] > self.num_frames:
                    rotation = rotation[:self.num_frames]
                self.rotations += [rotation]
        self.rotations = np.array(self.rotations, dtype=np.float32)

    def to_tensor(self, repr='euler', rot_only=False):
        if repr not in ['euler', 'quat', 'quaternion', 'repr6d']:
            raise Exception('Unknown rotation representation')
        rotations = self.get_rotation(repr=repr)
        positions = self.get_position()

        if rot_only:
            return rotations.reshape(rotations.shape[0], -1)

        rotations = rotations.reshape(rotations.shape[0], -1)
        return torch.cat((rotations, positions), dim=-1)

    def get_rotation(self, repr='quat'):
        if repr == 'quaternion' or repr == 'quat' or repr == 'repr6d':
            rotations = torch.tensor(self.rotations, dtype=torch.float).transpose(0, 1)
        if repr == 'repr6d':
            rotations = quat2repr6d(rotations)
        if repr == 'euler':
            rotations = quat2euler(rotations)
        return rotations

    def get_position(self):
        return torch.tensor(self.position, dtype=torch.float32)

class TracksMotion:
    def __init__(self, tracks_json, scale=1.0, repr='quat', use_velo=True, keep_up_pos=True, up_axis='Y_UP', padding_last=False):
        '''
        TracksMotion constructor
        Args:
            tracks_json      : dict, json format tracks data to load from
            scale            : float, scale of the tracks motion data
            repr             : string, rotation representation, support ['quat', 'repr6d', 'euler'] 
            use_velo         : book, whether to transform the joints positions to velocities
            keep_up_pos      : bool, whether to keep y position when converting to velocity
            up_axis          : string, string, up axis of the motion data
            padding_last     : bool, whether to pad the last position
        '''
        self.tracks_json = tracks_json

        self.raw_data = TracksParser(tracks_json, scale)
        self.motion_data = MotionData(self.raw_data.to_tensor(repr=repr).permute(1, 0).unsqueeze(0), repr=repr, use_velo=use_velo, keep_up_pos=keep_up_pos, up_axis=up_axis, 
                                      padding_last=padding_last, contact_id=None)
    @property
    def repr(self):
        return self.motion_data.repr

    @property
    def use_velo(self):
        return self.motion_data.use_velo

    @property
    def keep_up_pos(self):
        return self.motion_data.keep_up_pos
    
    @property
    def padding_last(self):
        return self.motion_data.padding_last

    @property
    def n_pad(self):
        return self.motion_data.n_pad

    @property
    def n_rot(self):
        return self.motion_data.n_rot

    def sample(self, size=None, slerp=False):
        '''
        Sample motion data, support slerp
        '''
        return self.motion_data.sample(size, slerp)


    def parse(self, motion, keep_velo=False,):
        """
        No batch support here!!!
        :returns tracks_json
        """
        motion = motion.clone()

        if self.use_velo and not keep_velo:
            motion = self.motion_data.to_position(motion)
        if self.n_pad:
            motion = motion[:, :-self.n_pad]

        motion = motion.squeeze().permute(1, 0)
        pos = motion[..., -3:] / self.raw_data.scale
        rot = motion[..., :-3].reshape(motion.shape[0], -1, self.n_rot)
        if self.repr == 'repr6d':
            rot = repr6d2quat(rot)
        elif self.repr == 'euler':
            raise NotImplementedError('parse "euler is not implemented yet!!!')

        times = []
        out_tracks_json = copy.deepcopy(self.tracks_json)
        for i, _track in enumerate(out_tracks_json):
            if i == 0:
                times = [ j * out_tracks_json[i]['times'][1] for j in range(motion.shape[0])]
                out_tracks_json[i]['values'] = pos.flatten().detach().cpu().numpy().tolist() 
            else:
                out_tracks_json[i]['values'] = rot[:, i-1, :].flatten().detach().cpu().numpy().tolist()
            out_tracks_json[i]['times'] = times

        return out_tracks_json


================================================
FILE: docker/Dockerfile
================================================
FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel

# For the convenience for users in China mainland
COPY apt-sources.list /etc/apt/sources.list
# Install some basic utilities
RUN rm /etc/apt/sources.list.d/cuda.list
RUN rm /etc/apt/sources.list.d/nvidia-ml.list
RUN apt-get update && apt-get install -y \
    curl \
    ca-certificates \
    sudo \
    git \
    bzip2 \
    libx11-6 \
    gcc \
    g++ \
    libusb-1.0-0 \
    libgl1-mesa-glx \
    libglib2.0-dev \
    openssh-server \
    openssh-client \
    iputils-ping \
    unzip \
    cmake \
    libssl-dev \
    libosmesa6-dev \
    freeglut3-dev \
    ffmpeg \
    iputils-ping \
 && rm -rf /var/lib/apt/lists/*

# For the convenience for users in China mainland
RUN pip config set global.index-url https://pypi.mirrors.ustc.edu.cn/simple \
  && export PATH="/usr/local/bin:$PATH" \
  && /bin/bash -c "source ~/.bashrc"
RUN conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ \
 && conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ \
 && conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/ \
 && conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ \
 && conda config --set show_channel_urls yes

# Install dependencies
COPY requirements.txt requirements.txt 
RUN pip install -r requirements.txt --user 

CMD ["python3"]

================================================
FILE: docker/README.md
================================================
## Build Docker Environment and use with GPU Support

Before you can use this Docker environment, you need to have the following:

- Docker installed on your system
- NVIDIA drivers installed on your system
- NVIDIA Container Toolkit installed on your system


### Build and Run
1. Build docker image:
   ```sh
   docker build -t GenMM:latest .
   ```
2. Start the docker container:
   ```sh
   docker run --gpus all -it GenMM:latest /bin/bash
   ```
3. Clone the repository:
   ```sh
   git clone git@github.com:wyysf-98/GenMM.git
   ```

## Troubleshooting

If you encounter any issues with the Docker environment with GPU support, please check the following:

- Make sure that you have installed the NVIDIA drivers and NVIDIA Container Toolkit on your system.
- Make sure that you have specified the --gpus all option when starting the Docker container.
- Make sure that your deep learning application is configured to use the GPU.

================================================
FILE: docker/apt-sources.list
================================================
deb https://mirrors.ustc.edu.cn/ubuntu/ bionic main restricted universe multiverse
deb-src https://mirrors.ustc.edu.cn/ubuntu/ bionic main restricted universe multiverse
deb https://mirrors.ustc.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse
deb-src https://mirrors.ustc.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse
deb https://mirrors.ustc.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse
deb-src https://mirrors.ustc.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse
deb https://mirrors.ustc.edu.cn/ubuntu/ bionic-security main restricted universe multiverse
deb-src https://mirrors.ustc.edu.cn/ubuntu/ bionic-security main restricted universe multiverse
deb https://mirrors.ustc.edu.cn/ubuntu/ bionic-proposed main restricted universe multiverse
deb-src https://mirrors.ustc.edu.cn/ubuntu/ bionic-proposed main restricted universe multiverse

================================================
FILE: docker/requirements.txt
================================================
torch==1.12.1
torchvision==0.13.1
tensorboardX==2.5
tqdm==4.62.3
unfoldNd==0.2.0
pyyaml>=5.3.1
gradio==3.34.0
matplotlib==3.3.2

================================================
FILE: docker/requirements_blender.txt
================================================
torch==2.2.0
torchvision==0.17.0
tqdm==4.62.3
unfoldNd==0.2.0
pyyaml>=5.3.1


================================================
FILE: fix_contact.py
================================================
from dataset.bvh.bvh_parser import BVH_file
from os.path import join as pjoin
import numpy as np
import torch
from utils.contact import constrain_from_contact
from utils.kinematics import InverseKinematicsJoint2
from utils.transforms import repr6d2quat
from tqdm import tqdm
import argparse
import matplotlib.pyplot as plt
from dataset.bvh_motion import skeleton_confs

def continuous_filter(contact, length=2):
    contact = contact.copy()
    for j in range(contact.shape[1]):
        c = contact[:, j]
        t_len = 0
        prev = c[0]
        for i in range(contact.shape[0]):
            if prev == c[i]:
                t_len += 1
            else:
                if t_len <= length:
                    c[i - t_len:i] = c[i]
                t_len = 1
                prev = c[i]
    return contact


def fix_negative_height(contact, constrain, cid):
    floor = -1
    constrain = constrain.clone()
    for i in range(constrain.shape[0]):
        for j in range(constrain.shape[1]):
            if constrain[i, j, 1] < floor:
                constrain[i, j, 1] = floor
    return constrain


def fix_contact(bvh_file, contact):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    cid = bvh_file.skeleton.contact_id
    glb = bvh_file.joint_position()
    rotation = bvh_file.get_rotation(repr='repr6d').to(device)
    position = bvh_file.get_position().to(device)
    contact = contact > 0.5
    # contact = continuous_filter(contact)
    constrain = constrain_from_contact(contact, glb, cid)
    constrain = fix_negative_height(contact, constrain, cid).to(device)
    cid = list(range(glb.shape[1]))
    ik_solver = InverseKinematicsJoint2(rotation, position, bvh_file.skeleton.offsets.to(device), bvh_file.skeleton.parent,
                                        constrain[:, cid], cid, 0.1, 0.01, use_velo=True)

    loop = tqdm(range(500))
    losses = []
    for i in loop:
        loss = ik_solver.step()
        loop.set_description(f'loss = {loss:.07f}')
        losses += [loss]
        plt.plot(losses)
    

    return repr6d2quat(ik_solver.rotations.detach()), ik_solver.get_position()


def fix_contact_on_file(prefix, name):
    try:
        contact = np.load(pjoin(prefix, name + '.bvh.contact.npy'))
    except:
        print(f'{name} not found')
        return
    bvh_file = BVH_file(pjoin(prefix, name + '.bvh'), no_scale=True, requires_contact=True)
    print('Fixing foot contact with IK...')
    res = fix_contact(bvh_file, contact)
    bvh_file.writer.write(pjoin(prefix, name + '_fixed.bvh'), res[0], res[1], names=bvh_file.skeleton.names, repr='quat')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--prefix', type=str, required=True)
    parser.add_argument('--name', type=str, required=True)
    parser.add_argument('--skeleton_name', type=str, required=True)
    args = parser.parse_args()
    if args.prefix[0] == '/':
        prefix = args.prefix
    else:
        prefix = f'./results/{args.prefix}'
    name = args.name
    contact = np.load(pjoin(prefix, name + '.bvh.contact.npy'))
    bvh_file = BVH_file(pjoin(prefix, name + '.bvh'), skeleton_confs[args.skeleton_name], auto_scale=False, requires_contact=True)

    res = fix_contact(bvh_file, contact)
    plt.savefig(f'{prefix}/losses.png')

    bvh_file.writer.write(pjoin(prefix, name + '_fixed.bvh'), res[0], res[1], names=bvh_file.skeleton.names, repr='quat')

================================================
FILE: nearest_neighbor/losses.py
================================================
import torch    
import torch.nn as nn

from .utils import extract_patches, combine_patches, efficient_cdist, get_NNs_Dists

class PatchCoherentLoss(torch.nn.Module):
    def __init__(self, patch_size=7, stride=1, alpha=None, loop=False, cache=False):
        super(PatchCoherentLoss, self).__init__()
        self.patch_size = patch_size
        assert self.patch_size % 2 == 1, "Only support odd patch size"
        self.stride = stride
        assert self.stride == 1, "Only support stride of 1"
        self.alpha = alpha
        self.loop = loop
        self.cache = cache
        if cache:
            self.cached_data = None

    def forward(self, X, Ys, dist_wrapper=None, ext=None, return_blended_results=False):
        """For each patch in input X find its NN in target Y and sum the their distances"""
        assert X.shape[0] == 1, "Only support batch size of 1 for X"
        dist_fn = lambda X, Y: dist_wrapper(efficient_cdist, X, Y) if dist_wrapper is not None else efficient_cdist(X, Y)

        x_patches = extract_patches(X, self.patch_size, self.stride, loop=self.loop)

        if not self.cache or self.cached_data is None:
            y_patches = []
            for y in Ys:
                y_patches += [extract_patches(y, self.patch_size, self.stride, loop=False)]
            y_patches = torch.cat(y_patches, dim=1)
            self.cached_data = y_patches
        else:
            y_patches = self.cached_data
        
        nnf, dist = get_NNs_Dists(dist_fn, x_patches.squeeze(0), y_patches.squeeze(0), self.alpha)

        if return_blended_results:
            return combine_patches(X.shape, y_patches[:, nnf, :], self.patch_size, self.stride, loop=self.loop), dist.mean()
        else:
            return dist.mean()
    
    def clean_cache(self):
        self.cached_data = None

================================================
FILE: nearest_neighbor/utils.py
================================================
"""
this file borrows some codes from https://github.com/ariel415el/Efficient-GPNN/blob/main/utils/NN.py.
"""
import torch
import torch.nn.functional as F
import unfoldNd

def extract_patches(x, patch_size, stride, loop=False):
    """Extract patches from a motion sequence"""
    b, c, _t = x.shape

    # manually padding to loop
    if loop:
        half = patch_size // 2
        front, tail = x[:,:,:half], x[:,:,-half:]
        x = torch.concat([tail, x, front], dim=-1)

    x_patches = unfoldNd.unfoldNd(x, kernel_size=patch_size, stride=stride).transpose(1, 2).reshape(b, -1, c, patch_size)
    
    return x_patches.view(b, -1, c * patch_size)

def combine_patches(x_shape, ys, patch_size, stride, loop=False):
    """Combine motion patches"""
    
    # manually handle the loop situation
    out_shape = [*x_shape]
    if loop:
        padding = patch_size // 2
        out_shape[-1] = out_shape[-1] + padding * 2

    combined = unfoldNd.foldNd(ys.permute(0, 2, 1), output_size=tuple(out_shape[-1:]), kernel_size=patch_size, stride=stride)

    # normal fold matrix
    input_ones = torch.ones(tuple(out_shape), dtype=ys.dtype, device=ys.device)
    divisor = unfoldNd.unfoldNd(input_ones, kernel_size=patch_size, stride=stride)
    divisor = unfoldNd.foldNd(divisor, output_size=tuple(out_shape[-1:]), kernel_size=patch_size, stride=stride)
    combined = (combined / divisor).squeeze(dim=0).unsqueeze(0)
    
    if loop:
        half = patch_size // 2
        front, tail = combined[:,:,:half], combined[:,:,-half:]
        combined[:, :, half:2 * half] = (combined[:, :, half:2 * half] + tail) / 2
        combined[:, :, - 2 * half:-half] = (front + combined[:, :, - 2 * half:-half]) / 2
        combined = combined[:, :, half:-half]

    return combined


def efficient_cdist(X, Y):
    """
    borrowed from https://github.com/ariel415el/Efficient-GPNN/blob/main/utils/NN.py
    Pytorch efficient way of computing distances between all vectors in X and Y, i.e (X[:, None] - Y[None, :])**2
    Get the nearest neighbor index from Y for each X
    :param X:  (n1, d) tensor
    :param Y:  (n2, d) tensor
    Returns a n2 n1 of indices
    """
    dist = (X * X).sum(1)[:, None] + (Y * Y).sum(1)[None, :] - 2.0 * torch.mm(X, torch.transpose(Y, 0, 1))
    d = X.shape[1]
    dist /= d # normalize by size of vector to make dists independent of the size of d ( use same alpha for all patche-sizes)
    return dist # DO NOT use torch.sqrt


def get_col_mins_efficient(dist_fn, X, Y, b=1024):
    """
    borrowed from https://github.com/ariel415el/Efficient-GPNN/blob/main/utils/NN.py
    Computes the l2 distance to the closest x or each y.
    :param X:  (n1, d) tensor
    :param Y:  (n2, d) tensor
    Returns n1 long array of L2 distances
    """
    n_batches = len(Y) // b
    mins = torch.zeros(Y.shape[0], dtype=X.dtype, device=X.device)
    for i in range(n_batches):
        mins[i * b:(i + 1) * b] = dist_fn(X, Y[i * b:(i + 1) * b]).min(0)[0]
    if len(Y) % b != 0:
        mins[n_batches * b:] = dist_fn(X, Y[n_batches * b:]).min(0)[0]

    return mins


def get_NNs_Dists(dist_fn, X, Y, alpha=None, b=1024):
    """
    borrowed from https://github.com/ariel415el/Efficient-GPNN/blob/main/utils/NN.py
    Get the nearest neighbor index from Y for each X.
    Avoids holding a (n1 * n2) amtrix in order to reducing memory footprint to (b * max(n1,n2)).
    :param X:  (n1, d) tensor
    :param Y:  (n2, d) tensor
    Returns a n2 n1 of indices amd distances
    """
    if alpha is not None:
        normalizing_row = get_col_mins_efficient(dist_fn, X, Y, b=b)
        normalizing_row = alpha + normalizing_row[None, :]
    else:
        normalizing_row = 1

    NNs = torch.zeros(X.shape[0], dtype=torch.long, device=X.device)
    Dists = torch.zeros(X.shape[0], dtype=torch.float, device=X.device)

    n_batches = len(X) // b
    for i in range(n_batches):
        dists = dist_fn(X[i * b:(i + 1) * b], Y) / normalizing_row
        NNs[i * b:(i + 1) * b] = dists.min(1)[1]
        Dists[i * b:(i + 1) * b] = dists.min(1)[0]
    if len(X) % b != 0:
        dists = dist_fn(X[n_batches * b:], Y) / normalizing_row
        NNs[n_batches * b:] = dists.min(1)[1]
        Dists[n_batches * b: ] = dists.min(1)[0]

    return NNs, Dists


================================================
FILE: run_random_generation.py
================================================
import os
import os.path as osp
import argparse
from GenMM import GenMM
from nearest_neighbor.losses import PatchCoherentLoss
from dataset.bvh_motion import BVHMotion, load_multiple_dataset
from utils.base import ConfigParser, set_seed

args = argparse.ArgumentParser(
    description='Random shuffle the input motion sequence')
args.add_argument('-m', '--mode', default='run',
                  choices=['run', 'eval', 'debug'], type=str, help='current run mode.')
args.add_argument('-i', '--input', required=True,
                  type=str, help='exemplar motion path.')
args.add_argument('-o', '--output_dir', default='./output',
                  type=str, help='output folder path for saving results.')
args.add_argument('-c', '--config', default='./configs/default.yaml',
                  type=str, help='config file path.')
args.add_argument('-s', '--seed', default=None,
                  type=int, help='random seed used.')
args.add_argument('-d', '--device', default="cuda:0",
                  type=str, help='device to use.')
args.add_argument('--post_precess', action='store_true',
                  help='whether to use IK post-process to fix foot contact.')

# Use argsparser to overwrite the configuration
# for dataset
args.add_argument('--skeleton_name', type=str,
                  help='(used when joint_reduction==True or contact==True) skeleton name to load pre-defined joints configuration.')
args.add_argument('--use_velo', type=int,
                  help='whether to use velocity rather than global position of each joint.')
args.add_argument('--repr', choices=['repr6d', 'quat', 'euler'], type=str,
                  help='rotation representation, support [epr6d, quat, reuler].')
args.add_argument('--requires_contact', type=int,
                  help='whether to use contact label.')
args.add_argument('--keep_up_pos', type=int,
                  help='whether to do not use velocity and keep the y(up) position.')
args.add_argument('--up_axis', type=str, choices=['X_UP', 'Y_UP', 'Z_UP'],
                  help='up axis of the motion.')
args.add_argument('--padding_last', type=int,
                  help='whether to pad the last position channel to match the rotation dimension.')
args.add_argument('--joint_reduction', type=int,
                  help='whether to simplify the skeleton using provided skeleton config.')
args.add_argument('--skeleton_aware', type=int,
                  help='whether to enable skeleton-aware component.')
args.add_argument('--joints_group', type=str,
                  help='joints spliting group for using skeleton-aware component.')
# for synthesis
args.add_argument('--num_frames', type=str, 
                  help='number of synthesized frames, supported Nx(N times) and int input.')
args.add_argument('--alpha', type=float,
                  help='completeness/diversity trade-off alpha value.')
args.add_argument('--num_steps', type=int,
                  help='number of optimization steps at each pyramid level.')
args.add_argument('--noise_sigma', type=float,
                  help='standard deviation of the zero mean normal noise added to the initialization.')
args.add_argument('--coarse_ratio', type=float,
                  help='downscale ratio of the coarse level.')
args.add_argument('--coarse_ratio_factor', type=float,
                  help='downscale ratio of the coarse level.')
args.add_argument('--pyr_factor', type=float,
                  help='upsample ratio of each pyramid level.')
args.add_argument('--num_stages_limit', type=int,
                  help='limit of the number of stages.')
args.add_argument('--patch_size', type=int, help='patch size for generation.')
args.add_argument('--loop', type=int, help='whether to loop the sequence.')
cfg = ConfigParser(args)


def generate(cfg):
    # seet seed for reproducible
    set_seed(cfg.seed)

    # set save path and prepare data for generation
    if cfg.input.endswith('.bvh'):
        base_dir = osp.join(
            cfg.output_dir, cfg.input.split('/')[-1].split('.')[0])
        motion_data = [BVHMotion(cfg.input, skeleton_name=cfg.skeleton_name, repr=cfg.repr,
                                 use_velo=cfg.use_velo, keep_up_pos=cfg.keep_up_pos, up_axis=cfg.up_axis, padding_last=cfg.padding_last,
                                 requires_contact=cfg.requires_contact, joint_reduction=cfg.joint_reduction)]
    elif cfg.input.endswith('.txt'):
        base_dir = osp.join(cfg.output_dir, cfg.input.split(
            '/')[-2], cfg.input.split('/')[-1].split('.')[0])
        motion_data = load_multiple_dataset(name_list=cfg.input, skeleton_name=cfg.skeleton_name, repr=cfg.repr,
                                            use_velo=cfg.use_velo, keep_up_pos=cfg.keep_up_pos, up_axis=cfg.up_axis, padding_last=cfg.padding_last,
                                            requires_contact=cfg.requires_contact, joint_reduction=cfg.joint_reduction)
    else:
        raise ValueError('exemplar must be a bvh file or a txt file')
    prefix = f"s{cfg.seed}+{cfg.num_frames}+{cfg.repr}+use_velo_{cfg.use_velo}+kypose_{cfg.keep_up_pos}+padding_{cfg.padding_last}" \
             f"+contact_{cfg.requires_contact}+jredu_{cfg.joint_reduction}+n{cfg.noise_sigma}+pyr{cfg.pyr_factor}" \
             f"+r{cfg.coarse_ratio}_{cfg.coarse_ratio_factor}+itr{cfg.num_steps}+ps_{cfg.patch_size}+alpha_{cfg.alpha}" \
             f"+loop_{cfg.loop}"

    # perform the generation
    model = GenMM(device=cfg.device, silent=True if cfg.mode == 'eval' else False)
    criteria = PatchCoherentLoss(patch_size=cfg.patch_size, alpha=cfg.alpha, loop=cfg.loop, cache=True)
    syn = model.run(motion_data, criteria,
                    num_frames=cfg.num_frames,
                    num_steps=cfg.num_steps,
                    noise_sigma=cfg.noise_sigma,
                    patch_size=cfg.patch_size, 
                    coarse_ratio=cfg.coarse_ratio,
                    pyr_factor=cfg.pyr_factor,
                    debug_dir=save_dir if cfg.mode == 'debug' else None)
    
    # save the generated results
    save_dir = osp.join(base_dir, prefix)
    os.makedirs(save_dir, exist_ok=True)
    motion_data[0].write(f"{save_dir}/syn.bvh", syn)

    if cfg.post_precess:
        cmd = f"python fix_contact.py --prefix {osp.abspath(save_dir)} --name syn --skeleton_name={cfg.skeleton_name}"
        os.system(cmd)

if __name__ == '__main__':
    generate(cfg)


================================================
FILE: run_web_server.py
================================================
import json
import time
import torch
import argparse
import gradio as gr

from GenMM import GenMM
from nearest_neighbor.losses import PatchCoherentLoss
from dataset.tracks_motion import TracksMotion

args = argparse.ArgumentParser(description='Web server for GenMM')
args.add_argument('-d', '--device', default="cuda:0", type=str, help='device to use.')
args.add_argument('--ip', default="0.0.0.0", type=str, help='interface url to host.')
args.add_argument('--port', default=8000, type=int, help='interface port to serve.')
args.add_argument('--debug', action='store_true', help='debug mode.')
args = args.parse_args()

def generate(data):
    data = json.loads(data)

    # create track object
    motion_data = [TracksMotion(data['tracks'], repr='repr6d', use_velo=True, keep_y_pos=True, padding_last=False)]
    model = GenMM(device=args.device, silent=True)
    criteria = PatchCoherentLoss(patch_size=data['setting']['patch_size'], 
                                alpha=data['setting']['alpha'] if data['setting']['completeness'] else None, 
                                loop=data['setting']['loop'], cache=True)

    # start generation
    start = time.time()
    syn = model.run(motion_data, criteria,
                    num_frames=str(data['setting']['frames']),
                    num_steps=data['setting']['num_steps'],
                    noise_sigma=data['setting']['noise_sigma'],
                    patch_size=data['setting']['patch_size'], 
                    coarse_ratio=f'{data["setting"]["coarse_ratio"]}x_nframes',
                    # coarse_ratio=f'3x_patchsize',
                    pyr_factor=data['setting']['pyr_factor'])
    end = time.time()

    data['time'] = end - start
    data['tracks'] = motion_data[0].parse(syn)

    return data

if __name__ == '__main__':
    demo = gr.Interface(fn=generate, inputs="json", outputs="json")
    demo.launch(debug=args.debug, server_name=args.ip, server_port=args.port)

================================================
FILE: utils/base.py
================================================
import os
import os.path as osp
import sys
import time
import yaml
import imageio
import random
import shutil
import random
import numpy as np
import torch
from tqdm import tqdm

# configuration
class ConfigParser():
    def __init__(self, args):
        """
        class to parse configuration.
        """
        args = args.parse_args()
        self.cfg = self.merge_config_file(args)

        # set random seed
        self.set_seed()

    def __str__(self):
        return str(self.cfg.__dict__)

    def __getattr__(self, name):
        """
        Access items use dot.notation.
        """
        return self.cfg.__dict__[name]

    def __getitem__(self, name):
        """
        Access items like ordinary dict.
        """
        return self.cfg.__dict__[name]

    def merge_config_file(self, args, allow_invalid=True):
        """
        Load json config file and merge the arguments
        """
        assert args.config is not None
        with open(args.config, 'r') as f:
            cfg = yaml.safe_load(f)
            if 'config' in cfg.keys():
                del cfg['config']
        f.close()
        invalid_args = list(set(cfg.keys()) - set(dir(args)))
        if invalid_args and not allow_invalid:
            raise ValueError(f"Invalid args {invalid_args} in {args.config}.")
        
        for k in list(cfg.keys()):
            if k in args.__dict__.keys() and args.__dict__[k] is not None:
                print('=========>  overwrite config: {} = {}'.format(k, args.__dict__[k]))
                del cfg[k]

        args.__dict__.update(cfg)

        return args

    def set_seed(self):
        ''' set random seed for random, numpy and torch. '''
        if 'seed' not in self.cfg.__dict__.keys():
            return
        if self.cfg.seed is None:
            self.cfg.seed = int(time.time()) % 1000000
        print('=========>  set random seed: {}'.format(self.cfg.seed))
        # fix random seeds for reproducibility
        random.seed(self.cfg.seed)
        np.random.seed(self.cfg.seed)
        torch.manual_seed(self.cfg.seed)
        torch.cuda.manual_seed(self.cfg.seed)

    def save_codes_and_config(self, save_path):
        """
        save codes and config to $save_path.
        """
        cur_codes_path = osp.dirname(osp.dirname(os.path.abspath(__file__)))
        if os.path.exists(save_path):
            shutil.rmtree(save_path)
        shutil.copytree(cur_codes_path, osp.join(save_path, 'codes'), \
            ignore=shutil.ignore_patterns('*debug*', '*data*', '*output*', '*exps*', '*.txt', '*.json', '*.mp4', '*.png', '*.jpg', '*.bvh', '*.csv', '*.pth', '*.tar', '*.npz'))

        with open(osp.join(save_path, 'config.yaml'), 'w') as f:
            f.write(yaml.dump(self.cfg.__dict__))
        f.close()


# logger util
class logger:
    """
    Keeps track of the levels and steps of optimization. Logs it via TQDM
    """
    def __init__(self, n_steps, n_lvls):
        self.n_steps = n_steps
        self.n_lvls = n_lvls
        self.lvl = -1
        self.lvl_step = 0
        self.steps = 0
        self.pbar = tqdm(total=self.n_lvls * self.n_steps, desc='Starting')

    def step(self):
        self.pbar.update(1)
        self.steps += 1
        self.lvl_step += 1

    def new_lvl(self):
        self.lvl += 1
        self.lvl_step = 0

    def print(self):
        self.pbar.set_description(f'Lvl {self.lvl}/{self.n_lvls-1}, step {self.lvl_step}/{self.n_steps}')


# other utils
def set_seed(seed=None):
    """
    Set all the seed for the reproducible
    """
    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

================================================
FILE: utils/contact.py
================================================
import torch


def foot_contact_by_height(pos):
    eps = 0.25
    return (-eps < pos[..., 1]) * (pos[..., 1] < eps)


def velocity(pos, padding=False):
    velo = pos[1:, ...] - pos[:-1, ...]
    velo_norm = torch.norm(velo, dim=-1)
    if padding:
        pad = torch.zeros_like(velo_norm[:1, :])
        velo_norm = torch.cat([pad, velo_norm], dim=0)
    return velo_norm


def foot_contact(pos, ref_height=1., threshold=0.018):
    velo_norm = velocity(pos)
    contact = velo_norm < threshold
    contact = contact.int()
    padding = torch.zeros_like(contact)
    contact = torch.cat([padding[:1, :], contact], dim=0)
    return contact


def alpha(t):
    return 2.0 * t * t * t - 3.0 * t * t + 1


def lerp(a, l, r):
    return (1 - a) * l + a * r


def constrain_from_contact(contact, glb, fid='TBD', L=5):
    """
    :param contact: contact label
    :param glb: original global position
    :param fid: joint id to fix, corresponding to the order in contact
    :param L: frame to look forward/backward
    :return:
    """
    T = glb.shape[0]

    for i, fidx in enumerate(fid):  # fidx: index of the foot joint
        fixed = contact[:, i]  # [T]
        s = 0
        while s < T:
            while s < T and fixed[s] == 0:
                s += 1
            if s >= T:
                break
            t = s
            avg = glb[t, fidx].clone()
            while t + 1 < T and fixed[t + 1] == 1:
                t += 1
                avg += glb[t, fidx].clone()
            avg /= (t - s + 1)

            for j in range(s, t + 1):
                glb[j, fidx] = avg.clone()
            s = t + 1

        for s in range(T):
            if fixed[s] == 1:
                continue
            l, r = None, None
            consl, consr = False, False
            for k in range(L):
                if s - k - 1 < 0:
                    break
                if fixed[s - k - 1]:
                    l = s - k - 1
                    consl = True
                    break
            for k in range(L):
                if s + k + 1 >= T:
                    break
                if fixed[s + k + 1]:
                    r = s + k + 1
                    consr = True
                    break
            if not consl and not consr:
                continue
            if consl and consr:
                litp = lerp(alpha(1.0 * (s - l + 1) / (L + 1)),
                            glb[s, fidx], glb[l, fidx])
                ritp = lerp(alpha(1.0 * (r - s + 1) / (L + 1)),
                            glb[s, fidx], glb[r, fidx])
                itp = lerp(alpha(1.0 * (s - l + 1) / (r - l + 1)),
                           ritp, litp)
                glb[s, fidx] = itp.clone()
                continue
            if consl:
                litp = lerp(alpha(1.0 * (s - l + 1) / (L + 1)),
                            glb[s, fidx], glb[l, fidx])
                glb[s, fidx] = litp.clone()
                continue
            if consr:
                ritp = lerp(alpha(1.0 * (r - s + 1) / (L + 1)),
                            glb[s, fidx], glb[r, fidx])
                glb[s, fidx] = ritp.clone()
    return glb


================================================
FILE: utils/kinematics.py
================================================
import torch
from utils.transforms import quat2mat, repr6d2mat, euler2mat


class ForwardKinematics:
    def __init__(self, parents, offsets=None):
        self.parents = parents
        if offsets is not None and len(offsets.shape) == 2:
            offsets = offsets.unsqueeze(0)
        self.offsets = offsets

    def forward(self, rots, offsets=None, global_pos=None):
        """
        Forward Kinematics: returns a per-bone transformation
        @param rots: local joint rotations (batch_size, bone_num, 3, 3)
        @param offsets: (batch_size, bone_num, 3) or None
        @param global_pos: global_position: (batch_size, 3) or keep it as in offsets (default)
        @return: (batch_szie, bone_num, 3, 4)
        """
        rots = rots.clone()
        if offsets is None:
            offsets = self.offsets.to(rots.device)
        if global_pos is None:
            global_pos = offsets[:, 0]

        pos = torch.zeros((rots.shape[0], rots.shape[1], 3), device=rots.device)
        rest_pos = torch.zeros_like(pos)
        res = torch.zeros((rots.shape[0], rots.shape[1], 3, 4), device=rots.device)

        pos[:, 0] = global_pos
        rest_pos[:, 0] = offsets[:, 0]

        for i, p in enumerate(self.parents):
            if i != 0:
                rots[:, i] = torch.matmul(rots[:, p], rots[:, i])
                pos[:, i] = torch.matmul(rots[:, p], offsets[:, i].unsqueeze(-1)).squeeze(-1) + pos[:, p]
                rest_pos[:, i] = rest_pos[:, p] + offsets[:, i]

            res[:, i, :3, :3] = rots[:, i]
            res[:, i, :, 3] = torch.matmul(rots[:, i], -rest_pos[:, i].unsqueeze(-1)).squeeze(-1) + pos[:, i]

        return res

    def accumulate(self, local_rots):
        """
        Get global joint rotation from local rotations
        @param local_rots: (batch_size, n_bone, 3, 3)
        @return: global_rotations
        """
        res = torch.empty_like(local_rots)
        for i, p in enumerate(self.parents):
            if i == 0:
                res[:, i] = local_rots[:, i]
            else:
                res[:, i] = torch.matmul(res[:, p], local_rots[:, i])
        return res

    def unaccumulate(self, global_rots):
        """
        Get local joint rotation from global rotations
        @param global_rots: (batch_size, n_bone, 3, 3)
        @return: local_rotations
        """
        res = torch.empty_like(global_rots)
        inv = torch.empty_like(global_rots)

        for i, p in enumerate(self.parents):
            if i == 0:
                inv[:, i] = global_rots[:, i].transpose(-2, -1)
                res[:, i] = global_rots[:, i]
                continue
            res[:, i] = torch.matmul(inv[:, p], global_rots[:, i])
            inv[:, i] = torch.matmul(res[:, i].transpose(-2, -1), inv[:, p])

        return res


class ForwardKinematicsJoint:
    def __init__(self, parents, offset):
        self.parents = parents
        self.offset = offset

    '''
        rotation should have shape batch_size * Joint_num * (3/4) * Time
        position should have shape batch_size * 3 * Time
        offset should have shape batch_size * Joint_num * 3
        output have shape batch_size * Time * Joint_num * 3
    '''

    def forward(self, rotation: torch.Tensor, position: torch.Tensor, offset=None,
                world=True):
        '''
        if not quater and rotation.shape[-2] != 3: raise Exception('Unexpected shape of rotation')
        if quater and rotation.shape[-2] != 4: raise Exception('Unexpected shape of rotation')
        rotation = rotation.permute(0, 3, 1, 2)
        position = position.permute(0, 2, 1)
        '''
        if rotation.shape[-1] == 6:
            transform = repr6d2mat(rotation)
        elif rotation.shape[-1] == 4:
            norm = torch.norm(rotation, dim=-1, keepdim=True)
            rotation = rotation / norm
            transform = quat2mat(rotation)
        elif rotation.shape[-1] == 3:
            transform = euler2mat(rotation)
        else:
            raise Exception('Only accept quaternion rotation input')
        result = torch.empty(transform.shape[:-2] + (3,), device=position.device)

        if offset is None:
            offset = self.offset
        offset = offset.reshape((-1, 1, offset.shape[-2], offset.shape[-1], 1))

        result[..., 0, :] = position
        for i, pi in enumerate(self.parents):
            if pi == -1:
                assert i == 0
                continue

            result[..., i, :] = torch.matmul(transform[..., pi, :, :], offset[..., i, :, :]).squeeze()
            transform[..., i, :, :] = torch.matmul(transform[..., pi, :, :].clone(), transform[..., i, :, :].clone())
            if world: result[..., i, :] += result[..., pi, :]
        return result


class InverseKinematicsJoint:
    def __init__(self, rotations: torch.Tensor, positions: torch.Tensor, offset, parents, constrains):
        self.rotations = rotations.detach().clone()
        self.rotations.requires_grad_(True)
        self.position = positions.detach().clone()
        self.position.requires_grad_(True)

        self.parents = parents
        self.offset = offset
        self.constrains = constrains

        self.optimizer = torch.optim.Adam([self.position, self.rotations], lr=1e-3, betas=(0.9, 0.999))
        self.criteria = torch.nn.MSELoss()

        self.fk = ForwardKinematicsJoint(parents, offset)

        self.glb = None

    def step(self):
        self.optimizer.zero_grad()
        glb = self.fk.forward(self.rotations, self.position)
        loss = self.criteria(glb, self.constrains)
        loss.backward()
        self.optimizer.step()
        self.glb = glb
        return loss.item()


class InverseKinematicsJoint2:
    def __init__(self, rotations: torch.Tensor, positions: torch.Tensor, offset, parents, constrains, cid,
                 lambda_rec_rot=1., lambda_rec_pos=1., use_velo=False):
        self.use_velo = use_velo
        self.rotations_ori = rotations.detach().clone()
        self.rotations = rotations.detach().clone()
        self.rotations.requires_grad_(True)
        self.position_ori = positions.detach().clone()
        self.position = positions.detach().clone()
        if self.use_velo:
            self.position[1:] = self.position[1:] - self.position[:-1]
        self.position.requires_grad_(True)

        self.parents = parents
        self.offset = offset
        self.constrains = constrains.detach().clone()
        self.cid = cid

        self.lambda_rec_rot = lambda_rec_rot
        self.lambda_rec_pos = lambda_rec_pos

        self.optimizer = torch.optim.Adam([self.position, self.rotations], lr=1e-3, betas=(0.9, 0.999))
        self.criteria = torch.nn.MSELoss()

        self.fk = ForwardKinematicsJoint(parents, offset)

        self.glb = None

    def step(self):
        self.optimizer.zero_grad()
        if self.use_velo:
            position = torch.cumsum(self.position, dim=0)
        else:
            position = self.position
        glb = self.fk.forward(self.rotations, position)
        self.constrain_loss = self.criteria(glb[:, self.cid], self.constrains)
        self.rec_loss_rot = self.criteria(self.rotations, self.rotations_ori)
        self.rec_loss_pos = self.criteria(self.position, self.position_ori)
        loss = self.constrain_loss + self.rec_loss_rot * self.lambda_rec_rot + self.rec_loss_pos * self.lambda_rec_pos
        loss.backward()
        self.optimizer.step()
        self.glb = glb
        return loss.item()

    def get_position(self):
        if self.use_velo:
            position = torch.cumsum(self.position.detach(), dim=0)
        else:
            position = self.position.detach()
        return position


================================================
FILE: utils/rename_mixamo_rig.py
================================================
# rename_mixamo_prefix.py
import bpy, re
rx = re.compile(r"mixamorig\d+:")          # any number before the colon

for obj in bpy.data.objects:
    if obj.type == 'ARMATURE':
        for b in obj.data.bones:
            b.name = rx.sub("mixamorig:", b.name)

================================================
FILE: utils/skeleton.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np


class SkeletonConv(nn.Module):
    def __init__(self, neighbour_list, in_channels, out_channels, kernel_size, joint_num, stride=1, padding=0,
                 bias=True, padding_mode='zeros', add_offset=False, in_offset_channel=0):
        super(SkeletonConv, self).__init__()

        if in_channels % joint_num != 0 or out_channels % joint_num != 0:
            raise Exception('in/out channels should be divided by joint_num')
        self.in_channels_per_joint = in_channels // joint_num
        self.out_channels_per_joint = out_channels // joint_num

        if padding_mode == 'zeros': padding_mode = 'constant'

        self.expanded_neighbour_list = []
        self.expanded_neighbour_list_offset = []
        self.neighbour_list = neighbour_list
        self.add_offset = add_offset
        self.joint_num = joint_num

        self.stride = stride
        self.dilation = 1
        self.groups = 1
        self.padding = padding
        self.padding_mode = padding_mode
        self._padding_repeated_twice = (padding, padding)

        for neighbour in neighbour_list:
            expanded = []
            for k in neighbour:
                for i in range(self.in_channels_per_joint):
                    expanded.append(k * self.in_channels_per_joint + i)
            self.expanded_neighbour_list.append(expanded)

        if self.add_offset:
            self.offset_enc = SkeletonLinear(neighbour_list, in_offset_channel * len(neighbour_list), out_channels)

            for neighbour in neighbour_list:
                expanded = []
                for k in neighbour:
                    for i in range(add_offset):
                        expanded.append(k * in_offset_channel + i)
                self.expanded_neighbour_list_offset.append(expanded)

        self.weight = torch.zeros(out_channels, in_channels, kernel_size)
        if bias:
            self.bias = torch.zeros(out_channels)
        else:
            self.register_parameter('bias', None)

        self.mask = torch.zeros_like(self.weight)
        for i, neighbour in enumerate(self.expanded_neighbour_list):
            self.mask[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...] = 1
        self.mask = nn.Parameter(self.mask, requires_grad=False)

        self.description = 'SkeletonConv(in_channels_per_armature={}, out_channels_per_armature={}, kernel_size={}, ' \
                           'joint_num={}, stride={}, padding={}, bias={})'.format(
            in_channels // joint_num, out_channels // joint_num, kernel_size, joint_num, stride, padding, bias
        )

        self.reset_parameters()

    def reset_parameters(self):
        for i, neighbour in enumerate(self.expanded_neighbour_list):
            """ Use temporary variable to avoid assign to copy of slice, which might lead to un expected result """
            tmp = torch.zeros_like(self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1),
                                   neighbour, ...])
            nn.init.kaiming_uniform_(tmp, a=math.sqrt(5))
            self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1),
                        neighbour, ...] = tmp
            if self.bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(
                    self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...])
                bound = 1 / math.sqrt(fan_in)
                tmp = torch.zeros_like(
                    self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)])
                nn.init.uniform_(tmp, -bound, bound)
                self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)] = tmp

        self.weight = nn.Parameter(self.weight)
        if self.bias is not None:
            self.bias = nn.Parameter(self.bias)

    def set_offset(self, offset):
        if not self.add_offset: raise Exception('Wrong Combination of Parameters')
        self.offset = offset.reshape(offset.shape[0], -1)

    def forward(self, input):
        weight_masked = self.weight * self.mask
        res = F.conv1d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
                       weight_masked, self.bias, self.stride,
                       0, self.dilation, self.groups)

        if self.add_offset:
            offset_res = self.offset_enc(self.offset)
            offset_res = offset_res.reshape(offset_res.shape + (1, ))
            res += offset_res / 100
        return res

    def __repr__(self):
        return self.description


class SkeletonLinear(nn.Module):
    def __init__(self, neighbour_list, in_channels, out_channels, extra_dim1=False):
        super(SkeletonLinear, self).__init__()
        self.neighbour_list = neighbour_list
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.in_channels_per_joint = in_channels // len(neighbour_list)
        self.out_channels_per_joint = out_channels // len(neighbour_list)
        self.extra_dim1 = extra_dim1
        self.expanded_neighbour_list = []

        for neighbour in neighbour_list:
            expanded = []
            for k in neighbour:
                for i in range(self.in_channels_per_joint):
                    expanded.append(k * self.in_channels_per_joint + i)
            self.expanded_neighbour_list.append(expanded)

        self.weight = torch.zeros(out_channels, in_channels)
        self.mask = torch.zeros(out_channels, in_channels)
        self.bias = nn.Parameter(torch.Tensor(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        for i, neighbour in enumerate(self.expanded_neighbour_list):
            tmp = torch.zeros_like(
                self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour]
            )
            self.mask[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = 1
            nn.init.kaiming_uniform_(tmp, a=math.sqrt(5))
            self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = tmp

        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
        bound = 1 / math.sqrt(fan_in)
        nn.init.uniform_(self.bias, -bound, bound)

        self.weight = nn.Parameter(self.weight)
        self.mask = nn.Parameter(self.mask, requires_grad=False)

    def forward(self, input):
        input = input.reshape(input.shape[0], -1)
        weight_masked = self.weight * self.mask
        res = F.linear(input, weight_masked, self.bias)
        if self.extra_dim1: res = res.reshape(res.shape + (1,))
        return res


class SkeletonPoolJoint(nn.Module):
    def __init__(self, topology, pooling_mode, channels_per_joint, last_pool=False):
        super(SkeletonPoolJoint, self).__init__()

        if pooling_mode != 'mean':
            raise Exception('Unimplemented pooling mode in matrix_implementation')

        self.joint_num = len(topology)
        self.parent = topology
        self.pooling_list = []
        self.pooling_mode = pooling_mode

        self.pooling_map = [-1 for _ in range(len(self.parent))]
        self.child = [-1 for _ in range(len(self.parent))]
        children_cnt = [0 for _ in range(len(self.parent))]
        for x, pa in enumerate(self.parent):
            if pa < 0: continue
            children_cnt[pa] += 1
            self.child[pa] = x
        self.pooling_map[0] = 0
        for x in range(len(self.parent)):
            if children_cnt[x] == 0 or (children_cnt[x] == 1 and children_cnt[self.child[x]] > 1):
                while children_cnt[x] <= 1:
                    pa = self.parent[x]
                    if last_pool:
                        seq = [x]
                        while pa != -1 and children_cnt[pa] == 1:
                            seq = [pa] + seq
                            x = pa
                            pa = self.parent[x]
                        self.pooling_list.append(seq)
                        break
                    else:
                        if pa != -1 and children_cnt[pa] == 1:
                            self.pooling_list.append([pa, x])
                            x = self.parent[pa]
                        else:
                            self.pooling_list.append([x, ])
                            break
            elif children_cnt[x] > 1:
                self.pooling_list.append([x, ])

        self.description = 'SkeletonPool(in_joint_num={}, out_joint_num={})'.format(
            len(topology), len(self.pooling_list),
        )

        self.pooling_list.sort(key=lambda x:x[0])
        for i, a in enumerate(self.pooling_list):
            for j in a:
                self.pooling_map[j] = i

        self.output_joint_num = len(self.pooling_list)
        self.new_topology = [-1 for _ in range(len(self.pooling_list))]
        for i, x in enumerate(self.pooling_list):
            if i < 1: continue
            self.new_topology[i] = self.pooling_map[self.parent[x[0]]]

        self.weight = torch.zeros(len(self.pooling_list) * channels_per_joint, self.joint_num * channels_per_joint)

        for i, pair in enumerate(self.pooling_list):
            for j in pair:
                for c in range(channels_per_joint):
                    self.weight[i * channels_per_joint + c, j * channels_per_joint + c] = 1.0 / len(pair)

        self.weight = nn.Parameter(self.weight, requires_grad=False)

    def forward(self, input: torch.Tensor):
        return torch.matmul(self.weight, input.unsqueeze(-1)).squeeze(-1)


class SkeletonPool(nn.Module):
    def __init__(self, edges, pooling_mode, channels_per_edge, last_pool=False):
        super(SkeletonPool, self).__init__()

        if pooling_mode != 'mean':
            raise Exception('Unimplemented pooling mode in matrix_implementation')

        self.channels_per_edge = channels_per_edge
        self.pooling_mode = pooling_mode
        self.edge_num = len(edges) + 1
        self.seq_list = []
        self.pooling_list = []
        self.new_edges = []
        degree = [0] * 100

        for edge in edges:
            degree[edge[0]] += 1
            degree[edge[1]] += 1

        def find_seq(j, seq):
            nonlocal self, degree, edges

            if degree[j] > 2 and j != 0:
                self.seq_list.append(seq)
                seq = []

            if degree[j] == 1:
                self.seq_list.append(seq)
                return

            for idx, edge in enumerate(edges):
                if edge[0] == j:
                    find_seq(edge[1], seq + [idx])

        find_seq(0, [])
        for seq in self.seq_list:
            if last_pool:
                self.pooling_list.append(seq)
                continue
            if len(seq) % 2 == 1:
                self.pooling_list.append([seq[0]])
                self.new_edges.append(edges[seq[0]])
                seq = seq[1:]
            for i in range(0, len(seq), 2):
                self.pooling_list.append([seq[i], seq[i + 1]])
                self.new_edges.append([edges[seq[i]][0], edges[seq[i + 1]][1]])

        # add global position
        self.pooling_list.append([self.edge_num - 1])

        self.description = 'SkeletonPool(in_edge_num={}, out_edge_num={})'.format(
            len(edges), len(self.pooling_list)
        )

        self.weight = torch.zeros(len(self.pooling_list) * channels_per_edge, self.edge_num * channels_per_edge)

        for i, pair in enumerate(self.pooling_list):
            for j in pair:
                for c in range(channels_per_edge):
                    self.weight[i * channels_per_edge + c, j * channels_per_edge + c] = 1.0 / len(pair)

        self.weight = nn.Parameter(self.weight, requires_grad=False)

    def forward(self, input: torch.Tensor):
        return torch.matmul(self.weight, input)


class SkeletonUnpool(nn.Module):
    def __init__(self, pooling_list, channels_per_edge):
        super(SkeletonUnpool, self).__init__()
        self.pooling_list = pooling_list
        self.input_joint_num = len(pooling_list)
        self.output_joint_num = 0
        self.channels_per_edge = channels_per_edge
        for t in self.pooling_list:
            self.output_joint_num += len(t)

        self.description = 'SkeletonUnpool(in_joint_num={}, out_joint_num={})'.format(
            self.input_joint_num, self.output_joint_num,
        )

        self.weight = torch.zeros(self.output_joint_num * channels_per_edge, self.input_joint_num * channels_per_edge)

        for i, pair in enumerate(self.pooling_list):
            for j in pair:
                for c in range(channels_per_edge):
                    self.weight[j * channels_per_edge + c, i * channels_per_edge + c] = 1

        self.weight = nn.Parameter(self.weight)
        self.weight.requires_grad_(False)

    def forward(self, input: torch.Tensor):
        return torch.matmul(self.weight, input.unsqueeze(-1)).squeeze(-1)


def find_neighbor_joint(parents, threshold):
    n_joint = len(parents)
    dist_mat = np.empty((n_joint, n_joint), dtype=np.int)
    dist_mat[:, :] = 100000
    for i, p in enumerate(parents):
        dist_mat[i, i] = 0
        if i != 0:
            dist_mat[i, p] = dist_mat[p, i] = 1

    """
    Floyd's algorithm
    """
    for k in range(n_joint):
        for i in range(n_joint):
            for j in range(n_joint):
                dist_mat[i, j] = min(dist_mat[i, j], dist_mat[i, k] + dist_mat[k, j])

    neighbor_list = []
    for i in range(n_joint):
        neighbor = []
        for j in range(n_joint):
            if dist_mat[i, j] <= threshold:
                neighbor.append(j)
        neighbor_list.append(neighbor)

    return neighbor_list


================================================
FILE: utils/transforms.py
================================================
import numpy as np
import torch


def batch_mm(matrix, matrix_batch):
    """
    https://github.com/pytorch/pytorch/issues/14489#issuecomment-607730242
    :param matrix: Sparse or dense matrix, size (m, n).
    :param matrix_batch: Batched dense matrices, size (b, n, k).
    :return: The batched matrix-matrix product, size (m, n) x (b, n, k) = (b, m, k).
    """
    batch_size = matrix_batch.shape[0]
    # Stack the vector batch into columns. (b, n, k) -> (n, b, k) -> (n, b*k)
    vectors = matrix_batch.transpose(0, 1).reshape(matrix.shape[1], -1)

    # A matrix-matrix product is a batched matrix-vector product of the columns.
    # And then reverse the reshaping. (m, n) x (n, b*k) = (m, b*k) -> (m, b, k) -> (b, m, k)
    return matrix.mm(vectors).reshape(matrix.shape[0], batch_size, -1).transpose(1, 0)


def aa2quat(rots, form='wxyz', unified_orient=True):
    """
    Convert angle-axis representation to wxyz quaternion and to the half plan (w >= 0)
    @param rots: angle-axis rotations, (*, 3)
    @param form: quaternion format, either 'wxyz' or 'xyzw'
    @param unified_orient: Use unified orientation for quaternion (quaternion is dual cover of SO3)
    :return:
    """
    angles = rots.norm(dim=-1, keepdim=True)
    norm = angles.clone()
    norm[norm < 1e-8] = 1
    axis = rots / norm
    quats = torch.empty(rots.shape[:-1] + (4,), device=rots.device, dtype=rots.dtype)
    angles = angles * 0.5
    if form == 'wxyz':
        quats[..., 0] = torch.cos(angles.squeeze(-1))
        quats[..., 1:] = torch.sin(angles) * axis
    elif form == 'xyzw':
        quats[..., :3] = torch.sin(angles) * axis
        quats[..., 3] = torch.cos(angles.squeeze(-1))

    if unified_orient:
        idx = quats[..., 0] < 0
        quats[idx, :] *= -1

    return quats


def quat2aa(quats):
    """
    Convert wxyz quaternions to angle-axis representation
    :param quats:
    :return:
    """
    _cos = quats[..., 0]
    xyz = quats[..., 1:]
    _sin = xyz.norm(dim=-1)
    norm = _sin.clone()
    norm[norm < 1e-7] = 1
    axis = xyz / norm.unsqueeze(-1)
    angle = torch.atan2(_sin, _cos) * 2
    return axis * angle.unsqueeze(-1)


def quat2mat(quats: torch.Tensor):
    """
    Convert (w, x, y, z) quaternions to 3x3 rotation matrix
    :param quats: quaternions of shape (..., 4)
    :return:  rotation matrices of shape (..., 3, 3)
    """
    qw = quats[..., 0]
    qx = quats[..., 1]
    qy = quats[..., 2]
    qz = quats[..., 3]

    x2 = qx + qx
    y2 = qy + qy
    z2 = qz + qz
    xx = qx * x2
    yy = qy * y2
    wx = qw * x2
    xy = qx * y2
    yz = qy * z2
    wy = qw * y2
    xz = qx * z2
    zz = qz * z2
    wz = qw * z2

    m = torch.empty(quats.shape[:-1] + (3, 3), device=quats.device, dtype=quats.dtype)
    m[..., 0, 0] = 1.0 - (yy + zz)
    m[..., 0, 1] = xy - wz
    m[..., 0, 2] = xz + wy
    m[..., 1, 0] = xy + wz
    m[..., 1, 1] = 1.0 - (xx + zz)
    m[..., 1, 2] = yz - wx
    m[..., 2, 0] = xz - wy
    m[..., 2, 1] = yz + wx
    m[..., 2, 2] = 1.0 - (xx + yy)

    return m


def quat2euler(q, order='xyz', degrees=True):
    """
    Convert (w, x, y, z) quaternions to xyz euler angles. This is  used for bvh output.
    """
    q0 = q[..., 0]
    q1 = q[..., 1]
    q2 = q[..., 2]
    q3 = q[..., 3]
    es = torch.empty(q0.shape + (3,), device=q.device, dtype=q.dtype)

    if order == 'xyz':
        es[..., 2] = torch.atan2(2 * (q0 * q3 - q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
        es[..., 1] = torch.asin((2 * (q1 * q3 + q0 * q2)).clip(-1, 1))
        es[..., 0] = torch.atan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
    else:
        raise NotImplementedError('Cannot convert to ordering %s' % order)

    if degrees:
        es = es * 180 / np.pi

    return es


def euler2mat(rots, order='xyz'):
    axis = {'x': torch.tensor((1, 0, 0), device=rots.device),
            'y': torch.tensor((0, 1, 0), device=rots.device),
            'z': torch.tensor((0, 0, 1), device=rots.device)}

    rots = rots / 180 * np.pi
    mats = []
    for i in range(3):
        aa = axis[order[i]] * rots[..., i].unsqueeze(-1)
        mats.append(aa2mat(aa))
    return mats[0] @ (mats[1] @ mats[2])


def aa2mat(rots):
    """
    Convert angle-axis representation to rotation matrix
    :param rots: angle-axis representation
    :return:
    """
    quat = aa2quat(rots)
    mat = quat2mat(quat)
    return mat


def mat2quat(R) -> torch.Tensor:
    '''
    https://github.com/duolu/pyrotation/blob/master/pyrotation/pyrotation.py
    Convert a rotation matrix to a unit quaternion.

    This uses the Shepperd’s method for numerical stability.
    '''

    # The rotation matrix must be orthonormal

    w2 = (1 + R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2])
    x2 = (1 + R[..., 0, 0] - R[..., 1, 1] - R[..., 2, 2])
    y2 = (1 - R[..., 0, 0] + R[..., 1, 1] - R[..., 2, 2])
    z2 = (1 - R[..., 0, 0] - R[..., 1, 1] + R[..., 2, 2])

    yz = (R[..., 1, 2] + R[..., 2, 1])
    xz = (R[..., 2, 0] + R[..., 0, 2])
    xy = (R[..., 0, 1] + R[..., 1, 0])

    wx = (R[..., 2, 1] - R[..., 1, 2])
    wy = (R[..., 0, 2] - R[..., 2, 0])
    wz = (R[..., 1, 0] - R[..., 0, 1])

    w = torch.empty_like(x2)
    x = torch.empty_like(x2)
    y = torch.empty_like(x2)
    z = torch.empty_like(x2)

    flagA = (R[..., 2, 2] < 0) * (R[..., 0, 0] > R[..., 1, 1])
    flagB = (R[..., 2, 2] < 0) * (R[..., 0, 0] <= R[..., 1, 1])
    flagC = (R[..., 2, 2] >= 0) * (R[..., 0, 0] < -R[..., 1, 1])
    flagD = (R[..., 2, 2] >= 0) * (R[..., 0, 0] >= -R[..., 1, 1])

    x[flagA] = torch.sqrt(x2[flagA])
    w[flagA] = wx[flagA] / x[flagA]
    y[flagA] = xy[flagA] / x[flagA]
    z[flagA] = xz[flagA] / x[flagA]

    y[flagB] = torch.sqrt(y2[flagB])
    w[flagB] = wy[flagB] / y[flagB]
    x[flagB] = xy[flagB] / y[flagB]
    z[flagB] = yz[flagB] / y[flagB]

    z[flagC] = torch.sqrt(z2[flagC])
    w[flagC] = wz[flagC] / z[flagC]
    x[flagC] = xz[flagC] / z[flagC]
    y[flagC] = yz[flagC] / z[flagC]

    w[flagD] = torch.sqrt(w2[flagD])
    x[flagD] = wx[flagD] / w[flagD]
    y[flagD] = wy[flagD] / w[flagD]
    z[flagD] = wz[flagD] / w[flagD]

    # if R[..., 2, 2] < 0:
    #
    #     if R[..., 0, 0] > R[..., 1, 1]:
    #
    #         x = torch.sqrt(x2)
    #         w = wx / x
    #         y = xy / x
    #         z = xz / x
    #
    #     else:
    #
    #         y = torch.sqrt(y2)
    #         w = wy / y
    #         x = xy / y
    #         z = yz / y
    #
    # else:
    #
    #     if R[..., 0, 0] < -R[..., 1, 1]:
    #
    #         z = torch.sqrt(z2)
    #         w = wz / z
    #         x = xz / z
    #         y = yz / z
    #
    #     else:
    #
    #         w = torch.sqrt(w2)
    #         x = wx / w
    #         y = wy / w
    #         z = wz / w

    res = [w, x, y, z]
    res = [z.unsqueeze(-1) for z in res]

    return torch.cat(res, dim=-1) / 2


def quat2repr6d(quat):
    mat = quat2mat(quat)
    res = mat[..., :2, :]
    res = res.reshape(res.shape[:-2] + (6, ))
    return res


def repr6d2mat(repr):
    x = repr[..., :3]
    y = repr[..., 3:]
    x = x / x.norm(dim=-1, keepdim=True)
    z = torch.cross(x, y)
    z = z / z.norm(dim=-1, keepdim=True)
    y = torch.cross(z, x)
    res = [x, y, z]
    res = [v.unsqueeze(-2) for v in res]
    mat = torch.cat(res, dim=-2)
    return mat


def repr6d2quat(repr) -> torch.Tensor:
    x = repr[..., :3]
    y = repr[..., 3:]
    x = x / x.norm(dim=-1, keepdim=True)
    z = torch.cross(x, y)
    z = z / z.norm(dim=-1, keepdim=True)
    y = torch.cross(z, x)
    res = [x, y, z]
    res = [v.unsqueeze(-2) for v in res]
    mat = torch.cat(res, dim=-2)
    return mat2quat(mat)


def inv_affine(mat):
    """
    Calculate the inverse of any affine transformation
    """
    affine = torch.zeros((mat.shape[:2] + (1, 4)))
    affine[..., 3] = 1
    vert_mat = torch.cat((mat, affine), dim=2)
    vert_mat_inv = torch.inverse(vert_mat)
    return vert_mat_inv[..., :3, :]


def inv_rigid_affine(mat):
    """
    Calculate the inverse of a rigid affine transformation
    """
    res = mat.clone()
    res[..., :3] = mat[..., :3].transpose(-2, -1)
    res[..., 3] = -torch.matmul(res[..., :3], mat[..., 3].unsqueeze(-1)).squeeze(-1)
    return res


def generate_pose(batch_size, device, uniform=False, factor=1, root_rot=False, n_bone=None, ee=None):
    if n_bone is None: n_bone = 24
    if ee is not None:
        if root_rot:
            ee.append(0)
        n_bone_ = n_bone
        n_bone = len(ee)
    axis = torch.randn((batch_size, n_bone, 3), device=device)
    axis /= axis.norm(dim=-1, keepdim=True)
    if uniform:
        angle = torch.rand((batch_size, n_bone, 1), device=device) * np.pi
    else:
        angle = torch.randn((batch_size, n_bone, 1), device=device) * np.pi / 6 * factor
        angle.clamp(-np.pi, np.pi)
    poses = axis * angle
    if ee is not None:
        res = torch.zeros((batch_size, n_bone_, 3), device=device)
        for i, id in enumerate(ee):
            res[:, id] = poses[:, i]
        poses 
Download .txt
gitextract_jk8oqutc/

├── .gitignore
├── GenMM.py
├── LICENSE
├── README.md
├── __init__.py
├── configs/
│   ├── default.yaml
│   └── ganimator.yaml
├── dataset/
│   ├── blender_motion.py
│   ├── bvh/
│   │   ├── Quaternions.py
│   │   ├── bvh_io.py
│   │   ├── bvh_parser.py
│   │   └── bvh_writer.py
│   ├── bvh_motion.py
│   ├── motion.py
│   └── tracks_motion.py
├── demo.blend
├── docker/
│   ├── Dockerfile
│   ├── README.md
│   ├── apt-sources.list
│   ├── requirements.txt
│   └── requirements_blender.txt
├── fix_contact.py
├── nearest_neighbor/
│   ├── losses.py
│   └── utils.py
├── run_random_generation.py
├── run_web_server.py
└── utils/
    ├── base.py
    ├── contact.py
    ├── kinematics.py
    ├── rename_mixamo_rig.py
    ├── skeleton.py
    └── transforms.py
Download .txt
SYMBOL INDEX (239 symbols across 20 files)

FILE: GenMM.py
  class GenMM (line 9) | class GenMM:
    method __init__ (line 10) | def __init__(self, mode = 'random_synthesis', noise_sigma = 1.0, coars...
    method _get_pyramid_lengths (line 20) | def _get_pyramid_lengths(self, final_len, coarse_ratio, pyr_factor):
    method _get_target_pyramid (line 33) | def _get_target_pyramid(self, target, coarse_ratio, pyr_factor, num_st...
    method _get_initial_motion (line 63) | def _get_initial_motion(self, init_length, noise_sigma):
    method run (line 81) | def run(self, target, criteria, num_frames, num_steps, noise_sigma, pa...
    method match_and_blend (line 139) | def match_and_blend(synthesized, targets, criteria, n_steps, pbar, ext...

FILE: __init__.py
  function capture_rest_pose (line 43) | def capture_rest_pose(armature_obj):
  function get_bvh_data (line 60) | def get_bvh_data(context,
  class BVH_Node (line 312) | class BVH_Node:
    method __init__ (line 360) | def __init__(self, name, rest_head_world, rest_head_local, parent, cha...
    method __repr__ (line 382) | def __repr__(self):
  function sorted_nodes (line 392) | def sorted_nodes(bvh_nodes):
  function read_bvh (line 398) | def read_bvh(context, bvh_str, rotate_mode='XYZ', global_scale=1.0):
  function bvh_node_dict2objects (line 623) | def bvh_node_dict2objects(context, bvh_name, bvh_nodes, rotate_mode='NAT...
  function bvh_node_dict2armature (line 686) | def bvh_node_dict2armature(
  function load (line 908) | def load(
  function _update_scene_fps (line 985) | def _update_scene_fps(context, report, bvh_frame_time):
  function _update_scene_duration (line 1007) | def _update_scene_duration(
  function set_smooth_shading (line 1038) | def set_smooth_shading(mesh: bpy.types.Mesh) -> None:
  function create_mesh_from_pydata (line 1045) | def create_mesh_from_pydata(scene: bpy.types.Scene,
  function add_subdivision_surface_modifier (line 1068) | def add_subdivision_surface_modifier(mesh_object: bpy.types.Object, leve...
  function create_armature_mesh (line 1082) | def create_armature_mesh(scene: bpy.types.Scene, armature_object: bpy.ty...
  class OP_AddMesh (line 1191) | class OP_AddMesh(bpy.types.Operator):
    method __init__ (line 1197) | def __init__(self) -> None:
    method execute (line 1200) | def execute(self, context: bpy.types.Context):
  class OP_RunSynthesis (line 1205) | class OP_RunSynthesis(bpy.types.Operator):
    method execute (line 1211) | def execute(self, context: bpy.types.Context):
  class GENMM_PT_ControlPanel (line 1258) | class GENMM_PT_ControlPanel(bpy.types.Panel):
    method poll (line 1265) | def poll(cls, context: bpy.types.Context):
    method draw_header (line 1268) | def draw_header(self, context: bpy.types.Context):
    method draw (line 1272) | def draw(self, context: bpy.types.Context):
  class PropertyGroup (line 1308) | class PropertyGroup(bpy.types.PropertyGroup):
  function register (line 1374) | def register():
  function unregister (line 1380) | def unregister():

FILE: dataset/blender_motion.py
  class BlenderMotion (line 9) | class BlenderMotion:
    method __init__ (line 10) | def __init__(self, motion_data, repr='quat', use_velo=True, keep_up_po...
    method repr (line 48) | def repr(self):
    method use_velo (line 52) | def use_velo(self):
    method keep_up_pos (line 56) | def keep_up_pos(self):
    method padding_last (line 60) | def padding_last(self):
    method concat_id (line 64) | def concat_id(self):
    method n_pad (line 68) | def n_pad(self):
    method n_contact (line 72) | def n_contact(self):
    method n_rot (line 76) | def n_rot(self):
    method sample (line 79) | def sample(self, size=None, slerp=False):
    method parse (line 85) | def parse(self, motion, keep_velo=False,):

FILE: dataset/bvh/Quaternions.py
  class Quaternions (line 11) | class Quaternions:
    method __init__ (line 31) | def __init__(self, qs):
    method __str__ (line 43) | def __str__(self): return "Quaternions("+ str(self.qs) + ")"
    method __repr__ (line 44) | def __repr__(self): return "Quaternions("+ repr(self.qs) + ")"
    method _broadcast (line 49) | def _broadcast(cls, sqs, oqs, scalar=False):
    method __add__ (line 72) | def __add__(self, other): return self * other
    method __sub__ (line 73) | def __sub__(self, other): return self / other
    method __mul__ (line 77) | def __mul__(self, other):
    method __div__ (line 126) | def __div__(self, other):
    method __eq__ (line 141) | def __eq__(self, other): return self.qs == other.qs
    method __ne__ (line 142) | def __ne__(self, other): return self.qs != other.qs
    method __neg__ (line 144) | def __neg__(self):
    method __abs__ (line 148) | def __abs__(self):
    method __iter__ (line 156) | def __iter__(self): return iter(self.qs)
    method __len__ (line 157) | def __len__(self): return len(self.qs)
    method __getitem__ (line 159) | def __getitem__(self, k):    return Quaternions(self.qs[k])
    method __setitem__ (line 160) | def __setitem__(self, k, v): self.qs[k] = v.qs
    method lengths (line 163) | def lengths(self):
    method reals (line 167) | def reals(self):
    method imaginaries (line 171) | def imaginaries(self):
    method shape (line 175) | def shape(self): return self.qs.shape[:-1]
    method repeat (line 177) | def repeat(self, n, **kwargs):
    method normalized (line 180) | def normalized(self):
    method log (line 183) | def log(self):
    method constrained (line 190) | def constrained(self, axis):
    method constrained_x (line 207) | def constrained_x(self): return self.constrained(np.array([1,0,0]))
    method constrained_y (line 208) | def constrained_y(self): return self.constrained(np.array([0,1,0]))
    method constrained_z (line 209) | def constrained_z(self): return self.constrained(np.array([0,0,1]))
    method dot (line 211) | def dot(self, q): return np.sum(self.qs * q.qs, axis=-1)
    method copy (line 213) | def copy(self): return Quaternions(np.copy(self.qs))
    method reshape (line 215) | def reshape(self, s):
    method interpolate (line 219) | def interpolate(self, ws):
    method euler (line 222) | def euler(self, order='xyz'):
    method average (line 292) | def average(self):
    method angle_axis (line 306) | def angle_axis(self):
    method transforms (line 318) | def transforms(self):
    method ravel (line 343) | def ravel(self):
    method id (line 347) | def id(cls, n):
    method id_like (line 362) | def id_like(cls, a):
    method exp (line 368) | def exp(cls, ws):
    method slerp (line 383) | def slerp(cls, q0s, q1s, a):
    method between (line 412) | def between(cls, v0s, v1s):
    method from_angle_axis (line 418) | def from_angle_axis(cls, angles, axis):
    method from_euler (line 425) | def from_euler(cls, es, order='xyz', world=False):
    method from_transforms (line 440) | def from_transforms(cls, ts):

FILE: dataset/bvh/bvh_io.py
  class Animation (line 32) | class Animation:
    method __init__ (line 33) | def __init__(self, rotations, positions, orients, offsets, parents, na...
    method shape (line 43) | def shape(self):
  function load (line 47) | def load(filename, start=None, end=None, order=None, world=False, need_q...
  function save (line 210) | def save(filename, anim, names=None, frametime=1.0/24.0, order='zyx', po...
  function save_joint (line 292) | def save_joint(f, anim, names, t, i, order='zyx', positions=False):

FILE: dataset/bvh/bvh_parser.py
  class Skeleton (line 11) | class Skeleton:
    method __init__ (line 12) | def __init__(self, names, parent, offsets, joint_reduction=True, skele...
    method parent (line 69) | def parent(self):
    method offsets (line 78) | def offsets(self):
    method names (line 82) | def names(self):
    method ee_id (line 86) | def ee_id(self):
  class BVH_file (line 94) | class BVH_file:
    method __init__ (line 95) | def __init__(self, file_path, skeleton_conf=None, requires_contact=Fal...
    method local_pos (line 137) | def local_pos(self):
    method rescale (line 142) | def rescale(self, ratio):
    method to_tensor (line 146) | def to_tensor(self, repr='euler', rot_only=False):
    method joint_position (line 163) | def joint_position(self):
    method get_rotation (line 171) | def get_rotation(self, repr='quat'):
    method get_position (line 182) | def get_position(self):
    method dfs (line 185) | def dfs(self, x, vis, dist):
    method get_neighbor (line 193) | def get_neighbor(self, threshold, enforce_contact=False):

FILE: dataset/bvh/bvh_writer.py
  function write_bvh (line 6) | def write_bvh(parent, offset, rotation, position, names, frametime, orde...
  class WriterWrapper (line 59) | class WriterWrapper:
    method __init__ (line 60) | def __init__(self, parents, offset=None):
    method write (line 64) | def write(self, filename, rot, pos, offset=None, names=None, repr='qua...

FILE: dataset/bvh_motion.py
  class BVHMotion (line 31) | class BVHMotion:
    method __init__ (line 32) | def __init__(self, bvh_file, skeleton_name=None, repr='quat', use_velo...
    method repr (line 57) | def repr(self):
    method use_velo (line 61) | def use_velo(self):
    method keep_up_pos (line 65) | def keep_up_pos(self):
    method padding_last (line 69) | def padding_last(self):
    method concat_id (line 73) | def concat_id(self):
    method n_pad (line 77) | def n_pad(self):
    method n_contact (line 81) | def n_contact(self):
    method n_rot (line 85) | def n_rot(self):
    method sample (line 88) | def sample(self, size=None, slerp=False):
    method write (line 95) | def write(self, filename, data):
  function load_multiple_dataset (line 125) | def load_multiple_dataset(name_list, **kargs):

FILE: dataset/motion.py
  class MotionData (line 5) | class MotionData:
    method __init__ (line 6) | def __init__(self, data, repr='quat', use_velo=True, keep_up_pos=True,...
    method __len__ (line 69) | def __len__(self):
    method sample (line 76) | def sample(self, size=None, slerp=False):
    method to_velocity (line 90) | def to_velocity(self, pos):
    method to_position (line 102) | def to_position(self, velo):

FILE: dataset/tracks_motion.py
  class TracksParser (line 9) | class TracksParser():
    method __init__ (line 10) | def __init__(self, tracks_json, scale):
    method to_tensor (line 34) | def to_tensor(self, repr='euler', rot_only=False):
    method get_rotation (line 46) | def get_rotation(self, repr='quat'):
    method get_position (line 55) | def get_position(self):
  class TracksMotion (line 58) | class TracksMotion:
    method __init__ (line 59) | def __init__(self, tracks_json, scale=1.0, repr='quat', use_velo=True,...
    method repr (line 77) | def repr(self):
    method use_velo (line 81) | def use_velo(self):
    method keep_up_pos (line 85) | def keep_up_pos(self):
    method padding_last (line 89) | def padding_last(self):
    method n_pad (line 93) | def n_pad(self):
    method n_rot (line 97) | def n_rot(self):
    method sample (line 100) | def sample(self, size=None, slerp=False):
    method parse (line 107) | def parse(self, motion, keep_velo=False,):

FILE: fix_contact.py
  function continuous_filter (line 13) | def continuous_filter(contact, length=2):
  function fix_negative_height (line 30) | def fix_negative_height(contact, constrain, cid):
  function fix_contact (line 40) | def fix_contact(bvh_file, contact):
  function fix_contact_on_file (line 66) | def fix_contact_on_file(prefix, name):

FILE: nearest_neighbor/losses.py
  class PatchCoherentLoss (line 6) | class PatchCoherentLoss(torch.nn.Module):
    method __init__ (line 7) | def __init__(self, patch_size=7, stride=1, alpha=None, loop=False, cac...
    method forward (line 19) | def forward(self, X, Ys, dist_wrapper=None, ext=None, return_blended_r...
    method clean_cache (line 42) | def clean_cache(self):

FILE: nearest_neighbor/utils.py
  function extract_patches (line 8) | def extract_patches(x, patch_size, stride, loop=False):
  function combine_patches (line 22) | def combine_patches(x_shape, ys, patch_size, stride, loop=False):
  function efficient_cdist (line 49) | def efficient_cdist(X, Y):
  function get_col_mins_efficient (line 64) | def get_col_mins_efficient(dist_fn, X, Y, b=1024):
  function get_NNs_Dists (line 82) | def get_NNs_Dists(dist_fn, X, Y, alpha=None, b=1024):

FILE: run_random_generation.py
  function generate (line 70) | def generate(cfg):

FILE: run_web_server.py
  function generate (line 18) | def generate(data):

FILE: utils/base.py
  class ConfigParser (line 15) | class ConfigParser():
    method __init__ (line 16) | def __init__(self, args):
    method __str__ (line 26) | def __str__(self):
    method __getattr__ (line 29) | def __getattr__(self, name):
    method __getitem__ (line 35) | def __getitem__(self, name):
    method merge_config_file (line 41) | def merge_config_file(self, args, allow_invalid=True):
    method set_seed (line 64) | def set_seed(self):
    method save_codes_and_config (line 77) | def save_codes_and_config(self, save_path):
  class logger (line 93) | class logger:
    method __init__ (line 97) | def __init__(self, n_steps, n_lvls):
    method step (line 105) | def step(self):
    method new_lvl (line 110) | def new_lvl(self):
    method print (line 114) | def print(self):
  function set_seed (line 119) | def set_seed(seed=None):

FILE: utils/contact.py
  function foot_contact_by_height (line 4) | def foot_contact_by_height(pos):
  function velocity (line 9) | def velocity(pos, padding=False):
  function foot_contact (line 18) | def foot_contact(pos, ref_height=1., threshold=0.018):
  function alpha (line 27) | def alpha(t):
  function lerp (line 31) | def lerp(a, l, r):
  function constrain_from_contact (line 35) | def constrain_from_contact(contact, glb, fid='TBD', L=5):

FILE: utils/kinematics.py
  class ForwardKinematics (line 5) | class ForwardKinematics:
    method __init__ (line 6) | def __init__(self, parents, offsets=None):
    method forward (line 12) | def forward(self, rots, offsets=None, global_pos=None):
    method accumulate (line 44) | def accumulate(self, local_rots):
    method unaccumulate (line 58) | def unaccumulate(self, global_rots):
  class ForwardKinematicsJoint (line 78) | class ForwardKinematicsJoint:
    method __init__ (line 79) | def __init__(self, parents, offset):
    method forward (line 90) | def forward(self, rotation: torch.Tensor, position: torch.Tensor, offs...
  class InverseKinematicsJoint (line 126) | class InverseKinematicsJoint:
    method __init__ (line 127) | def __init__(self, rotations: torch.Tensor, positions: torch.Tensor, o...
    method step (line 144) | def step(self):
  class InverseKinematicsJoint2 (line 154) | class InverseKinematicsJoint2:
    method __init__ (line 155) | def __init__(self, rotations: torch.Tensor, positions: torch.Tensor, o...
    method step (line 182) | def step(self):
    method get_position (line 198) | def get_position(self):

FILE: utils/skeleton.py
  class SkeletonConv (line 8) | class SkeletonConv(nn.Module):
    method __init__ (line 9) | def __init__(self, neighbour_list, in_channels, out_channels, kernel_s...
    method reset_parameters (line 68) | def reset_parameters(self):
    method set_offset (line 89) | def set_offset(self, offset):
    method forward (line 93) | def forward(self, input):
    method __repr__ (line 105) | def __repr__(self):
  class SkeletonLinear (line 109) | class SkeletonLinear(nn.Module):
    method __init__ (line 110) | def __init__(self, neighbour_list, in_channels, out_channels, extra_di...
    method reset_parameters (line 133) | def reset_parameters(self):
    method forward (line 149) | def forward(self, input):
  class SkeletonPoolJoint (line 157) | class SkeletonPoolJoint(nn.Module):
    method __init__ (line 158) | def __init__(self, topology, pooling_mode, channels_per_joint, last_po...
    method forward (line 223) | def forward(self, input: torch.Tensor):
  class SkeletonPool (line 227) | class SkeletonPool(nn.Module):
    method __init__ (line 228) | def __init__(self, edges, pooling_mode, channels_per_edge, last_pool=F...
    method forward (line 290) | def forward(self, input: torch.Tensor):
  class SkeletonUnpool (line 294) | class SkeletonUnpool(nn.Module):
    method __init__ (line 295) | def __init__(self, pooling_list, channels_per_edge):
    method forward (line 318) | def forward(self, input: torch.Tensor):
  function find_neighbor_joint (line 322) | def find_neighbor_joint(parents, threshold):

FILE: utils/transforms.py
  function batch_mm (line 5) | def batch_mm(matrix, matrix_batch):
  function aa2quat (line 21) | def aa2quat(rots, form='wxyz', unified_orient=True):
  function quat2aa (line 49) | def quat2aa(quats):
  function quat2mat (line 65) | def quat2mat(quats: torch.Tensor):
  function quat2euler (line 103) | def quat2euler(q, order='xyz', degrees=True):
  function euler2mat (line 126) | def euler2mat(rots, order='xyz'):
  function aa2mat (line 139) | def aa2mat(rots):
  function mat2quat (line 150) | def mat2quat(R) -> torch.Tensor:
  function quat2repr6d (line 241) | def quat2repr6d(quat):
  function repr6d2mat (line 248) | def repr6d2mat(repr):
  function repr6d2quat (line 261) | def repr6d2quat(repr) -> torch.Tensor:
  function inv_affine (line 274) | def inv_affine(mat):
  function inv_rigid_affine (line 285) | def inv_rigid_affine(mat):
  function generate_pose (line 295) | def generate_pose(batch_size, device, uniform=False, factor=1, root_rot=...
  function slerp (line 321) | def slerp(l, r, t, unit=True):
  function slerp_quat (line 355) | def slerp_quat(l, r, t):
  function interpolate_6d (line 378) | def interpolate_6d(input, size):
Condensed preview — 32 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (206K chars).
[
  {
    "path": ".gitignore",
    "chars": 1832,
    "preview": "*.json\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\nout/\n# C extensions\n*.so\n*.pkl\n\n# Dist"
  },
  {
    "path": "GenMM.py",
    "chars": 7513,
    "preview": "import os\nimport os.path as osp\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\n\nfrom utils.base import "
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 4548,
    "preview": "# Example-based Motion Synthesis via Generative Motion Matching, ACM Transactions on Graphics (Proceedings of SIGGRAPH 2"
  },
  {
    "path": "__init__.py",
    "chars": 50732,
    "preview": "# This program is free software; you can redistribute it and/or modify\n# it under the terms of the GNU General Public Li"
  },
  {
    "path": "configs/default.yaml",
    "chars": 428,
    "preview": "# motion data config\nrepr: 'repr6d'\nskeleton_name: null\nuse_velo: true\nkeep_up_pos: true\nup_axis: 'Y_UP'\npadding_last: f"
  },
  {
    "path": "configs/ganimator.yaml",
    "chars": 681,
    "preview": "################################################################\n# This configuration uses the same input format of GANi"
  },
  {
    "path": "dataset/blender_motion.py",
    "chars": 3897,
    "preview": "import os\nimport os.path as osp\nimport torch\nimport numpy as np\nimport torch.nn.functional as F\nfrom .motion import Moti"
  },
  {
    "path": "dataset/bvh/Quaternions.py",
    "chars": 17668,
    "preview": "\"\"\"\nThis code is modified from:\nhttp://theorangeduck.com/page/deep-learning-framework-character-motion-synthesis-and-edi"
  },
  {
    "path": "dataset/bvh/bvh_io.py",
    "chars": 9887,
    "preview": "\"\"\"\nThis code is modified from:\nhttp://theorangeduck.com/page/deep-learning-framework-character-motion-synthesis-and-edi"
  },
  {
    "path": "dataset/bvh/bvh_parser.py",
    "chars": 8831,
    "preview": "import torch\nimport numpy as np\nimport dataset.bvh.bvh_io as bvh_io\nfrom utils.kinematics import ForwardKinematicsJoint\n"
  },
  {
    "path": "dataset/bvh/bvh_writer.py",
    "chars": 3354,
    "preview": "import torch\nfrom utils.transforms import quat2euler, repr6d2quat\n\n\n# rotation with shape frame * J * 3\ndef write_bvh(pa"
  },
  {
    "path": "dataset/bvh_motion.py",
    "chars": 6749,
    "preview": "import os\nimport os.path as osp\nimport torch\nimport numpy as np\nimport torch.nn.functional as F\nfrom .motion import Moti"
  },
  {
    "path": "dataset/motion.py",
    "chars": 4226,
    "preview": "import torch\nimport torch.nn.functional as F\n\n\nclass MotionData:\n    def __init__(self, data, repr='quat', use_velo=True"
  },
  {
    "path": "dataset/tracks_motion.py",
    "chars": 5339,
    "preview": "import os\nfrom os.path import join as pjoin\nimport numpy as np\nimport copy\nimport torch\nfrom .motion import MotionData\nf"
  },
  {
    "path": "docker/Dockerfile",
    "chars": 1429,
    "preview": "FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel\n\n# For the convenience for users in China mainland\nCOPY apt-sources.li"
  },
  {
    "path": "docker/README.md",
    "chars": 934,
    "preview": "## Build Docker Environment and use with GPU Support\n\nBefore you can use this Docker environment, you need to have the f"
  },
  {
    "path": "docker/apt-sources.list",
    "chars": 921,
    "preview": "deb https://mirrors.ustc.edu.cn/ubuntu/ bionic main restricted universe multiverse\ndeb-src https://mirrors.ustc.edu.cn/u"
  },
  {
    "path": "docker/requirements.txt",
    "chars": 127,
    "preview": "torch==1.12.1\ntorchvision==0.13.1\ntensorboardX==2.5\ntqdm==4.62.3\nunfoldNd==0.2.0\npyyaml>=5.3.1\ngradio==3.34.0\nmatplotlib"
  },
  {
    "path": "docker/requirements_blender.txt",
    "chars": 76,
    "preview": "torch==2.2.0\ntorchvision==0.17.0\ntqdm==4.62.3\nunfoldNd==0.2.0\npyyaml>=5.3.1\n"
  },
  {
    "path": "fix_contact.py",
    "chars": 3437,
    "preview": "from dataset.bvh.bvh_parser import BVH_file\nfrom os.path import join as pjoin\nimport numpy as np\nimport torch\nfrom utils"
  },
  {
    "path": "nearest_neighbor/losses.py",
    "chars": 1816,
    "preview": "import torch    \nimport torch.nn as nn\n\nfrom .utils import extract_patches, combine_patches, efficient_cdist, get_NNs_Di"
  },
  {
    "path": "nearest_neighbor/utils.py",
    "chars": 4260,
    "preview": "\"\"\"\nthis file borrows some codes from https://github.com/ariel415el/Efficient-GPNN/blob/main/utils/NN.py.\n\"\"\"\nimport tor"
  },
  {
    "path": "run_random_generation.py",
    "chars": 6406,
    "preview": "import os\nimport os.path as osp\nimport argparse\nfrom GenMM import GenMM\nfrom nearest_neighbor.losses import PatchCoheren"
  },
  {
    "path": "run_web_server.py",
    "chars": 1949,
    "preview": "import json\nimport time\nimport torch\nimport argparse\nimport gradio as gr\n\nfrom GenMM import GenMM\nfrom nearest_neighbor."
  },
  {
    "path": "utils/base.py",
    "chars": 3695,
    "preview": "import os\nimport os.path as osp\nimport sys\nimport time\nimport yaml\nimport imageio\nimport random\nimport shutil\nimport ran"
  },
  {
    "path": "utils/contact.py",
    "chars": 3138,
    "preview": "import torch\n\n\ndef foot_contact_by_height(pos):\n    eps = 0.25\n    return (-eps < pos[..., 1]) * (pos[..., 1] < eps)\n\n\nd"
  },
  {
    "path": "utils/kinematics.py",
    "chars": 7684,
    "preview": "import torch\nfrom utils.transforms import quat2mat, repr6d2mat, euler2mat\n\n\nclass ForwardKinematics:\n    def __init__(se"
  },
  {
    "path": "utils/rename_mixamo_rig.py",
    "chars": 257,
    "preview": "# rename_mixamo_prefix.py\nimport bpy, re\nrx = re.compile(r\"mixamorig\\d+:\")          # any number before the colon\n\nfor o"
  },
  {
    "path": "utils/skeleton.py",
    "chars": 13902,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\nimport numpy as np\n\n\nclass SkeletonConv(n"
  },
  {
    "path": "utils/transforms.py",
    "chars": 11397,
    "preview": "import numpy as np\nimport torch\n\n\ndef batch_mm(matrix, matrix_batch):\n    \"\"\"\n    https://github.com/pytorch/pytorch/iss"
  }
]

// ... and 1 more files (download for full content)

About this extraction

This page contains the full source code of the wyysf-98/GenMM GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 32 files (193.8 KB), approximately 52.0k tokens, and a symbol index with 239 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!