[
  {
    "path": ".gitattributes",
    "content": ""
  },
  {
    "path": ".gitignore",
    "content": "videos\n/timechamber/logs\n*train_dir*\n*ige_logs*\n*.egg-info\n/.vs\n/.vscode\n/_package\n/shaders\n._tmptext.txt\n__pycache__/\n/timechamber/tasks/__pycache__\n/timechamber/utils/__pycache__\n/timechamber/tasks/base/__pycache__\n/tools/format/.lastrun\n*.pyc\n_doxygen\n/rlisaacgymenvsgpu/logs\n/timechamber/benchmarks/results\n/timechamber/simpletests/results\n*.pxd2\n/tests/logs\n/timechamber/balance_bot.xml\n/timechamber/quadcopter.xml\n/timechamber/ingenuity.xml\nlogs*\nnn/\nruns/\n.idea\noutputs/\n*.hydra*\n/timechamber/wandb\n/test\n.gitlab\n\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2022 MIT Inspir.ai\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE."
  },
  {
    "path": "LISENCE/isaacgymenvs/LICENSE",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."
  },
  {
    "path": "README.md",
    "content": "# TimeChamber: A Massively Parallel Large Scale Self-Play Framework\n\n****\n\n**TimeChamber** is a large scale self-play framework running on parallel simulation.\nRunning self-play algorithms always need lots of hardware resources, especially on 3D physically simulated\nenvironments.\nWe provide a self-play framework that can achieve fast training and evaluation with **ONLY ONE GPU**.\nTimeChamber is developed with the following key features:\n\n- **Parallel Simulation**: TimeChamber is built within [Isaac Gym](https://developer.nvidia.com/isaac-gym). Isaac Gym is\n  a fast GPU-based simulation platform. It supports running thousands of environments in parallel on a single GPU.For\n  example, on one NVIDIA Laptop RTX 3070Ti GPU, TimeChamber can reach **80,000+\n  mean FPS** by running 4,096 environments in parallel.\n- **Parallel Evaluation**: TimeChamber can fast calculate dozens of policies' ELO\n  rating(represent their combat power). It also supports multi-player ELO calculations\n  by [multi-elo](https://github.com/djcunningham0/multielo). Inspired by Vectorization techniques\n  for [fast population-based training](https://github.com/instadeepai/fastpbrl), we leverage the\n  vectorized models to evaluate different policy in parallel.\n- **Prioritized Fictitious Self-Play Benchmark**: We implement a classic PPO self-play algorithm on top\n  of [rl_games](https://github.com/Denys88/rl_games), with a prioritized player pool to avoid cycles and improve the\n  diversity of training policy.\n\n<div align=center>\n<img src=\"assets/images/algorithm.jpg\" align=\"center\" width=\"600\"/>\n</div> \n\n- **Competitive Multi-Agent Tasks**: Inspired by [OpenAI RoboSumo](https://github.com/openai/robosumo) and [ASE](https://github.com/nv-tlabs/ASE), we introduce three\n  competitive multi-agent tasks(e.g.,Ant Sumo,Ant\n  Battle and Humanoid Strike) as examples.\n  The efficiency of our self-play framework has been tested on these tasks. After days of training,our agent can\n  discover some interesting\n  physical skills like pulling, jumping,etc. **Welcome to contribute your own environments!**\n\n\n## Installation\n\n****\nDownload and follow the installation instructions of Isaac Gym: https://developer.nvidia.com/isaac-gym  \nEnsure that Isaac Gym works on your system by running one of the examples from the `python/examples`\ndirectory, like `joint_monkey.py`. If you have any trouble running the samples, please follow troubleshooting steps\ndescribed in the [Isaac Gym Preview Release 3/4 installation instructions](https://developer.nvidia.com/isaac-gym).  \nThen install this repo:\n\n```bash\npip install -e .\n```\n\n## Quick Start\n\n****\n\n### Tasks\n\nSource code for tasks can be found in  `timechamber/tasks`,The detailed settings of state/action/reward are\nin [here](./docs/environments.md).\nMore interesting tasks will come soon.\n\n#### Humanoid Strike\n\nHumanoid Strike is a 3D environment with two simulated humanoid physics characters. Each character is equipped with a sword and shield with 37 degrees-of-freedom.\nThe game will be restarted if one agent goes outside the arena. We measure how much the player damaged the opponent and how much the player was damaged by the opponent in the terminated step to determine the winner.\n\n<div align=center>\n<img src=\"assets/images/humanoid_strike.gif\" align=\"center\" width=\"600\"/>\n</div> \n\n\n\n#### Ant Sumo\n\nAnt Sumo is a 3D environment with simulated physics that allows pairs of ant agents to compete against each other.\nTo win, the agent has to push the opponent out of the ring. Every agent has 100 hp . Each step, If the agent's body\ntouches the ground, its hp will be reduced by 1.The agent whose hp becomes 0 will be eliminated.\n<div align=center>\n<img src=\"assets/images/ant_sumo.gif\" align=\"center\" width=\"600\"/>\n</div> \n\n#### Ant Battle\n\nAnt Battle is an expanded environment of Ant Sumo. It supports more than two agents competing against with\neach other. The battle ring radius will shrink, the agent going out of the ring will be eliminated.\n<div align=center>\n<img src=\"assets/images/ant_battle.gif\" align=\"center\" width=\"600\"/>\n</div>  \n\n### Self-Play Training\n\nTo train your policy for tasks, for example:\n\n```bash\n# run self-play training for Humanoid Strike task\npython train.py task=MA_Humanoid_Strike headless=True\n```\n\n```bash\n# run self-play training for Ant Sumo task\npython train.py task=MA_Ant_Sumo train=MA_Ant_SumoPPO headless=True\n```\n\n```bash\n# run self-play training for Ant Battle task\npython train.py task=MA_Ant_Battle train=MA_Ant_BattlePPO headless=True\n```\n\nKey arguments to the training script\nfollow [IsaacGymEnvs Configuration and command line arguments](https://github.com/NVIDIA-Omniverse/IsaacGymEnvs/blob/main/README.md#configuration-and-command-line-arguments)\n.\nOther training arguments follow [rl_games config parameters](https://github.com/Denys88/rl_games#config-parameters),\nyou can change them in `timechamber/tasks/train/*.yaml`. There are some specific arguments for self-play training:\n\n- `num_agents`: Set the number of agents for Ant Battle environment, it should be larger than 1.\n- `op_checkpoint`: Set to path to the checkpoint to load initial opponent agent policy.\n  If it's empty, opponent agent will use random policy.\n- `update_win_rate`: Win_rate threshold to add the current policy to opponent's player pool.\n- `player_pool_length`: The max size of player pool, following FIFO rules.\n- `games_to_check`: Warm up for training, the player pool won't be updated until the current policy plays such number of\n  games.\n- `max_update_steps`: If current policy update iterations exceed that number, the current policy will be added to\n  opponent player_pool.\n\n### Policies Evaluation\n\nTo evaluate your policies, for example:\n\n```bash\n# run testing for Ant Sumo policy\npython train.py task=MA_Ant_Sumo train=MA_Ant_SumoPPO test=True num_envs=4 minibatch_size=32 headless=False checkpoint='models/ant_sumo/policy.pth'\n```\n\n```bash\n# run testing for Humanoid Strike policy\npython train.py task=MA_Humanoid_Strike train=MA_Humanoid_StrikeHRL test=True num_envs=4 minibatch_size=32 headless=False checkpoint='models/Humanoid_Strike/policy.pth' op_checkpoint='models/Humanoid_Strike/policy_op.pth'\n```\n\nYou can set the opponent agent policy using `op_checkpoint`. If it's empty, the opponent agent will use the same policy\nas `checkpoint`.  \nWe use vectorized models to accelerate the evaluation of policies. Put policies into checkpoint dir, let them compete\nwith each\nother in parallel:\n\n```bash\n# run testing for Ant Sumo policy\npython train.py task=MA_Ant_Sumo train=MA_Ant_SumoPPO test=True headless=True checkpoint='models/ant_sumo' player_pool_type=vectorized\n```\n\nThere are some specific arguments for self-play evaluation, you can change them in `timechamber/tasks/train/*.yaml`:\n\n- `games_num`: Total episode number of evaluation.\n- `record_elo`: Set `True` to record the ELO rating of your policies, after evaluation, you can check the `elo.jpg` in\n  your checkpoint dir.\n\n<div align=center>\n  <img src=\"assets/images/elo.jpg\" align=\"center\" width=\"400\"/>\n</div>\n\n- `init_elo`: Initial ELO rating of each policy.\n\n### Building Your Own Task\n\nYou can build your own task\nfollow [IsaacGymEnvs](https://github.com/NVIDIA-Omniverse/IsaacGymEnvs/blob/main/README.md#creating-an-environment)\n, make sure the obs shape is correct and`info` contains `win`,`lose`and`draw`:\n\n```python\nimport isaacgym\nimport timechamber\nimport torch\n\nenvs = timechamber.make(\n    seed=0,\n    task=\"MA_Ant_Sumo\",\n    num_envs=2,\n    sim_device=\"cuda:0\",\n    rl_device=\"cuda:0\",\n)\n# the obs shape should be (num_agents*num_envs,num_obs).\n# the obs of training agent is (:num_envs,num_obs)\nprint(\"Observation space is\", envs.observation_space)\nprint(\"Action space is\", envs.action_space)\nobs = envs.reset()\nfor _ in range(20):\n    obs, reward, done, info = envs.step(\n        torch.rand((2 * 2,) + envs.action_space.shape, device=\"cuda:0\")\n    )\n# info:\n# {'win': tensor([Bool, Bool])\n# 'lose': tensor([Bool, Bool])\n# 'draw': tensor([Bool, Bool])}\n\n```\n\n## Citing\n\nIf you use timechamber in your research please use the following citation:\n\n````\n@misc{InspirAI,\n  author = {Huang Ziming, Ziyi Liu, Wu Yutong, Flood Sung},\n  title = {TimeChamber: A Massively Parallel Large Scale Self-Play Framework},\n  year = {2022},\n  publisher = {GitHub},\n  journal = {GitHub repository},\n  howpublished = {\\url{https://github.com/inspirai/TimeChamber}},\n}"
  },
  {
    "path": "assets/mjcf/nv_ant.xml",
    "content": "<mujoco model=\"ant\">\n  <custom>\n    <numeric data=\"0.0 0.0 0.55 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -1.0 0.0 -1.0 0.0 1.0\" name=\"init_qpos\"/>\n  </custom>\n\n  <default>\n    <joint armature=\"0.01\" damping=\"0.1\" limited=\"true\"/>\n    <geom condim=\"3\" density=\"5.0\" friction=\"1.5 0.1 0.1\" margin=\"0.01\" rgba=\"0.97 0.38 0.06 1\"/>\n  </default>\n\n  <compiler inertiafromgeom=\"true\" angle=\"degree\"/>\n\n  <option timestep=\"0.016\" iterations=\"50\" tolerance=\"1e-10\" solver=\"Newton\" jacobian=\"dense\" cone=\"pyramidal\"/>\n\n  <size nconmax=\"50\" njmax=\"200\" nstack=\"10000\"/>\n  <visual>\n      <map force=\"0.1\" zfar=\"30\"/>\n      <rgba haze=\"0.15 0.25 0.35 1\"/>\n      <quality shadowsize=\"2048\"/>\n      <global offwidth=\"800\" offheight=\"800\"/>\n  </visual>\n\n  <asset>\n      <texture type=\"skybox\" builtin=\"gradient\" rgb1=\"0.3 0.5 0.7\" rgb2=\"0 0 0\" width=\"512\" height=\"512\"/> \n      <texture name=\"texplane\" type=\"2d\" builtin=\"checker\" rgb1=\".2 .3 .4\" rgb2=\".1 0.15 0.2\" width=\"512\" height=\"512\" mark=\"cross\" markrgb=\".8 .8 .8\"/>\n      <texture name=\"texgeom\" type=\"cube\" builtin=\"flat\" mark=\"cross\" width=\"127\" height=\"1278\" \n          rgb1=\"0.8 0.6 0.4\" rgb2=\"0.8 0.6 0.4\" markrgb=\"1 1 1\" random=\"0.01\"/>  \n\n      <material name=\"matplane\" reflectance=\"0.3\" texture=\"texplane\" texrepeat=\"1 1\" texuniform=\"true\"/>\n      <material name=\"matgeom\" texture=\"texgeom\" texuniform=\"true\" rgba=\"0.8 0.6 .4 1\"/>\n  </asset>\n\n  <worldbody>\n    <geom name=\"floor\" pos=\"0 0 0\" size=\"0 0 .25\" type=\"plane\" material=\"matplane\" condim=\"3\"/>\n\n    <light directional=\"false\" diffuse=\".2 .2 .2\" specular=\"0 0 0\" pos=\"0 0 5\" dir=\"0 0 -1\" castshadow=\"false\"/>\n    <light mode=\"targetbodycom\" target=\"torso\" directional=\"false\" diffuse=\".8 .8 .8\" specular=\"0.3 0.3 0.3\" pos=\"0 0 4.0\" dir=\"0 0 -1\"/>\n\n    <body name=\"torso\" pos=\"0 0 0.75\">\n      <freejoint name=\"root\"/>\n      <geom name=\"torso_geom\" pos=\"0 0 0\" size=\"0.25\" type=\"sphere\"/>\n      <geom fromto=\"0.0 0.0 0.0 0.2 0.2 0.0\" name=\"aux_1_geom\" size=\"0.08\" type=\"capsule\" rgba=\".999 .2 .1 1\"/>\n      <geom fromto=\"0.0 0.0 0.0 -0.2 0.2 0.0\" name=\"aux_2_geom\" size=\"0.08\" type=\"capsule\"/>\n      <geom fromto=\"0.0 0.0 0.0 -0.2 -0.2 0.0\" name=\"aux_3_geom\" size=\"0.08\" type=\"capsule\"/>\n      <geom fromto=\"0.0 0.0 0.0 0.2 -0.2 0.0\" name=\"aux_4_geom\" size=\"0.08\" type=\"capsule\" rgba=\".999 .2 .02 1\"/>\n\n      <body name=\"front_left_leg\" pos=\"0.2 0.2 0\">\n        <joint axis=\"0 0 1\" name=\"hip_1\" pos=\"0.0 0.0 0.0\" range=\"-40 40\" type=\"hinge\"/>\n        <geom fromto=\"0.0 0.0 0.0 0.2 0.2 0.0\" name=\"left_leg_geom\" size=\"0.08\" type=\"capsule\" rgba=\".999 .2 .1 1\"/>\n        <body pos=\"0.2 0.2 0\" name=\"front_left_foot\">\n          <joint axis=\"-1 1 0\" name=\"ankle_1\" pos=\"0.0 0.0 0.0\" range=\"30 100\" type=\"hinge\"/>\n          <geom fromto=\"0.0 0.0 0.0 0.4 0.4 0.0\" name=\"left_ankle_geom\" size=\"0.08\" type=\"capsule\" rgba=\".999 .2 .1 1\"/>\n        </body>\n      </body>\n      <body name=\"front_right_leg\" pos=\"-0.2 0.2 0\">\n        <joint axis=\"0 0 1\" name=\"hip_2\" pos=\"0.0 0.0 0.0\" range=\"-40 40\" type=\"hinge\"/>\n        <geom fromto=\"0.0 0.0 0.0 -0.2 0.2 0.0\" name=\"right_leg_geom\" size=\"0.08\" type=\"capsule\"/>\n        <body pos=\"-0.2 0.2 0\" name=\"front_right_foot\">\n          <joint axis=\"1 1 0\" name=\"ankle_2\" pos=\"0.0 0.0 0.0\" range=\"-100 -30\" type=\"hinge\"/>\n          <geom fromto=\"0.0 0.0 0.0 -0.4 0.4 0.0\" name=\"right_ankle_geom\" size=\"0.08\" type=\"capsule\"/>\n        </body>\n      </body>\n      <body name=\"left_back_leg\" pos=\"-0.2 -0.2 0\">\n        <joint axis=\"0 0 1\" name=\"hip_3\" pos=\"0.0 0.0 0.0\" range=\"-40 40\" type=\"hinge\"/>\n        <geom fromto=\"0.0 0.0 0.0 -0.2 -0.2 0.0\" name=\"back_leg_geom\" size=\"0.08\" type=\"capsule\"/>\n        <body pos=\"-0.2 -0.2 0\" name=\"left_back_foot\">\n          <joint axis=\"-1 1 0\" name=\"ankle_3\" pos=\"0.0 0.0 0.0\" range=\"-100 -30\" type=\"hinge\"/>\n          <geom fromto=\"0.0 0.0 0.0 -0.4 -0.4 0.0\" name=\"third_ankle_geom\" size=\"0.08\" type=\"capsule\"/>\n        </body>\n      </body>\n      <body name=\"right_back_leg\" pos=\"0.2 -0.2 0\">\n        <joint axis=\"0 0 1\" name=\"hip_4\" pos=\"0.0 0.0 0.0\" range=\"-40 40\" type=\"hinge\"/>\n        <geom fromto=\"0.0 0.0 0.0 0.2 -0.2 0.0\" name=\"rightback_leg_geom\" size=\"0.08\" type=\"capsule\" rgba=\".999 .2 .1 1\"/>\n        <body pos=\"0.2 -0.2 0\" name=\"right_back_foot\">\n          <joint axis=\"1 1 0\" name=\"ankle_4\" pos=\"0.0 0.0 0.0\" range=\"30 100\" type=\"hinge\"/>\n          <geom fromto=\"0.0 0.0 0.0 0.4 -0.4 0.0\" name=\"fourth_ankle_geom\" size=\"0.08\" type=\"capsule\" rgba=\".999 .2 .1 1\"/>\n        </body>\n      </body>\n    </body>\n  </worldbody>\n\n  <actuator>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"hip_4\" gear=\"15\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"ankle_4\" gear=\"15\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"hip_1\" gear=\"15\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"ankle_1\" gear=\"15\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"hip_2\" gear=\"15\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"ankle_2\" gear=\"15\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"hip_3\" gear=\"15\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"ankle_3\" gear=\"15\"/>\n  </actuator>\n</mujoco>\n"
  },
  {
    "path": "docs/environments.md",
    "content": "## Environments\n\nWe provide a detailed description of the environment here.\n\n### Humanoid Strike\n\nHumanoid Strike is a 3D environment with two simulated humanoid physics characters. Each character is equipped with a sword and shield with 37 degrees-of-freedom.\nThe game will be restarted if one agent goes outside the arena or the game reaches the maximum episode steps. We measure how much the player damaged the opponent and how much the player was damaged by the opponent in the terminated step to determine the winner.\n\n#### <span id=\"obs1-1\">Low-Level Observation Space</span>\n\n|  Index  |          Description           |\n|:-------:|:------------------------------:|\n|  0   |           Height of the root from the ground.            |\n|  1 - 48  |         Position of the body in the character’s local coordinate frame.         |\n|  49 - 150  |      Rotation of the body in the character’s local coordinate frame.      |\n| 151 - 201 |      Linear velocity of the root in the character’s local coordinate frame.       |\n| 202 - 252 |      angular velocity of the root in the character’s local coordinate frame.          |\n\n\n#### <span id=\"obs1-2\">High-Level Observation Space</span>\n\n|  Index  |          Description           |\n|:-------:|:------------------------------:|\n|  0 - 1  |    relative distance from the borderline            |\n|  2 - 4  |    relative distance from the opponent          |\n|  5 - 10  |      Rotation of the opponent's root in the character’s local coordinate frame.      |\n| 11 - 13 |      Linear velocity of the opponent'root in the character’s local coordinate frame.       |\n| 14 - 16 |      angular velocity of the opponent'root in the character’s local coordinate frame.         |\n| 17 - 19 |      relative distance between ego agent and opponent's sword         |\n| 20 - 22 |      Linear velocity of the opponent' sword in the character’s local coordinate frame.          |\n| 23 - 25 |      relative distance between ego agent' shield and opponent's sword        |\n| 26 - 28 | relative velocity between ego agent' shield and opponent's sword |\n|   29 - 31    |   relative distance between ego agent' sword and opponent's torse    |\n|   32 - 34    | relative velocity between ego agent' sword and opponent's torse  |\n|   35 - 37    |   relative distance between ego agent' sword and opponent's head    |\n|   38 - 40    | relative velocity between ego agent' sword and opponent's head  |\n|   41 - 43    |   relative distance between ego agent' sword and opponent's right arm    |\n|   44 - 46    | relative distance between ego agent' sword and opponent's right thigh  |\n|   47 - 49    | relative distance between ego agent' sword and opponent's left thigh  |\n\n\n#### <span id=\"action1-1\">Low-Level Action Space</span>\n\n| Index |    Description    |\n|:-----:|:-----------------:|\n| 0 - 30 | target rotations  of each character’s joints |\n\n#### <span id=\"action1-2\">High-Level Action Space</span>\n\n| Index |    Description    |\n|:-----:|:-----------------:|\n| 0 - 63 | latent skill variables |\n\n#### <span id=\"r1\">Rewards</span>\n\nThe weights of reward components are as follows:\n\n```python\nop_fall_reward_w = 200.0\nego_fall_out_reward_w = 50.0\nshield_to_sword_pos_reward_w = 1.0\ndamage_reward_w = 8.0\nsword_to_op_reward_w = 0.8\nreward_energy_w = 3.0\nreward_strike_vel_acc_w = 3.0\nreward_face_w = 4.0\nreward_foot_to_op_w = 10.0\nreward_kick_w = 2.0\n```\n\n\n### Ant Sumo\n\nAnt Sumo is a 3D environment with simulated physics that allows pairs of ant agents to compete against each other.\nTo win, the agent has to push the opponent out of the ring. Every agent has 100 hp . Each step, If the agent's body\ntouches the ground, its hp will be reduced by 1.The agent whose hp becomes 0 will be eliminated.\n\n#### <span id=\"obs2\">Observation Space</span>\n\n|  Index  |          Description           |\n|:-------:|:------------------------------:|\n|  0 - 2  |           self pose            |\n|  3 - 6  |         self rotation          |\n|  7 - 9  |      self linear velocity      |\n| 10 - 12 |      self angle velocity       |\n| 13 - 20 |          self dof pos          |\n| 21 - 28 |       self dof velocity        |\n| 29 - 31 |         opponent pose          |\n| 32 - 35 |       opponent rotation        |\n| 36 - 37 | self-opponent pose vector(x,y) |\n|   38    |   is self body touch ground    |\n|   39    | is opponent body touch ground  |\n\n#### <span id=\"action2\">Action Space</span>\n\n| Index |    Description    |\n|:-----:|:-----------------:|\n| 0 - 7 | self dof position |\n\n#### <span id=\"r2\">Rewards</span>\n\nThe reward consists of two parts:sparse reward and dense reward.\n\n```python\nwin_reward = 2000\nlose_penalty = -2000\ndraw_penalty = -1000\ndense_reward_scale = 1.\ndof_at_limit_cost = torch.sum(obs_buf[:, 13:21] > 0.99, dim=-1) * joints_at_limit_cost_scale\npush_reward = -push_scale * torch.exp(-torch.linalg.norm(obs_buf_op[:, :2], dim=-1))\naction_cost_penalty = torch.sum(torch.square(torques), dim=1) * action_cost_scale\nnot_move_penalty = -10 * torch.exp(-torch.sum(torch.abs(torques), dim=1))\ndense_reward = move_reward + dof_at_limit_cost + push_reward + action_cost_penalty + not_move_penalty\ntotal_reward = win_reward + lose_penalty + draw_penalty + dense_reward * dense_reward_scale\n```\n\n### Ant Battle\n\nAnt Battle is an expanded environment of Ant Sumo. It supports more than two agents competing against with\neach other. The battle ring radius will shrink, the agent going out of the ring will be eliminated.\n\n#### <span id=\"obs3\">Observation Space</span>\n\n|  Index  |              Description               |\n|:-------:|:--------------------------------------:|\n|  0 - 2  |               self pose                |\n|  3 - 6  |             self rotation              |\n|  7 - 9  |          self linear velocity          |\n| 10 - 12 |          self angle velocity           |\n| 13 - 20 |              self dof pos              |\n| 21 - 28 |           self dof velocity            |\n|   29    |    border radius-self dis to centre    |\n|   30    |             border radius              |\n|   31    |       is self body touch ground        |\n| 32 - 34 |            opponent_1 pose             |\n| 35 - 38 |          opponent_1 rotation           |\n| 39 - 40 |    self-opponent_1 pose vector(x,y)    |\n| 41 - 48 |          opponent_1 dof pose           |\n| 49 - 56 |        opponent_1 dof velocity         |\n|   57    | border radius-opponent_1 dis to centre |\n|   58    |    is opponent_1 body touch ground     |\n|   ...   |                  ...                   |\n\n#### <span id=\"action3\">Action Space</span>\n\n| Index |    Description    |\n|:-----:|:-----------------:|\n| 0 - 7 | self dof position |\n\n#### <span id=\"r3\">Rewards</span>\n\nThe reward consists of two parts:sparse reward and dense reward.\n\n```python\nwin_reward_scale = 2000\nreward_per_rank = 2 * win_reward_scale / (num_agents - 1)\nsparse_reward = sparse_reward * (win_reward_scale - (nxt_rank[:, 0] - 1) * reward_per_rank)\nstay_in_center_reward = stay_in_center_reward_scale * torch.exp(-torch.linalg.norm(obs[0, :, :2], dim=-1))\ndof_at_limit_cost = torch.sum(obs[0, :, 13:21] > 0.99, dim=-1) * joints_at_limit_cost_scale\naction_cost_penalty = torch.sum(torch.square(torques), dim=1) * action_cost_scale\nnot_move_penalty = torch.exp(-torch.sum(torch.abs(torques), dim=1))\ndense_reward = dof_at_limit_cost + action_cost_penalty + not_move_penalty + stay_in_center_reward\ntotal_reward = sparse_reward + dense_reward * dense_reward_scale\n```"
  },
  {
    "path": "setup.py",
    "content": "\"\"\"Installation script for the 'timechamber' python package.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import print_function\nfrom __future__ import division\n\nfrom setuptools import setup, find_packages\n\nimport os\n\nroot_dir = os.path.dirname(os.path.realpath(__file__))\n\n# Minimum dependencies required prior to installation\nINSTALL_REQUIRES = [\n    # RL\n    \"gym==0.24\",\n    \"torch\",\n    \"omegaconf\",\n    \"termcolor\",\n    \"dill\",\n    \"hydra-core>=1.1\",\n    \"rl-games==1.5.2\",\n    \"pyvirtualdisplay\",\n    \"multielo @ git+https://github.com/djcunningham0/multielo.git@440f7922b90ff87009f8283d6491eb0f704e6624\",\n    \"matplotlib==3.5.2\",\n    \"pytest==7.1.2\",\n]\n\n# Installation operation\nsetup(\n    name=\"timechamber\",\n    author=\"ZeldaHuang, Ziyi Liu\",\n    version=\"0.0.1\",\n    description=\"A Massively Parallel Large Scale Self-Play Framework\",\n    keywords=[\"robotics\", \"rl\"],\n    include_package_data=True,\n    python_requires=\">=3.6.*\",\n    install_requires=INSTALL_REQUIRES,\n    packages=find_packages(\".\"),\n    classifiers=[\"Natural Language :: English\", \"Programming Language :: Python :: 3.7, 3.8\"],\n    zip_safe=False,\n)\n\n# EOF\n"
  },
  {
    "path": "timechamber/__init__.py",
    "content": "import hydra\nfrom hydra import compose, initialize\nfrom hydra.core.hydra_config import HydraConfig\nfrom omegaconf import DictConfig, OmegaConf\nfrom timechamber.utils.reformat import omegaconf_to_dict\n\n\nOmegaConf.register_new_resolver('eq', lambda x, y: x.lower()==y.lower())\nOmegaConf.register_new_resolver('contains', lambda x, y: x.lower() in y.lower())\nOmegaConf.register_new_resolver('if', lambda pred, a, b: a if pred else b)\nOmegaConf.register_new_resolver('resolve_default', lambda default, arg: default if arg=='' else arg)\n\n\ndef make(\n    seed: int, \n    task: str, \n    num_envs: int, \n    sim_device: str,\n    rl_device: str,\n    graphics_device_id: int = -1,\n    device_type: str = \"cuda\",\n    headless: bool = False,\n    multi_gpu: bool = False,\n    virtual_screen_capture: bool = False,\n    force_render: bool = True,\n    cfg: DictConfig = None\n):\n    from timechamber.utils.rlgames_utils import get_rlgames_env_creator\n    # create hydra config if no config passed in\n    if cfg is None:\n        # reset current hydra config if already parsed (but not passed in here)\n        if HydraConfig.initialized():\n            task = HydraConfig.get().runtime.choices['task']\n            hydra.core.global_hydra.GlobalHydra.instance().clear()\n\n        with initialize(config_path=\"./cfg\"):\n            cfg = compose(config_name=\"config\", overrides=[f\"task={task}\"])\n            task_dict = omegaconf_to_dict(cfg.task)\n            task_dict['env']['numEnvs'] = num_envs\n    # reuse existing config\n    else:\n        task_dict = omegaconf_to_dict(cfg.task)\n    task_dict['seed'] = cfg.seed\n    task_dict['rl_device'] = rl_device\n    if cfg.motion_file:\n        task_dict['env']['motion_file'] = cfg.motion_file\n    \n    create_rlgpu_env = get_rlgames_env_creator(\n        seed=seed,\n        cfg=cfg,\n        task_config=task_dict,\n        task_name=task_dict[\"name\"],\n        sim_device=sim_device,\n        rl_device=rl_device,\n        graphics_device_id=graphics_device_id,\n        headless=headless,\n        device_type=device_type,\n        multi_gpu=multi_gpu,\n        virtual_screen_capture=virtual_screen_capture,\n        force_render=force_render,\n    )\n    return create_rlgpu_env()\n"
  },
  {
    "path": "timechamber/ase/ase_agent.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n\nimport torch\nimport torch.nn as nn\nfrom isaacgym.torch_utils import *\nfrom rl_games.algos_torch import torch_ext\nfrom rl_games.common import a2c_common\nfrom rl_games.algos_torch.running_mean_std import RunningMeanStd\n\nfrom timechamber.ase import ase_network_builder\nfrom timechamber.ase.utils import amp_agent \n\nclass ASEAgent(amp_agent.AMPAgent):\n    def __init__(self, base_name, config):\n        super().__init__(base_name, config)\n        return\n\n    def init_tensors(self):\n        super().init_tensors()\n        \n        batch_shape = self.experience_buffer.obs_base_shape\n        self.experience_buffer.tensor_dict['ase_latents'] = torch.zeros(batch_shape + (self._latent_dim,),\n                                                                dtype=torch.float32, device=self.ppo_device)\n        \n        self._ase_latents = torch.zeros((batch_shape[-1], self._latent_dim), dtype=torch.float32,\n                                         device=self.ppo_device)\n        \n        self.tensor_list += ['ase_latents']\n\n        self._latent_reset_steps = torch.zeros(batch_shape[-1], dtype=torch.int32, device=self.ppo_device)\n        num_envs = self.vec_env.env.task.num_envs\n        env_ids = to_torch(np.arange(num_envs), dtype=torch.long, device=self.ppo_device)\n        self._reset_latent_step_count(env_ids)\n\n        return\n    \n    def play_steps(self):\n        self.set_eval()\n\n        epinfos = []\n        done_indices = []\n        update_list = self.update_list\n\n        for n in range(self.horizon_length):\n            self.obs = self.env_reset(done_indices)\n            self.experience_buffer.update_data('obses', n, self.obs['obs'])\n\n            self._update_latents()\n\n            if self.use_action_masks:\n                masks = self.vec_env.get_action_masks()\n                res_dict = self.get_masked_action_values(self.obs, self._ase_latents, masks)\n            else:\n                res_dict = self.get_action_values(self.obs, self._ase_latents, self._rand_action_probs)\n\n            for k in update_list:\n                self.experience_buffer.update_data(k, n, res_dict[k]) \n\n            if self.has_central_value:\n                self.experience_buffer.update_data('states', n, self.obs['states'])\n\n            self.obs, rewards, self.dones, infos = self.env_step(res_dict['actions'])\n            shaped_rewards = self.rewards_shaper(rewards)\n            self.experience_buffer.update_data('rewards', n, shaped_rewards)\n            self.experience_buffer.update_data('next_obses', n, self.obs['obs'])\n            self.experience_buffer.update_data('dones', n, self.dones)\n            self.experience_buffer.update_data('amp_obs', n, infos['amp_obs'])\n            self.experience_buffer.update_data('ase_latents', n, self._ase_latents)\n            self.experience_buffer.update_data('rand_action_mask', n, res_dict['rand_action_mask'])\n\n            terminated = infos['terminate'].float()\n            terminated = terminated.unsqueeze(-1)\n            next_vals = self._eval_critic(self.obs, self._ase_latents)\n            next_vals *= (1.0 - terminated)\n            self.experience_buffer.update_data('next_values', n, next_vals)\n\n            self.current_rewards += rewards\n            self.current_lengths += 1\n            all_done_indices = self.dones.nonzero(as_tuple=False)\n            done_indices = all_done_indices[::self.num_agents]\n\n            self.game_rewards.update(self.current_rewards[done_indices])\n            self.game_lengths.update(self.current_lengths[done_indices])\n            self.algo_observer.process_infos(infos, done_indices)\n\n            not_dones = 1.0 - self.dones.float()\n\n            self.current_rewards = self.current_rewards * not_dones.unsqueeze(1)\n            self.current_lengths = self.current_lengths * not_dones\n        \n            if (self.vec_env.env.task.viewer):\n                self._amp_debug(infos, self._ase_latents)\n\n            done_indices = done_indices[:, 0]\n\n        mb_fdones = self.experience_buffer.tensor_dict['dones'].float()\n        mb_values = self.experience_buffer.tensor_dict['values']\n        mb_next_values = self.experience_buffer.tensor_dict['next_values']\n        \n        mb_rewards = self.experience_buffer.tensor_dict['rewards']\n        mb_amp_obs = self.experience_buffer.tensor_dict['amp_obs']\n        mb_ase_latents = self.experience_buffer.tensor_dict['ase_latents']\n        amp_rewards = self._calc_amp_rewards(mb_amp_obs, mb_ase_latents)\n        mb_rewards = self._combine_rewards(mb_rewards, amp_rewards)\n        \n        mb_advs = self.discount_values(mb_fdones, mb_values, mb_rewards, mb_next_values)\n        mb_returns = mb_advs + mb_values\n\n        batch_dict = self.experience_buffer.get_transformed_list(a2c_common.swap_and_flatten01, self.tensor_list)\n        batch_dict['returns'] = a2c_common.swap_and_flatten01(mb_returns)\n        batch_dict['played_frames'] = self.batch_size\n\n        for k, v in amp_rewards.items():\n            batch_dict[k] = a2c_common.swap_and_flatten01(v)\n\n        return batch_dict\n\n    def get_action_values(self, obs_dict, ase_latents, rand_action_probs):\n        processed_obs = self._preproc_obs(obs_dict['obs'])\n\n        self.model.eval()\n        input_dict = {\n            'is_train': False,\n            'prev_actions': None, \n            'obs' : processed_obs,\n            'rnn_states' : self.rnn_states,\n            'ase_latents': ase_latents\n        }\n\n        with torch.no_grad():\n            res_dict = self.model(input_dict)\n            if self.has_central_value:\n                states = obs_dict['states']\n                input_dict = {\n                    'is_train': False,\n                    'states' : states,\n                }\n                value = self.get_central_value(input_dict)\n                res_dict['values'] = value\n\n        if self.normalize_value:\n            res_dict['values'] = self.value_mean_std(res_dict['values'], True)\n        \n        rand_action_mask = torch.bernoulli(rand_action_probs)\n        det_action_mask = rand_action_mask == 0.0\n        res_dict['actions'][det_action_mask] = res_dict['mus'][det_action_mask]\n        res_dict['rand_action_mask'] = rand_action_mask\n\n        return res_dict\n\n    def prepare_dataset(self, batch_dict):\n        super().prepare_dataset(batch_dict)\n        \n        ase_latents = batch_dict['ase_latents']\n        self.dataset.values_dict['ase_latents'] = ase_latents\n        \n        return\n\n    def calc_gradients(self, input_dict):\n        self.set_train()\n\n        value_preds_batch = input_dict['old_values']\n        old_action_log_probs_batch = input_dict['old_logp_actions']\n        advantage = input_dict['advantages']\n        old_mu_batch = input_dict['mu']\n        old_sigma_batch = input_dict['sigma']\n        return_batch = input_dict['returns']\n        actions_batch = input_dict['actions']\n        obs_batch = input_dict['obs']\n        obs_batch = self._preproc_obs(obs_batch)\n\n        amp_obs = input_dict['amp_obs'][0:self._amp_minibatch_size]\n        amp_obs = self._preproc_amp_obs(amp_obs)\n        if (self._enable_enc_grad_penalty()):\n            amp_obs.requires_grad_(True)\n\n        amp_obs_replay = input_dict['amp_obs_replay'][0:self._amp_minibatch_size]\n        amp_obs_replay = self._preproc_amp_obs(amp_obs_replay)\n\n        amp_obs_demo = input_dict['amp_obs_demo'][0:self._amp_minibatch_size]\n        amp_obs_demo = self._preproc_amp_obs(amp_obs_demo)\n        amp_obs_demo.requires_grad_(True)\n\n        ase_latents = input_dict['ase_latents']\n        \n        rand_action_mask = input_dict['rand_action_mask']\n        rand_action_sum = torch.sum(rand_action_mask)\n\n        lr = self.last_lr\n        kl = 1.0\n        lr_mul = 1.0\n        curr_e_clip = lr_mul * self.e_clip\n\n        batch_dict = {\n            'is_train': True,\n            'prev_actions': actions_batch, \n            'obs' : obs_batch,\n            'amp_obs' : amp_obs,\n            'amp_obs_replay' : amp_obs_replay,\n            'amp_obs_demo' : amp_obs_demo,\n            'ase_latents': ase_latents\n        }\n\n        rnn_masks = None\n        if self.is_rnn:\n            rnn_masks = input_dict['rnn_masks']\n            batch_dict['rnn_states'] = input_dict['rnn_states']\n            batch_dict['seq_length'] = self.seq_len\n            \n        rnn_masks = None\n        if self.is_rnn:\n            rnn_masks = input_dict['rnn_masks']\n            batch_dict['rnn_states'] = input_dict['rnn_states']\n            batch_dict['seq_length'] = self.seq_len\n\n        with torch.cuda.amp.autocast(enabled=self.mixed_precision):\n            res_dict = self.model(batch_dict)\n            action_log_probs = res_dict['prev_neglogp']\n            values = res_dict['values']\n            entropy = res_dict['entropy']\n            mu = res_dict['mus']\n            sigma = res_dict['sigmas']\n            disc_agent_logit = res_dict['disc_agent_logit']\n            disc_agent_replay_logit = res_dict['disc_agent_replay_logit']\n            disc_demo_logit = res_dict['disc_demo_logit']\n            enc_pred = res_dict['enc_pred']\n\n            a_info = self._actor_loss(old_action_log_probs_batch, action_log_probs, advantage, curr_e_clip)\n            a_loss = a_info['actor_loss']\n            a_clipped = a_info['actor_clipped'].float()\n\n            c_info = self._critic_loss(value_preds_batch, values, curr_e_clip, return_batch, self.clip_value)\n            c_loss = c_info['critic_loss']\n\n            b_loss = self.bound_loss(mu)\n\n            c_loss = torch.mean(c_loss)\n            a_loss = torch.sum(rand_action_mask * a_loss) / rand_action_sum\n            entropy = torch.sum(rand_action_mask * entropy) / rand_action_sum\n            b_loss = torch.sum(rand_action_mask * b_loss) / rand_action_sum\n            a_clip_frac = torch.sum(rand_action_mask * a_clipped) / rand_action_sum\n            \n            disc_agent_cat_logit = torch.cat([disc_agent_logit, disc_agent_replay_logit], dim=0)\n            disc_info = self._disc_loss(disc_agent_cat_logit, disc_demo_logit, amp_obs_demo)\n            disc_loss = disc_info['disc_loss']\n            \n            enc_latents = batch_dict['ase_latents'][0:self._amp_minibatch_size]\n            enc_loss_mask = rand_action_mask[0:self._amp_minibatch_size]\n            enc_info = self._enc_loss(enc_pred, enc_latents, batch_dict['amp_obs'], enc_loss_mask)\n            enc_loss = enc_info['enc_loss']\n\n            loss = a_loss + self.critic_coef * c_loss - self.entropy_coef * entropy + self.bounds_loss_coef * b_loss \\\n                 + self._disc_coef * disc_loss + self._enc_coef * enc_loss\n            \n            if (self._enable_amp_diversity_bonus()):\n                diversity_loss = self._diversity_loss(batch_dict['obs'], mu, batch_dict['ase_latents'])\n                diversity_loss = torch.sum(rand_action_mask * diversity_loss) / rand_action_sum\n                loss += self._amp_diversity_bonus * diversity_loss\n                a_info['amp_diversity_loss'] = diversity_loss\n                \n            a_info['actor_loss'] = a_loss\n            a_info['actor_clip_frac'] = a_clip_frac\n            c_info['critic_loss'] = c_loss\n\n            if self.multi_gpu:\n                self.optimizer.zero_grad()\n            else:\n                for param in self.model.parameters():\n                    param.grad = None\n\n        self.scaler.scale(loss).backward()\n        #TODO: Refactor this ugliest code of the year\n        if self.truncate_grads:\n            if self.multi_gpu:\n                self.optimizer.synchronize()\n                self.scaler.unscale_(self.optimizer)\n                nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)\n                with self.optimizer.skip_synchronize():\n                    self.scaler.step(self.optimizer)\n                    self.scaler.update()\n            else:\n                self.scaler.unscale_(self.optimizer)\n                nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)\n                self.scaler.step(self.optimizer)\n                self.scaler.update()    \n        else:\n            self.scaler.step(self.optimizer)\n            self.scaler.update()\n\n        with torch.no_grad():\n            reduce_kl = not self.is_rnn\n            kl_dist = torch_ext.policy_kl(mu.detach(), sigma.detach(), old_mu_batch, old_sigma_batch, reduce_kl)\n            if self.is_rnn:\n                kl_dist = (kl_dist * rnn_masks).sum() / rnn_masks.numel()  #/ sum_mask\n        \n        self.train_result = {\n            'entropy': entropy,\n            'kl': kl_dist,\n            'last_lr': self.last_lr, \n            'lr_mul': lr_mul, \n            'b_loss': b_loss\n        }\n        self.train_result.update(a_info)\n        self.train_result.update(c_info)\n        self.train_result.update(disc_info)\n        self.train_result.update(enc_info)\n\n        return\n    \n    def env_reset(self, env_ids=None):\n        obs = super().env_reset(env_ids)\n        \n        if (env_ids is None):\n            num_envs = self.vec_env.env.task.num_envs\n            env_ids = to_torch(np.arange(num_envs), dtype=torch.long, device=self.ppo_device)\n\n        if (len(env_ids) > 0):\n            self._reset_latents(env_ids)\n            self._reset_latent_step_count(env_ids)\n\n        return obs\n\n    def _reset_latent_step_count(self, env_ids):\n        self._latent_reset_steps[env_ids] = torch.randint_like(self._latent_reset_steps[env_ids], low=self._latent_steps_min, \n                                                         high=self._latent_steps_max)\n        return\n\n    def _load_config_params(self, config):\n        super()._load_config_params(config)\n        \n        self._latent_dim = config['latent_dim']\n        self._latent_steps_min = config.get('latent_steps_min', np.inf)\n        self._latent_steps_max = config.get('latent_steps_max', np.inf)\n        self._latent_dim = config['latent_dim']\n        self._amp_diversity_bonus = config['amp_diversity_bonus']\n        self._amp_diversity_tar = config['amp_diversity_tar']\n        \n        self._enc_coef = config['enc_coef']\n        self._enc_weight_decay = config['enc_weight_decay']\n        self._enc_reward_scale = config['enc_reward_scale']\n        self._enc_grad_penalty = config['enc_grad_penalty']\n\n        self._enc_reward_w = config['enc_reward_w']\n\n        return\n    \n    def _build_net_config(self):\n        config = super()._build_net_config()\n        config['ase_latent_shape'] = (self._latent_dim,)\n        return config\n\n    def _reset_latents(self, env_ids):\n        n = len(env_ids)\n        z = self._sample_latents(n)\n        self._ase_latents[env_ids] = z\n\n        if (self.vec_env.env.task.viewer):\n            self._change_char_color(env_ids)\n\n        return\n\n    def _sample_latents(self, n):\n        z = self.model.a2c_network.sample_latents(n)\n        return z\n\n    def _update_latents(self):\n        new_latent_envs = self._latent_reset_steps <= self.vec_env.env.task.progress_buf\n\n        need_update = torch.any(new_latent_envs)\n        if (need_update):\n            new_latent_env_ids = new_latent_envs.nonzero(as_tuple=False).flatten()\n            self._reset_latents(new_latent_env_ids)\n            self._latent_reset_steps[new_latent_env_ids] += torch.randint_like(self._latent_reset_steps[new_latent_env_ids],\n                                                                               low=self._latent_steps_min, \n                                                                               high=self._latent_steps_max)\n            if (self.vec_env.env.task.viewer):\n                self._change_char_color(new_latent_env_ids)\n\n        return\n\n    def _eval_actor(self, obs, ase_latents):\n        output = self.model.eval_actor(obs=obs, ase_latents=ase_latents)\n        return output\n\n    def _eval_critic(self, obs_dict, ase_latents):\n        self.model.eval()\n        obs = obs_dict['obs']\n        processed_obs = self._preproc_obs(obs)\n        value = self.model.eval_critic(processed_obs, ase_latents)\n\n        if self.normalize_value:\n            value = self.value_mean_std(value, True)\n        return value\n\n    def _calc_amp_rewards(self, amp_obs, ase_latents):\n        disc_r = self._calc_disc_rewards(amp_obs)\n        enc_r = self._calc_enc_rewards(amp_obs, ase_latents)\n        output = {\n            'disc_rewards': disc_r,\n            'enc_rewards': enc_r\n        }\n        return output\n\n    def _calc_enc_rewards(self, amp_obs, ase_latents):\n        with torch.no_grad():\n            enc_pred = self._eval_enc(amp_obs)\n            err = self._calc_enc_error(enc_pred, ase_latents)\n            enc_r = torch.clamp_min(-err, 0.0)\n            enc_r *= self._enc_reward_scale\n\n        return enc_r\n\n    def _enc_loss(self, enc_pred, ase_latent, enc_obs, loss_mask):\n        enc_err = self._calc_enc_error(enc_pred, ase_latent)\n        #mask_sum = torch.sum(loss_mask)\n        #enc_err = enc_err.squeeze(-1)\n        #enc_loss = torch.sum(loss_mask * enc_err) / mask_sum\n        enc_loss = torch.mean(enc_err)\n\n        # weight decay\n        if (self._enc_weight_decay != 0):\n            enc_weights = self.model.a2c_network.get_enc_weights()\n            enc_weights = torch.cat(enc_weights, dim=-1)\n            enc_weight_decay = torch.sum(torch.square(enc_weights))\n            enc_loss += self._enc_weight_decay * enc_weight_decay\n            \n        enc_info = {\n            'enc_loss': enc_loss\n        }\n\n        if (self._enable_enc_grad_penalty()):\n            enc_obs_grad = torch.autograd.grad(enc_err, enc_obs, grad_outputs=torch.ones_like(enc_err),\n                                               create_graph=True, retain_graph=True, only_inputs=True)\n            enc_obs_grad = enc_obs_grad[0]\n            enc_obs_grad = torch.sum(torch.square(enc_obs_grad), dim=-1)\n            #enc_grad_penalty = torch.sum(loss_mask * enc_obs_grad) / mask_sum\n            enc_grad_penalty = torch.mean(enc_obs_grad)\n\n            enc_loss += self._enc_grad_penalty * enc_grad_penalty\n\n            enc_info['enc_grad_penalty'] = enc_grad_penalty.detach()\n\n        return enc_info\n\n    def _diversity_loss(self, obs, action_params, ase_latents):\n        assert(self.model.a2c_network.is_continuous)\n\n        n = obs.shape[0]\n        assert(n == action_params.shape[0])\n\n        new_z = self._sample_latents(n)\n        mu, sigma = self._eval_actor(obs=obs, ase_latents=new_z)\n\n        clipped_action_params = torch.clamp(action_params, -1.0, 1.0)\n        clipped_mu = torch.clamp(mu, -1.0, 1.0)\n\n        a_diff = clipped_action_params - clipped_mu\n        a_diff = torch.mean(torch.square(a_diff), dim=-1)\n\n        z_diff = new_z * ase_latents\n        z_diff = torch.sum(z_diff, dim=-1)\n        z_diff = 0.5 - 0.5 * z_diff\n\n        diversity_bonus = a_diff / (z_diff + 1e-5)\n        diversity_loss = torch.square(self._amp_diversity_tar - diversity_bonus)\n\n        return diversity_loss\n\n    def _calc_enc_error(self, enc_pred, ase_latent):\n        err = enc_pred * ase_latent\n        err = -torch.sum(err, dim=-1, keepdim=True)\n        return err\n\n    def _enable_enc_grad_penalty(self):\n        return self._enc_grad_penalty != 0\n\n    def _enable_amp_diversity_bonus(self):\n        return self._amp_diversity_bonus != 0\n\n    def _eval_enc(self, amp_obs):\n        proc_amp_obs = self._preproc_amp_obs(amp_obs)\n        return self.model.a2c_network.eval_enc(proc_amp_obs)\n\n    def _combine_rewards(self, task_rewards, amp_rewards):\n        disc_r = amp_rewards['disc_rewards']\n        enc_r = amp_rewards['enc_rewards']\n        combined_rewards = self._task_reward_w * task_rewards \\\n                         + self._disc_reward_w * disc_r \\\n                         + self._enc_reward_w * enc_r\n        return combined_rewards\n\n    def _record_train_batch_info(self, batch_dict, train_info):\n        super()._record_train_batch_info(batch_dict, train_info)\n        train_info['enc_rewards'] = batch_dict['enc_rewards']\n        return\n\n    def _log_train_info(self, train_info, frame):\n        super()._log_train_info(train_info, frame)\n        \n        self.writer.add_scalar('losses/enc_loss', torch_ext.mean_list(train_info['enc_loss']).item(), frame)\n         \n        if (self._enable_amp_diversity_bonus()):\n            self.writer.add_scalar('losses/amp_diversity_loss', torch_ext.mean_list(train_info['amp_diversity_loss']).item(), frame)\n        \n        enc_reward_std, enc_reward_mean = torch.std_mean(train_info['enc_rewards'])\n        self.writer.add_scalar('info/enc_reward_mean', enc_reward_mean.item(), frame)\n        self.writer.add_scalar('info/enc_reward_std', enc_reward_std.item(), frame)\n\n        if (self._enable_enc_grad_penalty()):\n            self.writer.add_scalar('info/enc_grad_penalty', torch_ext.mean_list(train_info['enc_grad_penalty']).item(), frame)\n\n        return\n\n    def _change_char_color(self, env_ids):\n        base_col = np.array([0.4, 0.4, 0.4])\n        range_col = np.array([0.0706, 0.149, 0.2863])\n        range_sum = np.linalg.norm(range_col)\n\n        rand_col = np.random.uniform(0.0, 1.0, size=3)\n        rand_col = range_sum * rand_col / np.linalg.norm(rand_col)\n        rand_col += base_col\n        self.vec_env.env.task.set_char_color(rand_col, env_ids)\n        return\n\n    def _amp_debug(self, info, ase_latents):\n        with torch.no_grad():\n            amp_obs = info['amp_obs']\n            amp_obs = amp_obs\n            ase_latents = ase_latents\n            disc_pred = self._eval_disc(amp_obs)\n            amp_rewards = self._calc_amp_rewards(amp_obs, ase_latents)\n            disc_reward = amp_rewards['disc_rewards']\n            enc_reward = amp_rewards['enc_rewards']\n\n            disc_pred = disc_pred.detach().cpu().numpy()[0, 0]\n            disc_reward = disc_reward.cpu().numpy()[0, 0]\n            enc_reward = enc_reward.cpu().numpy()[0, 0]\n            print(\"disc_pred: \", disc_pred, disc_reward, enc_reward)\n        return"
  },
  {
    "path": "timechamber/ase/ase_models.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nfrom timechamber.ase.utils import amp_models\n\nclass ModelASEContinuous(amp_models.ModelAMPContinuous):\n    def __init__(self, network):\n        super().__init__(network)\n        return\n\n    def build(self, config):\n        net = self.network_builder.build('ase', **config)\n        for name, _ in net.named_parameters():\n            print(name)\n        # print(f\"ASE config: {config}\")\n        obs_shape = config['input_shape']\n        normalize_value = config.get('normalize_value', False)\n        normalize_input = config.get('normalize_input', False)\n        value_size = config.get('value_size', 1)\n        return ModelASEContinuous.Network(net,obs_shape=obs_shape, normalize_value=normalize_value,\n                                          normalize_input=normalize_input, value_size=value_size)\n\n\n    class Network(amp_models.ModelAMPContinuous.Network):\n        def __init__(self, a2c_network, obs_shape, normalize_value, normalize_input, value_size):\n            super().__init__(a2c_network,\n                             obs_shape=obs_shape, \n                             normalize_value=normalize_value,\n                             normalize_input=normalize_input, \n                             value_size=value_size)\n            return\n\n        def forward(self, input_dict):\n            is_train = input_dict.get('is_train', True)\n            result = super().forward(input_dict)\n\n            if (is_train):\n                amp_obs = input_dict['amp_obs']\n                enc_pred = self.a2c_network.eval_enc(amp_obs)\n                result[\"enc_pred\"] = enc_pred\n\n            return result\n\n        def eval_actor(self, obs, ase_latents, use_hidden_latents=False):\n            processed_obs = self.norm_obs(obs)\n            mu, sigma = self.a2c_network.eval_actor(obs=processed_obs, ase_latents=ase_latents)\n            return mu, sigma\n\n        def eval_critic(self, obs, ase_latents, use_hidden_latents=False):\n            processed_obs = self.norm_obs(obs)\n            value = self.a2c_network.eval_critic(processed_obs, ase_latents, use_hidden_latents)\n            return value"
  },
  {
    "path": "timechamber/ase/ase_network_builder.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nfrom rl_games.algos_torch import torch_ext\nfrom rl_games.algos_torch import layers\nfrom rl_games.algos_torch import network_builder\n\nimport torch\nimport torch.nn as nn\nimport numpy as np\nimport enum\n\nfrom timechamber.ase.utils import amp_network_builder\n\nENC_LOGIT_INIT_SCALE = 0.1\n\nclass LatentType(enum.Enum):\n    uniform = 0\n    sphere = 1\n\nclass ASEBuilder(amp_network_builder.AMPBuilder):\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n        return\n\n    class Network(amp_network_builder.AMPBuilder.Network):\n        def __init__(self, params, **kwargs):\n            actions_num = kwargs.get('actions_num')\n            input_shape = kwargs.get('input_shape')\n            self.value_size = kwargs.get('value_size', 1)\n            self.num_seqs = num_seqs = kwargs.get('num_seqs', 1)\n            amp_input_shape = kwargs.get('amp_input_shape')\n            self._ase_latent_shape = kwargs.get('ase_latent_shape')\n\n            network_builder.NetworkBuilder.BaseNetwork.__init__(self)\n            \n            self.load(params)\n\n            actor_out_size, critic_out_size = self._build_actor_critic_net(input_shape, self._ase_latent_shape)\n\n            self.value = torch.nn.Linear(critic_out_size, self.value_size)\n            self.value_act = self.activations_factory.create(self.value_activation)\n            \n            if self.is_discrete:\n                self.logits = torch.nn.Linear(actor_out_size, actions_num)\n            '''\n                for multidiscrete actions num is a tuple\n            '''\n            if self.is_multi_discrete:\n                self.logits = torch.nn.ModuleList([torch.nn.Linear(actor_out_size, num) for num in actions_num])\n            if self.is_continuous:\n                self.mu = torch.nn.Linear(actor_out_size, actions_num)\n                self.mu_act = self.activations_factory.create(self.space_config['mu_activation']) \n                mu_init = self.init_factory.create(**self.space_config['mu_init'])\n                self.sigma_act = self.activations_factory.create(self.space_config['sigma_activation']) \n\n                sigma_init = self.init_factory.create(**self.space_config['sigma_init'])\n\n                if (not self.space_config['learn_sigma']):\n                    self.sigma = nn.Parameter(torch.zeros(actions_num, requires_grad=False, dtype=torch.float32), requires_grad=False)\n                elif self.space_config['fixed_sigma']:\n                    self.sigma = nn.Parameter(torch.zeros(actions_num, requires_grad=True, dtype=torch.float32), requires_grad=True)\n                else:\n                    self.sigma = torch.nn.Linear(actor_out_size, actions_num)\n\n            mlp_init = self.init_factory.create(**self.initializer)\n            if self.has_cnn:\n                cnn_init = self.init_factory.create(**self.cnn['initializer'])\n\n            for m in self.modules():         \n                if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):\n                    cnn_init(m.weight)\n                    if getattr(m, \"bias\", None) is not None:\n                        torch.nn.init.zeros_(m.bias)\n                if isinstance(m, nn.Linear):\n                    mlp_init(m.weight)\n                    if getattr(m, \"bias\", None) is not None:\n                        torch.nn.init.zeros_(m.bias)    \n\n            self.actor_mlp.init_params()\n            self.critic_mlp.init_params()\n\n            if self.is_continuous:\n                mu_init(self.mu.weight)\n                if self.space_config['fixed_sigma']:\n                    sigma_init(self.sigma)\n                else:\n                    sigma_init(self.sigma.weight)\n\n            self._build_disc(amp_input_shape)\n            self._build_enc(amp_input_shape)\n\n            return\n\n        def load(self, params):\n            super().load(params)\n\n            self._enc_units = params['enc']['units']\n            self._enc_activation = params['enc']['activation']\n            self._enc_initializer = params['enc']['initializer']\n            self._enc_separate = params['enc']['separate']\n\n            return\n\n        def forward(self, obs_dict):\n            obs = obs_dict['obs']\n            ase_latents = obs_dict['ase_latents']\n            states = obs_dict.get('rnn_states', None)\n            use_hidden_latents = obs_dict.get('use_hidden_latents', False)\n\n            actor_outputs = self.eval_actor(obs, ase_latents, use_hidden_latents)\n            value = self.eval_critic(obs, ase_latents, use_hidden_latents)\n\n            output = actor_outputs + (value, states)\n\n            return output\n\n        def eval_critic(self, obs, ase_latents, use_hidden_latents=False):\n            c_out = self.critic_cnn(obs)\n            c_out = c_out.contiguous().view(c_out.size(0), -1)\n            \n            c_out = self.critic_mlp(c_out, ase_latents, use_hidden_latents)\n            value = self.value_act(self.value(c_out))\n            return value\n\n        def eval_actor(self, obs, ase_latents, use_hidden_latents=False):\n            a_out = self.actor_cnn(obs)\n            a_out = a_out.contiguous().view(a_out.size(0), -1)\n            a_out = self.actor_mlp(a_out, ase_latents, use_hidden_latents)\n                     \n            if self.is_discrete:\n                logits = self.logits(a_out)\n                return logits\n\n            if self.is_multi_discrete:\n                logits = [logit(a_out) for logit in self.logits]\n                return logits\n\n            if self.is_continuous:\n                mu = self.mu_act(self.mu(a_out))\n                if self.space_config['fixed_sigma']:\n                    sigma = mu * 0.0 + self.sigma_act(self.sigma)\n                else:\n                    sigma = self.sigma_act(self.sigma(a_out))\n\n                return mu, sigma\n            return\n\n        def get_enc_weights(self):\n            weights = []\n            for m in self._enc_mlp.modules():\n                if isinstance(m, nn.Linear):\n                    weights.append(torch.flatten(m.weight))\n\n            weights.append(torch.flatten(self._enc.weight))\n            return weights\n\n        def _build_actor_critic_net(self, input_shape, ase_latent_shape):\n            style_units = [512, 256]\n            style_dim = ase_latent_shape[-1]\n\n            self.actor_cnn = nn.Sequential()\n            self.critic_cnn = nn.Sequential()\n            \n            act_fn = self.activations_factory.create(self.activation)\n            initializer = self.init_factory.create(**self.initializer)\n\n            self.actor_mlp = AMPStyleCatNet1(obs_size=input_shape[-1],\n                                             ase_latent_size=ase_latent_shape[-1],\n                                             units=self.units,\n                                             activation=act_fn,\n                                             style_units=style_units,\n                                             style_dim=style_dim,\n                                             initializer=initializer)\n\n            if self.separate:\n                self.critic_mlp = AMPMLPNet(obs_size=input_shape[-1],\n                                            ase_latent_size=ase_latent_shape[-1],\n                                            units=self.units,\n                                            activation=act_fn,\n                                            initializer=initializer)\n\n            actor_out_size = self.actor_mlp.get_out_size()\n            critic_out_size = self.critic_mlp.get_out_size()\n\n            return actor_out_size, critic_out_size\n\n        def _build_enc(self, input_shape):\n            if (self._enc_separate):\n                self._enc_mlp = nn.Sequential()\n                mlp_args = {\n                    'input_size' : input_shape[0], \n                    'units' : self._enc_units, \n                    'activation' : self._enc_activation, \n                    'dense_func' : torch.nn.Linear\n                }\n                self._enc_mlp = self._build_mlp(**mlp_args)\n\n                mlp_init = self.init_factory.create(**self._enc_initializer)\n                for m in self._enc_mlp.modules():\n                    if isinstance(m, nn.Linear):\n                        mlp_init(m.weight)\n                        if getattr(m, \"bias\", None) is not None:\n                            torch.nn.init.zeros_(m.bias)\n            else:\n                self._enc_mlp = self._disc_mlp\n\n            mlp_out_layer = list(self._enc_mlp.modules())[-2]\n            mlp_out_size = mlp_out_layer.out_features\n            self._enc = torch.nn.Linear(mlp_out_size, self._ase_latent_shape[-1])\n            \n            torch.nn.init.uniform_(self._enc.weight, -ENC_LOGIT_INIT_SCALE, ENC_LOGIT_INIT_SCALE)\n            torch.nn.init.zeros_(self._enc.bias) \n            \n            return\n\n        def eval_enc(self, amp_obs):\n            enc_mlp_out = self._enc_mlp(amp_obs)\n            enc_output = self._enc(enc_mlp_out)\n            enc_output = torch.nn.functional.normalize(enc_output, dim=-1)\n\n            return enc_output\n\n        def sample_latents(self, n):\n            device = next(self._enc.parameters()).device\n            z = torch.normal(torch.zeros([n, self._ase_latent_shape[-1]], device=device))\n            z = torch.nn.functional.normalize(z, dim=-1)\n            return z\n\n    def build(self, name, **kwargs):\n        net = ASEBuilder.Network(self.params, **kwargs)\n        return net\n\n\nclass AMPMLPNet(torch.nn.Module):\n    def __init__(self, obs_size, ase_latent_size, units, activation, initializer):\n        super().__init__()\n\n        input_size = obs_size + ase_latent_size\n        print('build amp mlp net:', input_size)\n        \n        self._units = units\n        self._initializer = initializer\n        self._mlp = []\n\n        in_size = input_size\n        for i in range(len(units)):\n            unit = units[i]\n            curr_dense = torch.nn.Linear(in_size, unit)\n            self._mlp.append(curr_dense)\n            self._mlp.append(activation)\n            in_size = unit\n\n        self._mlp = nn.Sequential(*self._mlp)\n        self.init_params()\n        return\n\n    def forward(self, obs, latent, skip_style):\n        inputs = [obs, latent]\n        input = torch.cat(inputs, dim=-1)\n        output = self._mlp(input)\n        return output\n\n    def init_params(self):\n        for m in self.modules():\n            if isinstance(m, nn.Linear):\n                self._initializer(m.weight)\n                if getattr(m, \"bias\", None) is not None:\n                    torch.nn.init.zeros_(m.bias)\n        return\n\n    def get_out_size(self):\n        out_size = self._units[-1]\n        return out_size\n\nclass AMPStyleCatNet1(torch.nn.Module):\n    def __init__(self, obs_size, ase_latent_size, units, activation,\n                 style_units, style_dim, initializer):\n        super().__init__()\n\n        print('build amp style cat net:', obs_size, ase_latent_size)\n            \n        self._activation = activation\n        self._initializer = initializer\n        self._dense_layers = []\n        self._units = units\n        self._style_dim = style_dim\n        self._style_activation = torch.tanh\n\n        self._style_mlp = self._build_style_mlp(style_units, ase_latent_size)\n        self._style_dense = torch.nn.Linear(style_units[-1], style_dim)\n\n        in_size = obs_size + style_dim\n        for i in range(len(units)):\n            unit = units[i]\n            out_size = unit\n            curr_dense = torch.nn.Linear(in_size, out_size)\n            self._dense_layers.append(curr_dense)\n            \n            in_size = out_size\n\n        self._dense_layers = nn.ModuleList(self._dense_layers)\n\n        self.init_params()\n\n        return\n\n    def forward(self, obs, latent, skip_style):\n        if (skip_style):\n            style = latent\n        else:\n            style = self.eval_style(latent)\n\n        h = torch.cat([obs, style], dim=-1)\n\n        for i in range(len(self._dense_layers)):\n            curr_dense = self._dense_layers[i]\n            h = curr_dense(h)\n            h = self._activation(h)\n\n        return h\n\n    def eval_style(self, latent):\n        style_h = self._style_mlp(latent)\n        style = self._style_dense(style_h)\n        style = self._style_activation(style)\n        return style\n\n    def init_params(self):\n        scale_init_range = 1.0\n\n        for m in self.modules():\n            if isinstance(m, nn.Linear):\n                self._initializer(m.weight)\n                if getattr(m, \"bias\", None) is not None:\n                    torch.nn.init.zeros_(m.bias)\n\n        nn.init.uniform_(self._style_dense.weight, -scale_init_range, scale_init_range)\n        return\n\n    def get_out_size(self):\n        out_size = self._units[-1]\n        return out_size\n\n    def _build_style_mlp(self, style_units, input_size):\n        in_size = input_size\n        layers = []\n        for unit in style_units:\n            layers.append(torch.nn.Linear(in_size, unit))\n            layers.append(self._activation)\n            in_size = unit\n\n        enc_mlp = nn.Sequential(*layers)\n        return enc_mlp"
  },
  {
    "path": "timechamber/ase/ase_players.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nfrom pytest import param\nimport torch\n\nfrom isaacgym.torch_utils import *\nfrom rl_games.algos_torch import players\n\nfrom timechamber.ase.utils import amp_players\nimport timechamber.ase.ase_network_builder as ase_network_builder\n\nclass ASEPlayer(amp_players.AMPPlayerContinuous):\n    def __init__(self, params):\n        config = params['config']\n        self._latent_dim = config['latent_dim']\n        self._latent_steps_min = config.get('latent_steps_min', np.inf)\n        self._latent_steps_max = config.get('latent_steps_max', np.inf)\n\n        self._enc_reward_scale = config['enc_reward_scale']\n\n        super().__init__(params)\n        \n        if (hasattr(self, 'env')) and self.env is not None:\n            batch_size = self.env.task.num_envs\n        else:\n            batch_size = self.env_info['num_envs']\n        self._ase_latents = torch.zeros((batch_size, self._latent_dim), dtype=torch.float32,\n                                         device=self.device)\n\n        return\n\n    def run(self):\n        self._reset_latent_step_count()\n        super().run()\n        return\n\n    def get_action(self, obs_dict, is_determenistic=False):\n        self._update_latents()\n\n        obs = obs_dict['obs']\n        if len(obs.size()) == len(self.obs_shape):\n            obs = obs.unsqueeze(0)\n        obs = self._preproc_obs(obs)\n        ase_latents = self._ase_latents\n\n        input_dict = {\n            'is_train': False,\n            'prev_actions': None, \n            'obs' : obs,\n            'rnn_states' : self.states,\n            'ase_latents': ase_latents\n        }\n        with torch.no_grad():\n            res_dict = self.model(input_dict)\n        mu = res_dict['mus']\n        action = res_dict['actions']\n        self.states = res_dict['rnn_states']\n        if is_determenistic:\n            current_action = mu\n        else:\n            current_action = action\n        current_action = torch.squeeze(current_action.detach())\n        return  players.rescale_actions(self.actions_low, self.actions_high, torch.clamp(current_action, -1.0, 1.0))\n\n    def env_reset(self, env_ids=None):\n        obs = super().env_reset(env_ids)\n        self._reset_latents(env_ids)\n        return obs\n    \n    def _build_net_config(self):\n        config = super()._build_net_config()\n        config['ase_latent_shape'] = (self._latent_dim,)\n        return config\n    \n    def _reset_latents(self, done_env_ids=None):\n        if (done_env_ids is None):\n            num_envs = self.env.task.num_envs\n            done_env_ids = to_torch(np.arange(num_envs), dtype=torch.long, device=self.device)\n\n        rand_vals = self.model.a2c_network.sample_latents(len(done_env_ids))\n        self._ase_latents[done_env_ids] = rand_vals\n        self._change_char_color(done_env_ids)\n\n        return\n\n    def _update_latents(self):\n        if (self._latent_step_count <= 0):\n            self._reset_latents()\n            self._reset_latent_step_count()\n\n            if (self.env.task.viewer):\n                print(\"Sampling new amp latents------------------------------\")\n                num_envs = self.env.task.num_envs\n                env_ids = to_torch(np.arange(num_envs), dtype=torch.long, device=self.device)\n                self._change_char_color(env_ids)\n        else:\n            self._latent_step_count -= 1\n        return\n    \n    def _reset_latent_step_count(self):\n        self._latent_step_count = np.random.randint(self._latent_steps_min, self._latent_steps_max)\n        return\n\n    def _calc_amp_rewards(self, amp_obs, ase_latents):\n        disc_r = self._calc_disc_rewards(amp_obs)\n        enc_r = self._calc_enc_rewards(amp_obs, ase_latents)\n        output = {\n            'disc_rewards': disc_r,\n            'enc_rewards': enc_r\n        }\n        return output\n    \n    def _calc_enc_rewards(self, amp_obs, ase_latents):\n        with torch.no_grad():\n            enc_pred = self._eval_enc(amp_obs)\n            err = self._calc_enc_error(enc_pred, ase_latents)\n            enc_r = torch.clamp_min(-err, 0.0)\n            enc_r *= self._enc_reward_scale\n\n        return enc_r\n    \n    def _calc_enc_error(self, enc_pred, ase_latent):\n        err = enc_pred * ase_latent\n        err = -torch.sum(err, dim=-1, keepdim=True)\n        return err\n    \n    def _eval_enc(self, amp_obs):\n        proc_amp_obs = self._preproc_amp_obs(amp_obs)\n        return self.model.a2c_network.eval_enc(proc_amp_obs)\n\n    def _amp_debug(self, info):\n        with torch.no_grad():\n            amp_obs = info['amp_obs']\n            amp_obs = amp_obs\n            ase_latents = self._ase_latents\n            disc_pred = self._eval_disc(amp_obs)\n            amp_rewards = self._calc_amp_rewards(amp_obs, ase_latents)\n            disc_reward = amp_rewards['disc_rewards']\n            enc_reward = amp_rewards['enc_rewards']\n\n            disc_pred = disc_pred.detach().cpu().numpy()[0, 0]\n            disc_reward = disc_reward.cpu().numpy()[0, 0]\n            enc_reward = enc_reward.cpu().numpy()[0, 0]\n            print(\"disc_pred: \", disc_pred, disc_reward, enc_reward)\n        return\n\n    def _change_char_color(self, env_ids):\n        base_col = np.array([0.4, 0.4, 0.4])\n        range_col = np.array([0.0706, 0.149, 0.2863])\n        range_sum = np.linalg.norm(range_col)\n\n        rand_col = np.random.uniform(0.0, 1.0, size=3)\n        rand_col = range_sum * rand_col / np.linalg.norm(rand_col)\n        rand_col += base_col\n        self.env.task.set_char_color(rand_col, env_ids)\n        return"
  },
  {
    "path": "timechamber/ase/hrl_agent.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nimport copy\nfrom datetime import datetime\nfrom distutils.command.config import config\nfrom gym import spaces\nimport numpy as np\nimport os\nimport time\nimport yaml\n\nfrom rl_games.algos_torch import torch_ext\nfrom rl_games.algos_torch import central_value\nfrom rl_games.algos_torch.running_mean_std import RunningMeanStd\nfrom rl_games.common import a2c_common\nfrom rl_games.common import datasets\nfrom rl_games.common import schedulers\nfrom rl_games.common import vecenv\n\nimport torch\nfrom torch import optim\n\nimport timechamber.ase.utils.common_agent as common_agent \nimport timechamber.ase.ase_agent as ase_agent\nimport timechamber.ase.ase_models as ase_models\nimport timechamber.ase.ase_network_builder as ase_network_builder\n\nfrom tensorboardX import SummaryWriter\n\nclass HRLAgent(common_agent.CommonAgent):\n    def __init__(self, base_name, params):\n        config = params['config']\n        with open(os.path.join(os.getcwd(), config['llc_config']), 'r') as f:\n            llc_config = yaml.load(f, Loader=yaml.SafeLoader)\n            llc_config_params = llc_config['params']\n            self._latent_dim = llc_config_params['config']['latent_dim']\n\n        super().__init__(base_name, params)\n\n        self._task_size = self.vec_env.env.task.get_task_obs_size()\n\n        self._llc_steps = config['llc_steps']\n        llc_checkpoint = config['llc_checkpoint']\n        assert(llc_checkpoint != \"\")\n        self._build_llc(llc_config_params, llc_checkpoint)\n\n        return\n\n    def env_step(self, actions):\n        actions = self.preprocess_actions(actions)\n        obs = self.obs['obs']\n\n        rewards = 0.0\n        disc_rewards = 0.0\n        done_count = 0.0\n        terminate_count = 0.0\n        for t in range(self._llc_steps):\n            llc_actions = self._compute_llc_action(obs, actions)\n            obs_dict, curr_rewards, curr_dones, infos = self.vec_env.step(llc_actions)\n\n            # TODO\n            obs = obs_dict['obs']\n            \n            rewards += curr_rewards\n            done_count += curr_dones\n            terminate_count += infos['terminate']\n\n            amp_obs = infos['amp_obs']\n            curr_disc_reward = self._calc_disc_reward(amp_obs)\n            disc_rewards += curr_disc_reward\n\n        rewards /= self._llc_steps\n        disc_rewards /= self._llc_steps\n\n        dones = torch.zeros_like(done_count)\n        dones[done_count > 0] = 1.0\n        terminate = torch.zeros_like(terminate_count)\n        terminate[terminate_count > 0] = 1.0\n        infos['terminate'] = terminate\n        infos['disc_rewards'] = disc_rewards\n\n        if self.is_tensor_obses:\n            if self.value_size == 1:\n                rewards = rewards.unsqueeze(1)\n            return self.obs_to_tensors(obs), rewards.to(self.ppo_device), dones.to(self.ppo_device), infos\n        else:\n            if self.value_size == 1:\n                rewards = np.expand_dims(rewards, axis=1)\n            return self.obs_to_tensors(obs), torch.from_numpy(rewards).to(self.ppo_device).float(), torch.from_numpy(dones).to(self.ppo_device), infos\n\n    def cast_obs(self, obs):\n        obs = super().cast_obs(obs)\n        self._llc_agent.is_tensor_obses = self.is_tensor_obses\n        return obs\n\n    def preprocess_actions(self, actions):\n        clamped_actions = torch.clamp(actions, -1.0, 1.0)\n        if not self.is_tensor_obses:\n            clamped_actions = clamped_actions.cpu().numpy()\n        return clamped_actions\n\n    def play_steps(self):\n        self.set_eval()\n        \n        epinfos = []\n        done_indices = torch.tensor([], device=self.device, dtype=torch.long)\n        update_list = self.update_list\n\n        for n in range(self.horizon_length):\n            self.obs = self.env_reset(done_indices)\n            self.experience_buffer.update_data('obses', n, self.obs['obs'])\n\n            if self.use_action_masks:\n                masks = self.vec_env.get_action_masks()\n                res_dict = self.get_masked_action_values(self.obs, masks)\n            else:\n                res_dict = self.get_action_values(self.obs)\n\n            for k in update_list:\n                self.experience_buffer.update_data(k, n, res_dict[k]) \n\n            if self.has_central_value:\n                self.experience_buffer.update_data('states', n, self.obs['states'])\n\n            self.obs, rewards, self.dones, infos = self.env_step(res_dict['actions'])\n            shaped_rewards = self.rewards_shaper(rewards)\n            self.experience_buffer.update_data('rewards', n, shaped_rewards)\n            self.experience_buffer.update_data('next_obses', n, self.obs['obs'])\n            self.experience_buffer.update_data('dones', n, self.dones)\n            \n            self.experience_buffer.update_data('disc_rewards', n, infos['disc_rewards'])\n\n            terminated = infos['terminate'].float()\n            terminated = terminated.unsqueeze(-1)\n            next_vals = self._eval_critic(self.obs)\n            next_vals *= (1.0 - terminated)\n            self.experience_buffer.update_data('next_values', n, next_vals)\n\n            self.current_rewards += rewards\n            self.current_lengths += 1\n            all_done_indices = self.dones.nonzero(as_tuple=False)\n            done_indices = all_done_indices[::self.num_agents]\n  \n            self.game_rewards.update(self.current_rewards[done_indices])\n            self.game_lengths.update(self.current_lengths[done_indices])\n            self.algo_observer.process_infos(infos, done_indices)\n\n            not_dones = 1.0 - self.dones.float()\n\n            self.current_rewards = self.current_rewards * not_dones.unsqueeze(1)\n            self.current_lengths = self.current_lengths * not_dones\n\n            done_indices = done_indices[:, 0]\n\n        mb_fdones = self.experience_buffer.tensor_dict['dones'].float()\n        mb_values = self.experience_buffer.tensor_dict['values']\n        mb_next_values = self.experience_buffer.tensor_dict['next_values']\n\n        mb_rewards = self.experience_buffer.tensor_dict['rewards']\n        mb_disc_rewards = self.experience_buffer.tensor_dict['disc_rewards']\n        mb_rewards = self._combine_rewards(mb_rewards, mb_disc_rewards)\n\n        mb_advs = self.discount_values(mb_fdones, mb_values, mb_rewards, mb_next_values)\n        mb_returns = mb_advs + mb_values\n\n        batch_dict = self.experience_buffer.get_transformed_list(a2c_common.swap_and_flatten01, self.tensor_list)\n        batch_dict['returns'] = a2c_common.swap_and_flatten01(mb_returns)\n        batch_dict['played_frames'] = self.batch_size\n\n        return batch_dict\n    \n    def _load_config_params(self, config):\n        super()._load_config_params(config)\n        \n        self._task_reward_w = config['task_reward_w']\n        self._disc_reward_w = config['disc_reward_w']\n        return\n\n    def _get_mean_rewards(self):\n        rewards = super()._get_mean_rewards()\n        rewards *= self._llc_steps\n        return rewards\n\n    def _setup_action_space(self):\n        super()._setup_action_space()\n        self.actions_num = self._latent_dim\n        return\n\n    def init_tensors(self):\n        super().init_tensors()\n\n        del self.experience_buffer.tensor_dict['actions']\n        del self.experience_buffer.tensor_dict['mus']\n        del self.experience_buffer.tensor_dict['sigmas']\n\n        batch_shape = self.experience_buffer.obs_base_shape\n        self.experience_buffer.tensor_dict['actions'] = torch.zeros(batch_shape + (self._latent_dim,),\n                                                                dtype=torch.float32, device=self.ppo_device)\n        self.experience_buffer.tensor_dict['mus'] = torch.zeros(batch_shape + (self._latent_dim,),\n                                                                dtype=torch.float32, device=self.ppo_device)\n        self.experience_buffer.tensor_dict['sigmas'] = torch.zeros(batch_shape + (self._latent_dim,),\n                                                                dtype=torch.float32, device=self.ppo_device)\n        \n        self.experience_buffer.tensor_dict['disc_rewards'] = torch.zeros_like(self.experience_buffer.tensor_dict['rewards'])\n        self.tensor_list += ['disc_rewards']\n\n        return\n\n    def _build_llc(self, config_params, checkpoint_file):\n        llc_agent_config = self._build_llc_agent_config(config_params)\n        self._llc_agent = ase_agent.ASEAgent('llc', llc_agent_config)\n        self._llc_agent.restore(checkpoint_file)\n        print(\"Loaded LLC checkpoint from {:s}\".format(checkpoint_file))\n        self._llc_agent.set_eval()\n        return\n\n    def _build_llc_agent_config(self, config_params, network=None):\n        llc_env_info = copy.deepcopy(self.env_info)\n        obs_space = llc_env_info['observation_space']\n        obs_size = obs_space.shape[0]\n        obs_size -= self._task_size\n        llc_env_info['observation_space'] = spaces.Box(obs_space.low[:obs_size], obs_space.high[:obs_size])\n\n        params = config_params\n        params['config']['network'] = network\n        params['config']['num_actors'] = self.num_actors\n        params['config']['features'] = {'observer' : self.algo_observer}\n        params['config']['env_info'] = llc_env_info\n        params['config']['device'] = self.device\n\n        return params\n\n    def _compute_llc_action(self, obs, actions):\n        llc_obs = self._extract_llc_obs(obs)\n        processed_obs = self._llc_agent._preproc_obs(llc_obs)\n\n        z = torch.nn.functional.normalize(actions, dim=-1)\n        mu, _ = self._llc_agent.model.eval_actor(obs=processed_obs, ase_latents=z)\n        llc_action = mu\n        llc_action = self._llc_agent.preprocess_actions(llc_action)\n\n        return llc_action\n\n    def _extract_llc_obs(self, obs):\n        obs_size = obs.shape[-1]\n        llc_obs = obs[..., :obs_size - self._task_size]\n        return llc_obs\n\n    def _calc_disc_reward(self, amp_obs):\n        disc_reward = self._llc_agent._calc_disc_rewards(amp_obs)\n        return disc_reward\n\n    def _combine_rewards(self, task_rewards, disc_rewards): \n        combined_rewards = self._task_reward_w * task_rewards + \\\n                         + self._disc_reward_w * disc_rewards\n        \n        #combined_rewards = task_rewards * disc_rewards\n        return combined_rewards\n\n    def _record_train_batch_info(self, batch_dict, train_info):\n        super()._record_train_batch_info(batch_dict, train_info)\n        train_info['disc_rewards'] = batch_dict['disc_rewards']\n        return\n\n    def _log_train_info(self, train_info, frame):\n        super()._log_train_info(train_info, frame)\n\n        disc_reward_std, disc_reward_mean = torch.std_mean(train_info['disc_rewards'])\n        self.writer.add_scalar('info/disc_reward_mean', disc_reward_mean.item(), frame)\n        self.writer.add_scalar('info/disc_reward_std', disc_reward_std.item(), frame)\n        return"
  },
  {
    "path": "timechamber/ase/hrl_models.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nimport torch.nn as nn\nfrom rl_games.algos_torch.models import ModelA2CContinuousLogStd\n\nclass ModelHRLContinuous(ModelA2CContinuousLogStd):\n    def __init__(self, network):\n        super().__init__(network)\n        return\n\n    def build(self, config):\n        net = self.network_builder.build('amp', **config)\n        for name, _ in net.named_parameters():\n            print(name)\n        # print(f\"ASE config: {config}\")\n        obs_shape = config['input_shape']\n        normalize_value = config.get('normalize_value', False)\n        normalize_input = config.get('normalize_input', False)\n        value_size = config.get('value_size', 1)\n        return ModelHRLContinuous.Network(net, obs_shape=obs_shape, normalize_value=normalize_value,\n                                          normalize_input=normalize_input, value_size=value_size)\n\n    class Network(ModelA2CContinuousLogStd.Network):\n        def __init__(self, a2c_network, obs_shape, normalize_value, normalize_input, value_size):\n            super().__init__(a2c_network,\n                             obs_shape=obs_shape,\n                             normalize_value=normalize_value,\n                             normalize_input=normalize_input, \n                             value_size=value_size)\n            return\n\n        def eval_critic(self, obs):\n            processed_obs = self.norm_obs(obs)\n            value = self.a2c_network.eval_critic(processed_obs)\n            values = self.unnorm_value(value)\n            return values"
  },
  {
    "path": "timechamber/ase/hrl_network_builder.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nfrom rl_games.algos_torch import network_builder\n\nimport torch\nimport torch.nn as nn\n\nfrom timechamber.ase import ase_network_builder\n\nclass HRLBuilder(network_builder.A2CBuilder):\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n        return\n\n    class Network(network_builder.A2CBuilder.Network):\n        def __init__(self, params, **kwargs):\n            super().__init__(params, **kwargs)\n\n            if self.is_continuous:\n                if (not self.space_config['learn_sigma']):\n                    actions_num = kwargs.get('actions_num')\n                    sigma_init = self.init_factory.create(**self.space_config['sigma_init'])\n                    self.sigma = nn.Parameter(torch.zeros(actions_num, requires_grad=False, dtype=torch.float32), requires_grad=False)\n                    sigma_init(self.sigma)\n\n            return\n        \n        def forward(self, obs_dict):\n            mu, sigma, value, states = super().forward(obs_dict)\n            norm_mu = torch.tanh(mu)\n            return norm_mu, sigma, value, states\n\n        def eval_critic(self, obs):\n            c_out = self.critic_cnn(obs)\n            c_out = c_out.contiguous().view(c_out.size(0), -1)\n            c_out = self.critic_mlp(c_out)              \n            value = self.value_act(self.value(c_out))\n            return value\n\n    def build(self, name, **kwargs):\n        net = HRLBuilder.Network(self.params, **kwargs)\n        return net"
  },
  {
    "path": "timechamber/ase/hrl_players.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nimport copy\nfrom gym import spaces\nimport numpy as np\nimport os\nimport torch \nimport yaml\nimport time\n\nfrom rl_games.algos_torch import players\nfrom rl_games.algos_torch import torch_ext\nfrom rl_games.algos_torch.running_mean_std import RunningMeanStd\nfrom rl_games.common.player import BasePlayer\n\nimport timechamber.ase.utils.common_player as common_player\nimport timechamber.ase.ase_models as ase_models\nimport timechamber.ase.ase_network_builder as ase_network_builder\nimport timechamber.ase.ase_players as ase_players\n\nclass HRLPlayer(common_player.CommonPlayer):\n    def __init__(self, params):\n        config = params['config']\n        with open(os.path.join(os.getcwd(), config['llc_config']), 'r') as f:\n            llc_config = yaml.load(f, Loader=yaml.SafeLoader)\n            llc_config_params = llc_config['params']\n            self._latent_dim = llc_config_params['config']['latent_dim']\n\n        super().__init__(params)\n\n        self._task_size = self.env.task.get_task_obs_size()\n        \n        self._llc_steps = config['llc_steps']\n        llc_checkpoint = config['llc_checkpoint']\n        assert(llc_checkpoint != \"\")\n        self._build_llc(llc_config_params, llc_checkpoint)\n\n        return\n\n    def get_action(self, obs_dict, is_determenistic = False):\n        obs = obs_dict['obs']\n\n        if len(obs.size()) == len(self.obs_shape):\n            obs = obs.unsqueeze(0)\n        proc_obs = self._preproc_obs(obs)\n        input_dict = {\n            'is_train': False,\n            'prev_actions': None, \n            'obs' : proc_obs,\n            'rnn_states' : self.states\n        }\n        with torch.no_grad():\n            res_dict = self.model(input_dict)\n        mu = res_dict['mus']\n        action = res_dict['actions']\n        self.states = res_dict['rnn_states']\n        if is_determenistic:\n            current_action = mu\n        else:\n            current_action = action\n        current_action = torch.squeeze(current_action.detach())\n        clamped_actions = torch.clamp(current_action, -1.0, 1.0)\n        \n        return clamped_actions\n\n    def run(self):\n        n_games = self.games_num\n        render = self.render_env\n        n_game_life = self.n_game_life\n        is_determenistic = self.is_determenistic\n        sum_rewards = 0\n        sum_steps = 0\n        sum_game_res = 0\n        n_games = n_games * n_game_life\n        games_played = 0\n        has_masks = False\n        has_masks_func = getattr(self.env, \"has_action_mask\", None) is not None\n\n        op_agent = getattr(self.env, \"create_agent\", None)\n        if op_agent:\n            agent_inited = True\n\n        if has_masks_func:\n            has_masks = self.env.has_action_mask()\n\n        need_init_rnn = self.is_rnn\n        for _ in range(n_games):\n            if games_played >= n_games:\n                break\n\n            obs_dict = self.env_reset()\n            batch_size = 1\n            if len(obs_dict['obs'].size()) > len(self.obs_shape):\n                batch_size = obs_dict['obs'].size()[0]\n            self.batch_size = batch_size\n\n            if need_init_rnn:\n                self.init_rnn()\n                need_init_rnn = False\n\n            cr = torch.zeros(batch_size, dtype=torch.float32)\n            steps = torch.zeros(batch_size, dtype=torch.float32)\n\n            print_game_res = False\n\n            done_indices = []\n\n            for n in range(self.max_steps):\n                obs_dict = self.env_reset(done_indices)\n\n                if has_masks:\n                    masks = self.env.get_action_mask()\n                    action = self.get_masked_action(obs_dict, masks, is_determenistic)\n                else:\n                    action = self.get_action(obs_dict, is_determenistic)\n                obs_dict, r, done, info = self.env_step(self.env, obs_dict, action)\n                cr += r\n                steps += 1\n  \n                self._post_step(info)\n\n                if render:\n                    self.env.render(mode = 'human')\n                    time.sleep(self.render_sleep)\n\n                all_done_indices = done.nonzero(as_tuple=False)\n                done_indices = all_done_indices[::self.num_agents]\n                done_count = len(done_indices)\n                games_played += done_count\n\n                if done_count > 0:\n                    if self.is_rnn:\n                        for s in self.states:\n                            s[:,all_done_indices,:] = s[:,all_done_indices,:] * 0.0\n\n                    cur_rewards = cr[done_indices].sum().item()\n                    cur_steps = steps[done_indices].sum().item()\n\n                    cr = cr * (1.0 - done.float())\n                    steps = steps * (1.0 - done.float())\n                    sum_rewards += cur_rewards\n                    sum_steps += cur_steps\n\n                    game_res = 0.0\n                    if isinstance(info, dict):\n                        if 'battle_won' in info:\n                            print_game_res = True\n                            game_res = info.get('battle_won', 0.5)\n                        if 'scores' in info:\n                            print_game_res = True\n                            game_res = info.get('scores', 0.5)\n                    if self.print_stats:\n                        if print_game_res:\n                            print('reward:', cur_rewards/done_count, 'steps:', cur_steps/done_count, 'w:', game_res)\n                        else:\n                            print('reward:', cur_rewards/done_count, 'steps:', cur_steps/done_count)\n\n                    sum_game_res += game_res\n                    if batch_size//self.num_agents == 1 or games_played >= n_games:\n                        break\n        \n                done_indices = done_indices[:, 0]\n\n        print(sum_rewards)\n        if print_game_res:\n            print('av reward:', sum_rewards / games_played * n_game_life, 'av steps:', sum_steps / games_played * n_game_life, 'winrate:', sum_game_res / games_played * n_game_life)\n        else:\n            print('av reward:', sum_rewards / games_played * n_game_life, 'av steps:', sum_steps / games_played * n_game_life)\n\n        return\n\n    def env_step(self, env, obs_dict, action):\n        if not self.is_tensor_obses:\n            actions = actions.cpu().numpy()\n\n        obs = obs_dict['obs']\n        rewards = 0.0\n        done_count = 0.0\n        disc_rewards = 0.0\n        for t in range(self._llc_steps):\n            llc_actions = self._compute_llc_action(obs, action)\n            obs, curr_rewards, curr_dones, infos = env.step(llc_actions)\n\n            rewards += curr_rewards\n            done_count += curr_dones\n\n            amp_obs = infos['amp_obs']\n            curr_disc_reward = self._calc_disc_reward(amp_obs)\n            curr_disc_reward = curr_disc_reward[0, 0].cpu().numpy()\n            disc_rewards += curr_disc_reward\n\n        rewards /= self._llc_steps\n        dones = torch.zeros_like(done_count)\n        dones[done_count > 0] = 1.0\n\n        disc_rewards /= self._llc_steps\n\n        if isinstance(obs, dict):\n            obs = obs['obs']\n        if obs.dtype == np.float64:\n            obs = np.float32(obs)\n        if self.value_size > 1:\n            rewards = rewards[0]\n        if self.is_tensor_obses:\n            return obs, rewards.cpu(), dones.cpu(), infos\n        else:\n            if np.isscalar(dones):\n                rewards = np.expand_dims(np.asarray(rewards), 0)\n                dones = np.expand_dims(np.asarray(dones), 0)\n            return torch.from_numpy(obs).to(self.device), torch.from_numpy(rewards), torch.from_numpy(dones), infos\n\n    def _build_llc(self, config_params, checkpoint_file):\n        llc_agent_config = self._build_llc_agent_config(config_params)\n\n        self._llc_agent = ase_players.ASEPlayer(llc_agent_config)\n        self._llc_agent.restore(checkpoint_file)\n        print(\"Loaded LLC checkpoint from {:s}\".format(checkpoint_file))\n        return\n\n    def _build_llc_agent_config(self, config_params, network=None):\n        llc_env_info = copy.deepcopy(self.env_info)\n        obs_space = llc_env_info['observation_space']\n        obs_size = obs_space.shape[0]\n        obs_size -= self._task_size\n        llc_env_info['observation_space'] = spaces.Box(obs_space.low[:obs_size], obs_space.high[:obs_size])\n        llc_env_info['amp_observation_space'] = self.env.amp_observation_space.shape\n        llc_env_info['num_envs'] = self.env.task.num_envs\n\n        params = config_params\n        params['config']['network'] = network\n        params['config']['env_info'] = llc_env_info\n\n        return params\n\n    def _setup_action_space(self):\n        super()._setup_action_space()\n        self.actions_num = self._latent_dim\n        return\n\n    def _compute_llc_action(self, obs, actions):\n        llc_obs = self._extract_llc_obs(obs)\n        processed_obs = self._llc_agent._preproc_obs(llc_obs)\n\n        z = torch.nn.functional.normalize(actions, dim=-1)\n        mu, _ = self._llc_agent.model.eval_actor(obs=processed_obs, ase_latents=z)\n        llc_action = players.rescale_actions(self.actions_low, self.actions_high, torch.clamp(mu, -1.0, 1.0))\n\n        return llc_action\n\n    def _extract_llc_obs(self, obs):\n        obs_size = obs.shape[-1]\n        llc_obs = obs[..., :obs_size - self._task_size]\n        return llc_obs\n\n    def _calc_disc_reward(self, amp_obs):\n        disc_reward = self._llc_agent._calc_disc_rewards(amp_obs)\n        return disc_reward\n"
  },
  {
    "path": "timechamber/ase/utils/amp_agent.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nfrom rl_games.algos_torch.running_mean_std import RunningMeanStd\nfrom rl_games.algos_torch import torch_ext\nfrom rl_games.common import a2c_common\nfrom rl_games.common import schedulers\nfrom rl_games.common import vecenv\n\nfrom isaacgym.torch_utils import *\n\nimport time\nfrom datetime import datetime\nimport numpy as np\nfrom torch import optim\nimport torch \nfrom torch import nn\n\nimport timechamber.ase.utils.replay_buffer as replay_buffer\nimport timechamber.ase.utils.common_agent as common_agent \n\nfrom tensorboardX import SummaryWriter\n\nclass AMPAgent(common_agent.CommonAgent):\n    def __init__(self, base_name, params):\n        super().__init__(base_name, params)\n\n        if self._normalize_amp_input:\n            self._amp_input_mean_std = RunningMeanStd(self._amp_observation_space.shape).to(self.ppo_device)\n\n        return\n\n    def init_tensors(self):\n        super().init_tensors()\n        self._build_amp_buffers()\n        return\n    \n    def set_eval(self):\n        super().set_eval()\n        if self._normalize_amp_input:\n            self._amp_input_mean_std.eval()\n        return\n\n    def set_train(self):\n        super().set_train()\n        if self._normalize_amp_input:\n            self._amp_input_mean_std.train()\n        return\n\n    def get_stats_weights(self):\n        state = super().get_stats_weights()\n        if self._normalize_amp_input:\n            state['amp_input_mean_std'] = self._amp_input_mean_std.state_dict()\n        \n        return state\n\n    def set_stats_weights(self, weights):\n        super().set_stats_weights(weights)\n        if self._normalize_amp_input:\n            self._amp_input_mean_std.load_state_dict(weights['amp_input_mean_std'])\n        \n        return\n\n    def play_steps(self):\n        self.set_eval()\n\n        epinfos = []\n        done_indices = []\n        update_list = self.update_list\n\n        for n in range(self.horizon_length):\n\n            self.obs = self.env_reset(done_indices)\n            self.experience_buffer.update_data('obses', n, self.obs['obs'])\n\n            if self.use_action_masks:\n                masks = self.vec_env.get_action_masks()\n                res_dict = self.get_masked_action_values(self.obs, masks)\n            else:\n                res_dict = self.get_action_values(self.obs, self._rand_action_probs)\n\n            for k in update_list:\n                self.experience_buffer.update_data(k, n, res_dict[k]) \n\n            if self.has_central_value:\n                self.experience_buffer.update_data('states', n, self.obs['states'])\n\n            self.obs, rewards, self.dones, infos = self.env_step(res_dict['actions'])\n            shaped_rewards = self.rewards_shaper(rewards)\n            self.experience_buffer.update_data('rewards', n, shaped_rewards)\n            self.experience_buffer.update_data('next_obses', n, self.obs['obs'])\n            self.experience_buffer.update_data('dones', n, self.dones)\n            self.experience_buffer.update_data('amp_obs', n, infos['amp_obs'])\n            self.experience_buffer.update_data('rand_action_mask', n, res_dict['rand_action_mask'])\n\n            terminated = infos['terminate'].float()\n            terminated = terminated.unsqueeze(-1)\n            next_vals = self._eval_critic(self.obs)\n            next_vals *= (1.0 - terminated)\n            self.experience_buffer.update_data('next_values', n, next_vals)\n\n            self.current_rewards += rewards\n            self.current_lengths += 1\n            all_done_indices = self.dones.nonzero(as_tuple=False)\n            done_indices = all_done_indices[::self.num_agents]\n  \n            self.game_rewards.update(self.current_rewards[done_indices])\n            self.game_lengths.update(self.current_lengths[done_indices])\n            self.algo_observer.process_infos(infos, done_indices)\n\n            not_dones = 1.0 - self.dones.float()\n\n            self.current_rewards = self.current_rewards * not_dones.unsqueeze(1)\n            self.current_lengths = self.current_lengths * not_dones\n            \n            if (self.vec_env.env.task.viewer):\n                self._amp_debug(infos)\n                \n            done_indices = done_indices[:, 0]\n\n        mb_fdones = self.experience_buffer.tensor_dict['dones'].float()\n        mb_values = self.experience_buffer.tensor_dict['values']\n        mb_next_values = self.experience_buffer.tensor_dict['next_values']\n\n        mb_rewards = self.experience_buffer.tensor_dict['rewards']\n        mb_amp_obs = self.experience_buffer.tensor_dict['amp_obs']\n        amp_rewards = self._calc_amp_rewards(mb_amp_obs)\n        mb_rewards = self._combine_rewards(mb_rewards, amp_rewards)\n\n        mb_advs = self.discount_values(mb_fdones, mb_values, mb_rewards, mb_next_values)\n        mb_returns = mb_advs + mb_values\n\n        batch_dict = self.experience_buffer.get_transformed_list(a2c_common.swap_and_flatten01, self.tensor_list)\n        batch_dict['returns'] = a2c_common.swap_and_flatten01(mb_returns)\n        batch_dict['played_frames'] = self.batch_size\n\n        for k, v in amp_rewards.items():\n            batch_dict[k] = a2c_common.swap_and_flatten01(v)\n\n        return batch_dict\n    \n    def get_action_values(self, obs_dict, rand_action_probs):\n        processed_obs = self._preproc_obs(obs_dict['obs'])\n\n        self.model.eval()\n        input_dict = {\n            'is_train': False,\n            'prev_actions': None, \n            'obs' : processed_obs,\n            'rnn_states' : self.rnn_states\n        }\n\n        with torch.no_grad():\n            res_dict = self.model(input_dict)\n            if self.has_central_value:\n                states = obs_dict['states']\n                input_dict = {\n                    'is_train': False,\n                    'states' : states,\n                }\n                value = self.get_central_value(input_dict)\n                res_dict['values'] = value\n\n        if self.normalize_value:\n            res_dict['values'] = self.value_mean_std(res_dict['values'], True)\n        \n        rand_action_mask = torch.bernoulli(rand_action_probs)\n        det_action_mask = rand_action_mask == 0.0\n        res_dict['actions'][det_action_mask] = res_dict['mus'][det_action_mask]\n        res_dict['rand_action_mask'] = rand_action_mask\n\n        return res_dict\n\n    def prepare_dataset(self, batch_dict):\n        super().prepare_dataset(batch_dict)\n        self.dataset.values_dict['amp_obs'] = batch_dict['amp_obs']\n        self.dataset.values_dict['amp_obs_demo'] = batch_dict['amp_obs_demo']\n        self.dataset.values_dict['amp_obs_replay'] = batch_dict['amp_obs_replay']\n        \n        rand_action_mask = batch_dict['rand_action_mask']\n        self.dataset.values_dict['rand_action_mask'] = rand_action_mask\n        return\n\n    def train_epoch(self):\n        play_time_start = time.time()\n\n        with torch.no_grad():\n            if self.is_rnn:\n                batch_dict = self.play_steps_rnn()\n            else:\n                batch_dict = self.play_steps() \n\n        play_time_end = time.time()\n        update_time_start = time.time()\n        rnn_masks = batch_dict.get('rnn_masks', None)\n        \n        self._update_amp_demos()\n        num_obs_samples = batch_dict['amp_obs'].shape[0]\n        amp_obs_demo = self._amp_obs_demo_buffer.sample(num_obs_samples)['amp_obs']\n        batch_dict['amp_obs_demo'] = amp_obs_demo\n\n        if (self._amp_replay_buffer.get_total_count() == 0):\n            batch_dict['amp_obs_replay'] = batch_dict['amp_obs']\n        else:\n            batch_dict['amp_obs_replay'] = self._amp_replay_buffer.sample(num_obs_samples)['amp_obs']\n\n        self.set_train()\n\n        self.curr_frames = batch_dict.pop('played_frames')\n        self.prepare_dataset(batch_dict)\n        self.algo_observer.after_steps()\n\n        if self.has_central_value:\n            self.train_central_value()\n\n        train_info = None\n\n        if self.is_rnn:\n            frames_mask_ratio = rnn_masks.sum().item() / (rnn_masks.nelement())\n            print(frames_mask_ratio)\n\n        for _ in range(0, self.mini_epochs_num):\n            ep_kls = []\n            for i in range(len(self.dataset)):\n                curr_train_info = self.train_actor_critic(self.dataset[i])\n                \n                if self.schedule_type == 'legacy':  \n                    if self.multi_gpu:\n                        curr_train_info['kl'] = self.hvd.average_value(curr_train_info['kl'], 'ep_kls')\n                    self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, curr_train_info['kl'].item())\n                    self.update_lr(self.last_lr)\n\n                if (train_info is None):\n                    train_info = dict()\n                    for k, v in curr_train_info.items():\n                        train_info[k] = [v]\n                else:\n                    for k, v in curr_train_info.items():\n                        train_info[k].append(v)\n            \n            av_kls = torch_ext.mean_list(train_info['kl'])\n\n            if self.schedule_type == 'standard':\n                if self.multi_gpu:\n                    av_kls = self.hvd.average_value(av_kls, 'ep_kls')\n                self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item())\n                self.update_lr(self.last_lr)\n\n        if self.schedule_type == 'standard_epoch':\n            if self.multi_gpu:\n                av_kls = self.hvd.average_value(torch_ext.mean_list(kls), 'ep_kls')\n            self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item())\n            self.update_lr(self.last_lr)\n\n        update_time_end = time.time()\n        play_time = play_time_end - play_time_start\n        update_time = update_time_end - update_time_start\n        total_time = update_time_end - play_time_start\n\n        self._store_replay_amp_obs(batch_dict['amp_obs'])\n\n        train_info['play_time'] = play_time\n        train_info['update_time'] = update_time\n        train_info['total_time'] = total_time\n        self._record_train_batch_info(batch_dict, train_info)\n\n        return train_info\n\n    def calc_gradients(self, input_dict):\n        self.set_train()\n\n        value_preds_batch = input_dict['old_values']\n        old_action_log_probs_batch = input_dict['old_logp_actions']\n        advantage = input_dict['advantages']\n        old_mu_batch = input_dict['mu']\n        old_sigma_batch = input_dict['sigma']\n        return_batch = input_dict['returns']\n        actions_batch = input_dict['actions']\n        obs_batch = input_dict['obs']\n        obs_batch = self._preproc_obs(obs_batch)\n\n        amp_obs = input_dict['amp_obs'][0:self._amp_minibatch_size]\n        amp_obs = self._preproc_amp_obs(amp_obs)\n        amp_obs_replay = input_dict['amp_obs_replay'][0:self._amp_minibatch_size]\n        amp_obs_replay = self._preproc_amp_obs(amp_obs_replay)\n\n        amp_obs_demo = input_dict['amp_obs_demo'][0:self._amp_minibatch_size]\n        amp_obs_demo = self._preproc_amp_obs(amp_obs_demo)\n        amp_obs_demo.requires_grad_(True)\n        \n        rand_action_mask = input_dict['rand_action_mask']\n        rand_action_sum = torch.sum(rand_action_mask)\n\n        lr = self.last_lr\n        kl = 1.0\n        lr_mul = 1.0\n        curr_e_clip = lr_mul * self.e_clip\n\n        batch_dict = {\n            'is_train': True,\n            'prev_actions': actions_batch, \n            'obs' : obs_batch,\n            'amp_obs' : amp_obs,\n            'amp_obs_replay' : amp_obs_replay,\n            'amp_obs_demo' : amp_obs_demo\n        }\n\n        rnn_masks = None\n        if self.is_rnn:\n            rnn_masks = input_dict['rnn_masks']\n            batch_dict['rnn_states'] = input_dict['rnn_states']\n            batch_dict['seq_length'] = self.seq_len\n\n        with torch.cuda.amp.autocast(enabled=self.mixed_precision):\n            res_dict = self.model(batch_dict)\n            action_log_probs = res_dict['prev_neglogp']\n            values = res_dict['values']\n            entropy = res_dict['entropy']\n            mu = res_dict['mus']\n            sigma = res_dict['sigmas']\n            disc_agent_logit = res_dict['disc_agent_logit']\n            disc_agent_replay_logit = res_dict['disc_agent_replay_logit']\n            disc_demo_logit = res_dict['disc_demo_logit']\n\n            a_info = self._actor_loss(old_action_log_probs_batch, action_log_probs, advantage, curr_e_clip)\n            a_loss = a_info['actor_loss']\n            a_clipped = a_info['actor_clipped'].float()\n\n            c_info = self._critic_loss(value_preds_batch, values, curr_e_clip, return_batch, self.clip_value)\n            c_loss = c_info['critic_loss']\n\n            b_loss = self.bound_loss(mu)\n            \n            c_loss = torch.mean(c_loss)\n            a_loss = torch.sum(rand_action_mask * a_loss) / rand_action_sum\n            entropy = torch.sum(rand_action_mask * entropy) / rand_action_sum\n            b_loss = torch.sum(rand_action_mask * b_loss) / rand_action_sum\n            a_clip_frac = torch.sum(rand_action_mask * a_clipped) / rand_action_sum\n\n            disc_agent_cat_logit = torch.cat([disc_agent_logit, disc_agent_replay_logit], dim=0)\n            disc_info = self._disc_loss(disc_agent_cat_logit, disc_demo_logit, amp_obs_demo)\n            disc_loss = disc_info['disc_loss']\n\n            loss = a_loss + self.critic_coef * c_loss - self.entropy_coef * entropy + self.bounds_loss_coef * b_loss \\\n                 + self._disc_coef * disc_loss\n            \n            a_info['actor_loss'] = a_loss\n            a_info['actor_clip_frac'] = a_clip_frac\n            c_info['critic_loss'] = c_loss\n\n            if self.multi_gpu:\n                self.optimizer.zero_grad()\n            else:\n                for param in self.model.parameters():\n                    param.grad = None\n\n        self.scaler.scale(loss).backward()\n        #TODO: Refactor this ugliest code of the year\n        if self.truncate_grads:\n            if self.multi_gpu:\n                self.optimizer.synchronize()\n                self.scaler.unscale_(self.optimizer)\n                nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)\n                with self.optimizer.skip_synchronize():\n                    self.scaler.step(self.optimizer)\n                    self.scaler.update()\n            else:\n                self.scaler.unscale_(self.optimizer)\n                nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)\n                self.scaler.step(self.optimizer)\n                self.scaler.update()    \n        else:\n            self.scaler.step(self.optimizer)\n            self.scaler.update()\n\n        with torch.no_grad():\n            reduce_kl = not self.is_rnn\n            kl_dist = torch_ext.policy_kl(mu.detach(), sigma.detach(), old_mu_batch, old_sigma_batch, reduce_kl)\n            if self.is_rnn:\n                kl_dist = (kl_dist * rnn_masks).sum() / rnn_masks.numel()  #/ sum_mask\n                    \n        self.train_result = {\n            'entropy': entropy,\n            'kl': kl_dist,\n            'last_lr': self.last_lr, \n            'lr_mul': lr_mul, \n            'b_loss': b_loss\n        }\n        self.train_result.update(a_info)\n        self.train_result.update(c_info)\n        self.train_result.update(disc_info)\n\n        return\n\n    def _load_config_params(self, config):\n        super()._load_config_params(config)\n        \n        # when eps greedy is enabled, rollouts will be generated using a mixture of\n        # a deterministic and stochastic actions. The deterministic actions help to\n        # produce smoother, less noisy, motions that can be used to train a better\n        # discriminator. If the discriminator is only trained with jittery motions\n        # from noisy actions, it can learn to phone in on the jitteriness to\n        # differential between real and fake samples.\n        self._enable_eps_greedy = bool(config['enable_eps_greedy'])\n\n        self._task_reward_w = config['task_reward_w']\n        self._disc_reward_w = config['disc_reward_w']\n\n        self._amp_observation_space = self.env_info['amp_observation_space']\n        self._amp_batch_size = int(config['amp_batch_size'])\n        self._amp_minibatch_size = int(config['amp_minibatch_size'])\n        assert(self._amp_minibatch_size <= self.minibatch_size)\n\n        self._disc_coef = config['disc_coef']\n        self._disc_logit_reg = config['disc_logit_reg']\n        self._disc_grad_penalty = config['disc_grad_penalty']\n        self._disc_weight_decay = config['disc_weight_decay']\n        self._disc_reward_scale = config['disc_reward_scale']\n        self._normalize_amp_input = config.get('normalize_amp_input', True)\n        return\n\n    def _build_net_config(self):\n        config = super()._build_net_config()\n        config['amp_input_shape'] = self._amp_observation_space.shape\n        return config\n    \n    def _build_rand_action_probs(self):\n        num_envs = self.vec_env.env.task.num_envs\n        env_ids = to_torch(np.arange(num_envs), dtype=torch.float32, device=self.ppo_device)\n\n        self._rand_action_probs = 1.0 - torch.exp(10 * (env_ids / (num_envs - 1.0) - 1.0))\n        self._rand_action_probs[0] = 1.0\n        self._rand_action_probs[-1] = 0.0\n        \n        if not self._enable_eps_greedy:\n            self._rand_action_probs[:] = 1.0\n\n        return\n\n    def _init_train(self):\n        super()._init_train()\n        self._init_amp_demo_buf()\n        return\n\n    def _disc_loss(self, disc_agent_logit, disc_demo_logit, obs_demo):\n        # prediction loss\n        disc_loss_agent = self._disc_loss_neg(disc_agent_logit)\n        disc_loss_demo = self._disc_loss_pos(disc_demo_logit)\n        disc_loss = 0.5 * (disc_loss_agent + disc_loss_demo)\n\n        # logit reg\n        logit_weights = self.model.a2c_network.get_disc_logit_weights()\n        disc_logit_loss = torch.sum(torch.square(logit_weights))\n        disc_loss += self._disc_logit_reg * disc_logit_loss\n\n        # grad penalty\n        disc_demo_grad = torch.autograd.grad(disc_demo_logit, obs_demo, grad_outputs=torch.ones_like(disc_demo_logit),\n                                             create_graph=True, retain_graph=True, only_inputs=True)\n        disc_demo_grad = disc_demo_grad[0]\n        disc_demo_grad = torch.sum(torch.square(disc_demo_grad), dim=-1)\n        disc_grad_penalty = torch.mean(disc_demo_grad)\n        disc_loss += self._disc_grad_penalty * disc_grad_penalty\n\n        # weight decay\n        if (self._disc_weight_decay != 0):\n            disc_weights = self.model.a2c_network.get_disc_weights()\n            disc_weights = torch.cat(disc_weights, dim=-1)\n            disc_weight_decay = torch.sum(torch.square(disc_weights))\n            disc_loss += self._disc_weight_decay * disc_weight_decay\n\n        disc_agent_acc, disc_demo_acc = self._compute_disc_acc(disc_agent_logit, disc_demo_logit)\n\n        disc_info = {\n            'disc_loss': disc_loss,\n            'disc_grad_penalty': disc_grad_penalty.detach(),\n            'disc_logit_loss': disc_logit_loss.detach(),\n            'disc_agent_acc': disc_agent_acc.detach(),\n            'disc_demo_acc': disc_demo_acc.detach(),\n            'disc_agent_logit': disc_agent_logit.detach(),\n            'disc_demo_logit': disc_demo_logit.detach()\n        }\n        return disc_info\n\n    def _disc_loss_neg(self, disc_logits):\n        bce = torch.nn.BCEWithLogitsLoss()\n        loss = bce(disc_logits, torch.zeros_like(disc_logits))\n        return loss\n    \n    def _disc_loss_pos(self, disc_logits):\n        bce = torch.nn.BCEWithLogitsLoss()\n        loss = bce(disc_logits, torch.ones_like(disc_logits))\n        return loss\n\n    def _compute_disc_acc(self, disc_agent_logit, disc_demo_logit):\n        agent_acc = disc_agent_logit < 0\n        agent_acc = torch.mean(agent_acc.float())\n        demo_acc = disc_demo_logit > 0\n        demo_acc = torch.mean(demo_acc.float())\n        return agent_acc, demo_acc\n\n    def _fetch_amp_obs_demo(self, num_samples):\n        amp_obs_demo = self.vec_env.env.fetch_amp_obs_demo(num_samples)\n        return amp_obs_demo\n\n    def _build_amp_buffers(self):\n        batch_shape = self.experience_buffer.obs_base_shape\n        self.experience_buffer.tensor_dict['amp_obs'] = torch.zeros(batch_shape + self._amp_observation_space.shape,\n                                                                    device=self.ppo_device)\n        self.experience_buffer.tensor_dict['rand_action_mask'] = torch.zeros(batch_shape, dtype=torch.float32, device=self.ppo_device)\n        \n        amp_obs_demo_buffer_size = int(self.config['amp_obs_demo_buffer_size'])\n        self._amp_obs_demo_buffer = replay_buffer.ReplayBuffer(amp_obs_demo_buffer_size, self.ppo_device)\n\n        self._amp_replay_keep_prob = self.config['amp_replay_keep_prob']\n        replay_buffer_size = int(self.config['amp_replay_buffer_size'])\n        self._amp_replay_buffer = replay_buffer.ReplayBuffer(replay_buffer_size, self.ppo_device)\n        \n        self._build_rand_action_probs()\n        \n        self.tensor_list += ['amp_obs', 'rand_action_mask']\n        return\n\n    def _init_amp_demo_buf(self):\n        buffer_size = self._amp_obs_demo_buffer.get_buffer_size()\n        num_batches = int(np.ceil(buffer_size / self._amp_batch_size))\n\n        for i in range(num_batches):\n            curr_samples = self._fetch_amp_obs_demo(self._amp_batch_size)\n            self._amp_obs_demo_buffer.store({'amp_obs': curr_samples})\n\n        return\n    \n    def _update_amp_demos(self):\n        new_amp_obs_demo = self._fetch_amp_obs_demo(self._amp_batch_size)\n        self._amp_obs_demo_buffer.store({'amp_obs': new_amp_obs_demo})\n        return\n\n    def _preproc_amp_obs(self, amp_obs):\n        if self._normalize_amp_input:\n            amp_obs = self._amp_input_mean_std(amp_obs)\n        return amp_obs\n\n    def _combine_rewards(self, task_rewards, amp_rewards):\n        disc_r = amp_rewards['disc_rewards']\n        \n        combined_rewards = self._task_reward_w * task_rewards + \\\n                         + self._disc_reward_w * disc_r\n        return combined_rewards\n\n    def _eval_disc(self, amp_obs):\n        proc_amp_obs = self._preproc_amp_obs(amp_obs)\n        return self.model.a2c_network.eval_disc(proc_amp_obs)\n    \n    def _calc_advs(self, batch_dict):\n        returns = batch_dict['returns']\n        values = batch_dict['values']\n        rand_action_mask = batch_dict['rand_action_mask']\n\n        advantages = returns - values\n        advantages = torch.sum(advantages, axis=1)\n        if self.normalize_advantage:\n            advantages = torch_ext.normalization_with_masks(advantages, rand_action_mask)\n\n        return advantages\n\n    def _calc_amp_rewards(self, amp_obs):\n        disc_r = self._calc_disc_rewards(amp_obs)\n        output = {\n            'disc_rewards': disc_r\n        }\n        return output\n\n    def _calc_disc_rewards(self, amp_obs):\n        with torch.no_grad():\n            disc_logits = self._eval_disc(amp_obs)\n            prob = 1 / (1 + torch.exp(-disc_logits)) \n            disc_r = -torch.log(torch.maximum(1 - prob, torch.tensor(0.0001, device=self.ppo_device)))\n            disc_r *= self._disc_reward_scale\n\n        return disc_r\n\n    def _store_replay_amp_obs(self, amp_obs):\n        buf_size = self._amp_replay_buffer.get_buffer_size()\n        buf_total_count = self._amp_replay_buffer.get_total_count()\n        if (buf_total_count > buf_size):\n            keep_probs = to_torch(np.array([self._amp_replay_keep_prob] * amp_obs.shape[0]), device=self.ppo_device)\n            keep_mask = torch.bernoulli(keep_probs) == 1.0\n            amp_obs = amp_obs[keep_mask]\n\n        if (amp_obs.shape[0] > buf_size):\n            rand_idx = torch.randperm(amp_obs.shape[0])\n            rand_idx = rand_idx[:buf_size]\n            amp_obs = amp_obs[rand_idx]\n\n        self._amp_replay_buffer.store({'amp_obs': amp_obs})\n        return\n\n    \n    def _record_train_batch_info(self, batch_dict, train_info):\n        super()._record_train_batch_info(batch_dict, train_info)\n        train_info['disc_rewards'] = batch_dict['disc_rewards']\n        return\n\n    def _log_train_info(self, train_info, frame):\n        super()._log_train_info(train_info, frame)\n\n        self.writer.add_scalar('losses/disc_loss', torch_ext.mean_list(train_info['disc_loss']).item(), frame)\n\n        self.writer.add_scalar('info/disc_agent_acc', torch_ext.mean_list(train_info['disc_agent_acc']).item(), frame)\n        self.writer.add_scalar('info/disc_demo_acc', torch_ext.mean_list(train_info['disc_demo_acc']).item(), frame)\n        self.writer.add_scalar('info/disc_agent_logit', torch_ext.mean_list(train_info['disc_agent_logit']).item(), frame)\n        self.writer.add_scalar('info/disc_demo_logit', torch_ext.mean_list(train_info['disc_demo_logit']).item(), frame)\n        self.writer.add_scalar('info/disc_grad_penalty', torch_ext.mean_list(train_info['disc_grad_penalty']).item(), frame)\n        self.writer.add_scalar('info/disc_logit_loss', torch_ext.mean_list(train_info['disc_logit_loss']).item(), frame)\n\n        disc_reward_std, disc_reward_mean = torch.std_mean(train_info['disc_rewards'])\n        self.writer.add_scalar('info/disc_reward_mean', disc_reward_mean.item(), frame)\n        self.writer.add_scalar('info/disc_reward_std', disc_reward_std.item(), frame)\n        return\n\n    def _amp_debug(self, info):\n        with torch.no_grad():\n            amp_obs = info['amp_obs']\n            amp_obs = amp_obs[0:1]\n            disc_pred = self._eval_disc(amp_obs)\n            amp_rewards = self._calc_amp_rewards(amp_obs)\n            disc_reward = amp_rewards['disc_rewards']\n\n            disc_pred = disc_pred.detach().cpu().numpy()[0, 0]\n            disc_reward = disc_reward.cpu().numpy()[0, 0]\n            print(\"disc_pred: \", disc_pred, disc_reward)\n        return"
  },
  {
    "path": "timechamber/ase/utils/amp_datasets.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nimport torch\nfrom rl_games.common import datasets\n\nclass AMPDataset(datasets.PPODataset):\n    def __init__(self, batch_size, minibatch_size, is_discrete, is_rnn, device, seq_len):\n        super().__init__(batch_size, minibatch_size, is_discrete, is_rnn, device, seq_len)\n        self._idx_buf = torch.randperm(batch_size)\n        return\n    \n    def update_mu_sigma(self, mu, sigma):\t  \n        raise NotImplementedError()\n        return\n\n    def _get_item(self, idx):\n        start = idx * self.minibatch_size\n        end = (idx + 1) * self.minibatch_size\n        sample_idx = self._idx_buf[start:end]\n\n        input_dict = {}\n        for k,v in self.values_dict.items():\n            if k not in self.special_names and v is not None:\n                input_dict[k] = v[sample_idx]\n                \n        if (end >= self.batch_size):\n            self._shuffle_idx_buf()\n\n        return input_dict\n\n    def _shuffle_idx_buf(self):\n        self._idx_buf[:] = torch.randperm(self.batch_size)\n        return"
  },
  {
    "path": "timechamber/ase/utils/amp_models.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nimport torch.nn as nn\nfrom rl_games.algos_torch.models import ModelA2CContinuousLogStd\n\n\nclass ModelAMPContinuous(ModelA2CContinuousLogStd):\n    def __init__(self, network):\n        super().__init__(network)\n        return\n\n    def build(self, config):\n        net = self.network_builder.build('amp', **config)\n        for name, _ in net.named_parameters():\n            print(name)\n        # print(f\"AMP config: {config}\")\n        obs_shape = config['input_shape']\n        normalize_value = config.get('normalize_value', False)\n        normalize_input = config.get('normalize_input', False)\n        value_size = config.get('value_size', 1)\n\n        return ModelAMPContinuous.Network(net, obs_shape=obs_shape, normalize_value=normalize_value,\n                                          normalize_input=normalize_input, value_size=value_size)\n\n    class Network(ModelA2CContinuousLogStd.Network):\n        def __init__(self, a2c_network, obs_shape, normalize_value, normalize_input, value_size):\n            super().__init__(a2c_network, obs_shape=obs_shape, \n                             normalize_value=normalize_value,\n                             normalize_input=normalize_input, \n                             value_size=value_size)\n            return\n\n        def forward(self, input_dict):\n            is_train = input_dict.get('is_train', True)\n            result = super().forward(input_dict)\n\n            if (is_train):\n                amp_obs = input_dict['amp_obs']\n                disc_agent_logit = self.a2c_network.eval_disc(amp_obs)\n                result[\"disc_agent_logit\"] = disc_agent_logit\n\n                amp_obs_replay = input_dict['amp_obs_replay']\n                disc_agent_replay_logit = self.a2c_network.eval_disc(amp_obs_replay)\n                result[\"disc_agent_replay_logit\"] = disc_agent_replay_logit\n\n                amp_demo_obs = input_dict['amp_obs_demo']\n                disc_demo_logit = self.a2c_network.eval_disc(amp_demo_obs)\n                result[\"disc_demo_logit\"] = disc_demo_logit\n\n            return result\n    \n        def eval_actor(self, obs):\n            processed_obs = self.norm_obs(obs)\n            mu, sigma = self.a2c_network.eval_actor(obs=processed_obs)\n            return mu, sigma\n\n        def eval_critic(self, obs):\n            processed_obs = self.norm_obs(obs)\n            value = self.a2c_network.eval_critic(processed_obs)\n            return value"
  },
  {
    "path": "timechamber/ase/utils/amp_network_builder.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nfrom rl_games.algos_torch import torch_ext\nfrom rl_games.algos_torch import layers\nfrom rl_games.algos_torch import network_builder\n\nimport torch\nimport torch.nn as nn\nimport numpy as np\n\nDISC_LOGIT_INIT_SCALE = 1.0\n\nclass AMPBuilder(network_builder.A2CBuilder):\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n        return\n\n    class Network(network_builder.A2CBuilder.Network):\n        def __init__(self, params, **kwargs):\n            super().__init__(params, **kwargs)\n\n            if self.is_continuous:\n                if (not self.space_config['learn_sigma']):\n                    actions_num = kwargs.get('actions_num')\n                    sigma_init = self.init_factory.create(**self.space_config['sigma_init'])\n                    self.sigma = nn.Parameter(torch.zeros(actions_num, requires_grad=False, dtype=torch.float32), requires_grad=False)\n                    sigma_init(self.sigma)\n\n            amp_input_shape = kwargs.get('amp_input_shape')\n            self._build_disc(amp_input_shape)\n\n            return\n\n        def load(self, params):\n            super().load(params)\n\n            self._disc_units = params['disc']['units']\n            self._disc_activation = params['disc']['activation']\n            self._disc_initializer = params['disc']['initializer']\n            return\n\n        def forward(self, obs_dict):\n            obs = obs_dict['obs']\n            states = obs_dict.get('rnn_states', None)\n\n            actor_outputs = self.eval_actor(obs)\n            value = self.eval_critic(obs)\n\n            output = actor_outputs + (value, states)\n\n            return output\n\n        def eval_actor(self, obs):\n            a_out = self.actor_cnn(obs)\n            a_out = a_out.contiguous().view(a_out.size(0), -1)\n            a_out = self.actor_mlp(a_out)\n                     \n            if self.is_discrete:\n                logits = self.logits(a_out)\n                return logits\n\n            if self.is_multi_discrete:\n                logits = [logit(a_out) for logit in self.logits]\n                return logits\n\n            if self.is_continuous:\n                mu = self.mu_act(self.mu(a_out))\n                if self.space_config['fixed_sigma']:\n                    sigma = mu * 0.0 + self.sigma_act(self.sigma)\n                else:\n                    sigma = self.sigma_act(self.sigma(a_out))\n\n                return mu, sigma\n            return\n\n        def eval_critic(self, obs):\n            c_out = self.critic_cnn(obs)\n            c_out = c_out.contiguous().view(c_out.size(0), -1)\n            c_out = self.critic_mlp(c_out)              \n            value = self.value_act(self.value(c_out))\n            return value\n\n        def eval_disc(self, amp_obs):\n            disc_mlp_out = self._disc_mlp(amp_obs)\n            disc_logits = self._disc_logits(disc_mlp_out)\n            return disc_logits\n\n        def get_disc_logit_weights(self):\n            return torch.flatten(self._disc_logits.weight)\n\n        def get_disc_weights(self):\n            weights = []\n            for m in self._disc_mlp.modules():\n                if isinstance(m, nn.Linear):\n                    weights.append(torch.flatten(m.weight))\n\n            weights.append(torch.flatten(self._disc_logits.weight))\n            return weights\n\n        def _build_disc(self, input_shape):\n            self._disc_mlp = nn.Sequential()\n\n            mlp_args = {\n                'input_size' : input_shape[0], \n                'units' : self._disc_units, \n                'activation' : self._disc_activation, \n                'dense_func' : torch.nn.Linear\n            }\n            self._disc_mlp = self._build_mlp(**mlp_args)\n\n            mlp_out_size = self._disc_units[-1]\n            self._disc_logits = torch.nn.Linear(mlp_out_size, 1)\n\n            mlp_init = self.init_factory.create(**self._disc_initializer)\n            for m in self._disc_mlp.modules():\n                if isinstance(m, nn.Linear):\n                    mlp_init(m.weight)\n                    if getattr(m, \"bias\", None) is not None:\n                        torch.nn.init.zeros_(m.bias) \n\n            torch.nn.init.uniform_(self._disc_logits.weight, -DISC_LOGIT_INIT_SCALE, DISC_LOGIT_INIT_SCALE)\n            torch.nn.init.zeros_(self._disc_logits.bias) \n\n            return\n\n    def build(self, name, **kwargs):\n        net = AMPBuilder.Network(self.params, **kwargs)\n        return net"
  },
  {
    "path": "timechamber/ase/utils/amp_players.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nimport torch \n\nfrom rl_games.algos_torch import torch_ext\nfrom rl_games.algos_torch.running_mean_std import RunningMeanStd\n\nimport timechamber.ase.utils.common_player as common_player\n\nclass AMPPlayerContinuous(common_player.CommonPlayer):\n    def __init__(self, params):\n        config = params['config']\n        self._normalize_amp_input = config.get('normalize_amp_input', True)\n        self._disc_reward_scale = config['disc_reward_scale']\n        \n        super().__init__(params)\n        return\n\n    def restore(self, fn):\n        if (fn != 'Base'):\n            super().restore(fn)\n            if self._normalize_amp_input:\n                checkpoint = torch_ext.load_checkpoint(fn)\n                self._amp_input_mean_std.load_state_dict(checkpoint['amp_input_mean_std'])\n        return\n    \n    def _build_net(self, config):\n        super()._build_net(config)\n        \n        if self._normalize_amp_input:\n            self._amp_input_mean_std = RunningMeanStd(config['amp_input_shape']).to(self.device)\n            self._amp_input_mean_std.eval()  \n        \n        return\n\n    def _post_step(self, info):\n        super()._post_step(info)\n        if (self.env.task.viewer):\n            self._amp_debug(info)\n        return\n\n    def _build_net_config(self):\n        config = super()._build_net_config()\n        if (hasattr(self, 'env')) and self.env is not None:\n            config['amp_input_shape'] = self.env.amp_observation_space.shape\n        else:\n            config['amp_input_shape'] = self.env_info['amp_observation_space']\n        return config\n\n    def _amp_debug(self, info):\n        with torch.no_grad():\n            amp_obs = info['amp_obs']\n            amp_obs = amp_obs[0:1]\n            disc_pred = self._eval_disc(amp_obs)\n            amp_rewards = self._calc_amp_rewards(amp_obs)\n            disc_reward = amp_rewards['disc_rewards']\n\n            disc_pred = disc_pred.detach().cpu().numpy()[0, 0]\n            disc_reward = disc_reward.cpu().numpy()[0, 0]\n            print(\"disc_pred: \", disc_pred, disc_reward)\n\n        return\n\n    def _preproc_amp_obs(self, amp_obs):\n        if self._normalize_amp_input:\n            amp_obs = self._amp_input_mean_std(amp_obs)\n        return amp_obs\n\n    def _eval_disc(self, amp_obs):\n        proc_amp_obs = self._preproc_amp_obs(amp_obs)\n        return self.model.a2c_network.eval_disc(proc_amp_obs)\n\n    def _calc_amp_rewards(self, amp_obs):\n        disc_r = self._calc_disc_rewards(amp_obs)\n        output = {\n            'disc_rewards': disc_r\n        }\n        return output\n\n    def _calc_disc_rewards(self, amp_obs):\n        with torch.no_grad():\n            disc_logits = self._eval_disc(amp_obs)\n            prob = 1 / (1 + torch.exp(-disc_logits)) \n            disc_r = -torch.log(torch.maximum(1 - prob, torch.tensor(0.0001, device=self.device)))\n            disc_r *= self._disc_reward_scale\n        return disc_r\n"
  },
  {
    "path": "timechamber/ase/utils/common_agent.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nimport copy\nfrom datetime import datetime\nfrom gym import spaces\nimport numpy as np\nimport os\nimport time\nimport yaml\n\nfrom rl_games.algos_torch import a2c_continuous\nfrom rl_games.algos_torch import torch_ext\nfrom rl_games.algos_torch import central_value\nfrom rl_games.algos_torch.running_mean_std import RunningMeanStd\nfrom rl_games.common import a2c_common\nfrom rl_games.common import datasets\nfrom rl_games.common import schedulers\nfrom rl_games.common import vecenv\n\nimport torch\nfrom torch import optim\n\nimport timechamber.ase.utils.amp_datasets as amp_datasets\nfrom timechamber.utils.utils import load_check, load_checkpoint\n\nfrom tensorboardX import SummaryWriter\n\nclass CommonAgent(a2c_continuous.A2CAgent):\n    def __init__(self, base_name, params):\n        a2c_common.A2CBase.__init__(self, base_name, params)\n        self.config = config = params['config']\n        self._load_config_params(config)\n\n        self.is_discrete = False\n        self._setup_action_space()\n        self.bounds_loss_coef = config.get('bounds_loss_coef', None)\n        self.clip_actions = config.get('clip_actions', True)\n        self._save_intermediate = config.get('save_intermediate', False)\n\n        net_config = self._build_net_config()\n        self.model = self.network.build(net_config)\n        self.model.to(self.ppo_device)\n        self.states = None\n\n        self.init_rnn_from_model(self.model)\n        self.last_lr = float(self.last_lr)\n\n        self.optimizer = optim.Adam(self.model.parameters(), float(self.last_lr), eps=1e-08, weight_decay=self.weight_decay)\n\n        if self.normalize_input:\n            obs_shape = torch_ext.shape_whc_to_cwh(self.obs_shape)\n            self.running_mean_std = RunningMeanStd(obs_shape).to(self.ppo_device)\n        if self.normalize_value:\n            self.value_mean_std = self.central_value_net.model.value_mean_std if self.has_central_value else self.model.value_mean_std\n\n        if self.has_central_value:\n            cv_config = {\n                'state_shape' : torch_ext.shape_whc_to_cwh(self.state_shape), \n                'value_size' : self.value_size,\n                'ppo_device' : self.ppo_device, \n                'num_agents' : self.num_agents, \n                'horizon_length' : self.horizon_length, \n                'num_actors' : self.num_actors, \n                'num_actions' : self.actions_num, \n                'seq_len' : self.seq_len, \n                'model' : self.central_value_config['network'],\n                'config' : self.central_value_config, \n                'writter' : self.writer,\n                'multi_gpu' : self.multi_gpu\n            }\n            self.central_value_net = central_value.CentralValueTrain(**cv_config).to(self.ppo_device)\n\n        self.use_experimental_cv = self.config.get('use_experimental_cv', True)\n        self.dataset = amp_datasets.AMPDataset(self.batch_size, self.minibatch_size, self.is_discrete, self.is_rnn, self.ppo_device, self.seq_len)\n        self.algo_observer.after_init(self)\n\n        return\n\n    def init_tensors(self):\n        super().init_tensors()\n        self.experience_buffer.tensor_dict['next_obses'] = torch.zeros_like(self.experience_buffer.tensor_dict['obses'])\n        self.experience_buffer.tensor_dict['next_values'] = torch.zeros_like(self.experience_buffer.tensor_dict['values'])\n\n        self.tensor_list += ['next_obses']\n        return\n\n    def train(self):\n        self.init_tensors()\n        self.last_mean_rewards = -100500\n        start_time = time.time()\n        total_time = 0\n        rep_count = 0\n        self.frame = 0\n        self.obs = self.env_reset()\n        self.curr_frames = self.batch_size_envs\n\n        model_output_file = os.path.join(self.nn_dir, self.config['name'])\n\n        if self.multi_gpu:\n            self.hvd.setup_algo(self)\n\n        self._init_train()\n\n        while True:\n            epoch_num = self.update_epoch()\n            train_info = self.train_epoch()\n\n            sum_time = train_info['total_time']\n            total_time += sum_time\n            frame = self.frame\n            if self.multi_gpu:\n                self.hvd.sync_stats(self)\n\n            if self.rank == 0:\n                scaled_time = sum_time\n                scaled_play_time = train_info['play_time']\n                curr_frames = self.curr_frames\n                self.frame += curr_frames\n                if self.print_stats:\n                    fps_step = curr_frames / scaled_play_time\n                    fps_total = curr_frames / scaled_time\n                    print(f'fps step: {fps_step:.1f} fps total: {fps_total:.1f}')\n\n                self.writer.add_scalar('performance/total_fps', curr_frames / scaled_time, frame)\n                self.writer.add_scalar('performance/step_fps', curr_frames / scaled_play_time, frame)\n                self.writer.add_scalar('info/epochs', epoch_num, frame)\n                self._log_train_info(train_info, frame)\n\n                self.algo_observer.after_print_stats(frame, epoch_num, total_time)\n                \n                if self.game_rewards.current_size > 0:\n                    mean_rewards = self._get_mean_rewards()\n                    mean_lengths = self.game_lengths.get_mean()\n\n                    for i in range(self.value_size):\n                        self.writer.add_scalar('rewards{0}/frame'.format(i), mean_rewards[i], frame)\n                        self.writer.add_scalar('rewards{0}/iter'.format(i), mean_rewards[i], epoch_num)\n                        self.writer.add_scalar('rewards{0}/time'.format(i), mean_rewards[i], total_time)\n\n                    self.writer.add_scalar('episode_lengths/frame', mean_lengths, frame)\n                    self.writer.add_scalar('episode_lengths/iter', mean_lengths, epoch_num)\n\n                    if self.has_self_play_config:\n                        self.self_play_manager.update(self)\n\n                if self.save_freq > 0:\n                    if (epoch_num % self.save_freq == 0):\n                        self.save(model_output_file)\n\n                        if (self._save_intermediate):\n                            int_model_output_file = model_output_file + '_' + str(epoch_num).zfill(8)\n                            self.save(int_model_output_file)\n\n                if epoch_num > self.max_epochs:\n                    self.save(model_output_file)\n                    print('MAX EPOCHS NUM!')\n                    return self.last_mean_rewards, epoch_num\n\n                update_time = 0\n        return\n\n    def set_full_state_weights(self, weights):\n        self.set_weights(weights)\n        self.epoch_num = weights['epoch']\n        if self.has_central_value:\n            self.central_value_net.load_state_dict(weights['assymetric_vf_nets'])\n        self.optimizer.load_state_dict(weights['optimizer'])\n        self.frame = weights.get('frame', 0)\n        self.last_mean_rewards = weights.get('last_mean_rewards', -100500)\n\n        if self.vec_env is not None:\n            env_state = weights.get('env_state', None)\n            self.vec_env.set_env_state(env_state)\n\n        return\n\n    def restore(self, fn):\n        checkpoint = load_checkpoint(fn, device=self.device)\n        checkpoint = load_check(checkpoint=checkpoint,\n                                normalize_input=self.normalize_input,\n                                normalize_value=self.normalize_value)\n        self.set_full_state_weights(checkpoint)\n\n    def train_epoch(self):\n        play_time_start = time.time()\n        with torch.no_grad():\n            if self.is_rnn:\n                batch_dict = self.play_steps_rnn()\n            else:\n                batch_dict = self.play_steps() \n\n        play_time_end = time.time()\n        update_time_start = time.time()\n        rnn_masks = batch_dict.get('rnn_masks', None)\n\n        self.set_train()\n\n        self.curr_frames = batch_dict.pop('played_frames')\n        self.prepare_dataset(batch_dict)\n        self.algo_observer.after_steps()\n\n        if self.has_central_value:\n            self.train_central_value()\n\n        train_info = None\n\n        if self.is_rnn:\n            frames_mask_ratio = rnn_masks.sum().item() / (rnn_masks.nelement())\n            print(frames_mask_ratio)\n\n        for _ in range(0, self.mini_epochs_num):\n            ep_kls = []\n            for i in range(len(self.dataset)):\n                curr_train_info = self.train_actor_critic(self.dataset[i])\n                \n                if self.schedule_type == 'legacy':  \n                    if self.multi_gpu:\n                        curr_train_info['kl'] = self.hvd.average_value(curr_train_info['kl'], 'ep_kls')\n                    self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, curr_train_info['kl'].item())\n                    self.update_lr(self.last_lr)\n\n                if (train_info is None):\n                    train_info = dict()\n                    for k, v in curr_train_info.items():\n                        train_info[k] = [v]\n                else:\n                    for k, v in curr_train_info.items():\n                        train_info[k].append(v)\n            \n            av_kls = torch_ext.mean_list(train_info['kl'])\n\n            if self.schedule_type == 'standard':\n                if self.multi_gpu:\n                    av_kls = self.hvd.average_value(av_kls, 'ep_kls')\n                self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item())\n                self.update_lr(self.last_lr)\n\n        if self.schedule_type == 'standard_epoch':\n            if self.multi_gpu:\n                av_kls = self.hvd.average_value(torch_ext.mean_list(kls), 'ep_kls')\n            self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item())\n            self.update_lr(self.last_lr)\n\n        update_time_end = time.time()\n        play_time = play_time_end - play_time_start\n        update_time = update_time_end - update_time_start\n        total_time = update_time_end - play_time_start\n\n        train_info['step_time'] = batch_dict['step_time']\n        train_info['play_time'] = play_time\n        train_info['update_time'] = update_time\n        train_info['total_time'] = total_time\n        self._record_train_batch_info(batch_dict, train_info)\n\n        return train_info\n\n    def play_steps(self):\n        self.set_eval()\n\n        epinfos = []\n        done_indices = []\n        update_list = self.update_list\n\n        for n in range(self.horizon_length):\n            self.obs = self.env_reset(done_indices)\n            self.experience_buffer.update_data('obses', n, self.obs['obs'])\n\n            if self.use_action_masks:\n                masks = self.vec_env.get_action_masks()\n                res_dict = self.get_masked_action_values(self.obs, masks)\n            else:\n                res_dict = self.get_action_values(self.obs)\n\n            for k in update_list:\n                self.experience_buffer.update_data(k, n, res_dict[k])\n\n            if self.has_central_value:\n                self.experience_buffer.update_data('states', n, self.obs['states'])\n\n            self.obs, rewards, self.dones, infos = self.env_step(res_dict['actions'])\n            shaped_rewards = self.rewards_shaper(rewards)\n            self.experience_buffer.update_data('rewards', n, shaped_rewards)\n            self.experience_buffer.update_data('next_obses', n, self.obs['obs'])\n            self.experience_buffer.update_data('dones', n, self.dones)\n\n            terminated = infos['terminate'].float()\n            terminated = terminated.unsqueeze(-1)\n            next_vals = self._eval_critic(self.obs)\n            next_vals *= (1.0 - terminated)\n            self.experience_buffer.update_data('next_values', n, next_vals)\n\n            self.current_rewards += rewards\n            self.current_lengths += 1\n            all_done_indices = self.dones.nonzero(as_tuple=False)\n            done_indices = all_done_indices[::self.num_agents]\n  \n            self.game_rewards.update(self.current_rewards[done_indices])\n            self.game_lengths.update(self.current_lengths[done_indices])\n            self.algo_observer.process_infos(infos, done_indices)\n\n            not_dones = 1.0 - self.dones.float()\n\n            self.current_rewards = self.current_rewards * not_dones.unsqueeze(1)\n            self.current_lengths = self.current_lengths * not_dones\n\n            done_indices = done_indices[:, 0]\n\n        mb_fdones = self.experience_buffer.tensor_dict['dones'].float()\n        mb_values = self.experience_buffer.tensor_dict['values']\n        mb_next_values = self.experience_buffer.tensor_dict['next_values']\n        mb_rewards = self.experience_buffer.tensor_dict['rewards']\n        \n        mb_advs = self.discount_values(mb_fdones, mb_values, mb_rewards, mb_next_values)\n        mb_returns = mb_advs + mb_values\n\n        batch_dict = self.experience_buffer.get_transformed_list(a2c_common.swap_and_flatten01, self.tensor_list)\n        batch_dict['returns'] = a2c_common.swap_and_flatten01(mb_returns)\n        batch_dict['played_frames'] = self.batch_size\n\n        return batch_dict\n\n    def prepare_dataset(self, batch_dict):\n        obses = batch_dict['obses']\n        returns = batch_dict['returns']\n        dones = batch_dict['dones']\n        values = batch_dict['values']\n        actions = batch_dict['actions']\n        neglogpacs = batch_dict['neglogpacs']\n        mus = batch_dict['mus']\n        sigmas = batch_dict['sigmas']\n        rnn_states = batch_dict.get('rnn_states', None)\n        rnn_masks = batch_dict.get('rnn_masks', None)\n        \n        advantages = self._calc_advs(batch_dict)\n\n        if self.normalize_value:\n            self.value_mean_std.train()\n            values = self.value_mean_std(values)\n            returns = self.value_mean_std(returns)\n            self.value_mean_std.eval()\n\n        dataset_dict = {}\n        dataset_dict['old_values'] = values\n        dataset_dict['old_logp_actions'] = neglogpacs\n        dataset_dict['advantages'] = advantages\n        dataset_dict['returns'] = returns\n        dataset_dict['actions'] = actions\n        dataset_dict['obs'] = obses\n        dataset_dict['rnn_states'] = rnn_states\n        dataset_dict['rnn_masks'] = rnn_masks\n        dataset_dict['mu'] = mus\n        dataset_dict['sigma'] = sigmas\n\n        self.dataset.update_values_dict(dataset_dict)\n\n        if self.has_central_value:\n            dataset_dict = {}\n            dataset_dict['old_values'] = values\n            dataset_dict['advantages'] = advantages\n            dataset_dict['returns'] = returns\n            dataset_dict['actions'] = actions\n            dataset_dict['obs'] = batch_dict['states']\n            dataset_dict['rnn_masks'] = rnn_masks\n            self.central_value_net.update_dataset(dataset_dict)\n\n        return\n\n    def calc_gradients(self, input_dict):\n        self.set_train()\n\n        value_preds_batch = input_dict['old_values']\n        old_action_log_probs_batch = input_dict['old_logp_actions']\n        advantage = input_dict['advantages']\n        old_mu_batch = input_dict['mu']\n        old_sigma_batch = input_dict['sigma']\n        return_batch = input_dict['returns']\n        actions_batch = input_dict['actions']\n        obs_batch = input_dict['obs']\n        obs_batch = self._preproc_obs(obs_batch)\n\n        lr = self.last_lr\n        kl = 1.0\n        lr_mul = 1.0\n        curr_e_clip = lr_mul * self.e_clip\n\n        batch_dict = {\n            'is_train': True,\n            'prev_actions': actions_batch, \n            'obs' : obs_batch\n        }\n\n        rnn_masks = None\n        if self.is_rnn:\n            rnn_masks = input_dict['rnn_masks']\n            batch_dict['rnn_states'] = input_dict['rnn_states']\n            batch_dict['seq_length'] = self.seq_len\n\n        with torch.cuda.amp.autocast(enabled=self.mixed_precision):\n            res_dict = self.model(batch_dict)\n            action_log_probs = res_dict['prev_neglogp']\n            values = res_dict['values']\n            entropy = res_dict['entropy']\n            mu = res_dict['mus']\n            sigma = res_dict['sigmas']\n\n            a_info = self._actor_loss(old_action_log_probs_batch, action_log_probs, advantage, curr_e_clip)\n            a_loss = a_info['actor_loss']\n\n            c_info = self._critic_loss(value_preds_batch, values, curr_e_clip, return_batch, self.clip_value)\n            c_loss = c_info['critic_loss']\n\n            b_loss = self.bound_loss(mu)\n            \n            a_loss = torch.mean(a_loss)\n            c_loss = torch.mean(c_loss)\n            b_loss = torch.mean(b_loss)\n            entropy = torch.mean(entropy)\n\n            loss = a_loss + self.critic_coef * c_loss - self.entropy_coef * entropy + self.bounds_loss_coef * b_loss\n            \n            a_clip_frac = torch.mean(a_info['actor_clipped'].float())\n            \n            a_info['actor_loss'] = a_loss\n            a_info['actor_clip_frac'] = a_clip_frac\n\n            if self.multi_gpu:\n                self.optimizer.zero_grad()\n            else:\n                for param in self.model.parameters():\n                    param.grad = None\n\n        self.scaler.scale(loss).backward()\n        self.scaler.step(self.optimizer)\n        self.scaler.update()\n\n        with torch.no_grad():\n            reduce_kl = not self.is_rnn\n            kl_dist = torch_ext.policy_kl(mu.detach(), sigma.detach(), old_mu_batch, old_sigma_batch, reduce_kl)\n                    \n        self.train_result = {\n            'entropy': entropy,\n            'kl': kl_dist,\n            'last_lr': self.last_lr, \n            'lr_mul': lr_mul, \n            'b_loss': b_loss\n        }\n        self.train_result.update(a_info)\n        self.train_result.update(c_info)\n\n        return\n\n    def discount_values(self, mb_fdones, mb_values, mb_rewards, mb_next_values):\n        lastgaelam = 0\n        mb_advs = torch.zeros_like(mb_rewards)\n\n        for t in reversed(range(self.horizon_length)):\n            not_done = 1.0 - mb_fdones[t]\n            not_done = not_done.unsqueeze(1)\n\n            delta = mb_rewards[t] + self.gamma * mb_next_values[t] - mb_values[t]\n            lastgaelam = delta + self.gamma * self.tau * not_done * lastgaelam\n            mb_advs[t] = lastgaelam\n\n        return mb_advs\n\n    def env_reset(self, env_ids=None):\n        obs = self.vec_env.reset(env_ids)\n        obs = self.obs_to_tensors(obs)\n        return obs\n\n    def bound_loss(self, mu):\n        if self.bounds_loss_coef is not None:\n            soft_bound = 1.0\n            mu_loss_high = torch.clamp_min(mu - soft_bound, 0.0)**2\n            mu_loss_low = torch.clamp_max(mu + soft_bound, 0.0)**2\n            b_loss = (mu_loss_low + mu_loss_high).sum(axis=-1)\n        else:\n            b_loss = 0\n        return b_loss\n\n    def _get_mean_rewards(self):\n        return self.game_rewards.get_mean()\n\n    def _load_config_params(self, config):\n        self.last_lr = config['learning_rate']\n        return\n\n    def _build_net_config(self):\n        obs_shape = torch_ext.shape_whc_to_cwh(self.obs_shape)\n        config = {\n            'actions_num' : self.actions_num,\n            'input_shape' : obs_shape,\n            'num_seqs' : self.num_actors * self.num_agents,\n            'value_size': self.env_info.get('value_size', 1),\n            'normalize_value' : self.normalize_value,\n            'normalize_input': self.normalize_input,\n        }\n        return config\n\n    def _setup_action_space(self):\n        action_space = self.env_info['action_space']\n        self.actions_num = action_space.shape[0]\n\n        # todo introduce device instead of cuda()\n        self.actions_low = torch.from_numpy(action_space.low.copy()).float().to(self.ppo_device)\n        self.actions_high = torch.from_numpy(action_space.high.copy()).float().to(self.ppo_device)\n        return\n\n    def _init_train(self):\n        return\n\n    def _eval_critic(self, obs_dict):\n        self.model.eval()\n        obs = obs_dict['obs']\n        processed_obs = self._preproc_obs(obs)\n        value = self.model.eval_critic(processed_obs)\n\n        return value\n\n    def _actor_loss(self, old_action_log_probs_batch, action_log_probs, advantage, curr_e_clip):\n        ratio = torch.exp(old_action_log_probs_batch - action_log_probs)\n        surr1 = advantage * ratio\n        surr2 = advantage * torch.clamp(ratio, 1.0 - curr_e_clip,\n                                    1.0 + curr_e_clip)\n        a_loss = torch.max(-surr1, -surr2)\n\n        clipped = torch.abs(ratio - 1.0) > curr_e_clip\n        clipped = clipped.detach()\n        \n        info = {\n            'actor_loss': a_loss,\n            'actor_clipped': clipped.detach()\n        }\n        return info\n\n    def _critic_loss(self, value_preds_batch, values, curr_e_clip, return_batch, clip_value):\n        if clip_value:\n            value_pred_clipped = value_preds_batch + \\\n                    (values - value_preds_batch).clamp(-curr_e_clip, curr_e_clip)\n            value_losses = (values - return_batch)**2\n            value_losses_clipped = (value_pred_clipped - return_batch)**2\n            c_loss = torch.max(value_losses, value_losses_clipped)\n        else:\n            c_loss = (return_batch - values)**2\n\n        info = {\n            'critic_loss': c_loss\n        }\n        return info\n    \n    def _calc_advs(self, batch_dict):\n        returns = batch_dict['returns']\n        values = batch_dict['values']\n\n        advantages = returns - values\n        advantages = torch.sum(advantages, axis=1)\n\n        if self.normalize_advantage:\n            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)\n\n        return advantages\n\n    def _record_train_batch_info(self, batch_dict, train_info):\n        return\n\n    def _log_train_info(self, train_info, frame):\n        self.writer.add_scalar('performance/update_time', train_info['update_time'], frame)\n        self.writer.add_scalar('performance/play_time', train_info['play_time'], frame)\n        self.writer.add_scalar('losses/a_loss', torch_ext.mean_list(train_info['actor_loss']).item(), frame)\n        self.writer.add_scalar('losses/c_loss', torch_ext.mean_list(train_info['critic_loss']).item(), frame)\n        \n        self.writer.add_scalar('losses/bounds_loss', torch_ext.mean_list(train_info['b_loss']).item(), frame)\n        self.writer.add_scalar('losses/entropy', torch_ext.mean_list(train_info['entropy']).item(), frame)\n        self.writer.add_scalar('info/last_lr', train_info['last_lr'][-1] * train_info['lr_mul'][-1], frame)\n        self.writer.add_scalar('info/lr_mul', train_info['lr_mul'][-1], frame)\n        self.writer.add_scalar('info/e_clip', self.e_clip * train_info['lr_mul'][-1], frame)\n        self.writer.add_scalar('info/clip_frac', torch_ext.mean_list(train_info['actor_clip_frac']).item(), frame)\n        self.writer.add_scalar('info/kl', torch_ext.mean_list(train_info['kl']).item(), frame)\n        return\n"
  },
  {
    "path": "timechamber/ase/utils/common_player.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nimport torch \n\nfrom rl_games.algos_torch import players\nfrom rl_games.algos_torch import torch_ext\nfrom rl_games.algos_torch.running_mean_std import RunningMeanStd\nfrom rl_games.common.player import BasePlayer\nfrom timechamber.utils.utils import load_check, load_checkpoint\n\nimport numpy as np\n\nclass CommonPlayer(players.PpoPlayerContinuous):\n    def __init__(self, params):\n        config = params['config']\n        BasePlayer.__init__(self, params)\n        self.network = config['network']\n        \n        self._setup_action_space()\n        self.mask = [False]\n\n        self.normalize_input = self.config['normalize_input']\n        self.normalize_value = self.config.get('normalize_value', False)\n\n        net_config = self._build_net_config()\n        self._build_net(net_config)   \n        \n        return\n\n    def run(self):\n        n_games = self.games_num\n        render = self.render_env\n        n_game_life = self.n_game_life\n        is_determenistic = self.is_determenistic\n        sum_rewards = 0\n        sum_steps = 0\n        sum_game_res = 0\n        n_games = n_games * n_game_life\n        games_played = 0\n        has_masks = False\n        has_masks_func = getattr(self.env, \"has_action_mask\", None) is not None\n\n        op_agent = getattr(self.env, \"create_agent\", None)\n        if op_agent:\n            agent_inited = True\n\n        if has_masks_func:\n            has_masks = self.env.has_action_mask()\n\n        need_init_rnn = self.is_rnn\n        for _ in range(n_games):\n            if games_played >= n_games:\n                break\n\n            obs_dict = self.env_reset()\n            batch_size = 1\n            batch_size = self.get_batch_size(obs_dict['obs'], batch_size)\n\n            if need_init_rnn:\n                self.init_rnn()\n                need_init_rnn = False\n\n            cr = torch.zeros(batch_size, dtype=torch.float32, device=self.device)\n            steps = torch.zeros(batch_size, dtype=torch.float32, device=self.device)\n\n            print_game_res = False\n\n            done_indices = []\n\n            for n in range(self.max_steps):\n                # obs_dict = self.env_reset(done_indices)\n\n                if has_masks:\n                    masks = self.env.get_action_mask()\n                    action = self.get_masked_action(obs_dict, masks, is_determenistic)\n                else:\n                    action = self.get_action(obs_dict, is_determenistic)\n                obs_dict, r, done, info =  self.env_step(self.env, action)\n                obs_dict = {'obs': obs_dict}\n                # print('obs_dict shape: ', obs_dict.shape)\n                cr += r\n                steps += 1\n  \n                self._post_step(info)\n\n                if render:\n                    self.env.render(mode = 'human')\n                    time.sleep(self.render_sleep)\n\n                all_done_indices = done.nonzero(as_tuple=False)\n                done_indices = all_done_indices[::self.num_agents]\n                done_count = len(done_indices)\n                games_played += done_count\n\n                if done_count > 0:\n                    if self.is_rnn:\n                        for s in self.states:\n                            s[:,all_done_indices,:] = s[:,all_done_indices,:] * 0.0\n\n                    cur_rewards = cr[done_indices].sum().item()\n                    cur_steps = steps[done_indices].sum().item()\n\n                    cr = cr * (1.0 - done.float())\n                    steps = steps * (1.0 - done.float())\n                    sum_rewards += cur_rewards\n                    sum_steps += cur_steps\n\n                    game_res = 0.0\n                    if isinstance(info, dict):\n                        if 'battle_won' in info:\n                            print_game_res = True\n                            game_res = info.get('battle_won', 0.5)\n                        if 'scores' in info:\n                            print_game_res = True\n                            game_res = info.get('scores', 0.5)\n                    if self.print_stats:\n                        if print_game_res:\n                            print('reward:', cur_rewards/done_count, 'steps:', cur_steps/done_count, 'w:', game_res)\n                        else:\n                            print('reward:', cur_rewards/done_count, 'steps:', cur_steps/done_count)\n\n                    sum_game_res += game_res\n                    if batch_size//self.num_agents == 1 or games_played >= n_games:\n                        break\n                \n                done_indices = done_indices[:, 0]\n\n        print(sum_rewards)\n        if print_game_res:\n            print('av reward:', sum_rewards / games_played * n_game_life, 'av steps:', sum_steps / games_played * n_game_life, 'winrate:', sum_game_res / games_played * n_game_life)\n        else:\n            print('av reward:', sum_rewards / games_played * n_game_life, 'av steps:', sum_steps / games_played * n_game_life)\n\n        return\n\n    def get_action(self, obs_dict, is_determenistic = False):\n        output = super().get_action(obs_dict['obs'], is_determenistic)\n        return output\n\n    def env_step(self, env, actions):\n        if not self.is_tensor_obses:\n            actions = actions.cpu().numpy()\n        obs, rewards, dones, infos = env.step(actions)\n\n        if hasattr(obs, 'dtype') and obs.dtype == np.float64:\n            obs = np.float32(obs)\n        if self.value_size > 1:\n            rewards = rewards[0]\n        if self.is_tensor_obses:\n            return obs, rewards.to(self.device), dones.to(self.device), infos\n        else:\n            if np.isscalar(dones):\n                rewards = np.expand_dims(np.asarray(rewards), 0)\n                dones = np.expand_dims(np.asarray(dones), 0)\n            return self.obs_to_torch(obs), torch.from_numpy(rewards), torch.from_numpy(dones), infos\n\n    def _build_net(self, config):\n        self.model = self.network.build(config)\n        self.model.to(self.device)\n        self.model.eval()\n        self.is_rnn = self.model.is_rnn()\n        if self.normalize_input:\n            obs_shape = torch_ext.shape_whc_to_cwh(self.obs_shape)\n            self.running_mean_std = RunningMeanStd(obs_shape).to(self.device)\n            self.running_mean_std.eval() \n        return\n\n    def env_reset(self, env_ids=None):\n        obs = self.env.reset(env_ids)\n        return self.obs_to_torch(obs)\n\n    def _post_step(self, info):\n        return\n\n    def _build_net_config(self):\n        obs_shape = torch_ext.shape_whc_to_cwh(self.obs_shape)\n        config = {\n            'actions_num' : self.actions_num,\n            'input_shape' : obs_shape,\n            'num_seqs' : self.num_agents,\n            'normalize_input': self.normalize_input,\n            'normalize_value' : self.normalize_value,\n        }\n        return config\n\n    def restore(self, fn):\n        checkpoint = load_checkpoint(fn, device=self.device)\n        checkpoint = load_check(checkpoint=checkpoint,\n                                normalize_input=self.normalize_input,\n                                normalize_value=self.normalize_value)\n        self.model.load_state_dict(checkpoint['model'])\n\n        if self.normalize_input and 'running_mean_std' in checkpoint:\n            self.model.running_mean_std.load_state_dict(checkpoint['running_mean_std'])\n\n    def _setup_action_space(self):\n        self.actions_num = self.action_space.shape[0] \n        self.actions_low = torch.from_numpy(self.action_space.low.copy()).float().to(self.device)\n        self.actions_high = torch.from_numpy(self.action_space.high.copy()).float().to(self.device)\n        return"
  },
  {
    "path": "timechamber/ase/utils/replay_buffer.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nimport torch\n\nclass ReplayBuffer():\n    def __init__(self, buffer_size, device):\n        self._head = 0\n        self._total_count = 0\n        self._buffer_size = buffer_size\n        self._device = device\n        self._data_buf = None\n        self._sample_idx = torch.randperm(buffer_size)\n        self._sample_head = 0\n\n        return\n\n    def reset(self):\n        self._head = 0\n        self._total_count = 0\n        self._reset_sample_idx()\n        return\n\n    def get_buffer_size(self):\n        return self._buffer_size\n\n    def get_total_count(self):\n        return self._total_count\n\n    def store(self, data_dict):\n        if (self._data_buf is None):\n            self._init_data_buf(data_dict)\n\n        n = next(iter(data_dict.values())).shape[0]\n        buffer_size = self.get_buffer_size()\n        assert(n <= buffer_size)\n\n        for key, curr_buf in self._data_buf.items():\n            curr_n = data_dict[key].shape[0]\n            assert(n == curr_n)\n\n            store_n = min(curr_n, buffer_size - self._head)\n            curr_buf[self._head:(self._head + store_n)] = data_dict[key][:store_n]    \n        \n            remainder = n - store_n\n            if (remainder > 0):\n                curr_buf[0:remainder] = data_dict[key][store_n:]  \n\n        self._head = (self._head + n) % buffer_size\n        self._total_count += n\n\n        return\n\n    def sample(self, n):\n        total_count = self.get_total_count()\n        buffer_size = self.get_buffer_size()\n\n        idx = torch.arange(self._sample_head, self._sample_head + n)\n        idx = idx % buffer_size\n        rand_idx = self._sample_idx[idx]\n        if (total_count < buffer_size):\n            rand_idx = rand_idx % self._head\n\n        samples = dict()\n        for k, v in self._data_buf.items():\n            samples[k] = v[rand_idx]\n\n        self._sample_head += n\n        if (self._sample_head >= buffer_size):\n            self._reset_sample_idx()\n\n        return samples\n\n    def _reset_sample_idx(self):\n        buffer_size = self.get_buffer_size()\n        self._sample_idx[:] = torch.randperm(buffer_size)\n        self._sample_head = 0\n        return\n\n    def _init_data_buf(self, data_dict):\n        buffer_size = self.get_buffer_size()\n        self._data_buf = dict()\n\n        for k, v in data_dict.items():\n            v_shape = v.shape[1:]\n            self._data_buf[k] = torch.zeros((buffer_size,) + v_shape, device=self._device)\n\n        return"
  },
  {
    "path": "timechamber/cfg/config.yaml",
    "content": "# Task name - used to pick the class to load\ntask_name: ${task.name}\n# experiment name. defaults to name of training config\nexperiment: ''\n\n# if set to positive integer, overrides the default number of environments\nnum_envs: ''\n\n# seed - set to -1 to choose random seed\nseed: 42\n# set to True for deterministic performance\ntorch_deterministic: False\n\n# set the maximum number of learning iterations to train for. overrides default per-environment setting\nmax_iterations: ''\n\n# set minibatch_size\nminibatch_size: 32768\n\n## Device config\n#  'physx' or 'flex'\nphysics_engine: 'physx'\n# whether to use cpu or gpu pipeline\npipeline: 'gpu'\nuse_gpu: True\nuse_gpu_pipeline: True\n# device for running physics simulation\nsim_device: 'cuda:0'\n# device to run RL\nrl_device: 'cuda:0'\ngraphics_device_id: 0\ndevice_type: cuda\n\n## PhysX arguments\nnum_threads: 4 # Number of worker threads per scene used by PhysX - for CPU PhysX only.\nsolver_type: 1 # 0: pgs, 1: tgs\nnum_subscenes: 4 # Splits the simulation into N physics scenes and runs each one in a separate thread\n\n# RLGames Arguments\n# test - if set, run policy in inference mode (requires setting checkpoint to load)\ntest: False\n# used to set checkpoint path\ncheckpoint: ''\nop_checkpoint: ''\nplayer_pool_type: ''\nnum_agents: 2\n\n# HRL Arguments\nmotion_file: 'tasks/data/motions/reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy'\n\n# set to True to use multi-gpu horovod training\nmulti_gpu: False\n\nwandb_activate: False\nwandb_group: ''\nwandb_name: ${train.params.config.name}\nwandb_entity: ''\nwandb_project: 'timechamber'\ncapture_video: False\ncapture_video_freq: 1464\ncapture_video_len: 100\nforce_render: True\n\n# disables rendering\nheadless: True\n\n# set default task and default training config based on task\ndefaults:\n  - task: MA_Humanoid_Strike\n  - train: ${task}HRL\n  - hydra/job_logging: disabled\n\n# set the directory where the output files get saved\nhydra:\n  output_subdir: null\n  run:\n    dir: .\n\n"
  },
  {
    "path": "timechamber/cfg/task/MA_Ant_Battle.yaml",
    "content": "# used to create the object\nname: MA_Ant_Battle\n\nphysics_engine: ${..physics_engine}\n\n# if given, will override the device setting in gym.\nenv:\n  #  numEnvs: ${...num_envs}\n  numEnvs: ${resolve_default:4096,${...num_envs}}\n  numAgents: ${...num_agents}\n  # rgb color of Ant body\n  color: [ [ 0.97, 0.38, 0.06 ],[ 0.24, 0.38, 0.06 ],[ 0.56, 0.85, 0.25 ],[ 0.56, 0.85, 0.25 ],[ 0.14, 0.97, 0.24 ],[ 0.63, 0.2, 0.87 ] ]\n  envSpacing: 6\n  borderlineSpace: 3\n  episodeLength: 1000\n  enableDebugVis: False\n  controlFrequencyInv: 1\n  clipActions: 1.0\n  clipObservations: 5.0\n  actionScale: 0.5\n  control:\n    # PD Drive parameters:\n    stiffness: 85.0  # [N*m/rad]\n    damping: 2.0     # [N*m*s/rad]\n    actionScale: 0.5\n    controlFrequencyInv: 1 # 60 Hz\n\n  # reward parameters\n  headingWeight: 0.5\n  upWeight: 0.1\n\n  # cost parameters\n  terminationHeight: 0.31\n  dofVelocityScale: 0.2\n  jointsAtLimitCost: -0.1\n\n  plane:\n    staticFriction: 1.0\n    dynamicFriction: 1.0\n    restitution: 0.0\n\n  asset:\n    assetFileName: \"mjcf/nv_ant.xml\"\n\n  # set to True if you use camera sensors in the environment\n  enableCameraSensors: False\n\nsim:\n  dt: 0.0166 # 1/60 s\n  substeps: 2\n  up_axis: \"z\"\n  use_gpu_pipeline: ${eq:${...pipeline},\"gpu\"}\n  gravity: [ 0.0, 0.0, -9.81 ]\n  physx:\n    num_threads: ${....num_threads}\n    solver_type: ${....solver_type}\n    use_gpu: ${contains:\"cuda\",${....sim_device}} # set to False to run on CPU\n    num_position_iterations: 4\n    num_velocity_iterations: 0\n    contact_offset: 0.02\n    rest_offset: 0.0\n    bounce_threshold_velocity: 0.2\n    max_depenetration_velocity: 10.0\n    default_buffer_size_multiplier: 5.0\n    max_gpu_contact_pairs: 8388608 # 8*1024*1024\n    num_subscenes: ${....num_subscenes}\n    contact_collection: 0 # 0: CC_NEVER (don't collect contact info), 1: CC_LAST_SUBSTEP (collect only contacts on last substep), 2: CC_ALL_SUBSTEPS (default - all contacts)\n\ntask:\n  randomize: False\n  randomization_params:\n    # specify which attributes to randomize for each actor type and property\n    frequency: 600   # Define how many environment steps between generating new randomizations\n    observations:\n      range: [ 0, .002 ] # range for the white noise\n      operation: \"additive\"\n      distribution: \"gaussian\"\n    actions:\n      range: [ 0., .02 ]\n      operation: \"additive\"\n      distribution: \"gaussian\"\n    actor_params:\n      ant:\n        color: True\n        rigid_body_properties:\n          mass:\n            range: [ 0.5, 1.5 ]\n            operation: \"scaling\"\n            distribution: \"uniform\"\n            setup_only: True # Property will only be randomized once before simulation is started. See Domain Randomization Documentation for more info.\n        dof_properties:\n          damping:\n            range: [ 0.5, 1.5 ]\n            operation: \"scaling\"\n            distribution: \"uniform\"\n          stiffness:\n            range: [ 0.5, 1.5 ]\n            operation: \"scaling\"\n            distribution: \"uniform\"\n          lower:\n            range: [ 0, 0.01 ]\n            operation: \"additive\"\n            distribution: \"gaussian\"\n          upper:\n            range: [ 0, 0.01 ]\n            operation: \"additive\"\n            distribution: \"gaussian\"\n"
  },
  {
    "path": "timechamber/cfg/task/MA_Ant_Sumo.yaml",
    "content": "# used to create the object\nname: MA_Ant_Sumo\n\nphysics_engine: ${..physics_engine}\n\n# if given, will override the device setting in gym.\nenv:\n#  numEnvs: ${...num_envs}\n  numEnvs: ${resolve_default:4096,${...num_envs}}\n  numAgents: ${...num_agents}\n  envSpacing: 6\n  borderlineSpace: 3\n  episodeLength: 1000\n  enableDebugVis: False\n  controlFrequencyInv: 1\n  clipActions: 1.0\n  clipObservations: 5.0\n  actionScale: 0.5\n  control:\n    # PD Drive parameters:\n    stiffness: 85.0  # [N*m/rad]\n    damping: 2.0     # [N*m*s/rad]\n    actionScale: 0.5\n    controlFrequencyInv: 1 # 60 Hz\n\n  # reward parameters\n  headingWeight: 0.5\n  upWeight: 0.1\n\n  # cost parameters\n  terminationHeight: 0.31\n  dofVelocityScale: 0.2\n  jointsAtLimitCost: -0.1\n\n  plane:\n    staticFriction: 1.0\n    dynamicFriction: 1.0\n    restitution: 0.0\n\n  asset:\n    assetFileName: \"mjcf/nv_ant.xml\"\n\n# set to True if you use camera sensors in the environment\nenableCameraSensors: False\n\nsim:\n  dt: 0.0166 # 1/60 s\n  substeps: 2\n  up_axis: \"z\"\n  use_gpu_pipeline: ${eq:${...pipeline},\"gpu\"}\n  gravity: [0.0, 0.0, -9.81]\n  physx:\n    num_threads: ${....num_threads}\n    solver_type: ${....solver_type}\n    use_gpu: ${contains:\"cuda\",${....sim_device}} # set to False to run on CPU\n    num_position_iterations: 4\n    num_velocity_iterations: 0\n    contact_offset: 0.02\n    rest_offset: 0.0\n    bounce_threshold_velocity: 0.2\n    max_depenetration_velocity: 10.0\n    default_buffer_size_multiplier: 5.0\n    max_gpu_contact_pairs: 8388608 # 8*1024*1024\n    num_subscenes: ${....num_subscenes}\n    contact_collection: 0 # 0: CC_NEVER (don't collect contact info), 1: CC_LAST_SUBSTEP (collect only contacts on last substep), 2: CC_ALL_SUBSTEPS (default - all contacts)\n\ntask:\n  randomize: False\n  randomization_params:\n    # specify which attributes to randomize for each actor type and property\n    frequency: 600   # Define how many environment steps between generating new randomizations\n    observations:\n      range: [0, .002] # range for the white noise\n      operation: \"additive\"\n      distribution: \"gaussian\"\n    actions:\n      range: [0., .02]\n      operation: \"additive\"\n      distribution: \"gaussian\"\n    actor_params:\n      ant:\n        color: True\n        rigid_body_properties:\n          mass:\n            range: [0.5, 1.5]\n            operation: \"scaling\"\n            distribution: \"uniform\"\n            setup_only: True # Property will only be randomized once before simulation is started. See Domain Randomization Documentation for more info.\n        dof_properties:\n          damping:\n            range: [0.5, 1.5]\n            operation: \"scaling\"\n            distribution: \"uniform\"\n          stiffness:\n            range: [0.5, 1.5]\n            operation: \"scaling\"\n            distribution: \"uniform\"\n          lower:\n            range: [0, 0.01]\n            operation: \"additive\"\n            distribution: \"gaussian\"\n          upper:\n            range: [0, 0.01]\n            operation: \"additive\"\n            distribution: \"gaussian\"\n"
  },
  {
    "path": "timechamber/cfg/task/MA_Humanoid_Strike.yaml",
    "content": "name: MA_Humanoid_Strike\n\nphysics_engine: ${..physics_engine}\n\n# if given, will override the device setting in gym. \nenv: \n  numEnvs: ${resolve_default:4096,${...num_envs}}\n  envSpacing: 6\n  episodeLength: 1500\n  borderlineSpace: 3.0\n  numAgents: 2\n  isFlagrun: False\n  enableDebugVis: False\n  \n  pdControl: True\n  powerScale: 1.0\n  controlFrequencyInv: 2 # 30 Hz\n  stateInit: \"Default\"\n  hybridInitProb: 0.5\n  numAMPObsSteps: 10\n  \n  localRootObs: True\n  keyBodies: [\"right_hand\", \"left_hand\", \"right_foot\", \"left_foot\", \"sword\", \"shield\"]\n  contactBodies: [\"right_foot\", \"left_foot\"]\n  # forceBodies: [\"torso\", \"right_upper_arm\", \"right_thigh\", \"right_shin\", \"left_thigh\", \"left_shin\"]\n  forceBodies: [\"torso\", \"right_thigh\", \"right_shin\", \"left_thigh\", \"left_shin\"]\n  terminationHeight: 0.15\n  enableEarlyTermination: True\n\n  strikeBodyNames: [\"sword\", \"shield\", \"right_hand\", \"right_lower_arm\", \"left_hand\", \"left_lower_arm\"]\n  enableTaskObs: True\n  \n  asset:\n    assetRoot: \"tasks/data/assets\"\n    assetFileName: \"mjcf/amp_humanoid_sword_shield.xml\"\n\n  plane:\n    staticFriction: 1.0\n    dynamicFriction: 1.0\n    restitution: 0.0\n\nsim:\n  substeps: 2\n  physx:\n    num_threads: 4\n    solver_type: 1  # 0: pgs, 1: tgs\n    num_position_iterations: 4\n    num_velocity_iterations: 0\n    contact_offset: 0.02\n    rest_offset: 0.0\n    bounce_threshold_velocity: 0.2\n    max_depenetration_velocity: 10.0\n    default_buffer_size_multiplier: 10.0\n\n  flex:\n    num_inner_iterations: 10\n    warm_start: 0.25\n"
  },
  {
    "path": "timechamber/cfg/train/MA_Ant_BattlePPO.yaml",
    "content": "params:\n  seed: ${...seed}\n\n  algo:\n    name: self_play_continuous\n\n  model:\n    name: continuous_a2c_logstd\n\n  network:\n    name: actor_critic\n    separate: False\n    space:\n      continuous:\n        mu_activation: None\n        sigma_activation: None\n        mu_init:\n          name: default\n        sigma_init:\n          name: const_initializer\n          val: 0\n        fixed_sigma: True\n    mlp:\n      units: [ 256, 128, 64 ]\n      activation: elu\n      d2rl: False\n\n      initializer:\n        name: default\n\n  player_pool_type: ${...player_pool_type}\n  load_checkpoint: ${if:${...checkpoint},True,False} # flag which sets whether to load the checkpoint\n  load_path: ${...checkpoint} # path to the checkpoint to load\n  op_load_path: ${if:${...op_checkpoint},${...op_checkpoint},${...checkpoint}} # default play with myself\n  num_agents: ${...num_agents}\n  update_win_rate: 0.7\n  player_pool_length: 4\n  games_to_check: 400\n  max_update_steps: 5000\n\n  device: ${...rl_device}\n  config:\n    name: ${resolve_default:MA_Ant_1v1,${....experiment}}\n    env_name: rlgpu\n    multi_gpu: ${....multi_gpu}\n    ppo: True\n    mixed_precision: False\n    normalize_input: True\n    normalize_value: True\n    value_bootstrap: True\n    num_actors: ${....task.env.numEnvs}\n    reward_shaper:\n      scale_value: 0.01\n    normalize_advantage: True\n    gamma: 0.99\n    tau: 0.95\n    learning_rate: 3e-4\n    lr_schedule: adaptive\n    schedule_type: standard\n    kl_threshold: 0.008\n    score_to_win: 20000\n    max_epochs: ${resolve_default:2000,${....max_iterations}}\n    save_best_after: 200\n    save_frequency: 1000\n    grad_norm: 1.0\n    entropy_coef: 0.0\n    truncate_grads: True\n    e_clip: 0.2\n    horizon_length: 64\n    minibatch_size: ${resolve_default:32768,${....minibatch_size}}\n    mini_epochs: 4\n    critic_coef: 2\n    clip_value: True\n    use_smooth_clamp: True\n    bounds_loss_coef: 0.0000\n    player:\n      games_num: 4000\n      record_elo: True\n      init_elo: 400"
  },
  {
    "path": "timechamber/cfg/train/MA_Ant_SumoPPO.yaml",
    "content": "params:\n  seed: ${...seed}\n\n  algo:\n    name: self_play_continuous\n\n  model:\n    name: continuous_a2c_logstd\n\n  network:\n    name: actor_critic\n    separate: False\n    space:\n      continuous:\n        mu_activation: None\n        sigma_activation: None\n        mu_init:\n          name: default\n        sigma_init:\n          name: const_initializer\n          val: 0\n        fixed_sigma: True\n    mlp:\n      units: [ 256, 128, 64 ]\n      activation: elu\n      d2rl: False\n\n      initializer:\n        name: default\n  # self play agent related\n  player_pool_type: ${...player_pool_type}\n  load_checkpoint: ${if:${...checkpoint},True,False} # flag which sets whether to load the checkpoint\n  load_path: ${...checkpoint} # path to the checkpoint to load\n  op_load_path: ${if:${...op_checkpoint},${...op_checkpoint},${...checkpoint}} # default play with myself\n  num_agents: ${...num_agents}\n\n  update_win_rate: 0.7\n  player_pool_length: 2\n  games_to_check: 400\n  max_update_steps: 5000\n  device: ${...rl_device}\n  config:\n    name: ${resolve_default:MA_Ant_1v1,${....experiment}}\n    env_name: rlgpu\n    multi_gpu: ${....multi_gpu}\n    ppo: True\n    mixed_precision: False\n    normalize_input: True\n    normalize_value: True\n    value_bootstrap: True\n    num_actors: ${....task.env.numEnvs}\n    reward_shaper:\n      scale_value: 0.01\n    normalize_advantage: True\n    gamma: 0.99\n    tau: 0.95\n    learning_rate: 3e-4\n    lr_schedule: adaptive\n    schedule_type: standard\n    kl_threshold: 0.008\n    score_to_win: 20000\n    max_epochs: ${resolve_default:100000,${....max_iterations}}\n    save_best_after: 200\n    save_frequency: 500\n    grad_norm: 1.0\n    entropy_coef: 0.0\n    truncate_grads: True\n    e_clip: 0.2\n    horizon_length: 64\n    minibatch_size: ${resolve_default:32768,${....minibatch_size}}\n    mini_epochs: 4\n    critic_coef: 2\n    clip_value: True\n    use_smooth_clamp: True\n    bounds_loss_coef: 0.0000\n    player:\n      games_num: 4000\n      record_elo: True\n      init_elo: 400"
  },
  {
    "path": "timechamber/cfg/train/MA_Humanoid_StrikeHRL.yaml",
    "content": "params:\n  seed: ${...seed}\n\n  algo:\n    name: self_play_hrl\n\n  model:\n    name: hrl\n\n  network:\n    name: hrl\n    separate: True\n\n    space:\n      continuous:\n        mu_activation: None\n        sigma_activation: None\n        mu_init:\n          name: default\n        sigma_init:\n          name: const_initializer\n          val: -2.3\n        fixed_sigma: True\n        learn_sigma: False\n\n    mlp:\n      units: [1024, 512]\n      activation: relu\n      d2rl: False\n\n      initializer:\n        name: default\n      regularizer:\n        name: None\n\n  # self play agent related\n  player_pool_type: ${...player_pool_type}\n  load_checkpoint: ${if:${...checkpoint},True,False} # flag which sets whether to load the checkpoint\n  load_path: ${...checkpoint} # path to the checkpoint to load\n  op_load_path: ${if:${...op_checkpoint},${...op_checkpoint},${...checkpoint}} # default play with myself\n  num_agents: ${...num_agents}\n\n  update_win_rate: 0.8\n  player_pool_length: 4\n  games_to_check: 400\n  max_update_steps: 5000\n  device: ${...rl_device}\n\n  config:\n    name: Humanoid\n    env_name: rlgpu\n    multi_gpu: False\n    ppo: True\n    mixed_precision: False\n    normalize_input: True\n    normalize_value: True\n    num_actors: ${....task.env.numEnvs}\n    reward_shaper:\n      scale_value: 1\n    normalize_advantage: True\n    gamma: 0.99\n    tau: 0.95\n    learning_rate: 2e-5\n    lr_schedule: constant\n    score_to_win: 20000000\n    max_epochs: ${resolve_default:100000,${....max_iterations}}\n    save_best_after: 10\n    save_frequency: 50\n    print_stats: True\n    grad_norm: 1.0\n    entropy_coef: 0.0\n    truncate_grads: False\n    e_clip: 0.2\n    horizon_length: 64\n    minibatch_size: ${resolve_default:64,${....minibatch_size}}\n    mini_epochs: 6\n    critic_coef: 5\n    clip_value: False\n    seq_len: 4\n    bounds_loss_coef: 10\n    \n    task_reward_w: 0.9\n    disc_reward_w: 0.1\n\n    player:\n      determenistic: False\n      games_num: 4000\n      record_elo: True\n      init_elo: 400\n\n    llc_steps: 5\n    llc_config: cfg/train/base/ase_humanoid_hrl.yaml\n    llc_checkpoint: tasks/data/models/llc_reallusion_sword_shield.pth\n"
  },
  {
    "path": "timechamber/cfg/train/base/ase_humanoid_hrl.yaml",
    "content": "params:\n  seed: -1\n\n  algo:\n    name: ase\n\n  model:\n    name: ase\n\n  network:\n    name: ase\n    separate: True\n\n    space:\n      continuous:\n        mu_activation: None\n        sigma_activation: None\n        mu_init:\n          name: default\n        sigma_init:\n          name: const_initializer\n          val: -2.9\n        fixed_sigma: True\n        learn_sigma: False\n\n    mlp:\n      units: [1024, 1024, 512]\n      activation: relu\n      d2rl: False\n\n      initializer:\n        name: default\n      regularizer:\n        name: None\n\n    disc:\n      units: [1024, 1024, 512]\n      activation: relu\n\n      initializer:\n        name: default\n\n    enc:\n      units: [1024, 512]\n      activation: relu\n      separate: False\n\n      initializer:\n        name: default\n\n  load_checkpoint: False\n\n  config:\n    name: Humanoid\n    env_name: rlgpu\n    multi_gpu: False\n    ppo: True\n    mixed_precision: False\n    normalize_input: True\n    normalize_value: True\n    reward_shaper:\n      scale_value: 1\n    normalize_advantage: True\n    gamma: 0.99\n    tau: 0.95\n    learning_rate: 2e-5\n    lr_schedule: constant\n    score_to_win: 20000\n    max_epochs: 100000\n    save_best_after: 50\n    save_frequency: 50\n    print_stats: True\n    grad_norm: 1.0\n    entropy_coef: 0.0\n    truncate_grads: False\n    ppo: True\n    e_clip: 0.2\n    horizon_length: 32\n    minibatch_size: 1\n    mini_epochs: 6\n    critic_coef: 5\n    clip_value: False\n    seq_len: 4\n    bounds_loss_coef: 10\n    amp_obs_demo_buffer_size: 200000\n    amp_replay_buffer_size: 200000\n    amp_replay_keep_prob: 0.01\n    amp_batch_size: 32\n    amp_minibatch_size: 1\n    disc_coef: 5\n    disc_logit_reg: 0.01\n    disc_grad_penalty: 5\n    disc_reward_scale: 2\n    disc_weight_decay: 0.0001\n    normalize_amp_input: True\n    enable_eps_greedy: False\n\n    latent_dim: 64\n    latent_steps_min: 1\n    latent_steps_max: 150\n    \n    amp_latent_grad_bonus: 0.00\n    amp_latent_grad_bonus_max: 100.0\n    amp_diversity_bonus: 0.01\n    amp_diversity_tar: 1.0\n    \n    enc_coef: 5\n    enc_weight_decay: 0.0000\n    enc_reward_scale: 1\n    enc_grad_penalty: 0\n\n    task_reward_w: 0.0\n    disc_reward_w: 0.5\n    enc_reward_w: 0.5\n"
  },
  {
    "path": "timechamber/learning/common_agent.py",
    "content": "# License: see [LICENSE, LICENSES/isaacgymenvs/LICENSE]\nimport copy\nfrom datetime import datetime\nfrom gym import spaces\nimport numpy as np\nimport os\nimport time\nimport yaml\n\nfrom rl_games.algos_torch import a2c_continuous\nfrom rl_games.algos_torch import torch_ext\nfrom rl_games.algos_torch import central_value\nfrom rl_games.algos_torch.running_mean_std import RunningMeanStd\nfrom rl_games.common import a2c_common\nfrom rl_games.common import datasets\nfrom rl_games.common import schedulers\nfrom rl_games.common import vecenv\n\nimport torch\nfrom torch import optim\n\n\nfrom tensorboardX import SummaryWriter\n\n\nclass CommonAgent(a2c_continuous.A2CAgent):\n\n    def __init__(self, base_name, params):\n    \n        a2c_common.A2CBase.__init__(self, base_name, params)\n\n        config = params['config']\n        self._load_config_params(config)\n\n        self.is_discrete = False\n        self._setup_action_space()\n        self.bounds_loss_coef = config.get('bounds_loss_coef', None)\n        self.clip_actions = config.get('clip_actions', True)\n\n        self.network_path = config.get('network_path', \"./runs\")\n        self.network_path = os.path.join(self.network_path, self.config['name'])\n        self.network_path = os.path.join(self.network_path, 'nn')\n        \n        net_config = self._build_net_config()\n        self.model = self.network.build(net_config)\n        self.model.to(self.ppo_device)\n        self.states = None\n\n        self.init_rnn_from_model(self.model)\n        self.last_lr = float(self.last_lr)\n\n        self.optimizer = optim.Adam(self.model.parameters(), float(self.last_lr), eps=1e-08, weight_decay=self.weight_decay)\n\n        if self.has_central_value:\n            cv_config = {\n                'state_shape' : torch_ext.shape_whc_to_cwh(self.state_shape), \n                'value_size' : self.value_size,\n                'ppo_device' : self.ppo_device, \n                'num_agents' : self.num_agents, \n                'num_steps' : self.horizon_length, \n                'num_actors' : self.num_actors, \n                'num_actions' : self.actions_num, \n                'seq_len' : self.seq_len, \n                'model' : self.central_value_config['network'],\n                'config' : self.central_value_config, \n                'writter' : self.writer,\n                'multi_gpu' : self.multi_gpu\n            }\n            self.central_value_net = central_value.CentralValueTrain(**cv_config).to(self.ppo_device)\n\n        self.use_experimental_cv = self.config.get('use_experimental_cv', True)\n        self.algo_observer.after_init(self)\n        \n        return\n\n    def init_tensors(self):\n        super().init_tensors()\n        self.experience_buffer.tensor_dict['next_obses'] = torch.zeros_like(self.experience_buffer.tensor_dict['obses'])\n        self.experience_buffer.tensor_dict['next_values'] = torch.zeros_like(self.experience_buffer.tensor_dict['values'])\n\n        self.tensor_list += ['next_obses']\n        return\n\n    def train(self):\n        self.init_tensors()\n        self.last_mean_rewards = -100500\n        start_time = time.time()\n        total_time = 0\n        rep_count = 0\n        self.frame = 0\n        self.obs = self.env_reset()\n        self.curr_frames = self.batch_size_envs\n\n        self.model_output_file = os.path.join(self.network_path, self.config['name'])\n\n        if self.multi_gpu:\n            self.hvd.setup_algo(self)\n\n        self._init_train()\n\n        while True:\n            epoch_num = self.update_epoch()\n            train_info = self.train_epoch()\n\n            sum_time = train_info['total_time']\n            total_time += sum_time\n            frame = self.frame\n            if self.multi_gpu:\n                self.hvd.sync_stats(self)\n\n            if self.rank == 0:\n                scaled_time = sum_time\n                scaled_play_time = train_info['play_time']\n                curr_frames = self.curr_frames\n                self.frame += curr_frames\n                if self.print_stats:\n                    fps_step = curr_frames / scaled_play_time\n                    fps_total = curr_frames / scaled_time\n                    print(f'fps step: {fps_step:.1f} fps total: {fps_total:.1f}')\n\n                self.writer.add_scalar('performance/total_fps', curr_frames / scaled_time, frame)\n                self.writer.add_scalar('performance/step_fps', curr_frames / scaled_play_time, frame)\n                self.writer.add_scalar('info/epochs', epoch_num, frame)\n                self._log_train_info(train_info, frame)\n\n                self.algo_observer.after_print_stats(frame, epoch_num, total_time)\n                \n                if self.game_rewards.current_size > 0:\n                    mean_rewards = self.game_rewards.get_mean()\n                    mean_lengths = self.game_lengths.get_mean()\n\n                    for i in range(self.value_size):\n                        self.writer.add_scalar('rewards/frame'.format(i), mean_rewards[i], frame)\n                        self.writer.add_scalar('rewards/iter'.format(i), mean_rewards[i], epoch_num)\n                        self.writer.add_scalar('rewards/time'.format(i), mean_rewards[i], total_time)\n\n                    self.writer.add_scalar('episode_lengths/frame', mean_lengths, frame)\n                    self.writer.add_scalar('episode_lengths/iter', mean_lengths, epoch_num)\n\n                    if self.has_self_play_config:\n                        self.self_play_manager.update(self)\n\n                if self.save_freq > 0:\n                    if (epoch_num % self.save_freq == 0):\n                        self.save(self.model_output_file + \"_\" + str(epoch_num))\n\n                if epoch_num > self.max_epochs:\n                    self.save(self.model_output_file)\n                    print('MAX EPOCHS NUM!')\n                    return self.last_mean_rewards, epoch_num\n\n                update_time = 0\n        return\n\n    def train_epoch(self):\n        play_time_start = time.time()\n        with torch.no_grad():\n            if self.is_rnn:\n                batch_dict = self.play_steps_rnn()\n            else:\n                batch_dict = self.play_steps() \n\n        play_time_end = time.time()\n        update_time_start = time.time()\n        rnn_masks = batch_dict.get('rnn_masks', None)\n        \n        self.set_train()\n\n        self.curr_frames = batch_dict.pop('played_frames')\n        self.prepare_dataset(batch_dict)\n        self.algo_observer.after_steps()\n\n        if self.has_central_value:\n            self.train_central_value()\n\n        train_info = None\n\n        if self.is_rnn:\n            frames_mask_ratio = rnn_masks.sum().item() / (rnn_masks.nelement())\n            print(frames_mask_ratio)\n\n        for _ in range(0, self.mini_epochs_num):\n            ep_kls = []\n            for i in range(len(self.dataset)):\n                curr_train_info = self.train_actor_critic(self.dataset[i])\n                print(type(curr_train_info))\n                \n                if self.schedule_type == 'legacy':  \n                    if self.multi_gpu:\n                        curr_train_info['kl'] = self.hvd.average_value(curr_train_info['kl'], 'ep_kls')\n                    self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, curr_train_info['kl'].item())\n                    self.update_lr(self.last_lr)\n\n                if (train_info is None):\n                    train_info = dict()\n                    for k, v in curr_train_info.items():\n                        train_info[k] = [v]\n                else:\n                    for k, v in curr_train_info.items():\n                        train_info[k].append(v)\n            \n            av_kls = torch_ext.mean_list(train_info['kl'])\n\n            if self.schedule_type == 'standard':\n                if self.multi_gpu:\n                    av_kls = self.hvd.average_value(av_kls, 'ep_kls')\n                self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item())\n                self.update_lr(self.last_lr)\n\n        if self.schedule_type == 'standard_epoch':\n            if self.multi_gpu:\n                av_kls = self.hvd.average_value(torch_ext.mean_list(kls), 'ep_kls')\n            self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item())\n            self.update_lr(self.last_lr)\n\n        update_time_end = time.time()\n        play_time = play_time_end - play_time_start\n        update_time = update_time_end - update_time_start\n        total_time = update_time_end - play_time_start\n\n        train_info['play_time'] = play_time\n        train_info['update_time'] = update_time\n        train_info['total_time'] = total_time\n        self._record_train_batch_info(batch_dict, train_info)\n\n        return train_info\n\n    def play_steps(self):\n        self.set_eval()\n        \n        epinfos = []\n        update_list = self.update_list\n\n        for n in range(self.horizon_length):\n            self.obs, done_env_ids = self._env_reset_done()\n            self.experience_buffer.update_data('obses', n, self.obs['obs'])\n\n            if self.use_action_masks:\n                masks = self.vec_env.get_action_masks()\n                res_dict = self.get_masked_action_values(self.obs, masks)\n            else:\n                res_dict = self.get_action_values(self.obs)\n\n            for k in update_list:\n                self.experience_buffer.update_data(k, n, res_dict[k]) \n\n            if self.has_central_value:\n                self.experience_buffer.update_data('states', n, self.obs['states'])\n\n            self.obs, rewards, self.dones, infos = self.env_step(res_dict['actions'])\n            shaped_rewards = self.rewards_shaper(rewards)\n            self.experience_buffer.update_data('rewards', n, shaped_rewards)\n            self.experience_buffer.update_data('next_obses', n, self.obs['obs'])\n            self.experience_buffer.update_data('dones', n, self.dones)\n\n            terminated = infos['terminate'].float()\n            terminated = terminated.unsqueeze(-1)\n            next_vals = self._eval_critic(self.obs)\n            next_vals *= (1.0 - terminated)\n            self.experience_buffer.update_data('next_values', n, next_vals)\n\n            self.current_rewards += rewards\n            self.current_lengths += 1\n            all_done_indices = self.dones.nonzero(as_tuple=False)\n            done_indices = all_done_indices[::self.num_agents]\n  \n            self.game_rewards.update(self.current_rewards[done_indices])\n            self.game_lengths.update(self.current_lengths[done_indices])\n            self.algo_observer.process_infos(infos, done_indices)\n\n            not_dones = 1.0 - self.dones.float()\n\n            self.current_rewards = self.current_rewards * not_dones.unsqueeze(1)\n            self.current_lengths = self.current_lengths * not_dones\n\n        mb_fdones = self.experience_buffer.tensor_dict['dones'].float()\n        mb_values = self.experience_buffer.tensor_dict['values']\n        mb_next_values = self.experience_buffer.tensor_dict['next_values']\n        mb_rewards = self.experience_buffer.tensor_dict['rewards']\n        \n        mb_advs = self.discount_values(mb_fdones, mb_values, mb_rewards, mb_next_values)\n        mb_returns = mb_advs + mb_values\n\n        batch_dict = self.experience_buffer.get_transformed_list(a2c_common.swap_and_flatten01, self.tensor_list)\n        batch_dict['returns'] = a2c_common.swap_and_flatten01(mb_returns)\n        batch_dict['played_frames'] = self.batch_size\n\n        return batch_dict\n\n    def calc_gradients(self, input_dict):\n        self.set_train()\n\n        value_preds_batch = input_dict['old_values']\n        old_action_log_probs_batch = input_dict['old_logp_actions']\n        advantage = input_dict['advantages']\n        old_mu_batch = input_dict['mu']\n        old_sigma_batch = input_dict['sigma']\n        return_batch = input_dict['returns']\n        actions_batch = input_dict['actions']\n        obs_batch = input_dict['obs']\n        obs_batch = self._preproc_obs(obs_batch)\n\n        lr = self.last_lr\n        kl = 1.0\n        lr_mul = 1.0\n        curr_e_clip = lr_mul * self.e_clip\n\n        batch_dict = {\n            'is_train': True,\n            'prev_actions': actions_batch, \n            'obs' : obs_batch\n        }\n\n        rnn_masks = None\n        if self.is_rnn:\n            rnn_masks = input_dict['rnn_masks']\n            batch_dict['rnn_states'] = input_dict['rnn_states']\n            batch_dict['seq_length'] = self.seq_len\n\n        with torch.cuda.amp.autocast(enabled=self.mixed_precision):\n            res_dict = self.model(batch_dict)\n            action_log_probs = res_dict['prev_neglogp']\n            values = res_dict['value']\n            entropy = res_dict['entropy']\n            mu = res_dict['mu']\n            sigma = res_dict['sigma']\n\n            a_info = self._actor_loss(old_action_log_probs_batch, action_log_probs, advantage, curr_e_clip)\n            a_loss = a_info['actor_loss']\n\n            c_info = self._critic_loss(value_preds_batch, values, curr_e_clip, return_batch, self.clip_value)\n            c_loss = c_info['critic_loss']\n\n            b_loss = self.bound_loss(mu)\n\n            losses, sum_mask = torch_ext.apply_masks([a_loss.unsqueeze(1), c_loss, entropy.unsqueeze(1), b_loss.unsqueeze(1)], rnn_masks)\n            a_loss, c_loss, entropy, b_loss = losses[0], losses[1], losses[2], losses[3]\n            \n            loss = a_loss + self.critic_coef * c_loss - self.entropy_coef * entropy + self.bounds_loss_coef * b_loss\n            \n            if self.multi_gpu:\n                self.optimizer.zero_grad()\n            else:\n                for param in self.model.parameters():\n                    param.grad = None\n\n        self.scaler.scale(loss).backward()\n        #TODO: Refactor this ugliest code of the year\n        if self.truncate_grads:\n            if self.multi_gpu:\n                self.optimizer.synchronize()\n                self.scaler.unscale_(self.optimizer)\n                nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)\n                with self.optimizer.skip_synchronize():\n                    self.scaler.step(self.optimizer)\n                    self.scaler.update()\n            else:\n                self.scaler.unscale_(self.optimizer)\n                nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)\n                self.scaler.step(self.optimizer)\n                self.scaler.update()    \n        else:\n            self.scaler.step(self.optimizer)\n            self.scaler.update()\n\n        with torch.no_grad():\n            reduce_kl = not self.is_rnn\n            kl_dist = torch_ext.policy_kl(mu.detach(), sigma.detach(), old_mu_batch, old_sigma_batch, reduce_kl)\n            if self.is_rnn:\n                kl_dist = (kl_dist * rnn_masks).sum() / rnn_masks.numel()  #/ sum_mask\n                    \n        self.train_result = {\n            'entropy': entropy,\n            'kl': kl_dist,\n            'last_lr': self.last_lr, \n            'lr_mul': lr_mul, \n            'b_loss': b_loss\n        }\n        self.train_result.update(a_info)\n        self.train_result.update(c_info)\n\n        return\n\n    def discount_values(self, mb_fdones, mb_values, mb_rewards, mb_next_values):\n        lastgaelam = 0\n        mb_advs = torch.zeros_like(mb_rewards)\n\n        for t in reversed(range(self.horizon_length)):\n            not_done = 1.0 - mb_fdones[t]\n            not_done = not_done.unsqueeze(1)\n\n            delta = mb_rewards[t] + self.gamma * mb_next_values[t] - mb_values[t]\n            lastgaelam = delta + self.gamma * self.tau * not_done * lastgaelam\n            mb_advs[t] = lastgaelam\n\n        return mb_advs\n\n    def bound_loss(self, mu):\n        if self.bounds_loss_coef is not None:\n            soft_bound = 1.0\n            mu_loss_high = torch.maximum(mu - soft_bound, torch.tensor(0, device=self.ppo_device))**2\n            mu_loss_low = torch.minimum(mu + soft_bound, torch.tensor(0, device=self.ppo_device))**2\n            b_loss = (mu_loss_low + mu_loss_high).sum(axis=-1)\n        else:\n            b_loss = 0\n        return b_loss\n\n    def _load_config_params(self, config):\n        self.last_lr = config['learning_rate']\n        return\n\n    def _build_net_config(self):\n        obs_shape = torch_ext.shape_whc_to_cwh(self.obs_shape)\n        config = {\n            'actions_num' : self.actions_num,\n            'input_shape' : obs_shape,\n            'num_seqs' : self.num_actors * self.num_agents,\n            'value_size': self.env_info.get('value_size', 1),\n            'normalize_value' : self.normalize_value,\n            'normalize_input': self.normalize_input,\n        }\n        return config\n\n    def _setup_action_space(self):\n        action_space = self.env_info['action_space']\n        self.actions_num = action_space.shape[0]\n\n        # todo introduce device instead of cuda()\n        self.actions_low = torch.from_numpy(action_space.low.copy()).float().to(self.ppo_device)\n        self.actions_high = torch.from_numpy(action_space.high.copy()).float().to(self.ppo_device)\n        return\n\n    def _init_train(self):\n        return\n\n    def _env_reset_done(self):\n        obs, done_env_ids = self.vec_env.reset_done()\n        return self.obs_to_tensors(obs), done_env_ids\n\n    def _eval_critic(self, obs_dict):\n        self.model.eval()\n        obs = obs_dict['obs']\n\n        processed_obs = self._preproc_obs(obs)\n        if self.normalize_input:\n            processed_obs = self.model.norm_obs(processed_obs)\n        value = self.model.a2c_network.eval_critic(processed_obs)\n\n        if self.normalize_value:\n            value = self.value_mean_std(value, True)\n        return value\n\n    def _actor_loss(self, old_action_log_probs_batch, action_log_probs, advantage, curr_e_clip):\n        clip_frac = None\n        if (self.ppo):\n            ratio = torch.exp(old_action_log_probs_batch - action_log_probs)\n            surr1 = advantage * ratio\n            surr2 = advantage * torch.clamp(ratio, 1.0 - curr_e_clip,\n                                    1.0 + curr_e_clip)\n            a_loss = torch.max(-surr1, -surr2)\n\n            clipped = torch.abs(ratio - 1.0) > curr_e_clip\n            clip_frac = torch.mean(clipped.float())\n            clip_frac = clip_frac.detach()\n        else:\n            a_loss = (action_log_probs * advantage)\n    \n        info = {\n            'actor_loss': a_loss,\n            'actor_clip_frac': clip_frac\n        }\n        return info\n\n    def _critic_loss(self, value_preds_batch, values, curr_e_clip, return_batch, clip_value):\n        if clip_value:\n            value_pred_clipped = value_preds_batch + \\\n                    (values - value_preds_batch).clamp(-curr_e_clip, curr_e_clip)\n            value_losses = (values - return_batch)**2\n            value_losses_clipped = (value_pred_clipped - return_batch)**2\n            c_loss = torch.max(value_losses, value_losses_clipped)\n        else:\n            c_loss = (return_batch - values)**2\n\n        info = {\n            'critic_loss': c_loss\n        }\n        return info\n    \n    def _record_train_batch_info(self, batch_dict, train_info):\n        return\n\n    def _log_train_info(self, train_info, frame):\n        self.writer.add_scalar('performance/update_time', train_info['update_time'], frame)\n        self.writer.add_scalar('performance/play_time', train_info['play_time'], frame)\n        self.writer.add_scalar('losses/a_loss', torch_ext.mean_list(train_info['actor_loss']).item(), frame)\n        self.writer.add_scalar('losses/c_loss', torch_ext.mean_list(train_info['critic_loss']).item(), frame)\n        \n        self.writer.add_scalar('losses/bounds_loss', torch_ext.mean_list(train_info['b_loss']).item(), frame)\n        self.writer.add_scalar('losses/entropy', torch_ext.mean_list(train_info['entropy']).item(), frame)\n        self.writer.add_scalar('info/last_lr', train_info['last_lr'][-1] * train_info['lr_mul'][-1], frame)\n        self.writer.add_scalar('info/lr_mul', train_info['lr_mul'][-1], frame)\n        self.writer.add_scalar('info/e_clip', self.e_clip * train_info['lr_mul'][-1], frame)\n        self.writer.add_scalar('info/clip_frac', torch_ext.mean_list(train_info['actor_clip_frac']).item(), frame)\n        self.writer.add_scalar('info/kl', torch_ext.mean_list(train_info['kl']).item(), frame)\n        return\n"
  },
  {
    "path": "timechamber/learning/common_player.py",
    "content": "# License: see [LICENSE, LICENSES/isaacgymenvs/LICENSE]\n\nimport torch \n\nfrom rl_games.algos_torch import players\nfrom rl_games.algos_torch import torch_ext\nfrom rl_games.algos_torch.running_mean_std import RunningMeanStd\nfrom rl_games.common.player import BasePlayer\n\n\nclass CommonPlayer(players.PpoPlayerContinuous):\n\n    def __init__(self, params):\n        BasePlayer.__init__(self, params)\n        self.network = self.config['network']\n\n        self.normalize_input = self.config['normalize_input']\n        self.normalize_value = self.config['normalize_value']\n        \n        self._setup_action_space()\n        self.mask = [False]\n        \n        net_config = self._build_net_config()\n        self._build_net(net_config)   \n        \n        return\n\n    def run(self):\n        n_games = self.games_num\n        render = self.render_env\n        n_game_life = self.n_game_life\n        is_determenistic = self.is_determenistic\n        sum_rewards = 0\n        sum_steps = 0\n        sum_game_res = 0\n        n_games = n_games * n_game_life\n        games_played = 0\n        has_masks = False\n        has_masks_func = getattr(self.env, \"has_action_mask\", None) is not None\n\n        op_agent = getattr(self.env, \"create_agent\", None)\n        if op_agent:\n            agent_inited = True\n\n        if has_masks_func:\n            has_masks = self.env.has_action_mask()\n\n        need_init_rnn = self.is_rnn\n        for _ in range(n_games):\n            if games_played >= n_games:\n                break\n\n            obs_dict = self.env_reset(self.env)\n            batch_size = 1\n            batch_size = self.get_batch_size(obs_dict['obs'], batch_size)\n\n            if need_init_rnn:\n                self.init_rnn()\n                need_init_rnn = False\n\n            cr = torch.zeros(batch_size, dtype=torch.float32)\n            steps = torch.zeros(batch_size, dtype=torch.float32)\n\n            print_game_res = False\n\n            for n in range(self.max_steps):\n                obs_dict, done_env_ids = self._env_reset_done()\n\n                if has_masks:\n                    masks = self.env.get_action_mask()\n                    action = self.get_masked_action(obs_dict, masks, is_determenistic)\n                else:\n                    action = self.get_action(obs_dict, is_determenistic)\n                obs_dict, r, done, info =  self.env_step(self.env, action)\n                cr += r\n                steps += 1\n  \n                self._post_step(info)\n\n                if render:\n                    self.env.render(mode = 'human')\n                    time.sleep(self.render_sleep)\n\n                all_done_indices = done.nonzero(as_tuple=False)\n                done_indices = all_done_indices[::self.num_agents]\n                done_count = len(done_indices)\n                games_played += done_count\n\n                if done_count > 0:\n                    if self.is_rnn:\n                        for s in self.states:\n                            s[:,all_done_indices,:] = s[:,all_done_indices,:] * 0.0\n\n                    cur_rewards = cr[done_indices].sum().item()\n                    cur_steps = steps[done_indices].sum().item()\n\n                    cr = cr * (1.0 - done.float())\n                    steps = steps * (1.0 - done.float())\n                    sum_rewards += cur_rewards\n                    sum_steps += cur_steps\n\n                    game_res = 0.0\n                    if isinstance(info, dict):\n                        if 'battle_won' in info:\n                            print_game_res = True\n                            game_res = info.get('battle_won', 0.5)\n                        if 'scores' in info:\n                            print_game_res = True\n                            game_res = info.get('scores', 0.5)\n                    if self.print_stats:\n                        if print_game_res:\n                            print('reward:', cur_rewards/done_count, 'steps:', cur_steps/done_count, 'w:', game_res)\n                        else:\n                            print('reward:', cur_rewards/done_count, 'steps:', cur_steps/done_count)\n\n                    sum_game_res += game_res\n                    if batch_size//self.num_agents == 1 or games_played >= n_games:\n                        break\n\n        print(sum_rewards)\n        if print_game_res:\n            print('av reward:', sum_rewards / games_played * n_game_life, 'av steps:', sum_steps / games_played * n_game_life, 'winrate:', sum_game_res / games_played * n_game_life)\n        else:\n            print('av reward:', sum_rewards / games_played * n_game_life, 'av steps:', sum_steps / games_played * n_game_life)\n\n        return\n\n    def obs_to_torch(self, obs):\n        obs = super().obs_to_torch(obs)\n        obs_dict = {\n            'obs': obs\n        }\n        return obs_dict\n\n    def get_action(self, obs_dict, is_determenistic = False):\n        output = super().get_action(obs_dict['obs'], is_determenistic)\n        return output\n\n    def _build_net(self, config):\n        self.model = self.network.build(config)\n        self.model.to(self.device)\n        self.model.eval()\n        self.is_rnn = self.model.is_rnn()\n\n        return\n\n    def _env_reset_done(self):\n        obs, done_env_ids = self.env.reset_done()\n        return self.obs_to_torch(obs), done_env_ids\n\n    def _post_step(self, info):\n        return\n\n    def _build_net_config(self):\n        obs_shape = torch_ext.shape_whc_to_cwh(self.obs_shape)\n        config = {\n            'actions_num' : self.actions_num,\n            'input_shape' : obs_shape,\n            'num_seqs' : self.num_agents,\n            'value_size': self.env_info.get('value_size', 1),\n            'normalize_value': self.normalize_value,\n            'normalize_input': self.normalize_input,\n        } \n        return config\n\n    def _setup_action_space(self):\n        self.actions_num = self.action_space.shape[0] \n        self.actions_low = torch.from_numpy(self.action_space.low.copy()).float().to(self.device)\n        self.actions_high = torch.from_numpy(self.action_space.high.copy()).float().to(self.device)\n        return"
  },
  {
    "path": "timechamber/learning/hrl_sp_agent.py",
    "content": "import copy\nfrom collections import OrderedDict\nfrom datetime import datetime\nfrom gym import spaces\nimport numpy as np\nimport os\nimport time\nfrom .pfsp_player_pool import PFSPPlayerPool, SinglePlayer, PFSPPlayerThreadPool, PFSPPlayerProcessPool, \\\n    PFSPPlayerVectorizedPool\nfrom rl_games.common.a2c_common import swap_and_flatten01\nfrom rl_games.algos_torch import torch_ext\nfrom rl_games.algos_torch import central_value\nfrom isaacgym.torch_utils import *\nimport torch\nfrom torch import optim\nfrom tensorboardX import SummaryWriter\nimport torch.distributed as dist\nimport timechamber.ase.hrl_agent as hrl_agent\nfrom timechamber.utils.utils import load_check, load_checkpoint\n\nclass HRLSPAgent(hrl_agent.HRLAgent):\n    def __init__(self, base_name, params):\n        params['config']['device'] = params['device']\n        super().__init__(base_name, params)\n        self.player_pool_type = params['player_pool_type']\n        self.base_model_config = {\n            'actions_num': self.actions_num,\n            'input_shape': self.obs_shape,\n            'num_seqs': self.num_agents,\n            'value_size': self.env_info.get('value_size', 1),\n            'normalize_value': self.normalize_value,\n            'normalize_input': self.normalize_input,\n        }\n        self.max_his_player_num = params['player_pool_length']\n\n        if params['op_load_path']:\n            self.init_op_model = self.create_model()\n            self.restore_op(params['op_load_path'])\n        else:\n            self.init_op_model = self.model\n        self.players_dir = os.path.join(self.experiment_dir, 'policy_dir')\n        os.makedirs(self.players_dir, exist_ok=True)\n        self.update_win_rate = params['update_win_rate']\n        self.num_opponent_agents = params['num_agents'] - 1\n        self.player_pool = self._build_player_pool(params)\n\n        self.games_to_check = params['games_to_check']\n        self.now_update_steps = 0\n        self.max_update_steps = params['max_update_steps']\n        self.update_op_num = 0\n        self.update_player_pool(self.init_op_model, player_idx=self.update_op_num)\n        self.resample_op(torch.arange(end=self.num_actors, device=self.device, dtype=torch.long))\n\n        assert self.num_actors % self.max_his_player_num == 0\n\n    def _build_player_pool(self, params):\n        if self.player_pool_type == 'vectorized':\n            vector_model_config = self.base_model_config\n            vector_model_config['num_envs'] = self.num_actors * self.num_opponent_agents\n            vector_model_config['population_size'] = self.max_his_player_num\n\n            return PFSPPlayerVectorizedPool(max_length=self.max_his_player_num, device=self.device,\n                                            vector_model_config=vector_model_config, params=params)\n        else:\n            return PFSPPlayerPool(max_length=self.max_his_player_num, device=self.device)\n\n    def play_steps(self):\n        self.set_eval()\n\n        env_done_indices = torch.tensor([], device=self.device, dtype=torch.long)\n        update_list = self.update_list\n        step_time = 0.0\n\n        for n in range(self.horizon_length):\n            self.obs = self.env_reset(env_done_indices)\n            self.experience_buffer.update_data('obses', n, self.obs['obs'])\n            \n            if self.use_action_masks:\n                masks = self.vec_env.get_action_masks()\n                res_dict = self.get_masked_action_values(self.obs, masks)\n            else:\n                res_dict_op = self.get_action_values(self.obs, is_op=True)\n                res_dict = self.get_action_values(self.obs)\n            \n            for k in update_list:\n                self.experience_buffer.update_data(k, n, res_dict[k])\n            if self.has_central_value:\n                self.experience_buffer.update_data('states', n, self.obs['states'])\n\n            if self.player_pool_type == 'multi_thread':\n                self.player_pool.thread_pool.shutdown()\n            step_time_start = time.time()\n            \n            self.obs, rewards, self.dones, infos = self.env_step(res_dict['actions'],\n                                                                 res_dict_op['actions'])\n            step_time_end = time.time()\n            step_time += (step_time_end - step_time_start)\n\n            shaped_rewards = self.rewards_shaper(rewards)\n            if self.value_bootstrap and 'time_outs' in infos:\n                shaped_rewards += self.gamma * res_dict['values'] * self.cast_obs(infos['time_outs']).unsqueeze(\n                    1).float()\n\n            self.experience_buffer.update_data('rewards', n, shaped_rewards)\n            self.experience_buffer.update_data('next_obses', n, self.obs['obs'])\n            self.experience_buffer.update_data('dones', n, self.dones)\n            self.experience_buffer.update_data('disc_rewards', n, infos['disc_rewards'])\n\n            terminated = infos['terminate'].float()\n            terminated = terminated.unsqueeze(-1)\n            next_vals = self._eval_critic(self.obs)\n            next_vals *= (1.0 - terminated)\n            self.experience_buffer.update_data('next_values', n, next_vals)\n\n            self.current_rewards += rewards\n            self.current_lengths += 1\n            all_done_indices = self.dones.nonzero(as_tuple=False)\n            env_done_indices = self.dones.view(self.num_actors, self.num_agents).all(dim=1).nonzero(as_tuple=False)\n\n            self.game_rewards.update(self.current_rewards[env_done_indices])\n            self.game_lengths.update(self.current_lengths[env_done_indices])\n            self.algo_observer.process_infos(infos, env_done_indices)\n\n            not_dones = 1.0 - self.dones.float()\n\n            self.current_rewards = self.current_rewards * not_dones.unsqueeze(1)\n            self.current_lengths = self.current_lengths * not_dones\n\n            self.player_pool.update_player_metric(infos=infos)\n            self.resample_op(all_done_indices.flatten())\n            \n            env_done_indices = env_done_indices[:, 0]\n\n        last_values = self.get_values(self.obs)\n\n        mb_fdones = self.experience_buffer.tensor_dict['dones'].float()\n        mb_values = self.experience_buffer.tensor_dict['values']\n        mb_next_values = self.experience_buffer.tensor_dict['next_values']\n\n        mb_rewards = self.experience_buffer.tensor_dict['rewards']\n        mb_disc_rewards = self.experience_buffer.tensor_dict['disc_rewards']\n        mb_rewards = self._combine_rewards(mb_rewards, mb_disc_rewards)\n\n        mb_advs = self.discount_values(mb_fdones, mb_values, mb_rewards, mb_next_values)\n        mb_returns = mb_advs + mb_values\n\n        batch_dict = self.experience_buffer.get_transformed_list(swap_and_flatten01, self.tensor_list)\n        batch_dict['returns'] = swap_and_flatten01(mb_returns)\n        batch_dict['played_frames'] = self.batch_size\n        batch_dict['step_time'] = step_time\n        return batch_dict\n\n    def env_step(self, ego_actions, op_actions):\n        ego_actions = self.preprocess_actions(ego_actions)\n        op_actions = self.preprocess_actions(op_actions)\n        obs = self.obs['obs']\n        obs_op = self.obs['obs_op']\n\n        rewards = 0.0\n        disc_rewards = 0.0\n        done_count = 0.0\n        terminate_count = 0.0\n        win_count = 0.0\n        lose_count = 0.0\n        draw_count = 0.0\n\n        for t in range(self._llc_steps):\n            llc_ego_actions = self._compute_llc_action(obs, ego_actions)\n            llc_op_actions = self._compute_llc_action(obs_op, op_actions)\n            llc_actions = torch.cat((llc_ego_actions, llc_op_actions), dim=0)\n\n            obs_dict, curr_rewards, curr_dones, infos = self.vec_env.step(llc_actions)\n\n            rewards += curr_rewards\n            done_count += curr_dones\n            terminate_count += infos['terminate']\n            win_count += infos['win']\n            lose_count += infos['lose']\n            draw_count += infos['draw']\n\n            amp_obs = infos['amp_obs']\n            curr_disc_reward = self._calc_disc_reward(amp_obs)\n            disc_rewards += curr_disc_reward\n\n            obs = obs_dict['obs'][:self.num_actors]\n            obs_op = obs_dict['obs'][self.num_actors:]\n\n        rewards /= self._llc_steps\n        disc_rewards /= self._llc_steps\n\n        dones = torch.zeros_like(done_count)\n        dones[done_count > 0] = 1.0\n        terminate = torch.zeros_like(terminate_count)\n        terminate[terminate_count > 0] = 1.0\n        infos['terminate'] = terminate\n        infos['disc_rewards'] = disc_rewards\n        \n        wins = torch.zeros_like(win_count)\n        wins[win_count > 0] = 1.0\n        infos['win'] = wins\n        \n        loses = torch.zeros_like(lose_count)\n        loses[lose_count > 0] = 1.0\n        infos['lose'] = loses\n        \n        draws = torch.zeros_like(draw_count)\n        draws[draw_count > 0] = 1.0\n        infos['draw'] = draws\n\n        obs_dict = {}\n        obs_dict['obs'] = obs\n        obs_dict['obs_op'] = obs_op\n\n        if self.is_tensor_obses:\n            if self.value_size == 1:\n                rewards = rewards.unsqueeze(1)\n            return self.obs_to_tensors(obs_dict), rewards.to(self.ppo_device), dones.to(self.ppo_device), infos\n        else:\n            if self.value_size == 1:\n                rewards = np.expand_dims(rewards, axis=1)\n            return self.obs_to_tensors(obs_dict), torch.from_numpy(rewards).to(self.ppo_device).float(), torch.from_numpy(dones).to(self.ppo_device), infos\n\n    def env_reset(self, env_ids=None):\n        obs = self.vec_env.reset(env_ids)\n        obs = self.obs_to_tensors(obs)\n        obs['obs_op'] = obs['obs'][self.num_actors:]\n        obs['obs'] = obs['obs'][:self.num_actors]\n\n        return obs\n\n    def train(self):\n        self.init_tensors()\n        self.mean_rewards = self.last_mean_rewards = -100500\n        start_time = time.time()\n        total_time = 0\n        rep_count = 0\n        # self.frame = 0  # loading from checkpoint\n        self.obs = self.env_reset()\n\n        if self.multi_gpu:\n            torch.cuda.set_device(self.rank)\n            print(\"====================broadcasting parameters\")\n            model_params = [self.model.state_dict()]\n            dist.broadcast_object_list(model_params, 0)\n            self.model.load_state_dict(model_params[0])\n\n        self._init_train()\n\n        while True:\n            epoch_num = self.update_epoch()\n            train_info = self.train_epoch()\n            print(f\"epoch num: {epoch_num}\")\n            sum_time = train_info['total_time']\n            step_time = train_info['step_time']\n            play_time = train_info['play_time']\n            update_time = train_info['update_time']\n            a_losses = train_info['actor_loss']\n            c_losses = train_info['critic_loss']\n            entropies = train_info['entropy']\n            kls = train_info['kl']\n            last_lr = train_info['last_lr'][-1]\n            lr_mul = train_info['lr_mul'][-1]\n\n            # cleaning memory to optimize space\n            self.dataset.update_values_dict(None)\n            total_time += sum_time\n            curr_frames = self.curr_frames * self.rank_size if self.multi_gpu else self.curr_frames\n            self.frame += curr_frames\n            should_exit = False\n\n            if self.rank == 0:\n                self.diagnostics.epoch(self, current_epoch=epoch_num)\n                scaled_time = self.num_agents * sum_time\n                scaled_play_time = self.num_agents * play_time\n\n                frame = self.frame // self.num_agents\n\n                if self.print_stats:\n                    step_time = max(step_time, 1e-6)\n                    fps_step = curr_frames / step_time\n                    fps_step_inference = curr_frames / scaled_play_time\n                    fps_total = curr_frames / scaled_time\n                    print(\n                        f'fps step: {fps_step:.0f} fps step and policy inference: {fps_step_inference:.0f} fps total: {fps_total:.0f} epoch: {epoch_num}/{self.max_epochs}')\n\n                self.write_stats(total_time, epoch_num, step_time, play_time, update_time, a_losses, c_losses,\n                                 entropies, kls, last_lr, lr_mul, frame, scaled_time, scaled_play_time, curr_frames)\n\n                self.algo_observer.after_print_stats(frame, epoch_num, total_time)\n\n                if self.game_rewards.current_size > 0:\n                    mean_rewards = self.game_rewards.get_mean()\n                    mean_lengths = self.game_lengths.get_mean()\n                    self.mean_rewards = mean_rewards[0]\n\n                    for i in range(self.value_size):\n                        rewards_name = 'rewards' if i == 0 else 'rewards{0}'.format(i)\n                        self.writer.add_scalar(rewards_name + '/step'.format(i), mean_rewards[i], frame)\n                        self.writer.add_scalar(rewards_name + '/iter'.format(i), mean_rewards[i], epoch_num)\n                        self.writer.add_scalar(rewards_name + '/time'.format(i), mean_rewards[i], total_time)\n\n                    self.writer.add_scalar('episode_lengths/step', mean_lengths, frame)\n                    self.writer.add_scalar('episode_lengths/iter', mean_lengths, epoch_num)\n                    self.writer.add_scalar('episode_lengths/time', mean_lengths, total_time)\n\n                    # removed equal signs (i.e. \"rew=\") from the checkpoint name since it messes with hydra CLI parsing\n                    checkpoint_name = self.config['name'] + '_ep_' + str(epoch_num) + '_rew_' + str(mean_rewards[0])\n\n                    if self.save_freq > 0:\n                        if (epoch_num % self.save_freq == 0) and (mean_rewards <= self.last_mean_rewards):\n                            self.save(os.path.join(self.nn_dir, 'last_' + checkpoint_name))\n\n                    if mean_rewards[0] > self.last_mean_rewards and epoch_num >= self.save_best_after:\n                        print('saving next best rewards: ', mean_rewards)\n                        self.last_mean_rewards = mean_rewards[0]\n                        self.save(os.path.join(self.nn_dir, self.config['name']))\n\n                        if 'score_to_win' in self.config:\n                            if self.last_mean_rewards > self.config['score_to_win']:\n                                print('Network won!')\n                                self.save(os.path.join(self.nn_dir, checkpoint_name))\n                                should_exit = True\n\n                if epoch_num >= self.max_epochs:\n                    if self.game_rewards.current_size == 0:\n                        print('WARNING: Max epochs reached before any env terminated at least once')\n                        mean_rewards = -np.inf\n\n                    self.save(os.path.join(self.nn_dir,\n                                           'last_' + self.config['name'] + 'ep' + str(epoch_num) + 'rew' + str(\n                                               mean_rewards)))\n                    print('MAX EPOCHS NUM!')\n                    should_exit = True\n                self.update_metric()\n                update_time = 0\n\n            if self.multi_gpu:\n                should_exit_t = torch.tensor(should_exit, device=self.device).float()\n                dist.broadcast(should_exit_t, 0)\n                should_exit = should_exit_t.bool().item()\n            if should_exit:\n                return self.last_mean_rewards, epoch_num\n\n    def update_metric(self):\n        tot_win_rate = 0\n        tot_games_num = 0\n        self.now_update_steps += 1\n        # self_player process\n        for player in self.player_pool.players:\n            win_rate = player.win_rate()\n            games = player.games_num()\n            self.writer.add_scalar(f'rate/win_rate_player_{player.player_idx}', win_rate, self.epoch_num)\n            tot_win_rate += win_rate * games\n            tot_games_num += games\n        win_rate = tot_win_rate / tot_games_num\n        if tot_games_num > self.games_to_check:\n            self.check_update_opponent(win_rate)\n        self.writer.add_scalar('rate/win_rate', win_rate, self.epoch_num)\n\n    def get_action_values(self, obs, is_op=False):\n        processed_obs = self._preproc_obs(obs['obs_op'] if is_op else obs['obs'])\n        if not is_op:\n            self.model.eval()\n        input_dict = {\n            'is_train': False,\n            'prev_actions': None,\n            'obs': processed_obs,\n            'rnn_states': self.rnn_states\n        }\n        with torch.no_grad():\n            if is_op:\n                res_dict = {\n                    \"actions\": torch.zeros((self.num_actors * self.num_opponent_agents, self.actions_num),\n                                           device=self.device),\n                    \"values\": torch.zeros((self.num_actors * self.num_opponent_agents, 1), device=self.device)\n                }\n                self.player_pool.inference(input_dict, res_dict, processed_obs)\n            else:\n                res_dict = self.model(input_dict)\n            if self.has_central_value:\n                states = obs['states']\n                input_dict = {\n                    'is_train': False,\n                    'states': states,\n                }\n                value = self.get_central_value(input_dict)\n                res_dict['values'] = value\n        return res_dict\n\n    def restore(self, fn):\n        checkpoint = load_checkpoint(fn, device=self.device)\n        checkpoint = load_check(checkpoint=checkpoint,\n                                normalize_input=self.normalize_input,\n                                normalize_value=self.normalize_value)\n        self.set_full_state_weights(checkpoint)\n\n    def resample_op(self, resample_indices):\n        for op_idx in range(self.num_opponent_agents):\n            for player in self.player_pool.players:\n                player.remove_envs(resample_indices + op_idx * self.num_actors)\n        for op_idx in range(self.num_opponent_agents):\n            for env_idx in resample_indices:\n                player = self.player_pool.sample_player()\n                player.add_envs(env_idx + op_idx * self.num_actors)\n        for player in self.player_pool.players:\n            player.reset_envs()\n\n    def resample_batch(self):\n        env_indices = torch.arange(end=self.num_actors * self.num_opponent_agents,\n                                   device=self.device, dtype=torch.long,\n                                   requires_grad=False)\n        step = self.num_actors // 32\n        for player in self.player_pool.players:\n            player.clear_envs()\n        for i in range(0, self.num_actors, step):\n            player = self.player_pool.sample_player()\n            player.add_envs(env_indices[i:i + step])\n        print(\"resample done\")\n\n    def restore_op(self, fn):\n        checkpoint = load_checkpoint(fn, device=self.device)\n        checkpoint = load_check(checkpoint, normalize_input=self.normalize_input,\n                                normalize_value=self.normalize_value)\n        self.init_op_model.load_state_dict(checkpoint['model'])\n        if self.normalize_input and 'running_mean_std' in checkpoint:\n            self.init_op_model.running_mean_std.load_state_dict(checkpoint['running_mean_std'])\n\n    def check_update_opponent(self, win_rate):\n        if win_rate > self.update_win_rate or self.now_update_steps > self.max_update_steps:\n            print(f'winrate:{win_rate},add opponent to player pool')\n            self.update_op_num += 1\n            self.now_update_steps = 0\n            self.update_player_pool(self.model, player_idx=self.update_op_num)\n            self.player_pool.clear_player_metric()\n            self.resample_op(torch.arange(end=self.num_actors, device=self.device, dtype=torch.long))\n            self.save(os.path.join(self.players_dir, f'policy_{self.update_op_num}'))\n\n    def create_model(self):\n        model = self.network.build(self.base_model_config)\n        model.to(self.device)\n        return model\n\n    def update_player_pool(self, model, player_idx):\n        new_model = self.create_model()\n        new_model.load_state_dict(copy.deepcopy(model.state_dict()))\n        if hasattr(model, 'running_mean_std'):\n            new_model.running_mean_std.load_state_dict(copy.deepcopy(model.running_mean_std.state_dict()))\n        player = SinglePlayer(player_idx, new_model, self.device, self.num_actors * self.num_opponent_agents)\n        self.player_pool.add_player(player)\n"
  },
  {
    "path": "timechamber/learning/hrl_sp_player.py",
    "content": "# License: see [LICENSE, LICENSES/isaacgymenvs/LICENSE]\nimport os\nimport time\nimport torch\nimport numpy as np\nfrom rl_games.algos_torch import players\nimport random\nfrom rl_games.algos_torch import torch_ext\nfrom rl_games.common.tr_helpers import unsqueeze_obs\nfrom timechamber.ase import hrl_players\nfrom timechamber.utils.utils import load_check, load_checkpoint\nfrom .pfsp_player_pool import PFSPPlayerPool, PFSPPlayerVectorizedPool, PFSPPlayerThreadPool, PFSPPlayerProcessPool, \\\n    SinglePlayer\nimport matplotlib.pyplot as plt\n\nfrom multielo import MultiElo\n\n\nclass HRLSPPlayer(hrl_players.HRLPlayer):\n    def __init__(self, params):\n        params['config']['device_name'] = params['device']\n        super().__init__(params)\n        print(f'params:{params}')\n        self.network = self.config['network']\n        self.mask = [False]\n        self.is_rnn = False\n        self.normalize_input = self.config['normalize_input']\n        self.normalize_value = self.config.get('normalize_value', False)\n        self.base_model_config = {\n            'actions_num': self.actions_num,\n            'input_shape': self.obs_shape,\n            'num_seqs': self.num_agents,\n            'value_size': self.env_info.get('value_size', 1),\n            'normalize_value': self.normalize_value,\n            'normalize_input': self.normalize_input,\n        }\n        self.policy_timestep = []\n        self.policy_op_timestep = []\n        self.params = params\n        self.record_elo = self.player_config.get('record_elo', False)\n        self.init_elo = self.player_config.get('init_elo', 400)\n        self.num_actors = params['config']['num_actors']\n        self.player_pool_type = params['player_pool_type']\n        self.player_pool = None\n        self.op_player_pool = None\n        self.num_opponents = params['num_agents'] - 1\n        self.max_steps = 1000\n        self.update_op_num = 0\n        self.players_per_env = []\n        self.elo = MultiElo()\n\n    def restore(self, load_dir):\n        if os.path.isdir(load_dir):\n            self.player_pool = self._build_player_pool(params=self.params, player_num=len(os.listdir(load_dir)))\n            print('dir:', load_dir)\n            sorted_players = []\n            for idx, policy_check_checkpoint in enumerate(os.listdir(load_dir)):\n                model_timestep = os.path.getmtime(load_dir + '/' + str(policy_check_checkpoint))\n                self.policy_timestep.append(model_timestep)\n                model = self.load_model(load_dir + '/' + str(policy_check_checkpoint))\n                new_player = SinglePlayer(player_idx=model_timestep, model=model, device=self.device,\n                                          rating=self.init_elo, obs_batch_len=self.num_actors * self.num_opponents)\n                sorted_players.append(new_player)\n            sorted_players.sort(key=lambda player: player.player_idx)\n            for idx, player in enumerate(sorted_players):\n                player.player_idx = idx\n                self.player_pool.add_player(player)\n            self.policy_timestep.sort()\n        else:\n            self.player_pool = self._build_player_pool(params=self.params, player_num=1)\n            self.policy_timestep.append(os.path.getmtime(load_dir))\n            model = self.load_model(load_dir)\n            new_player = SinglePlayer(player_idx=0, model=model, device=self.device,\n                                      rating=self.init_elo, obs_batch_len=self.num_actors * self.num_opponents)\n            self.player_pool.add_player(new_player)\n        self.restore_op(self.params['op_load_path'])\n        self._norm_policy_timestep()\n        self._alloc_env_indices()\n\n    def restore_op(self, load_dir):\n        if os.path.isdir(load_dir):\n            self.op_player_pool = self._build_player_pool(params=self.params, player_num=len(os.listdir(load_dir)))\n            sorted_players = []\n            for idx, policy_check_checkpoint in enumerate(os.listdir(load_dir)):\n                model_timestep = os.path.getmtime(load_dir + '/' + str(policy_check_checkpoint))\n                self.policy_op_timestep.append(model_timestep)\n                model = self.load_model(load_dir + '/' + str(policy_check_checkpoint))\n                new_player = SinglePlayer(player_idx=model_timestep, model=model, device=self.device,\n                                          rating=self.init_elo, obs_batch_len=self.num_actors * self.num_opponents)\n                sorted_players.append(new_player)\n            sorted_players.sort(key=lambda player: player.player_idx)\n            for idx, player in enumerate(sorted_players):\n                player.player_idx = idx\n                self.op_player_pool.add_player(player)\n            self.policy_op_timestep.sort()\n        else:\n            self.op_player_pool = self._build_player_pool(params=self.params, player_num=1)\n            self.policy_op_timestep.append(os.path.getmtime(load_dir))\n            model = self.load_model(load_dir)\n            new_player = SinglePlayer(player_idx=0, model=model, device=self.device,\n                                      rating=400, obs_batch_len=self.num_actors * self.num_opponents)\n            self.op_player_pool.add_player(new_player)\n\n    def _alloc_env_indices(self):\n        for idx in range(self.num_actors):\n            player_idx = random.randint(0, len(self.player_pool.players) - 1)\n            self.player_pool.players[player_idx].add_envs(torch.tensor([idx], dtype=torch.long, device=self.device))\n            env_player = [self.player_pool.players[player_idx]]\n            for op_idx in range(self.num_opponents):\n                op_player_idx = random.randint(0, len(self.op_player_pool.players) - 1)\n                self.op_player_pool.players[op_player_idx].add_envs(\n                    torch.tensor([idx + op_idx * self.num_actors], dtype=torch.long, device=self.device))\n                env_player.append(self.op_player_pool.players[op_player_idx])\n            self.players_per_env.append(env_player)\n        for player in self.player_pool.players:\n            player.reset_envs()\n        for player in self.op_player_pool.players:\n            player.reset_envs()\n\n    def _build_player_pool(self, params, player_num):\n\n        if self.player_pool_type == 'multi_thread':\n            return PFSPPlayerProcessPool(max_length=player_num,\n                                         device=self.device)\n        elif self.player_pool_type == 'multi_process':\n            return PFSPPlayerThreadPool(max_length=player_num,\n                                        device=self.device)\n        elif self.player_pool_type == 'vectorized':\n            vector_model_config = self.base_model_config\n            vector_model_config['num_envs'] = self.num_actors * self.num_opponents\n            vector_model_config['population_size'] = player_num\n\n            return PFSPPlayerVectorizedPool(max_length=player_num, device=self.device,\n                                            vector_model_config=vector_model_config, params=params)\n        else:\n            return PFSPPlayerPool(max_length=player_num, device=self.device)\n\n    def _update_rating(self, info, env_indices):\n        for env_idx in env_indices:\n            if self.num_opponents == 1:\n                player = self.players_per_env[env_idx][0]\n                op_player = self.players_per_env[env_idx][1]\n                if info['win'][env_idx]:\n                    player.rating, op_player.rating = self.elo.get_new_ratings([player.rating, op_player.rating])\n                elif info['lose'][env_idx]:\n                    op_player.rating, player.rating = self.elo.get_new_ratings([op_player.rating, player.rating])\n                elif info['draw'][env_idx]:\n                    player.rating, op_player.rating = self.elo.get_new_ratings([player.rating, op_player.rating],\n                                                                               result_order=[1, 1])\n            else:\n                ranks = info['ranks'][env_idx].cpu().numpy()\n                players_sorted_by_rank = sorted(enumerate(self.players_per_env[env_idx]), key=lambda x: ranks[x[0]])\n                sorted_ranks = sorted(ranks)\n                now_ratings = [player.rating for idx, player in players_sorted_by_rank]\n                new_ratings = self.elo.get_new_ratings(now_ratings, result_order=sorted_ranks)\n                for idx, new_rating in enumerate(new_ratings):\n                    players_sorted_by_rank[idx][1].rating = new_rating\n\n    def run(self):\n        n_games = self.games_num\n        render = self.render_env\n        n_game_life = self.n_game_life\n        is_determenistic = self.is_determenistic\n        sum_rewards = 0\n        sum_steps = 0\n        sum_game_res = 0\n        n_games = n_games * n_game_life\n        games_played = 0\n        has_masks = False\n        has_masks_func = getattr(self.env, \"has_action_mask\", None) is not None\n\n        if has_masks_func:\n            has_masks = self.env.has_action_mask()\n        print(f'games_num:{n_games}')\n        need_init_rnn = self.is_rnn\n        for _ in range(n_games):\n            if games_played >= n_games:\n                break\n\n            obses = self.env_reset(self.env)\n            batch_size = 1\n            batch_size = self.get_batch_size(obses['obs'], batch_size)\n\n            if need_init_rnn:\n                self.init_rnn()\n                need_init_rnn = False\n\n            cr = torch.zeros(batch_size, dtype=torch.float32, device=self.device)\n            steps = torch.zeros(batch_size, dtype=torch.float32, device=self.device)\n\n            print_game_res = False\n            done_indices = torch.tensor([], device=self.device, dtype=torch.long)\n\n            for n in range(self.max_steps):\n                obses = self.env_reset(self.env, done_indices)\n                if has_masks:\n                    masks = self.env.get_action_mask()\n                    action = self.get_masked_action(\n                        obses, masks, is_determenistic)\n                else:\n                    action = self.get_action(obses['obs'], is_determenistic)\n                    action_op = self.get_action(obses['obs_op'], is_determenistic, is_op=True)\n                obses, r, done, info = self.env_step(self.env, obses, action, action_op)\n                cr += r\n                steps += 1\n\n                if render:\n                    self.env.render(mode='human')\n                    time.sleep(self.render_sleep)\n\n                all_done_indices = done.nonzero(as_tuple=False)\n                done_indices = all_done_indices[::self.num_agents]\n                done_count = len(done_indices)\n                games_played += done_count\n                if self.record_elo:\n                    self._update_rating(info, all_done_indices.flatten())\n                if done_count > 0:\n                    if self.is_rnn:\n                        for s in self.states:\n                            s[:, all_done_indices, :] = s[:, all_done_indices, :] * 0.0\n\n                    cur_rewards = cr[done_indices].sum().item()\n                    cur_steps = steps[done_indices].sum().item()\n\n                    cr = cr * (1.0 - done.float())\n                    steps = steps * (1.0 - done.float())\n                    sum_rewards += cur_rewards\n                    sum_steps += cur_steps\n\n                    game_res = 0.0\n                    if isinstance(info, dict):\n                        if 'battle_won' in info:\n                            print_game_res = True\n                            game_res = info.get('battle_won', 0.5)\n                        if 'scores' in info:\n                            print_game_res = True\n                            game_res = info.get('scores', 0.5)\n                    if self.print_stats:\n                        if print_game_res:\n                            print('reward:', cur_rewards / done_count,\n                                  'steps:', cur_steps / done_count, 'w:', game_res)\n                        else:\n                            print('reward:', cur_rewards / done_count,\n                                  'steps:', cur_steps / done_count)\n\n                    sum_game_res += game_res\n                    if batch_size // self.num_agents == 1 or games_played >= n_games:\n                        break\n                done_indices = done_indices[:, 0]\n        if self.record_elo:\n            self._plot_elo_curve()\n\n    def _plot_elo_curve(self):\n        x = np.array(self.policy_timestep)\n        y = np.arange(len(self.player_pool.players))\n        x_op = np.array(self.policy_op_timestep)\n        y_op = np.arange(len(self.op_player_pool.players))\n        for player in self.player_pool.players:\n            idx = player.player_idx\n            y[idx] = player.rating\n        for player in self.op_player_pool.players:\n            idx = player.player_idx\n            y_op[idx] = player.rating\n        if self.params['load_path'] != self.params['op_load_path']:\n            l1 = plt.plot(x, y, 'b--', label='policy')\n            l2 = plt.plot(x_op, y_op, 'r--', label='policy_op')\n            plt.plot(x, y, 'b^-', x_op, y_op, 'ro-')\n        else:\n            l1 = plt.plot(x, y, 'b--', label='policy')\n            plt.plot(x, y, 'b^-')\n        plt.title('ELO Curve')\n        plt.xlabel('timestep/days')\n        plt.ylabel('ElO')\n        plt.legend()\n        plt.savefig(self.params['load_path'] + '/../elo.jpg')\n\n    def get_action(self, obs, is_determenistic=False, is_op=False):\n        if self.has_batch_dimension == False:\n            obs = unsqueeze_obs(obs)\n        obs = self._preproc_obs(obs)\n        input_dict = {\n            'is_train': False,\n            'prev_actions': None,\n            'obs': obs,\n            'rnn_states': self.states\n        }\n        with torch.no_grad():\n            data_len = self.num_actors * self.num_opponents if is_op else self.num_actors\n            res_dict = {\n                \"actions\": torch.zeros((data_len, self.actions_num), device=self.device),\n                \"values\": torch.zeros((data_len, 1), device=self.device),\n                \"mus\": torch.zeros((data_len, self.actions_num), device=self.device)\n            }\n            if is_op:\n                self.op_player_pool.inference(input_dict, res_dict, obs)\n            else:\n                self.player_pool.inference(input_dict, res_dict, obs)\n        mu = res_dict['mus']\n        action = res_dict['actions']\n        if is_determenistic:\n            current_action = mu\n        else:\n            current_action = action\n\n        current_action = torch.squeeze(current_action.detach())\n        return torch.clamp(current_action, -1.0, 1.0)\n\n    def _norm_policy_timestep(self):\n        self.policy_op_timestep.sort()\n        self.policy_timestep.sort()\n        for idx in range(1, len(self.policy_op_timestep)):\n            self.policy_op_timestep[idx] -= self.policy_op_timestep[0]\n            self.policy_op_timestep[idx] /= 3600 * 24\n        for idx in range(1, len(self.policy_timestep)):\n            self.policy_timestep[idx] -= self.policy_timestep[0]\n            self.policy_timestep[idx] /= 3600 * 24\n        self.policy_timestep[0] = 0\n        if len(self.policy_op_timestep):\n            self.policy_op_timestep[0] = 0\n\n    def env_reset(self, env, env_ids=None):\n        obs = env.reset(env_ids)\n        obs_dict = {}\n        obs_dict['obs_op'] = obs[self.num_actors:]\n        obs_dict['obs'] = obs[:self.num_actors]\n        return obs_dict\n\n    def env_step(self, env, obs_dict, ego_actions, op_actions):\n        obs = obs_dict['obs']\n        obs_op = obs_dict['obs_op']\n        rewards = 0.0\n        done_count = 0.0\n        disc_rewards = 0.0\n        terminate_count = 0.0\n        win_count = 0.0\n        lose_count = 0.0\n        draw_count = 0.0\n\n        for t in range(self._llc_steps):\n            llc_ego_actions = self._compute_llc_action(obs, ego_actions)\n            llc_op_actions = self._compute_llc_action(obs_op, op_actions)\n            llc_actions = torch.cat((llc_ego_actions, llc_op_actions), dim=0)\n            obs_all, curr_rewards, curr_dones, infos = env.step(llc_actions)\n\n            rewards += curr_rewards\n            done_count += curr_dones\n\n            terminate_count += infos['terminate']\n            win_count += infos['win']\n            lose_count += infos['lose']\n            draw_count += infos['draw']\n\n            amp_obs = infos['amp_obs']\n            curr_disc_reward = self._calc_disc_reward(amp_obs)\n            curr_disc_reward = curr_disc_reward[0, 0].cpu().numpy()\n            disc_rewards += curr_disc_reward\n\n            obs = obs_all[:self.num_actors]\n            obs_op = obs_all[self.num_actors:]\n\n        rewards /= self._llc_steps\n        disc_rewards /= self._llc_steps\n        dones = torch.zeros_like(done_count)\n        dones[done_count > 0] = 1.0\n        terminate = torch.zeros_like(terminate_count)\n        terminate[terminate_count > 0] = 1.0\n        infos['terminate'] = terminate\n        infos['disc_rewards'] = disc_rewards\n\n        wins = torch.zeros_like(win_count)\n        wins[win_count > 0] = 1.0\n        infos['win'] = wins\n        \n        loses = torch.zeros_like(lose_count)\n        loses[lose_count > 0] = 1.0\n        infos['lose'] = loses\n        \n        draws = torch.zeros_like(draw_count)\n        draws[draw_count > 0] = 1.0\n        infos['draw'] = draws\n\n        next_obs_dict = {}\n        next_obs_dict['obs_op'] = obs_op\n        next_obs_dict['obs'] = obs\n\n        if self.value_size > 1:\n            rewards = rewards[0]\n        if self.is_tensor_obses:\n            return self.obs_to_torch(next_obs_dict), rewards.cpu(), dones.cpu(), infos\n        else:\n            if np.isscalar(dones):\n                rewards = np.expand_dims(np.asarray(rewards), 0)\n                dones = np.expand_dims(np.asarray(dones), 0)\n            return next_obs_dict, rewards, dones, infos\n\n    def create_model(self):\n        model = self.network.build(self.base_model_config)\n        model.to(self.device)\n        return model\n\n    def load_model(self, fn):\n        model = self.create_model()\n        checkpoint = load_checkpoint(fn, device=self.device)\n        checkpoint = load_check(checkpoint, normalize_input=self.normalize_input,\n                                normalize_value=self.normalize_value)\n\n        model.load_state_dict(checkpoint['model'])\n\n        if self.normalize_input and 'running_mean_std' in checkpoint:\n            model.running_mean_std.load_state_dict(checkpoint['running_mean_std'])\n\n        return model\n"
  },
  {
    "path": "timechamber/learning/pfsp_player_pool.py",
    "content": "import collections\n\nimport random\nimport torch\nimport torch.multiprocessing as mp\nimport dill\n# import time\nfrom rl_games.algos_torch import model_builder\nfrom concurrent.futures import ThreadPoolExecutor, as_completed, wait, ALL_COMPLETED\n\n\ndef player_inference_thread(model, input_dict, res_dict, env_indices, processed_obs):\n    if len(env_indices) == 0:\n        return None\n    input_dict['obs'] = processed_obs[env_indices]\n    out_dict = model(input_dict)\n    for key in res_dict:\n        res_dict[key][env_indices] = out_dict[key]\n    return out_dict\n\n\ndef player_inference_process(pipe, queue, barrier):\n    input_dict = {\n        'is_train': False,\n        'prev_actions': None,\n        'obs': None,\n        'rnn_states': None,\n    }\n    model = None\n    barrier.wait()\n    while True:\n        msg = pipe.recv()\n        task = msg['task']\n        if task == 'init':\n            if model is not None:\n                del model\n            model = queue.get()\n            model = dill.loads(model)\n            barrier.wait()\n        elif task == 'forward':\n            obs, actions, values, env_indices = queue.get()\n            input_dict['obs'] = obs[env_indices]\n            out_dict = model(input_dict)\n            actions[env_indices] = out_dict['actions']\n            values[env_indices] = out_dict['values']\n            barrier.wait()\n            del obs, actions, values, env_indices\n        elif task == 'terminate':\n            break\n        else:\n            barrier.wait()\n\n\nclass SinglePlayer:\n    def __init__(self, player_idx, model, device, obs_batch_len=0, rating=None):\n        self.model = model\n        if model:\n            self.model.eval()\n        self.player_idx = player_idx\n        self._games = torch.tensor(0, device=device, dtype=torch.float)\n        self._wins = torch.tensor(0, device=device, dtype=torch.float)\n        self._loses = torch.tensor(0, device=device, dtype=torch.float)\n        self._draws = torch.tensor(0, device=device, dtype=torch.float)\n        self._decay = 0.998\n        self._has_env = torch.zeros((obs_batch_len,), device=device, dtype=torch.bool)\n        self.device = device\n        self.env_indices = torch.tensor([], device=device, dtype=torch.long, requires_grad=False)\n        if rating:\n            self.rating = rating\n\n    def __call__(self, input_dict):\n        return self.model(input_dict)\n\n    def reset_envs(self):\n        self.env_indices = self._has_env.nonzero(as_tuple=True)\n\n    def remove_envs(self, env_indices):\n        self._has_env[env_indices] = False\n\n    def add_envs(self, env_indices):\n        self._has_env[env_indices] = True\n\n    def clear_envs(self):\n        self.env_indices = torch.tensor([], device=self.device, dtype=torch.long, requires_grad=False)\n\n    def update_metric(self, wins, loses, draws):\n        win_count = torch.sum(wins[self.env_indices])\n        lose_count = torch.sum(loses[self.env_indices])\n        draw_count = torch.sum(draws[self.env_indices])\n        for stats in (self._games, self._wins, self._loses, self._draws):\n            stats *= self._decay\n        self._games += win_count + lose_count + draw_count\n        self._wins += win_count\n        self._loses += lose_count\n        self._draws += draw_count\n\n    def clear_metric(self):\n        self._games = torch.tensor(0, device=self.device, dtype=torch.float)\n        self._wins = torch.tensor(0, device=self.device, dtype=torch.float)\n        self._loses = torch.tensor(0, device=self.device, dtype=torch.float)\n        self._draws = torch.tensor(0, device=self.device, dtype=torch.float)\n\n    def win_rate(self):\n        if self.model is None:\n            return 0\n        elif self._games == 0:\n            return 0.5\n        return (self._wins + 0.5 * self._draws) / self._games\n\n    def games_num(self):\n        return self._games\n\n\nclass PFSPPlayerPool:\n    def __init__(self, max_length, device):\n        assert max_length > 0\n        self.players = []\n        self.max_length = max_length\n        self.idx = 0\n        self.device = device\n        self.weightings = {\n            \"variance\": lambda x: x * (1 - x),\n            \"linear\": lambda x: 1 - x,\n            \"squared\": lambda x: (1 - x) ** 2,\n        }\n\n    def add_player(self, player):\n        if len(self.players) < self.max_length:\n            self.players.append(player)\n        else:\n            self.players[self.idx] = player\n        self.idx += 1\n        self.idx %= self.max_length\n\n    def sample_player(self, weight='linear'):\n        weight_func = self.weightings[weight]\n        player = \\\n            random.choices(self.players, weights=[weight_func(player.win_rate()) for player in self.players])[0]\n        return player\n\n    def update_player_metric(self, infos):\n        for player in self.players:\n            player.update_metric(infos['win'], infos['lose'], infos['draw'])\n\n    def clear_player_metric(self):\n        for player in self.players:\n            player.clear_metric()\n\n    def inference(self, input_dict, res_dict, processed_obs):\n        for i, player in enumerate(self.players):\n            if len(player.env_indices[0]) == 0:\n                continue\n            input_dict['obs'] = processed_obs[player.env_indices]\n            out_dict = player(input_dict)\n            for key in res_dict:\n                res_dict[key][player.env_indices] = out_dict[key]\n\n\nclass PFSPPlayerVectorizedPool(PFSPPlayerPool):\n    def __init__(self, max_length, device, vector_model_config, params):\n        super(PFSPPlayerVectorizedPool, self).__init__(max_length, device)\n        params['model']['name'] = 'vectorized_a2c'\n        params['network']['name'] = 'vectorized_a2c'\n        builder = model_builder.ModelBuilder()\n        self.vectorized_network = builder.load(params)\n        self.vectorized_model = self.vectorized_network.build(vector_model_config)\n        self.vectorized_model.to(self.device)\n        self.vectorized_model.eval()\n        self.obs = torch.zeros(\n            (self.max_length, vector_model_config[\"num_envs\"], vector_model_config['input_shape'][0]),\n            dtype=torch.float32, device=self.device)\n        for idx in range(max_length):\n            self.add_player(SinglePlayer(idx, None, self.device, vector_model_config[\"num_envs\"]))\n\n    def inference(self, input_dict, res_dict, processed_obs):\n        for i, player in enumerate(self.players):\n            self.obs[i][player.env_indices] = processed_obs[player.env_indices]\n        input_dict['obs'] = self.obs\n        out_dict = self.vectorized_model(input_dict)\n        for i, player in enumerate(self.players):\n            if len(player.env_indices) == 0:\n                continue\n            for key in res_dict:\n                res_dict[key][player.env_indices] = out_dict[key][i][player.env_indices]\n\n    def add_player(self, player):\n        if player.model:\n            self.vectorized_model.update(self.idx, player.model)\n        super().add_player(player)\n\n\nclass PFSPPlayerThreadPool(PFSPPlayerPool):\n    def __init__(self, max_length, device):\n        super().__init__(max_length, device)\n        self.thread_pool = ThreadPoolExecutor(max_workers=self.max_length)\n\n    def inference(self, input_dict, res_dict, processed_obs):\n        self.thread_pool.map(player_inference_thread, [player.model for player in self.players],\n                             [input_dict for _ in range(len(self.players))],\n                             [res_dict for _ in range(len(self.players))],\n                             [player.env_indices for player in self.players],\n                             [processed_obs for _ in range(len(self.players))])\n\n\nclass PFSPPlayerProcessPool(PFSPPlayerPool):\n    def __init__(self, max_length, device):\n        super(PFSPPlayerProcessPool, self).__init__(max_length, device)\n        self.inference_processes = []\n        self.queues = []\n        self.producer_pipes = []\n        self.consumer_pipes = []\n        self.barrier = mp.Barrier(self.max_length + 1)\n        mp.set_start_method(method='spawn', force=True)\n        self._init_inference_processes()\n\n    def _init_inference_processes(self):\n        for _ in range(self.max_length):\n            queue = mp.Queue()\n            self.queues.append(queue)\n            pipe_read, pipe_write = mp.Pipe(duplex=False)\n            self.producer_pipes.append(pipe_write)\n            self.consumer_pipes.append(pipe_read)\n            process = mp.Process(target=player_inference_process,\n                                 args=(pipe_read, queue, self.barrier),\n                                 daemon=True)\n            self.inference_processes.append(process)\n            process.start()\n        self.barrier.wait()\n\n    def add_player(self, player):\n        with torch.no_grad():\n            model = dill.dumps(player.model)\n            for i in range(self.max_length):\n                if i == self.idx:\n                    self.producer_pipes[i].send({'task': 'init'})\n                    self.queues[i].put(model)\n                else:\n                    self.producer_pipes[i].send({'task': 'continue'})\n            self.barrier.wait()\n            if len(self.players) < self.max_length:\n                self.players.append(player)\n            else:\n                self.players[self.idx] = player\n            self.idx += 1\n            self.idx %= self.max_length\n\n    def inference(self, input_dict, res_dict, processed_obs):\n\n        for i in range(self.max_length):\n            if i < len(self.players) and len(self.players[i].env_indices):\n                self.producer_pipes[i].send({'task': 'forward'})\n                self.queues[i].put(\n                    (processed_obs, res_dict['actions'],\n                     res_dict['values'], self.players[i].env_indices))\n            else:\n                self.producer_pipes[i].send({'task': 'continue'})\n\n    def __del__(self):\n        for pipe in self.producer_pipes:\n            pipe.send({'task': 'terminate'})\n        for process in self.inference_processes:\n            process.join()\n"
  },
  {
    "path": "timechamber/learning/ppo_sp_agent.py",
    "content": "# License: see [LICENSE, LICENSES/isaacgymenvs/LICENSE]\n\nimport copy\nfrom datetime import datetime\nfrom gym import spaces\nimport numpy as np\nimport os\nimport time\nfrom .pfsp_player_pool import PFSPPlayerPool, SinglePlayer, PFSPPlayerThreadPool, PFSPPlayerProcessPool, \\\n    PFSPPlayerVectorizedPool\nfrom timechamber.utils.utils import load_checkpoint\nfrom rl_games.algos_torch import a2c_continuous\nfrom rl_games.common.a2c_common import swap_and_flatten01\nfrom rl_games.algos_torch import torch_ext\nfrom rl_games.algos_torch import central_value\nimport torch\nfrom torch import optim\nfrom tensorboardX import SummaryWriter\nimport torch.distributed as dist\n\n\nclass SPAgent(a2c_continuous.A2CAgent):\n    def __init__(self, base_name, params):\n        params['config']['device'] = params['device']\n        super().__init__(base_name, params)\n        self.player_pool_type = params['player_pool_type']\n        self.base_model_config = {\n            'actions_num': self.actions_num,\n            'input_shape': self.obs_shape,\n            'num_seqs': self.num_agents,\n            'value_size': self.env_info.get('value_size', 1),\n            'normalize_value': self.normalize_value,\n            'normalize_input': self.normalize_input,\n        }\n        self.max_his_player_num = params['player_pool_length']\n\n        if params['op_load_path']:\n            self.init_op_model = self.create_model()\n            self.restore_op(params['op_load_path'])\n        else:\n            self.init_op_model = self.model\n        self.players_dir = os.path.join(self.experiment_dir, 'policy_dir')\n        os.makedirs(self.players_dir, exist_ok=True)\n        self.update_win_rate = params['update_win_rate']\n        self.num_opponent_agents = params['num_agents'] - 1\n        self.player_pool = self._build_player_pool(params)\n\n        self.games_to_check = params['games_to_check']\n        self.now_update_steps = 0\n        self.max_update_steps = params['max_update_steps']\n        self.update_op_num = 0\n        self.update_player_pool(self.init_op_model, player_idx=self.update_op_num)\n        self.resample_op(torch.arange(end=self.num_actors, device=self.device, dtype=torch.long))\n\n        assert self.num_actors % self.max_his_player_num == 0\n\n    def _build_player_pool(self, params):\n        if self.player_pool_type == 'multi_thread':\n            return PFSPPlayerProcessPool(max_length=self.max_his_player_num,\n                                         device=self.device)\n        elif self.player_pool_type == 'multi_process':\n            return PFSPPlayerThreadPool(max_length=self.max_his_player_num,\n                                        device=self.device)\n        elif self.player_pool_type == 'vectorized':\n            vector_model_config = self.base_model_config\n            vector_model_config['num_envs'] = self.num_actors * self.num_opponent_agents\n            vector_model_config['population_size'] = self.max_his_player_num\n\n            return PFSPPlayerVectorizedPool(max_length=self.max_his_player_num, device=self.device,\n                                            vector_model_config=vector_model_config, params=params)\n        else:\n            return PFSPPlayerPool(max_length=self.max_his_player_num, device=self.device)\n\n    def play_steps(self):\n        update_list = self.update_list\n        step_time = 0.0\n        env_done_indices = torch.tensor([], device=self.device, dtype=torch.long)\n        \n        for n in range(self.horizon_length):\n            self.obs = self.env_reset(env_done_indices)\n            if self.use_action_masks:\n                masks = self.vec_env.get_action_masks()\n                res_dict = self.get_masked_action_values(self.obs, masks)\n            else:\n                res_dict_op = self.get_action_values(self.obs, is_op=True)\n\n                res_dict = self.get_action_values(self.obs)\n            self.experience_buffer.update_data('obses', n, self.obs['obs'])\n            self.experience_buffer.update_data('dones', n, self.dones)\n            for k in update_list:\n                self.experience_buffer.update_data(k, n, res_dict[k])\n            if self.has_central_value:\n                self.experience_buffer.update_data('states', n, self.obs['states'])\n\n            if self.player_pool_type == 'multi_thread':\n                self.player_pool.thread_pool.shutdown()\n            step_time_start = time.time()\n            self.obs, rewards, self.dones, infos = self.env_step(\n                torch.cat((res_dict['actions'], res_dict_op['actions']), dim=0))\n            step_time_end = time.time()\n            step_time += (step_time_end - step_time_start)\n\n            shaped_rewards = self.rewards_shaper(rewards)\n            if self.value_bootstrap and 'time_outs' in infos:\n                shaped_rewards += self.gamma * res_dict['values'] * self.cast_obs(infos['time_outs']).unsqueeze(\n                    1).float()\n\n            self.experience_buffer.update_data('rewards', n, shaped_rewards)\n\n            self.current_rewards += rewards\n            self.current_lengths += 1\n            all_done_indices = self.dones.nonzero(as_tuple=False)\n            env_done_indices = self.dones.view(self.num_actors, self.num_agents).all(dim=1).nonzero(as_tuple=False)\n            # print(f\"env done indices: {env_done_indices}\")\n            # print(f\"self.dones {self.dones}\")\n            self.game_rewards.update(self.current_rewards[env_done_indices])\n            self.game_lengths.update(self.current_lengths[env_done_indices])\n            self.algo_observer.process_infos(infos, env_done_indices)\n\n            not_dones = 1.0 - self.dones.float()\n\n            self.current_rewards = self.current_rewards * not_dones.unsqueeze(1)\n            self.current_lengths = self.current_lengths * not_dones\n\n            self.player_pool.update_player_metric(infos=infos)\n            self.resample_op(all_done_indices.flatten())\n\n            env_done_indices = env_done_indices[:, 0]\n\n        last_values = self.get_values(self.obs)\n\n        fdones = self.dones.float()\n        mb_fdones = self.experience_buffer.tensor_dict['dones'].float()\n        mb_values = self.experience_buffer.tensor_dict['values']\n        mb_rewards = self.experience_buffer.tensor_dict['rewards']\n        mb_advs = self.discount_values(fdones, last_values, mb_fdones, mb_values, mb_rewards)\n        mb_returns = mb_advs + mb_values\n\n        batch_dict = self.experience_buffer.get_transformed_list(swap_and_flatten01, self.tensor_list)\n        batch_dict['returns'] = swap_and_flatten01(mb_returns)\n        batch_dict['played_frames'] = self.batch_size\n        batch_dict['step_time'] = step_time\n        return batch_dict\n\n    def env_step(self, actions):\n        actions = self.preprocess_actions(actions)\n        obs, rewards, dones, infos = self.vec_env.step(actions)\n        obs['obs_op'] = obs['obs'][self.num_actors:]\n        obs['obs'] = obs['obs'][:self.num_actors]\n        if self.is_tensor_obses:\n            if self.value_size == 1:\n                rewards = rewards.unsqueeze(1)\n            return self.obs_to_tensors(obs), rewards.to(self.ppo_device), dones.to(self.ppo_device), infos\n        else:\n            if self.value_size == 1:\n                rewards = np.expand_dims(rewards, axis=1)\n            return self.obs_to_tensors(obs), torch.from_numpy(rewards).to(self.ppo_device).float(), torch.from_numpy(\n                dones).to(self.ppo_device), infos\n\n    def env_reset(self, env_ids=None):\n        obs = self.vec_env.reset(env_ids)\n        obs = self.obs_to_tensors(obs)\n        obs['obs_op'] = obs['obs'][self.num_actors:]\n        obs['obs'] = obs['obs'][:self.num_actors]\n        return obs\n\n    def train(self):\n        self.init_tensors()\n        self.mean_rewards = self.last_mean_rewards = -100500\n        start_time = time.time()\n        total_time = 0\n        rep_count = 0\n        # self.frame = 0  # loading from checkpoint\n        self.obs = self.env_reset()\n\n        if self.multi_gpu:\n            torch.cuda.set_device(self.rank)\n            print(\"====================broadcasting parameters\")\n            model_params = [self.model.state_dict()]\n            dist.broadcast_object_list(model_params, 0)\n            self.model.load_state_dict(model_params[0])\n\n        while True:\n            epoch_num = self.update_epoch()\n            step_time, play_time, update_time, sum_time, a_losses, c_losses, b_losses, entropies, kls, last_lr, lr_mul = self.train_epoch()\n            # cleaning memory to optimize space\n            self.dataset.update_values_dict(None)\n            total_time += sum_time\n            curr_frames = self.curr_frames * self.rank_size if self.multi_gpu else self.curr_frames\n            self.frame += curr_frames\n            should_exit = False\n\n            if self.rank == 0:\n                self.diagnostics.epoch(self, current_epoch=epoch_num)\n                scaled_time = self.num_agents * sum_time\n                scaled_play_time = self.num_agents * play_time\n\n                frame = self.frame // self.num_agents\n\n                if self.print_stats:\n                    step_time = max(step_time, 1e-6)\n                    fps_step = curr_frames / step_time\n                    fps_step_inference = curr_frames / scaled_play_time\n                    fps_total = curr_frames / scaled_time\n                    print(\n                        f'fps step: {fps_step:.0f} fps step and policy inference: {fps_step_inference:.0f} fps total: {fps_total:.0f} epoch: {epoch_num}/{self.max_epochs}')\n\n                self.write_stats(total_time, epoch_num, step_time, play_time, update_time, a_losses, c_losses,\n                                 entropies, kls, last_lr, lr_mul, frame, scaled_time, scaled_play_time, curr_frames)\n\n                self.algo_observer.after_print_stats(frame, epoch_num, total_time)\n\n                if self.game_rewards.current_size > 0:\n                    mean_rewards = self.game_rewards.get_mean()\n                    mean_lengths = self.game_lengths.get_mean()\n                    self.mean_rewards = mean_rewards[0]\n\n                    for i in range(self.value_size):\n                        rewards_name = 'rewards' if i == 0 else 'rewards{0}'.format(i)\n                        self.writer.add_scalar(rewards_name + '/step'.format(i), mean_rewards[i], frame)\n                        self.writer.add_scalar(rewards_name + '/iter'.format(i), mean_rewards[i], epoch_num)\n                        self.writer.add_scalar(rewards_name + '/time'.format(i), mean_rewards[i], total_time)\n\n                    self.writer.add_scalar('episode_lengths/step', mean_lengths, frame)\n                    self.writer.add_scalar('episode_lengths/iter', mean_lengths, epoch_num)\n                    self.writer.add_scalar('episode_lengths/time', mean_lengths, total_time)\n\n                    # removed equal signs (i.e. \"rew=\") from the checkpoint name since it messes with hydra CLI parsing\n                    checkpoint_name = self.config['name'] + '_ep_' + str(epoch_num) + '_rew_' + str(mean_rewards[0])\n\n                    if self.save_freq > 0:\n                        if (epoch_num % self.save_freq == 0) and (mean_rewards <= self.last_mean_rewards):\n                            self.save(os.path.join(self.nn_dir, 'last_' + checkpoint_name))\n\n                    if mean_rewards[0] > self.last_mean_rewards and epoch_num >= self.save_best_after:\n                        print('saving next best rewards: ', mean_rewards)\n                        self.last_mean_rewards = mean_rewards[0]\n                        self.save(os.path.join(self.nn_dir, self.config['name']))\n\n                        if 'score_to_win' in self.config:\n                            if self.last_mean_rewards > self.config['score_to_win']:\n                                print('Network won!')\n                                self.save(os.path.join(self.nn_dir, checkpoint_name))\n                                should_exit = True\n\n                if epoch_num >= self.max_epochs:\n                    if self.game_rewards.current_size == 0:\n                        print('WARNING: Max epochs reached before any env terminated at least once')\n                        mean_rewards = -np.inf\n\n                    self.save(os.path.join(self.nn_dir,\n                                           'last_' + self.config['name'] + 'ep' + str(epoch_num) + 'rew' + str(\n                                               mean_rewards)))\n                    print('MAX EPOCHS NUM!')\n                    should_exit = True\n                self.update_metric()\n                update_time = 0\n\n            if self.multi_gpu:\n                should_exit_t = torch.tensor(should_exit, device=self.device).float()\n                dist.broadcast(should_exit_t, 0)\n                should_exit = should_exit_t.bool().item()\n            if should_exit:\n                return self.last_mean_rewards, epoch_num\n\n    def update_metric(self):\n        tot_win_rate = 0\n        tot_games_num = 0\n        self.now_update_steps += 1\n        # self_player process\n        for player in self.player_pool.players:\n            win_rate = player.win_rate()\n            games = player.games_num()\n            self.writer.add_scalar(f'rate/win_rate_player_{player.player_idx}', win_rate, self.epoch_num)\n            tot_win_rate += win_rate * games\n            tot_games_num += games\n        win_rate = tot_win_rate / tot_games_num\n        if tot_games_num > self.games_to_check:\n            self.check_update_opponent(win_rate)\n        self.writer.add_scalar('rate/win_rate', win_rate, self.epoch_num)\n\n    def get_action_values(self, obs, is_op=False):\n        processed_obs = self._preproc_obs(obs['obs_op'] if is_op else obs['obs'])\n        if not is_op:\n            self.model.eval()\n        input_dict = {\n            'is_train': False,\n            'prev_actions': None,\n            'obs': processed_obs,\n            'rnn_states': self.rnn_states\n        }\n        with torch.no_grad():\n            if is_op:\n                res_dict = {\n                    \"actions\": torch.zeros((self.num_actors * self.num_opponent_agents, self.actions_num),\n                                           device=self.device),\n                    \"values\": torch.zeros((self.num_actors * self.num_opponent_agents, 1), device=self.device)\n                }\n                self.player_pool.inference(input_dict, res_dict, processed_obs)\n            else:\n                res_dict = self.model(input_dict)\n            if self.has_central_value:\n                states = obs['states']\n                input_dict = {\n                    'is_train': False,\n                    'states': states,\n                }\n                value = self.get_central_value(input_dict)\n                res_dict['values'] = value\n        return res_dict\n\n    def resample_op(self, resample_indices):\n        for op_idx in range(self.num_opponent_agents):\n            for player in self.player_pool.players:\n                player.remove_envs(resample_indices + op_idx * self.num_actors)\n        for op_idx in range(self.num_opponent_agents):\n            for env_idx in resample_indices:\n                player = self.player_pool.sample_player()\n                player.add_envs(env_idx + op_idx * self.num_actors)\n        for player in self.player_pool.players:\n            player.reset_envs()\n\n    def resample_batch(self):\n        env_indices = torch.arange(end=self.num_actors * self.num_opponent_agents,\n                                   device=self.device, dtype=torch.long,\n                                   requires_grad=False)\n        step = self.num_actors // 32\n        for player in self.player_pool.players:\n            player.clear_envs()\n        for i in range(0, self.num_actors, step):\n            player = self.player_pool.sample_player()\n            player.add_envs(env_indices[i:i + step])\n        print(\"resample done\")\n\n    def restore_op(self, fn):\n        checkpoint = load_checkpoint(fn, device=self.device)\n        self.init_op_model.load_state_dict(checkpoint['model'])\n        if self.normalize_input and 'running_mean_std' in checkpoint:\n            self.init_op_model.running_mean_std.load_state_dict(checkpoint['running_mean_std'])\n\n    def check_update_opponent(self, win_rate):\n        if win_rate > self.update_win_rate or self.now_update_steps > self.max_update_steps:\n            print(f'winrate:{win_rate},add opponent to player pool')\n            self.update_op_num += 1\n            self.now_update_steps = 0\n            self.update_player_pool(self.model, player_idx=self.update_op_num)\n            self.player_pool.clear_player_metric()\n            self.resample_op(torch.arange(end=self.num_actors, device=self.device, dtype=torch.long))\n            self.save(os.path.join(self.players_dir, f'policy_{self.update_op_num}'))\n\n    def create_model(self):\n        model = self.network.build(self.base_model_config)\n        model.to(self.device)\n        return model\n\n    def update_player_pool(self, model, player_idx):\n        new_model = self.create_model()\n        new_model.load_state_dict(copy.deepcopy(model.state_dict()))\n        if hasattr(model, 'running_mean_std'):\n            new_model.running_mean_std.load_state_dict(copy.deepcopy(model.running_mean_std.state_dict()))\n        player = SinglePlayer(player_idx, new_model, self.device, self.num_actors * self.num_opponent_agents)\n        self.player_pool.add_player(player)\n"
  },
  {
    "path": "timechamber/learning/ppo_sp_player.py",
    "content": "# License: see [LICENSE, LICENSES/isaacgymenvs/LICENSE]\nimport os\nimport time\nimport torch\nimport numpy as np\nfrom rl_games.algos_torch import players\nimport random\nfrom rl_games.algos_torch import torch_ext\nfrom rl_games.common.tr_helpers import unsqueeze_obs\nfrom rl_games.common.player import BasePlayer\nfrom .pfsp_player_pool import PFSPPlayerPool, PFSPPlayerVectorizedPool, PFSPPlayerThreadPool, PFSPPlayerProcessPool, \\\n    SinglePlayer\nimport matplotlib.pyplot as plt\n\nfrom multielo import MultiElo\n\n\ndef rescale_actions(low, high, action):\n    d = (high - low) / 2.0\n    m = (high + low) / 2.0\n    scaled_action = action * d + m\n    return scaled_action\n\n\nclass SPPlayer(BasePlayer):\n    def __init__(self, params):\n        params['config']['device_name'] = params['device']\n        super().__init__(params)\n        print(f'params:{params}')\n        self.network = self.config['network']\n        self.actions_num = self.action_space.shape[0]\n        self.actions_low = torch.from_numpy(self.action_space.low.copy()).float().to(self.device)\n        self.actions_high = torch.from_numpy(self.action_space.high.copy()).float().to(self.device)\n        self.mask = [False]\n        self.is_rnn = False\n        self.normalize_input = self.config['normalize_input']\n        self.normalize_value = self.config.get('normalize_value', False)\n        self.base_model_config = {\n            'actions_num': self.actions_num,\n            'input_shape': self.obs_shape,\n            'num_seqs': self.num_agents,\n            'value_size': self.env_info.get('value_size', 1),\n            'normalize_value': self.normalize_value,\n            'normalize_input': self.normalize_input,\n        }\n        self.policy_timestep = []\n        self.policy_op_timestep = []\n        self.params = params\n        self.record_elo = self.player_config.get('record_elo', False)\n        self.init_elo = self.player_config.get('init_elo', 400)\n        self.num_actors = params['config']['num_actors']\n        self.player_pool_type = params['player_pool_type']\n        self.player_pool = None\n        self.op_player_pool = None\n        self.num_opponents = params['num_agents'] - 1\n        self.max_steps = 1000\n        self.update_op_num = 0\n        self.players_per_env = []\n        self.elo = MultiElo()\n\n    def restore(self, load_dir):\n        if os.path.isdir(load_dir):\n            self.player_pool = self._build_player_pool(params=self.params, player_num=len(os.listdir(load_dir)))\n            print('dir:', load_dir)\n            sorted_players = []\n            for idx, policy_check_checkpoint in enumerate(os.listdir(load_dir)):\n                model_timestep = os.path.getmtime(load_dir + '/' + str(policy_check_checkpoint))\n                self.policy_timestep.append(model_timestep)\n                model = self.load_model(load_dir + '/' + str(policy_check_checkpoint))\n                new_player = SinglePlayer(player_idx=model_timestep, model=model, device=self.device,\n                                          rating=self.init_elo, obs_batch_len=self.num_actors * self.num_opponents)\n                sorted_players.append(new_player)\n            sorted_players.sort(key=lambda player: player.player_idx)\n            for idx, player in enumerate(sorted_players):\n                player.player_idx = idx\n                self.player_pool.add_player(player)\n            self.policy_timestep.sort()\n        else:\n            self.player_pool = self._build_player_pool(params=self.params, player_num=1)\n            self.policy_timestep.append(os.path.getmtime(load_dir))\n            model = self.load_model(load_dir)\n            new_player = SinglePlayer(player_idx=0, model=model, device=self.device,\n                                      rating=self.init_elo, obs_batch_len=self.num_actors * self.num_opponents)\n            self.player_pool.add_player(new_player)\n        self.restore_op(self.params['op_load_path'])\n        self._norm_policy_timestep()\n        self._alloc_env_indices()\n\n    def restore_op(self, load_dir):\n        if os.path.isdir(load_dir):\n            self.op_player_pool = self._build_player_pool(params=self.params, player_num=len(os.listdir(load_dir)))\n            sorted_players = []\n            for idx, policy_check_checkpoint in enumerate(os.listdir(load_dir)):\n                model_timestep = os.path.getmtime(load_dir + '/' + str(policy_check_checkpoint))\n                self.policy_op_timestep.append(model_timestep)\n                model = self.load_model(load_dir + '/' + str(policy_check_checkpoint))\n                new_player = SinglePlayer(player_idx=model_timestep, model=model, device=self.device,\n                                          rating=self.init_elo, obs_batch_len=self.num_actors * self.num_opponents)\n                sorted_players.append(new_player)\n            sorted_players.sort(key=lambda player: player.player_idx)\n            for idx, player in enumerate(sorted_players):\n                player.player_idx = idx\n                self.op_player_pool.add_player(player)\n            self.policy_op_timestep.sort()\n        else:\n            self.op_player_pool = self._build_player_pool(params=self.params, player_num=1)\n            self.policy_op_timestep.append(os.path.getmtime(load_dir))\n            model = self.load_model(load_dir)\n            new_player = SinglePlayer(player_idx=0, model=model, device=self.device,\n                                      rating=400, obs_batch_len=self.num_actors * self.num_opponents)\n            self.op_player_pool.add_player(new_player)\n\n    def _alloc_env_indices(self):\n        for idx in range(self.num_actors):\n            player_idx = random.randint(0, len(self.player_pool.players) - 1)\n            self.player_pool.players[player_idx].add_envs(torch.tensor([idx], dtype=torch.long, device=self.device))\n            env_player = [self.player_pool.players[player_idx]]\n            for op_idx in range(self.num_opponents):\n                op_player_idx = random.randint(0, len(self.op_player_pool.players) - 1)\n                self.op_player_pool.players[op_player_idx].add_envs(\n                    torch.tensor([idx + op_idx * self.num_actors], dtype=torch.long, device=self.device))\n                env_player.append(self.op_player_pool.players[op_player_idx])\n            self.players_per_env.append(env_player)\n        for player in self.player_pool.players:\n            player.reset_envs()\n        for player in self.op_player_pool.players:\n            player.reset_envs()\n\n    def _build_player_pool(self, params, player_num):\n\n        if self.player_pool_type == 'multi_thread':\n            return PFSPPlayerProcessPool(max_length=player_num,\n                                         device=self.device)\n        elif self.player_pool_type == 'multi_process':\n            return PFSPPlayerThreadPool(max_length=player_num,\n                                        device=self.device)\n        elif self.player_pool_type == 'vectorized':\n            vector_model_config = self.base_model_config\n            vector_model_config['num_envs'] = self.num_actors * self.num_opponents\n            vector_model_config['population_size'] = player_num\n\n            return PFSPPlayerVectorizedPool(max_length=player_num, device=self.device,\n                                            vector_model_config=vector_model_config, params=params)\n        else:\n            return PFSPPlayerPool(max_length=player_num, device=self.device)\n\n    def _update_rating(self, info, env_indices):\n        for env_idx in env_indices:\n            if self.num_opponents == 1:\n                player = self.players_per_env[env_idx][0]\n                op_player = self.players_per_env[env_idx][1]\n                if info['win'][env_idx]:\n                    player.rating, op_player.rating = self.elo.get_new_ratings([player.rating, op_player.rating])\n                elif info['lose'][env_idx]:\n                    op_player.rating, player.rating = self.elo.get_new_ratings([op_player.rating, player.rating])\n                elif info['draw'][env_idx]:\n                    player.rating, op_player.rating = self.elo.get_new_ratings([player.rating, op_player.rating],\n                                                                               result_order=[1, 1])\n            else:\n                ranks = info['ranks'][env_idx].cpu().numpy()\n                players_sorted_by_rank = sorted(enumerate(self.players_per_env[env_idx]), key=lambda x: ranks[x[0]])\n                sorted_ranks = sorted(ranks)\n                now_ratings = [player.rating for idx, player in players_sorted_by_rank]\n                new_ratings = self.elo.get_new_ratings(now_ratings, result_order=sorted_ranks)\n                # print(now_ratings, new_ratings)\n                # assert new_ratings[0] > 0 and new_ratings[1] > 0 and new_ratings[2] > 0\n                for idx, new_rating in enumerate(new_ratings):\n                    players_sorted_by_rank[idx][1].rating = new_rating\n\n    def run(self):\n        n_games = self.games_num\n        render = self.render_env\n        n_game_life = self.n_game_life\n        is_determenistic = self.is_determenistic\n        sum_rewards = 0\n        sum_steps = 0\n        sum_game_res = 0\n        n_games = n_games * n_game_life\n        games_played = 0\n        has_masks = False\n        has_masks_func = getattr(self.env, \"has_action_mask\", None) is not None\n\n        if has_masks_func:\n            has_masks = self.env.has_action_mask()\n        print(f'games_num:{n_games}')\n        need_init_rnn = self.is_rnn\n        for _ in range(n_games):\n            if games_played >= n_games:\n                break\n\n            obses = self.env_reset(self.env)\n            batch_size = 1\n            batch_size = self.get_batch_size(obses['obs'], batch_size)\n\n            if need_init_rnn:\n                self.init_rnn()\n                need_init_rnn = False\n\n            cr = torch.zeros(batch_size, dtype=torch.float32, device=self.device)\n            steps = torch.zeros(batch_size, dtype=torch.float32, device=self.device)\n\n            print_game_res = False\n            done_indices = torch.tensor([], device=self.device, dtype=torch.long)\n            for n in range(self.max_steps):\n                obses = self.env_reset(self.env, done_indices)\n                if has_masks:\n                    masks = self.env.get_action_mask()\n                    action = self.get_masked_action(\n                        obses, masks, is_determenistic)\n                else:\n                    action = self.get_action(obses['obs'], is_determenistic)\n                    action_op = self.get_action(obses['obs_op'], is_determenistic, is_op=True)\n                obses, r, done, info = self.env_step(self.env, torch.cat((action, action_op), dim=0))\n                cr += r\n                steps += 1\n\n                if render:\n                    self.env.render(mode='human')\n                    time.sleep(self.render_sleep)\n\n                all_done_indices = done.nonzero(as_tuple=False)\n                done_indices = all_done_indices[::self.num_agents]\n                done_count = len(done_indices)\n                games_played += done_count\n                if self.record_elo:\n                    self._update_rating(info, all_done_indices.flatten())\n                if done_count > 0:\n                    if self.is_rnn:\n                        for s in self.states:\n                            s[:, all_done_indices, :] = s[:, all_done_indices, :] * 0.0\n\n                    cur_rewards = cr[done_indices].sum().item()\n                    cur_steps = steps[done_indices].sum().item()\n\n                    cr = cr * (1.0 - done.float())\n                    steps = steps * (1.0 - done.float())\n                    sum_rewards += cur_rewards\n                    sum_steps += cur_steps\n\n                    game_res = 0.0\n                    if isinstance(info, dict):\n                        if 'battle_won' in info:\n                            print_game_res = True\n                            game_res = info.get('battle_won', 0.5)\n                        if 'scores' in info:\n                            print_game_res = True\n                            game_res = info.get('scores', 0.5)\n                    if self.print_stats:\n                        if print_game_res:\n                            print('reward:', cur_rewards / done_count,\n                                  'steps:', cur_steps / done_count, 'w:', game_res)\n                        else:\n                            print('reward:', cur_rewards / done_count,\n                                  'steps:', cur_steps / done_count)\n\n                    sum_game_res += game_res\n                    if batch_size // self.num_agents == 1 or games_played >= n_games:\n                        print(f\"games_player: {games_played}\")\n                        break\n                done_indices = done_indices[:, 0]\n\n        if self.record_elo:\n            self._plot_elo_curve()\n\n    def _plot_elo_curve(self):\n        x = np.array(self.policy_timestep)\n        y = np.arange(len(self.player_pool.players))\n        x_op = np.array(self.policy_op_timestep)\n        y_op = np.arange(len(self.op_player_pool.players))\n        for player in self.player_pool.players:\n            idx = player.player_idx\n            # print(player.player_idx, player.rating)\n            y[idx] = player.rating\n        for player in self.op_player_pool.players:\n            idx = player.player_idx\n            # print(player.player_idx, player.rating)\n            y_op[idx] = player.rating\n        if self.params['load_path'] != self.params['op_load_path']:\n            l1 = plt.plot(x, y, 'b--', label='policy')\n            l2 = plt.plot(x_op, y_op, 'r--', label='policy_op')\n            plt.plot(x, y, 'b^-', x_op, y_op, 'ro-')\n        else:\n            l1 = plt.plot(x, y, 'b--', label='policy')\n            plt.plot(x, y, 'b^-')\n        plt.title('ELO Curve')\n        plt.xlabel('timestep/days')\n        plt.ylabel('ElO')\n        plt.legend()\n        parent_path = os.path.dirname(self.params['load_path'])\n        plt.savefig(os.path.join(parent_path, 'elo.jpg'))\n\n    def get_action(self, obs, is_determenistic=False, is_op=False):\n        if self.has_batch_dimension == False:\n            obs = unsqueeze_obs(obs)\n        obs = self._preproc_obs(obs)\n        input_dict = {\n            'is_train': False,\n            'prev_actions': None,\n            'obs': obs,\n            'rnn_states': self.states\n        }\n        with torch.no_grad():\n            data_len = self.num_actors * self.num_opponents if is_op else self.num_actors\n            res_dict = {\n                \"actions\": torch.zeros((data_len, self.actions_num), device=self.device),\n                \"values\": torch.zeros((data_len, 1), device=self.device),\n                \"mus\": torch.zeros((data_len, self.actions_num), device=self.device)\n            }\n            if is_op:\n                self.op_player_pool.inference(input_dict, res_dict, obs)\n            else:\n                self.player_pool.inference(input_dict, res_dict, obs)\n        mu = res_dict['mus']\n        action = res_dict['actions']\n        # self.states = res_dict['rnn_states']\n        if is_determenistic:\n            current_action = mu\n        else:\n            current_action = action\n        if self.has_batch_dimension == False:\n            current_action = torch.squeeze(current_action.detach())\n\n        if self.clip_actions:\n            return rescale_actions(self.actions_low, self.actions_high, torch.clamp(current_action, -1.0, 1.0))\n        else:\n            return current_action\n\n    def _norm_policy_timestep(self):\n        self.policy_op_timestep.sort()\n        self.policy_timestep.sort()\n        for idx in range(1, len(self.policy_op_timestep)):\n            self.policy_op_timestep[idx] -= self.policy_op_timestep[0]\n            self.policy_op_timestep[idx] /= 3600 * 24\n        for idx in range(1, len(self.policy_timestep)):\n            self.policy_timestep[idx] -= self.policy_timestep[0]\n            self.policy_timestep[idx] /= 3600 * 24\n        self.policy_timestep[0] = 0\n        if len(self.policy_op_timestep):\n            self.policy_op_timestep[0] = 0\n\n    def env_reset(self, env, done_indices=None):\n        obs = env.reset(done_indices)\n        obs_dict = {}\n        obs_dict['obs_op'] = obs[self.num_actors:]\n        obs_dict['obs'] = obs[:self.num_actors]\n        return obs_dict\n\n    def env_step(self, env, actions):\n        obs, rewards, dones, infos = env.step(actions)\n        if hasattr(obs, 'dtype') and obs.dtype == np.float64:\n            obs = np.float32(obs)\n        obs_dict = {}\n        obs_dict['obs_op'] = obs[self.num_actors:]\n        obs_dict['obs'] = obs[:self.num_actors]\n        if self.value_size > 1:\n            rewards = rewards[0]\n        if self.is_tensor_obses:\n            return self.obs_to_torch(obs_dict), rewards.cpu(), dones.cpu(), infos\n        else:\n            if np.isscalar(dones):\n                rewards = np.expand_dims(np.asarray(rewards), 0)\n                dones = np.expand_dims(np.asarray(dones), 0)\n            return obs_dict, rewards, dones, infos\n\n    def create_model(self):\n        model = self.network.build(self.base_model_config)\n        model.to(self.device)\n        return model\n\n    def load_model(self, fn):\n        model = self.create_model()\n        checkpoint = torch_ext.safe_filesystem_op(torch.load, fn, map_location=self.device)\n        model.load_state_dict(checkpoint['model'])\n        if self.normalize_input and 'running_mean_std' in checkpoint:\n            model.running_mean_std.load_state_dict(checkpoint['running_mean_std'])\n        return model\n"
  },
  {
    "path": "timechamber/learning/replay_buffer.py",
    "content": "# License: see [LICENSE, LICENSES/isaacgymenvs/LICENSE]\n\nimport torch\n\n\nclass ReplayBuffer():\n    def __init__(self, buffer_size, device):\n        self._head = 0\n        self._total_count = 0\n        self._buffer_size = buffer_size\n        self._device = device\n        self._data_buf = None\n        self._sample_idx = torch.randperm(buffer_size)\n        self._sample_head = 0\n\n        return\n\n    def reset(self):\n        self._head = 0\n        self._total_count = 0\n        self._reset_sample_idx()\n        return\n\n    def get_buffer_size(self):\n        return self._buffer_size\n\n    def get_total_count(self):\n        return self._total_count\n\n    def store(self, data_dict):\n        if (self._data_buf is None):\n            self._init_data_buf(data_dict)\n\n        n = next(iter(data_dict.values())).shape[0]\n        buffer_size = self.get_buffer_size()\n        assert (n < buffer_size)\n\n        for key, curr_buf in self._data_buf.items():\n            curr_n = data_dict[key].shape[0]\n            assert (n == curr_n)\n\n            store_n = min(curr_n, buffer_size - self._head)\n            curr_buf[self._head:(self._head + store_n)] = data_dict[key][:store_n]\n\n            remainder = n - store_n\n            if (remainder > 0):\n                curr_buf[0:remainder] = data_dict[key][store_n:]\n\n        self._head = (self._head + n) % buffer_size\n        self._total_count += n\n\n        return\n\n    def sample(self, n):\n        total_count = self.get_total_count()\n        buffer_size = self.get_buffer_size()\n\n        idx = torch.arange(self._sample_head, self._sample_head + n)\n        idx = idx % buffer_size\n        rand_idx = self._sample_idx[idx]\n        if (total_count < buffer_size):\n            rand_idx = rand_idx % self._head\n\n        samples = dict()\n        for k, v in self._data_buf.items():\n            samples[k] = v[rand_idx]\n\n        self._sample_head += n\n        if (self._sample_head >= buffer_size):\n            self._reset_sample_idx()\n\n        return samples\n\n    def _reset_sample_idx(self):\n        buffer_size = self.get_buffer_size()\n        self._sample_idx[:] = torch.randperm(buffer_size)\n        self._sample_head = 0\n        return\n\n    def _init_data_buf(self, data_dict):\n        buffer_size = self.get_buffer_size()\n        self._data_buf = dict()\n\n        for k, v in data_dict.items():\n            v_shape = v.shape[1:]\n            self._data_buf[k] = torch.zeros((buffer_size,) + v_shape, device=self._device)\n\n        return\n"
  },
  {
    "path": "timechamber/learning/vectorized_models.py",
    "content": "import torch\nimport torch.nn as nn\nfrom rl_games.algos_torch.running_mean_std import RunningMeanStd, RunningMeanStdObs\nfrom rl_games.algos_torch import torch_ext\nfrom rl_games.algos_torch.models import ModelA2CContinuousLogStd\n\n\nclass VectorizedRunningMeanStd(RunningMeanStd):\n    def __init__(self, insize, population_size, epsilon=1e-05, per_channel=False, norm_only=False, is_training=False):\n        # input shape: population_size*batch_size*(insize)\n        super(VectorizedRunningMeanStd, self).__init__(population_size, epsilon, per_channel, norm_only)\n        self.insize = insize\n        self.epsilon = epsilon\n        self.population_size = population_size\n        self.training = is_training\n        self.norm_only = norm_only\n        self.per_channel = per_channel\n        if per_channel:\n            if len(self.insize) == 3:\n                self.axis = [1, 3, 4]\n            if len(self.insize) == 2:\n                self.axis = [1, 3]\n            if len(self.insize) == 1:\n                self.axis = [1]\n            in_size = self.insize[1]\n        else:\n            self.axis = [1]\n            in_size = insize\n        # print(in_size)\n        self.register_buffer(\"running_mean\", torch.zeros((population_size, *in_size), dtype=torch.float32))\n        self.register_buffer(\"running_var\", torch.ones((population_size, *in_size), dtype=torch.float32))\n        self.register_buffer(\"count\", torch.ones((population_size, 1), dtype=torch.float32))\n\n    def _update_mean_var_count_from_moments(self, mean, var, count, batch_mean, batch_var, batch_count):\n        delta = batch_mean - mean\n        tot_count = count + batch_count\n        new_mean = mean + delta * batch_count / tot_count\n        m_a = var * count\n        m_b = batch_var * batch_count\n        M2 = m_a + m_b + delta ** 2 * count * batch_count / tot_count\n        new_var = M2 / tot_count\n        new_count = tot_count\n        return new_mean, new_var, new_count\n\n    def forward(self, input, unnorm=False, mask=None):\n        if self.training:\n            if mask is not None:\n                mean, var = torch_ext.get_mean_std_with_masks(input, mask)\n            else:\n                mean = input.mean(self.axis)  # along channel axis\n                var = input.var(self.axis)\n            self.running_mean, self.running_var, self.count = self._update_mean_var_count_from_moments(\n                self.running_mean, self.running_var, self.count,\n                mean, var, input.size()[1])\n\n        # change shape\n        if self.per_channel:\n            if len(self.insize) == 3:\n                current_mean = self.running_mean.view([self.population_size, 1, self.insize[0], 1, 1]).expand_as(input)\n                current_var = self.running_var.view([self.population_size, 1, self.insize[0], 1, 1]).expand_as(input)\n            if len(self.insize) == 2:\n                current_mean = self.running_mean.view([self.population_size, 1, self.insize[0], 1]).expand_as(input)\n                current_var = self.running_var.view([self.population_size, 1, self.insize[0], 1]).expand_as(input)\n            if len(self.insize) == 1:\n                current_mean = self.running_mean.view([self.population_size, 1, self.insize[0]]).expand_as(input)\n                current_var = self.running_var.view([self.population_size, 1, self.insize[0]]).expand_as(input)\n        else:\n            current_mean = self.running_mean\n            current_var = self.running_var\n        # get output\n\n        if unnorm:\n            y = torch.clamp(input, min=-5.0, max=5.0)\n            y = torch.sqrt(torch.unsqueeze(current_var.float(), 1) + self.epsilon) * y + torch.unsqueeze(\n                current_mean.float(), 1)\n        else:\n            if self.norm_only:\n                y = input / torch.sqrt(current_var.float() + self.epsilon)\n            else:\n                y = (input - torch.unsqueeze(current_mean.float(), 1)) / torch.sqrt(\n                    torch.unsqueeze(current_var.float(), 1) + self.epsilon)\n                y = torch.clamp(y, min=-5.0, max=5.0)\n        return y\n\n\nclass ModelVectorizedA2C(ModelA2CContinuousLogStd):\n    def __init__(self, network):\n        super().__init__(network)\n        return\n\n    def build(self, config):\n        net = self.network_builder.build('vectorized_a2c', **config)\n        for name, _ in net.named_parameters():\n            print(name)\n\n        obs_shape = config['input_shape']\n        population_size = config['population_size']\n        normalize_value = config.get('normalize_value', False)\n        normalize_input = config.get('normalize_input', False)\n        value_size = config.get('value_size', 1)\n\n        return self.Network(net, population_size, obs_shape=obs_shape,\n                            normalize_value=normalize_value, normalize_input=normalize_input, value_size=value_size, )\n\n    class Network(ModelA2CContinuousLogStd.Network):\n        def __init__(self, a2c_network, population_size, obs_shape, normalize_value, normalize_input, value_size):\n            self.population_size = population_size\n            super().__init__(a2c_network, obs_shape=obs_shape,\n                             normalize_value=normalize_value, normalize_input=normalize_input, value_size=value_size)\n            if normalize_value:\n                self.value_mean_std = VectorizedRunningMeanStd((self.value_size,), self.population_size)\n            if normalize_input:\n                if isinstance(obs_shape, dict):\n                    self.running_mean_std = RunningMeanStdObs(obs_shape)\n                else:\n                    self.running_mean_std = VectorizedRunningMeanStd(obs_shape, self.population_size)\n\n        def update(self, population_idx, network):\n            for key in self.state_dict():\n                param1 = self.state_dict()[key]\n                param2 = network.state_dict()[key]\n                if len(param1.shape) == len(param2.shape):\n                    self.state_dict()[key] = param2\n                elif len(param2.shape) == 1:\n                    if len(param1.shape) == 3:\n                        param1[population_idx] = torch.unsqueeze(param2, dim=0)\n                    else:\n                        param1[population_idx] = param2\n                elif len(param2.shape) == 2:\n                    param1[population_idx] = torch.transpose(param2, 0, 1)\n"
  },
  {
    "path": "timechamber/learning/vectorized_network_builder.py",
    "content": "import torch\nimport torch.nn as nn\nimport math\nfrom rl_games.algos_torch import network_builder\n\n\nclass VectorizedLinearLayer(torch.nn.Module):\n    \"\"\"Vectorized version of torch.nn.Linear.\"\"\"\n\n    def __init__(\n            self,\n            population_size: int,\n            in_features: int,\n            out_features: int,\n            use_layer_norm: bool = False,\n    ):\n        super().__init__()\n        self._population_size = population_size\n        self._in_features = in_features\n        self._out_features = out_features\n\n        self.weight = torch.nn.Parameter(\n            torch.empty(self._population_size, self._in_features, self._out_features),\n            requires_grad=True,\n        )\n        self.bias = torch.nn.Parameter(\n            torch.empty(self._population_size, 1, self._out_features),\n            requires_grad=True,\n        )\n\n        for member_id in range(population_size):\n            torch.nn.init.kaiming_uniform_(self.weight[member_id], a=math.sqrt(5))\n        fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight[0])\n        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0\n        torch.nn.init.uniform_(self.bias, -bound, bound)\n\n        self._layer_norm = (\n            torch.nn.LayerNorm(self._out_features, self._population_size)\n            if use_layer_norm\n            else None\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        assert x.shape[0] == self._population_size\n        if self._layer_norm is not None:\n            return self._layer_norm(x.matmul(self.weight) + self.bias)\n        return x.matmul(self.weight) + self.bias\n\n\nclass VectorizedA2CBuilder(network_builder.A2CBuilder):\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n        return\n\n    class Network(network_builder.A2CBuilder.Network):\n        def __init__(self, params, **kwargs):\n            self.population_size = kwargs.get('population_size')\n            super().__init__(params, **kwargs)\n\n            self.value = VectorizedLinearLayer(population_size=self.population_size,\n                                               in_features=self.units[-1],\n                                               out_features=self.value_size)\n            actions_num = kwargs.get('actions_num')\n            self.mu = VectorizedLinearLayer(self.population_size, self.units[-1], actions_num)\n            if self.fixed_sigma:\n                self.sigma = nn.Parameter(\n                    torch.zeros((self.population_size, 1, actions_num), requires_grad=True, dtype=torch.float32),\n                    requires_grad=True)\n            else:\n                self.sigma = VectorizedLinearLayer(self.population_size, self.units[-1], actions_num)\n\n        def _build_vectorized_mlp(self,\n                                  input_size,\n                                  units,\n                                  activation,\n                                  norm_func_name=None):\n            print(f'build vectorized mlp:{self.population_size}x{input_size}')\n            in_size = input_size\n            layers = []\n            for unit in units:\n                layers.append(\n                    VectorizedLinearLayer(self.population_size, in_size, unit, norm_func_name == 'layer_norm'))\n                layers.append(self.activations_factory.create(activation))\n                in_size = unit\n            return nn.Sequential(*layers)\n\n        def _build_mlp(self,\n                       input_size,\n                       units,\n                       activation,\n                       dense_func,\n                       norm_only_first_layer=False,\n                       norm_func_name=None,\n                       d2rl=False):\n            return self._build_vectorized_mlp(input_size, units, activation, norm_func_name=norm_func_name)\n\n        def forward(self, obs_dict):  # implement continues situation\n            obs = obs_dict['obs']\n            states = obs_dict.get('rnn_states', None)\n            out = self.actor_mlp(obs)\n            value = self.value_act(self.value(out))\n            mu = self.mu_act(self.mu(out))\n            if self.fixed_sigma:\n                sigma = self.sigma_act(self.sigma)\n            else:\n                sigma = self.sigma_act(self.sigma(out))\n            return mu, mu * 0 + sigma, value, states\n\n        def load(self, params):\n            super().load(params)\n\n    def build(self, name, **kwargs):\n        net = VectorizedA2CBuilder.Network(self.params, **kwargs)\n        return net\n"
  },
  {
    "path": "timechamber/tasks/__init__.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n\nfrom .ma_ant_sumo import MA_Ant_Sumo\nfrom .ma_ant_battle import MA_Ant_Battle\nfrom .ma_humanoid_strike import HumanoidStrike\n\n# Mappings from strings to environments\nisaacgym_task_map = {\n    \"MA_Ant_Sumo\": MA_Ant_Sumo,\n    \"MA_Ant_Battle\": MA_Ant_Battle,\n    \"MA_Humanoid_Strike\": HumanoidStrike\n}\n"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/base_task.py",
    "content": "# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\nimport sys\nimport os\nimport operator\nfrom copy import deepcopy\nimport random\n\nfrom isaacgym import gymapi\nfrom isaacgym.gymutil import get_property_setter_map, get_property_getter_map, get_default_setter_args, apply_random_samples, check_buckets, generate_random_samples\n\nimport numpy as np\nimport torch\n\n\n# Base class for RL tasks\nclass BaseTask():\n\n    def __init__(self, cfg, enable_camera_sensors=False):\n        self.gym = gymapi.acquire_gym()\n\n        self.device_type = cfg.get(\"device_type\", \"cuda\")\n        self.device_id = cfg.get(\"device_id\", 0)\n\n        self.device = \"cpu\"\n        if self.device_type == \"cuda\" or self.device_type == \"GPU\":\n            self.device = \"cuda\" + \":\" + str(self.device_id)\n\n        self.headless = cfg[\"headless\"]\n        self.num_agents = cfg[\"env\"].get(\"numAgents\", 1)  # used for multi-agent environments\n\n        # double check!\n        self.graphics_device_id = self.device_id\n        if enable_camera_sensors == False and self.headless == True:\n            self.graphics_device_id = -1\n\n        self.num_envs = cfg[\"env\"][\"numEnvs\"]\n        self.num_obs = cfg[\"env\"][\"numObservations\"]\n        self.num_states = cfg[\"env\"].get(\"numStates\", 0)\n        self.num_actions = cfg[\"env\"][\"numActions\"]\n\n        self.control_freq_inv = cfg[\"env\"].get(\"controlFrequencyInv\", 1)\n\n        # optimization flags for pytorch JIT\n        torch._C._jit_set_profiling_mode(False)\n        torch._C._jit_set_profiling_executor(False)\n\n        # allocate buffers\n        self.obs_buf = torch.zeros(\n            (self.num_envs, self.num_obs), device=self.device, dtype=torch.float)\n        self.states_buf = torch.zeros(\n            (self.num_envs, self.num_states), device=self.device, dtype=torch.float)\n        self.rew_buf = torch.zeros(\n            self.num_envs, device=self.device, dtype=torch.float)\n        self.reset_buf = torch.ones(\n            self.num_envs, device=self.device, dtype=torch.long)\n        self.progress_buf = torch.zeros(\n            self.num_envs, device=self.device, dtype=torch.long)\n        self.randomize_buf = torch.zeros(\n            self.num_envs, device=self.device, dtype=torch.long)\n        self.extras = {}\n\n        self.original_props = {}\n        self.dr_randomizations = {}\n        self.first_randomization = True\n        self.actor_params_generator = None\n        self.extern_actor_params = {}\n        for env_id in range(self.num_envs):\n            self.extern_actor_params[env_id] = None\n\n        self.last_step = -1\n        self.last_rand_step = -1\n\n        # create envs, sim and viewer\n        self.create_sim()\n        self.gym.prepare_sim(self.sim)\n\n        # todo: read from config\n        self.enable_viewer_sync = True\n        self.viewer = None\n\n        # if running with a viewer, set up keyboard shortcuts and camera\n        if self.headless == False:\n            # subscribe to keyboard shortcuts\n            self.viewer = self.gym.create_viewer(\n                self.sim, gymapi.CameraProperties())\n            self.gym.subscribe_viewer_keyboard_event(\n                self.viewer, gymapi.KEY_ESCAPE, \"QUIT\")\n            self.gym.subscribe_viewer_keyboard_event(\n                self.viewer, gymapi.KEY_V, \"toggle_viewer_sync\")\n\n            # set the camera position based on up axis\n            sim_params = self.gym.get_sim_params(self.sim)\n            if sim_params.up_axis == gymapi.UP_AXIS_Z:\n                cam_pos = gymapi.Vec3(20.0, 25.0, 3.0)\n                cam_target = gymapi.Vec3(10.0, 15.0, 0.0)\n            else:\n                cam_pos = gymapi.Vec3(20.0, 3.0, 25.0)\n                cam_target = gymapi.Vec3(10.0, 0.0, 15.0)\n\n            self.gym.viewer_camera_look_at(\n                self.viewer, None, cam_pos, cam_target)\n\n    # set gravity based on up axis and return axis index\n    def set_sim_params_up_axis(self, sim_params, axis):\n        if axis == 'z':\n            sim_params.up_axis = gymapi.UP_AXIS_Z\n            sim_params.gravity.x = 0\n            sim_params.gravity.y = 0\n            sim_params.gravity.z = -9.81\n            return 2\n        return 1\n\n    def create_sim(self, compute_device, graphics_device, physics_engine, sim_params):\n        sim = self.gym.create_sim(compute_device, graphics_device, physics_engine, sim_params)\n        if sim is None:\n            print(\"*** Failed to create sim\")\n            quit()\n\n        return sim\n\n    def step(self, actions):\n        if self.dr_randomizations.get('actions', None):\n            actions = self.dr_randomizations['actions']['noise_lambda'](actions)\n\n        # apply actions\n        self.pre_physics_step(actions)\n\n        # step physics and render each frame\n        self._physics_step()\n\n        # to fix!\n        if self.device == 'cpu':\n            self.gym.fetch_results(self.sim, True)\n\n        # compute observations, rewards, resets, ...\n        self.post_physics_step()\n\n        if self.dr_randomizations.get('observations', None):\n            self.obs_buf = self.dr_randomizations['observations']['noise_lambda'](self.obs_buf)\n\n    def get_states(self):\n        return self.states_buf\n\n    def render(self, sync_frame_time=False):\n        if self.viewer:\n            # check for window closed\n            if self.gym.query_viewer_has_closed(self.viewer):\n                sys.exit()\n\n            # check for keyboard events\n            for evt in self.gym.query_viewer_action_events(self.viewer):\n                if evt.action == \"QUIT\" and evt.value > 0:\n                    sys.exit()\n                elif evt.action == \"toggle_viewer_sync\" and evt.value > 0:\n                    self.enable_viewer_sync = not self.enable_viewer_sync\n\n            # fetch results\n            if self.device != 'cpu':\n                self.gym.fetch_results(self.sim, True)\n\n            # step graphics\n            if self.enable_viewer_sync:\n                self.gym.step_graphics(self.sim)\n                self.gym.draw_viewer(self.viewer, self.sim, True)\n            else:\n                self.gym.poll_viewer_events(self.viewer)\n\n    def get_actor_params_info(self, dr_params, env):\n        \"\"\"Returns a flat array of actor params, their names and ranges.\"\"\"\n        if \"actor_params\" not in dr_params:\n            return None\n        params = []\n        names = []\n        lows = []\n        highs = []\n        param_getters_map = get_property_getter_map(self.gym)\n        for actor, actor_properties in dr_params[\"actor_params\"].items():\n            handle = self.gym.find_actor_handle(env, actor)\n            for prop_name, prop_attrs in actor_properties.items():\n                if prop_name == 'color':\n                    continue  # this is set randomly\n                props = param_getters_map[prop_name](env, handle)\n                if not isinstance(props, list):\n                    props = [props]\n                for prop_idx, prop in enumerate(props):\n                    for attr, attr_randomization_params in prop_attrs.items():\n                        name = prop_name+'_'+str(prop_idx)+'_'+attr\n                        lo_hi = attr_randomization_params['range']\n                        distr = attr_randomization_params['distribution']\n                        if 'uniform' not in distr:\n                            lo_hi = (-1.0*float('Inf'), float('Inf'))\n                        if isinstance(prop, np.ndarray):\n                            for attr_idx in range(prop[attr].shape[0]):\n                                params.append(prop[attr][attr_idx])\n                                names.append(name+'_'+str(attr_idx))\n                                lows.append(lo_hi[0])\n                                highs.append(lo_hi[1])\n                        else:\n                            params.append(getattr(prop, attr))\n                            names.append(name)\n                            lows.append(lo_hi[0])\n                            highs.append(lo_hi[1])\n        return params, names, lows, highs\n\n    # Apply randomizations only on resets, due to current PhysX limitations\n    def apply_randomizations(self, dr_params):\n        # If we don't have a randomization frequency, randomize every step\n        rand_freq = dr_params.get(\"frequency\", 1)\n\n        # First, determine what to randomize:\n        #   - non-environment parameters when > frequency steps have passed since the last non-environment\n        #   - physical environments in the reset buffer, which have exceeded the randomization frequency threshold\n        #   - on the first call, randomize everything\n        self.last_step = self.gym.get_frame_count(self.sim)\n        if self.first_randomization:\n            do_nonenv_randomize = True\n            env_ids = list(range(self.num_envs))\n        else:\n            do_nonenv_randomize = (self.last_step - self.last_rand_step) >= rand_freq\n            rand_envs = torch.where(self.randomize_buf >= rand_freq, torch.ones_like(self.randomize_buf), torch.zeros_like(self.randomize_buf))\n            rand_envs = torch.logical_and(rand_envs, self.reset_buf)\n            env_ids = torch.nonzero(rand_envs, as_tuple=False).squeeze(-1).tolist()\n            self.randomize_buf[rand_envs] = 0\n\n        if do_nonenv_randomize:\n            self.last_rand_step = self.last_step\n\n        param_setters_map = get_property_setter_map(self.gym)\n        param_setter_defaults_map = get_default_setter_args(self.gym)\n        param_getters_map = get_property_getter_map(self.gym)\n\n        # On first iteration, check the number of buckets\n        if self.first_randomization:\n            check_buckets(self.gym, self.envs, dr_params)\n\n        for nonphysical_param in [\"observations\", \"actions\"]:\n            if nonphysical_param in dr_params and do_nonenv_randomize:\n                dist = dr_params[nonphysical_param][\"distribution\"]\n                op_type = dr_params[nonphysical_param][\"operation\"]\n                sched_type = dr_params[nonphysical_param][\"schedule\"] if \"schedule\" in dr_params[nonphysical_param] else None\n                sched_step = dr_params[nonphysical_param][\"schedule_steps\"] if \"schedule\" in dr_params[nonphysical_param] else None\n                op = operator.add if op_type == 'additive' else operator.mul\n\n                if sched_type == 'linear':\n                    sched_scaling = 1.0 / sched_step * \\\n                        min(self.last_step, sched_step)\n                elif sched_type == 'constant':\n                    sched_scaling = 0 if self.last_step < sched_step else 1\n                else:\n                    sched_scaling = 1\n\n                if dist == 'gaussian':\n                    mu, var = dr_params[nonphysical_param][\"range\"]\n                    mu_corr, var_corr = dr_params[nonphysical_param].get(\"range_correlated\", [0., 0.])\n\n                    if op_type == 'additive':\n                        mu *= sched_scaling\n                        var *= sched_scaling\n                        mu_corr *= sched_scaling\n                        var_corr *= sched_scaling\n                    elif op_type == 'scaling':\n                        var = var * sched_scaling  # scale up var over time\n                        mu = mu * sched_scaling + 1.0 * \\\n                            (1.0 - sched_scaling)  # linearly interpolate\n\n                        var_corr = var_corr * sched_scaling  # scale up var over time\n                        mu_corr = mu_corr * sched_scaling + 1.0 * \\\n                            (1.0 - sched_scaling)  # linearly interpolate\n\n                    def noise_lambda(tensor, param_name=nonphysical_param):\n                        params = self.dr_randomizations[param_name]\n                        corr = params.get('corr', None)\n                        if corr is None:\n                            corr = torch.randn_like(tensor)\n                            params['corr'] = corr\n                        corr = corr * params['var_corr'] + params['mu_corr']\n                        return op(\n                            tensor, corr + torch.randn_like(tensor) * params['var'] + params['mu'])\n\n                    self.dr_randomizations[nonphysical_param] = {'mu': mu, 'var': var, 'mu_corr': mu_corr, 'var_corr': var_corr, 'noise_lambda': noise_lambda}\n\n                elif dist == 'uniform':\n                    lo, hi = dr_params[nonphysical_param][\"range\"]\n                    lo_corr, hi_corr = dr_params[nonphysical_param].get(\"range_correlated\", [0., 0.])\n\n                    if op_type == 'additive':\n                        lo *= sched_scaling\n                        hi *= sched_scaling\n                        lo_corr *= sched_scaling\n                        hi_corr *= sched_scaling\n                    elif op_type == 'scaling':\n                        lo = lo * sched_scaling + 1.0 * (1.0 - sched_scaling)\n                        hi = hi * sched_scaling + 1.0 * (1.0 - sched_scaling)\n                        lo_corr = lo_corr * sched_scaling + 1.0 * (1.0 - sched_scaling)\n                        hi_corr = hi_corr * sched_scaling + 1.0 * (1.0 - sched_scaling)\n\n                    def noise_lambda(tensor, param_name=nonphysical_param):\n                        params = self.dr_randomizations[param_name]\n                        corr = params.get('corr', None)\n                        if corr is None:\n                            corr = torch.randn_like(tensor)\n                            params['corr'] = corr\n                        corr = corr * (params['hi_corr'] - params['lo_corr']) + params['lo_corr']\n                        return op(tensor, corr + torch.rand_like(tensor) * (params['hi'] - params['lo']) + params['lo'])\n\n                    self.dr_randomizations[nonphysical_param] = {'lo': lo, 'hi': hi, 'lo_corr': lo_corr, 'hi_corr': hi_corr, 'noise_lambda': noise_lambda}\n\n        if \"sim_params\" in dr_params and do_nonenv_randomize:\n            prop_attrs = dr_params[\"sim_params\"]\n            prop = self.gym.get_sim_params(self.sim)\n\n            if self.first_randomization:\n                self.original_props[\"sim_params\"] = {\n                    attr: getattr(prop, attr) for attr in dir(prop)}\n\n            for attr, attr_randomization_params in prop_attrs.items():\n                apply_random_samples(\n                    prop, self.original_props[\"sim_params\"], attr, attr_randomization_params, self.last_step)\n\n            self.gym.set_sim_params(self.sim, prop)\n\n        # If self.actor_params_generator is initialized: use it to\n        # sample actor simulation params. This gives users the\n        # freedom to generate samples from arbitrary distributions,\n        # e.g. use full-covariance distributions instead of the DR's\n        # default of treating each simulation parameter independently.\n        extern_offsets = {}\n        if self.actor_params_generator is not None:\n            for env_id in env_ids:\n                self.extern_actor_params[env_id] = \\\n                    self.actor_params_generator.sample()\n                extern_offsets[env_id] = 0\n\n        for actor, actor_properties in dr_params[\"actor_params\"].items():\n            for env_id in env_ids:\n                env = self.envs[env_id]\n                handle = self.gym.find_actor_handle(env, actor)\n                extern_sample = self.extern_actor_params[env_id]\n\n                for prop_name, prop_attrs in actor_properties.items():\n                    if prop_name == 'color':\n                        num_bodies = self.gym.get_actor_rigid_body_count(\n                            env, handle)\n                        for n in range(num_bodies):\n                            self.gym.set_rigid_body_color(env, handle, n, gymapi.MESH_VISUAL,\n                                                          gymapi.Vec3(random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1)))\n                        continue\n                    if prop_name == 'scale':\n                        attr_randomization_params = prop_attrs\n                        sample = generate_random_samples(attr_randomization_params, 1,\n                                                         self.last_step, None)\n                        og_scale = 1\n                        if attr_randomization_params['operation'] == 'scaling':\n                            new_scale = og_scale * sample\n                        elif attr_randomization_params['operation'] == 'additive':\n                            new_scale = og_scale + sample\n                        self.gym.set_actor_scale(env, handle, new_scale)\n                        continue\n\n                    prop = param_getters_map[prop_name](env, handle)\n                    if isinstance(prop, list):\n                        if self.first_randomization:\n                            self.original_props[prop_name] = [\n                                {attr: getattr(p, attr) for attr in dir(p)} for p in prop]\n                        for p, og_p in zip(prop, self.original_props[prop_name]):\n                            for attr, attr_randomization_params in prop_attrs.items():\n                                smpl = None\n                                if self.actor_params_generator is not None:\n                                    smpl, extern_offsets[env_id] = get_attr_val_from_sample(\n                                        extern_sample, extern_offsets[env_id], p, attr)\n                                apply_random_samples(\n                                    p, og_p, attr, attr_randomization_params,\n                                    self.last_step, smpl)\n                    else:\n                        if self.first_randomization:\n                            self.original_props[prop_name] = deepcopy(prop)\n                        for attr, attr_randomization_params in prop_attrs.items():\n                            smpl = None\n                            if self.actor_params_generator is not None:\n                                smpl, extern_offsets[env_id] = get_attr_val_from_sample(\n                                    extern_sample, extern_offsets[env_id], prop, attr)\n                            apply_random_samples(\n                                prop, self.original_props[prop_name], attr,\n                                attr_randomization_params, self.last_step, smpl)\n\n                    setter = param_setters_map[prop_name]\n                    default_args = param_setter_defaults_map[prop_name]\n                    setter(env, handle, prop, *default_args)\n\n        if self.actor_params_generator is not None:\n            for env_id in env_ids:  # check that we used all dims in sample\n                if extern_offsets[env_id] > 0:\n                    extern_sample = self.extern_actor_params[env_id]\n                    if extern_offsets[env_id] != extern_sample.shape[0]:\n                        print('env_id', env_id,\n                              'extern_offset', extern_offsets[env_id],\n                              'vs extern_sample.shape', extern_sample.shape)\n                        raise Exception(\"Invalid extern_sample size\")\n\n        self.first_randomization = False\n\n    def pre_physics_step(self, actions):\n        raise NotImplementedError\n\n    def _physics_step(self):\n        for i in range(self.control_freq_inv):\n            self.render()\n            self.gym.simulate(self.sim)\n        return\n\n    def post_physics_step(self):\n        raise NotImplementedError\n\n\ndef get_attr_val_from_sample(sample, offset, prop, attr):\n    \"\"\"Retrieves param value for the given prop and attr from the sample.\"\"\"\n    if sample is None:\n        return None, 0\n    if isinstance(prop, np.ndarray):\n        smpl = sample[offset:offset+prop[attr].shape[0]]\n        return smpl, offset+prop[attr].shape[0]\n    else:\n        return sample[offset], offset+1\n"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/humanoid.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nimport numpy as np\nimport os\nimport torch\n\nfrom isaacgym import gymtorch\nfrom isaacgym import gymapi\nfrom isaacgym.torch_utils import *\n\nfrom timechamber.utils import torch_utils\nfrom timechamber.utils.utils import print_actor_info, print_asset_info\nfrom timechamber.tasks.ase_humanoid_base.base_task import BaseTask\n\nclass Humanoid(BaseTask):\n    def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, headless):\n        self.cfg = cfg\n        self.sim_params = sim_params\n        self.physics_engine = physics_engine\n\n        ##\n        self.borderline_space = self.cfg[\"env\"][\"borderlineSpace\"]\n        self.num_agents = self.cfg[\"env\"].get(\"numAgents\", 1)\n        \n        self._pd_control = self.cfg[\"env\"][\"pdControl\"]\n        self.power_scale = self.cfg[\"env\"][\"powerScale\"]\n\n        self.debug_viz = self.cfg[\"env\"][\"enableDebugVis\"]\n        self.plane_static_friction = self.cfg[\"env\"][\"plane\"][\"staticFriction\"]\n        self.plane_dynamic_friction = self.cfg[\"env\"][\"plane\"][\"dynamicFriction\"]\n        self.plane_restitution = self.cfg[\"env\"][\"plane\"][\"restitution\"]\n\n        self.max_episode_length = self.cfg[\"env\"][\"episodeLength\"]\n        self._local_root_obs = self.cfg[\"env\"][\"localRootObs\"]\n        self._root_height_obs = self.cfg[\"env\"].get(\"rootHeightObs\", True)\n        self._enable_early_termination = self.cfg[\"env\"][\"enableEarlyTermination\"]\n        \n        key_bodies = self.cfg[\"env\"][\"keyBodies\"]\n        self._setup_character_props(key_bodies)\n\n        self.cfg[\"env\"][\"numObservations\"] = self.get_obs_size()\n        self.cfg[\"env\"][\"numActions\"] = self.get_action_size()\n\n        self.cfg[\"device_type\"] = device_type\n        self.cfg[\"device_id\"] = device_id\n        self.cfg[\"headless\"] = headless\n         \n        super().__init__(cfg=self.cfg)\n        \n        self.dt = self.control_freq_inv * sim_params.dt\n\n        # get gym GPU state tensors\n        actor_root_state = self.gym.acquire_actor_root_state_tensor(self.sim)\n        dof_state_tensor = self.gym.acquire_dof_state_tensor(self.sim)\n        # print(f\"dof_state_tensor shape: {dof_state_tensor.shape}\")\n        sensor_tensor = self.gym.acquire_force_sensor_tensor(self.sim)\n        rigid_body_state = self.gym.acquire_rigid_body_state_tensor(self.sim)\n        contact_force_tensor = self.gym.acquire_net_contact_force_tensor(self.sim)\n\n        sensors_per_env = 2\n        self.vec_sensor_tensor = gymtorch.wrap_tensor(sensor_tensor).view(self.num_envs * self.num_agents, sensors_per_env * 6)\n\n        dof_force_tensor = self.gym.acquire_dof_force_tensor(self.sim)\n        self.dof_force_tensor = gymtorch.wrap_tensor(dof_force_tensor).view(self.num_envs * self.num_agents, self.num_dof)\n\n        self.gym.refresh_dof_state_tensor(self.sim)\n        self.gym.refresh_actor_root_state_tensor(self.sim)\n        self.gym.refresh_rigid_body_state_tensor(self.sim)\n        self.gym.refresh_net_contact_force_tensor(self.sim)\n\n        self._root_states = gymtorch.wrap_tensor(actor_root_state)\n        # print(f'root_states:{self._root_states.shape}')\n        num_actors = self.get_num_actors_per_env()\n        # print(f\"num actors: {num_actors}\")\n\n        self._humanoid_root_states = self._root_states\n        # print(f\"humanoid_root_states shape: {self._humanoid_root_states.shape}\") # (num_envs*2, 13)\n        self._initial_humanoid_root_states = self._humanoid_root_states.clone()\n        self._initial_humanoid_root_states[:, 7:13] = 0 # zero for linear vel and angular vel\n\n        self._humanoid_actor_ids = num_actors * torch.arange(self.num_envs, device=self.device, dtype=torch.int32)\n        # print(f\"humanoid_actor_ids: {self._humanoid_actor_ids}\") # 0, 2, 4, 6...\n        # print(f\"humanoid indices: {self.humanoid_indices}\") # 0, 2, 4, 6...\n        # print(f\"humanooid op indices: {self.humanoid_indices_op}\") # 1, 3, 5, 7...\n\n        # create some wrapper tensors for different slices\n        self._dof_state = gymtorch.wrap_tensor(dof_state_tensor)\n        dofs_per_env = self._dof_state.shape[0] // self.num_envs\n        self._dof_pos = self._dof_state.view(self.num_envs, dofs_per_env, 2)[..., :self.num_dof, 0]\n        self._dof_vel = self._dof_state.view(self.num_envs, dofs_per_env, 2)[..., :self.num_dof, 1]\n        # op\n        self._dof_pos_op = self._dof_state.view(self.num_envs, dofs_per_env, 2)[..., self.num_dof:, 0]\n        self._dof_vel_op = self._dof_state.view(self.num_envs, dofs_per_env, 2)[..., self.num_dof:, 1]\n\n        self._initial_dof_pos = torch.zeros_like(self._dof_pos, device=self.device, dtype=torch.float)\n        self._initial_dof_vel = torch.zeros_like(self._dof_vel, device=self.device, dtype=torch.float)\n        # op\n        self._initial_dof_pos_op = torch.zeros_like(self._dof_pos, device=self.device, dtype=torch.float)\n        self._initial_dof_vel_op = torch.zeros_like(self._dof_vel, device=self.device, dtype=torch.float)\n\n        self._rigid_body_state = gymtorch.wrap_tensor(rigid_body_state)\n\n        bodies_per_env = self._rigid_body_state.shape[0] // self.num_envs\n        rigid_body_state_reshaped = self._rigid_body_state.view(self.num_envs, bodies_per_env, 13)\n\n        self._rigid_body_pos = rigid_body_state_reshaped[..., :self.num_bodies, 0:3]\n        self._rigid_body_rot = rigid_body_state_reshaped[..., :self.num_bodies, 3:7]\n        self._rigid_body_vel = rigid_body_state_reshaped[..., :self.num_bodies, 7:10]\n        self._rigid_body_ang_vel = rigid_body_state_reshaped[..., :self.num_bodies, 10:13]\n        # op\n        self._rigid_body_pos_op = rigid_body_state_reshaped[..., self.num_bodies:, 0:3]\n        self._rigid_body_rot_op = rigid_body_state_reshaped[..., self.num_bodies:, 3:7]\n        self._rigid_body_vel_op = rigid_body_state_reshaped[..., self.num_bodies:, 7:10]\n        self._rigid_body_ang_vel_op = rigid_body_state_reshaped[..., self.num_bodies:, 10:13]\n\n        contact_force_tensor = gymtorch.wrap_tensor(contact_force_tensor)\n        self._contact_forces = contact_force_tensor.view(self.num_envs, bodies_per_env, 3)[..., :self.num_bodies, :]\n        self._contact_forces_op = contact_force_tensor.view(self.num_envs, bodies_per_env, 3)[..., self.num_bodies:, :]\n\n        self._terminate_buf = torch.ones(self.num_envs, device=self.device, dtype=torch.long)\n\n        self._build_termination_heights()\n\n        contact_bodies = self.cfg[\"env\"][\"contactBodies\"]\n        self._key_body_ids = self._build_key_body_ids_tensor(key_bodies)\n        self._contact_body_ids = self._build_contact_body_ids_tensor(contact_bodies)\n        self.allocate_buffers()\n\n        return\n\n    def get_obs_size(self):\n        return self._num_obs\n\n    def get_action_size(self):\n        return self._num_actions\n\n    def get_num_actors_per_env(self):\n        num_actors = self._root_states.shape[0] // self.num_envs\n        return num_actors\n\n    def _add_circle_borderline(self, env):\n        lines = []\n        borderline_height = 0.01\n        for height in range(20):\n            for angle in range(360):\n                begin_point = [np.cos(np.radians(angle)), np.sin(np.radians(angle)), borderline_height * height]\n                end_point = [np.cos(np.radians(angle + 1)), np.sin(np.radians(angle + 1)), borderline_height * height]\n                lines.append(begin_point)\n                lines.append(end_point)\n        lines = np.array(lines, dtype=np.float32) * self.borderline_space\n        colors = np.array([[1, 0, 0]] * int(len(lines) / 2), dtype=np.float32)\n        self.gym.add_lines(self.viewer, env, int(len(lines) / 2), lines, colors)\n\n    def _add_rectangle_borderline(self, env):\n        lines = []\n        colors = np.zeros((90*60, 3), dtype=np.float32)\n        for k in range(4):\n            for height in range(10):\n                lines1 = []\n                lines2 = []\n                lines3 = []\n                lines4 = []\n                for i in range(90):\n                    begin_point1 = [-self.borderline_space + i * self.borderline_space / 45,\n                                self.borderline_space,\n                                height*0.01+ k*0.25]\n                    end_point1 = [-self.borderline_space + (i+1) * self.borderline_space / 45,\n                                self.borderline_space,\n                                height*0.01+ k*0.25]\n                    begin_point2 = [self.borderline_space,\n                                self.borderline_space - i * self.borderline_space / 45,\n                                height*0.01+ k*0.25]\n                    end_point2 = [self.borderline_space,\n                                self.borderline_space - (i+1) * self.borderline_space / 45,\n                                height*0.01+ k*0.25]\n                    begin_point3 = [self.borderline_space - i * self.borderline_space / 45,\n                                -self.borderline_space,\n                                height*0.01+ k*0.25]\n                    end_point3 = [self.borderline_space - (i+1) * self.borderline_space / 45,\n                                -self.borderline_space,\n                                height*0.01+ k*0.25]\n                    begin_point4 = [-self.borderline_space ,\n                                -self.borderline_space + i * self.borderline_space / 45,\n                                height*0.01+ k*0.25]\n                    end_point4 = [-self.borderline_space,\n                                -self.borderline_space + (i+1) * self.borderline_space / 45,\n                                height*0.01+ k*0.25]\n                    lines1.append(begin_point1)\n                    lines1.append(end_point1)\n                    lines2.append(begin_point2)\n                    lines2.append(end_point2)\n                    lines3.append(begin_point3)\n                    lines3.append(end_point3)\n                    lines4.append(begin_point4)\n                    lines4.append(end_point4)\n                lines.extend(lines1)\n                lines.extend(lines2)\n                lines.extend(lines3)\n                lines.extend(lines4)\n\n        lines = np.array(lines, dtype=np.float32)\n\n        colors = np.array([[1, 0, 0]] * int(len(lines) / 2), dtype=np.float32)\n        self.gym.add_lines(self.viewer, env, int(len(lines) / 2), lines, colors)\n\n    def allocate_buffers(self):\n        self.obs_buf = torch.zeros((self.num_agents * self.num_envs, self.num_obs), device=self.device,\n                                   dtype=torch.float)\n        self.states_buf = torch.zeros(\n            (self.num_envs, self.num_states), device=self.device, dtype=torch.float)\n        self.rew_buf = torch.zeros(\n            self.num_envs, device=self.device, dtype=torch.float)\n        self.reset_buf = torch.ones(self.num_envs, device=self.device, dtype=torch.long)\n        self.timeout_buf = torch.zeros(\n            self.num_envs, device=self.device, dtype=torch.long)\n        self.progress_buf = torch.zeros(\n            self.num_envs, device=self.device, dtype=torch.long)\n        self.randomize_buf = torch.zeros(\n            self.num_envs * self.num_agents, device=self.device, dtype=torch.long)\n        self.extras = {\n            'win': torch.zeros(((self.num_agents - 1) * self.num_envs,), device=self.device, dtype=torch.bool),\n            'lose': torch.zeros(((self.num_agents - 1) * self.num_envs,), device=self.device, dtype=torch.bool),\n            'draw': torch.zeros(((self.num_agents - 1) * self.num_envs,), device=self.device, dtype=torch.bool)}\n        self.x_unit_tensor = to_torch([1, 0, 0], dtype=torch.float, device=self.device).repeat((2 * self.num_envs, 1))\n        self.y_unit_tensor = to_torch([0, 1, 0], dtype=torch.float, device=self.device).repeat((2 * self.num_envs, 1))\n        self.z_unit_tensor = to_torch([0, 0, 1], dtype=torch.float, device=self.device).repeat((2 * self.num_envs, 1))\n\n    def create_sim(self):\n        self.up_axis_idx = self.set_sim_params_up_axis(self.sim_params, 'z')\n        self.sim = super().create_sim(self.device_id, self.graphics_device_id, self.physics_engine, self.sim_params)\n\n        self._create_ground_plane()\n        self._create_envs(self.num_envs, self.cfg[\"env\"]['envSpacing'], int(np.sqrt(self.num_envs)))\n        return\n\n    def reset(self, env_ids=None):\n        if (env_ids is None):\n            env_ids = to_torch(np.arange(self.num_envs), device=self.device, dtype=torch.long)\n        self._reset_envs(env_ids)\n        return\n\n    def set_char_color(self, col, env_ids):\n        for env_id in env_ids:\n            env_ptr = self.envs[env_id]\n            handle = self.humanoid_handles[env_id]\n\n            for j in range(self.num_bodies):\n                self.gym.set_rigid_body_color(env_ptr, handle, j, gymapi.MESH_VISUAL,\n                                              gymapi.Vec3(col[0], col[1], col[2]))\n\n        return\n\n    def _reset_envs(self, env_ids):\n        if (len(env_ids) > 0):\n            self._reset_actors(env_ids)\n            self._reset_env_tensors(env_ids)\n            self._refresh_sim_tensors()\n            self._compute_observations()\n        return\n\n    def _reset_env_tensors(self, env_ids):\n        # env_ids_int32 = self._humanoid_actor_ids[env_ids]\n        env_ids_int32 = (torch.cat((self.humanoid_indices[env_ids],\n                                    self.humanoid_indices_op[env_ids]))).to(dtype=torch.int32)\n        self.gym.set_actor_root_state_tensor_indexed(self.sim,\n                                                     gymtorch.unwrap_tensor(self._root_states),\n                                                     gymtorch.unwrap_tensor(env_ids_int32), len(env_ids_int32))\n        self.gym.set_dof_state_tensor_indexed(self.sim,\n                                              gymtorch.unwrap_tensor(self._dof_state),\n                                              gymtorch.unwrap_tensor(env_ids_int32), len(env_ids_int32))\n\n        self.progress_buf[env_ids] = 0\n        self.reset_buf[env_ids] = 0\n        self._terminate_buf[env_ids] = 0\n        \n        return\n\n    def _create_ground_plane(self):\n        plane_params = gymapi.PlaneParams()\n        plane_params.normal = gymapi.Vec3(0.0, 0.0, 1.0)\n        plane_params.static_friction = self.plane_static_friction\n        plane_params.dynamic_friction = self.plane_dynamic_friction\n        plane_params.restitution = self.plane_restitution\n        self.gym.add_ground(self.sim, plane_params)\n        return\n\n    def _setup_character_props(self, key_bodies):\n        asset_file = self.cfg[\"env\"][\"asset\"][\"assetFileName\"]\n        num_key_bodies = len(key_bodies)\n\n        if (asset_file == \"mjcf/amp_humanoid.xml\"):\n            self._dof_body_ids = [1, 2, 3, 4, 6, 7, 9, 10, 11, 12, 13, 14]\n            self._dof_offsets = [0, 3, 6, 9, 10, 13, 14, 17, 18, 21, 24, 25, 28]\n            self._dof_obs_size = 72\n            self._num_actions = 28\n            self._num_obs = 1 + 15 * (3 + 6 + 3 + 3) - 3\n\n        elif (asset_file == \"mjcf/amp_humanoid_sword_shield.xml\"):\n            self._dof_body_ids = [1, 2, 3, 4, 5, 7, 8, 11, 12, 13, 14, 15, 16]\n            self._dof_offsets = [0, 3, 6, 9, 10, 13, 16, 17, 20, 21, 24, 27, 28, 31]\n            self._dof_obs_size = 78\n            self._num_actions = 31\n            self._num_obs = 1 + 17 * (3 + 6 + 3 + 3) - 3\n\n        else:\n            print(\"Unsupported character config file: {s}\".format(asset_file))\n            assert(False)\n\n        return\n\n    def _build_termination_heights(self):\n        head_term_height = 0.3\n        shield_term_height = 0.32\n\n        termination_height = self.cfg[\"env\"][\"terminationHeight\"]\n        self._termination_heights = np.array([termination_height] * self.num_bodies)\n\n        head_id = self.gym.find_actor_rigid_body_handle(self.envs[0], self.humanoid_handles[0], \"head\")\n        self._termination_heights[head_id] = max(head_term_height, self._termination_heights[head_id])\n\n        asset_file = self.cfg[\"env\"][\"asset\"][\"assetFileName\"]\n        if (asset_file == \"mjcf/amp_humanoid_sword_shield.xml\"):\n            left_arm_id = self.gym.find_actor_rigid_body_handle(self.envs[0], self.humanoid_handles[0], \"left_lower_arm\")\n            self._termination_heights[left_arm_id] = max(shield_term_height, self._termination_heights[left_arm_id])\n\n        self._termination_heights = to_torch(self._termination_heights, device=self.device)\n        return\n\n    def _create_envs(self, num_envs, spacing, num_per_row):\n        lower = gymapi.Vec3(-spacing, -spacing, 0.0)\n        upper = gymapi.Vec3(spacing, spacing, spacing)\n\n        asset_root = self.cfg[\"env\"][\"asset\"][\"assetRoot\"]\n        asset_file = self.cfg[\"env\"][\"asset\"][\"assetFileName\"]\n\n        asset_path = os.path.join(asset_root, asset_file)\n        asset_root = os.path.dirname(asset_path)\n        asset_file = os.path.basename(asset_path)\n\n        asset_options = gymapi.AssetOptions()\n        asset_options.angular_damping = 0.01\n        asset_options.max_angular_velocity = 100.0\n        asset_options.default_dof_drive_mode = gymapi.DOF_MODE_NONE\n        #asset_options.fix_base_link = True\n        humanoid_asset = self.gym.load_asset(self.sim, asset_root, asset_file, asset_options)\n        humanoid_asset_op = self.gym.load_asset(self.sim, asset_root, asset_file, asset_options)\n\n        actuator_props = self.gym.get_asset_actuator_properties(humanoid_asset)\n        motor_efforts = [prop.motor_effort for prop in actuator_props]\n\n        # create force sensors at the feet\n        right_foot_idx = self.gym.find_asset_rigid_body_index(humanoid_asset, \"right_foot\")\n        left_foot_idx = self.gym.find_asset_rigid_body_index(humanoid_asset, \"left_foot\")\n\n        # op\n        right_foot_idx_op = self.gym.find_asset_rigid_body_index(humanoid_asset_op, \"right_foot\")\n        left_foot_idx_op = self.gym.find_asset_rigid_body_index(humanoid_asset_op, \"left_foot\")\n\n        sensor_pose = gymapi.Transform()\n        sensor_pose_op = gymapi.Transform()\n\n        self.gym.create_asset_force_sensor(humanoid_asset, right_foot_idx, sensor_pose)\n        self.gym.create_asset_force_sensor(humanoid_asset, left_foot_idx, sensor_pose)\n\n        # op\n        self.gym.create_asset_force_sensor(humanoid_asset_op, right_foot_idx_op, sensor_pose_op)\n        self.gym.create_asset_force_sensor(humanoid_asset_op, left_foot_idx_op, sensor_pose_op)\n\n        self.max_motor_effort = max(motor_efforts)\n        self.motor_efforts = to_torch(motor_efforts, device=self.device)\n\n        self.torso_index = 0\n\n        # 17 bodies\n        self.num_bodies = self.gym.get_asset_rigid_body_count(humanoid_asset)\n\n        # 31 dofs\n        self.num_dof = self.gym.get_asset_dof_count(humanoid_asset)\n\n        # 34 joints\n        self.num_joints = self.gym.get_asset_joint_count(humanoid_asset)\n\n        self.humanoid_handles = []\n        self.humanoid_handles_op = []\n        self.humanoid_indices = []\n        self.humanoid_indices_op = []\n        self.envs = []\n        self.dof_limits_lower = []\n        self.dof_limits_upper = []\n\n        for i in range(self.num_envs):\n            # create env instance\n            env_ptr = self.gym.create_env(self.sim, lower, upper, num_per_row)\n            self._build_env(i, env_ptr, humanoid_asset, humanoid_asset_op)\n            self.envs.append(env_ptr)\n\n        dof_prop = self.gym.get_actor_dof_properties(self.envs[0], self.humanoid_handles[0])\n        for j in range(self.num_dof):\n            if dof_prop['lower'][j] > dof_prop['upper'][j]:\n                self.dof_limits_lower.append(dof_prop['upper'][j])\n                self.dof_limits_upper.append(dof_prop['lower'][j])\n            else:\n                self.dof_limits_lower.append(dof_prop['lower'][j])\n                self.dof_limits_upper.append(dof_prop['upper'][j])\n\n        self.dof_limits_lower = to_torch(self.dof_limits_lower, device=self.device)\n        self.dof_limits_upper = to_torch(self.dof_limits_upper, device=self.device)\n        self.humanoid_indices = to_torch(self.humanoid_indices, dtype=torch.long, device=self.device)\n        self.humanoid_indices_op = to_torch(self.humanoid_indices_op, dtype=torch.long, device=self.device)\n        \n        if (self._pd_control):\n            self._build_pd_action_offset_scale()\n\n        return\n\n    def _build_env(self, env_id, env_ptr, humanoid_asset, humanoid_asset_op):\n        col_group = env_id\n        col_filter = self._get_humanoid_collision_filter()\n        segmentation_id = 0\n\n        start_pose = gymapi.Transform()\n        start_pose_op = gymapi.Transform()\n        # asset_file = self.cfg[\"env\"][\"asset\"][\"assetFileName\"]\n        # char_h = 0.89\n\n        start_pose.p = gymapi.Vec3(-self.borderline_space + 2, -self.borderline_space + 2, 0.89)\n        start_pose.r = gymapi.Quat(0.0, 0.0, 0.0, 1.0)\n\n        start_pose_op.p = gymapi.Vec3(self.borderline_space - 2, self.borderline_space - 2, 0.89)\n        # start_pose_op.p = gymapi.Vec3(0, 0, 0.89)\n        start_pose_op.r = gymapi.Quat(0.0, 0.0, 0.0, 1.0)\n\n        humanoid_handle = self.gym.create_actor(env_ptr, humanoid_asset, start_pose, \"humanoid\", col_group, col_filter, segmentation_id)\n        humanoid_index = self.gym.get_actor_index(env_ptr, humanoid_handle, gymapi.DOMAIN_SIM)\n\n        humanoid_handle_op = self.gym.create_actor(env_ptr, humanoid_asset_op, start_pose_op, \"humanoid\", col_group, col_filter, segmentation_id)\n        humanoid_index_op = self.gym.get_actor_index(env_ptr, humanoid_handle_op, gymapi.DOMAIN_SIM)\n\n        self.gym.enable_actor_dof_force_sensors(env_ptr, humanoid_handle)\n        self.gym.enable_actor_dof_force_sensors(env_ptr, humanoid_handle_op)\n\n        for j in range(self.num_bodies):\n            self.gym.set_rigid_body_color(env_ptr, humanoid_handle, j, gymapi.MESH_VISUAL, gymapi.Vec3(0.54, 0.85, 0.2))\n            self.gym.set_rigid_body_color(env_ptr, humanoid_handle_op, j, gymapi.MESH_VISUAL, gymapi.Vec3(0.97, 0.38, 0.06))\n\n        if (self._pd_control):\n            dof_prop = self.gym.get_asset_dof_properties(humanoid_asset)\n            dof_prop[\"driveMode\"] = gymapi.DOF_MODE_POS\n            self.gym.set_actor_dof_properties(env_ptr, humanoid_handle, dof_prop)\n\n            dof_prop_op = self.gym.get_asset_dof_properties(humanoid_asset_op)\n            dof_prop_op[\"driveMode\"] = gymapi.DOF_MODE_POS\n            self.gym.set_actor_dof_properties(env_ptr, humanoid_handle_op, dof_prop_op)\n\n        self.humanoid_handles.append(humanoid_handle)\n        self.humanoid_indices.append(humanoid_index)\n        self.humanoid_handles_op.append(humanoid_handle_op)\n        self.humanoid_indices_op.append(humanoid_index_op)\n\n        return\n\n    def _build_pd_action_offset_scale(self):\n        num_joints = len(self._dof_offsets) - 1\n\n        lim_low = self.dof_limits_lower.cpu().numpy()\n        lim_high = self.dof_limits_upper.cpu().numpy()\n\n        for j in range(num_joints):\n            dof_offset = self._dof_offsets[j]\n            dof_size = self._dof_offsets[j + 1] - self._dof_offsets[j]\n\n            if (dof_size == 3):\n                curr_low = lim_low[dof_offset:(dof_offset + dof_size)]\n                curr_high = lim_high[dof_offset:(dof_offset + dof_size)]\n                curr_low = np.max(np.abs(curr_low))\n                curr_high = np.max(np.abs(curr_high))\n                curr_scale = max([curr_low, curr_high])\n                curr_scale = 1.2 * curr_scale\n                curr_scale = min([curr_scale, np.pi])\n\n                lim_low[dof_offset:(dof_offset + dof_size)] = -curr_scale\n                lim_high[dof_offset:(dof_offset + dof_size)] = curr_scale\n                \n                #lim_low[dof_offset:(dof_offset + dof_size)] = -np.pi\n                #lim_high[dof_offset:(dof_offset + dof_size)] = np.pi\n\n\n            elif (dof_size == 1):\n                curr_low = lim_low[dof_offset]\n                curr_high = lim_high[dof_offset]\n                curr_mid = 0.5 * (curr_high + curr_low)\n                \n                # extend the action range to be a bit beyond the joint limits so that the motors\n                # don't lose their strength as they approach the joint limits\n                curr_scale = 0.7 * (curr_high - curr_low)\n                curr_low = curr_mid - curr_scale\n                curr_high = curr_mid + curr_scale\n\n                lim_low[dof_offset] = curr_low\n                lim_high[dof_offset] =  curr_high\n\n        self._pd_action_offset = 0.5 * (lim_high + lim_low)\n        self._pd_action_scale = 0.5 * (lim_high - lim_low)\n        self._pd_action_offset = to_torch(self._pd_action_offset, device=self.device)\n        self._pd_action_scale = to_torch(self._pd_action_scale, device=self.device)\n        return\n\n    def _get_humanoid_collision_filter(self):\n        return 0\n\n    def _compute_reward(self, actions):\n        self.rew_buf[:] = compute_humanoid_reward(self.obs_buf)\n        return\n\n    def _compute_reset(self):\n        self.reset_buf[:], self._terminate_buf[:] = compute_humanoid_reset(self.reset_buf, self.progress_buf,\n                                                   self._contact_forces, self._contact_body_ids,\n                                                   self._rigid_body_pos, self.max_episode_length,\n                                                   self._enable_early_termination, self._termination_heights)\n        return\n\n    def _refresh_sim_tensors(self):\n        self.gym.refresh_dof_state_tensor(self.sim)\n        self.gym.refresh_actor_root_state_tensor(self.sim)\n        self.gym.refresh_rigid_body_state_tensor(self.sim)\n\n        self.gym.refresh_force_sensor_tensor(self.sim)\n        self.gym.refresh_dof_force_tensor(self.sim)\n        self.gym.refresh_net_contact_force_tensor(self.sim)\n        return\n\n    def _compute_observations(self):\n        obs, obs_op = self._compute_humanoid_obs()\n\n        self.obs_buf[:self.num_envs] = obs\n        self.obs_buf[self.num_envs:] = obs_op\n\n        return\n\n    def _compute_humanoid_obs(self):\n        body_pos = self._rigid_body_pos\n        body_rot = self._rigid_body_rot\n        body_vel = self._rigid_body_vel\n        body_ang_vel = self._rigid_body_ang_vel\n\n        body_pos_op = self._rigid_body_pos_op\n        body_rot_op = self._rigid_body_rot_op\n        body_vel_op = self._rigid_body_vel_op\n        body_ang_vel_op = self._rigid_body_ang_vel_op\n        \n        obs = compute_humanoid_observations_max(body_pos, body_rot, body_vel, body_ang_vel, self._local_root_obs,\n                                                self._root_height_obs)\n        \n        obs_op = compute_humanoid_observations_max(body_pos_op, body_rot_op, body_vel_op, body_ang_vel_op, self._local_root_obs,\n                                                self._root_height_obs)\n        \n        return obs, obs_op\n\n    def _reset_actors(self, env_ids):\n        agent_env_ids = expand_env_ids(env_ids, 2)\n        self._humanoid_root_states[agent_env_ids] = self._initial_humanoid_root_states[agent_env_ids]\n        self._dof_pos[env_ids] = self._initial_dof_pos[env_ids]\n        self._dof_vel[env_ids] = self._initial_dof_vel[env_ids]\n        self._dof_pos_op[env_ids] = self._initial_dof_pos_op[env_ids]\n        self._dof_vel_op[env_ids] = self._initial_dof_vel_op[env_ids]\n        return\n\n    def pre_physics_step(self, actions):\n        self.actions = actions.to(self.device).clone()\n        ego_actions = self.actions[:self.num_envs]\n        op_actions = self.actions[self.num_envs:]\n        if (self._pd_control):\n            pd_tar_ego = self._action_to_pd_targets(ego_actions)\n            pd_tar_op = self._action_to_pd_targets(op_actions)\n            pd_tar = torch.cat([pd_tar_ego, pd_tar_op], dim=-1)\n            pd_tar_tensor = gymtorch.unwrap_tensor(pd_tar)\n\n            self.gym.set_dof_position_target_tensor(self.sim, pd_tar_tensor)\n        else:\n            forces = self.actions * self.motor_efforts.unsqueeze(0) * self.power_scale\n            force_tensor = gymtorch.unwrap_tensor(forces)\n            self.gym.set_dof_actuation_force_tensor(self.sim, force_tensor)\n\n        return\n\n    def post_physics_step(self):\n        self.progress_buf += 1\n\n        self._refresh_sim_tensors()\n        self._compute_observations()\n        self._compute_reward(self.actions)\n        self._compute_reset()\n\n        self.extras[\"terminate\"] = self._terminate_buf\n\n        # debug viz\n        if self.viewer and self.debug_viz:\n            self._update_debug_viz()\n\n        return\n\n    def render(self, sync_frame_time=False):\n\n        super().render(sync_frame_time)\n        return\n\n    def _build_key_body_ids_tensor(self, key_body_names):\n        env_ptr = self.envs[0]\n        actor_handle = self.humanoid_handles[0]\n        body_ids = []\n\n        for body_name in key_body_names:\n            body_id = self.gym.find_actor_rigid_body_handle(env_ptr, actor_handle, body_name)\n            assert(body_id != -1)\n            body_ids.append(body_id)\n\n        body_ids = to_torch(body_ids, device=self.device, dtype=torch.long)\n        return body_ids\n\n    def _build_contact_body_ids_tensor(self, contact_body_names):\n        env_ptr = self.envs[0]\n        actor_handle = self.humanoid_handles[0]\n        body_ids = []\n\n        for body_name in contact_body_names:\n            body_id = self.gym.find_actor_rigid_body_handle(env_ptr, actor_handle, body_name)\n            assert(body_id != -1)\n            body_ids.append(body_id)\n\n        body_ids = to_torch(body_ids, device=self.device, dtype=torch.long)\n        return body_ids\n\n    def _action_to_pd_targets(self, action):\n        pd_tar = self._pd_action_offset + self._pd_action_scale * action\n        return pd_tar\n\n    def _update_debug_viz(self):\n        self.gym.clear_lines(self.viewer)\n        return\n\n#####################################################################\n###=========================jit functions=========================###\n#####################################################################\n\n@torch.jit.script\ndef dof_to_obs(pose, dof_obs_size, dof_offsets):\n    # type: (Tensor, int, List[int]) -> Tensor\n    joint_obs_size = 6\n    num_joints = len(dof_offsets) - 1\n\n    dof_obs_shape = pose.shape[:-1] + (dof_obs_size,)\n    dof_obs = torch.zeros(dof_obs_shape, device=pose.device)\n    dof_obs_offset = 0\n\n    for j in range(num_joints):\n        dof_offset = dof_offsets[j]\n        dof_size = dof_offsets[j + 1] - dof_offsets[j]\n        joint_pose = pose[:, dof_offset:(dof_offset + dof_size)]\n\n        # assume this is a spherical joint\n        if (dof_size == 3):\n            joint_pose_q = torch_utils.exp_map_to_quat(joint_pose)\n        elif (dof_size == 1):\n            axis = torch.tensor([0.0, 1.0, 0.0], dtype=joint_pose.dtype, device=pose.device)\n            joint_pose_q = quat_from_angle_axis(joint_pose[..., 0], axis)\n        else:\n            joint_pose_q = None\n            assert(False), \"Unsupported joint type\"\n\n        joint_dof_obs = torch_utils.quat_to_tan_norm(joint_pose_q)\n        dof_obs[:, (j * joint_obs_size):((j + 1) * joint_obs_size)] = joint_dof_obs\n\n    assert((num_joints * joint_obs_size) == dof_obs_size)\n\n    return dof_obs\n\n@torch.jit.script\ndef compute_humanoid_observations(root_pos, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_body_pos,\n                                  local_root_obs, root_height_obs, dof_obs_size, dof_offsets):\n    # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, bool, bool, int, List[int]) -> Tensor\n    root_h = root_pos[:, 2:3]\n    heading_rot = torch_utils.calc_heading_quat_inv(root_rot)\n\n    if (local_root_obs):\n        root_rot_obs = quat_mul(heading_rot, root_rot)\n    else:\n        root_rot_obs = root_rot\n    root_rot_obs = torch_utils.quat_to_tan_norm(root_rot_obs)\n    \n    if (not root_height_obs):\n        root_h_obs = torch.zeros_like(root_h)\n    else:\n        root_h_obs = root_h\n    \n    local_root_vel = quat_rotate(heading_rot, root_vel)\n    local_root_ang_vel = quat_rotate(heading_rot, root_ang_vel)\n\n    root_pos_expand = root_pos.unsqueeze(-2)\n    local_key_body_pos = key_body_pos - root_pos_expand\n\n    heading_rot_expand = heading_rot.unsqueeze(-2)\n    heading_rot_expand = heading_rot_expand.repeat((1, local_key_body_pos.shape[1], 1))\n    flat_end_pos = local_key_body_pos.view(local_key_body_pos.shape[0] * local_key_body_pos.shape[1], local_key_body_pos.shape[2])\n    flat_heading_rot = heading_rot_expand.view(heading_rot_expand.shape[0] * heading_rot_expand.shape[1], \n                                               heading_rot_expand.shape[2])\n    local_end_pos = quat_rotate(flat_heading_rot, flat_end_pos)\n    flat_local_key_pos = local_end_pos.view(local_key_body_pos.shape[0], local_key_body_pos.shape[1] * local_key_body_pos.shape[2])\n\n    dof_obs = dof_to_obs(dof_pos, dof_obs_size, dof_offsets)\n\n    obs = torch.cat((root_h_obs, root_rot_obs, local_root_vel, local_root_ang_vel, dof_obs, dof_vel, flat_local_key_pos), dim=-1)\n    return obs\n\n@torch.jit.script\ndef compute_humanoid_observations_max(body_pos, body_rot, body_vel, body_ang_vel, local_root_obs, root_height_obs):\n    # type: (Tensor, Tensor, Tensor, Tensor, bool, bool) -> Tensor\n    root_pos = body_pos[:, 0, :] # 0: pelvis, root\n    root_rot = body_rot[:, 0, :]\n\n    root_h = root_pos[:, 2:3] # 1. Height of the root from the ground\n    heading_rot = torch_utils.calc_heading_quat_inv(root_rot)\n\n    if (not root_height_obs):\n        root_h_obs = torch.zeros_like(root_h)\n    else:\n        root_h_obs = root_h\n\n    heading_rot_expand = heading_rot.unsqueeze(-2) # num_envs, 1, 4\n    # num_envs, body_pos.shape[1], 4\n    heading_rot_expand = heading_rot_expand.repeat((1, body_pos.shape[1], 1))\n    flat_heading_rot = heading_rot_expand.reshape(heading_rot_expand.shape[0] * heading_rot_expand.shape[1], \n                                               heading_rot_expand.shape[2])\n\n    root_pos_expand = root_pos.unsqueeze(-2)\n    local_body_pos = body_pos - root_pos_expand\n    flat_local_body_pos = local_body_pos.reshape(local_body_pos.shape[0] * local_body_pos.shape[1], local_body_pos.shape[2])\n    flat_local_body_pos = quat_rotate(flat_heading_rot, flat_local_body_pos)\n    local_body_pos = flat_local_body_pos.reshape(local_body_pos.shape[0], local_body_pos.shape[1] * local_body_pos.shape[2])\n    local_body_pos = local_body_pos[..., 3:] # remove root pos\n\n    flat_body_rot = body_rot.reshape(body_rot.shape[0] * body_rot.shape[1], body_rot.shape[2])\n    flat_local_body_rot = quat_mul(flat_heading_rot, flat_body_rot)\n    flat_local_body_rot_obs = torch_utils.quat_to_tan_norm(flat_local_body_rot)\n    local_body_rot_obs = flat_local_body_rot_obs.reshape(body_rot.shape[0], body_rot.shape[1] * flat_local_body_rot_obs.shape[1])\n\n    if (local_root_obs):\n        root_rot_obs = torch_utils.quat_to_tan_norm(root_rot)\n        local_body_rot_obs[..., 0:6] = root_rot_obs\n\n    flat_body_vel = body_vel.reshape(body_vel.shape[0] * body_vel.shape[1], body_vel.shape[2])\n    flat_local_body_vel = quat_rotate(flat_heading_rot, flat_body_vel)\n    local_body_vel = flat_local_body_vel.reshape(body_vel.shape[0], body_vel.shape[1] * body_vel.shape[2])\n\n    flat_body_ang_vel = body_ang_vel.reshape(body_ang_vel.shape[0] * body_ang_vel.shape[1], body_ang_vel.shape[2])\n    flat_local_body_ang_vel = quat_rotate(flat_heading_rot, flat_body_ang_vel)\n    local_body_ang_vel = flat_local_body_ang_vel.reshape(body_ang_vel.shape[0], body_ang_vel.shape[1] * body_ang_vel.shape[2])\n\n    obs = torch.cat((root_h_obs, local_body_pos, local_body_rot_obs, local_body_vel, local_body_ang_vel), dim=-1)\n    return obs\n\n\n@torch.jit.script\ndef expand_env_ids(env_ids, n_agents):\n    # type: (Tensor, int) -> Tensor\n    device = env_ids.device\n    agent_env_ids = torch.zeros((n_agents * len(env_ids)), device=device, dtype=torch.long)\n    for idx in range(n_agents):\n        agent_env_ids[idx::n_agents] = env_ids * n_agents + idx\n    return agent_env_ids\n\n@torch.jit.script\ndef compute_humanoid_reward(obs_buf):\n    # type: (Tensor) -> Tensor\n    reward = torch.ones_like(obs_buf[:, 0])\n    return reward\n\n@torch.jit.script\ndef compute_humanoid_reset(reset_buf, progress_buf, contact_buf, contact_body_ids, rigid_body_pos,\n                           max_episode_length, enable_early_termination, termination_heights):\n    # type: (Tensor, Tensor, Tensor, Tensor, Tensor, float, bool, Tensor) -> Tuple[Tensor, Tensor]\n    terminated = torch.zeros_like(reset_buf)\n\n    if (enable_early_termination):\n        masked_contact_buf = contact_buf.clone()\n        masked_contact_buf[:, contact_body_ids, :] = 0\n        fall_contact = torch.any(torch.abs(masked_contact_buf) > 0.1, dim=-1)\n        fall_contact = torch.any(fall_contact, dim=-1)\n\n        body_height = rigid_body_pos[..., 2]\n        fall_height = body_height < termination_heights\n        fall_height[:, contact_body_ids] = False\n        fall_height = torch.any(fall_height, dim=-1)\n\n        has_fallen = torch.logical_and(fall_contact, fall_height)\n\n        # first timestep can sometimes still have nonzero contact forces\n        # so only check after first couple of steps\n        has_fallen *= (progress_buf > 1)\n        terminated = torch.where(has_fallen, torch.ones_like(reset_buf), terminated)\n    \n    reset = torch.where(progress_buf >= max_episode_length - 1, torch.ones_like(reset_buf), terminated)\n\n    return reset, terminated\n"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/humanoid_amp.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nfrom enum import Enum\nimport numpy as np\nimport torch\n\nfrom isaacgym import gymapi\nfrom isaacgym import gymtorch\n\nfrom timechamber.tasks.ase_humanoid_base.humanoid import Humanoid, dof_to_obs\nfrom timechamber.utils import gym_util\nfrom timechamber.utils.motion_lib import MotionLib\nfrom isaacgym.torch_utils import *\n\nfrom utils import torch_utils\n\nclass HumanoidAMP(Humanoid):\n    class StateInit(Enum):\n        Default = 0\n        Start = 1\n        Random = 2\n        Hybrid = 3\n\n    def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, headless):\n        state_init = cfg[\"env\"][\"stateInit\"]\n        self._state_init = HumanoidAMP.StateInit[state_init]\n        self._hybrid_init_prob = cfg[\"env\"][\"hybridInitProb\"]\n        self._num_amp_obs_steps = cfg[\"env\"][\"numAMPObsSteps\"]\n        assert(self._num_amp_obs_steps >= 2)\n\n        self._reset_default_env_ids = []\n        self._reset_ref_env_ids = []\n\n        super().__init__(cfg=cfg,\n                         sim_params=sim_params,\n                         physics_engine=physics_engine,\n                         device_type=device_type,\n                         device_id=device_id,\n                         headless=headless)\n\n        motion_file = cfg['env']['motion_file']\n        self._load_motion(motion_file)\n\n        self._amp_obs_buf = torch.zeros((self.num_envs, self._num_amp_obs_steps, self._num_amp_obs_per_step), device=self.device, dtype=torch.float)\n        self._curr_amp_obs_buf = self._amp_obs_buf[:, 0]\n        self._hist_amp_obs_buf = self._amp_obs_buf[:, 1:]\n        \n        self._amp_obs_demo_buf = None\n\n        return\n\n    def post_physics_step(self):\n        super().post_physics_step()\n        \n        self._update_hist_amp_obs()\n        self._compute_amp_observations()\n\n        amp_obs_flat = self._amp_obs_buf.view(-1, self.get_num_amp_obs())\n        self.extras[\"amp_obs\"] = amp_obs_flat\n\n        return\n\n    def get_num_amp_obs(self):\n        return self._num_amp_obs_steps * self._num_amp_obs_per_step\n\n    def fetch_amp_obs_demo(self, num_samples):\n\n        if (self._amp_obs_demo_buf is None):\n            self._build_amp_obs_demo_buf(num_samples)\n        else:\n            assert(self._amp_obs_demo_buf.shape[0] == num_samples)\n        \n        motion_ids = self._motion_lib.sample_motions(num_samples)\n        motion_times0 = self._motion_lib.sample_time(motion_ids)\n        amp_obs_demo = self.build_amp_obs_demo(motion_ids, motion_times0)\n        self._amp_obs_demo_buf[:] = amp_obs_demo.view(self._amp_obs_demo_buf.shape)\n        amp_obs_demo_flat = self._amp_obs_demo_buf.view(-1, self.get_num_amp_obs())\n\n        return amp_obs_demo_flat\n\n    def build_amp_obs_demo(self, motion_ids, motion_times0):\n        dt = self.dt\n\n        motion_ids = torch.tile(motion_ids.unsqueeze(-1), [1, self._num_amp_obs_steps])\n        motion_times = motion_times0.unsqueeze(-1)\n        time_steps = -dt * torch.arange(0, self._num_amp_obs_steps, device=self.device)\n        motion_times = motion_times + time_steps\n\n        motion_ids = motion_ids.view(-1)\n        motion_times = motion_times.view(-1)\n        root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, key_pos \\\n               = self._motion_lib.get_motion_state(motion_ids, motion_times)\n        amp_obs_demo = build_amp_observations(root_pos, root_rot, root_vel, root_ang_vel,\n                                              dof_pos, dof_vel, key_pos,\n                                              self._local_root_obs, self._root_height_obs,\n                                              self._dof_obs_size, self._dof_offsets)\n        return amp_obs_demo\n\n    def _build_amp_obs_demo_buf(self, num_samples):\n        self._amp_obs_demo_buf = torch.zeros((num_samples, self._num_amp_obs_steps, self._num_amp_obs_per_step), device=self.device, dtype=torch.float32)\n        return\n        \n    def _setup_character_props(self, key_bodies):\n        super()._setup_character_props(key_bodies)\n\n        asset_file = self.cfg[\"env\"][\"asset\"][\"assetFileName\"]\n        num_key_bodies = len(key_bodies)\n\n        if (asset_file == \"mjcf/amp_humanoid.xml\"):\n            self._num_amp_obs_per_step = 13 + self._dof_obs_size + 28 + 3 * num_key_bodies # [root_h, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_body_pos]\n        elif (asset_file == \"mjcf/amp_humanoid_sword_shield.xml\"):\n            self._num_amp_obs_per_step = 13 + self._dof_obs_size + 31 + 3 * num_key_bodies # [root_h, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_body_pos]\n        else:\n            print(\"Unsupported character config file: {s}\".format(asset_file))\n            assert(False)\n\n        return\n\n    def _load_motion(self, motion_file):\n        assert(self._dof_offsets[-1] == self.num_dof)\n        self._motion_lib = MotionLib(motion_file=motion_file,\n                                     dof_body_ids=self._dof_body_ids,\n                                     dof_offsets=self._dof_offsets,\n                                     key_body_ids=self._key_body_ids.cpu().numpy(), \n                                     device=self.device)\n        return\n\n    def _reset_envs(self, env_ids):\n        self._reset_default_env_ids = []\n        self._reset_ref_env_ids = []\n        super()._reset_envs(env_ids)\n        self._init_amp_obs(env_ids)\n\n        return\n\n    def _reset_actors(self, env_ids):\n        if (self._state_init == HumanoidAMP.StateInit.Default):\n            self._reset_default(env_ids)\n        elif (self._state_init == HumanoidAMP.StateInit.Start\n              or self._state_init == HumanoidAMP.StateInit.Random):\n            self._reset_ref_state_init(env_ids)\n        elif (self._state_init == HumanoidAMP.StateInit.Hybrid):\n            self._reset_hybrid_state_init(env_ids)\n        else:\n            assert(False), \"Unsupported state initialization strategy: {:s}\".format(str(self._state_init))\n        return\n\n    def _reset_default(self, env_ids):\n        super()._reset_actors(env_ids)\n        # self._humanoid_root_states[env_ids] = self._initial_humanoid_root_states[env_ids]\n        # self._dof_pos[env_ids] = self._initial_dof_pos[env_ids]\n        # self._dof_vel[env_ids] = self._initial_dof_vel[env_ids]\n        # self._reset_default_env_ids = env_ids\n        return\n\n    def _reset_ref_state_init(self, env_ids):\n        num_envs = env_ids.shape[0]\n        motion_ids = self._motion_lib.sample_motions(num_envs)\n        \n        if (self._state_init == HumanoidAMP.StateInit.Random\n            or self._state_init == HumanoidAMP.StateInit.Hybrid):\n            motion_times = self._motion_lib.sample_time(motion_ids)\n        elif (self._state_init == HumanoidAMP.StateInit.Start):\n            motion_times = torch.zeros(num_envs, device=self.device)\n        else:\n            assert(False), \"Unsupported state initialization strategy: {:s}\".format(str(self._state_init))\n\n        root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, key_pos \\\n               = self._motion_lib.get_motion_state(motion_ids, motion_times)\n\n        self._set_env_state(env_ids=env_ids, \n                            root_pos=root_pos, \n                            root_rot=root_rot, \n                            dof_pos=dof_pos, \n                            root_vel=root_vel, \n                            root_ang_vel=root_ang_vel, \n                            dof_vel=dof_vel)\n\n        self._reset_ref_env_ids = env_ids\n        self._reset_ref_motion_ids = motion_ids\n        self._reset_ref_motion_times = motion_times\n        return\n\n    def _reset_hybrid_state_init(self, env_ids):\n        num_envs = env_ids.shape[0]\n        ref_probs = to_torch(np.array([self._hybrid_init_prob] * num_envs), device=self.device)\n        ref_init_mask = torch.bernoulli(ref_probs) == 1.0\n\n        ref_reset_ids = env_ids[ref_init_mask]\n        if (len(ref_reset_ids) > 0):\n            self._reset_ref_state_init(ref_reset_ids)\n\n        default_reset_ids = env_ids[torch.logical_not(ref_init_mask)]\n        if (len(default_reset_ids) > 0):\n            self._reset_default(default_reset_ids)\n\n        return\n\n    def _init_amp_obs(self, env_ids):\n        self._compute_amp_observations(env_ids)\n        \n        if (len(self._reset_default_env_ids) > 0):\n            self._init_amp_obs_default(self._reset_default_env_ids)\n\n        if (len(self._reset_ref_env_ids) > 0):\n            self._init_amp_obs_ref(self._reset_ref_env_ids, self._reset_ref_motion_ids,\n                                   self._reset_ref_motion_times)\n        \n        return\n\n    def _init_amp_obs_default(self, env_ids):\n        curr_amp_obs = self._curr_amp_obs_buf[env_ids].unsqueeze(-2)\n        self._hist_amp_obs_buf[env_ids] = curr_amp_obs\n        return\n\n    def _init_amp_obs_ref(self, env_ids, motion_ids, motion_times):\n        dt = self.dt\n        motion_ids = torch.tile(motion_ids, [1, self._num_amp_obs_steps - 1])\n        motion_times = motion_times.unsqueeze(-1)\n        time_steps = -dt * (torch.arange(0, self._num_amp_obs_steps - 1, device=self.device) + 1)\n        motion_times = motion_times + time_steps\n\n        motion_ids = motion_ids.view(-1)\n        motion_times = motion_times.view(-1)\n        root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, key_pos \\\n               = self._motion_lib.get_motion_state(motion_ids, motion_times)\n        amp_obs_demo = build_amp_observations(root_pos, root_rot, root_vel, root_ang_vel, \n                                              dof_pos, dof_vel, key_pos, \n                                              self._local_root_obs, self._root_height_obs, \n                                              self._dof_obs_size, self._dof_offsets)\n        self._hist_amp_obs_buf[env_ids] = amp_obs_demo.view(self._hist_amp_obs_buf[env_ids].shape)\n        return\n    \n    def _set_env_state(self, env_ids, root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel):\n        self._humanoid_root_states[env_ids, 0:3] = root_pos\n        self._humanoid_root_states[env_ids, 3:7] = root_rot\n        self._humanoid_root_states[env_ids, 7:10] = root_vel\n        self._humanoid_root_states[env_ids, 10:13] = root_ang_vel\n        \n        self._dof_pos[env_ids] = dof_pos\n        self._dof_vel[env_ids] = dof_vel\n        return\n\n    def _update_hist_amp_obs(self, env_ids=None):\n        if (env_ids is None):\n            self._hist_amp_obs_buf[:] = self._amp_obs_buf[:, 0:(self._num_amp_obs_steps - 1)]\n        else:\n            self._hist_amp_obs_buf[env_ids] = self._amp_obs_buf[env_ids, 0:(self._num_amp_obs_steps - 1)]\n        return\n\n    def _compute_amp_observations(self, env_ids=None):\n        key_body_pos = self._rigid_body_pos[:, self._key_body_ids, :]\n        if (env_ids is None):\n            self._curr_amp_obs_buf[:] = build_amp_observations(self._rigid_body_pos[:, 0, :],\n                                                               self._rigid_body_rot[:, 0, :],\n                                                               self._rigid_body_vel[:, 0, :],\n                                                               self._rigid_body_ang_vel[:, 0, :],\n                                                               self._dof_pos, self._dof_vel, key_body_pos,\n                                                               self._local_root_obs, self._root_height_obs, \n                                                               self._dof_obs_size, self._dof_offsets)\n        else:\n            self._curr_amp_obs_buf[env_ids] = build_amp_observations(self._rigid_body_pos[env_ids][:, 0, :],\n                                                                   self._rigid_body_rot[env_ids][:, 0, :],\n                                                                   self._rigid_body_vel[env_ids][:, 0, :],\n                                                                   self._rigid_body_ang_vel[env_ids][:, 0, :],\n                                                                   self._dof_pos[env_ids], self._dof_vel[env_ids], key_body_pos[env_ids],\n                                                                   self._local_root_obs, self._root_height_obs, \n                                                                   self._dof_obs_size, self._dof_offsets)\n        return\n\n\n#####################################################################\n###=========================jit functions=========================###\n#####################################################################\n\n@torch.jit.script\ndef build_amp_observations(root_pos, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_body_pos, \n                           local_root_obs, root_height_obs, dof_obs_size, dof_offsets):\n    # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, bool, bool, int, List[int]) -> Tensor\n    root_h = root_pos[:, 2:3]\n    heading_rot = torch_utils.calc_heading_quat_inv(root_rot)\n\n    if (local_root_obs):\n        root_rot_obs = quat_mul(heading_rot, root_rot)\n    else:\n        root_rot_obs = root_rot\n    root_rot_obs = torch_utils.quat_to_tan_norm(root_rot_obs)\n    \n    if (not root_height_obs):\n        root_h_obs = torch.zeros_like(root_h)\n    else:\n        root_h_obs = root_h\n    \n    local_root_vel = quat_rotate(heading_rot, root_vel)\n    local_root_ang_vel = quat_rotate(heading_rot, root_ang_vel)\n\n    root_pos_expand = root_pos.unsqueeze(-2)\n    local_key_body_pos = key_body_pos - root_pos_expand\n    \n    heading_rot_expand = heading_rot.unsqueeze(-2)\n    heading_rot_expand = heading_rot_expand.repeat((1, local_key_body_pos.shape[1], 1))\n    flat_end_pos = local_key_body_pos.view(local_key_body_pos.shape[0] * local_key_body_pos.shape[1], local_key_body_pos.shape[2])\n    flat_heading_rot = heading_rot_expand.view(heading_rot_expand.shape[0] * heading_rot_expand.shape[1], \n                                               heading_rot_expand.shape[2])\n    local_end_pos = quat_rotate(flat_heading_rot, flat_end_pos)\n    flat_local_key_pos = local_end_pos.view(local_key_body_pos.shape[0], local_key_body_pos.shape[1] * local_key_body_pos.shape[2])\n    \n    dof_obs = dof_to_obs(dof_pos, dof_obs_size, dof_offsets)\n    obs = torch.cat((root_h_obs, root_rot_obs, local_root_vel, local_root_ang_vel, dof_obs, dof_vel, flat_local_key_pos), dim=-1)\n    return obs"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/humanoid_amp_task.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nimport torch\n\nimport timechamber.tasks.ase_humanoid_base.humanoid_amp as humanoid_amp\n\nclass HumanoidAMPTask(humanoid_amp.HumanoidAMP):\n    def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, headless):\n        self._enable_task_obs = cfg[\"env\"][\"enableTaskObs\"]\n\n        super().__init__(cfg=cfg,\n                         sim_params=sim_params,\n                         physics_engine=physics_engine,\n                         device_type=device_type,\n                         device_id=device_id,\n                         headless=headless)\n        return\n\n    def get_obs_size(self):\n        obs_size = super().get_obs_size()\n        if (self._enable_task_obs):\n            task_obs_size = self.get_task_obs_size()\n            obs_size += task_obs_size\n        return obs_size\n\n    def get_task_obs_size(self):\n        return 0\n\n    def pre_physics_step(self, actions):\n        super().pre_physics_step(actions)\n        self._update_task()\n        return\n\n    def render(self, sync_frame_time=False):\n        super().render(sync_frame_time)\n\n        if self.viewer:\n            self._draw_task()\n        return\n\n    def _update_task(self):\n        return\n\n    def _reset_envs(self, env_ids):\n        super()._reset_envs(env_ids)\n        self._reset_task(env_ids)\n        return\n\n    def _reset_task(self, env_ids):\n        return\n\n    def _compute_observations(self):\n        # humanoid_obs = self._compute_humanoid_obs()\n        \n        # if (self._enable_task_obs):\n        #     task_obs = self._compute_task_obs(env_ids=None)\n        #     obs = torch.cat([humanoid_obs, task_obs], dim=-1)\n        # else:\n        #     obs = humanoid_obs\n\n        # if (env_ids is None):\n            # self.obs_buf[:] = obs\n        # else:\n        #     self.obs_buf[env_ids] = obs\n        obs, obs_op = self._compute_humanoid_obs()\n        if (self._enable_task_obs):\n            task_obs = self._compute_task_obs(env_ids=None)\n            obs = torch.cat([obs, task_obs], dim=-1)\n        # else:\n\n        self.obs_buf[:self.num_envs] = obs\n        self.obs_buf[self.num_envs:] = obs_op\n\n        return\n\n    def _compute_task_obs(self, env_ids=None):\n        return NotImplemented\n\n    def _compute_reward(self, actions):\n        return NotImplemented\n\n    def _draw_task(self):\n        return"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/README.md",
    "content": "# poselib\n\n`poselib` is a library for loading, manipulating, and retargeting skeleton poses and motions. It is separated into three modules: `poselib.core` for basic data loading and tensor operations, `poselib.skeleton` for higher-level skeleton operations, and `poselib.visualization` for displaying skeleton poses. This library is built on top of the PyTorch framework and requires data to be in PyTorch tensors.\n\n## poselib.core\n- `poselib.core.rotation3d`: A set of Torch JIT functions for computing quaternions, transforms, and rotation/transformation matrices.\n    - `quat_*` manipulate and create quaternions in [x, y, z, w] format (where w is the real component).\n    - `transform_*` handle 7D transforms in [quat, pos] format.\n    - `rot_matrix_*` handle 3x3 rotation matrices.\n    - `euclidean_*` handle 4x4 Euclidean transformation matrices.\n- `poselib.core.tensor_utils`: Provides loading and saving functions for PyTorch tensors.\n\n## poselib.skeleton\n- `poselib.skeleton.skeleton3d`: Utilities for loading and manipulating skeleton poses, and retargeting poses to different skeletons.\n    - `SkeletonTree` is a class that stores a skeleton as a tree structure. This describes the skeleton topology and joints.\n    - `SkeletonState` describes the static state of a skeleton, and provides both global and local joint angles.\n    - `SkeletonMotion` describes a time-series of skeleton states and provides utilities for computing joint velocities.\n\n## poselib.visualization\n- `poselib.visualization.common`: Functions used for visualizing skeletons interactively in `matplotlib`.\n    - In SkeletonState visualization, use key `q` to quit window.\n    - In interactive SkeletonMotion visualization, you can use the following key commands:\n        - `w` - loop animation\n        - `x` - play/pause animation\n        - `z` - previous frame\n        - `c` - next frame\n        - `n` - quit window\n\n## Key Features\nPoselib provides several key features for working with animation data. We list some of the frequently used ones here, and provide instructions and examples on their usage.\n\n### Importing from FBX\nPoselib supports importing skeletal animation sequences from .fbx format into a SkeletonMotion representation. To use this functionality, you will need to first set up the Python FBX SDK on your machine using the following instructions.\n\nThis package is necessary to read data from fbx files, which is a proprietary file format owned by Autodesk. The latest FBX SDK tested was FBX SDK 2020.2.1 for Python 3.7, which can be found on the Autodesk website: https://www.autodesk.com/developer-network/platform-technologies/fbx-sdk-2020-2-1.\n\nFollow the instructions at https://help.autodesk.com/view/FBX/2020/ENU/?guid=FBX_Developer_Help_scripting_with_python_fbx_installing_python_fbx_html for download, install, and copy/paste instructions for the FBX Python SDK.\n\nThis repo provides an example script `fbx_importer.py` that shows usage of importing a .fbx file. Note that `SkeletonMotion.from_fbx()` takes in an optional parameter `root_joint`, which can be used to specify a joint in the skeleton tree as the root joint. If `root_joint` is not specified, we will default to using the first node in the FBX scene that contains animation data. \n\n### Importing from MJCF\nMJCF is a robotics file format supported by Isaac Gym. For convenience, we provide an API for importing MJCF assets into SkeletonTree definitions to represent the skeleton topology. An example script `mjcf_importer.py` is provided to show usage of this.\n\nThis can be helpful if motion sequences need to be retargeted to your simulation skeleton that's been created in MJCF format. Importing the file to SkeletonTree format will allow you to generate T-poses or other retargeting poses that can be used for retargeting. We also show an example of creating a T-Pose for our AMP Humanoid asset in `generate_amp_humanoid_tpose.py`.\n\n### Retargeting Motions\nRetargeting motions is important when your source data uses skeletons that have different morphologies than your target skeletons. We provide APIs for performing retarget of motion sequences in our SkeletonState and SkeletonMotion classes.\n\nTo use the retargeting API, users must provide the following information:\n  - source_motion: a SkeletonMotion npy representation of a motion sequence. The motion clip should use the same skeleton as the source T-Pose skeleton.\n  - target_motion_path: path to save the retargeted motion to\n  - source_tpose: a SkeletonState npy representation of the source skeleton in it's T-Pose state\n  - target_tpose: a SkeletonState npy representation of the target skeleton in it's T-Pose state (pose should match source T-Pose)\n  - joint_mapping: mapping of joint names from source to target\n  - rotation: root rotation offset from source to target skeleton (for transforming across different orientation axes), represented as a quaternion in XYZW order.\n  - scale: scale offset from source to target skeleton\n\nWe provide an example script `retarget_motion.py` to demonstrate usage of the retargeting API for the CMU Motion Capture Database. Note that the retargeting data for this script is stored in `data/configs/retarget_cmu_to_amp.json`.\n\nAdditionally, a SkeletonState T-Pose file and retargeting config file are also provided for the SFU Motion Capture Database. These can be found at `data/sfu_tpose.npy` and `data/configs/retarget_sfu_to_amp.json`.\n\n### Documentation\nWe provide a description of the functions and classes available in poselib in the comments of the APIs. Please check them out for more details.\n"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/data/configs/retarget_cmu_to_amp.json",
    "content": "{\n    \"source_motion\": \"data/01_01_cmu.npy\",\n    \"target_motion_path\": \"data/01_01_cmu_amp.npy\",\n    \"source_tpose\": \"data/cmu_tpose.npy\",\n    \"target_tpose\": \"data/amp_humanoid_tpose.npy\",\n    \"joint_mapping\": {\n         \"Hips\": \"pelvis\",\n         \"LeftUpLeg\": \"left_thigh\",\n         \"LeftLeg\": \"left_shin\",\n         \"LeftFoot\": \"left_foot\",\n         \"RightUpLeg\": \"right_thigh\",\n         \"RightLeg\": \"right_shin\",\n         \"RightFoot\": \"right_foot\",\n         \"Spine1\": \"torso\",\n         \"Head\": \"head\",\n         \"LeftArm\": \"left_upper_arm\",\n         \"LeftForeArm\": \"left_lower_arm\",\n         \"LeftHand\": \"left_hand\",\n         \"RightArm\": \"right_upper_arm\",\n         \"RightForeArm\": \"right_lower_arm\",\n         \"RightHand\": \"right_hand\"\n    },\n    \"rotation\": [0, 0, 0.7071068, 0.7071068],\n    \"scale\": 0.056444,\n\t\"root_height_offset\": 0.05,\n\t\"trim_frame_beg\": 75,\n\t\"trim_frame_end\": 372\n}"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/data/configs/retarget_sfu_to_amp.json",
    "content": "{\n    \"source_motion\": \"data/0005_Jogging001.npy\",\n    \"target_motion_path\": \"data/0005_Jogging001_amp.npy\",\n    \"source_tpose\": \"data/sfu_tpose.npy\",\n    \"target_tpose\": \"data/amp_humanoid_tpose.npy\",\n    \"joint_mapping\": {\n         \"Hips\": \"pelvis\",\n         \"LeftUpLeg\": \"left_thigh\",\n         \"LeftLeg\": \"left_shin\",\n         \"LeftFoot\": \"left_foot\",\n         \"RightUpLeg\": \"right_thigh\",\n         \"RightLeg\": \"right_shin\",\n         \"RightFoot\": \"right_foot\",\n         \"Spine1\": \"torso\",\n         \"Head\": \"head\",\n         \"LeftArm\": \"left_upper_arm\",\n         \"LeftForeArm\": \"left_lower_arm\",\n         \"LeftHand\": \"left_hand\",\n         \"RightArm\": \"right_upper_arm\",\n         \"RightForeArm\": \"right_lower_arm\",\n         \"RightHand\": \"right_hand\"\n    },\n    \"rotation\": [0.5, 0.5, 0.5, 0.5],\n    \"scale\": 0.01,\n    \"root_height_offset\": 0.0,\n    \"trim_frame_beg\": 0,\n    \"trim_frame_end\": 100\n}"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/fbx_importer.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n\nimport os\nimport json\n\nfrom poselib.skeleton.skeleton3d import SkeletonTree, SkeletonState, SkeletonMotion\nfrom poselib.visualization.common import plot_skeleton_state, plot_skeleton_motion_interactive\n\n# source fbx file path\nfbx_file = \"data/01_01_cmu.fbx\"\n\n# import fbx file - make sure to provide a valid joint name for root_joint\nmotion = SkeletonMotion.from_fbx(\n    fbx_file_path=fbx_file,\n    root_joint=\"Hips\",\n    fps=60\n)\n\n# save motion in npy format\nmotion.to_file(\"data/01_01_cmu.npy\")\n\n# visualize motion\nplot_skeleton_motion_interactive(motion)\n"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/generate_amp_humanoid_tpose.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n\nimport torch\n\nfrom poselib.core.rotation3d import *\nfrom poselib.skeleton.skeleton3d import SkeletonTree, SkeletonState\nfrom poselib.visualization.common import plot_skeleton_state\n\n\"\"\"\nThis scripts imports a MJCF XML file and converts the skeleton into a SkeletonTree format.\nIt then generates a zero rotation pose, and adjusts the pose into a T-Pose.\n\"\"\"\n\n# import MJCF file\nxml_path = \"../../../../assets/mjcf/amp_humanoid.xml\"\nskeleton = SkeletonTree.from_mjcf(xml_path)\n\n# generate zero rotation pose\nzero_pose = SkeletonState.zero_pose(skeleton)\n\n# adjust pose into a T Pose\nlocal_rotation = zero_pose.local_rotation\nlocal_rotation[skeleton.index(\"left_upper_arm\")] = quat_mul(\n    quat_from_angle_axis(angle=torch.tensor([90.0]), axis=torch.tensor([1.0, 0.0, 0.0]), degree=True), \n    local_rotation[skeleton.index(\"left_upper_arm\")]\n)\nlocal_rotation[skeleton.index(\"right_upper_arm\")] = quat_mul(\n    quat_from_angle_axis(angle=torch.tensor([-90.0]), axis=torch.tensor([1.0, 0.0, 0.0]), degree=True), \n    local_rotation[skeleton.index(\"right_upper_arm\")]\n)\ntranslation = zero_pose.root_translation\ntranslation += torch.tensor([0, 0, 0.9])\n\n# save and visualize T-pose\nzero_pose.to_file(\"data/amp_humanoid_tpose.npy\")\nplot_skeleton_state(zero_pose)"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/mjcf_importer.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n\nfrom poselib.skeleton.skeleton3d import SkeletonTree, SkeletonState\nfrom poselib.visualization.common import plot_skeleton_state\n\n# load in XML mjcf file and save zero rotation pose in npy format\nxml_path = \"../../../../assets/mjcf/nv_humanoid.xml\"\nskeleton = SkeletonTree.from_mjcf(xml_path)\nzero_pose = SkeletonState.zero_pose(skeleton)\nzero_pose.to_file(\"data/nv_humanoid.npy\")\n\n# visualize zero rotation pose\nplot_skeleton_state(zero_pose)"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/__init__.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n\n__version__ = \"0.0.1\"\n\nfrom .core import *\n"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/core/__init__.py",
    "content": "# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\nfrom .tensor_utils import *\nfrom .rotation3d import *\nfrom .backend import Serializable, logger\n"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/core/backend/__init__.py",
    "content": "# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\nfrom .abstract import Serializable\n\nfrom .logger import logger\n"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/core/backend/abstract.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n\nfrom abc import ABCMeta, abstractmethod, abstractclassmethod\nfrom collections import OrderedDict\nimport json\n\nimport numpy as np\nimport os\n\nTENSOR_CLASS = {}\n\n\ndef register(name):\n    global TENSOR_CLASS\n\n    def core(tensor_cls):\n        TENSOR_CLASS[name] = tensor_cls\n        return tensor_cls\n\n    return core\n\n\ndef _get_cls(name):\n    global TENSOR_CLASS\n    return TENSOR_CLASS[name]\n\n\nclass NumpyEncoder(json.JSONEncoder):\n    \"\"\" Special json encoder for numpy types \"\"\"\n\n    def default(self, obj):\n        if isinstance(\n            obj,\n            (\n                np.int_,\n                np.intc,\n                np.intp,\n                np.int8,\n                np.int16,\n                np.int32,\n                np.int64,\n                np.uint8,\n                np.uint16,\n                np.uint32,\n                np.uint64,\n            ),\n        ):\n            return int(obj)\n        elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):\n            return float(obj)\n        elif isinstance(obj, (np.ndarray,)):\n            return dict(__ndarray__=obj.tolist(), dtype=str(obj.dtype), shape=obj.shape)\n        return json.JSONEncoder.default(self, obj)\n\n\ndef json_numpy_obj_hook(dct):\n    if isinstance(dct, dict) and \"__ndarray__\" in dct:\n        data = np.asarray(dct[\"__ndarray__\"], dtype=dct[\"dtype\"])\n        return data.reshape(dct[\"shape\"])\n    return dct\n\n\nclass Serializable:\n    \"\"\" Implementation to read/write to file.\n    All class the is inherited from this class needs to implement to_dict() and \n    from_dict()\n    \"\"\"\n\n    @abstractclassmethod\n    def from_dict(cls, dict_repr, *args, **kwargs):\n        \"\"\" Read the object from an ordered dictionary\n\n        :param dict_repr: the ordered dictionary that is used to construct the object\n        :type dict_repr: OrderedDict\n        :param args, kwargs: the arguments that need to be passed into from_dict()\n        :type args, kwargs: additional arguments\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def to_dict(self):\n        \"\"\" Construct an ordered dictionary from the object\n        \n        :rtype: OrderedDict\n        \"\"\"\n        pass\n\n    @classmethod\n    def from_file(cls, path, *args, **kwargs):\n        \"\"\" Read the object from a file (either .npy or .json)\n\n        :param path: path of the file\n        :type path: string\n        :param args, kwargs: the arguments that need to be passed into from_dict()\n        :type args, kwargs: additional arguments\n        \"\"\"\n        if path.endswith(\".json\"):\n            with open(path, \"r\") as f:\n                d = json.load(f, object_hook=json_numpy_obj_hook)\n        elif path.endswith(\".npy\"):\n            d = np.load(path, allow_pickle=True).item()\n        else:\n            assert False, \"failed to load {} from {}\".format(cls.__name__, path)\n        assert d[\"__name__\"] == cls.__name__, \"the file belongs to {}, not {}\".format(\n            d[\"__name__\"], cls.__name__\n        )\n        return cls.from_dict(d, *args, **kwargs)\n\n    def to_file(self, path: str) -> None:\n        \"\"\" Write the object to a file (either .npy or .json)\n\n        :param path: path of the file\n        :type path: string\n        \"\"\"\n        if os.path.dirname(path) != \"\" and not os.path.exists(os.path.dirname(path)):\n            os.makedirs(os.path.dirname(path))\n        d = self.to_dict()\n        d[\"__name__\"] = self.__class__.__name__\n        if path.endswith(\".json\"):\n            with open(path, \"w\") as f:\n                json.dump(d, f, cls=NumpyEncoder, indent=4)\n        elif path.endswith(\".npy\"):\n            np.save(path, d)\n"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/core/backend/logger.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n\nimport logging\n\nlogger = logging.getLogger(\"poselib\")\nlogger.setLevel(logging.INFO)\n\nif not len(logger.handlers):\n    formatter = logging.Formatter(\n        fmt=\"%(asctime)-15s - %(levelname)s - %(module)s - %(message)s\"\n    )\n    handler = logging.StreamHandler()\n    handler.setFormatter(formatter)\n    logger.addHandler(handler)\n    logger.info(\"logger initialized\")\n"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/core/rotation3d.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n\nfrom typing import List, Optional\n\nimport math\nimport torch\n\n\n@torch.jit.script\ndef quat_mul(a, b):\n    \"\"\"\n    quaternion multiplication\n    \"\"\"\n    x1, y1, z1, w1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3]\n    x2, y2, z2, w2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3]\n\n    w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2\n    x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2\n    y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2\n    z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2\n\n    return torch.stack([x, y, z, w], dim=-1)\n\n\n@torch.jit.script\ndef quat_pos(x):\n    \"\"\"\n    make all the real part of the quaternion positive\n    \"\"\"\n    q = x\n    z = (q[..., 3:] < 0).float()\n    q = (1 - 2 * z) * q\n    return q\n\n\n@torch.jit.script\ndef quat_abs(x):\n    \"\"\"\n    quaternion norm (unit quaternion represents a 3D rotation, which has norm of 1)\n    \"\"\"\n    x = x.norm(p=2, dim=-1)\n    return x\n\n\n@torch.jit.script\ndef quat_unit(x):\n    \"\"\"\n    normalized quaternion with norm of 1\n    \"\"\"\n    norm = quat_abs(x).unsqueeze(-1)\n    return x / (norm.clamp(min=1e-9))\n\n\n@torch.jit.script\ndef quat_conjugate(x):\n    \"\"\"\n    quaternion with its imaginary part negated\n    \"\"\"\n    return torch.cat([-x[..., :3], x[..., 3:]], dim=-1)\n\n\n@torch.jit.script\ndef quat_real(x):\n    \"\"\"\n    real component of the quaternion\n    \"\"\"\n    return x[..., 3]\n\n\n@torch.jit.script\ndef quat_imaginary(x):\n    \"\"\"\n    imaginary components of the quaternion\n    \"\"\"\n    return x[..., :3]\n\n\n@torch.jit.script\ndef quat_norm_check(x):\n    \"\"\"\n    verify that a quaternion has norm 1\n    \"\"\"\n    assert bool(\n        (abs(x.norm(p=2, dim=-1) - 1) < 1e-3).all()\n    ), \"the quaternion is has non-1 norm: {}\".format(abs(x.norm(p=2, dim=-1) - 1))\n    assert bool((x[..., 3] >= 0).all()), \"the quaternion has negative real part\"\n\n\n@torch.jit.script\ndef quat_normalize(q):\n    \"\"\"\n    Construct 3D rotation from quaternion (the quaternion needs not to be normalized).\n    \"\"\"\n    q = quat_unit(quat_pos(q))  # normalized to positive and unit quaternion\n    return q\n\n\n@torch.jit.script\ndef quat_from_xyz(xyz):\n    \"\"\"\n    Construct 3D rotation from the imaginary component\n    \"\"\"\n    w = (1.0 - xyz.norm()).unsqueeze(-1)\n    assert bool((w >= 0).all()), \"xyz has its norm greater than 1\"\n    return torch.cat([xyz, w], dim=-1)\n\n\n@torch.jit.script\ndef quat_identity(shape: List[int]):\n    \"\"\"\n    Construct 3D identity rotation given shape\n    \"\"\"\n    w = torch.ones(shape + [1])\n    xyz = torch.zeros(shape + [3])\n    q = torch.cat([xyz, w], dim=-1)\n    return quat_normalize(q)\n\n\n@torch.jit.script\ndef quat_from_angle_axis(angle, axis, degree: bool = False):\n    \"\"\" Create a 3D rotation from angle and axis of rotation. The rotation is counter-clockwise \n    along the axis.\n\n    The rotation can be interpreted as a_R_b where frame \"b\" is the new frame that\n    gets rotated counter-clockwise along the axis from frame \"a\"\n\n    :param angle: angle of rotation\n    :type angle: Tensor\n    :param axis: axis of rotation\n    :type axis: Tensor\n    :param degree: put True here if the angle is given by degree\n    :type degree: bool, optional, default=False\n    \"\"\"\n    if degree:\n        angle = angle / 180.0 * math.pi\n    theta = (angle / 2).unsqueeze(-1)\n    axis = axis / (axis.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-9))\n    xyz = axis * theta.sin()\n    w = theta.cos()\n    return quat_normalize(torch.cat([xyz, w], dim=-1))\n\n\n@torch.jit.script\ndef quat_from_rotation_matrix(m):\n    \"\"\"\n    Construct a 3D rotation from a valid 3x3 rotation matrices.\n    Reference can be found here:\n    http://www.cg.info.hiroshima-cu.ac.jp/~miyazaki/knowledge/teche52.html\n\n    :param m: 3x3 orthogonal rotation matrices.\n    :type m: Tensor\n\n    :rtype: Tensor\n    \"\"\"\n    m = m.unsqueeze(0)\n    diag0 = m[..., 0, 0]\n    diag1 = m[..., 1, 1]\n    diag2 = m[..., 2, 2]\n\n    # Math stuff.\n    w = (((diag0 + diag1 + diag2 + 1.0) / 4.0).clamp(0.0, None)) ** 0.5\n    x = (((diag0 - diag1 - diag2 + 1.0) / 4.0).clamp(0.0, None)) ** 0.5\n    y = (((-diag0 + diag1 - diag2 + 1.0) / 4.0).clamp(0.0, None)) ** 0.5\n    z = (((-diag0 - diag1 + diag2 + 1.0) / 4.0).clamp(0.0, None)) ** 0.5\n\n    # Only modify quaternions where w > x, y, z.\n    c0 = (w >= x) & (w >= y) & (w >= z)\n    x[c0] *= (m[..., 2, 1][c0] - m[..., 1, 2][c0]).sign()\n    y[c0] *= (m[..., 0, 2][c0] - m[..., 2, 0][c0]).sign()\n    z[c0] *= (m[..., 1, 0][c0] - m[..., 0, 1][c0]).sign()\n\n    # Only modify quaternions where x > w, y, z\n    c1 = (x >= w) & (x >= y) & (x >= z)\n    w[c1] *= (m[..., 2, 1][c1] - m[..., 1, 2][c1]).sign()\n    y[c1] *= (m[..., 1, 0][c1] + m[..., 0, 1][c1]).sign()\n    z[c1] *= (m[..., 0, 2][c1] + m[..., 2, 0][c1]).sign()\n\n    # Only modify quaternions where y > w, x, z.\n    c2 = (y >= w) & (y >= x) & (y >= z)\n    w[c2] *= (m[..., 0, 2][c2] - m[..., 2, 0][c2]).sign()\n    x[c2] *= (m[..., 1, 0][c2] + m[..., 0, 1][c2]).sign()\n    z[c2] *= (m[..., 2, 1][c2] + m[..., 1, 2][c2]).sign()\n\n    # Only modify quaternions where z > w, x, y.\n    c3 = (z >= w) & (z >= x) & (z >= y)\n    w[c3] *= (m[..., 1, 0][c3] - m[..., 0, 1][c3]).sign()\n    x[c3] *= (m[..., 2, 0][c3] + m[..., 0, 2][c3]).sign()\n    y[c3] *= (m[..., 2, 1][c3] + m[..., 1, 2][c3]).sign()\n\n    return quat_normalize(torch.stack([x, y, z, w], dim=-1)).squeeze(0)\n\n\n@torch.jit.script\ndef quat_mul_norm(x, y):\n    \"\"\"\n    Combine two set of 3D rotations together using \\**\\* operator. The shape needs to be\n    broadcastable\n    \"\"\"\n    return quat_normalize(quat_mul(x, y))\n\n\n@torch.jit.script\ndef quat_rotate(rot, vec):\n    \"\"\"\n    Rotate a 3D vector with the 3D rotation\n    \"\"\"\n    other_q = torch.cat([vec, torch.zeros_like(vec[..., :1])], dim=-1)\n    return quat_imaginary(quat_mul(quat_mul(rot, other_q), quat_conjugate(rot)))\n\n\n@torch.jit.script\ndef quat_inverse(x):\n    \"\"\"\n    The inverse of the rotation\n    \"\"\"\n    return quat_conjugate(x)\n\n\n@torch.jit.script\ndef quat_identity_like(x):\n    \"\"\"\n    Construct identity 3D rotation with the same shape\n    \"\"\"\n    return quat_identity(x.shape[:-1])\n\n\n@torch.jit.script\ndef quat_angle_axis(x):\n    \"\"\"\n    The (angle, axis) representation of the rotation. The axis is normalized to unit length.\n    The angle is guaranteed to be between [0, pi].\n    \"\"\"\n    s = 2 * (x[..., 3] ** 2) - 1\n    angle = s.clamp(-1, 1).arccos()  # just to be safe\n    axis = x[..., :3]\n    axis /= axis.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-9)\n    return angle, axis\n\n\n@torch.jit.script\ndef quat_yaw_rotation(x, z_up: bool = True):\n    \"\"\"\n    Yaw rotation (rotation along z-axis)\n    \"\"\"\n    q = x\n    if z_up:\n        q = torch.cat([torch.zeros_like(q[..., 0:2]), q[..., 2:3], q[..., 3:]], dim=-1)\n    else:\n        q = torch.cat(\n            [\n                torch.zeros_like(q[..., 0:1]),\n                q[..., 1:2],\n                torch.zeros_like(q[..., 2:3]),\n                q[..., 3:4],\n            ],\n            dim=-1,\n        )\n    return quat_normalize(q)\n\n\n@torch.jit.script\ndef transform_from_rotation_translation(\n    r: Optional[torch.Tensor] = None, t: Optional[torch.Tensor] = None\n):\n    \"\"\"\n    Construct a transform from a quaternion and 3D translation. Only one of them can be None.\n    \"\"\"\n    assert r is not None or t is not None, \"rotation and translation can't be all None\"\n    if r is None:\n        assert t is not None\n        r = quat_identity(list(t.shape))\n    if t is None:\n        t = torch.zeros(list(r.shape) + [3])\n    return torch.cat([r, t], dim=-1)\n\n\n@torch.jit.script\ndef transform_identity(shape: List[int]):\n    \"\"\"\n    Identity transformation with given shape\n    \"\"\"\n    r = quat_identity(shape)\n    t = torch.zeros(shape + [3])\n    return transform_from_rotation_translation(r, t)\n\n\n\n@torch.jit.script\ndef transform_rotation(x):\n    \"\"\"Get rotation from transform\"\"\"\n    return x[..., :4]\n\n\n@torch.jit.script\ndef transform_translation(x):\n    \"\"\"Get translation from transform\"\"\"\n    return x[..., 4:]\n\n\n@torch.jit.script\ndef transform_inverse(x):\n    \"\"\"\n    Inverse transformation\n    \"\"\"\n    inv_so3 = quat_inverse(transform_rotation(x))\n    return transform_from_rotation_translation(\n        r=inv_so3, t=quat_rotate(inv_so3, -transform_translation(x))\n    )\n\n\n@torch.jit.script\ndef transform_identity_like(x):\n    \"\"\"\n    identity transformation with the same shape\n    \"\"\"\n    return transform_identity(x.shape)\n\n\n@torch.jit.script\ndef transform_mul(x, y):\n    \"\"\"\n    Combine two transformation together\n    \"\"\"\n    z = transform_from_rotation_translation(\n        r=quat_mul_norm(transform_rotation(x), transform_rotation(y)),\n        t=quat_rotate(transform_rotation(x), transform_translation(y))\n        + transform_translation(x),\n    )\n    return z\n\n\n@torch.jit.script\ndef transform_apply(rot, vec):\n    \"\"\"\n    Transform a 3D vector\n    \"\"\"\n    assert isinstance(vec, torch.Tensor)\n    return quat_rotate(transform_rotation(rot), vec) + transform_translation(rot)\n\n\n@torch.jit.script\ndef rot_matrix_det(x):\n    \"\"\"\n    Return the determinant of the 3x3 matrix. The shape of the tensor will be as same as the\n    shape of the matrix\n    \"\"\"\n    a, b, c = x[..., 0, 0], x[..., 0, 1], x[..., 0, 2]\n    d, e, f = x[..., 1, 0], x[..., 1, 1], x[..., 1, 2]\n    g, h, i = x[..., 2, 0], x[..., 2, 1], x[..., 2, 2]\n    t1 = a * (e * i - f * h)\n    t2 = b * (d * i - f * g)\n    t3 = c * (d * h - e * g)\n    return t1 - t2 + t3\n\n\n@torch.jit.script\ndef rot_matrix_integrity_check(x):\n    \"\"\"\n    Verify that a rotation matrix has a determinant of one and is orthogonal\n    \"\"\"\n    det = rot_matrix_det(x)\n    assert bool((abs(det - 1) < 1e-3).all()), \"the matrix has non-one determinant\"\n    rtr = x @ x.permute(torch.arange(x.dim() - 2), -1, -2)\n    rtr_gt = rtr.zeros_like()\n    rtr_gt[..., 0, 0] = 1\n    rtr_gt[..., 1, 1] = 1\n    rtr_gt[..., 2, 2] = 1\n    assert bool(((rtr - rtr_gt) < 1e-3).all()), \"the matrix is not orthogonal\"\n\n\n@torch.jit.script\ndef rot_matrix_from_quaternion(q):\n    \"\"\"\n    Construct rotation matrix from quaternion\n    \"\"\"\n    # Shortcuts for individual elements (using wikipedia's convention)\n    qi, qj, qk, qr = q[..., 0], q[..., 1], q[..., 2], q[..., 3]\n\n    # Set individual elements\n    R00 = 1.0 - 2.0 * (qj ** 2 + qk ** 2)\n    R01 = 2 * (qi * qj - qk * qr)\n    R02 = 2 * (qi * qk + qj * qr)\n    R10 = 2 * (qi * qj + qk * qr)\n    R11 = 1.0 - 2.0 * (qi ** 2 + qk ** 2)\n    R12 = 2 * (qj * qk - qi * qr)\n    R20 = 2 * (qi * qk - qj * qr)\n    R21 = 2 * (qj * qk + qi * qr)\n    R22 = 1.0 - 2.0 * (qi ** 2 + qj ** 2)\n\n    R0 = torch.stack([R00, R01, R02], dim=-1)\n    R1 = torch.stack([R10, R11, R12], dim=-1)\n    R2 = torch.stack([R10, R21, R22], dim=-1)\n\n    R = torch.stack([R0, R1, R2], dim=-2)\n\n    return R\n\n\n@torch.jit.script\ndef euclidean_to_rotation_matrix(x):\n    \"\"\"\n    Get the rotation matrix on the top-left corner of a Euclidean transformation matrix\n    \"\"\"\n    return x[..., :3, :3]\n\n\n@torch.jit.script\ndef euclidean_integrity_check(x):\n    euclidean_to_rotation_matrix(x)  # check 3d-rotation matrix\n    assert bool((x[..., 3, :3] == 0).all()), \"the last row is illegal\"\n    assert bool((x[..., 3, 3] == 1).all()), \"the last row is illegal\"\n\n\n@torch.jit.script\ndef euclidean_translation(x):\n    \"\"\"\n    Get the translation vector located at the last column of the matrix\n    \"\"\"\n    return x[..., :3, 3]\n\n\n@torch.jit.script\ndef euclidean_inverse(x):\n    \"\"\"\n    Compute the matrix that represents the inverse rotation\n    \"\"\"\n    s = x.zeros_like()\n    irot = quat_inverse(quat_from_rotation_matrix(x))\n    s[..., :3, :3] = irot\n    s[..., :3, 4] = quat_rotate(irot, -euclidean_translation(x))\n    return s\n\n\n@torch.jit.script\ndef euclidean_to_transform(transformation_matrix):\n    \"\"\"\n    Construct a transform from a Euclidean transformation matrix\n    \"\"\"\n    return transform_from_rotation_translation(\n        r=quat_from_rotation_matrix(\n            m=euclidean_to_rotation_matrix(transformation_matrix)\n        ),\n        t=euclidean_translation(transformation_matrix),\n    )\n\n"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/core/tensor_utils.py",
    "content": "# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\nfrom collections import OrderedDict\nfrom .backend import Serializable\nimport torch\n\n\nclass TensorUtils(Serializable):\n    @classmethod\n    def from_dict(cls, dict_repr, *args, **kwargs):\n        \"\"\" Read the object from an ordered dictionary\n\n        :param dict_repr: the ordered dictionary that is used to construct the object\n        :type dict_repr: OrderedDict\n        :param kwargs: the arguments that need to be passed into from_dict()\n        :type kwargs: additional arguments\n        \"\"\"\n        return torch.from_numpy(dict_repr[\"arr\"].astype(dict_repr[\"context\"][\"dtype\"]))\n\n    def to_dict(self):\n        \"\"\" Construct an ordered dictionary from the object\n        \n        :rtype: OrderedDict\n        \"\"\"\n        return NotImplemented\n\ndef tensor_to_dict(x):\n    \"\"\" Construct an ordered dictionary from the object\n    \n    :rtype: OrderedDict\n    \"\"\"\n    x_np = x.numpy()\n    return {\n        \"arr\": x_np,\n        \"context\": {\n            \"dtype\": x_np.dtype.name\n        }\n    }\n"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/core/tests/__init__.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/core/tests/test_rotation.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n\nfrom ..rotation3d import *\nimport numpy as np\nimport torch\n\nq = torch.from_numpy(np.array([[0, 1, 2, 3], [-2, 3, -1, 5]], dtype=np.float32))\nprint(\"q\", q)\nr = quat_normalize(q)\nx = torch.from_numpy(np.array([[1, 0, 0], [0, -1, 0]], dtype=np.float32))\nprint(r)\nprint(quat_rotate(r, x))\n\nangle = torch.from_numpy(np.array(np.random.rand() * 10.0, dtype=np.float32))\naxis = torch.from_numpy(\n    np.array([1, np.random.rand() * 10.0, np.random.rand() * 10.0], dtype=np.float32),\n)\n\nprint(repr(angle))\nprint(repr(axis))\n\nrot = quat_from_angle_axis(angle, axis)\nx = torch.from_numpy(np.random.rand(5, 6, 3))\ny = quat_rotate(quat_inverse(rot), quat_rotate(rot, x))\nprint(x.numpy())\nprint(y.numpy())\nassert np.allclose(x.numpy(), y.numpy())\n\nm = torch.from_numpy(np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]], dtype=np.float32))\nr = quat_from_rotation_matrix(m)\nt = torch.from_numpy(np.array([0, 1, 0], dtype=np.float32))\nse3 = transform_from_rotation_translation(r=r, t=t)\nprint(se3)\nprint(transform_apply(se3, t))\n\nrot = quat_from_angle_axis(\n    torch.from_numpy(np.array([45, -54], dtype=np.float32)),\n    torch.from_numpy(np.array([[1, 0, 0], [0, 1, 0]], dtype=np.float32)),\n    degree=True,\n)\ntrans = torch.from_numpy(np.array([[1, 1, 0], [1, 1, 0]], dtype=np.float32))\ntransform = transform_from_rotation_translation(r=rot, t=trans)\n\nt = transform_mul(transform, transform_inverse(transform))\ngt = np.zeros((2, 7))\ngt[:, 0] = 1.0\nprint(t.numpy())\nprint(gt)\n# assert np.allclose(t.numpy(), gt)\n\ntransform2 = torch.from_numpy(\n    np.array(\n        [[1, 0, 0, 1], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=np.float32\n    ),\n)\ntransform2 = euclidean_to_transform(transform2)\nprint(transform2)\n"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/skeleton/__init__.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/skeleton/backend/__init__.py",
    "content": ""
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/skeleton/backend/fbx/__init__.py",
    "content": "# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited."
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/skeleton/backend/fbx/fbx_backend.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n\n\"\"\"\nThis script reads an fbx file and returns the joint names, parents, and transforms.\n\nNOTE: It requires the Python FBX package to be installed.\n\"\"\"\n\nimport sys\n\nimport numpy as np\n\ntry:\n    import fbx\n    import FbxCommon\nexcept ImportError as e:\n    print(\"Error: FBX library failed to load - importing FBX data will not succeed. Message: {}\".format(e))\n    print(\"FBX tools must be installed from https://help.autodesk.com/view/FBX/2020/ENU/?guid=FBX_Developer_Help_scripting_with_python_fbx_installing_python_fbx_html\")\n\n\ndef fbx_to_npy(file_name_in, root_joint_name, fps):\n    \"\"\"\n    This function reads in an fbx file, and saves the relevant info to a numpy array\n\n    Fbx files have a series of animation curves, each of which has animations at different \n    times. This script assumes that for mocap data, there is only one animation curve that\n    contains all the joints. Otherwise it is unclear how to read in the data.\n\n    If this condition isn't met, then the method throws an error\n\n    :param file_name_in: str, file path in. Should be .fbx file\n    :return: nothing, it just writes a file.\n    \"\"\"\n\n    # Create the fbx scene object and load the .fbx file\n    fbx_sdk_manager, fbx_scene = FbxCommon.InitializeSdkObjects()\n    FbxCommon.LoadScene(fbx_sdk_manager, fbx_scene, file_name_in)\n\n    \"\"\"\n    To read in the animation, we must find the root node of the skeleton.\n    \n    Unfortunately fbx files can have \"scene parents\" and other parts of the tree that are \n    not joints\n    \n    As a crude fix, this reader just takes and finds the first thing which has an \n    animation curve attached\n    \"\"\"\n\n    search_root = (root_joint_name is None or root_joint_name == \"\")\n\n    # Get the root node of the skeleton, which is the child of the scene's root node\n    possible_root_nodes = [fbx_scene.GetRootNode()]\n    found_root_node = False\n    max_key_count = 0\n    root_joint = None\n    while len(possible_root_nodes) > 0:\n        joint = possible_root_nodes.pop(0)\n        if not search_root:\n            if joint.GetName() == root_joint_name:\n                root_joint = joint\n        try:\n            curve = _get_animation_curve(joint, fbx_scene)\n        except RuntimeError:\n            curve = None\n        if curve is not None:\n            key_count = curve.KeyGetCount()\n            if key_count > max_key_count:\n                found_root_node = True\n                max_key_count = key_count\n                root_curve = curve\n            if search_root and not root_joint:\n                root_joint = joint\n        for child_index in range(joint.GetChildCount()):\n            possible_root_nodes.append(joint.GetChild(child_index))\n    if not found_root_node:\n        raise RuntimeError(\"No root joint found!! Exiting\")\n\n    joint_list, joint_names, parents = _get_skeleton(root_joint)\n\n    \"\"\"\n    Read in the transformation matrices of the animation, taking the scaling into account\n    \"\"\"\n\n    anim_range, frame_count, frame_rate = _get_frame_count(fbx_scene)\n\n    local_transforms = []\n    #for frame in range(frame_count):\n    time_sec = anim_range.GetStart().GetSecondDouble()\n    time_range_sec = anim_range.GetStop().GetSecondDouble() - time_sec\n    fbx_fps = frame_count / time_range_sec\n    if fps != 120:\n        fbx_fps = fps\n    print(\"FPS: \", fbx_fps)\n    while time_sec < anim_range.GetStop().GetSecondDouble():\n        fbx_time = fbx.FbxTime()\n        fbx_time.SetSecondDouble(time_sec)\n        fbx_time = fbx_time.GetFramedTime()\n        transforms_current_frame = []\n\n        # Fbx has a unique time object which you need\n        #fbx_time = root_curve.KeyGetTime(frame)\n        for joint in joint_list:\n            arr = np.array(_recursive_to_list(joint.EvaluateLocalTransform(fbx_time)))\n            scales = np.array(_recursive_to_list(joint.EvaluateLocalScaling(fbx_time)))\n            if not np.allclose(scales[0:3], scales[0]):\n                raise ValueError(\n                    \"Different X, Y and Z scaling. Unsure how this should be handled. \"\n                    \"To solve this, look at this link and try to upgrade the script \"\n                    \"http://help.autodesk.com/view/FBX/2017/ENU/?guid=__files_GUID_10CDD\"\n                    \"63C_79C1_4F2D_BB28_AD2BE65A02ED_htm\"\n                )\n            # Adjust the array for scaling\n            arr /= scales[0]\n            arr[3, 3] = 1.0\n            transforms_current_frame.append(arr)\n        local_transforms.append(transforms_current_frame)\n\n        time_sec += (1.0/fbx_fps)\n\n    local_transforms = np.array(local_transforms)\n    print(\"Frame Count: \", len(local_transforms))\n\n    return joint_names, parents, local_transforms, fbx_fps\n\ndef _get_frame_count(fbx_scene):\n    # Get the animation stacks and layers, in order to pull off animation curves later\n    num_anim_stacks = fbx_scene.GetSrcObjectCount(\n        FbxCommon.FbxCriteria.ObjectType(FbxCommon.FbxAnimStack.ClassId)\n    )\n    # if num_anim_stacks != 1:\n    #     raise RuntimeError(\n    #         \"More than one animation stack was found. \"\n    #         \"This script must be modified to handle this case. Exiting\"\n    #     )\n    if num_anim_stacks > 1:\n        index = 1\n    else:\n        index = 0\n    anim_stack = fbx_scene.GetSrcObject(\n        FbxCommon.FbxCriteria.ObjectType(FbxCommon.FbxAnimStack.ClassId), index\n    )\n\n    anim_range = anim_stack.GetLocalTimeSpan()\n    duration = anim_range.GetDuration()\n    fps = duration.GetFrameRate(duration.GetGlobalTimeMode())\n    frame_count = duration.GetFrameCount(True)\n\n    return anim_range, frame_count, fps\n\ndef _get_animation_curve(joint, fbx_scene):\n    # Get the animation stacks and layers, in order to pull off animation curves later\n    num_anim_stacks = fbx_scene.GetSrcObjectCount(\n        FbxCommon.FbxCriteria.ObjectType(FbxCommon.FbxAnimStack.ClassId)\n    )\n    # if num_anim_stacks != 1:\n    #     raise RuntimeError(\n    #         \"More than one animation stack was found. \"\n    #         \"This script must be modified to handle this case. Exiting\"\n    #     )\n    if num_anim_stacks > 1:\n        index = 1\n    else:\n        index = 0\n    anim_stack = fbx_scene.GetSrcObject(\n        FbxCommon.FbxCriteria.ObjectType(FbxCommon.FbxAnimStack.ClassId), index\n    )\n\n    num_anim_layers = anim_stack.GetSrcObjectCount(\n        FbxCommon.FbxCriteria.ObjectType(FbxCommon.FbxAnimLayer.ClassId)\n    )\n    if num_anim_layers != 1:\n        raise RuntimeError(\n            \"More than one animation layer was found. \"\n            \"This script must be modified to handle this case. Exiting\"\n        )\n    animation_layer = anim_stack.GetSrcObject(\n        FbxCommon.FbxCriteria.ObjectType(FbxCommon.FbxAnimLayer.ClassId), 0\n    )\n\n    def _check_longest_curve(curve, max_curve_key_count):\n        longest_curve = None\n        if curve and curve.KeyGetCount() > max_curve_key_count[0]:\n            max_curve_key_count[0] = curve.KeyGetCount()\n            return True\n\n        return False\n\n    max_curve_key_count = [0]\n    longest_curve = None\n    for c in [\"X\", \"Y\", \"Z\"]:\n        curve = joint.LclTranslation.GetCurve(\n            animation_layer, c\n        )  # sample curve for translation\n        if _check_longest_curve(curve, max_curve_key_count):\n            longest_curve = curve\n\n        curve = joint.LclRotation.GetCurve(\n            animation_layer, \"X\"\n        )\n        if _check_longest_curve(curve, max_curve_key_count):\n            longest_curve = curve\n\n    return longest_curve\n\n\ndef _get_skeleton(root_joint):\n\n    # Do a depth first search of the skeleton to extract all the joints\n    joint_list = [root_joint]\n    joint_names = [root_joint.GetName()]\n    parents = [-1]  # -1 means no parent\n\n    def append_children(joint, pos):\n        \"\"\"\n        Depth first search function\n        :param joint: joint item in the fbx\n        :param pos: position of current element (for parenting)\n        :return: Nothing\n        \"\"\"\n        for child_index in range(joint.GetChildCount()):\n            child = joint.GetChild(child_index)\n            joint_list.append(child)\n            joint_names.append(child.GetName())\n            parents.append(pos)\n            append_children(child, len(parents) - 1)\n\n    append_children(root_joint, 0)\n    return joint_list, joint_names, parents\n\n\ndef _recursive_to_list(array):\n    \"\"\"\n    Takes some iterable that might contain iterables and converts it to a list of lists \n    [of lists... etc]\n\n    Mainly used for converting the strange fbx wrappers for c++ arrays into python lists\n    :param array: array to be converted\n    :return: array converted to lists\n    \"\"\"\n    try:\n        return float(array)\n    except TypeError:\n        return [_recursive_to_list(a) for a in array]\n\n\ndef parse_fbx(file_name_in, root_joint_name, fps):\n    return fbx_to_npy(file_name_in, root_joint_name, fps)\n"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/skeleton/backend/fbx/fbx_read_wrapper.py",
    "content": "# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n\"\"\"\nScript that reads in fbx files from python\n\nThis requires a configs file, which contains the command necessary to switch conda\nenvironments to run the fbx reading script from python\n\"\"\"\n\nfrom ....core import logger\n\nimport inspect\nimport os\n\nimport numpy as np\n\nfrom .fbx_backend import parse_fbx\n\n\ndef fbx_to_array(fbx_file_path, root_joint, fps):\n    \"\"\"\n    Reads an fbx file to an array.\n\n    :param fbx_file_path: str, file path to fbx\n    :return: tuple with joint_names, parents, transforms, frame time\n    \"\"\"\n\n    # Ensure the file path is valid\n    fbx_file_path = os.path.abspath(fbx_file_path)\n    assert os.path.exists(fbx_file_path)\n\n    # Parse FBX file\n    joint_names, parents, local_transforms, fbx_fps = parse_fbx(fbx_file_path, root_joint, fps)\n    return joint_names, parents, local_transforms, fbx_fps\n"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/skeleton/skeleton3d.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nimport os\nimport xml.etree.ElementTree as ET\nfrom collections import OrderedDict\nfrom typing import List, Optional, Type, Dict\n\nimport numpy as np\nimport torch\n\nfrom ..core import *\nfrom .backend.fbx.fbx_read_wrapper import fbx_to_array\nimport scipy.ndimage.filters as filters\n\n\nclass SkeletonTree(Serializable):\n    \"\"\"\n    A skeleton tree gives a complete description of a rigid skeleton. It describes a tree structure\n    over a list of nodes with their names indicated by strings. Each edge in the tree has a local\n    translation associated with it which describes the distance between the two nodes that it\n    connects. \n\n    Basic Usage:\n        >>> t = SkeletonTree.from_mjcf(SkeletonTree.__example_mjcf_path__)\n        >>> t\n        SkeletonTree(\n            node_names=['torso', 'front_left_leg', 'aux_1', 'front_left_foot', 'front_right_leg', 'aux_2', 'front_right_foot', 'left_back_leg', 'aux_3', 'left_back_foot', 'right_back_leg', 'aux_4', 'right_back_foot'],\n            parent_indices=tensor([-1,  0,  1,  2,  0,  4,  5,  0,  7,  8,  0, 10, 11]),\n            local_translation=tensor([[ 0.0000,  0.0000,  0.7500],\n                    [ 0.0000,  0.0000,  0.0000],\n                    [ 0.2000,  0.2000,  0.0000],\n                    [ 0.2000,  0.2000,  0.0000],\n                    [ 0.0000,  0.0000,  0.0000],\n                    [-0.2000,  0.2000,  0.0000],\n                    [-0.2000,  0.2000,  0.0000],\n                    [ 0.0000,  0.0000,  0.0000],\n                    [-0.2000, -0.2000,  0.0000],\n                    [-0.2000, -0.2000,  0.0000],\n                    [ 0.0000,  0.0000,  0.0000],\n                    [ 0.2000, -0.2000,  0.0000],\n                    [ 0.2000, -0.2000,  0.0000]])\n        )\n        >>> t.node_names\n        ['torso', 'front_left_leg', 'aux_1', 'front_left_foot', 'front_right_leg', 'aux_2', 'front_right_foot', 'left_back_leg', 'aux_3', 'left_back_foot', 'right_back_leg', 'aux_4', 'right_back_foot']\n        >>> t.parent_indices\n        tensor([-1,  0,  1,  2,  0,  4,  5,  0,  7,  8,  0, 10, 11])\n        >>> t.local_translation\n        tensor([[ 0.0000,  0.0000,  0.7500],\n                [ 0.0000,  0.0000,  0.0000],\n                [ 0.2000,  0.2000,  0.0000],\n                [ 0.2000,  0.2000,  0.0000],\n                [ 0.0000,  0.0000,  0.0000],\n                [-0.2000,  0.2000,  0.0000],\n                [-0.2000,  0.2000,  0.0000],\n                [ 0.0000,  0.0000,  0.0000],\n                [-0.2000, -0.2000,  0.0000],\n                [-0.2000, -0.2000,  0.0000],\n                [ 0.0000,  0.0000,  0.0000],\n                [ 0.2000, -0.2000,  0.0000],\n                [ 0.2000, -0.2000,  0.0000]])\n        >>> t.parent_of('front_left_leg')\n        'torso'\n        >>> t.index('front_right_foot')\n        6\n        >>> t[2]\n        'aux_1'\n    \"\"\"\n\n    __example_mjcf_path__ = os.path.join(\n        os.path.dirname(os.path.realpath(__file__)), \"tests/ant.xml\"\n    )\n\n    def __init__(self, node_names, parent_indices, local_translation):\n        \"\"\"\n        :param node_names: a list of names for each tree node\n        :type node_names: List[str]\n        :param parent_indices: an int32-typed tensor that represents the edge to its parent.\\\n        -1 represents the root node\n        :type parent_indices: Tensor\n        :param local_translation: a 3d vector that gives local translation information\n        :type local_translation: Tensor\n        \"\"\"\n        ln, lp, ll = len(node_names), len(parent_indices), len(local_translation)\n        assert len(set((ln, lp, ll))) == 1\n        self._node_names = node_names\n        self._parent_indices = parent_indices.long()\n        self._local_translation = local_translation\n        self._node_indices = {self.node_names[i]: i for i in range(len(self))}\n\n    def __len__(self):\n        \"\"\" number of nodes in the skeleton tree \"\"\"\n        return len(self.node_names)\n\n    def __iter__(self):\n        \"\"\" iterator that iterate through the name of each node \"\"\"\n        yield from self.node_names\n\n    def __getitem__(self, item):\n        \"\"\" get the name of the node given the index \"\"\"\n        return self.node_names[item]\n\n    def __repr__(self):\n        return (\n            \"SkeletonTree(\\n    node_names={},\\n    parent_indices={},\"\n            \"\\n    local_translation={}\\n)\".format(\n                self._indent(repr(self.node_names)),\n                self._indent(repr(self.parent_indices)),\n                self._indent(repr(self.local_translation)),\n            )\n        )\n\n    def _indent(self, s):\n        return \"\\n    \".join(s.split(\"\\n\"))\n\n    @property\n    def node_names(self):\n        return self._node_names\n\n    @property\n    def parent_indices(self):\n        return self._parent_indices\n\n    @property\n    def local_translation(self):\n        return self._local_translation\n\n    @property\n    def num_joints(self):\n        \"\"\" number of nodes in the skeleton tree \"\"\"\n        return len(self)\n\n    @classmethod\n    def from_dict(cls, dict_repr, *args, **kwargs):\n        return cls(\n            list(map(str, dict_repr[\"node_names\"])),\n            TensorUtils.from_dict(dict_repr[\"parent_indices\"], *args, **kwargs),\n            TensorUtils.from_dict(dict_repr[\"local_translation\"], *args, **kwargs),\n        )\n\n    def to_dict(self):\n        return OrderedDict(\n            [\n                (\"node_names\", self.node_names),\n                (\"parent_indices\", tensor_to_dict(self.parent_indices)),\n                (\"local_translation\", tensor_to_dict(self.local_translation)),\n            ]\n        )\n\n    @classmethod\n    def from_mjcf(cls, path: str) -> \"SkeletonTree\":\n        \"\"\"\n        Parses a mujoco xml scene description file and returns a Skeleton Tree.\n        We use the model attribute at the root as the name of the tree.\n        \n        :param path:\n        :type path: string\n        :return: The skeleton tree constructed from the mjcf file\n        :rtype: SkeletonTree\n        \"\"\"\n        tree = ET.parse(path)\n        xml_doc_root = tree.getroot()\n        xml_world_body = xml_doc_root.find(\"worldbody\")\n        if xml_world_body is None:\n            raise ValueError(\"MJCF parsed incorrectly please verify it.\")\n        # assume this is the root\n        xml_body_root = xml_world_body.find(\"body\")\n        if xml_body_root is None:\n            raise ValueError(\"MJCF parsed incorrectly please verify it.\")\n\n        node_names = []\n        parent_indices = []\n        local_translation = []\n\n        # recursively adding all nodes into the skel_tree\n        def _add_xml_node(xml_node, parent_index, node_index):\n            node_name = xml_node.attrib.get(\"name\")\n            # parse the local translation into float list\n            pos = np.fromstring(xml_node.attrib.get(\"pos\"), dtype=float, sep=\" \")\n            node_names.append(node_name)\n            parent_indices.append(parent_index)\n            local_translation.append(pos)\n            curr_index = node_index\n            node_index += 1\n            for next_node in xml_node.findall(\"body\"):\n                node_index = _add_xml_node(next_node, curr_index, node_index)\n            return node_index\n\n        _add_xml_node(xml_body_root, -1, 0)\n\n        return cls(\n            node_names,\n            torch.from_numpy(np.array(parent_indices, dtype=np.int32)),\n            torch.from_numpy(np.array(local_translation, dtype=np.float32)),\n        )\n\n    def parent_of(self, node_name):\n        \"\"\" get the name of the parent of the given node\n\n        :param node_name: the name of the node\n        :type node_name: string\n        :rtype: string\n        \"\"\"\n        return self[int(self.parent_indices[self.index(node_name)].item())]\n\n    def index(self, node_name):\n        \"\"\" get the index of the node\n        \n        :param node_name: the name of the node\n        :type node_name: string\n        :rtype: int\n        \"\"\"\n        return self._node_indices[node_name]\n\n    def drop_nodes_by_names(\n        self, node_names: List[str], pairwise_translation=None\n    ) -> \"SkeletonTree\":\n        new_length = len(self) - len(node_names)\n        new_node_names = []\n        new_local_translation = torch.zeros(\n            new_length, 3, dtype=self.local_translation.dtype\n        )\n        new_parent_indices = torch.zeros(new_length, dtype=self.parent_indices.dtype)\n        parent_indices = self.parent_indices.numpy()\n        new_node_indices: dict = {}\n        new_node_index = 0\n        for node_index in range(len(self)):\n            if self[node_index] in node_names:\n                continue\n            tb_node_index = parent_indices[node_index]\n            if tb_node_index != -1:\n                local_translation = self.local_translation[node_index, :]\n                while tb_node_index != -1 and self[tb_node_index] in node_names:\n                    local_translation += self.local_translation[tb_node_index, :]\n                    tb_node_index = parent_indices[tb_node_index]\n                assert tb_node_index != -1, \"the root node cannot be dropped\"\n\n                if pairwise_translation is not None:\n                    local_translation = pairwise_translation[\n                        tb_node_index, node_index, :\n                    ]\n            else:\n                local_translation = self.local_translation[node_index, :]\n\n            new_node_names.append(self[node_index])\n            new_local_translation[new_node_index, :] = local_translation\n            if tb_node_index == -1:\n                new_parent_indices[new_node_index] = -1\n            else:\n                new_parent_indices[new_node_index] = new_node_indices[\n                    self[tb_node_index]\n                ]\n            new_node_indices[self[node_index]] = new_node_index\n            new_node_index += 1\n\n        return SkeletonTree(new_node_names, new_parent_indices, new_local_translation)\n\n    def keep_nodes_by_names(\n        self, node_names: List[str], pairwise_translation=None\n    ) -> \"SkeletonTree\":\n        nodes_to_drop = list(filter(lambda x: x not in node_names, self))\n        return self.drop_nodes_by_names(nodes_to_drop, pairwise_translation)\n\n\nclass SkeletonState(Serializable):\n    \"\"\"\n    A skeleton state contains all the information needed to describe a static state of a skeleton.\n    It requires a skeleton tree, local/global rotation at each joint and the root translation.\n\n    Example:\n        >>> t = SkeletonTree.from_mjcf(SkeletonTree.__example_mjcf_path__)\n        >>> zero_pose = SkeletonState.zero_pose(t)\n        >>> plot_skeleton_state(zero_pose)  # can be imported from `.visualization.common`\n        [plot of the ant at zero pose\n        >>> local_rotation = zero_pose.local_rotation.clone()\n        >>> local_rotation[2] = torch.tensor([0, 0, 1, 0])\n        >>> new_pose = SkeletonState.from_rotation_and_root_translation(\n        ...             skeleton_tree=t,\n        ...             r=local_rotation,\n        ...             t=zero_pose.root_translation,\n        ...             is_local=True\n        ...         )\n        >>> new_pose.local_rotation\n        tensor([[0., 0., 0., 1.],\n                [0., 0., 0., 1.],\n                [0., 1., 0., 0.],\n                [0., 0., 0., 1.],\n                [0., 0., 0., 1.],\n                [0., 0., 0., 1.],\n                [0., 0., 0., 1.],\n                [0., 0., 0., 1.],\n                [0., 0., 0., 1.],\n                [0., 0., 0., 1.],\n                [0., 0., 0., 1.],\n                [0., 0., 0., 1.],\n                [0., 0., 0., 1.]])\n        >>> plot_skeleton_state(new_pose)  # you should be able to see one of ant's leg is bent\n        [plot of the ant with the new pose\n        >>> new_pose.global_rotation  # the local rotation is propagated to the global rotation at joint #3\n        tensor([[0., 0., 0., 1.],\n                [0., 0., 0., 1.],\n                [0., 1., 0., 0.],\n                [0., 1., 0., 0.],\n                [0., 0., 0., 1.],\n                [0., 0., 0., 1.],\n                [0., 0., 0., 1.],\n                [0., 0., 0., 1.],\n                [0., 0., 0., 1.],\n                [0., 0., 0., 1.],\n                [0., 0., 0., 1.],\n                [0., 0., 0., 1.],\n                [0., 0., 0., 1.]])\n\n    Global/Local Representation (cont. from the previous example)\n        >>> new_pose.is_local\n        True\n        >>> new_pose.tensor  # this will return the local rotation followed by the root translation\n        tensor([0., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0.,\n                0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1.,\n                0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.,\n                0.])\n        >>> new_pose.tensor.shape  # 4 * 13 (joint rotation) + 3 (root translatio\n        torch.Size([55])\n        >>> new_pose.global_repr().is_local\n        False\n        >>> new_pose.global_repr().tensor  # this will return the global rotation followed by the root translation instead\n        tensor([0., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0.,\n                0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1.,\n                0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.,\n                0.])\n        >>> new_pose.global_repr().tensor.shape  # 4 * 13 (joint rotation) + 3 (root translation\n        torch.Size([55])\n    \"\"\"\n\n    def __init__(self, tensor_backend, skeleton_tree, is_local):\n        self._skeleton_tree = skeleton_tree\n        self._is_local = is_local\n        self.tensor = tensor_backend.clone()\n\n    def __len__(self):\n        return self.tensor.shape[0]\n\n    @property\n    def rotation(self):\n        if not hasattr(self, \"_rotation\"):\n            self._rotation = self.tensor[..., : self.num_joints * 4].reshape(\n                *(self.tensor.shape[:-1] + (self.num_joints, 4))\n            )\n        return self._rotation\n\n    @property\n    def _local_rotation(self):\n        if self._is_local:\n            return self.rotation\n        else:\n            return None\n\n    @property\n    def _global_rotation(self):\n        if not self._is_local:\n            return self.rotation\n        else:\n            return None\n\n    @property\n    def is_local(self):\n        \"\"\" is the rotation represented in local frame? \n        \n        :rtype: bool\n        \"\"\"\n        return self._is_local\n\n    @property\n    def invariant_property(self):\n        return {\"skeleton_tree\": self.skeleton_tree, \"is_local\": self.is_local}\n\n    @property\n    def num_joints(self):\n        \"\"\" number of joints in the skeleton tree \n        \n        :rtype: int\n        \"\"\"\n        return self.skeleton_tree.num_joints\n\n    @property\n    def skeleton_tree(self):\n        \"\"\" skeleton tree \n        \n        :rtype: SkeletonTree\n        \"\"\"\n        return self._skeleton_tree\n\n    @property\n    def root_translation(self):\n        \"\"\" root translation \n        \n        :rtype: Tensor\n        \"\"\"\n        if not hasattr(self, \"_root_translation\"):\n            self._root_translation = self.tensor[\n                ..., self.num_joints * 4 : self.num_joints * 4 + 3\n            ]\n        return self._root_translation\n\n    @property\n    def global_transformation(self):\n        \"\"\" global transformation of each joint (transform from joint frame to global frame) \"\"\"\n        if not hasattr(self, \"_global_transformation\"):\n            local_transformation = self.local_transformation\n            global_transformation = []\n            parent_indices = self.skeleton_tree.parent_indices.numpy()\n            # global_transformation = local_transformation.identity_like()\n            for node_index in range(len(self.skeleton_tree)):\n                parent_index = parent_indices[node_index]\n                if parent_index == -1:\n                    global_transformation.append(\n                        local_transformation[..., node_index, :]\n                    )\n                else:\n                    global_transformation.append(\n                        transform_mul(\n                            global_transformation[parent_index],\n                            local_transformation[..., node_index, :],\n                        )\n                    )\n            self._global_transformation = torch.stack(global_transformation, axis=-2)\n        return self._global_transformation\n\n    @property\n    def global_rotation(self):\n        \"\"\" global rotation of each joint (rotation matrix to rotate from joint's F.O.R to global\n        F.O.R) \"\"\"\n        if self._global_rotation is None:\n            if not hasattr(self, \"_comp_global_rotation\"):\n                self._comp_global_rotation = transform_rotation(\n                    self.global_transformation\n                )\n            return self._comp_global_rotation\n        else:\n            return self._global_rotation\n\n    @property\n    def global_translation(self):\n        \"\"\" global translation of each joint \"\"\"\n        if not hasattr(self, \"_global_translation\"):\n            self._global_translation = transform_translation(self.global_transformation)\n        return self._global_translation\n\n    @property\n    def global_translation_xy(self):\n        \"\"\" global translation in xy \"\"\"\n        trans_xy_data = self.global_translation.zeros_like()\n        trans_xy_data[..., 0:2] = self.global_translation[..., 0:2]\n        return trans_xy_data\n\n    @property\n    def global_translation_xz(self):\n        \"\"\" global translation in xz \"\"\"\n        trans_xz_data = self.global_translation.zeros_like()\n        trans_xz_data[..., 0:1] = self.global_translation[..., 0:1]\n        trans_xz_data[..., 2:3] = self.global_translation[..., 2:3]\n        return trans_xz_data\n\n    @property\n    def local_rotation(self):\n        \"\"\" the rotation from child frame to parent frame given in the order of child nodes appeared\n        in `.skeleton_tree.node_names` \"\"\"\n        if self._local_rotation is None:\n            if not hasattr(self, \"_comp_local_rotation\"):\n                local_rotation = quat_identity_like(self.global_rotation)\n                for node_index in range(len(self.skeleton_tree)):\n                    parent_index = self.skeleton_tree.parent_indices[node_index]\n                    if parent_index == -1:\n                        local_rotation[..., node_index, :] = self.global_rotation[\n                            ..., node_index, :\n                        ]\n                    else:\n                        local_rotation[..., node_index, :] = quat_mul_norm(\n                            quat_inverse(self.global_rotation[..., parent_index, :]),\n                            self.global_rotation[..., node_index, :],\n                        )\n                self._comp_local_rotation = local_rotation\n            return self._comp_local_rotation\n        else:\n            return self._local_rotation\n\n    @property\n    def local_transformation(self):\n        \"\"\" local translation + local rotation. It describes the transformation from child frame to \n        parent frame given in the order of child nodes appeared in `.skeleton_tree.node_names` \"\"\"\n        if not hasattr(self, \"_local_transformation\"):\n            self._local_transformation = transform_from_rotation_translation(\n                r=self.local_rotation, t=self.local_translation\n            )\n        return self._local_transformation\n\n    @property\n    def local_translation(self):\n        \"\"\" local translation of the skeleton state. It is identical to the local translation in\n        `.skeleton_tree.local_translation` except the root translation. The root translation is\n        identical to `.root_translation` \"\"\"\n        if not hasattr(self, \"_local_translation\"):\n            broadcast_shape = (\n                tuple(self.tensor.shape[:-1])\n                + (len(self.skeleton_tree),)\n                + tuple(self.skeleton_tree.local_translation.shape[-1:])\n            )\n            local_translation = self.skeleton_tree.local_translation.broadcast_to(\n                *broadcast_shape\n            ).clone()\n            local_translation[..., 0, :] = self.root_translation\n            self._local_translation = local_translation\n        return self._local_translation\n\n    # Root Properties\n    @property\n    def root_translation_xy(self):\n        \"\"\" root translation on xy \"\"\"\n        if not hasattr(self, \"_root_translation_xy\"):\n            self._root_translation_xy = self.global_translation_xy[..., 0, :]\n        return self._root_translation_xy\n\n    @property\n    def global_root_rotation(self):\n        \"\"\" root rotation \"\"\"\n        if not hasattr(self, \"_global_root_rotation\"):\n            self._global_root_rotation = self.global_rotation[..., 0, :]\n        return self._global_root_rotation\n\n    @property\n    def global_root_yaw_rotation(self):\n        \"\"\" root yaw rotation \"\"\"\n        if not hasattr(self, \"_global_root_yaw_rotation\"):\n            self._global_root_yaw_rotation = self.global_root_rotation.yaw_rotation()\n        return self._global_root_yaw_rotation\n\n    # Properties relative to root\n    @property\n    def local_translation_to_root(self):\n        \"\"\" The 3D translation from joint frame to the root frame. \"\"\"\n        if not hasattr(self, \"_local_translation_to_root\"):\n            self._local_translation_to_root = (\n                self.global_translation - self.root_translation.unsqueeze(-1)\n            )\n        return self._local_translation_to_root\n\n    @property\n    def local_rotation_to_root(self):\n        \"\"\" The 3D rotation from joint frame to the root frame. It is equivalent to \n        The root_R_world * world_R_node \"\"\"\n        return (\n            quat_inverse(self.global_root_rotation).unsqueeze(-1) * self.global_rotation\n        )\n\n    def compute_forward_vector(\n        self,\n        left_shoulder_index,\n        right_shoulder_index,\n        left_hip_index,\n        right_hip_index,\n        gaussian_filter_width=20,\n    ):\n        \"\"\" Computes forward vector based on cross product of the up vector with \n        average of the right->left shoulder and hip vectors \"\"\"\n        global_positions = self.global_translation\n        # Perpendicular to the forward direction.\n        # Uses the shoulders and hips to find this.\n        side_direction = (\n            global_positions[:, left_shoulder_index].numpy()\n            - global_positions[:, right_shoulder_index].numpy()\n            + global_positions[:, left_hip_index].numpy()\n            - global_positions[:, right_hip_index].numpy()\n        )\n        side_direction = (\n            side_direction\n            / np.sqrt((side_direction ** 2).sum(axis=-1))[..., np.newaxis]\n        )\n\n        # Forward direction obtained by crossing with the up direction.\n        forward_direction = np.cross(side_direction, np.array([[0, 1, 0]]))\n\n        # Smooth the forward direction with a Gaussian.\n        # Axis 0 is the time/frame axis.\n        forward_direction = filters.gaussian_filter1d(\n            forward_direction, gaussian_filter_width, axis=0, mode=\"nearest\"\n        )\n        forward_direction = (\n            forward_direction\n            / np.sqrt((forward_direction ** 2).sum(axis=-1))[..., np.newaxis]\n        )\n\n        return torch.from_numpy(forward_direction)\n\n    @staticmethod\n    def _to_state_vector(rot, rt):\n        state_shape = rot.shape[:-2]\n        vr = rot.reshape(*(state_shape + (-1,)))\n        vt = rt.broadcast_to(*state_shape + rt.shape[-1:]).reshape(\n            *(state_shape + (-1,))\n        )\n        v = torch.cat([vr, vt], axis=-1)\n        return v\n\n    @classmethod\n    def from_dict(\n        cls: Type[\"SkeletonState\"], dict_repr: OrderedDict, *args, **kwargs\n    ) -> \"SkeletonState\":\n        rot = TensorUtils.from_dict(dict_repr[\"rotation\"], *args, **kwargs)\n        rt = TensorUtils.from_dict(dict_repr[\"root_translation\"], *args, **kwargs)\n        return cls(\n            SkeletonState._to_state_vector(rot, rt),\n            SkeletonTree.from_dict(dict_repr[\"skeleton_tree\"], *args, **kwargs),\n            dict_repr[\"is_local\"],\n        )\n\n    def to_dict(self) -> OrderedDict:\n        return OrderedDict(\n            [\n                (\"rotation\", tensor_to_dict(self.rotation)),\n                (\"root_translation\", tensor_to_dict(self.root_translation)),\n                (\"skeleton_tree\", self.skeleton_tree.to_dict()),\n                (\"is_local\", self.is_local),\n            ]\n        )\n\n    @classmethod\n    def from_rotation_and_root_translation(cls, skeleton_tree, r, t, is_local=True):\n        \"\"\"\n        Construct a skeleton state from rotation and root translation\n\n        :param skeleton_tree: the skeleton tree\n        :type skeleton_tree: SkeletonTree\n        :param r: rotation (either global or local)\n        :type r: Tensor\n        :param t: root translation\n        :type t: Tensor\n        :param is_local: to indicate that whether the rotation is local or global\n        :type is_local: bool, optional, default=True\n        \"\"\"\n        assert (\n            r.dim() > 0\n        ), \"the rotation needs to have at least 1 dimension (dim = {})\".format(r.dim)\n        return cls(\n            SkeletonState._to_state_vector(r, t),\n            skeleton_tree=skeleton_tree,\n            is_local=is_local,\n        )\n\n    @classmethod\n    def zero_pose(cls, skeleton_tree):\n        \"\"\"\n        Construct a zero-pose skeleton state from the skeleton tree by assuming that all the local\n        rotation is 0 and root translation is also 0.\n\n        :param skeleton_tree: the skeleton tree as the rigid body\n        :type skeleton_tree: SkeletonTree\n        \"\"\"\n        return cls.from_rotation_and_root_translation(\n            skeleton_tree=skeleton_tree,\n            r=quat_identity([skeleton_tree.num_joints]),\n            t=torch.zeros(3, dtype=skeleton_tree.local_translation.dtype),\n            is_local=True,\n        )\n\n    def local_repr(self):\n        \"\"\" \n        Convert the skeleton state into local representation. This will only affects the values of\n        .tensor. If the skeleton state already has `is_local=True`. This method will do nothing. \n\n        :rtype: SkeletonState\n        \"\"\"\n        if self.is_local:\n            return self\n        return SkeletonState.from_rotation_and_root_translation(\n            self.skeleton_tree,\n            r=self.local_rotation,\n            t=self.root_translation,\n            is_local=True,\n        )\n\n    def global_repr(self):\n        \"\"\" \n        Convert the skeleton state into global representation. This will only affects the values of\n        .tensor. If the skeleton state already has `is_local=False`. This method will do nothing. \n\n        :rtype: SkeletonState\n        \"\"\"\n        if not self.is_local:\n            return self\n        return SkeletonState.from_rotation_and_root_translation(\n            self.skeleton_tree,\n            r=self.global_rotation,\n            t=self.root_translation,\n            is_local=False,\n        )\n\n    def _get_pairwise_average_translation(self):\n        global_transform_inv = transform_inverse(self.global_transformation)\n        p1 = global_transform_inv.unsqueeze(-2)\n        p2 = self.global_transformation.unsqueeze(-3)\n\n        pairwise_translation = (\n            transform_translation(transform_mul(p1, p2))\n            .reshape(-1, len(self.skeleton_tree), len(self.skeleton_tree), 3)\n            .mean(axis=0)\n        )\n        return pairwise_translation\n\n    def _transfer_to(self, new_skeleton_tree: SkeletonTree):\n        old_indices = list(map(self.skeleton_tree.index, new_skeleton_tree))\n        return SkeletonState.from_rotation_and_root_translation(\n            new_skeleton_tree,\n            r=self.global_rotation[..., old_indices, :],\n            t=self.root_translation,\n            is_local=False,\n        )\n\n    def drop_nodes_by_names(\n        self, node_names: List[str], estimate_local_translation_from_states: bool = True\n    ) -> \"SkeletonState\":\n        \"\"\" \n        Drop a list of nodes from the skeleton and re-compute the local rotation to match the \n        original joint position as much as possible. \n\n        :param node_names: a list node names that specifies the nodes need to be dropped\n        :type node_names: List of strings\n        :param estimate_local_translation_from_states: the boolean indicator that specifies whether\\\n        or not to re-estimate the local translation from the states (avg.)\n        :type estimate_local_translation_from_states: boolean\n        :rtype: SkeletonState\n        \"\"\"\n        if estimate_local_translation_from_states:\n            pairwise_translation = self._get_pairwise_average_translation()\n        else:\n            pairwise_translation = None\n        new_skeleton_tree = self.skeleton_tree.drop_nodes_by_names(\n            node_names, pairwise_translation\n        )\n        return self._transfer_to(new_skeleton_tree)\n\n    def keep_nodes_by_names(\n        self, node_names: List[str], estimate_local_translation_from_states: bool = True\n    ) -> \"SkeletonState\":\n        \"\"\" \n        Keep a list of nodes and drop all other nodes from the skeleton and re-compute the local \n        rotation to match the original joint position as much as possible. \n\n        :param node_names: a list node names that specifies the nodes need to be dropped\n        :type node_names: List of strings\n        :param estimate_local_translation_from_states: the boolean indicator that specifies whether\\\n        or not to re-estimate the local translation from the states (avg.)\n        :type estimate_local_translation_from_states: boolean\n        :rtype: SkeletonState\n        \"\"\"\n        return self.drop_nodes_by_names(\n            list(filter(lambda x: (x not in node_names), self)),\n            estimate_local_translation_from_states,\n        )\n\n    def _remapped_to(\n        self, joint_mapping: Dict[str, str], target_skeleton_tree: SkeletonTree\n    ):\n        joint_mapping_inv = {target: source for source, target in joint_mapping.items()}\n        reduced_target_skeleton_tree = target_skeleton_tree.keep_nodes_by_names(\n            list(joint_mapping_inv)\n        )\n        n_joints = (\n            len(joint_mapping),\n            len(self.skeleton_tree),\n            len(reduced_target_skeleton_tree),\n        )\n        assert (\n            len(set(n_joints)) == 1\n        ), \"the joint mapping is not consistent with the skeleton trees\"\n        source_indices = list(\n            map(\n                lambda x: self.skeleton_tree.index(joint_mapping_inv[x]),\n                reduced_target_skeleton_tree,\n            )\n        )\n        target_local_rotation = self.local_rotation[..., source_indices, :]\n        return SkeletonState.from_rotation_and_root_translation(\n            skeleton_tree=reduced_target_skeleton_tree,\n            r=target_local_rotation,\n            t=self.root_translation,\n            is_local=True,\n        )\n\n    def retarget_to(\n        self,\n        joint_mapping: Dict[str, str],\n        source_tpose_local_rotation,\n        source_tpose_root_translation: np.ndarray,\n        target_skeleton_tree: SkeletonTree,\n        target_tpose_local_rotation,\n        target_tpose_root_translation: np.ndarray,\n        rotation_to_target_skeleton,\n        scale_to_target_skeleton: float,\n        z_up: bool = True,\n    ) -> \"SkeletonState\":\n        \"\"\" \n        Retarget the skeleton state to a target skeleton tree. This is a naive retarget\n        implementation with rough approximations. The function follows the procedures below.\n\n        Steps:\n            1. Drop the joints from the source (self) that do not belong to the joint mapping\\\n            with an implementation that is similar to \"keep_nodes_by_names()\" - take a\\\n            look at the function doc for more details (same for source_tpose)\n            \n            2. Rotate the source state and the source tpose by \"rotation_to_target_skeleton\"\\\n            to align the source with the target orientation\n            \n            3. Extract the root translation and normalize it to match the scale of the target\\\n            skeleton\n            \n            4. Extract the global rotation from source state relative to source tpose and\\\n            re-apply the relative rotation to the target tpose to construct the global\\\n            rotation after retargetting\n            \n            5. Combine the computed global rotation and the root translation from 3 and 4 to\\\n            complete the retargeting.\n            \n            6. Make feet on the ground (global translation z)\n\n        :param joint_mapping: a dictionary of that maps the joint node from the source skeleton to \\\n        the target skeleton\n        :type joint_mapping: Dict[str, str]\n        \n        :param source_tpose_local_rotation: the local rotation of the source skeleton\n        :type source_tpose_local_rotation: Tensor\n        \n        :param source_tpose_root_translation: the root translation of the source tpose\n        :type source_tpose_root_translation: np.ndarray\n        \n        :param target_skeleton_tree: the target skeleton tree\n        :type target_skeleton_tree: SkeletonTree\n        \n        :param target_tpose_local_rotation: the local rotation of the target skeleton\n        :type target_tpose_local_rotation: Tensor\n        \n        :param target_tpose_root_translation: the root translation of the target tpose\n        :type target_tpose_root_translation: Tensor\n        \n        :param rotation_to_target_skeleton: the rotation that needs to be applied to the source\\\n        skeleton to align with the target skeleton. Essentially the rotation is t_R_s, where t is\\\n        the frame of reference of the target skeleton and s is the frame of reference of the source\\\n        skeleton\n        :type rotation_to_target_skeleton: Tensor\n        :param scale_to_target_skeleton: the factor that needs to be multiplied from source\\\n        skeleton to target skeleton (unit in distance). For example, to go from `cm` to `m`, the \\\n        factor needs to be 0.01.\n        :type scale_to_target_skeleton: float\n        :rtype: SkeletonState\n        \"\"\"\n\n        # STEP 0: Preprocess\n        source_tpose = SkeletonState.from_rotation_and_root_translation(\n            skeleton_tree=self.skeleton_tree,\n            r=source_tpose_local_rotation,\n            t=source_tpose_root_translation,\n            is_local=True,\n        )\n        target_tpose = SkeletonState.from_rotation_and_root_translation(\n            skeleton_tree=target_skeleton_tree,\n            r=target_tpose_local_rotation,\n            t=target_tpose_root_translation,\n            is_local=True,\n        )\n\n        # STEP 1: Drop the irrelevant joints\n        pairwise_translation = self._get_pairwise_average_translation()\n        node_names = list(joint_mapping)\n        new_skeleton_tree = self.skeleton_tree.keep_nodes_by_names(\n            node_names, pairwise_translation\n        )\n\n        # TODO: combine the following steps before STEP 3\n        source_tpose = source_tpose._transfer_to(new_skeleton_tree)\n        source_state = self._transfer_to(new_skeleton_tree)\n\n        source_tpose = source_tpose._remapped_to(joint_mapping, target_skeleton_tree)\n        source_state = source_state._remapped_to(joint_mapping, target_skeleton_tree)\n\n        # STEP 2: Rotate the source to align with the target\n        new_local_rotation = source_tpose.local_rotation.clone()\n        new_local_rotation[..., 0, :] = quat_mul_norm(\n            rotation_to_target_skeleton, source_tpose.local_rotation[..., 0, :]\n        )\n\n        source_tpose = SkeletonState.from_rotation_and_root_translation(\n            skeleton_tree=source_tpose.skeleton_tree,\n            r=new_local_rotation,\n            t=quat_rotate(rotation_to_target_skeleton, source_tpose.root_translation),\n            is_local=True,\n        )\n\n        new_local_rotation = source_state.local_rotation.clone()\n        new_local_rotation[..., 0, :] = quat_mul_norm(\n            rotation_to_target_skeleton, source_state.local_rotation[..., 0, :]\n        )\n        source_state = SkeletonState.from_rotation_and_root_translation(\n            skeleton_tree=source_state.skeleton_tree,\n            r=new_local_rotation,\n            t=quat_rotate(rotation_to_target_skeleton, source_state.root_translation),\n            is_local=True,\n        )\n\n        # STEP 3: Normalize to match the target scale\n        root_translation_diff = (\n            source_state.root_translation - source_tpose.root_translation\n        ) * scale_to_target_skeleton\n        # STEP 4: the global rotation from source state relative to source tpose and\n        # re-apply to the target\n        current_skeleton_tree = source_state.skeleton_tree\n        target_tpose_global_rotation = source_state.global_rotation[0, :].clone()\n        for current_index, name in enumerate(current_skeleton_tree):\n            if name in target_tpose.skeleton_tree:\n                target_tpose_global_rotation[\n                    current_index, :\n                ] = target_tpose.global_rotation[\n                    target_tpose.skeleton_tree.index(name), :\n                ]\n\n        global_rotation_diff = quat_mul_norm(\n            source_state.global_rotation, quat_inverse(source_tpose.global_rotation)\n        )\n        new_global_rotation = quat_mul_norm(\n            global_rotation_diff, target_tpose_global_rotation\n        )\n\n        # STEP 5: Putting 3 and 4 together\n        current_skeleton_tree = source_state.skeleton_tree\n        shape = source_state.global_rotation.shape[:-1]\n        shape = shape[:-1] + target_tpose.global_rotation.shape[-2:-1]\n        new_global_rotation_output = quat_identity(shape)\n        for current_index, name in enumerate(target_skeleton_tree):\n            while name not in current_skeleton_tree:\n                name = target_skeleton_tree.parent_of(name)\n            parent_index = current_skeleton_tree.index(name)\n            new_global_rotation_output[:, current_index, :] = new_global_rotation[\n                :, parent_index, :\n            ]\n\n        source_state = SkeletonState.from_rotation_and_root_translation(\n            skeleton_tree=target_skeleton_tree,\n            r=new_global_rotation_output,\n            t=target_tpose.root_translation + root_translation_diff,\n            is_local=False,\n        ).local_repr()\n\n        return source_state\n\n    def retarget_to_by_tpose(\n        self,\n        joint_mapping: Dict[str, str],\n        source_tpose: \"SkeletonState\",\n        target_tpose: \"SkeletonState\",\n        rotation_to_target_skeleton,\n        scale_to_target_skeleton: float,\n    ) -> \"SkeletonState\":\n        \"\"\" \n        Retarget the skeleton state to a target skeleton tree. This is a naive retarget\n        implementation with rough approximations. See the method `retarget_to()` for more information\n\n        :param joint_mapping: a dictionary of that maps the joint node from the source skeleton to \\\n        the target skeleton\n        :type joint_mapping: Dict[str, str]\n        \n        :param source_tpose: t-pose of the source skeleton\n        :type source_tpose: SkeletonState\n        \n        :param target_tpose: t-pose of the target skeleton\n        :type target_tpose: SkeletonState\n        \n        :param rotation_to_target_skeleton: the rotation that needs to be applied to the source\\\n        skeleton to align with the target skeleton. Essentially the rotation is t_R_s, where t is\\\n        the frame of reference of the target skeleton and s is the frame of reference of the source\\\n        skeleton\n        :type rotation_to_target_skeleton: Tensor\n        :param scale_to_target_skeleton: the factor that needs to be multiplied from source\\\n        skeleton to target skeleton (unit in distance). For example, to go from `cm` to `m`, the \\\n        factor needs to be 0.01.\n        :type scale_to_target_skeleton: float\n        :rtype: SkeletonState\n        \"\"\"\n        assert (\n            len(source_tpose.shape) == 0 and len(target_tpose.shape) == 0\n        ), \"the retargeting script currently doesn't support vectorized operations\"\n        return self.retarget_to(\n            joint_mapping,\n            source_tpose.local_rotation,\n            source_tpose.root_translation,\n            target_tpose.skeleton_tree,\n            target_tpose.local_rotation,\n            target_tpose.root_translation,\n            rotation_to_target_skeleton,\n            scale_to_target_skeleton,\n        )\n\n\nclass SkeletonMotion(SkeletonState):\n    def __init__(self, tensor_backend, skeleton_tree, is_local, fps, *args, **kwargs):\n        self._fps = fps\n        super().__init__(tensor_backend, skeleton_tree, is_local, *args, **kwargs)\n\n    def clone(self):\n        return SkeletonMotion(\n            self.tensor.clone(), self.skeleton_tree, self._is_local, self._fps\n        )\n\n    @property\n    def invariant_property(self):\n        return {\n            \"skeleton_tree\": self.skeleton_tree,\n            \"is_local\": self.is_local,\n            \"fps\": self.fps,\n        }\n\n    @property\n    def global_velocity(self):\n        \"\"\" global velocity \"\"\"\n        curr_index = self.num_joints * 4 + 3\n        return self.tensor[..., curr_index : curr_index + self.num_joints * 3].reshape(\n            *(self.tensor.shape[:-1] + (self.num_joints, 3))\n        )\n\n    @property\n    def global_angular_velocity(self):\n        \"\"\" global angular velocity \"\"\"\n        curr_index = self.num_joints * 7 + 3\n        return self.tensor[..., curr_index : curr_index + self.num_joints * 3].reshape(\n            *(self.tensor.shape[:-1] + (self.num_joints, 3))\n        )\n\n    @property\n    def fps(self):\n        \"\"\" number of frames per second \"\"\"\n        return self._fps\n\n    @property\n    def time_delta(self):\n        \"\"\" time between two adjacent frames \"\"\"\n        return 1.0 / self.fps\n\n    @property\n    def global_root_velocity(self):\n        \"\"\" global root velocity \"\"\"\n        return self.global_velocity[..., 0, :]\n\n    @property\n    def global_root_angular_velocity(self):\n        \"\"\" global root angular velocity \"\"\"\n        return self.global_angular_velocity[..., 0, :]\n\n    @classmethod\n    def from_state_vector_and_velocity(\n        cls,\n        skeleton_tree,\n        state_vector,\n        global_velocity,\n        global_angular_velocity,\n        is_local,\n        fps,\n    ):\n        \"\"\"\n        Construct a skeleton motion from a skeleton state vector, global velocity and angular\n        velocity at each joint.\n\n        :param skeleton_tree: the skeleton tree that the motion is based on \n        :type skeleton_tree: SkeletonTree\n        :param state_vector: the state vector from the skeleton state by `.tensor`\n        :type state_vector: Tensor\n        :param global_velocity: the global velocity at each joint\n        :type global_velocity: Tensor\n        :param global_angular_velocity: the global angular velocity at each joint\n        :type global_angular_velocity: Tensor\n        :param is_local: if the rotation ins the state vector is given in local frame\n        :type is_local: boolean\n        :param fps: number of frames per second\n        :type fps: int\n\n        :rtype: SkeletonMotion\n        \"\"\"\n        state_shape = state_vector.shape[:-1]\n        v = global_velocity.reshape(*(state_shape + (-1,)))\n        av = global_angular_velocity.reshape(*(state_shape + (-1,)))\n        new_state_vector = torch.cat([state_vector, v, av], axis=-1)\n        return cls(\n            new_state_vector, skeleton_tree=skeleton_tree, is_local=is_local, fps=fps,\n        )\n\n    @classmethod\n    def from_skeleton_state(\n        cls: Type[\"SkeletonMotion\"], skeleton_state: SkeletonState, fps: int\n    ):\n        \"\"\"\n        Construct a skeleton motion from a skeleton state. The velocities are estimated using second\n        order gaussian filter along the last axis. The skeleton state must have at least .dim >= 1\n\n        :param skeleton_state: the skeleton state that the motion is based on \n        :type skeleton_state: SkeletonState\n        :param fps: number of frames per second\n        :type fps: int\n\n        :rtype: SkeletonMotion\n        \"\"\"\n        assert (\n            type(skeleton_state) == SkeletonState\n        ), \"expected type of {}, got {}\".format(SkeletonState, type(skeleton_state))\n        global_velocity = SkeletonMotion._compute_velocity(\n            p=skeleton_state.global_translation, time_delta=1 / fps\n        )\n        global_angular_velocity = SkeletonMotion._compute_angular_velocity(\n            r=skeleton_state.global_rotation, time_delta=1 / fps\n        )\n        return cls.from_state_vector_and_velocity(\n            skeleton_tree=skeleton_state.skeleton_tree,\n            state_vector=skeleton_state.tensor,\n            global_velocity=global_velocity,\n            global_angular_velocity=global_angular_velocity,\n            is_local=skeleton_state.is_local,\n            fps=fps,\n        )\n\n    @staticmethod\n    def _to_state_vector(rot, rt, vel, avel):\n        state_shape = rot.shape[:-2]\n        skeleton_state_v = SkeletonState._to_state_vector(rot, rt)\n        v = vel.reshape(*(state_shape + (-1,)))\n        av = avel.reshape(*(state_shape + (-1,)))\n        skeleton_motion_v = torch.cat([skeleton_state_v, v, av], axis=-1)\n        return skeleton_motion_v\n\n    @classmethod\n    def from_dict(\n        cls: Type[\"SkeletonMotion\"], dict_repr: OrderedDict, *args, **kwargs\n    ) -> \"SkeletonMotion\":\n        rot = TensorUtils.from_dict(dict_repr[\"rotation\"], *args, **kwargs)\n        rt = TensorUtils.from_dict(dict_repr[\"root_translation\"], *args, **kwargs)\n        vel = TensorUtils.from_dict(dict_repr[\"global_velocity\"], *args, **kwargs)\n        avel = TensorUtils.from_dict(\n            dict_repr[\"global_angular_velocity\"], *args, **kwargs\n        )\n        return cls(\n            SkeletonMotion._to_state_vector(rot, rt, vel, avel),\n            skeleton_tree=SkeletonTree.from_dict(\n                dict_repr[\"skeleton_tree\"], *args, **kwargs\n            ),\n            is_local=dict_repr[\"is_local\"],\n            fps=dict_repr[\"fps\"],\n        )\n\n    def to_dict(self) -> OrderedDict:\n        return OrderedDict(\n            [\n                (\"rotation\", tensor_to_dict(self.rotation)),\n                (\"root_translation\", tensor_to_dict(self.root_translation)),\n                (\"global_velocity\", tensor_to_dict(self.global_velocity)),\n                (\"global_angular_velocity\", tensor_to_dict(self.global_angular_velocity)),\n                (\"skeleton_tree\", self.skeleton_tree.to_dict()),\n                (\"is_local\", self.is_local),\n                (\"fps\", self.fps),\n            ]\n        )\n\n    @classmethod\n    def from_fbx(\n        cls: Type[\"SkeletonMotion\"],\n        fbx_file_path,\n        skeleton_tree=None,\n        is_local=True,\n        fps=120,\n        root_joint=\"\",\n        root_trans_index=0,\n        *args,\n        **kwargs,\n    ) -> \"SkeletonMotion\":\n        \"\"\"\n        Construct a skeleton motion from a fbx file (TODO - generalize this). If the skeleton tree\n        is not given, it will use the first frame of the mocap to construct the skeleton tree.\n\n        :param fbx_file_path: the path of the fbx file\n        :type fbx_file_path: string\n        :param fbx_configs: the configuration in terms of {\"tmp_path\": ..., \"fbx_py27_path\": ...}\n        :type fbx_configs: dict\n        :param skeleton_tree: the optional skeleton tree that the rotation will be applied to\n        :type skeleton_tree: SkeletonTree, optional\n        :param is_local: the state vector uses local or global rotation as the representation\n        :type is_local: bool, optional, default=True\n        :param fps: FPS of the FBX animation\n        :type fps: int, optional, default=120\n        :param root_joint: the name of the root joint for the skeleton\n        :type root_joint: string, optional, default=\"\" or the first node in the FBX scene with animation data\n        :param root_trans_index: index of joint to extract root transform from\n        :type root_trans_index: int, optional, default=0 or the root joint in the parsed skeleton\n        :rtype: SkeletonMotion\n        \"\"\"\n        joint_names, joint_parents, transforms, fps = fbx_to_array(\n            fbx_file_path, root_joint, fps\n        )\n        # swap the last two axis to match the convention\n        local_transform = euclidean_to_transform(\n            transformation_matrix=torch.from_numpy(\n                np.swapaxes(np.array(transforms), -1, -2),\n            ).float()\n        )\n        local_rotation = transform_rotation(local_transform)\n        root_translation = transform_translation(local_transform)[..., root_trans_index, :]\n        joint_parents = torch.from_numpy(np.array(joint_parents)).int()\n\n        if skeleton_tree is None:\n            local_translation = transform_translation(local_transform).reshape(\n                -1, len(joint_parents), 3\n            )[0]\n            skeleton_tree = SkeletonTree(joint_names, joint_parents, local_translation)\n        skeleton_state = SkeletonState.from_rotation_and_root_translation(\n            skeleton_tree, r=local_rotation, t=root_translation, is_local=True\n        )\n        if not is_local:\n            skeleton_state = skeleton_state.global_repr()\n        return cls.from_skeleton_state(\n            skeleton_state=skeleton_state, fps=fps\n        )\n\n    @staticmethod\n    def _compute_velocity(p, time_delta, guassian_filter=True):\n        velocity = torch.from_numpy(\n            filters.gaussian_filter1d(\n                np.gradient(p.numpy(), axis=-3), 2, axis=-3, mode=\"nearest\"\n            )\n            / time_delta,\n        )\n        return velocity\n\n    @staticmethod\n    def _compute_angular_velocity(r, time_delta: float, guassian_filter=True):\n        # assume the second last dimension is the time axis\n        diff_quat_data = quat_identity_like(r)\n        diff_quat_data[..., :-1, :, :] = quat_mul_norm(\n            r[..., 1:, :, :], quat_inverse(r[..., :-1, :, :])\n        )\n        diff_angle, diff_axis = quat_angle_axis(diff_quat_data)\n        angular_velocity = diff_axis * diff_angle.unsqueeze(-1) / time_delta\n        angular_velocity = torch.from_numpy(\n            filters.gaussian_filter1d(\n                angular_velocity.numpy(), 2, axis=-3, mode=\"nearest\"\n            ),\n        )\n        return angular_velocity\n\n    def crop(self, start: int, end: int, fps: Optional[int] = None):\n        \"\"\"\n        Crop the motion along its last axis. This is equivalent to performing a slicing on the\n        object with [..., start: end: skip_every] where skip_every = old_fps / fps. Note that the\n        new fps provided must be a factor of the original fps. \n\n        :param start: the beginning frame index\n        :type start: int\n        :param end: the ending frame index\n        :type end: int\n        :param fps: number of frames per second in the output (if not given the original fps will be used)\n        :type fps: int, optional\n        :rtype: SkeletonMotion\n        \"\"\"\n        if fps is None:\n            new_fps = int(self.fps)\n            old_fps = int(self.fps)\n        else:\n            new_fps = int(fps)\n            old_fps = int(self.fps)\n            assert old_fps % fps == 0, (\n                \"the resampling doesn't support fps with non-integer division \"\n                \"from the original fps: {} => {}\".format(old_fps, fps)\n            )\n        skip_every = old_fps // new_fps\n        return SkeletonMotion.from_skeleton_state(\n          SkeletonState.from_rotation_and_root_translation(\n            skeleton_tree=self.skeleton_tree,\n            t=self.root_translation[start:end:skip_every],\n            r=self.local_rotation[start:end:skip_every],\n            is_local=True\n          ),\n          fps=self.fps\n        )\n\n    def retarget_to(\n        self,\n        joint_mapping: Dict[str, str],\n        source_tpose_local_rotation,\n        source_tpose_root_translation: np.ndarray,\n        target_skeleton_tree: \"SkeletonTree\",\n        target_tpose_local_rotation,\n        target_tpose_root_translation: np.ndarray,\n        rotation_to_target_skeleton,\n        scale_to_target_skeleton: float,\n        z_up: bool = True,\n    ) -> \"SkeletonMotion\":\n        \"\"\" \n        Same as the one in :class:`SkeletonState`. This method discards all velocity information before\n        retargeting and re-estimate the velocity after the retargeting. The same fps is used in the\n        new retargetted motion.\n\n        :param joint_mapping: a dictionary of that maps the joint node from the source skeleton to \\\n        the target skeleton\n        :type joint_mapping: Dict[str, str]\n        \n        :param source_tpose_local_rotation: the local rotation of the source skeleton\n        :type source_tpose_local_rotation: Tensor\n        \n        :param source_tpose_root_translation: the root translation of the source tpose\n        :type source_tpose_root_translation: np.ndarray\n        \n        :param target_skeleton_tree: the target skeleton tree\n        :type target_skeleton_tree: SkeletonTree\n        \n        :param target_tpose_local_rotation: the local rotation of the target skeleton\n        :type target_tpose_local_rotation: Tensor\n        \n        :param target_tpose_root_translation: the root translation of the target tpose\n        :type target_tpose_root_translation: Tensor\n        \n        :param rotation_to_target_skeleton: the rotation that needs to be applied to the source\\\n        skeleton to align with the target skeleton. Essentially the rotation is t_R_s, where t is\\\n        the frame of reference of the target skeleton and s is the frame of reference of the source\\\n        skeleton\n        :type rotation_to_target_skeleton: Tensor\n        :param scale_to_target_skeleton: the factor that needs to be multiplied from source\\\n        skeleton to target skeleton (unit in distance). For example, to go from `cm` to `m`, the \\\n        factor needs to be 0.01.\n        :type scale_to_target_skeleton: float\n        :rtype: SkeletonMotion\n        \"\"\"\n        return SkeletonMotion.from_skeleton_state(\n            super().retarget_to(\n                joint_mapping,\n                source_tpose_local_rotation,\n                source_tpose_root_translation,\n                target_skeleton_tree,\n                target_tpose_local_rotation,\n                target_tpose_root_translation,\n                rotation_to_target_skeleton,\n                scale_to_target_skeleton,\n                z_up,\n            ),\n            self.fps,\n        )\n\n    def retarget_to_by_tpose(\n        self,\n        joint_mapping: Dict[str, str],\n        source_tpose: \"SkeletonState\",\n        target_tpose: \"SkeletonState\",\n        rotation_to_target_skeleton,\n        scale_to_target_skeleton: float,\n        z_up: bool = True,\n    ) -> \"SkeletonMotion\":\n        \"\"\" \n        Same as the one in :class:`SkeletonState`. This method discards all velocity information before\n        retargeting and re-estimate the velocity after the retargeting. The same fps is used in the\n        new retargetted motion.\n\n        :param joint_mapping: a dictionary of that maps the joint node from the source skeleton to \\\n        the target skeleton\n        :type joint_mapping: Dict[str, str]\n        \n        :param source_tpose: t-pose of the source skeleton\n        :type source_tpose: SkeletonState\n        \n        :param target_tpose: t-pose of the target skeleton\n        :type target_tpose: SkeletonState\n        \n        :param rotation_to_target_skeleton: the rotation that needs to be applied to the source\\\n        skeleton to align with the target skeleton. Essentially the rotation is t_R_s, where t is\\\n        the frame of reference of the target skeleton and s is the frame of reference of the source\\\n        skeleton\n        :type rotation_to_target_skeleton: Tensor\n        :param scale_to_target_skeleton: the factor that needs to be multiplied from source\\\n        skeleton to target skeleton (unit in distance). For example, to go from `cm` to `m`, the \\\n        factor needs to be 0.01.\n        :type scale_to_target_skeleton: float\n        :rtype: SkeletonMotion\n        \"\"\"\n        return self.retarget_to(\n            joint_mapping,\n            source_tpose.local_rotation,\n            source_tpose.root_translation,\n            target_tpose.skeleton_tree,\n            target_tpose.local_rotation,\n            target_tpose.root_translation,\n            rotation_to_target_skeleton,\n            scale_to_target_skeleton,\n            z_up,\n        )\n\n"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/visualization/__init__.py",
    "content": "# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited."
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/visualization/common.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nimport os\n\nfrom ..core import logger\nfrom .plt_plotter import Matplotlib3DPlotter\nfrom .skeleton_plotter_tasks import Draw3DSkeletonMotion, Draw3DSkeletonState\n\n\ndef plot_skeleton_state(skeleton_state, task_name=\"\"):\n    \"\"\"\n    Visualize a skeleton state\n\n    :param skeleton_state:\n    :param task_name:\n    :type skeleton_state: SkeletonState\n    :type task_name: string, optional\n    \"\"\"\n    logger.info(\"plotting {}\".format(task_name))\n    task = Draw3DSkeletonState(task_name=task_name, skeleton_state=skeleton_state)\n    plotter = Matplotlib3DPlotter(task)\n    plotter.show()\n\n\ndef plot_skeleton_states(skeleton_state, skip_n=1, task_name=\"\"):\n    \"\"\"\n    Visualize a sequence of skeleton state. The dimension of the skeleton state must be 1\n\n    :param skeleton_state:\n    :param task_name:\n    :type skeleton_state: SkeletonState\n    :type task_name: string, optional\n    \"\"\"\n    logger.info(\"plotting {} motion\".format(task_name))\n    assert len(skeleton_state.shape) == 1, \"the state must have only one dimension\"\n    task = Draw3DSkeletonState(task_name=task_name, skeleton_state=skeleton_state[0])\n    plotter = Matplotlib3DPlotter(task)\n    for frame_id in range(skeleton_state.shape[0]):\n        if frame_id % skip_n != 0:\n            continue\n        task.update(skeleton_state[frame_id])\n        plotter.update()\n    plotter.show()\n\n\ndef plot_skeleton_motion(skeleton_motion, skip_n=1, task_name=\"\"):\n    \"\"\"\n    Visualize a skeleton motion along its first dimension.\n\n    :param skeleton_motion:\n    :param task_name:\n    :type skeleton_motion: SkeletonMotion\n    :type task_name: string, optional\n    \"\"\"\n    logger.info(\"plotting {} motion\".format(task_name))\n    task = Draw3DSkeletonMotion(\n        task_name=task_name, skeleton_motion=skeleton_motion, frame_index=0\n    )\n    plotter = Matplotlib3DPlotter(task)\n    for frame_id in range(len(skeleton_motion)):\n        if frame_id % skip_n != 0:\n            continue\n        task.update(frame_id)\n        plotter.update()\n    plotter.show()\n\n\ndef plot_skeleton_motion_interactive_base(skeleton_motion, task_name=\"\"):\n    class PlotParams:\n        def __init__(self, total_num_frames):\n            self.current_frame = 0\n            self.playing = False\n            self.looping = False\n            self.confirmed = False\n            self.playback_speed = 4\n            self.total_num_frames = total_num_frames\n\n        def sync(self, other):\n            self.current_frame = other.current_frame\n            self.playing = other.playing\n            self.looping = other.current_frame\n            self.confirmed = other.confirmed\n            self.playback_speed = other.playback_speed\n            self.total_num_frames = other.total_num_frames\n\n    task = Draw3DSkeletonMotion(\n        task_name=task_name, skeleton_motion=skeleton_motion, frame_index=0\n    )\n    plotter = Matplotlib3DPlotter(task)\n\n    plot_params = PlotParams(total_num_frames=len(skeleton_motion))\n    print(\"Entered interactive plot - press 'n' to quit, 'h' for a list of commands\")\n\n    def press(event):\n        if event.key == \"x\":\n            plot_params.playing = not plot_params.playing\n        elif event.key == \"z\":\n            plot_params.current_frame = plot_params.current_frame - 1\n        elif event.key == \"c\":\n            plot_params.current_frame = plot_params.current_frame + 1\n        elif event.key == \"a\":\n            plot_params.current_frame = plot_params.current_frame - 20\n        elif event.key == \"d\":\n            plot_params.current_frame = plot_params.current_frame + 20\n        elif event.key == \"w\":\n            plot_params.looping = not plot_params.looping\n            print(\"Looping: {}\".format(plot_params.looping))\n        elif event.key == \"v\":\n            plot_params.playback_speed *= 2\n            print(\"playback speed: {}\".format(plot_params.playback_speed))\n        elif event.key == \"b\":\n            if plot_params.playback_speed != 1:\n                plot_params.playback_speed //= 2\n            print(\"playback speed: {}\".format(plot_params.playback_speed))\n        elif event.key == \"n\":\n            plot_params.confirmed = True\n        elif event.key == \"h\":\n            rows, columns = os.popen(\"stty size\", \"r\").read().split()\n            columns = int(columns)\n            print(\"=\" * columns)\n            print(\"x: play/pause\")\n            print(\"z: previous frame\")\n            print(\"c: next frame\")\n            print(\"a: jump 10 frames back\")\n            print(\"d: jump 10 frames forward\")\n            print(\"w: looping/non-looping\")\n            print(\"v: double speed (this can be applied multiple times)\")\n            print(\"b: half speed (this can be applied multiple times)\")\n            print(\"n: quit\")\n            print(\"h: help\")\n            print(\"=\" * columns)\n\n        print(\n            'current frame index: {}/{} (press \"n\" to quit)'.format(\n                plot_params.current_frame, plot_params.total_num_frames - 1\n            )\n        )\n\n    plotter.fig.canvas.mpl_connect(\"key_press_event\", press)\n    while True:\n        reset_trail = False\n        if plot_params.confirmed:\n            break\n        if plot_params.playing:\n            plot_params.current_frame += plot_params.playback_speed\n        if plot_params.current_frame >= plot_params.total_num_frames:\n            if plot_params.looping:\n                plot_params.current_frame %= plot_params.total_num_frames\n                reset_trail = True\n            else:\n                plot_params.current_frame = plot_params.total_num_frames - 1\n        if plot_params.current_frame < 0:\n            if plot_params.looping:\n                plot_params.current_frame %= plot_params.total_num_frames\n                reset_trail = True\n            else:\n                plot_params.current_frame = 0\n        yield plot_params\n        task.update(plot_params.current_frame, reset_trail)\n        plotter.update()\n\n\ndef plot_skeleton_motion_interactive(skeleton_motion, task_name=\"\"):\n    \"\"\"\n    Visualize a skeleton motion along its first dimension interactively.\n\n    :param skeleton_motion:\n    :param task_name:\n    :type skeleton_motion: SkeletonMotion\n    :type task_name: string, optional\n    \"\"\"\n    for _ in plot_skeleton_motion_interactive_base(skeleton_motion, task_name):\n        pass\n\n\ndef plot_skeleton_motion_interactive_multiple(*callables, sync=True):\n    for _ in zip(*callables):\n        if sync:\n            for p1, p2 in zip(_[:-1], _[1:]):\n                p2.sync(p1)\n\n\n# def plot_skeleton_motion_interactive_multiple_same(skeleton_motions, task_name=\"\"):\n\n"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/visualization/core.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n\"\"\"\nThe base abstract classes for plotter and the plotting tasks. It describes how the plotter\ndeals with the tasks in the general cases\n\"\"\"\nfrom typing import List\n\n\nclass BasePlotterTask(object):\n    _task_name: str  # unique name of the task\n    _task_type: str  # type of the task is used to identify which callable\n\n    def __init__(self, task_name: str, task_type: str) -> None:\n        self._task_name = task_name\n        self._task_type = task_type\n\n    @property\n    def task_name(self):\n        return self._task_name\n\n    @property\n    def task_type(self):\n        return self._task_type\n\n    def get_scoped_name(self, name):\n        return self._task_name + \"/\" + name\n\n    def __iter__(self):\n        \"\"\"Should override this function to return a list of task primitives\n        \"\"\"\n        raise NotImplementedError\n\n\nclass BasePlotterTasks(object):\n    def __init__(self, tasks) -> None:\n        self._tasks = tasks\n\n    def __iter__(self):\n        for task in self._tasks:\n            yield from task\n\n\nclass BasePlotter(object):\n    \"\"\"An abstract plotter which deals with a plotting task. The children class needs to implement\n    the functions to create/update the objects according to the task given\n    \"\"\"\n\n    _task_primitives: List[BasePlotterTask]\n\n    def __init__(self, task: BasePlotterTask) -> None:\n        self._task_primitives = []\n        self.create(task)\n\n    @property\n    def task_primitives(self):\n        return self._task_primitives\n\n    def create(self, task: BasePlotterTask) -> None:\n        \"\"\"Create more task primitives from a task for the plotter\"\"\"\n        new_task_primitives = list(task)  # get all task primitives\n        self._task_primitives += new_task_primitives  # append them\n        self._create_impl(new_task_primitives)\n\n    def update(self) -> None:\n        \"\"\"Update the plotter for any updates in the task primitives\"\"\"\n        self._update_impl(self._task_primitives)\n\n    def _update_impl(self, task_list: List[BasePlotterTask]) -> None:\n        raise NotImplementedError\n\n    def _create_impl(self, task_list: List[BasePlotterTask]) -> None:\n        raise NotImplementedError\n"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/visualization/plt_plotter.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n\n\"\"\"\nThe matplotlib plotter implementation for all the primitive tasks (in our case: lines and\ndots)\n\"\"\"\nfrom typing import Any, Callable, Dict, List\n\nimport matplotlib.pyplot as plt\nimport mpl_toolkits.mplot3d.axes3d as p3\n\nimport numpy as np\n\nfrom .core import BasePlotter, BasePlotterTask\n\n\nclass Matplotlib2DPlotter(BasePlotter):\n    _fig: plt.figure  # plt figure\n    _ax: plt.axis  # plt axis\n    # stores artist objects for each task (task name as the key)\n    _artist_cache: Dict[str, Any]\n    # callables for each task primitives\n    _create_impl_callables: Dict[str, Callable]\n    _update_impl_callables: Dict[str, Callable]\n\n    def __init__(self, task: \"BasePlotterTask\") -> None:\n        fig, ax = plt.subplots()\n        self._fig = fig\n        self._ax = ax\n        self._artist_cache = {}\n\n        self._create_impl_callables = {\n            \"Draw2DLines\": self._lines_create_impl,\n            \"Draw2DDots\": self._dots_create_impl,\n            \"Draw2DTrail\": self._trail_create_impl,\n        }\n        self._update_impl_callables = {\n            \"Draw2DLines\": self._lines_update_impl,\n            \"Draw2DDots\": self._dots_update_impl,\n            \"Draw2DTrail\": self._trail_update_impl,\n        }\n        self._init_lim()\n        super().__init__(task)\n\n    @property\n    def ax(self):\n        return self._ax\n\n    @property\n    def fig(self):\n        return self._fig\n\n    def show(self):\n        plt.show()\n\n    def _min(self, x, y):\n        if x is None:\n            return y\n        if y is None:\n            return x\n        return min(x, y)\n\n    def _max(self, x, y):\n        if x is None:\n            return y\n        if y is None:\n            return x\n        return max(x, y)\n\n    def _init_lim(self):\n        self._curr_x_min = None\n        self._curr_y_min = None\n        self._curr_x_max = None\n        self._curr_y_max = None\n\n    def _update_lim(self, xs, ys):\n        self._curr_x_min = self._min(np.min(xs), self._curr_x_min)\n        self._curr_y_min = self._min(np.min(ys), self._curr_y_min)\n        self._curr_x_max = self._max(np.max(xs), self._curr_x_max)\n        self._curr_y_max = self._max(np.max(ys), self._curr_y_max)\n\n    def _set_lim(self):\n        if not (\n            self._curr_x_min is None\n            or self._curr_x_max is None\n            or self._curr_y_min is None\n            or self._curr_y_max is None\n        ):\n            self._ax.set_xlim(self._curr_x_min, self._curr_x_max)\n            self._ax.set_ylim(self._curr_y_min, self._curr_y_max)\n        self._init_lim()\n\n    @staticmethod\n    def _lines_extract_xy_impl(index, lines_task):\n        return lines_task[index, :, 0], lines_task[index, :, 1]\n\n    @staticmethod\n    def _trail_extract_xy_impl(index, trail_task):\n        return (trail_task[index : index + 2, 0], trail_task[index : index + 2, 1])\n\n    def _lines_create_impl(self, lines_task):\n        color = lines_task.color\n        self._artist_cache[lines_task.task_name] = [\n            self._ax.plot(\n                *Matplotlib2DPlotter._lines_extract_xy_impl(i, lines_task),\n                color=color,\n                linewidth=lines_task.line_width,\n                alpha=lines_task.alpha\n            )[0]\n            for i in range(len(lines_task))\n        ]\n\n    def _lines_update_impl(self, lines_task):\n        lines_artists = self._artist_cache[lines_task.task_name]\n        for i in range(len(lines_task)):\n            artist = lines_artists[i]\n            xs, ys = Matplotlib2DPlotter._lines_extract_xy_impl(i, lines_task)\n            artist.set_data(xs, ys)\n            if lines_task.influence_lim:\n                self._update_lim(xs, ys)\n\n    def _dots_create_impl(self, dots_task):\n        color = dots_task.color\n        self._artist_cache[dots_task.task_name] = self._ax.plot(\n            dots_task[:, 0],\n            dots_task[:, 1],\n            c=color,\n            linestyle=\"\",\n            marker=\".\",\n            markersize=dots_task.marker_size,\n            alpha=dots_task.alpha,\n        )[0]\n\n    def _dots_update_impl(self, dots_task):\n        dots_artist = self._artist_cache[dots_task.task_name]\n        dots_artist.set_data(dots_task[:, 0], dots_task[:, 1])\n        if dots_task.influence_lim:\n            self._update_lim(dots_task[:, 0], dots_task[:, 1])\n\n    def _trail_create_impl(self, trail_task):\n        color = trail_task.color\n        trail_length = len(trail_task) - 1\n        self._artist_cache[trail_task.task_name] = [\n            self._ax.plot(\n                *Matplotlib2DPlotter._trail_extract_xy_impl(i, trail_task),\n                color=trail_task.color,\n                linewidth=trail_task.line_width,\n                alpha=trail_task.alpha * (1.0 - i / (trail_length - 1))\n            )[0]\n            for i in range(trail_length)\n        ]\n\n    def _trail_update_impl(self, trail_task):\n        trails_artists = self._artist_cache[trail_task.task_name]\n        for i in range(len(trail_task) - 1):\n            artist = trails_artists[i]\n            xs, ys = Matplotlib2DPlotter._trail_extract_xy_impl(i, trail_task)\n            artist.set_data(xs, ys)\n            if trail_task.influence_lim:\n                self._update_lim(xs, ys)\n\n    def _create_impl(self, task_list):\n        for task in task_list:\n            self._create_impl_callables[task.task_type](task)\n        self._draw()\n\n    def _update_impl(self, task_list):\n        for task in task_list:\n            self._update_impl_callables[task.task_type](task)\n        self._draw()\n\n    def _set_aspect_equal_2d(self, zero_centered=True):\n        xlim = self._ax.get_xlim()\n        ylim = self._ax.get_ylim()\n\n        if not zero_centered:\n            xmean = np.mean(xlim)\n            ymean = np.mean(ylim)\n        else:\n            xmean = 0\n            ymean = 0\n\n        plot_radius = max(\n            [\n                abs(lim - mean_)\n                for lims, mean_ in ((xlim, xmean), (ylim, ymean))\n                for lim in lims\n            ]\n        )\n\n        self._ax.set_xlim([xmean - plot_radius, xmean + plot_radius])\n        self._ax.set_ylim([ymean - plot_radius, ymean + plot_radius])\n\n    def _draw(self):\n        self._set_lim()\n        self._set_aspect_equal_2d()\n        self._fig.canvas.draw()\n        self._fig.canvas.flush_events()\n        plt.pause(0.00001)\n\n\nclass Matplotlib3DPlotter(BasePlotter):\n    _fig: plt.figure  # plt figure\n    _ax: p3.Axes3D  # plt 3d axis\n    # stores artist objects for each task (task name as the key)\n    _artist_cache: Dict[str, Any]\n    # callables for each task primitives\n    _create_impl_callables: Dict[str, Callable]\n    _update_impl_callables: Dict[str, Callable]\n\n    def __init__(self, task: \"BasePlotterTask\") -> None:\n        self._fig = plt.figure()\n        self._ax = p3.Axes3D(self._fig)\n        self._artist_cache = {}\n\n        self._create_impl_callables = {\n            \"Draw3DLines\": self._lines_create_impl,\n            \"Draw3DDots\": self._dots_create_impl,\n            \"Draw3DTrail\": self._trail_create_impl,\n        }\n        self._update_impl_callables = {\n            \"Draw3DLines\": self._lines_update_impl,\n            \"Draw3DDots\": self._dots_update_impl,\n            \"Draw3DTrail\": self._trail_update_impl,\n        }\n        self._init_lim()\n        super().__init__(task)\n\n    @property\n    def ax(self):\n        return self._ax\n\n    @property\n    def fig(self):\n        return self._fig\n\n    def show(self):\n        plt.show()\n\n    def _min(self, x, y):\n        if x is None:\n            return y\n        if y is None:\n            return x\n        return min(x, y)\n\n    def _max(self, x, y):\n        if x is None:\n            return y\n        if y is None:\n            return x\n        return max(x, y)\n\n    def _init_lim(self):\n        self._curr_x_min = None\n        self._curr_y_min = None\n        self._curr_z_min = None\n        self._curr_x_max = None\n        self._curr_y_max = None\n        self._curr_z_max = None\n\n    def _update_lim(self, xs, ys, zs):\n        self._curr_x_min = self._min(np.min(xs), self._curr_x_min)\n        self._curr_y_min = self._min(np.min(ys), self._curr_y_min)\n        self._curr_z_min = self._min(np.min(zs), self._curr_z_min)\n        self._curr_x_max = self._max(np.max(xs), self._curr_x_max)\n        self._curr_y_max = self._max(np.max(ys), self._curr_y_max)\n        self._curr_z_max = self._max(np.max(zs), self._curr_z_max)\n\n    def _set_lim(self):\n        if not (\n            self._curr_x_min is None\n            or self._curr_x_max is None\n            or self._curr_y_min is None\n            or self._curr_y_max is None\n            or self._curr_z_min is None\n            or self._curr_z_max is None\n        ):\n            self._ax.set_xlim3d(self._curr_x_min, self._curr_x_max)\n            self._ax.set_ylim3d(self._curr_y_min, self._curr_y_max)\n            self._ax.set_zlim3d(self._curr_z_min, self._curr_z_max)\n        self._init_lim()\n\n    @staticmethod\n    def _lines_extract_xyz_impl(index, lines_task):\n        return lines_task[index, :, 0], lines_task[index, :, 1], lines_task[index, :, 2]\n\n    @staticmethod\n    def _trail_extract_xyz_impl(index, trail_task):\n        return (\n            trail_task[index : index + 2, 0],\n            trail_task[index : index + 2, 1],\n            trail_task[index : index + 2, 2],\n        )\n\n    def _lines_create_impl(self, lines_task):\n        color = lines_task.color\n        self._artist_cache[lines_task.task_name] = [\n            self._ax.plot(\n                *Matplotlib3DPlotter._lines_extract_xyz_impl(i, lines_task),\n                color=color,\n                linewidth=lines_task.line_width,\n                alpha=lines_task.alpha\n            )[0]\n            for i in range(len(lines_task))\n        ]\n\n    def _lines_update_impl(self, lines_task):\n        lines_artists = self._artist_cache[lines_task.task_name]\n        for i in range(len(lines_task)):\n            artist = lines_artists[i]\n            xs, ys, zs = Matplotlib3DPlotter._lines_extract_xyz_impl(i, lines_task)\n            artist.set_data(xs, ys)\n            artist.set_3d_properties(zs)\n            if lines_task.influence_lim:\n                self._update_lim(xs, ys, zs)\n\n    def _dots_create_impl(self, dots_task):\n        color = dots_task.color\n        self._artist_cache[dots_task.task_name] = self._ax.plot(\n            dots_task[:, 0],\n            dots_task[:, 1],\n            dots_task[:, 2],\n            c=color,\n            linestyle=\"\",\n            marker=\".\",\n            markersize=dots_task.marker_size,\n            alpha=dots_task.alpha,\n        )[0]\n\n    def _dots_update_impl(self, dots_task):\n        dots_artist = self._artist_cache[dots_task.task_name]\n        dots_artist.set_data(dots_task[:, 0], dots_task[:, 1])\n        dots_artist.set_3d_properties(dots_task[:, 2])\n        if dots_task.influence_lim:\n            self._update_lim(dots_task[:, 0], dots_task[:, 1], dots_task[:, 2])\n\n    def _trail_create_impl(self, trail_task):\n        color = trail_task.color\n        trail_length = len(trail_task) - 1\n        self._artist_cache[trail_task.task_name] = [\n            self._ax.plot(\n                *Matplotlib3DPlotter._trail_extract_xyz_impl(i, trail_task),\n                color=trail_task.color,\n                linewidth=trail_task.line_width,\n                alpha=trail_task.alpha * (1.0 - i / (trail_length - 1))\n            )[0]\n            for i in range(trail_length)\n        ]\n\n    def _trail_update_impl(self, trail_task):\n        trails_artists = self._artist_cache[trail_task.task_name]\n        for i in range(len(trail_task) - 1):\n            artist = trails_artists[i]\n            xs, ys, zs = Matplotlib3DPlotter._trail_extract_xyz_impl(i, trail_task)\n            artist.set_data(xs, ys)\n            artist.set_3d_properties(zs)\n            if trail_task.influence_lim:\n                self._update_lim(xs, ys, zs)\n\n    def _create_impl(self, task_list):\n        for task in task_list:\n            self._create_impl_callables[task.task_type](task)\n        self._draw()\n\n    def _update_impl(self, task_list):\n        for task in task_list:\n            self._update_impl_callables[task.task_type](task)\n        self._draw()\n\n    def _set_aspect_equal_3d(self):\n        xlim = self._ax.get_xlim3d()\n        ylim = self._ax.get_ylim3d()\n        zlim = self._ax.get_zlim3d()\n\n        xmean = np.mean(xlim)\n        ymean = np.mean(ylim)\n        zmean = np.mean(zlim)\n\n        plot_radius = max(\n            [\n                abs(lim - mean_)\n                for lims, mean_ in ((xlim, xmean), (ylim, ymean), (zlim, zmean))\n                for lim in lims\n            ]\n        )\n\n        self._ax.set_xlim3d([xmean - plot_radius, xmean + plot_radius])\n        self._ax.set_ylim3d([ymean - plot_radius, ymean + plot_radius])\n        self._ax.set_zlim3d([zmean - plot_radius, zmean + plot_radius])\n\n    def _draw(self):\n        self._set_lim()\n        self._set_aspect_equal_3d()\n        self._fig.canvas.draw()\n        self._fig.canvas.flush_events()\n        plt.pause(0.00001)\n"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/visualization/simple_plotter_tasks.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n\n\"\"\"\nThis is where all the task primitives are defined\n\"\"\"\nimport numpy as np\n\nfrom .core import BasePlotterTask\n\n\nclass DrawXDLines(BasePlotterTask):\n    _lines: np.ndarray\n    _color: str\n    _line_width: int\n    _alpha: float\n    _influence_lim: bool\n\n    def __init__(\n        self,\n        task_name: str,\n        lines: np.ndarray,\n        color: str = \"blue\",\n        line_width: int = 2,\n        alpha: float = 1.0,\n        influence_lim: bool = True,\n    ) -> None:\n        super().__init__(task_name=task_name, task_type=self.__class__.__name__)\n        self._color = color\n        self._line_width = line_width\n        self._alpha = alpha\n        self._influence_lim = influence_lim\n        self.update(lines)\n\n    @property\n    def influence_lim(self) -> bool:\n        return self._influence_lim\n\n    @property\n    def raw_data(self):\n        return self._lines\n\n    @property\n    def color(self):\n        return self._color\n\n    @property\n    def line_width(self):\n        return self._line_width\n\n    @property\n    def alpha(self):\n        return self._alpha\n\n    @property\n    def dim(self):\n        raise NotImplementedError\n\n    @property\n    def name(self):\n        return \"{}DLines\".format(self.dim)\n\n    def update(self, lines):\n        self._lines = np.array(lines)\n        shape = self._lines.shape\n        assert shape[-1] == self.dim and shape[-2] == 2 and len(shape) == 3\n\n    def __getitem__(self, index):\n        return self._lines[index]\n\n    def __len__(self):\n        return self._lines.shape[0]\n\n    def __iter__(self):\n        yield self\n\n\nclass DrawXDDots(BasePlotterTask):\n    _dots: np.ndarray\n    _color: str\n    _marker_size: int\n    _alpha: float\n    _influence_lim: bool\n\n    def __init__(\n        self,\n        task_name: str,\n        dots: np.ndarray,\n        color: str = \"blue\",\n        marker_size: int = 10,\n        alpha: float = 1.0,\n        influence_lim: bool = True,\n    ) -> None:\n        super().__init__(task_name=task_name, task_type=self.__class__.__name__)\n        self._color = color\n        self._marker_size = marker_size\n        self._alpha = alpha\n        self._influence_lim = influence_lim\n        self.update(dots)\n\n    def update(self, dots):\n        self._dots = np.array(dots)\n        shape = self._dots.shape\n        assert shape[-1] == self.dim and len(shape) == 2\n\n    def __getitem__(self, index):\n        return self._dots[index]\n\n    def __len__(self):\n        return self._dots.shape[0]\n\n    def __iter__(self):\n        yield self\n\n    @property\n    def influence_lim(self) -> bool:\n        return self._influence_lim\n\n    @property\n    def raw_data(self):\n        return self._dots\n\n    @property\n    def color(self):\n        return self._color\n\n    @property\n    def marker_size(self):\n        return self._marker_size\n\n    @property\n    def alpha(self):\n        return self._alpha\n\n    @property\n    def dim(self):\n        raise NotImplementedError\n\n    @property\n    def name(self):\n        return \"{}DDots\".format(self.dim)\n\n\nclass DrawXDTrail(DrawXDDots):\n    @property\n    def line_width(self):\n        return self.marker_size\n\n    @property\n    def name(self):\n        return \"{}DTrail\".format(self.dim)\n\n\nclass Draw2DLines(DrawXDLines):\n    @property\n    def dim(self):\n        return 2\n\n\nclass Draw3DLines(DrawXDLines):\n    @property\n    def dim(self):\n        return 3\n\n\nclass Draw2DDots(DrawXDDots):\n    @property\n    def dim(self):\n        return 2\n\n\nclass Draw3DDots(DrawXDDots):\n    @property\n    def dim(self):\n        return 3\n\n\nclass Draw2DTrail(DrawXDTrail):\n    @property\n    def dim(self):\n        return 2\n\n\nclass Draw3DTrail(DrawXDTrail):\n    @property\n    def dim(self):\n        return 3\n\n"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/visualization/skeleton_plotter_tasks.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n\n\"\"\"\nThis is where all skeleton related complex tasks are defined (skeleton state and skeleton\nmotion)\n\"\"\"\nimport numpy as np\n\nfrom .core import BasePlotterTask\nfrom .simple_plotter_tasks import Draw3DDots, Draw3DLines, Draw3DTrail\n\n\nclass Draw3DSkeletonState(BasePlotterTask):\n    _lines_task: Draw3DLines  # sub-task for drawing lines\n    _dots_task: Draw3DDots  # sub-task for drawing dots\n\n    def __init__(\n        self,\n        task_name: str,\n        skeleton_state,\n        joints_color: str = \"red\",\n        lines_color: str = \"blue\",\n        alpha=1.0,\n    ) -> None:\n        super().__init__(task_name=task_name, task_type=\"3DSkeletonState\")\n        lines, dots = Draw3DSkeletonState._get_lines_and_dots(skeleton_state)\n        self._lines_task = Draw3DLines(\n            self.get_scoped_name(\"bodies\"), lines, joints_color, alpha=alpha\n        )\n        self._dots_task = Draw3DDots(\n            self.get_scoped_name(\"joints\"), dots, lines_color, alpha=alpha\n        )\n\n    @property\n    def name(self):\n        return \"3DSkeleton\"\n\n    def update(self, skeleton_state) -> None:\n        self._update(*Draw3DSkeletonState._get_lines_and_dots(skeleton_state))\n\n    @staticmethod\n    def _get_lines_and_dots(skeleton_state):\n        \"\"\"Get all the lines and dots needed to draw the skeleton state\n        \"\"\"\n        assert (\n            len(skeleton_state.tensor.shape) == 1\n        ), \"the state has to be zero dimensional\"\n        dots = skeleton_state.global_translation.numpy()\n        skeleton_tree = skeleton_state.skeleton_tree\n        parent_indices = skeleton_tree.parent_indices.numpy()\n        lines = []\n        for node_index in range(len(skeleton_tree)):\n            parent_index = parent_indices[node_index]\n            if parent_index != -1:\n                lines.append([dots[node_index], dots[parent_index]])\n        lines = np.array(lines)\n        return lines, dots\n\n    def _update(self, lines, dots) -> None:\n        self._lines_task.update(lines)\n        self._dots_task.update(dots)\n\n    def __iter__(self):\n        yield from self._lines_task\n        yield from self._dots_task\n\n\nclass Draw3DSkeletonMotion(BasePlotterTask):\n    def __init__(\n        self,\n        task_name: str,\n        skeleton_motion,\n        frame_index=None,\n        joints_color=\"red\",\n        lines_color=\"blue\",\n        velocity_color=\"green\",\n        angular_velocity_color=\"purple\",\n        trail_color=\"black\",\n        trail_length=10,\n        alpha=1.0,\n    ) -> None:\n        super().__init__(task_name=task_name, task_type=\"3DSkeletonMotion\")\n        self._trail_length = trail_length\n        self._skeleton_motion = skeleton_motion\n        # if frame_index is None:\n        curr_skeleton_motion = self._skeleton_motion.clone()\n        if frame_index is not None:\n            curr_skeleton_motion.tensor = self._skeleton_motion.tensor[frame_index, :]\n        # else:\n        #     curr_skeleton_motion = self._skeleton_motion[frame_index, :]\n        self._skeleton_state_task = Draw3DSkeletonState(\n            self.get_scoped_name(\"skeleton_state\"),\n            curr_skeleton_motion,\n            joints_color=joints_color,\n            lines_color=lines_color,\n            alpha=alpha,\n        )\n        vel_lines, avel_lines = Draw3DSkeletonMotion._get_vel_and_avel(\n            curr_skeleton_motion\n        )\n        self._com_pos = curr_skeleton_motion.root_translation.numpy()[\n            np.newaxis, ...\n        ].repeat(trail_length, axis=0)\n        self._vel_task = Draw3DLines(\n            self.get_scoped_name(\"velocity\"),\n            vel_lines,\n            velocity_color,\n            influence_lim=False,\n            alpha=alpha,\n        )\n        self._avel_task = Draw3DLines(\n            self.get_scoped_name(\"angular_velocity\"),\n            avel_lines,\n            angular_velocity_color,\n            influence_lim=False,\n            alpha=alpha,\n        )\n        self._com_trail_task = Draw3DTrail(\n            self.get_scoped_name(\"com_trail\"),\n            self._com_pos,\n            trail_color,\n            marker_size=2,\n            influence_lim=True,\n            alpha=alpha,\n        )\n\n    @property\n    def name(self):\n        return \"3DSkeletonMotion\"\n\n    def update(self, frame_index=None, reset_trail=False, skeleton_motion=None) -> None:\n        if skeleton_motion is not None:\n            self._skeleton_motion = skeleton_motion\n\n        curr_skeleton_motion = self._skeleton_motion.clone()\n        if frame_index is not None:\n            curr_skeleton_motion.tensor = curr_skeleton_motion.tensor[frame_index, :]\n        if reset_trail:\n            self._com_pos = curr_skeleton_motion.root_translation.numpy()[\n                np.newaxis, ...\n            ].repeat(self._trail_length, axis=0)\n        else:\n            self._com_pos = np.concatenate(\n                (\n                    curr_skeleton_motion.root_translation.numpy()[np.newaxis, ...],\n                    self._com_pos[:-1],\n                ),\n                axis=0,\n            )\n        self._skeleton_state_task.update(curr_skeleton_motion)\n        self._com_trail_task.update(self._com_pos)\n        self._update(*Draw3DSkeletonMotion._get_vel_and_avel(curr_skeleton_motion))\n\n    @staticmethod\n    def _get_vel_and_avel(skeleton_motion):\n        \"\"\"Get all the velocity and angular velocity lines\n        \"\"\"\n        pos = skeleton_motion.global_translation.numpy()\n        vel = skeleton_motion.global_velocity.numpy()\n        avel = skeleton_motion.global_angular_velocity.numpy()\n\n        vel_lines = np.stack((pos, pos + vel * 0.02), axis=1)\n        avel_lines = np.stack((pos, pos + avel * 0.01), axis=1)\n        return vel_lines, avel_lines\n\n    def _update(self, vel_lines, avel_lines) -> None:\n        self._vel_task.update(vel_lines)\n        self._avel_task.update(avel_lines)\n\n    def __iter__(self):\n        yield from self._skeleton_state_task\n        yield from self._vel_task\n        yield from self._avel_task\n        yield from self._com_trail_task\n\n\nclass Draw3DSkeletonMotions(BasePlotterTask):\n    def __init__(self, skeleton_motion_tasks) -> None:\n        self._skeleton_motion_tasks = skeleton_motion_tasks\n\n    @property\n    def name(self):\n        return \"3DSkeletonMotions\"\n\n    def update(self, frame_index) -> None:\n        list(map(lambda x: x.update(frame_index), self._skeleton_motion_tasks))\n\n    def __iter__(self):\n        yield from self._skeleton_state_tasks\n"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/visualization/tests/__init__.py",
    "content": "# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited."
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/visualization/tests/test_plotter.py",
    "content": "# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\nfrom typing import cast\n\nimport matplotlib.pyplot as plt\nimport numpy as np\n\nfrom ..core import BasePlotterTask, BasePlotterTasks\nfrom ..plt_plotter import Matplotlib3DPlotter\nfrom ..simple_plotter_tasks import Draw3DDots, Draw3DLines\n\ntask = Draw3DLines(task_name=\"test\", \n    lines=np.array([[[0, 0, 0], [0, 0, 1]], [[0, 1, 1], [0, 1, 0]]]), color=\"blue\")\ntask2 = Draw3DDots(task_name=\"test2\", \n    dots=np.array([[0, 0, 0], [0, 0, 1], [0, 1, 1], [0, 1, 0]]), color=\"red\")\ntask3 = BasePlotterTasks([task, task2])\nplotter = Matplotlib3DPlotter(cast(BasePlotterTask, task3))\nplt.show()\n"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/retarget_motion.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nfrom isaacgym.torch_utils import *\nimport torch\nimport json\nimport numpy as np\n\nfrom poselib.core.rotation3d import *\nfrom poselib.skeleton.skeleton3d import SkeletonTree, SkeletonState, SkeletonMotion\nfrom poselib.visualization.common import plot_skeleton_state, plot_skeleton_motion_interactive\n\n\"\"\"\nThis scripts shows how to retarget a motion clip from the source skeleton to a target skeleton.\nData required for retargeting are stored in a retarget config dictionary as a json file. This file contains:\n  - source_motion: a SkeletonMotion npy format representation of a motion sequence. The motion clip should use the same skeleton as the source T-Pose skeleton.\n  - target_motion_path: path to save the retargeted motion to\n  - source_tpose: a SkeletonState npy format representation of the source skeleton in it's T-Pose state\n  - target_tpose: a SkeletonState npy format representation of the target skeleton in it's T-Pose state (pose should match source T-Pose)\n  - joint_mapping: mapping of joint names from source to target\n  - rotation: root rotation offset from source to target skeleton (for transforming across different orientation axes), represented as a quaternion in XYZW order.\n  - scale: scale offset from source to target skeleton\n\"\"\"\n\nVISUALIZE = False\n\ndef project_joints(motion):\n    right_upper_arm_id = motion.skeleton_tree._node_indices[\"right_upper_arm\"]\n    right_lower_arm_id = motion.skeleton_tree._node_indices[\"right_lower_arm\"]\n    right_hand_id = motion.skeleton_tree._node_indices[\"right_hand\"]\n    left_upper_arm_id = motion.skeleton_tree._node_indices[\"left_upper_arm\"]\n    left_lower_arm_id = motion.skeleton_tree._node_indices[\"left_lower_arm\"]\n    left_hand_id = motion.skeleton_tree._node_indices[\"left_hand\"]\n    \n    right_thigh_id = motion.skeleton_tree._node_indices[\"right_thigh\"]\n    right_shin_id = motion.skeleton_tree._node_indices[\"right_shin\"]\n    right_foot_id = motion.skeleton_tree._node_indices[\"right_foot\"]\n    left_thigh_id = motion.skeleton_tree._node_indices[\"left_thigh\"]\n    left_shin_id = motion.skeleton_tree._node_indices[\"left_shin\"]\n    left_foot_id = motion.skeleton_tree._node_indices[\"left_foot\"]\n    \n    device = motion.global_translation.device\n\n    # right arm\n    right_upper_arm_pos = motion.global_translation[..., right_upper_arm_id, :]\n    right_lower_arm_pos = motion.global_translation[..., right_lower_arm_id, :]\n    right_hand_pos = motion.global_translation[..., right_hand_id, :]\n    right_shoulder_rot = motion.local_rotation[..., right_upper_arm_id, :]\n    right_elbow_rot = motion.local_rotation[..., right_lower_arm_id, :]\n    \n    right_arm_delta0 = right_upper_arm_pos - right_lower_arm_pos\n    right_arm_delta1 = right_hand_pos - right_lower_arm_pos\n    right_arm_delta0 = right_arm_delta0 / torch.norm(right_arm_delta0, dim=-1, keepdim=True)\n    right_arm_delta1 = right_arm_delta1 / torch.norm(right_arm_delta1, dim=-1, keepdim=True)\n    right_elbow_dot = torch.sum(-right_arm_delta0 * right_arm_delta1, dim=-1)\n    right_elbow_dot = torch.clamp(right_elbow_dot, -1.0, 1.0)\n    right_elbow_theta = torch.acos(right_elbow_dot)\n    right_elbow_q = quat_from_angle_axis(-torch.abs(right_elbow_theta), torch.tensor(np.array([[0.0, 1.0, 0.0]]), \n                                            device=device, dtype=torch.float32))\n    \n    right_elbow_local_dir = motion.skeleton_tree.local_translation[right_hand_id]\n    right_elbow_local_dir = right_elbow_local_dir / torch.norm(right_elbow_local_dir)\n    right_elbow_local_dir_tile = torch.tile(right_elbow_local_dir.unsqueeze(0), [right_elbow_rot.shape[0], 1])\n    right_elbow_local_dir0 = quat_rotate(right_elbow_rot, right_elbow_local_dir_tile)\n    right_elbow_local_dir1 = quat_rotate(right_elbow_q, right_elbow_local_dir_tile)\n    right_arm_dot = torch.sum(right_elbow_local_dir0 * right_elbow_local_dir1, dim=-1)\n    right_arm_dot = torch.clamp(right_arm_dot, -1.0, 1.0)\n    right_arm_theta = torch.acos(right_arm_dot)\n    right_arm_theta = torch.where(right_elbow_local_dir0[..., 1] <= 0, right_arm_theta, -right_arm_theta)\n    right_arm_q = quat_from_angle_axis(right_arm_theta, right_elbow_local_dir.unsqueeze(0))\n    right_shoulder_rot = quat_mul(right_shoulder_rot, right_arm_q)\n    \n    # left arm\n    left_upper_arm_pos = motion.global_translation[..., left_upper_arm_id, :]\n    left_lower_arm_pos = motion.global_translation[..., left_lower_arm_id, :]\n    left_hand_pos = motion.global_translation[..., left_hand_id, :]\n    left_shoulder_rot = motion.local_rotation[..., left_upper_arm_id, :]\n    left_elbow_rot = motion.local_rotation[..., left_lower_arm_id, :]\n    \n    left_arm_delta0 = left_upper_arm_pos - left_lower_arm_pos\n    left_arm_delta1 = left_hand_pos - left_lower_arm_pos\n    left_arm_delta0 = left_arm_delta0 / torch.norm(left_arm_delta0, dim=-1, keepdim=True)\n    left_arm_delta1 = left_arm_delta1 / torch.norm(left_arm_delta1, dim=-1, keepdim=True)\n    left_elbow_dot = torch.sum(-left_arm_delta0 * left_arm_delta1, dim=-1)\n    left_elbow_dot = torch.clamp(left_elbow_dot, -1.0, 1.0)\n    left_elbow_theta = torch.acos(left_elbow_dot)\n    left_elbow_q = quat_from_angle_axis(-torch.abs(left_elbow_theta), torch.tensor(np.array([[0.0, 1.0, 0.0]]), \n                                        device=device, dtype=torch.float32))\n\n    left_elbow_local_dir = motion.skeleton_tree.local_translation[left_hand_id]\n    left_elbow_local_dir = left_elbow_local_dir / torch.norm(left_elbow_local_dir)\n    left_elbow_local_dir_tile = torch.tile(left_elbow_local_dir.unsqueeze(0), [left_elbow_rot.shape[0], 1])\n    left_elbow_local_dir0 = quat_rotate(left_elbow_rot, left_elbow_local_dir_tile)\n    left_elbow_local_dir1 = quat_rotate(left_elbow_q, left_elbow_local_dir_tile)\n    left_arm_dot = torch.sum(left_elbow_local_dir0 * left_elbow_local_dir1, dim=-1)\n    left_arm_dot = torch.clamp(left_arm_dot, -1.0, 1.0)\n    left_arm_theta = torch.acos(left_arm_dot)\n    left_arm_theta = torch.where(left_elbow_local_dir0[..., 1] <= 0, left_arm_theta, -left_arm_theta)\n    left_arm_q = quat_from_angle_axis(left_arm_theta, left_elbow_local_dir.unsqueeze(0))\n    left_shoulder_rot = quat_mul(left_shoulder_rot, left_arm_q)\n    \n    # right leg\n    right_thigh_pos = motion.global_translation[..., right_thigh_id, :]\n    right_shin_pos = motion.global_translation[..., right_shin_id, :]\n    right_foot_pos = motion.global_translation[..., right_foot_id, :]\n    right_hip_rot = motion.local_rotation[..., right_thigh_id, :]\n    right_knee_rot = motion.local_rotation[..., right_shin_id, :]\n    \n    right_leg_delta0 = right_thigh_pos - right_shin_pos\n    right_leg_delta1 = right_foot_pos - right_shin_pos\n    right_leg_delta0 = right_leg_delta0 / torch.norm(right_leg_delta0, dim=-1, keepdim=True)\n    right_leg_delta1 = right_leg_delta1 / torch.norm(right_leg_delta1, dim=-1, keepdim=True)\n    right_knee_dot = torch.sum(-right_leg_delta0 * right_leg_delta1, dim=-1)\n    right_knee_dot = torch.clamp(right_knee_dot, -1.0, 1.0)\n    right_knee_theta = torch.acos(right_knee_dot)\n    right_knee_q = quat_from_angle_axis(torch.abs(right_knee_theta), torch.tensor(np.array([[0.0, 1.0, 0.0]]), \n                                        device=device, dtype=torch.float32))\n    \n    right_knee_local_dir = motion.skeleton_tree.local_translation[right_foot_id]\n    right_knee_local_dir = right_knee_local_dir / torch.norm(right_knee_local_dir)\n    right_knee_local_dir_tile = torch.tile(right_knee_local_dir.unsqueeze(0), [right_knee_rot.shape[0], 1])\n    right_knee_local_dir0 = quat_rotate(right_knee_rot, right_knee_local_dir_tile)\n    right_knee_local_dir1 = quat_rotate(right_knee_q, right_knee_local_dir_tile)\n    right_leg_dot = torch.sum(right_knee_local_dir0 * right_knee_local_dir1, dim=-1)\n    right_leg_dot = torch.clamp(right_leg_dot, -1.0, 1.0)\n    right_leg_theta = torch.acos(right_leg_dot)\n    right_leg_theta = torch.where(right_knee_local_dir0[..., 1] >= 0, right_leg_theta, -right_leg_theta)\n    right_leg_q = quat_from_angle_axis(right_leg_theta, right_knee_local_dir.unsqueeze(0))\n    right_hip_rot = quat_mul(right_hip_rot, right_leg_q)\n    \n    # left leg\n    left_thigh_pos = motion.global_translation[..., left_thigh_id, :]\n    left_shin_pos = motion.global_translation[..., left_shin_id, :]\n    left_foot_pos = motion.global_translation[..., left_foot_id, :]\n    left_hip_rot = motion.local_rotation[..., left_thigh_id, :]\n    left_knee_rot = motion.local_rotation[..., left_shin_id, :]\n    \n    left_leg_delta0 = left_thigh_pos - left_shin_pos\n    left_leg_delta1 = left_foot_pos - left_shin_pos\n    left_leg_delta0 = left_leg_delta0 / torch.norm(left_leg_delta0, dim=-1, keepdim=True)\n    left_leg_delta1 = left_leg_delta1 / torch.norm(left_leg_delta1, dim=-1, keepdim=True)\n    left_knee_dot = torch.sum(-left_leg_delta0 * left_leg_delta1, dim=-1)\n    left_knee_dot = torch.clamp(left_knee_dot, -1.0, 1.0)\n    left_knee_theta = torch.acos(left_knee_dot)\n    left_knee_q = quat_from_angle_axis(torch.abs(left_knee_theta), torch.tensor(np.array([[0.0, 1.0, 0.0]]), \n                                        device=device, dtype=torch.float32))\n    \n    left_knee_local_dir = motion.skeleton_tree.local_translation[left_foot_id]\n    left_knee_local_dir = left_knee_local_dir / torch.norm(left_knee_local_dir)\n    left_knee_local_dir_tile = torch.tile(left_knee_local_dir.unsqueeze(0), [left_knee_rot.shape[0], 1])\n    left_knee_local_dir0 = quat_rotate(left_knee_rot, left_knee_local_dir_tile)\n    left_knee_local_dir1 = quat_rotate(left_knee_q, left_knee_local_dir_tile)\n    left_leg_dot = torch.sum(left_knee_local_dir0 * left_knee_local_dir1, dim=-1)\n    left_leg_dot = torch.clamp(left_leg_dot, -1.0, 1.0)\n    left_leg_theta = torch.acos(left_leg_dot)\n    left_leg_theta = torch.where(left_knee_local_dir0[..., 1] >= 0, left_leg_theta, -left_leg_theta)\n    left_leg_q = quat_from_angle_axis(left_leg_theta, left_knee_local_dir.unsqueeze(0))\n    left_hip_rot = quat_mul(left_hip_rot, left_leg_q)\n    \n\n    new_local_rotation = motion.local_rotation.clone()\n    new_local_rotation[..., right_upper_arm_id, :] = right_shoulder_rot\n    new_local_rotation[..., right_lower_arm_id, :] = right_elbow_q\n    new_local_rotation[..., left_upper_arm_id, :] = left_shoulder_rot\n    new_local_rotation[..., left_lower_arm_id, :] = left_elbow_q\n    \n    new_local_rotation[..., right_thigh_id, :] = right_hip_rot\n    new_local_rotation[..., right_shin_id, :] = right_knee_q\n    new_local_rotation[..., left_thigh_id, :] = left_hip_rot\n    new_local_rotation[..., left_shin_id, :] = left_knee_q\n    \n    new_local_rotation[..., left_hand_id, :] = quat_identity([1])\n    new_local_rotation[..., right_hand_id, :] = quat_identity([1])\n\n    new_sk_state = SkeletonState.from_rotation_and_root_translation(motion.skeleton_tree, new_local_rotation, motion.root_translation, is_local=True)\n    new_motion = SkeletonMotion.from_skeleton_state(new_sk_state, fps=motion.fps)\n    \n    return new_motion\n\n\ndef main():\n    # load retarget config\n    retarget_data_path = \"data/configs/retarget_cmu_to_amp.json\"\n    with open(retarget_data_path) as f:\n        retarget_data = json.load(f)\n\n    # load and visualize t-pose files\n    source_tpose = SkeletonState.from_file(retarget_data[\"source_tpose\"])\n    if VISUALIZE:\n        plot_skeleton_state(source_tpose)\n\n    target_tpose = SkeletonState.from_file(retarget_data[\"target_tpose\"])\n    if VISUALIZE:\n        plot_skeleton_state(target_tpose)\n\n    # load and visualize source motion sequence\n    source_motion = SkeletonMotion.from_file(retarget_data[\"source_motion\"])\n    if VISUALIZE:\n        plot_skeleton_motion_interactive(source_motion)\n\n    # parse data from retarget config\n    joint_mapping = retarget_data[\"joint_mapping\"]\n    rotation_to_target_skeleton = torch.tensor(retarget_data[\"rotation\"])\n\n    # run retargeting\n    target_motion = source_motion.retarget_to_by_tpose(\n      joint_mapping=retarget_data[\"joint_mapping\"],\n      source_tpose=source_tpose,\n      target_tpose=target_tpose,\n      rotation_to_target_skeleton=rotation_to_target_skeleton,\n      scale_to_target_skeleton=retarget_data[\"scale\"]\n    )\n\n    # keep frames between [trim_frame_beg, trim_frame_end - 1]\n    frame_beg = retarget_data[\"trim_frame_beg\"]\n    frame_end = retarget_data[\"trim_frame_end\"]\n    if (frame_beg == -1):\n        frame_beg = 0\n        \n    if (frame_end == -1):\n        frame_end = target_motion.local_rotation.shape[0]\n        \n    local_rotation = target_motion.local_rotation\n    root_translation = target_motion.root_translation\n    local_rotation = local_rotation[frame_beg:frame_end, ...]\n    root_translation = root_translation[frame_beg:frame_end, ...]\n      \n    new_sk_state = SkeletonState.from_rotation_and_root_translation(target_motion.skeleton_tree, local_rotation, root_translation, is_local=True)\n    target_motion = SkeletonMotion.from_skeleton_state(new_sk_state, fps=target_motion.fps)\n\n    # need to convert some joints from 3D to 1D (e.g. elbows and knees)\n    target_motion = project_joints(target_motion)\n\n    # move the root so that the feet are on the ground\n    local_rotation = target_motion.local_rotation\n    root_translation = target_motion.root_translation\n    tar_global_pos = target_motion.global_translation\n    min_h = torch.min(tar_global_pos[..., 2])\n    root_translation[:, 2] += -min_h\n    \n    # adjust the height of the root to avoid ground penetration\n    root_height_offset = retarget_data[\"root_height_offset\"]\n    root_translation[:, 2] += root_height_offset\n    \n    new_sk_state = SkeletonState.from_rotation_and_root_translation(target_motion.skeleton_tree, local_rotation, root_translation, is_local=True)\n    target_motion = SkeletonMotion.from_skeleton_state(new_sk_state, fps=target_motion.fps)\n\n    # save retargeted motion\n    target_motion.to_file(retarget_data[\"target_motion_path\"])\n\n    # visualize retargeted motion\n    plot_skeleton_motion_interactive(target_motion)\n    \n    return\n\nif __name__ == '__main__':\n    main()"
  },
  {
    "path": "timechamber/tasks/base/__init__.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n"
  },
  {
    "path": "timechamber/tasks/base/ma_vec_task.py",
    "content": "# Copyright (c) 2018-2021, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nfrom typing import Dict, Any, Tuple\n\nimport gym\nfrom gym import spaces\n\nfrom isaacgym import gymtorch, gymapi\nfrom isaacgym.torch_utils import to_torch\nfrom isaacgym.gymutil import get_property_setter_map, get_property_getter_map, get_default_setter_args, \\\n    apply_random_samples, check_buckets, generate_random_samples\n\nimport torch\nimport numpy as np\nimport operator, random\nfrom copy import deepcopy\n\nimport sys\n\nimport abc\nfrom .vec_task import Env\n\n\nclass MA_VecTask(Env):\n\n    def __init__(self, config, rl_device, sim_device, graphics_device_id, headless,\n                 virtual_screen_capture: bool = False, force_render: bool = False):\n        \"\"\"Initialise the `MA_VecTask`.\n\n        Args:\n            config: config dictionary for the environment.\n            sim_device: the device to simulate physics on. eg. 'cuda:0' or 'cpu'\n            graphics_device_id: the device ID to render with.\n            headless: Set to False to disable viewer rendering.\n        \"\"\"\n        super().__init__(config, rl_device, sim_device, graphics_device_id, headless)\n        \n        self.virtual_screen_capture = virtual_screen_capture\n        self.force_render = force_render\n\n        self.sim_params = self.__parse_sim_params(self.cfg[\"physics_engine\"], self.cfg[\"sim\"])\n        if self.cfg[\"physics_engine\"] == \"physx\":\n            self.physics_engine = gymapi.SIM_PHYSX\n        elif self.cfg[\"physics_engine\"] == \"flex\":\n            self.physics_engine = gymapi.SIM_FLEX\n        else:\n            msg = f\"Invalid physics engine backend: {self.cfg['physics_engine']}\"\n            raise ValueError(msg)\n\n        # optimization flags for pytorch JIT\n        torch._C._jit_set_profiling_mode(False)\n        torch._C._jit_set_profiling_executor(False)\n\n        self.gym = gymapi.acquire_gym()\n\n        self.first_randomization = True\n        self.original_props = {}\n        self.dr_randomizations = {}\n        self.actor_params_generator = None\n        self.extern_actor_params = {}\n        self.last_step = -1\n        self.last_rand_step = -1\n        for env_id in range(self.num_envs):\n            self.extern_actor_params[env_id] = None\n\n        # create envs, sim and viewer\n        self.sim_initialized = False\n        self.create_sim()\n        self.gym.prepare_sim(self.sim)\n        self.sim_initialized = True\n\n        self.set_viewer()\n        self.allocate_buffers()\n\n        self.obs_dict = {}\n\n    def set_viewer(self):\n        \"\"\"Create the viewer.\"\"\"\n\n        # todo: read from config\n        self.enable_viewer_sync = True\n        self.viewer = None\n\n        # if running with a viewer, set up keyboard shortcuts and camera\n        if self.headless == False:\n            # subscribe to keyboard shortcuts\n            self.viewer = self.gym.create_viewer(\n                self.sim, gymapi.CameraProperties())\n            self.gym.subscribe_viewer_keyboard_event(\n                self.viewer, gymapi.KEY_ESCAPE, \"QUIT\")\n            self.gym.subscribe_viewer_keyboard_event(\n                self.viewer, gymapi.KEY_V, \"toggle_viewer_sync\")\n\n            # set the camera position based on up axis\n            sim_params = self.gym.get_sim_params(self.sim)\n            if sim_params.up_axis == gymapi.UP_AXIS_Z:\n                cam_pos = gymapi.Vec3(20.0, 25.0, 3.0)\n                cam_target = gymapi.Vec3(10.0, 15.0, 0.0)\n            else:\n                cam_pos = gymapi.Vec3(20.0, 3.0, 25.0)\n                cam_target = gymapi.Vec3(10.0, 0.0, 15.0)\n\n            self.gym.viewer_camera_look_at(\n                self.viewer, None, cam_pos, cam_target)\n\n    def allocate_buffers(self):\n        \"\"\"Allocate the observation, states, etc. buffers.\n\n        These are what is used to set observations and states in the environment classes which\n        inherit from this one, and are read in `step` and other related functions.\n\n        \"\"\"\n\n        # allocate buffers\n        self.obs_buf = torch.zeros(\n            (self.num_envs * self.num_agents, self.num_obs), device=self.device, dtype=torch.float)\n        self.states_buf = torch.zeros(\n            (self.num_envs, self.num_states), device=self.device, dtype=torch.float)\n        self.rew_buf = torch.zeros(\n            self.num_envs * self.num_agents, device=self.device, dtype=torch.float)\n        self.reset_buf = torch.ones(\n            self.num_envs * self.num_agents, device=self.device, dtype=torch.long)\n        self.timeout_buf = torch.zeros(\n            self.num_envs * self.num_agents, device=self.device, dtype=torch.long)\n        self.progress_buf = torch.zeros(\n            self.num_envs * self.num_agents, device=self.device, dtype=torch.long)\n        self.randomize_buf = torch.zeros(\n            self.num_envs * self.num_agents, device=self.device, dtype=torch.long)\n        self.extras = {}\n\n    def set_sim_params_up_axis(self, sim_params: gymapi.SimParams, axis: str) -> int:\n        \"\"\"Set gravity based on up axis and return axis index.\n\n        Args:\n            sim_params: sim params to modify the axis for.\n            axis: axis to set sim params for.\n        Returns:\n            axis index for up axis.\n        \"\"\"\n        if axis == 'z':\n            sim_params.up_axis = gymapi.UP_AXIS_Z\n            sim_params.gravity.x = 0\n            sim_params.gravity.y = 0\n            sim_params.gravity.z = -9.81\n            return 2\n        return 1\n\n    def create_sim(self, compute_device: int, graphics_device: int, physics_engine, sim_params: gymapi.SimParams):\n        \"\"\"Create an Isaac Gym sim object.\n\n        Args:\n            compute_device: ID of compute device to use.\n            graphics_device: ID of graphics device to use.\n            physics_engine: physics engine to use (`gymapi.SIM_PHYSX` or `gymapi.SIM_FLEX`)\n            sim_params: sim params to use.\n        Returns:\n            the Isaac Gym sim object.\n        \"\"\"\n        sim = self.gym.create_sim(compute_device, graphics_device, physics_engine, sim_params)\n        if sim is None:\n            print(\"*** Failed to create sim\")\n            quit()\n\n        return sim\n\n    def get_state(self):\n        \"\"\"Returns the state buffer of the environment (the priviledged observations for asymmetric training).\"\"\"\n        return torch.clamp(self.states_buf, -self.clip_obs, self.clip_obs).to(self.rl_device)\n\n    @abc.abstractmethod\n    def pre_physics_step(self, actions: torch.Tensor):\n        \"\"\"Apply the actions to the environment (eg by setting torques, position targets).\n\n        Args:\n            actions: the actions to apply\n        \"\"\"\n\n    @abc.abstractmethod\n    def post_physics_step(self):\n        \"\"\"Compute reward and observations, reset any environments that require it.\"\"\"\n\n    def step(self, actions: torch.Tensor) -> Tuple[Dict[str, torch.Tensor], torch.Tensor, torch.Tensor, Dict[str, Any]]:\n        \"\"\"Step the physics of the environment.\n\n        Args:\n            actions: actions to apply\n        Returns:\n            Observations, rewards, resets, info\n            Observations are dict of observations (currently only one member called 'obs')\n        \"\"\"\n\n        # randomize actions\n        if self.dr_randomizations.get('actions', None):\n            actions = self.dr_randomizations['actions']['noise_lambda'](actions)\n\n        # apply actions\n        self.pre_physics_step(actions)\n\n        # step physics and render each frame\n        for i in range(self.control_freq_inv):\n            if self.force_render:\n                self.render()\n            self.gym.simulate(self.sim)\n\n        # to fix!\n        if self.device == 'cpu':\n            self.gym.fetch_results(self.sim, True)\n\n        # fill time out buffer\n        self.timeout_buf = torch.where(self.progress_buf >= self.max_episode_length - 1,\n                                       torch.ones_like(self.timeout_buf), torch.zeros_like(self.timeout_buf))\n\n        # compute observations, rewards, resets, ...\n        self.post_physics_step()\n\n        # randomize observations\n        if self.dr_randomizations.get('observations', None):\n            self.obs_buf = self.dr_randomizations['observations']['noise_lambda'](self.obs_buf)\n\n        self.extras[\"time_outs\"] = self.timeout_buf.to(self.rl_device)\n\n        return\n\n    def zero_actions(self) -> torch.Tensor:\n        \"\"\"Returns a buffer with zero actions.\n\n        Returns:\n            A buffer of zero torch actions\n        \"\"\"\n        actions = torch.zeros([self.num_envs * self.num_agents, self.num_actions], dtype=torch.float32,\n                              device=self.rl_device)\n\n        return actions\n\n    def reset(self, env_ids=None) -> torch.Tensor:\n        \"\"\"Reset the environment.\n        \"\"\"\n        if (env_ids is None):\n            # zero_actions = self.zero_actions()\n            # self.step(zero_actions)\n            env_ids = to_torch(np.arange(self.num_envs), device=self.device, dtype=torch.long)\n            self.reset_idx(env_ids)\n            self.compute_observations()\n            self.pos_before = self.obs_buf[:self.num_envs, :2].clone()\n        else:\n            self._reset_envs(env_ids=env_ids)\n        return\n\n    def _reset_envs(self, env_ids):\n        if (len(env_ids) > 0):\n            self.reset_idx(env_ids)\n            self.compute_observations()\n            self.pos_before = self.obs_buf[:self.num_envs, :2].clone()\n        return\n\n    def reset_done(self):\n        \"\"\"Reset the environment.\n        Returns:\n            Observation dictionary, indices of environments being reset\n        \"\"\"\n        done_env_ids = self.reset_buf.nonzero(as_tuple=False).flatten()\n        if len(done_env_ids) > 0:\n            self.reset_idx(done_env_ids)\n\n        self.obs_dict[\"obs\"] = torch.clamp(self.obs_buf, -self.clip_obs, self.clip_obs).to(self.rl_device)\n        # asymmetric actor-critic\n        if self.num_states > 0:\n            self.obs_dict[\"states\"] = self.get_state()\n\n        return self.obs_dict, done_env_ids\n\n    def render(self):\n        \"\"\"Draw the frame to the viewer, and check for keyboard events.\"\"\"\n        if self.viewer:\n            # check for window closed\n            if self.gym.query_viewer_has_closed(self.viewer):\n                sys.exit()\n\n            # check for keyboard events\n            for evt in self.gym.query_viewer_action_events(self.viewer):\n                if evt.action == \"QUIT\" and evt.value > 0:\n                    sys.exit()\n                elif evt.action == \"toggle_viewer_sync\" and evt.value > 0:\n                    self.enable_viewer_sync = not self.enable_viewer_sync\n\n            # fetch results\n            if self.device != 'cpu':\n                self.gym.fetch_results(self.sim, True)\n\n            # step graphics\n            if self.enable_viewer_sync:\n                self.gym.step_graphics(self.sim)\n                self.gym.draw_viewer(self.viewer, self.sim, True)\n\n                # Wait for dt to elapse in real time.\n                # This synchronizes the physics simulation with the rendering rate.\n                self.gym.sync_frame_time(self.sim)\n\n            else:\n                self.gym.poll_viewer_events(self.viewer)\n\n    def __parse_sim_params(self, physics_engine: str, config_sim: Dict[str, Any]) -> gymapi.SimParams:\n        \"\"\"Parse the config dictionary for physics stepping settings.\n\n        Args:\n            physics_engine: which physics engine to use. \"physx\" or \"flex\"\n            config_sim: dict of sim configuration parameters\n        Returns\n            IsaacGym SimParams object with updated settings.\n        \"\"\"\n        sim_params = gymapi.SimParams()\n\n        # check correct up-axis\n        if config_sim[\"up_axis\"] not in [\"z\", \"y\"]:\n            msg = f\"Invalid physics up-axis: {config_sim['up_axis']}\"\n            print(msg)\n            raise ValueError(msg)\n\n        # assign general sim parameters\n        sim_params.dt = config_sim[\"dt\"]\n        sim_params.num_client_threads = config_sim.get(\"num_client_threads\", 0)\n        sim_params.use_gpu_pipeline = config_sim[\"use_gpu_pipeline\"]\n        sim_params.substeps = config_sim.get(\"substeps\", 2)\n\n        # assign up-axis\n        if config_sim[\"up_axis\"] == \"z\":\n            sim_params.up_axis = gymapi.UP_AXIS_Z\n        else:\n            sim_params.up_axis = gymapi.UP_AXIS_Y\n\n        # assign gravity\n        sim_params.gravity = gymapi.Vec3(*config_sim[\"gravity\"])\n\n        # configure physics parameters\n        if physics_engine == \"physx\":\n            # set the parameters\n            if \"physx\" in config_sim:\n                for opt in config_sim[\"physx\"].keys():\n                    if opt == \"contact_collection\":\n                        setattr(sim_params.physx, opt, gymapi.ContactCollection(config_sim[\"physx\"][opt]))\n                    else:\n                        setattr(sim_params.physx, opt, config_sim[\"physx\"][opt])\n        else:\n            # set the parameters\n            if \"flex\" in config_sim:\n                for opt in config_sim[\"flex\"].keys():\n                    setattr(sim_params.flex, opt, config_sim[\"flex\"][opt])\n\n        # return the configured params\n        return sim_params\n\n    \"\"\"\n    Domain Randomization methods\n    \"\"\"\n\n    def get_actor_params_info(self, dr_params: Dict[str, Any], env):\n        \"\"\"Generate a flat array of actor params, their names and ranges.\n\n        Returns:\n            The array\n        \"\"\"\n\n        if \"actor_params\" not in dr_params:\n            return None\n        params = []\n        names = []\n        lows = []\n        highs = []\n        param_getters_map = get_property_getter_map(self.gym)\n        for actor, actor_properties in dr_params[\"actor_params\"].items():\n            handle = self.gym.find_actor_handle(env, actor)\n            for prop_name, prop_attrs in actor_properties.items():\n                if prop_name == 'color':\n                    continue  # this is set randomly\n                props = param_getters_map[prop_name](env, handle)\n                if not isinstance(props, list):\n                    props = [props]\n                for prop_idx, prop in enumerate(props):\n                    for attr, attr_randomization_params in prop_attrs.items():\n                        name = prop_name + '_' + str(prop_idx) + '_' + attr\n                        lo_hi = attr_randomization_params['range']\n                        distr = attr_randomization_params['distribution']\n                        if 'uniform' not in distr:\n                            lo_hi = (-1.0 * float('Inf'), float('Inf'))\n                        if isinstance(prop, np.ndarray):\n                            for attr_idx in range(prop[attr].shape[0]):\n                                params.append(prop[attr][attr_idx])\n                                names.append(name + '_' + str(attr_idx))\n                                lows.append(lo_hi[0])\n                                highs.append(lo_hi[1])\n                        else:\n                            params.append(getattr(prop, attr))\n                            names.append(name)\n                            lows.append(lo_hi[0])\n                            highs.append(lo_hi[1])\n        return params, names, lows, highs\n\n    def apply_randomizations(self, dr_params):\n        \"\"\"Apply domain randomizations to the environment.\n\n        Note that currently we can only apply randomizations only on resets, due to current PhysX limitations\n\n        Args:\n            dr_params: parameters for domain randomization to use.\n        \"\"\"\n\n        # If we don't have a randomization frequency, randomize every step\n        rand_freq = dr_params.get(\"frequency\", 1)\n\n        # First, determine what to randomize:\n        #   - non-environment parameters when > frequency steps have passed since the last non-environment\n        #   - physical environments in the reset buffer, which have exceeded the randomization frequency threshold\n        #   - on the first call, randomize everything\n        self.last_step = self.gym.get_frame_count(self.sim)\n        if self.first_randomization:\n            do_nonenv_randomize = True\n            env_ids = list(range(self.num_envs))\n        else:\n            do_nonenv_randomize = (self.last_step - self.last_rand_step) >= rand_freq\n            rand_envs = torch.where(self.randomize_buf >= rand_freq, torch.ones_like(self.randomize_buf),\n                                    torch.zeros_like(self.randomize_buf))\n            rand_envs = torch.logical_and(rand_envs, self.reset_buf)\n            env_ids = torch.nonzero(rand_envs, as_tuple=False).squeeze(-1).tolist()\n            self.randomize_buf[rand_envs] = 0\n\n        if do_nonenv_randomize:\n            self.last_rand_step = self.last_step\n\n        param_setters_map = get_property_setter_map(self.gym)\n        param_setter_defaults_map = get_default_setter_args(self.gym)\n        param_getters_map = get_property_getter_map(self.gym)\n\n        # On first iteration, check the number of buckets\n        if self.first_randomization:\n            check_buckets(self.gym, self.envs, dr_params)\n\n        for nonphysical_param in [\"observations\", \"actions\"]:\n            if nonphysical_param in dr_params and do_nonenv_randomize:\n                dist = dr_params[nonphysical_param][\"distribution\"]\n                op_type = dr_params[nonphysical_param][\"operation\"]\n                sched_type = dr_params[nonphysical_param][\"schedule\"] if \"schedule\" in dr_params[\n                    nonphysical_param] else None\n                sched_step = dr_params[nonphysical_param][\"schedule_steps\"] if \"schedule\" in dr_params[\n                    nonphysical_param] else None\n                op = operator.add if op_type == 'additive' else operator.mul\n\n                if sched_type == 'linear':\n                    sched_scaling = 1.0 / sched_step * \\\n                                    min(self.last_step, sched_step)\n                elif sched_type == 'constant':\n                    sched_scaling = 0 if self.last_step < sched_step else 1\n                else:\n                    sched_scaling = 1\n\n                if dist == 'gaussian':\n                    mu, var = dr_params[nonphysical_param][\"range\"]\n                    mu_corr, var_corr = dr_params[nonphysical_param].get(\"range_correlated\", [0., 0.])\n\n                    if op_type == 'additive':\n                        mu *= sched_scaling\n                        var *= sched_scaling\n                        mu_corr *= sched_scaling\n                        var_corr *= sched_scaling\n                    elif op_type == 'scaling':\n                        var = var * sched_scaling  # scale up var over time\n                        mu = mu * sched_scaling + 1.0 * \\\n                             (1.0 - sched_scaling)  # linearly interpolate\n\n                        var_corr = var_corr * sched_scaling  # scale up var over time\n                        mu_corr = mu_corr * sched_scaling + 1.0 * \\\n                                  (1.0 - sched_scaling)  # linearly interpolate\n\n                    def noise_lambda(tensor, param_name=nonphysical_param):\n                        params = self.dr_randomizations[param_name]\n                        corr = params.get('corr', None)\n                        if corr is None:\n                            corr = torch.randn_like(tensor)\n                            params['corr'] = corr\n                        corr = corr * params['var_corr'] + params['mu_corr']\n                        return op(\n                            tensor, corr + torch.randn_like(tensor) * params['var'] + params['mu'])\n\n                    self.dr_randomizations[nonphysical_param] = {'mu': mu, 'var': var, 'mu_corr': mu_corr,\n                                                                 'var_corr': var_corr, 'noise_lambda': noise_lambda}\n\n                elif dist == 'uniform':\n                    lo, hi = dr_params[nonphysical_param][\"range\"]\n                    lo_corr, hi_corr = dr_params[nonphysical_param].get(\"range_correlated\", [0., 0.])\n\n                    if op_type == 'additive':\n                        lo *= sched_scaling\n                        hi *= sched_scaling\n                        lo_corr *= sched_scaling\n                        hi_corr *= sched_scaling\n                    elif op_type == 'scaling':\n                        lo = lo * sched_scaling + 1.0 * (1.0 - sched_scaling)\n                        hi = hi * sched_scaling + 1.0 * (1.0 - sched_scaling)\n                        lo_corr = lo_corr * sched_scaling + 1.0 * (1.0 - sched_scaling)\n                        hi_corr = hi_corr * sched_scaling + 1.0 * (1.0 - sched_scaling)\n\n                    def noise_lambda(tensor, param_name=nonphysical_param):\n                        params = self.dr_randomizations[param_name]\n                        corr = params.get('corr', None)\n                        if corr is None:\n                            corr = torch.randn_like(tensor)\n                            params['corr'] = corr\n                        corr = corr * (params['hi_corr'] - params['lo_corr']) + params['lo_corr']\n                        return op(tensor, corr + torch.rand_like(tensor) * (params['hi'] - params['lo']) + params['lo'])\n\n                    self.dr_randomizations[nonphysical_param] = {'lo': lo, 'hi': hi, 'lo_corr': lo_corr,\n                                                                 'hi_corr': hi_corr, 'noise_lambda': noise_lambda}\n\n        if \"sim_params\" in dr_params and do_nonenv_randomize:\n            prop_attrs = dr_params[\"sim_params\"]\n            prop = self.gym.get_sim_params(self.sim)\n\n            if self.first_randomization:\n                self.original_props[\"sim_params\"] = {\n                    attr: getattr(prop, attr) for attr in dir(prop)}\n\n            for attr, attr_randomization_params in prop_attrs.items():\n                apply_random_samples(\n                    prop, self.original_props[\"sim_params\"], attr, attr_randomization_params, self.last_step)\n\n            self.gym.set_sim_params(self.sim, prop)\n\n        # If self.actor_params_generator is initialized: use it to\n        # sample actor simulation params. This gives users the\n        # freedom to generate samples from arbitrary distributions,\n        # e.g. use full-covariance distributions instead of the DR's\n        # default of treating each simulation parameter independently.\n        extern_offsets = {}\n        if self.actor_params_generator is not None:\n            for env_id in env_ids:\n                self.extern_actor_params[env_id] = \\\n                    self.actor_params_generator.sample()\n                extern_offsets[env_id] = 0\n\n        for actor, actor_properties in dr_params[\"actor_params\"].items():\n            for env_id in env_ids:\n                env = self.envs[env_id]\n                handle = self.gym.find_actor_handle(env, actor)\n                extern_sample = self.extern_actor_params[env_id]\n\n                for prop_name, prop_attrs in actor_properties.items():\n                    if prop_name == 'color':\n                        num_bodies = self.gym.get_actor_rigid_body_count(\n                            env, handle)\n                        for n in range(num_bodies):\n                            self.gym.set_rigid_body_color(env, handle, n, gymapi.MESH_VISUAL,\n                                                          gymapi.Vec3(random.uniform(0, 1), random.uniform(0, 1),\n                                                                      random.uniform(0, 1)))\n                        continue\n                    if prop_name == 'scale':\n                        setup_only = prop_attrs.get('setup_only', False)\n                        if (setup_only and not self.sim_initialized) or not setup_only:\n                            attr_randomization_params = prop_attrs\n                            sample = generate_random_samples(attr_randomization_params, 1,\n                                                             self.last_step, None)\n                            og_scale = 1\n                            if attr_randomization_params['operation'] == 'scaling':\n                                new_scale = og_scale * sample\n                            elif attr_randomization_params['operation'] == 'additive':\n                                new_scale = og_scale + sample\n                            self.gym.set_actor_scale(env, handle, new_scale)\n                        continue\n\n                    prop = param_getters_map[prop_name](env, handle)\n                    set_random_properties = True\n                    if isinstance(prop, list):\n                        if self.first_randomization:\n                            self.original_props[prop_name] = [\n                                {attr: getattr(p, attr) for attr in dir(p)} for p in prop]\n                        for p, og_p in zip(prop, self.original_props[prop_name]):\n                            for attr, attr_randomization_params in prop_attrs.items():\n                                setup_only = attr_randomization_params.get('setup_only', False)\n                                if (setup_only and not self.sim_initialized) or not setup_only:\n                                    smpl = None\n                                    if self.actor_params_generator is not None:\n                                        smpl, extern_offsets[env_id] = get_attr_val_from_sample(\n                                            extern_sample, extern_offsets[env_id], p, attr)\n                                    apply_random_samples(\n                                        p, og_p, attr, attr_randomization_params,\n                                        self.last_step, smpl)\n                                else:\n                                    set_random_properties = False\n                    else:\n                        if self.first_randomization:\n                            self.original_props[prop_name] = deepcopy(prop)\n                        for attr, attr_randomization_params in prop_attrs.items():\n                            setup_only = attr_randomization_params.get('setup_only', False)\n                            if (setup_only and not self.sim_initialized) or not setup_only:\n                                smpl = None\n                                if self.actor_params_generator is not None:\n                                    smpl, extern_offsets[env_id] = get_attr_val_from_sample(\n                                        extern_sample, extern_offsets[env_id], prop, attr)\n                                apply_random_samples(\n                                    prop, self.original_props[prop_name], attr,\n                                    attr_randomization_params, self.last_step, smpl)\n                            else:\n                                set_random_properties = False\n\n                    if set_random_properties:\n                        setter = param_setters_map[prop_name]\n                        default_args = param_setter_defaults_map[prop_name]\n                        setter(env, handle, prop, *default_args)\n\n        if self.actor_params_generator is not None:\n            for env_id in env_ids:  # check that we used all dims in sample\n                if extern_offsets[env_id] > 0:\n                    extern_sample = self.extern_actor_params[env_id]\n                    if extern_offsets[env_id] != extern_sample.shape[0]:\n                        print('env_id', env_id,\n                              'extern_offset', extern_offsets[env_id],\n                              'vs extern_sample.shape', extern_sample.shape)\n                        raise Exception(\"Invalid extern_sample size\")\n\n        self.first_randomization = False\n"
  },
  {
    "path": "timechamber/tasks/base/vec_task.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nfrom typing import Dict, Any, Tuple\n\nimport gym\nfrom gym import spaces\n\nfrom isaacgym import gymtorch, gymapi\nfrom isaacgym.torch_utils import to_torch\nfrom isaacgym.gymutil import get_property_setter_map, get_property_getter_map, get_default_setter_args, apply_random_samples, check_buckets, generate_random_samples\n\nimport torch\nimport numpy as np\nimport operator, random\nfrom copy import deepcopy\nimport sys\n\nimport abc\nfrom abc import ABC\n\nEXISTING_SIM = None\nSCREEN_CAPTURE_RESOLUTION = (1027, 768)\n\ndef _create_sim_once(gym, *args, **kwargs):\n    global EXISTING_SIM\n    if EXISTING_SIM is not None:\n        return EXISTING_SIM\n    else:\n        EXISTING_SIM = gym.create_sim(*args, **kwargs)\n        return EXISTING_SIM\n\n\nclass Env(ABC):\n    def __init__(self, config: Dict[str, Any], rl_device: str, sim_device: str, graphics_device_id: int, headless: bool):\n        \"\"\"Initialise the env.\n\n        Args:\n            config: the configuration dictionary.\n            sim_device: the device to simulate physics on. eg. 'cuda:0' or 'cpu'\n            graphics_device_id: the device ID to render with.\n            headless: Set to False to disable viewer rendering.\n        \"\"\"\n\n        split_device = sim_device.split(\":\")\n        self.device_type = split_device[0]\n        self.device_id = int(split_device[1]) if len(split_device) > 1 else 0\n\n        self.device = \"cpu\"\n        if config[\"sim\"][\"use_gpu_pipeline\"]:\n            if self.device_type.lower() == \"cuda\" or self.device_type.lower() == \"gpu\":\n                self.device = \"cuda\" + \":\" + str(self.device_id)\n            else:\n                print(\"GPU Pipeline can only be used with GPU simulation. Forcing CPU Pipeline.\")\n                config[\"sim\"][\"use_gpu_pipeline\"] = False\n\n        self.rl_device = rl_device\n\n        # Rendering\n        # if training in a headless mode\n        self.headless = headless\n\n        enable_camera_sensors = config.get(\"enableCameraSensors\", False)\n        self.graphics_device_id = graphics_device_id\n        if enable_camera_sensors == False and self.headless == True:\n            self.graphics_device_id = -1\n\n        self.num_environments = config[\"env\"][\"numEnvs\"]\n        self.num_agents = config[\"env\"].get(\"numAgents\", 1)  # used for multi-agent environments\n        self.num_observations = config[\"env\"][\"numObservations\"]\n        self.num_states = config[\"env\"].get(\"numStates\", 0)\n        self.num_actions = config[\"env\"][\"numActions\"]\n\n        self.control_freq_inv = config[\"env\"].get(\"controlFrequencyInv\", 1)\n\n        self.obs_space = spaces.Box(np.ones(self.num_obs) * -np.Inf, np.ones(self.num_obs) * np.Inf)\n        self.state_space = spaces.Box(np.ones(self.num_states) * -np.Inf, np.ones(self.num_states) * np.Inf)\n\n        self.act_space = spaces.Box(np.ones(self.num_actions) * -1., np.ones(self.num_actions) * 1.)\n\n        self.clip_obs = config[\"env\"].get(\"clipObservations\", np.Inf)\n        self.clip_actions = config[\"env\"].get(\"clipActions\", np.Inf)\n\n    @abc.abstractmethod \n    def allocate_buffers(self):\n        \"\"\"Create torch buffers for observations, rewards, actions dones and any additional data.\"\"\"\n\n    @abc.abstractmethod\n    def step(self, actions: torch.Tensor) -> Tuple[Dict[str, torch.Tensor], torch.Tensor, torch.Tensor, Dict[str, Any]]:\n        \"\"\"Step the physics of the environment.\n\n        Args:\n            actions: actions to apply\n        Returns:\n            Observations, rewards, resets, info\n            Observations are dict of observations (currently only one member called 'obs')\n        \"\"\"\n\n    @abc.abstractmethod\n    def reset(self)-> Dict[str, torch.Tensor]:\n        \"\"\"Reset the environment.\n        Returns:\n            Observation dictionary\n        \"\"\"\n\n    @abc.abstractmethod\n    def reset_idx(self, env_ids: torch.Tensor):\n        \"\"\"Reset environments having the provided indices.\n        Args:\n            env_ids: environments to reset\n        \"\"\"\n\n    @property\n    def observation_space(self) -> gym.Space:\n        \"\"\"Get the environment's observation space.\"\"\"\n        return self.obs_space\n\n    @property\n    def action_space(self) -> gym.Space:\n        \"\"\"Get the environment's action space.\"\"\"\n        return self.act_space\n\n    @property\n    def num_envs(self) -> int:\n        \"\"\"Get the number of environments.\"\"\"\n        return self.num_environments\n\n    @property\n    def num_acts(self) -> int:\n        \"\"\"Get the number of actions in the environment.\"\"\"\n        return self.num_actions\n\n    @property\n    def num_obs(self) -> int:\n        \"\"\"Get the number of observations in the environment.\"\"\"\n        return self.num_observations\n\n\nclass VecTask(Env):\n\n    metadata = {\"render.modes\": [\"human\", \"rgb_array\"], \"video.frames_per_second\": 24}\n\n    def __init__(self, config, rl_device, sim_device, graphics_device_id, headless, virtual_screen_capture: bool = False, force_render: bool = False):\n        \"\"\"Initialise the `VecTask`.\n\n        Args:\n            config: config dictionary for the environment.\n            sim_device: the device to simulate physics on. eg. 'cuda:0' or 'cpu'\n            graphics_device_id: the device ID to render with.\n            headless: Set to False to disable viewer rendering.\n            virtual_screen_capture: Set to True to allow the users get captured screen in RGB array via `env.render(mode='rgb_array')`. \n            force_render: Set to True to always force rendering in the steps (if the `control_freq_inv` is greater than 1 we suggest stting this arg to True)\n        \"\"\"\n        super().__init__(config, rl_device, sim_device, graphics_device_id, headless)\n        self.virtual_screen_capture = virtual_screen_capture\n        self.virtual_display = None\n        if self.virtual_screen_capture:\n            from pyvirtualdisplay.smartdisplay import SmartDisplay\n            self.virtual_display = SmartDisplay(size=SCREEN_CAPTURE_RESOLUTION)\n            self.virtual_display.start()\n        self.force_render = force_render\n\n        self.sim_params = self.__parse_sim_params(self.cfg[\"physics_engine\"], self.cfg[\"sim\"])\n        if self.cfg[\"physics_engine\"] == \"physx\":\n            self.physics_engine = gymapi.SIM_PHYSX\n        elif self.cfg[\"physics_engine\"] == \"flex\":\n            self.physics_engine = gymapi.SIM_FLEX\n        else:\n            msg = f\"Invalid physics engine backend: {self.cfg['physics_engine']}\"\n            raise ValueError(msg)\n\n        # optimization flags for pytorch JIT\n        torch._C._jit_set_profiling_mode(False)\n        torch._C._jit_set_profiling_executor(False)\n\n        self.gym = gymapi.acquire_gym()\n\n        self.first_randomization = True\n        self.original_props = {}\n        self.dr_randomizations = {}\n        self.actor_params_generator = None\n        self.extern_actor_params = {}\n        self.last_step = -1\n        self.last_rand_step = -1\n        for env_id in range(self.num_envs):\n            self.extern_actor_params[env_id] = None\n\n        # create envs, sim and viewer\n        self.sim_initialized = False\n        self.create_sim()\n        self.gym.prepare_sim(self.sim)\n        self.sim_initialized = True\n\n        self.set_viewer()\n        self.allocate_buffers()\n\n        self.obs_dict = {}\n\n    def set_viewer(self):\n        \"\"\"Create the viewer.\"\"\"\n\n        # todo: read from config\n        self.enable_viewer_sync = True\n        self.viewer = None\n\n        # if running with a viewer, set up keyboard shortcuts and camera\n        if self.headless == False:\n            # subscribe to keyboard shortcuts\n            self.viewer = self.gym.create_viewer(\n                self.sim, gymapi.CameraProperties())\n            self.gym.subscribe_viewer_keyboard_event(\n                self.viewer, gymapi.KEY_ESCAPE, \"QUIT\")\n            self.gym.subscribe_viewer_keyboard_event(\n                self.viewer, gymapi.KEY_V, \"toggle_viewer_sync\")\n\n            # set the camera position based on up axis\n            sim_params = self.gym.get_sim_params(self.sim)\n            if sim_params.up_axis == gymapi.UP_AXIS_Z:\n                cam_pos = gymapi.Vec3(20.0, 25.0, 3.0)\n                cam_target = gymapi.Vec3(10.0, 15.0, 0.0)\n            else:\n                cam_pos = gymapi.Vec3(20.0, 3.0, 25.0)\n                cam_target = gymapi.Vec3(10.0, 0.0, 15.0)\n\n            self.gym.viewer_camera_look_at(\n                self.viewer, None, cam_pos, cam_target)\n\n    def allocate_buffers(self):\n        \"\"\"Allocate the observation, states, etc. buffers.\n\n        These are what is used to set observations and states in the environment classes which\n        inherit from this one, and are read in `step` and other related functions.\n\n        \"\"\"\n\n        # allocate buffers\n        self.obs_buf = torch.zeros(\n            (self.num_envs, self.num_obs), device=self.device, dtype=torch.float)\n        self.states_buf = torch.zeros(\n            (self.num_envs, self.num_states), device=self.device, dtype=torch.float)\n        self.rew_buf = torch.zeros(\n            self.num_envs, device=self.device, dtype=torch.float)\n        self.reset_buf = torch.ones(\n            self.num_envs, device=self.device, dtype=torch.long)\n        self.timeout_buf = torch.zeros(\n             self.num_envs, device=self.device, dtype=torch.long)\n        self.progress_buf = torch.zeros(\n            self.num_envs, device=self.device, dtype=torch.long)\n        self.randomize_buf = torch.zeros(\n            self.num_envs, device=self.device, dtype=torch.long)\n        self.extras = {}\n\n    def create_sim(self, compute_device: int, graphics_device: int, physics_engine, sim_params: gymapi.SimParams):\n        \"\"\"Create an Isaac Gym sim object.\n\n        Args:\n            compute_device: ID of compute device to use.\n            graphics_device: ID of graphics device to use.\n            physics_engine: physics engine to use (`gymapi.SIM_PHYSX` or `gymapi.SIM_FLEX`)\n            sim_params: sim params to use.\n        Returns:\n            the Isaac Gym sim object.\n        \"\"\"\n        sim = _create_sim_once(self.gym, compute_device, graphics_device, physics_engine, sim_params)\n        if sim is None:\n            print(\"*** Failed to create sim\")\n            quit()\n\n        return sim\n\n    def get_state(self):\n        \"\"\"Returns the state buffer of the environment (the privileged observations for asymmetric training).\"\"\"\n        return torch.clamp(self.states_buf, -self.clip_obs, self.clip_obs).to(self.rl_device)\n\n    @abc.abstractmethod\n    def pre_physics_step(self, actions: torch.Tensor):\n        \"\"\"Apply the actions to the environment (eg by setting torques, position targets).\n\n        Args:\n            actions: the actions to apply\n        \"\"\"\n\n    @abc.abstractmethod\n    def post_physics_step(self):\n        \"\"\"Compute reward and observations, reset any environments that require it.\"\"\"\n\n    def step(self, actions: torch.Tensor) -> Tuple[Dict[str, torch.Tensor], torch.Tensor, torch.Tensor, Dict[str, Any]]:\n        \"\"\"Step the physics of the environment.\n\n        Args:\n            actions: actions to apply\n        Returns:\n            Observations, rewards, resets, info\n            Observations are dict of observations (currently only one member called 'obs')\n        \"\"\"\n\n        # randomize actions\n        if self.dr_randomizations.get('actions', None):\n            actions = self.dr_randomizations['actions']['noise_lambda'](actions)\n\n        action_tensor = torch.clamp(actions, -self.clip_actions, self.clip_actions)\n        # apply actions\n        self.pre_physics_step(action_tensor)\n\n        # step physics and render each frame\n        for i in range(self.control_freq_inv):\n            if self.force_render:\n                self.render()\n            self.gym.simulate(self.sim)\n\n        # to fix!\n        if self.device == 'cpu':\n            self.gym.fetch_results(self.sim, True)\n\n        # compute observations, rewards, resets, ...\n        self.post_physics_step()\n\n        # fill time out buffer: set to 1 if we reached the max episode length AND the reset buffer is 1. Timeout == 1 makes sense only if the reset buffer is 1.\n        self.timeout_buf = (self.progress_buf >= self.max_episode_length - 1) & (self.reset_buf != 0)\n\n        # randomize observations\n        if self.dr_randomizations.get('observations', None):\n            self.obs_buf = self.dr_randomizations['observations']['noise_lambda'](self.obs_buf)\n\n        self.extras[\"time_outs\"] = self.timeout_buf.to(self.rl_device)\n\n        self.obs_dict[\"obs\"] = torch.clamp(self.obs_buf, -self.clip_obs, self.clip_obs).to(self.rl_device)\n\n        # asymmetric actor-critic\n        if self.num_states > 0:\n            self.obs_dict[\"states\"] = self.get_state()\n\n        return self.obs_dict, self.rew_buf.to(self.rl_device), self.reset_buf.to(self.rl_device), self.extras\n\n    def zero_actions(self) -> torch.Tensor:\n        \"\"\"Returns a buffer with zero actions.\n\n        Returns:\n            A buffer of zero torch actions\n        \"\"\"\n        actions = torch.zeros([self.num_envs, self.num_actions], dtype=torch.float32, device=self.rl_device)\n\n        return actions\n\n    def reset_idx(self, env_idx):\n        \"\"\"Reset environment with indces in env_idx. \n        Should be implemented in an environment class inherited from VecTask.\n        \"\"\"  \n        pass\n\n    def reset(self):\n        \"\"\"Is called only once when environment starts to provide the first observations.\n        Doesn't calculate observations. Actual reset and observation calculation need to be implemented by user.\n        Returns:\n            Observation dictionary\n        \"\"\"\n        self.obs_dict[\"obs\"] = torch.clamp(self.obs_buf, -self.clip_obs, self.clip_obs).to(self.rl_device)\n\n        # asymmetric actor-critic\n        if self.num_states > 0:\n            self.obs_dict[\"states\"] = self.get_state()\n\n        return self.obs_dict\n\n    def reset_done(self):\n        \"\"\"Reset the environment.\n        Returns:\n            Observation dictionary, indices of environments being reset\n        \"\"\"\n        done_env_ids = self.reset_buf.nonzero(as_tuple=False).flatten()\n        if len(done_env_ids) > 0:\n            self.reset_idx(done_env_ids)\n\n        self.obs_dict[\"obs\"] = torch.clamp(self.obs_buf, -self.clip_obs, self.clip_obs).to(self.rl_device)\n\n        # asymmetric actor-critic\n        if self.num_states > 0:\n            self.obs_dict[\"states\"] = self.get_state()\n\n        return self.obs_dict, done_env_ids\n\n    def render(self, mode=\"rgb_array\"):\n        \"\"\"Draw the frame to the viewer, and check for keyboard events.\"\"\"\n        if self.viewer:\n            # check for window closed\n            if self.gym.query_viewer_has_closed(self.viewer):\n                sys.exit()\n\n            # check for keyboard events\n            for evt in self.gym.query_viewer_action_events(self.viewer):\n                if evt.action == \"QUIT\" and evt.value > 0:\n                    sys.exit()\n                elif evt.action == \"toggle_viewer_sync\" and evt.value > 0:\n                    self.enable_viewer_sync = not self.enable_viewer_sync\n\n            # fetch results\n            if self.device != 'cpu':\n                self.gym.fetch_results(self.sim, True)\n\n            # step graphics\n            if self.enable_viewer_sync:\n                self.gym.step_graphics(self.sim)\n                self.gym.draw_viewer(self.viewer, self.sim, True)\n\n                # Wait for dt to elapse in real time.\n                # This synchronizes the physics simulation with the rendering rate.\n                self.gym.sync_frame_time(self.sim)\n\n            else:\n                self.gym.poll_viewer_events(self.viewer)\n\n            if self.virtual_display and mode == \"rgb_array\":\n                img = self.virtual_display.grab()\n                return np.array(img)\n\n    def __parse_sim_params(self, physics_engine: str, config_sim: Dict[str, Any]) -> gymapi.SimParams:\n        \"\"\"Parse the config dictionary for physics stepping settings.\n\n        Args:\n            physics_engine: which physics engine to use. \"physx\" or \"flex\"\n            config_sim: dict of sim configuration parameters\n        Returns\n            IsaacGym SimParams object with updated settings.\n        \"\"\"\n        sim_params = gymapi.SimParams()\n\n        # check correct up-axis\n        if config_sim[\"up_axis\"] not in [\"z\", \"y\"]:\n            msg = f\"Invalid physics up-axis: {config_sim['up_axis']}\"\n            print(msg)\n            raise ValueError(msg)\n\n        # assign general sim parameters\n        sim_params.dt = config_sim[\"dt\"]\n        sim_params.num_client_threads = config_sim.get(\"num_client_threads\", 0)\n        sim_params.use_gpu_pipeline = config_sim[\"use_gpu_pipeline\"]\n        sim_params.substeps = config_sim.get(\"substeps\", 2)\n\n        # assign up-axis\n        if config_sim[\"up_axis\"] == \"z\":\n            sim_params.up_axis = gymapi.UP_AXIS_Z\n        else:\n            sim_params.up_axis = gymapi.UP_AXIS_Y\n\n        # assign gravity\n        sim_params.gravity = gymapi.Vec3(*config_sim[\"gravity\"])\n\n        # configure physics parameters\n        if physics_engine == \"physx\":\n            # set the parameters\n            if \"physx\" in config_sim:\n                for opt in config_sim[\"physx\"].keys():\n                    if opt == \"contact_collection\":\n                        setattr(sim_params.physx, opt, gymapi.ContactCollection(config_sim[\"physx\"][opt]))\n                    else:\n                        setattr(sim_params.physx, opt, config_sim[\"physx\"][opt])\n        else:\n            # set the parameters\n            if \"flex\" in config_sim:\n                for opt in config_sim[\"flex\"].keys():\n                    setattr(sim_params.flex, opt, config_sim[\"flex\"][opt])\n\n        # return the configured params\n        return sim_params\n\n    \"\"\"\n    Domain Randomization methods\n    \"\"\"\n\n    def get_actor_params_info(self, dr_params: Dict[str, Any], env):\n        \"\"\"Generate a flat array of actor params, their names and ranges.\n\n        Returns:\n            The array\n        \"\"\"\n\n        if \"actor_params\" not in dr_params:\n            return None\n        params = []\n        names = []\n        lows = []\n        highs = []\n        param_getters_map = get_property_getter_map(self.gym)\n        for actor, actor_properties in dr_params[\"actor_params\"].items():\n            handle = self.gym.find_actor_handle(env, actor)\n            for prop_name, prop_attrs in actor_properties.items():\n                if prop_name == 'color':\n                    continue  # this is set randomly\n                props = param_getters_map[prop_name](env, handle)\n                if not isinstance(props, list):\n                    props = [props]\n                for prop_idx, prop in enumerate(props):\n                    for attr, attr_randomization_params in prop_attrs.items():\n                        name = prop_name+'_' + str(prop_idx) + '_'+attr\n                        lo_hi = attr_randomization_params['range']\n                        distr = attr_randomization_params['distribution']\n                        if 'uniform' not in distr:\n                            lo_hi = (-1.0*float('Inf'), float('Inf'))\n                        if isinstance(prop, np.ndarray):\n                            for attr_idx in range(prop[attr].shape[0]):\n                                params.append(prop[attr][attr_idx])\n                                names.append(name+'_'+str(attr_idx))\n                                lows.append(lo_hi[0])\n                                highs.append(lo_hi[1])\n                        else:\n                            params.append(getattr(prop, attr))\n                            names.append(name)\n                            lows.append(lo_hi[0])\n                            highs.append(lo_hi[1])\n        return params, names, lows, highs\n\n    def apply_randomizations(self, dr_params):\n        \"\"\"Apply domain randomizations to the environment.\n\n        Note that currently we can only apply randomizations only on resets, due to current PhysX limitations\n\n        Args:\n            dr_params: parameters for domain randomization to use.\n        \"\"\"\n\n        # If we don't have a randomization frequency, randomize every step\n        rand_freq = dr_params.get(\"frequency\", 1)\n\n        # First, determine what to randomize:\n        #   - non-environment parameters when > frequency steps have passed since the last non-environment\n        #   - physical environments in the reset buffer, which have exceeded the randomization frequency threshold\n        #   - on the first call, randomize everything\n        self.last_step = self.gym.get_frame_count(self.sim)\n        if self.first_randomization:\n            do_nonenv_randomize = True\n            env_ids = list(range(self.num_envs))\n        else:\n            do_nonenv_randomize = (self.last_step - self.last_rand_step) >= rand_freq\n            rand_envs = torch.where(self.randomize_buf >= rand_freq, torch.ones_like(self.randomize_buf), torch.zeros_like(self.randomize_buf))\n            rand_envs = torch.logical_and(rand_envs, self.reset_buf)\n            env_ids = torch.nonzero(rand_envs, as_tuple=False).squeeze(-1).tolist()\n            self.randomize_buf[rand_envs] = 0\n\n        if do_nonenv_randomize:\n            self.last_rand_step = self.last_step\n\n        param_setters_map = get_property_setter_map(self.gym)\n        param_setter_defaults_map = get_default_setter_args(self.gym)\n        param_getters_map = get_property_getter_map(self.gym)\n\n        # On first iteration, check the number of buckets\n        if self.first_randomization:\n            check_buckets(self.gym, self.envs, dr_params)\n\n        for nonphysical_param in [\"observations\", \"actions\"]:\n            if nonphysical_param in dr_params and do_nonenv_randomize:\n                dist = dr_params[nonphysical_param][\"distribution\"]\n                op_type = dr_params[nonphysical_param][\"operation\"]\n                sched_type = dr_params[nonphysical_param][\"schedule\"] if \"schedule\" in dr_params[nonphysical_param] else None\n                sched_step = dr_params[nonphysical_param][\"schedule_steps\"] if \"schedule\" in dr_params[nonphysical_param] else None\n                op = operator.add if op_type == 'additive' else operator.mul\n\n                if sched_type == 'linear':\n                    sched_scaling = 1.0 / sched_step * \\\n                        min(self.last_step, sched_step)\n                elif sched_type == 'constant':\n                    sched_scaling = 0 if self.last_step < sched_step else 1\n                else:\n                    sched_scaling = 1\n\n                if dist == 'gaussian':\n                    mu, var = dr_params[nonphysical_param][\"range\"]\n                    mu_corr, var_corr = dr_params[nonphysical_param].get(\"range_correlated\", [0., 0.])\n\n                    if op_type == 'additive':\n                        mu *= sched_scaling\n                        var *= sched_scaling\n                        mu_corr *= sched_scaling\n                        var_corr *= sched_scaling\n                    elif op_type == 'scaling':\n                        var = var * sched_scaling  # scale up var over time\n                        mu = mu * sched_scaling + 1.0 * \\\n                            (1.0 - sched_scaling)  # linearly interpolate\n\n                        var_corr = var_corr * sched_scaling  # scale up var over time\n                        mu_corr = mu_corr * sched_scaling + 1.0 * \\\n                            (1.0 - sched_scaling)  # linearly interpolate\n\n                    def noise_lambda(tensor, param_name=nonphysical_param):\n                        params = self.dr_randomizations[param_name]\n                        corr = params.get('corr', None)\n                        if corr is None:\n                            corr = torch.randn_like(tensor)\n                            params['corr'] = corr\n                        corr = corr * params['var_corr'] + params['mu_corr']\n                        return op(\n                            tensor, corr + torch.randn_like(tensor) * params['var'] + params['mu'])\n\n                    self.dr_randomizations[nonphysical_param] = {'mu': mu, 'var': var, 'mu_corr': mu_corr, 'var_corr': var_corr, 'noise_lambda': noise_lambda}\n\n                elif dist == 'uniform':\n                    lo, hi = dr_params[nonphysical_param][\"range\"]\n                    lo_corr, hi_corr = dr_params[nonphysical_param].get(\"range_correlated\", [0., 0.])\n\n                    if op_type == 'additive':\n                        lo *= sched_scaling\n                        hi *= sched_scaling\n                        lo_corr *= sched_scaling\n                        hi_corr *= sched_scaling\n                    elif op_type == 'scaling':\n                        lo = lo * sched_scaling + 1.0 * (1.0 - sched_scaling)\n                        hi = hi * sched_scaling + 1.0 * (1.0 - sched_scaling)\n                        lo_corr = lo_corr * sched_scaling + 1.0 * (1.0 - sched_scaling)\n                        hi_corr = hi_corr * sched_scaling + 1.0 * (1.0 - sched_scaling)\n\n                    def noise_lambda(tensor, param_name=nonphysical_param):\n                        params = self.dr_randomizations[param_name]\n                        corr = params.get('corr', None)\n                        if corr is None:\n                            corr = torch.randn_like(tensor)\n                            params['corr'] = corr\n                        corr = corr * (params['hi_corr'] - params['lo_corr']) + params['lo_corr']\n                        return op(tensor, corr + torch.rand_like(tensor) * (params['hi'] - params['lo']) + params['lo'])\n\n                    self.dr_randomizations[nonphysical_param] = {'lo': lo, 'hi': hi, 'lo_corr': lo_corr, 'hi_corr': hi_corr, 'noise_lambda': noise_lambda}\n\n        if \"sim_params\" in dr_params and do_nonenv_randomize:\n            prop_attrs = dr_params[\"sim_params\"]\n            prop = self.gym.get_sim_params(self.sim)\n\n            if self.first_randomization:\n                self.original_props[\"sim_params\"] = {\n                    attr: getattr(prop, attr) for attr in dir(prop)}\n\n            for attr, attr_randomization_params in prop_attrs.items():\n                apply_random_samples(\n                    prop, self.original_props[\"sim_params\"], attr, attr_randomization_params, self.last_step)\n\n            self.gym.set_sim_params(self.sim, prop)\n\n        # If self.actor_params_generator is initialized: use it to\n        # sample actor simulation params. This gives users the\n        # freedom to generate samples from arbitrary distributions,\n        # e.g. use full-covariance distributions instead of the DR's\n        # default of treating each simulation parameter independently.\n        extern_offsets = {}\n        if self.actor_params_generator is not None:\n            for env_id in env_ids:\n                self.extern_actor_params[env_id] = \\\n                    self.actor_params_generator.sample()\n                extern_offsets[env_id] = 0\n\n        for actor, actor_properties in dr_params[\"actor_params\"].items():\n            for env_id in env_ids:\n                env = self.envs[env_id]\n                handle = self.gym.find_actor_handle(env, actor)\n                extern_sample = self.extern_actor_params[env_id]\n\n                for prop_name, prop_attrs in actor_properties.items():\n                    if prop_name == 'color':\n                        num_bodies = self.gym.get_actor_rigid_body_count(\n                            env, handle)\n                        for n in range(num_bodies):\n                            self.gym.set_rigid_body_color(env, handle, n, gymapi.MESH_VISUAL,\n                                                          gymapi.Vec3(random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1)))\n                        continue\n                    if prop_name == 'scale':\n                        setup_only = prop_attrs.get('setup_only', False)\n                        if (setup_only and not self.sim_initialized) or not setup_only:\n                            attr_randomization_params = prop_attrs\n                            sample = generate_random_samples(attr_randomization_params, 1,\n                                                             self.last_step, None)\n                            og_scale = 1\n                            if attr_randomization_params['operation'] == 'scaling':\n                                new_scale = og_scale * sample\n                            elif attr_randomization_params['operation'] == 'additive':\n                                new_scale = og_scale + sample\n                            self.gym.set_actor_scale(env, handle, new_scale)\n                        continue\n\n                    prop = param_getters_map[prop_name](env, handle)\n                    set_random_properties = True\n                    if isinstance(prop, list):\n                        if self.first_randomization:\n                            self.original_props[prop_name] = [\n                                {attr: getattr(p, attr) for attr in dir(p)} for p in prop]\n                        for p, og_p in zip(prop, self.original_props[prop_name]):\n                            for attr, attr_randomization_params in prop_attrs.items():\n                                setup_only = attr_randomization_params.get('setup_only', False)\n                                if (setup_only and not self.sim_initialized) or not setup_only:\n                                    smpl = None\n                                    if self.actor_params_generator is not None:\n                                        smpl, extern_offsets[env_id] = get_attr_val_from_sample(\n                                            extern_sample, extern_offsets[env_id], p, attr)\n                                    apply_random_samples(\n                                        p, og_p, attr, attr_randomization_params,\n                                        self.last_step, smpl)\n                                else:\n                                    set_random_properties = False\n                    else:\n                        if self.first_randomization:\n                            self.original_props[prop_name] = deepcopy(prop)\n                        for attr, attr_randomization_params in prop_attrs.items():\n                            setup_only = attr_randomization_params.get('setup_only', False)\n                            if (setup_only and not self.sim_initialized) or not setup_only:\n                                smpl = None\n                                if self.actor_params_generator is not None:\n                                    smpl, extern_offsets[env_id] = get_attr_val_from_sample(\n                                        extern_sample, extern_offsets[env_id], prop, attr)\n                                apply_random_samples(\n                                    prop, self.original_props[prop_name], attr,\n                                    attr_randomization_params, self.last_step, smpl)\n                            else:\n                                set_random_properties = False\n\n                    if set_random_properties:\n                        setter = param_setters_map[prop_name]\n                        default_args = param_setter_defaults_map[prop_name]\n                        setter(env, handle, prop, *default_args)\n\n        if self.actor_params_generator is not None:\n            for env_id in env_ids:  # check that we used all dims in sample\n                if extern_offsets[env_id] > 0:\n                    extern_sample = self.extern_actor_params[env_id]\n                    if extern_offsets[env_id] != extern_sample.shape[0]:\n                        print('env_id', env_id,\n                              'extern_offset', extern_offsets[env_id],\n                              'vs extern_sample.shape', extern_sample.shape)\n                        raise Exception(\"Invalid extern_sample size\")\n\n        self.first_randomization = False\n\n"
  },
  {
    "path": "timechamber/tasks/data/assets/mjcf/amp_humanoid_sword_shield.xml",
    "content": "<mujoco model=\"humanoid\">\n\n  <statistic extent=\"2\" center=\"0 0 1\"/>\n\n  <option timestep=\"0.00555\"/>\n\n  <default>\n    <motor ctrlrange=\"-1 1\" ctrllimited=\"true\"/>\n    <default class=\"body\">\n      <geom type=\"capsule\" condim=\"1\" friction=\"1.0 0.05 0.05\" solimp=\".9 .99 .003\" solref=\".015 1\"/>\n      <joint type=\"hinge\" damping=\"0.1\" stiffness=\"5\" armature=\".007\" limited=\"true\" solimplimit=\"0 .99 .01\"/>\n      <site size=\".04\" group=\"3\"/>\n      <default class=\"force-torque\">\n        <site type=\"box\" size=\".01 .01 .02\" rgba=\"1 0 0 1\" />\n      </default>\n      <default class=\"touch\">\n        <site type=\"capsule\" rgba=\"0 0 1 .3\"/>\n      </default>\n    </default>\n  </default>\n\n  <worldbody>\n    <geom name=\"floor\" type=\"plane\" conaffinity=\"1\" size=\"100 100 .2\" material=\"grid\"/>\n    <body name=\"pelvis\" pos=\"0 0 1\" childclass=\"body\">\n      <freejoint name=\"root\"/>\n      <site name=\"root\" class=\"force-torque\"/>\n      <geom name=\"pelvis\" type=\"sphere\" pos=\"0 0 0.07\" size=\".09\" density=\"2226\"/>\n      <geom name=\"upper_waist\" type=\"sphere\" pos=\"0 0 0.205\" size=\"0.07\" density=\"2226\"/>\n      <site name=\"pelvis\" class=\"touch\" type=\"sphere\" pos=\"0 0 0.07\" size=\"0.091\"/>\n      <site name=\"upper_waist\" class=\"touch\" type=\"sphere\" pos=\"0 0 0.205\" size=\"0.071\"/>\n\n      <body name=\"torso\" pos=\"0 0 0.236151\">\n        <light name=\"top\" pos=\"0 0 2\" mode=\"trackcom\"/>\n        <camera name=\"back\" pos=\"-3 0 1\" xyaxes=\"0 -1 0 1 0 2\" mode=\"trackcom\"/>\n        <camera name=\"side\" pos=\"0 -3 1\" xyaxes=\"1 0 0 0 1 2\" mode=\"trackcom\"/>\n        <joint name=\"abdomen_x\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"-60 60\" stiffness=\"1000\" damping=\"100\" armature=\".02\"/>\n        <joint name=\"abdomen_y\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-60 90\" stiffness=\"1000\" damping=\"100\" armature=\".02\"/>\n        <joint name=\"abdomen_z\" pos=\"0 0 0\" axis=\"0 0 1\" range=\"-50 50\" stiffness=\"1000\" damping=\"100\" armature=\".02\"/>\n        <geom name=\"torso\" type=\"sphere\" pos=\"0 0 0.12\" size=\"0.11\" density=\"1794\"/>\n        <site name=\"torso\" class=\"touch\" type=\"sphere\" pos=\"0 0 0.12\" size=\"0.111\"/>\n\n        <geom name=\"right_clavicle\" fromto=\"-0.0060125 -0.0457775 0.2287955 -0.016835 -0.128177 0.2376182\" size=\".045\" density=\"1100\"/>\n        <geom name=\"left_clavicle\" fromto=\"-0.0060125 0.0457775 0.2287955 -0.016835 0.128177 0.2376182\" size=\".045\" density=\"1100\"/>\n\n        <body name=\"head\" pos=\"0 0 0.223894\">\n          <joint name=\"neck_x\" axis=\"1 0 0\" range=\"-50 50\" stiffness=\"100\" damping=\"10\" armature=\".01\"/>\n          <joint name=\"neck_y\" axis=\"0 1 0\" range=\"-40 60\" stiffness=\"100\" damping=\"10\" armature=\".01\"/>\n          <joint name=\"neck_z\" axis=\"0 0 1\" range=\"-45 45\" stiffness=\"100\" damping=\"10\" armature=\".01\"/>\n          <geom name=\"head\" type=\"sphere\" pos=\"0 0 0.175\" size=\"0.095\" density=\"1081\"/>\n          <site name=\"head\" class=\"touch\" pos=\"0 0 0.175\" type=\"sphere\" size=\"0.103\"/>\n          <camera name=\"egocentric\" pos=\".103 0 0.175\" xyaxes=\"0 -1 0 .1 0 1\" fovy=\"80\"/>\n        </body>\n\n        <body name=\"right_upper_arm\" pos=\"-0.02405 -0.18311 0.24350\">\n          <joint name=\"right_shoulder_x\" axis=\"1 0 0\" range=\"-180 45\" stiffness=\"400\" damping=\"40\" armature=\".02\"/>\n          <joint name=\"right_shoulder_y\" axis=\"0 1 0\" range=\"-180 60\" stiffness=\"400\" damping=\"40\" armature=\".02\"/>\n          <joint name=\"right_shoulder_z\" axis=\"0 0 1\"  range=\"-90 90\" stiffness=\"400\" damping=\"40\" armature=\".02\"/>\n          <geom name=\"right_upper_arm\" fromto=\"0 0 -0.05 0 0 -0.23\" size=\".045\" density=\"982\"/>\n          <site name=\"right_upper_arm\" class=\"touch\" pos=\"0 0 -0.14\" size=\"0.046 0.1\" zaxis=\"0 0 1\"/>\n\n          <body name=\"right_lower_arm\" pos=\"0 0 -0.274788\">\n            <joint name=\"right_elbow\" axis=\"0 1 0\" range=\"-160 0\" stiffness=\"300\" damping=\"30\" armature=\".01\"/>\n            <geom name=\"right_lower_arm\" fromto=\"0 0 -0.0525 0 0 -0.1875\" size=\"0.04\" density=\"1056\"/>\n            <site name=\"right_lower_arm\" class=\"touch\" pos=\"0 0 -0.12\" size=\"0.041 0.0685\" zaxis=\"0 1 0\"/>\n\n            <body name=\"right_hand\" pos=\"0 0 -0.258947\">\n              <joint name=\"right_hand_x\" axis=\"1 0 0\" range=\"-90 90\" stiffness=\"100\" damping=\"10\" armature=\".01\"/>\n\t\t\t        <joint name=\"right_hand_y\" axis=\"0 1 0\" range=\"-90 90\" stiffness=\"100\" damping=\"10\" armature=\".003\"/>\n\t\t\t        <joint name=\"right_hand_z\" axis=\"0 0 1\"  range=\"-90 90\" stiffness=\"100\" damping=\"10\" armature=\".003\"/>\n\t\t\t        <geom name=\"right_hand\" type=\"sphere\" size=\".04\" density=\"1865\"/>\n              <site name=\"right_hand\" class=\"touch\" type=\"sphere\" size=\".041\"/>\n\n              <body name=\"sword\" pos=\"0.74 0 0\">\n                <geom name=\"sword_hilt\" fromto=\"-0.87 0 0 -0.64 0 0\" size=\"0.023\" density=\"300\"/>\n\t\t\t          <geom name=\"sword_blade\" type=\"box\" pos=\"-0.34 0 0\" size=\"0.34 0.01 0.035\" density=\"600\"/>\n              </body>\n            </body>\n          </body>\n        </body>\n\n        <body name=\"left_upper_arm\" pos=\"-0.02405 0.18311 0.24350\">\n          <joint name=\"left_shoulder_x\" axis=\"1 0 0\" range=\"-45 180\" stiffness=\"400\" damping=\"40\" armature=\".02\"/>\n          <joint name=\"left_shoulder_y\" axis=\"0 1 0\" range=\"-180 60\" stiffness=\"400\" damping=\"40\" armature=\".02\"/>\n          <joint name=\"left_shoulder_z\" axis=\"0 0 1\"  range=\"-90 90\" stiffness=\"400\" damping=\"40\" armature=\".02\"/>\n          <geom name=\"left_upper_arm\" fromto=\"0 0 -0.05 0 0 -0.23\" size=\"0.045\" density=\"982\"/>\n          <site name=\"left_upper_arm\" class=\"touch\" pos=\"0 0 -0.14\" size=\"0.046 0.1\" zaxis=\"0 0 1\"/>\n\n          <body name=\"left_lower_arm\" pos=\"0 0 -0.274788\">\n            <joint name=\"left_elbow\" axis=\"0 1 0\" range=\"-160 0\" stiffness=\"300\" damping=\"30\" armature=\".01\"/>\n            <geom name=\"left_lower_arm\" fromto=\"0 0 -0.0525 0 0 -0.1875\" size=\"0.04\" density=\"1056\"/>\n            <site name=\"left_lower_arm\" class=\"touch\" pos=\"0 0 -0.1\" size=\"0.041 0.0685\" zaxis=\"0 0 1\"/>\n\n            <body name=\"shield\" pos=\"0 0.07 -0.12\">\n              <geom name=\"shield\" type=\"cylinder\" fromto=\"0 0 0 0 0.03 0\" size=\"0.3\" density=\"250\"/>\n            </body>\n            \n            <body name=\"left_hand\" pos=\"0 0 -0.258947\">\n              <geom name=\"left_hand\" type=\"sphere\" size=\".04\" density=\"1865\"/>\n              <site name=\"left_hand\" class=\"touch\" type=\"sphere\" size=\".041\"/>\n            </body>\n          </body>\n        </body>\n      </body>\n\n      <body name=\"right_thigh\" pos=\"0 -0.084887 0\">\n        <site name=\"right_hip\" class=\"force-torque\"/>\n        <joint name=\"right_hip_x\" axis=\"1 0 0\" range=\"-60 15\" stiffness=\"500\" damping=\"50\" armature=\".02\"/>\n        <joint name=\"right_hip_y\" axis=\"0 1 0\" range=\"-140 60\" stiffness=\"500\" damping=\"50\" armature=\".02\"/>\n        <joint name=\"right_hip_z\" axis=\"0 0 1\" range=\"-60 35\" stiffness=\"500\" damping=\"50\" armature=\".02\"/>\n        <geom name=\"right_thigh\" fromto=\"0 0 -0.06 0 0 -0.36\" size=\"0.055\" density=\"1269\"/>\n        <site name=\"right_thigh\" class=\"touch\" pos=\"0 0 -0.21\" size=\"0.056 0.301\" zaxis=\"0 0 -1\"/>\n\n        <body name=\"right_shin\" pos=\"0 0 -0.421546\">\n          <site name=\"right_knee\" class=\"force-torque\" pos=\"0 0 0\"/>\n          <joint name=\"right_knee\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"0 160\" stiffness=\"500\" damping=\"50\" armature=\".02\"/>\n          <geom name=\"right_shin\" fromto=\"0 0 -0.045 0 0 -0.355\"  size=\".05\" density=\"1014\"/>\n          <site name=\"right_shin\" class=\"touch\" pos=\"0 0 -0.2\" size=\"0.051 0.156\" zaxis=\"0 0 -1\"/>\n\n          <body name=\"right_foot\" pos=\"0 0 -0.409870\">\n            <site name=\"right_ankle\" class=\"force-torque\"/>\n            <joint name=\"right_ankle_x\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"-30 30\" stiffness=\"400\" damping=\"40\" armature=\".01\"/>\n            <joint name=\"right_ankle_y\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-55 55\" stiffness=\"400\" damping=\"40\" armature=\".01\"/>\n            <joint name=\"right_ankle_z\" pos=\"0 0 0\" axis=\"0 0 1\" range=\"-40 40\" stiffness=\"400\" damping=\"40\" armature=\".01\"/>\n            <geom name=\"right_foot\" type=\"box\" pos=\"0.045 0 -0.0225\" size=\"0.0885 0.045 0.0275\" density=\"1141\"/>\n            <site name=\"right_foot\" class=\"touch\" type=\"box\" pos=\"0.045 0 -0.0225\" size=\"0.0895 0.055 0.0285\"/>\n          </body>\n        </body>\n      </body>\n\n      <body name=\"left_thigh\" pos=\"0 0.084887 0\">\n        <site name=\"left_hip\" class=\"force-torque\"/>\n        <joint name=\"left_hip_x\" axis=\"1 0 0\" range=\"-15 60\" stiffness=\"500\" damping=\"50\" armature=\".02\"/>\n        <joint name=\"left_hip_y\" axis=\"0 1 0\" range=\"-140 60\" stiffness=\"500\" damping=\"50\" armature=\".02\"/>\n        <joint name=\"left_hip_z\" axis=\"0 0 1\" range=\"-35 60\" stiffness=\"500\" damping=\"50\" armature=\".02\"/>\n        <geom name=\"left_thigh\" fromto=\"0 0 -0.06 0 0 -0.36\" size=\".055\" density=\"1269\"/>\n        <site name=\"left_thigh\" class=\"touch\" pos=\"0 0 -0.21\" size=\"0.056 0.301\" zaxis=\"0 0 -1\"/>\n\n        <body name=\"left_shin\" pos=\"0 0 -0.421546\">\n          <site name=\"left_knee\" class=\"force-torque\" pos=\"0 0 .02\"/>\n          <joint name=\"left_knee\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"0 160\" stiffness=\"500\" damping=\"50\" armature=\".02\"/>\n          <geom name=\"left_shin\" fromto=\"0 0 -0.045 0 0 -0.355\"  size=\".05\" density=\"1014\"/>\n          <site name=\"left_shin\" class=\"touch\" pos=\"0 0 -0.2\" size=\"0.051 0.156\" zaxis=\"0 0 -1\"/>\n\n          <body name=\"left_foot\" pos=\"0 0 -0.409870\">\n            <site name=\"left_ankle\" class=\"force-torque\"/>\n            <joint name=\"left_ankle_x\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"-30 30\" stiffness=\"400\" damping=\"40\" armature=\".01\"/>\n            <joint name=\"left_ankle_y\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-55 55\" stiffness=\"400\" damping=\"40\" armature=\".01\"/>\n            <joint name=\"left_ankle_z\" pos=\"0 0 0\" axis=\"0 0 1\" range=\"-40 40\" stiffness=\"400\" damping=\"40\" armature=\".01\"/>\n            <geom name=\"left_foot\" type=\"box\" pos=\"0.045 0 -0.0225\" size=\"0.0885 0.045 0.0275\" density=\"1141\"/>\n            <site name=\"left_foot\" class=\"touch\" type=\"box\" pos=\"0.045 0 -0.0225\" size=\"0.0895 0.055 0.0285\"/>\n          </body>\n        </body>\n      </body>\n    </body>\n  </worldbody>\n\n  <actuator>\n    <motor name='abdomen_x'       \tgear='200' \tjoint='abdomen_x'/>\n    <motor name='abdomen_y'       \tgear='200' \tjoint='abdomen_y'/>\n    <motor name='abdomen_z'       \tgear='200' \tjoint='abdomen_z'/>\n    <motor name='neck_x'          \tgear='50' \tjoint='neck_x'/>\n    <motor name='neck_y'            gear='50' \tjoint='neck_y'/>\n    <motor name='neck_z'           \tgear='50' \tjoint='neck_z'/>\n    <motor name='right_shoulder_x' \tgear='100' \tjoint='right_shoulder_x'/>\n    <motor name='right_shoulder_y' \tgear='100' \tjoint='right_shoulder_y'/>\n    <motor name='right_shoulder_z' \tgear='100' \tjoint='right_shoulder_z'/>\n    <motor name='right_elbow'     \tgear='70' \tjoint='right_elbow'/>\n    <motor name='right_hand_x' \t  \tgear='50' \tjoint='right_hand_x'/>\n    <motor name='right_hand_y'    \tgear='50' \tjoint='right_hand_y'/>\n    <motor name='right_hand_z'    \tgear='50' \tjoint='right_hand_z'/>\n    <motor name='left_shoulder_x' \tgear='100' \tjoint='left_shoulder_x'/>\n    <motor name='left_shoulder_y' \tgear='100' \tjoint='left_shoulder_y'/>\n    <motor name='left_shoulder_z' \tgear='100' \tjoint='left_shoulder_z'/>\n    <motor name='left_elbow'      \tgear='70' \tjoint='left_elbow'/>\n    <motor name='right_hip_x'     \tgear='200' \tjoint='right_hip_x'/>\n    <motor name='right_hip_z'     \tgear='200' \tjoint='right_hip_z'/>\n    <motor name='right_hip_y'     \tgear='200' \tjoint='right_hip_y'/>\n    <motor name='right_knee'      \tgear='150' \tjoint='right_knee'/>\n    <motor name='right_ankle_x'   \tgear='90' \tjoint='right_ankle_x'/>\n    <motor name='right_ankle_y'   \tgear='90' \tjoint='right_ankle_y'/>\n    <motor name='right_ankle_z'   \tgear='90' \tjoint='right_ankle_z'/>\n    <motor name='left_hip_x'      \tgear='200' \tjoint='left_hip_x'/>\n    <motor name='left_hip_z'      \tgear='200' \tjoint='left_hip_z'/>\n    <motor name='left_hip_y'      \tgear='200' \tjoint='left_hip_y'/>\n    <motor name='left_knee'       \tgear='150' \tjoint='left_knee'/>\n    <motor name='left_ankle_x'    \tgear='90' \tjoint='left_ankle_x'/>\n    <motor name='left_ankle_y'    \tgear='90' \tjoint='left_ankle_y'/>\n    <motor name='left_ankle_z'    \tgear='90' \tjoint='left_ankle_z'/>\n  </actuator>\n\n  <sensor>\n    <subtreelinvel name=\"pelvis_subtreelinvel\" body=\"pelvis\"/>\n    <accelerometer name=\"root_accel\"    site=\"root\"/>\n    <velocimeter name=\"root_vel\"        site=\"root\"/>\n    <gyro name=\"root_gyro\"              site=\"root\"/>\n\n    <force name=\"left_ankle_force\"       site=\"left_ankle\"/>\n    <force name=\"right_ankle_force\"      site=\"right_ankle\"/>\n    <force name=\"left_knee_force\"        site=\"left_knee\"/>\n    <force name=\"right_knee_force\"       site=\"right_knee\"/>\n    <force name=\"left_hip_force\"         site=\"left_hip\"/>\n    <force name=\"right_hip_force\"        site=\"right_hip\"/>\n\n    <torque name=\"left_ankle_torque\"     site=\"left_ankle\"/>\n    <torque name=\"right_ankle_torque\"    site=\"right_ankle\"/>\n    <torque name=\"left_knee_torque\"      site=\"left_knee\"/>\n    <torque name=\"right_knee_torque\"     site=\"right_knee\"/>\n    <torque name=\"left_hip_torque\"       site=\"left_hip\"/>\n    <torque name=\"right_hip_torque\"      site=\"right_hip\"/>\n\n    <touch name=\"pelvis_touch\"           site=\"pelvis\"/>\n    <touch name=\"upper_waist_touch\"      site=\"upper_waist\"/>\n    <touch name=\"torso_touch\"            site=\"torso\"/>\n    <touch name=\"head_touch\"             site=\"head\"/>\n    <touch name=\"right_upper_arm_touch\"  site=\"right_upper_arm\"/>\n    <touch name=\"right_lower_arm_touch\"  site=\"right_lower_arm\"/>\n    <touch name=\"right_hand_touch\"       site=\"right_hand\"/>\n    <touch name=\"left_upper_arm_touch\"   site=\"left_upper_arm\"/>\n    <touch name=\"left_lower_arm_touch\"   site=\"left_lower_arm\"/>\n    <touch name=\"left_hand_touch\"        site=\"left_hand\"/>\n    <touch name=\"right_thigh_touch\"      site=\"right_thigh\"/>\n    <touch name=\"right_shin_touch\"       site=\"right_shin\"/>\n    <touch name=\"right_foot_touch\"       site=\"right_foot\"/>\n    <touch name=\"left_thigh_touch\"       site=\"left_thigh\"/>\n    <touch name=\"left_shin_touch\"        site=\"left_shin\"/>\n    <touch name=\"left_foot_touch\"        site=\"left_foot\"/>\n  </sensor>\n\n</mujoco>\n"
  },
  {
    "path": "timechamber/tasks/data/motions/reallusion_sword_shield/README.txt",
    "content": "This motion data is provided courtesy of Reallusion,\nstrictly for noncommercial use. The original motion data\nis available at:\nhttps://actorcore.reallusion.com/motion/pack/studio-mocap-sword-and-shield-stunts\nhttps://actorcore.reallusion.com/motion/pack/studio-mocap-sword-and-shield-moves\n"
  },
  {
    "path": "timechamber/tasks/data/motions/reallusion_sword_shield/dataset_reallusion_sword_shield.yaml",
    "content": "motions:\n  - file: \"RL_Avatar_Atk_2xCombo01_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Atk_2xCombo02_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Atk_2xCombo03_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Atk_2xCombo04_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Atk_2xCombo05_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Atk_3xCombo01_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Atk_3xCombo02_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Atk_3xCombo03_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Atk_3xCombo04_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Atk_3xCombo05_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Atk_3xCombo06_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Atk_3xCombo07_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Atk_4xCombo01_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Atk_4xCombo02_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Atk_4xCombo03_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Atk_SlashDown_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Atk_SlashLeft_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Atk_SlashRight_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Atk_SlashUp_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Atk_Spin_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Atk_Stab_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Counter_Atk01_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Counter_Atk02_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Counter_Atk03_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Counter_Atk04_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Kill_2xCombo01_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Kill_2xCombo02_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Kill_3xCombo01_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Kill_3xCombo02_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Kill_4xCombo01_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Atk_Jump_Motion.npy\"\n    weight: 0.03105590\n  - file: \"RL_Avatar_Atk_Kick_Motion.npy\"\n    weight: 0.03105590\n  - file: \"RL_Avatar_Atk_ShieldCharge_Motion.npy\"\n    weight: 0.03105590\n  - file: \"RL_Avatar_Atk_ShieldSwipe01_Motion.npy\"\n    weight: 0.03105590\n  - file: \"RL_Avatar_Atk_ShieldSwipe02_Motion.npy\"\n    weight: 0.03105590\n  - file: \"RL_Avatar_Counter_Atk05_Motion.npy\"\n    weight: 0.03105590\n  - file: \"RL_Avatar_Standoff_Feint_Motion.npy\"\n    weight: 0.03105590\n  - file: \"RL_Avatar_Dodge_Backward_Motion.npy\"\n    weight: 0.01552795\n  - file: \"RL_Avatar_RunBackward_Motion.npy\"\n    weight: 0.01552795\n  - file: \"RL_Avatar_WalkBackward01_Motion.npy\"\n    weight: 0.01552795\n  - file: \"RL_Avatar_WalkBackward02_Motion.npy\"\n    weight: 0.01552795\n  - file: \"RL_Avatar_Dodgle_Left_Motion.npy\"\n    weight: 0.01552795\n  - file: \"RL_Avatar_RunLeft_Motion.npy\"\n    weight: 0.01552795\n  - file: \"RL_Avatar_WalkLeft01_Motion.npy\"\n    weight: 0.01552795\n  - file: \"RL_Avatar_WalkLeft02_Motion.npy\"\n    weight: 0.01552795\n  - file: \"RL_Avatar_Dodgle_Right_Motion.npy\"\n    weight: 0.01552795\n  - file: \"RL_Avatar_RunRight_Motion.npy\"\n    weight: 0.01552795\n  - file: \"RL_Avatar_WalkRight01_Motion.npy\"\n    weight: 0.01552795\n  - file: \"RL_Avatar_WalkRight02_Motion.npy\"\n    weight: 0.01552795\n  - file: \"RL_Avatar_RunForward_Motion.npy\"\n    weight: 0.02070393\n  - file: \"RL_Avatar_WalkForward01_Motion.npy\"\n    weight: 0.02070393\n  - file: \"RL_Avatar_WalkForward02_Motion.npy\"\n    weight: 0.02070393\n  - file: \"RL_Avatar_Standoff_Circle_Motion.npy\"\n    weight: 0.06211180\n  - file: \"RL_Avatar_TurnLeft90_Motion.npy\"\n    weight: 0.03105590\n  - file: \"RL_Avatar_TurnLeft180_Motion.npy\"\n    weight: 0.03105590\n  - file: \"RL_Avatar_TurnRight90_Motion.npy\"\n    weight: 0.03105590\n  - file: \"RL_Avatar_TurnRight180_Motion.npy\"\n    weight: 0.03105590\n  - file: \"RL_Avatar_Fall_Backward_Motion.npy\"\n    weight: 0.00869565\n  - file: \"RL_Avatar_Fall_Left_Motion.npy\"\n    weight: 0.00869565\n  - file: \"RL_Avatar_Fall_Right_Motion.npy\"\n    weight: 0.00869565\n  - file: \"RL_Avatar_Fall_SpinLeft_Motion.npy\"\n    weight: 0.00869565\n  - file: \"RL_Avatar_Fall_SpinRight_Motion.npy\"\n    weight: 0.00869565\n  - file: \"RL_Avatar_Idle_Alert(0)_Motion.npy\"\n    weight: 0.00434783\n  - file: \"RL_Avatar_Idle_Alert_Motion.npy\"\n    weight: 0.00434783\n  - file: \"RL_Avatar_Idle_Battle(0)_Motion.npy\"\n    weight: 0.00434783\n  - file: \"RL_Avatar_Idle_Battle_Motion.npy\"\n    weight: 0.00434783\n  - file: \"RL_Avatar_Idle_Ready(0)_Motion.npy\"\n    weight: 0.00434783\n  - file: \"RL_Avatar_Idle_Ready_Motion.npy\"\n    weight: 0.00434783\n  - file: \"RL_Avatar_Standoff_Swing_Motion.npy\"\n    weight: 0.00434783\n  - file: \"RL_Avatar_Taunt_PoundChest_Motion.npy\"\n    weight: 0.00434783\n  - file: \"RL_Avatar_Taunt_Roar_Motion.npy\"\n    weight: 0.00434783\n  - file: \"RL_Avatar_Taunt_ShieldKnock_Motion.npy\"\n    weight: 0.00434783\n  - file: \"RL_Avatar_Shield_BlockBackward_Motion.npy\"\n    weight: 0.00289855\n  - file: \"RL_Avatar_Shield_BlockCrouch_Motion.npy\"\n    weight: 0.00289855\n  - file: \"RL_Avatar_Shield_BlockDown_Motion.npy\"\n    weight: 0.00289855\n  - file: \"RL_Avatar_Shield_BlockLeft_Motion.npy\"\n    weight: 0.00289855\n  - file: \"RL_Avatar_Shield_BlockRight_Motion.npy\"\n    weight: 0.00289855\n  - file: \"RL_Avatar_Shield_BlockUp_Motion.npy\"\n    weight: 0.00289855\n  - file: \"RL_Avatar_Sword_ParryBackward01_Motion.npy\"\n    weight: 0.00289855\n  - file: \"RL_Avatar_Sword_ParryBackward02_Motion.npy\"\n    weight: 0.00289855\n  - file: \"RL_Avatar_Sword_ParryBackward03_Motion.npy\"\n    weight: 0.00289855\n  - file: \"RL_Avatar_Sword_ParryBackward04_Motion.npy\"\n    weight: 0.00289855\n  - file: \"RL_Avatar_Sword_ParryCrouch_Motion.npy\"\n    weight: 0.00289855\n  - file: \"RL_Avatar_Sword_ParryDown_Motion.npy\"\n    weight: 0.00289855\n  - file: \"RL_Avatar_Sword_ParryLeft_Motion.npy\"\n    weight: 0.00289855\n  - file: \"RL_Avatar_Sword_ParryRight_Motion.npy\"\n    weight: 0.00289855\n  - file: \"RL_Avatar_Sword_ParryUp_Motion.npy\"\n    weight: 0.00289855\n"
  },
  {
    "path": "timechamber/tasks/ma_ant_battle.py",
    "content": "from typing import Tuple\nimport os\n\nimport torch\nfrom isaacgym import gymtorch\nfrom isaacgym.gymtorch import *\n\nfrom timechamber.utils.torch_jit_utils import *\nfrom .base.ma_vec_task import MA_VecTask\n\n\nclass MA_Ant_Battle(MA_VecTask):\n\n    def __init__(self, cfg, sim_device, rl_device, graphics_device_id, headless, virtual_screen_capture, force_render):\n\n        self.extras = None\n        self.cfg = cfg\n        self.randomization_params = self.cfg[\"task\"][\"randomization_params\"]\n        self.randomize = self.cfg[\"task\"][\"randomize\"]\n\n        self.max_episode_length = self.cfg[\"env\"][\"episodeLength\"]\n        self.termination_height = self.cfg[\"env\"][\"terminationHeight\"]\n        self.plane_static_friction = self.cfg[\"env\"][\"plane\"][\"staticFriction\"]\n        self.plane_dynamic_friction = self.cfg[\"env\"][\"plane\"][\"dynamicFriction\"]\n        self.plane_restitution = self.cfg[\"env\"][\"plane\"][\"restitution\"]\n        self.action_scale = self.cfg[\"env\"][\"control\"][\"actionScale\"]\n        self.joints_at_limit_cost_scale = self.cfg[\"env\"][\"jointsAtLimitCost\"]\n        self.dof_vel_scale = self.cfg[\"env\"][\"dofVelocityScale\"]\n        self.ant_agents_state = []\n        self.win_reward_scale = 2000\n        self.move_to_op_reward_scale = 1.\n        self.stay_in_center_reward_scale = 0.2\n        self.action_cost_scale = -0.000025\n        self.push_scale = 1.\n        self.dense_reward_scale = 1.0\n        self.hp_decay_scale = 1.\n        self.Kp = self.cfg[\"env\"][\"control\"][\"stiffness\"]\n        self.Kd = self.cfg[\"env\"][\"control\"][\"damping\"]\n        self.cfg[\"env\"][\"numObservations\"] = 32 + 27 * (self.cfg[\"env\"].get(\"numAgents\", 1) - 1)\n        self.cfg[\"env\"][\"numActions\"] = 8\n        self.borderline_space = cfg[\"env\"][\"borderlineSpace\"]\n        self.borderline_space_unit = self.borderline_space / self.max_episode_length\n        self.ant_body_colors = [gymapi.Vec3(*rgb_arr) for rgb_arr in self.cfg[\"env\"][\"color\"]]\n        super().__init__(config=self.cfg, sim_device=sim_device, rl_device=rl_device,\n                         graphics_device_id=graphics_device_id,\n                         headless=headless)\n\n        self.use_central_value = False\n        self.obs_idxs = torch.eye(4, dtype=torch.float32, device=self.device)\n        if self.viewer is not None:\n            for i, env in enumerate(self.envs):\n                self._add_circle_borderline(env, self.borderline_space)\n            cam_pos = gymapi.Vec3(15.0, 0.0, 3.4)\n            cam_target = gymapi.Vec3(10.0, 0.0, 0.0)\n            self.gym.viewer_camera_look_at(self.viewer, None, cam_pos, cam_target)\n\n        # get gym GPU state tensors\n        actor_root_state = self.gym.acquire_actor_root_state_tensor(self.sim)\n        dof_state_tensor = self.gym.acquire_dof_state_tensor(self.sim)\n        sensor_tensor = self.gym.acquire_force_sensor_tensor(self.sim)\n\n        sensors_per_env = 4\n        self.vec_sensor_tensor = gymtorch.wrap_tensor(sensor_tensor).view(self.num_envs,\n                                                                          sensors_per_env * 6)\n\n        self.gym.refresh_dof_state_tensor(self.sim)\n        self.gym.refresh_actor_root_state_tensor(self.sim)\n\n        self.root_states = gymtorch.wrap_tensor(actor_root_state)\n        print(f'root_states:{self.root_states.shape}')\n        self.initial_root_states = self.root_states.clone()\n        self.initial_root_states[:, 7:13] = 0  # set lin_vel and ang_vel to 0\n\n        # create some wrapper tensors for different slices\n        self.dof_state = gymtorch.wrap_tensor(dof_state_tensor)\n        print(f'dof:{self.dof_state.shape}')\n        dof_state_shaped = self.dof_state.view(self.num_envs, -1, 2)\n        for idx in range(self.num_agents):\n            ant_root_state = self.root_states[idx::self.num_agents]\n            ant_dof_pos = dof_state_shaped[:, idx * self.num_dof:(idx + 1) * self.num_dof, 0]\n            ant_dof_vel = dof_state_shaped[:, idx * self.num_dof:(idx + 1) * self.num_dof, 1]\n            self.ant_agents_state.append((ant_root_state, ant_dof_pos, ant_dof_vel))\n\n        self.initial_dof_pos = torch.zeros_like(self.ant_agents_state[0][1], device=self.device, dtype=torch.float)\n        zero_tensor = torch.tensor([0.0], device=self.device)\n        self.initial_dof_pos = torch.where(self.dof_limits_lower > zero_tensor, self.dof_limits_lower,\n                                           torch.where(self.dof_limits_upper < zero_tensor, self.dof_limits_upper,\n                                                       self.initial_dof_pos))\n        self.initial_dof_vel = torch.zeros_like(self.ant_agents_state[0][2], device=self.device, dtype=torch.float)\n        self.dt = self.cfg[\"sim\"][\"dt\"]\n\n        torques = self.gym.acquire_dof_force_tensor(self.sim)\n        self.torques = gymtorch.wrap_tensor(torques).view(self.num_envs, self.num_agents * self.num_dof)\n\n        self.x_unit_tensor = to_torch([1, 0, 0], dtype=torch.float, device=self.device).repeat(\n            (self.num_agents * self.num_envs, 1))\n        self.y_unit_tensor = to_torch([0, 1, 0], dtype=torch.float, device=self.device).repeat(\n            (self.num_agents * self.num_envs, 1))\n        self.z_unit_tensor = to_torch([0, 0, 1], dtype=torch.float, device=self.device).repeat(\n            (self.num_agents * self.num_envs, 1))\n\n    def allocate_buffers(self):\n        self.obs_buf = torch.zeros((self.num_agents * self.num_envs, self.num_obs), device=self.device,\n                                   dtype=torch.float)\n        self.rew_buf = torch.zeros(\n            self.num_envs, device=self.device, dtype=torch.float)\n        self.reset_buf = torch.ones(self.num_envs, device=self.device, dtype=torch.long)\n        self.timeout_buf = torch.zeros(\n            self.num_envs, device=self.device, dtype=torch.long)\n        self.progress_buf = torch.zeros(\n            self.num_envs, device=self.device, dtype=torch.long)\n        self.randomize_buf = torch.zeros(\n            self.num_envs * self.num_agents, device=self.device, dtype=torch.long)\n        self.extras = {'ranks': torch.zeros((self.num_envs, self.num_agents), device=self.device, dtype=torch.long),\n                       'win': torch.zeros((self.num_envs * (self.num_agents - 1),), device=self.device,\n                                          dtype=torch.bool),\n                       'lose': torch.zeros((self.num_envs * (self.num_agents - 1),), device=self.device,\n                                           dtype=torch.bool),\n                       'draw': torch.zeros((self.num_envs * (self.num_agents - 1),), device=self.device,\n                                           dtype=torch.bool)}\n\n    def create_sim(self):\n        self.up_axis_idx = self.set_sim_params_up_axis(self.sim_params, 'z')\n        self.sim = super().create_sim(self.device_id, self.graphics_device_id, self.physics_engine, self.sim_params)\n        lines = []\n        borderline_height = 0.01\n        for height in range(20):\n            for angle in range(360):\n                begin_point = [np.cos(np.radians(angle)), np.sin(np.radians(angle)), borderline_height * height]\n                end_point = [np.cos(np.radians(angle + 1)), np.sin(np.radians(angle + 1)), borderline_height * height]\n                lines.append(begin_point)\n                lines.append(end_point)\n        self.lines = np.array(lines, dtype=np.float32)\n        self._create_ground_plane()\n        print(f'num envs {self.num_envs} env spacing {self.cfg[\"env\"][\"envSpacing\"]}')\n        self._create_envs(self.num_envs, self.cfg[\"env\"]['envSpacing'], int(np.sqrt(self.num_envs)))\n\n        # If randomizing, apply once immediately on startup before the fist sim step\n        if self.randomize:\n            self.apply_randomizations(self.randomization_params)\n\n    def _add_circle_borderline(self, env, radius):\n        lines = self.lines * radius\n        colors = np.array([[1, 0, 0]] * (len(lines) // 2), dtype=np.float32)\n        self.gym.add_lines(self.viewer, env, len(lines) // 2, lines, colors)\n\n    def _create_ground_plane(self):\n        plane_params = gymapi.PlaneParams()\n        plane_params.normal = gymapi.Vec3(0.0, 0.0, 1.0)\n        plane_params.static_friction = self.plane_static_friction\n        plane_params.dynamic_friction = self.plane_dynamic_friction\n        self.gym.add_ground(self.sim, plane_params)\n\n    def _create_envs(self, num_envs, spacing, num_per_row):\n        lower = gymapi.Vec3(-spacing, -spacing, 0.0)\n        upper = gymapi.Vec3(spacing, spacing, spacing)\n\n        asset_root = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../assets')\n        asset_file = \"mjcf/nv_ant.xml\"\n\n        if \"asset\" in self.cfg[\"env\"]:\n            asset_file = self.cfg[\"env\"][\"asset\"].get(\"assetFileName\", asset_file)\n\n        asset_path = os.path.join(asset_root, asset_file)\n        asset_root = os.path.dirname(asset_path)\n        asset_file = os.path.basename(asset_path)\n\n        asset_options = gymapi.AssetOptions()\n        # Note - DOF mode is set in the MJCF file and loaded by Isaac Gym\n        asset_options.default_dof_drive_mode = gymapi.DOF_MODE_NONE\n        asset_options.angular_damping = 0.0\n        ant_assets = []\n        for _ in range(self.num_agents):\n            ant_asset = self.gym.load_asset(self.sim, asset_root, asset_file, asset_options)\n            ant_assets.append(ant_asset)\n        dof_props = self.gym.get_asset_dof_properties(ant_assets[0])\n\n        self.num_dof = self.gym.get_asset_dof_count(ant_assets[0])\n        self.num_bodies = self.gym.get_asset_rigid_body_count(ant_assets[0])\n        for i in range(self.num_dof):\n            dof_props['driveMode'][i] = gymapi.DOF_MODE_POS\n            dof_props['stiffness'][i] = self.Kp\n            dof_props['damping'][i] = self.Kd\n\n        start_pose = gymapi.Transform()\n        start_pose.p = gymapi.Vec3(-self.borderline_space + 1, -self.borderline_space + 1, 1.)\n        self.start_rotation = torch.tensor([start_pose.r.x, start_pose.r.y, start_pose.r.z, start_pose.r.w],\n                                           device=self.device)\n\n        self.torso_index = 0\n        self.num_bodies = self.gym.get_asset_rigid_body_count(ant_assets[0])\n        body_names = [self.gym.get_asset_rigid_body_name(ant_assets[0], i) for i in range(self.num_bodies)]\n        extremity_names = [s for s in body_names if \"foot\" in s]\n        self.extremities_index = torch.zeros(len(extremity_names), dtype=torch.long, device=self.device)\n        print(body_names, extremity_names, self.extremities_index)\n        # create force sensors attached to the \"feet\"\n        extremity_indices = [self.gym.find_asset_rigid_body_index(ant_assets[0], name) for name in extremity_names]\n        sensor_pose = gymapi.Transform()\n        for body_idx in extremity_indices:\n            self.gym.create_asset_force_sensor(ant_assets[0], body_idx, sensor_pose)\n\n        self.ant_handles = []\n        self.actor_indices = []\n        self.envs = []\n        self.dof_limits_lower = []\n        self.dof_limits_upper = []\n\n        for i in range(self.num_envs):\n            # create env instance\n            env_ptr = self.gym.create_env(\n                self.sim, lower, upper, num_per_row\n            )\n            # create actor instance\n            for j in range(self.num_agents):\n                ant_handle = self.gym.create_actor(env_ptr, ant_assets[j], start_pose, \"ant_\" + str(j), i, -1, 0)\n                actor_index = self.gym.get_actor_index(env_ptr, ant_handle, gymapi.DOMAIN_SIM)\n                self.gym.set_actor_dof_properties(env_ptr, ant_handle, dof_props)\n                self.actor_indices.append(actor_index)\n                self.gym.enable_actor_dof_force_sensors(env_ptr, ant_handle)\n                self.ant_handles.append(ant_handle)\n                for k in range(self.num_bodies):\n                    self.gym.set_rigid_body_color(\n                        env_ptr, ant_handle, k, gymapi.MESH_VISUAL, self.ant_body_colors[j])\n            self.envs.append(env_ptr)\n\n        dof_prop = self.gym.get_actor_dof_properties(self.envs[0], self.ant_handles[0])\n\n        for j in range(self.num_dof):\n            if dof_prop['lower'][j] > dof_prop['upper'][j]:\n                self.dof_limits_lower.append(dof_prop['upper'][j])\n                self.dof_limits_upper.append(dof_prop['lower'][j])\n            else:\n                self.dof_limits_lower.append(dof_prop['lower'][j])\n                self.dof_limits_upper.append(dof_prop['upper'][j])\n\n        self.dof_limits_lower = to_torch(self.dof_limits_lower, device=self.device)\n        self.dof_limits_upper = to_torch(self.dof_limits_upper, device=self.device)\n        self.actor_indices = to_torch(self.actor_indices, device=self.device).to(dtype=torch.int32)\n\n        for i in range(len(extremity_names)):\n            self.extremities_index[i] = self.gym.find_actor_rigid_body_handle(self.envs[0], self.ant_handles[0],\n                                                                              extremity_names[i])\n\n    def compute_reward(self, actions):\n\n        self.rew_buf[:], self.reset_buf[:], self.extras['ranks'][:], self.extras['win'], self.extras['lose'], \\\n        self.extras[\n            'draw'] = compute_ant_reward(\n            self.obs_buf,\n            self.reset_buf,\n            self.progress_buf,\n            self.torques,\n            self.extras['ranks'],\n            self.termination_height,\n            self.max_episode_length,\n            self.borderline_space,\n            self.borderline_space_unit,\n            self.win_reward_scale,\n            self.stay_in_center_reward_scale,\n            self.action_cost_scale,\n            self.push_scale,\n            self.joints_at_limit_cost_scale,\n            self.dense_reward_scale,\n            self.dt,\n            self.num_agents\n        )\n\n    def compute_observations(self):\n        self.gym.refresh_dof_state_tensor(self.sim)\n        self.gym.refresh_actor_root_state_tensor(self.sim)\n        self.gym.refresh_force_sensor_tensor(self.sim)\n        self.gym.refresh_dof_force_tensor(self.sim)\n        for agent_idx in range(self.num_agents):\n            self.obs_buf[agent_idx * self.num_envs:(agent_idx + 1) * self.num_envs, :] = compute_ant_observations(\n                self.ant_agents_state,\n                self.progress_buf,\n                self.dof_limits_lower,\n                self.dof_limits_upper,\n                self.dof_vel_scale,\n                self.termination_height,\n                self.borderline_space_unit,\n                self.borderline_space,\n                self.num_agents,\n                agent_idx,\n            )\n\n    def reset_idx(self, env_ids):\n        # print('reset.....', env_ids)\n        # Randomization can happen only at reset time, since it can reset actor positions on GPU\n        if self.randomize:\n            self.apply_randomizations(self.randomization_params)\n\n        positions = torch_rand_float(-0.2, 0.2, (len(env_ids), self.num_dof), device=self.device)\n        velocities = torch_rand_float(-0.1, 0.1, (len(env_ids), self.num_dof), device=self.device)\n\n        for agent_idx in range(self.num_agents):\n            root_state, dof_pos, dof_vel = self.ant_agents_state[agent_idx]\n            dof_pos[env_ids] = tensor_clamp(self.initial_dof_pos[env_ids] + positions, self.dof_limits_lower,\n                                            self.dof_limits_upper)\n            dof_vel[env_ids] = velocities\n        agent_env_ids = expand_env_ids(env_ids, self.num_agents)\n        env_ids_int32 = self.actor_indices[agent_env_ids]\n        rand_angle = torch.rand((len(env_ids),), device=self.device) * torch.pi * 2  # generate angle in 0-360\n\n        rand_pos = (self.borderline_space * torch.ones((len(agent_env_ids), 2), device=self.device) -\n                    torch.rand((len(agent_env_ids), 2), device=self.device))\n\n        unit_angle = 2 * torch.pi / self.num_agents\n        for agent_idx in range(self.num_agents):\n            rand_pos[agent_idx::self.num_agents, 0] *= torch.cos(rand_angle + agent_idx * unit_angle)\n            rand_pos[agent_idx::self.num_agents, 1] *= torch.sin(rand_angle + agent_idx * unit_angle)\n        rand_floats = torch_rand_float(-1.0, 1.0, (len(agent_env_ids), 1), device=self.device)\n        rand_rotation = quat_from_angle_axis(rand_floats[:, 0] * np.pi, self.z_unit_tensor[agent_env_ids])\n        self.root_states[agent_env_ids] = self.initial_root_states[agent_env_ids]\n        self.root_states[agent_env_ids, :2] = rand_pos\n        self.root_states[agent_env_ids, 3:7] = rand_rotation\n        self.gym.set_actor_root_state_tensor_indexed(self.sim,\n                                                     gymtorch.unwrap_tensor(self.root_states),\n                                                     gymtorch.unwrap_tensor(env_ids_int32), len(env_ids_int32))\n\n        self.gym.set_dof_state_tensor_indexed(self.sim,\n                                              gymtorch.unwrap_tensor(self.dof_state),\n                                              gymtorch.unwrap_tensor(env_ids_int32), len(env_ids_int32))\n        self.progress_buf[env_ids] = 0\n        self.reset_buf[env_ids] = 0\n        self.extras['ranks'][env_ids] = 0\n\n    def pre_physics_step(self, actions):\n        # actions.shape = [num_envs * num_agents, num_actions], stacked as followed:\n        # {[(agent1_act_1, agent1_act2)|(agent2_act1, agent2_act2)|...]_(env0),\n        #  [(agent1_act_1, agent1_act2)|(agent2_act1, agent2_act2)|...]_(env1),\n        #  ... }\n\n        self.actions = torch.tensor([], device=self.device)\n        for agent_idx in range(self.num_agents):\n            self.actions = torch.cat((self.actions, actions[agent_idx * self.num_envs:(agent_idx + 1) * self.num_envs]),\n                                     dim=-1)\n        tmp_actions = self.extras['ranks'].unsqueeze(-1).repeat_interleave(self.num_actions, dim=-1).view(self.num_envs,\n                                                                                                          self.num_actions * self.num_agents)\n        zero_actions = torch.zeros_like(tmp_actions, dtype=torch.float)\n        self.actions = torch.where(tmp_actions > 0, zero_actions, self.actions)\n\n        # reshape [num_envs * num_agents, num_actions] to [num_envs, num_agents * num_actions] print(f'action_size{\n\n        targets = self.actions\n\n        self.gym.set_dof_position_target_tensor(self.sim, gymtorch.unwrap_tensor(targets))\n\n    def post_physics_step(self):\n        self.progress_buf += 1\n        self.randomize_buf += 1\n\n        resets = self.reset_buf.reshape(self.num_envs, 1).sum(dim=1)\n        # print(resets)\n        env_ids = (resets == 1).nonzero(as_tuple=False).flatten()\n        if len(env_ids) > 0:\n            self.reset_idx(env_ids)\n\n        self.compute_observations()\n        self.compute_reward(self.actions)\n\n        if self.viewer is not None:\n            self.gym.clear_lines(self.viewer)\n            for i, env in enumerate(self.envs):\n                self._add_circle_borderline(env, self.borderline_space - self.borderline_space_unit * self.progress_buf[\n                    i].item())\n\n    def get_number_of_agents(self):\n        # only train 1 agent\n        return 1\n\n    def zero_actions(self) -> torch.Tensor:\n        \"\"\"Returns a buffer with zero actions.\n\n        Returns:\n            A buffer of zero torch actions\n        \"\"\"\n        actions = torch.zeros([self.num_envs * self.num_agents, self.num_actions], dtype=torch.float32,\n                              device=self.rl_device)\n        self.extras['win'] = self.extras['lose'] = self.extras['draw'] = 0\n        return actions\n\n    def clear_count(self):\n        self.dense_reward_scale *= 0.9\n        self.extras['ranks'] = torch.zeros((self.num_agents, self.num_agents), device=self.device, dtype=torch.float)\n\n\n#####################################################################\n###=========================jit functions=========================###\n#####################################################################\n\n\n@torch.jit.script\ndef expand_env_ids(env_ids, n_agents):\n    # type: (Tensor, int) -> Tensor\n    device = env_ids.device\n    # print(f'nanget:{n_agents}')\n    agent_env_ids = torch.zeros((n_agents * len(env_ids)), device=device, dtype=torch.long)\n    for idx in range(n_agents):\n        agent_env_ids[idx::n_agents] = env_ids * n_agents + idx\n    return agent_env_ids\n\n\n@torch.jit.script\ndef compute_ant_reward(\n        obs_buf,\n        reset_buf,\n        progress_buf,\n        torques,\n        now_rank,\n        termination_height,\n        max_episode_length,\n        borderline_space,\n        borderline_space_unit,\n        win_reward_scale,\n        stay_in_center_reward_scale,\n        action_cost_scale,\n        push_scale,\n        joints_at_limit_cost_scale,\n        dense_reward_scale,\n        dt,\n        num_agents\n):\n    # type: (Tensor, Tensor, Tensor,Tensor,Tensor,float,float,float,float,float,float,float,float,float,float,float,int) -> Tuple[Tensor, Tensor,Tensor,Tensor,Tensor,Tensor]\n    obs = obs_buf.view(num_agents, -1, obs_buf.shape[1])\n    nxt_rank_val = num_agents - torch.count_nonzero(now_rank, dim=-1).view(-1, 1).repeat_interleave(num_agents, dim=-1)\n    is_out = torch.sum(torch.square(obs[:, :, 0:2]), dim=-1) >= \\\n             (borderline_space - progress_buf * borderline_space_unit).square()\n    nxt_rank = torch.where((torch.transpose(is_out, 0, 1) > 0) & (now_rank == 0), nxt_rank_val, now_rank)\n    # reset agents\n    tmp_ones = torch.ones_like(reset_buf)\n    reset = torch.where(is_out[0, :], tmp_ones, reset_buf)\n    reset = torch.where(progress_buf >= max_episode_length - 1, tmp_ones, reset)\n    reset = torch.where(torch.min(is_out[1:], dim=0).values, tmp_ones, reset)\n    tmp_reset = reset.view(-1, 1).repeat_interleave(num_agents, dim=-1)\n    nxt_rank = torch.where((tmp_reset == 1) & (nxt_rank == 0),\n                           nxt_rank_val - 1,\n                           nxt_rank)\n    # compute metric logic\n    tmp_reset = reset.view(1, -1).repeat_interleave(num_agents - 1, dim=0)\n    tmp_zeros = torch.zeros_like(is_out[1:], dtype=torch.bool)\n    wins = torch.ones_like(is_out[1:], dtype=torch.bool)\n    loses = torch.ones_like(is_out[1:], dtype=torch.bool)\n    draws = (progress_buf >= max_episode_length - 1).view(1, -1).repeat_interleave(num_agents - 1, dim=0)\n    wins = torch.where(is_out[1:], wins & (tmp_reset == 1), tmp_zeros)\n    draws = torch.where(is_out[1:] == 0, draws & (tmp_reset == 1), tmp_zeros)\n    loses = torch.where(is_out[1:] == 0, loses & (tmp_reset == 1) & (draws == 0), tmp_zeros)\n\n    sparse_reward = 1.0 * reset\n    reward_per_rank = 2 * win_reward_scale / (num_agents - 1)\n    sparse_reward = sparse_reward * (win_reward_scale - (nxt_rank[:, 0] - 1) * reward_per_rank)\n    stay_in_center_reward = stay_in_center_reward_scale * torch.exp(-torch.linalg.norm(obs[0, :, :2], dim=-1))\n    dof_at_limit_cost = torch.sum(obs[0, :, 13:21] > 0.99, dim=-1) * joints_at_limit_cost_scale\n    action_cost_penalty = torch.sum(torch.square(torques), dim=1) * action_cost_scale\n    # print(\"torques:\", torques[0, 2])\n    not_move_penalty = torch.exp(-torch.sum(torch.abs(torques), dim=1))\n    # print(f'action:...{action_cost_penalty.shape}')\n    dense_reward = dof_at_limit_cost + action_cost_penalty + not_move_penalty + stay_in_center_reward\n    total_reward = sparse_reward + dense_reward * dense_reward_scale\n\n    return total_reward, reset, nxt_rank, wins.flatten(), loses.flatten(), draws.flatten()\n\n\n@torch.jit.script\ndef compute_ant_observations(\n        ant_agents_state,\n        progress_buf,\n        dof_limits_lower,\n        dof_limits_upper,\n        dof_vel_scale,\n        termination_height,\n        borderline_space_unit,\n        borderline_space,\n        num_agents,\n        agent_idx,\n):\n    # type: (List[Tuple[Tensor,Tensor,Tensor]],Tensor,Tensor,Tensor,float,float,float,float,int,int)->Tensor\n    # tot length:13+8+8+1+1+(num_agents-1)*(7+2+8+8+1)\n    self_root_state, self_dof_pos, self_dof_vel = ant_agents_state[agent_idx]\n    dof_pos_scaled = unscale(self_dof_pos, dof_limits_lower, dof_limits_upper)\n    now_border_space = (borderline_space - progress_buf * borderline_space_unit).unsqueeze(-1)\n    obs = torch.cat((self_root_state[:, :13], dof_pos_scaled, self_dof_vel * dof_vel_scale,\n                     now_border_space - torch.sqrt(torch.sum(self_root_state[:, :2].square(), dim=-1)).unsqueeze(-1),\n                     # dis to border\n                     now_border_space,\n                     torch.unsqueeze(self_root_state[:, 2] < termination_height, -1)), dim=-1)\n    for op_idx in range(num_agents):\n        if op_idx == agent_idx:\n            continue\n        op_root_state, op_dof_pos, op_dof_vel = ant_agents_state[op_idx]\n        dof_pos_scaled = unscale(op_dof_pos, dof_limits_lower, dof_limits_upper)\n        obs = torch.cat((obs, op_root_state[:, :7], self_root_state[:, :2] - op_root_state[:, :2],\n                         dof_pos_scaled, op_dof_vel * dof_vel_scale,\n                         now_border_space - torch.sqrt(torch.sum(op_root_state[:, :2].square(), dim=-1)).unsqueeze(-1),\n                         torch.unsqueeze(op_root_state[:, 2] < termination_height, -1)), dim=-1)\n    # print(obs.shape)\n    return obs\n\n\n@torch.jit.script\ndef randomize_rotation(rand0, rand1, x_unit_tensor, y_unit_tensor):\n    return quat_mul(quat_from_angle_axis(rand0 * np.pi, x_unit_tensor),\n                    quat_from_angle_axis(rand1 * np.pi, y_unit_tensor))\n"
  },
  {
    "path": "timechamber/tasks/ma_ant_sumo.py",
    "content": "from typing import Tuple\nimport numpy as np\nimport os\nimport math\nimport torch\nimport random\n\nfrom isaacgym import gymtorch\nfrom isaacgym import gymapi\nfrom isaacgym.gymtorch import *\n# from torch.tensor import Tensor\n\nfrom timechamber.utils.torch_jit_utils import *\nfrom .base.vec_task import VecTask\nfrom .base.ma_vec_task import MA_VecTask\n\n\n# todo critic_state full obs\nclass MA_Ant_Sumo(MA_VecTask):\n\n    def __init__(self, cfg, sim_device, rl_device, graphics_device_id, headless, virtual_screen_capture, force_render):\n\n        self.cfg = cfg\n        self.randomization_params = self.cfg[\"task\"][\"randomization_params\"]\n        self.randomize = self.cfg[\"task\"][\"randomize\"]\n\n        self.max_episode_length = self.cfg[\"env\"][\"episodeLength\"]\n\n        self.termination_height = self.cfg[\"env\"][\"terminationHeight\"]\n        self.borderline_space = cfg[\"env\"][\"borderlineSpace\"]\n        self.plane_static_friction = self.cfg[\"env\"][\"plane\"][\"staticFriction\"]\n        self.plane_dynamic_friction = self.cfg[\"env\"][\"plane\"][\"dynamicFriction\"]\n        self.plane_restitution = self.cfg[\"env\"][\"plane\"][\"restitution\"]\n        self.action_scale = self.cfg[\"env\"][\"control\"][\"actionScale\"]\n        self.joints_at_limit_cost_scale = self.cfg[\"env\"][\"jointsAtLimitCost\"]\n        self.dof_vel_scale = self.cfg[\"env\"][\"dofVelocityScale\"]\n\n        self.draw_penalty_scale = -1000\n        self.win_reward_scale = 2000\n        self.move_to_op_reward_scale = 1.\n        self.stay_in_center_reward_scale = 0.2\n        self.action_cost_scale = -0.000025\n        self.push_scale = 1.\n        self.dense_reward_scale = 1.\n        self.hp_decay_scale = 1.\n\n        self.Kp = self.cfg[\"env\"][\"control\"][\"stiffness\"]\n        self.Kd = self.cfg[\"env\"][\"control\"][\"damping\"]\n\n        # see func: compute_ant_observations() for details\n        # self.cfg[\"env\"][\"numObservations\"] = 48 # dof pos(2) + dof vel(2) + dof action(2) + feet force sensor(force&torque, 6)\n        self.cfg[\"env\"][\n            \"numObservations\"] = 40\n        self.cfg[\"env\"][\"numActions\"] = 8\n        self.cfg[\"env\"][\"numAgents\"] = 2\n        self.use_central_value = False\n\n        super().__init__(config=self.cfg, sim_device=sim_device, rl_device=rl_device,\n                         graphics_device_id=graphics_device_id,\n                         headless=headless, virtual_screen_capture=virtual_screen_capture,\n                         force_render=force_render)\n\n        if self.viewer is not None:\n            for env in self.envs:\n                self._add_circle_borderline(env)\n            cam_pos = gymapi.Vec3(15.0, 0.0, 3.0)\n            cam_target = gymapi.Vec3(10.0, 0.0, 0.0)\n            self.gym.viewer_camera_look_at(self.viewer, None, cam_pos, cam_target)\n\n        # get gym GPU state tensors\n        actor_root_state = self.gym.acquire_actor_root_state_tensor(self.sim)\n        dof_state_tensor = self.gym.acquire_dof_state_tensor(self.sim)\n        sensor_tensor = self.gym.acquire_force_sensor_tensor(self.sim)\n\n        sensors_per_env = 4\n        self.vec_sensor_tensor = gymtorch.wrap_tensor(sensor_tensor).view(self.num_envs * self.num_agents,\n                                                                          sensors_per_env * 6)\n\n        self.gym.refresh_dof_state_tensor(self.sim)\n        self.gym.refresh_actor_root_state_tensor(self.sim)\n\n        self.root_states = gymtorch.wrap_tensor(actor_root_state)\n        print(f'root_states:{self.root_states.shape}')\n        self.initial_root_states = self.root_states.clone()\n        self.initial_root_states[:, 7:13] = 0  # set lin_vel and ang_vel to 0\n\n        # create some wrapper tensors for different slices\n        self.dof_state = gymtorch.wrap_tensor(dof_state_tensor)\n        print(f\"dof state shape: {self.dof_state.shape}\")\n        self.dof_pos = self.dof_state.view(self.num_envs, -1, 2)[:, :self.num_dof, 0]\n        self.dof_pos_op = self.dof_state.view(self.num_envs, -1, 2)[:, self.num_dof:2 * self.num_dof, 0]\n        self.dof_vel = self.dof_state.view(self.num_envs, -1, 2)[:, :self.num_dof, 1]\n        self.dof_vel_op = self.dof_state.view(self.num_envs, -1, 2)[:, self.num_dof:2 * self.num_dof, 1]\n\n        self.initial_dof_pos = torch.zeros_like(self.dof_pos, device=self.device, dtype=torch.float)\n        zero_tensor = torch.tensor([0.0], device=self.device)\n        self.initial_dof_pos = torch.where(self.dof_limits_lower > zero_tensor, self.dof_limits_lower,\n                                           torch.where(self.dof_limits_upper < zero_tensor, self.dof_limits_upper,\n                                                       self.initial_dof_pos))\n        self.initial_dof_vel = torch.zeros_like(self.dof_vel, device=self.device, dtype=torch.float)\n        self.dt = self.cfg[\"sim\"][\"dt\"]\n\n        torques = self.gym.acquire_dof_force_tensor(self.sim)\n        self.torques = gymtorch.wrap_tensor(torques).view(self.num_envs, 2 * self.num_dof)\n\n        self.x_unit_tensor = to_torch([1, 0, 0], dtype=torch.float, device=self.device).repeat((2 * self.num_envs, 1))\n        self.y_unit_tensor = to_torch([0, 1, 0], dtype=torch.float, device=self.device).repeat((2 * self.num_envs, 1))\n        self.z_unit_tensor = to_torch([0, 0, 1], dtype=torch.float, device=self.device).repeat((2 * self.num_envs, 1))\n\n        self.hp = torch.ones((self.num_envs,), device=self.device, dtype=torch.float32) * 100\n        self.hp_op = torch.ones((self.num_envs,), device=self.device, dtype=torch.float32) * 100\n\n    def allocate_buffers(self):\n        self.obs_buf = torch.zeros((self.num_agents * self.num_envs, self.num_obs), device=self.device,\n                                   dtype=torch.float)\n        self.rew_buf = torch.zeros(\n            self.num_envs, device=self.device, dtype=torch.float)\n        self.reset_buf = torch.ones(self.num_envs, device=self.device, dtype=torch.long)\n        self.timeout_buf = torch.zeros(\n            self.num_envs, device=self.device, dtype=torch.long)\n        self.progress_buf = torch.zeros(\n            self.num_envs, device=self.device, dtype=torch.long)\n        self.randomize_buf = torch.zeros(\n            self.num_envs * self.num_agents, device=self.device, dtype=torch.long)\n        self.extras = {\n            'win': torch.zeros(((self.num_agents - 1) * self.num_envs,), device=self.device, dtype=torch.bool),\n            'lose': torch.zeros(((self.num_agents - 1) * self.num_envs,), device=self.device, dtype=torch.bool),\n            'draw': torch.zeros(((self.num_agents - 1) * self.num_envs,), device=self.device, dtype=torch.bool)}\n\n    def create_sim(self):\n        self.up_axis_idx = self.set_sim_params_up_axis(self.sim_params, 'z')\n        self.sim = super().create_sim(self.device_id, self.graphics_device_id, self.physics_engine, self.sim_params)\n\n        self._create_ground_plane()\n        print(f'num envs {self.num_envs} env spacing {self.cfg[\"env\"][\"envSpacing\"]}')\n        self._create_envs(self.num_envs, self.cfg[\"env\"]['envSpacing'], int(np.sqrt(self.num_envs)))\n\n        # If randomizing, apply once immediately on startup before the fist sim step\n        if self.randomize:\n            self.apply_randomizations(self.randomization_params)\n\n    def _add_circle_borderline(self, env):\n        lines = []\n        borderline_height = 0.01\n        for height in range(20):\n            for angle in range(360):\n                begin_point = [np.cos(np.radians(angle)), np.sin(np.radians(angle)), borderline_height * height]\n                end_point = [np.cos(np.radians(angle + 1)), np.sin(np.radians(angle + 1)), borderline_height * height]\n                lines.append(begin_point)\n                lines.append(end_point)\n        lines = np.array(lines, dtype=np.float32) * self.borderline_space\n        colors = np.array([[1, 0, 0]] * int(len(lines) / 2), dtype=np.float32)\n        self.gym.add_lines(self.viewer, env, int(len(lines) / 2), lines, colors)\n\n    def _create_ground_plane(self):\n        plane_params = gymapi.PlaneParams()\n        plane_params.normal = gymapi.Vec3(0.0, 0.0, 1.0)\n        plane_params.static_friction = self.plane_static_friction\n        plane_params.dynamic_friction = self.plane_dynamic_friction\n        self.gym.add_ground(self.sim, plane_params)\n\n    def _create_envs(self, num_envs, spacing, num_per_row):\n        lower = gymapi.Vec3(-spacing, -spacing, 0.0)\n        upper = gymapi.Vec3(spacing, spacing, spacing)\n\n        asset_root = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../assets')\n        asset_file = \"mjcf/nv_ant.xml\"\n\n        if \"asset\" in self.cfg[\"env\"]:\n            asset_file = self.cfg[\"env\"][\"asset\"].get(\"assetFileName\", asset_file)\n\n        asset_path = os.path.join(asset_root, asset_file)\n        asset_root = os.path.dirname(asset_path)\n        asset_file = os.path.basename(asset_path)\n\n        asset_options = gymapi.AssetOptions()\n        # Note - DOF mode is set in the MJCF file and loaded by Isaac Gym\n        asset_options.default_dof_drive_mode = gymapi.DOF_MODE_NONE\n        asset_options.angular_damping = 0.0\n\n        ant_asset = self.gym.load_asset(self.sim, asset_root, asset_file, asset_options)\n        ant_asset_op = self.gym.load_asset(self.sim, asset_root, asset_file, asset_options)\n        dof_props = self.gym.get_asset_dof_properties(ant_asset)\n\n        self.num_dof = self.gym.get_asset_dof_count(ant_asset)\n        self.num_bodies = self.gym.get_asset_rigid_body_count(ant_asset)  # 9 = 4 x 2(front&back-end legs) + 1(torso)\n        for i in range(self.num_dof):\n            dof_props['driveMode'][i] = gymapi.DOF_MODE_POS\n            dof_props['stiffness'][i] = self.Kp\n            dof_props['damping'][i] = self.Kd\n\n        box_pose = gymapi.Transform()\n        box_pose.p = gymapi.Vec3(0, 0, 0)\n        start_pose = gymapi.Transform()\n        start_pose.p = gymapi.Vec3(-self.borderline_space + 1, -self.borderline_space + 1, 1.)\n        start_pose_op = gymapi.Transform()\n        start_pose_op.p = gymapi.Vec3(self.borderline_space - 1, self.borderline_space - 1, 1.)\n\n        print(start_pose.p, start_pose_op.p)\n        self.start_rotation = torch.tensor([start_pose.r.x, start_pose.r.y, start_pose.r.z, start_pose.r.w],\n                                           device=self.device)\n\n        self.torso_index = 0\n        self.num_bodies = self.gym.get_asset_rigid_body_count(ant_asset)\n        body_names = [self.gym.get_asset_rigid_body_name(ant_asset, i) for i in range(self.num_bodies)]\n        extremity_names = [s for s in body_names if \"foot\" in s]\n        self.extremities_index = torch.zeros(len(extremity_names), dtype=torch.long, device=self.device)\n\n        # create force sensors attached to the \"feet\"\n        extremity_indices = [self.gym.find_asset_rigid_body_index(ant_asset, name) for name in extremity_names]\n        sensor_pose = gymapi.Transform()\n        sensor_pose_op = gymapi.Transform()\n        for body_idx in extremity_indices:\n            self.gym.create_asset_force_sensor(ant_asset, body_idx, sensor_pose)\n            self.gym.create_asset_force_sensor(ant_asset_op, body_idx, sensor_pose_op)\n\n        self.ant_handles = []\n        self.actor_indices = []\n        self.actor_indices_op = []\n        self.actor_handles_op = []\n        self.envs = []\n        self.pos_before = torch.zeros(2, device=self.device)\n        self.dof_limits_lower = []\n        self.dof_limits_upper = []\n\n        for i in range(self.num_envs):\n            # create env instance\n            env_ptr = self.gym.create_env(\n                self.sim, lower, upper, num_per_row\n            )\n            ant_handle = self.gym.create_actor(env_ptr, ant_asset, start_pose, \"ant\", i, -1, 0)\n            actor_index = self.gym.get_actor_index(env_ptr, ant_handle, gymapi.DOMAIN_SIM)\n            self.gym.set_actor_dof_properties(env_ptr, ant_handle, dof_props)\n            self.actor_indices.append(actor_index)\n            self.gym.enable_actor_dof_force_sensors(env_ptr, ant_handle)\n\n            ant_handle_op = self.gym.create_actor(env_ptr, ant_asset_op, start_pose_op, \"ant_op\", i, -1, 0)\n            actor_index_op = self.gym.get_actor_index(env_ptr, ant_handle_op, gymapi.DOMAIN_SIM)\n            self.gym.set_actor_dof_properties(env_ptr, ant_handle_op, dof_props)\n\n            self.actor_indices_op.append(actor_index_op)\n            for j in range(self.num_bodies):\n                self.gym.set_rigid_body_color(\n                    env_ptr, ant_handle, j, gymapi.MESH_VISUAL, gymapi.Vec3(0.97, 0.38, 0.06))\n                self.gym.set_rigid_body_color(\n                    env_ptr, ant_handle_op, j, gymapi.MESH_VISUAL, gymapi.Vec3(0.24, 0.38, 0.06))\n\n            self.envs.append(env_ptr)\n            self.ant_handles.append(ant_handle)\n            self.actor_handles_op.append(ant_handle_op)\n\n        dof_prop = self.gym.get_actor_dof_properties(env_ptr, ant_handle)\n\n        for j in range(self.num_dof):\n            if dof_prop['lower'][j] > dof_prop['upper'][j]:\n                self.dof_limits_lower.append(dof_prop['upper'][j])\n                self.dof_limits_upper.append(dof_prop['lower'][j])\n            else:\n                self.dof_limits_lower.append(dof_prop['lower'][j])\n                self.dof_limits_upper.append(dof_prop['upper'][j])\n\n        self.dof_limits_lower = to_torch(self.dof_limits_lower, device=self.device)\n        self.dof_limits_upper = to_torch(self.dof_limits_upper, device=self.device)\n        self.actor_indices = to_torch(self.actor_indices, dtype=torch.long, device=self.device)\n        self.actor_indices_op = to_torch(self.actor_indices_op, dtype=torch.long, device=self.device)\n\n        for i in range(len(extremity_names)):\n            self.extremities_index[i] = self.gym.find_actor_rigid_body_handle(self.envs[0], self.ant_handles[0],\n                                                                              extremity_names[i])\n\n    def compute_reward(self, actions):\n\n        self.rew_buf[:], self.reset_buf[:], self.hp[:], self.hp_op[:], \\\n        self.extras['win'], self.extras['lose'], self.extras['draw'] = compute_ant_reward(\n            self.obs_buf[:self.num_envs],\n            self.obs_buf[self.num_envs:],\n            self.reset_buf,\n            self.progress_buf,\n            self.pos_before,\n            self.torques[:, :self.num_dof],\n            self.hp,\n            self.hp_op,\n            self.termination_height,\n            self.max_episode_length,\n            self.borderline_space,\n            self.draw_penalty_scale,\n            self.win_reward_scale,\n            self.move_to_op_reward_scale,\n            self.stay_in_center_reward_scale,\n            self.action_cost_scale,\n            self.push_scale,\n            self.joints_at_limit_cost_scale,\n            self.dense_reward_scale,\n            self.hp_decay_scale,\n            self.dt,\n        )\n\n    def compute_observations(self):\n        self.gym.refresh_dof_state_tensor(self.sim)\n        self.gym.refresh_actor_root_state_tensor(self.sim)\n        self.gym.refresh_force_sensor_tensor(self.sim)\n        self.gym.refresh_dof_force_tensor(self.sim)\n        self.obs_buf[:self.num_envs] = \\\n            compute_ant_observations(\n                self.root_states[0::2],\n                self.root_states[1::2],\n                self.dof_pos,\n                self.dof_vel,\n                self.dof_limits_lower,\n                self.dof_limits_upper,\n                self.dof_vel_scale,\n                self.termination_height\n            )\n\n        self.obs_buf[self.num_envs:] = compute_ant_observations(\n            self.root_states[1::2],\n            self.root_states[0::2],\n            self.dof_pos_op,\n            self.dof_vel_op,\n            self.dof_limits_lower,\n            self.dof_limits_upper,\n            self.dof_vel_scale,\n            self.termination_height\n        )\n\n    def reset_idx(self, env_ids):\n        # print('reset.....', env_ids)\n        # Randomization can happen only at reset time, since it can reset actor positions on GPU\n        if self.randomize:\n            self.apply_randomizations(self.randomization_params)\n\n        positions = torch_rand_float(-0.2, 0.2, (len(env_ids), self.num_dof), device=self.device)\n        velocities = torch_rand_float(-0.1, 0.1, (len(env_ids), self.num_dof), device=self.device)\n\n        self.dof_pos[env_ids] = tensor_clamp(self.initial_dof_pos[env_ids] + positions, self.dof_limits_lower,\n                                             self.dof_limits_upper)\n        self.dof_vel[env_ids] = velocities\n\n        self.dof_pos_op[env_ids] = tensor_clamp(self.initial_dof_pos[env_ids] + positions, self.dof_limits_lower,\n                                                self.dof_limits_upper)\n        self.dof_vel_op[env_ids] = velocities\n\n        env_ids_int32 = (torch.cat((self.actor_indices[env_ids], self.actor_indices_op[env_ids]))).to(dtype=torch.int32)\n        agent_env_ids = expand_env_ids(env_ids, 2)\n\n        rand_angle = torch.rand((len(env_ids),), device=self.device) * torch.pi * 2\n\n        rand_pos = torch.ones((len(agent_env_ids), 2), device=self.device) * (\n                self.borderline_space * torch.ones((len(agent_env_ids), 2), device=self.device) - torch.rand(\n            (len(agent_env_ids), 2), device=self.device) * 2)\n        rand_pos[0::2, 0] *= torch.cos(rand_angle)\n        rand_pos[0::2, 1] *= torch.sin(rand_angle)\n        rand_pos[1::2, 0] *= torch.cos(rand_angle + torch.pi)\n        rand_pos[1::2, 1] *= torch.sin(rand_angle + torch.pi)\n        rand_floats = torch_rand_float(-1.0, 1.0, (len(agent_env_ids), 3), device=self.device)\n        rand_rotation = quat_from_angle_axis(rand_floats[:, 1] * np.pi, self.z_unit_tensor[agent_env_ids])\n        rand_rotation2 = quat_from_angle_axis(rand_floats[:, 2] * np.pi, self.z_unit_tensor[agent_env_ids])\n        self.root_states[agent_env_ids] = self.initial_root_states[agent_env_ids]\n        self.root_states[agent_env_ids, :2] = rand_pos\n        self.root_states[agent_env_ids[1::2], 3:7] = rand_rotation[1::2]\n        self.root_states[agent_env_ids[0::2], 3:7] = rand_rotation2[0::2]\n        self.gym.set_actor_root_state_tensor_indexed(self.sim,\n                                                     gymtorch.unwrap_tensor(self.root_states),\n                                                     gymtorch.unwrap_tensor(env_ids_int32), len(env_ids_int32))\n\n        self.gym.set_dof_state_tensor_indexed(self.sim,\n                                              gymtorch.unwrap_tensor(self.dof_state),\n                                              gymtorch.unwrap_tensor(env_ids_int32), len(env_ids_int32))\n        self.pos_before = self.root_states[0::2, :2].clone()\n\n        self.progress_buf[env_ids] = 0\n        self.reset_buf[env_ids] = 0\n\n    def pre_physics_step(self, actions):\n        # actions.shape = [num_envs * num_agents, num_actions], stacked as followed:\n        # {[(agent1_act_1, agent1_act2)|(agent2_act1, agent2_act2)|...]_(env0),\n        #  [(agent1_act_1, agent1_act2)|(agent2_act1, agent2_act2)|...]_(env1),\n        #  ... }\n\n        self.actions = actions.clone().to(self.device)\n        self.actions = torch.cat((self.actions[:self.num_envs], self.actions[self.num_envs:]), dim=-1)\n\n        # reshape [num_envs * num_agents, num_actions] to [num_envs, num_agents * num_actions]\n        targets = self.actions\n\n        self.gym.set_dof_position_target_tensor(self.sim, gymtorch.unwrap_tensor(targets))\n\n    def post_physics_step(self):\n        self.progress_buf += 1\n        self.randomize_buf += 1\n\n        self.compute_observations()\n        self.compute_reward(self.actions)\n        self.pos_before = self.obs_buf[:self.num_envs, :2].clone()\n\n    def get_number_of_agents(self):\n        # train one agent with index 0\n        return 1\n\n    def zero_actions(self) -> torch.Tensor:\n        \"\"\"Returns a buffer with zero actions.\n\n        Returns:\n            A buffer of zero torch actions\n        \"\"\"\n        actions = torch.zeros([self.num_envs * self.num_agents, self.num_actions], dtype=torch.float32,\n                              device=self.rl_device)\n\n        return actions\n\n    def clear_count(self):\n        self.dense_reward_scale *= 0.9\n        self.extras['win'][:] = 0\n        self.extras['draw'][:] = 0\n\n\n#####################################################################\n###=========================jit functions=========================###\n#####################################################################\n\n\n@torch.jit.script\ndef expand_env_ids(env_ids, n_agents):\n    # type: (Tensor, int) -> Tensor\n    device = env_ids.device\n    agent_env_ids = torch.zeros((n_agents * len(env_ids)), device=device, dtype=torch.long)\n    for idx in range(n_agents):\n        agent_env_ids[idx::n_agents] = env_ids * n_agents + idx\n    return agent_env_ids\n\n\n@torch.jit.script\ndef compute_move_reward(\n        pos,\n        pos_before,\n        target,\n        dt,\n        move_to_op_reward_scale\n):\n    # type: (Tensor,Tensor,Tensor,float,float) -> Tensor\n    move_vec = (pos - pos_before) / dt\n    direction = target - pos_before\n    direction = torch.div(direction, torch.linalg.norm(direction, dim=-1).view(-1, 1))\n    s = torch.sum(move_vec * direction, dim=-1)\n    return torch.maximum(s, torch.zeros_like(s)) * move_to_op_reward_scale\n\n\n@torch.jit.script\ndef compute_ant_reward(\n        obs_buf,\n        obs_buf_op,\n        reset_buf,\n        progress_buf,\n        pos_before,\n        torques,\n        hp,\n        hp_op,\n        termination_height,\n        max_episode_length,\n        borderline_space,\n        draw_penalty_scale,\n        win_reward_scale,\n        move_to_op_reward_scale,\n        stay_in_center_reward_scale,\n        action_cost_scale,\n        push_scale,\n        joints_at_limit_cost_scale,\n        dense_reward_scale,\n        hp_decay_scale,\n        dt,\n):\n    # type: (Tensor, Tensor, Tensor, Tensor,Tensor,Tensor,Tensor,Tensor,float, float,float, float,float,float,float,float,float,float,float,float,float) -> Tuple[Tensor, Tensor,Tensor,Tensor,Tensor,Tensor,Tensor]\n\n    hp -= (obs_buf[:, 2] < termination_height) * hp_decay_scale\n    hp_op -= (obs_buf_op[:, 2] < termination_height) * hp_decay_scale\n    is_out = torch.sum(torch.square(obs_buf[:, 0:2]), dim=-1) >= borderline_space ** 2\n    is_out_op = torch.sum(torch.square(obs_buf_op[:, 0:2]), dim=-1) >= borderline_space ** 2\n    is_out = is_out | (hp <= 0)\n    is_out_op = is_out_op | (hp_op <= 0)\n    # reset agents\n    tmp_ones = torch.ones_like(reset_buf)\n    reset = torch.where(is_out, tmp_ones, reset_buf)\n    reset = torch.where(is_out_op, tmp_ones, reset)\n    reset = torch.where(progress_buf >= max_episode_length - 1, tmp_ones, reset)\n\n    hp = torch.where(reset > 0, tmp_ones * 100., hp)\n    hp_op = torch.where(reset > 0, tmp_ones * 100., hp_op)\n\n    win_reward = win_reward_scale * is_out_op\n    lose_penalty = -win_reward_scale * is_out\n    draw_penalty = torch.where(progress_buf >= max_episode_length - 1, tmp_ones * draw_penalty_scale,\n                               torch.zeros_like(reset, dtype=torch.float))\n    move_reward = compute_move_reward(obs_buf[:, 0:2], pos_before,\n                                      obs_buf_op[:, 0:2], dt,\n                                      move_to_op_reward_scale)\n    # stay_in_center_reward = stay_in_center_reward_scale * torch.exp(-torch.linalg.norm(obs_buf[:, :2], dim=-1))\n    dof_at_limit_cost = torch.sum(obs_buf[:, 13:21] > 0.99, dim=-1) * joints_at_limit_cost_scale\n    push_reward = -push_scale * torch.exp(-torch.linalg.norm(obs_buf_op[:, :2], dim=-1))\n    action_cost_penalty = torch.sum(torch.square(torques), dim=1) * action_cost_scale\n    not_move_penalty = -10 * torch.exp(-torch.sum(torch.abs(torques), dim=1))\n    dense_reward = move_reward + dof_at_limit_cost + push_reward + action_cost_penalty + not_move_penalty\n    total_reward = win_reward + lose_penalty + draw_penalty + dense_reward * dense_reward_scale\n\n    return total_reward, reset, hp, hp_op, is_out_op, is_out, progress_buf >= max_episode_length - 1\n\n\n@torch.jit.script\ndef compute_ant_observations(\n        root_states,\n        root_states_op,\n        dof_pos,\n        dof_vel,\n        dof_limits_lower,\n        dof_limits_upper,\n        dof_vel_scale,\n        termination_height\n):\n    # type: (Tensor,Tensor,Tensor,Tensor,Tensor,Tensor,float,float)->Tensor\n    dof_pos_scaled = unscale(dof_pos, dof_limits_lower, dof_limits_upper)\n    obs = torch.cat(\n        (root_states[:, :13], dof_pos_scaled, dof_vel * dof_vel_scale, root_states_op[:, :7],\n         root_states[:, :2] - root_states_op[:, :2], torch.unsqueeze(root_states[:, 2] < termination_height, -1),\n         torch.unsqueeze(root_states_op[:, 2] < termination_height, -1)), dim=-1)\n\n    return obs\n\n\n@torch.jit.script\ndef randomize_rotation(rand0, rand1, x_unit_tensor, y_unit_tensor):\n    return quat_mul(quat_from_angle_axis(rand0 * np.pi, x_unit_tensor),\n                    quat_from_angle_axis(rand1 * np.pi, y_unit_tensor))\n"
  },
  {
    "path": "timechamber/tasks/ma_humanoid_strike.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nfrom asyncio import shield\nfrom dis import dis\nimport torch\nimport math\n\nfrom isaacgym import gymapi, gymtorch\nfrom isaacgym.torch_utils import *\n\nimport timechamber.tasks.ase_humanoid_base.humanoid_amp_task as humanoid_amp_task\nfrom timechamber.utils import torch_utils\n\n\nclass HumanoidStrike(humanoid_amp_task.HumanoidAMPTask):\n    def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, headless):\n        super().__init__(cfg=cfg,\n                         sim_params=sim_params,\n                         physics_engine=physics_engine,\n                         device_type=device_type,\n                         device_id=device_id,\n                         headless=headless)\n\n        self.ego_to_op_damage = torch.zeros_like(self.reset_buf, device=self.device, dtype=torch.float)\n        self.op_to_ego_damage = torch.zeros_like(self.reset_buf, device=self.device, dtype=torch.float)\n        \n        self._prev_root_pos = torch.zeros([self.num_envs, 3], device=self.device, dtype=torch.float)\n        self._prev_root_pos_op = torch.zeros([self.num_envs, 3], device=self.device, dtype=torch.float)\n        self._prev_body_ang_vel = torch.zeros([self.num_envs, self.num_bodies, 3],\n                                          device=self.device, dtype=torch.float32)\n        self._prev_body_vel = torch.zeros([self.num_envs, self.num_bodies, 3],\n                                          device=self.device, dtype=torch.float32)\n\n        strike_body_names = cfg[\"env\"][\"strikeBodyNames\"]\n        self._strike_body_ids = self._build_body_ids_tensor(self.envs[0], self.humanoid_handles[0], strike_body_names)\n        force_body_names = cfg[\"env\"][\"forceBodies\"]\n        self._force_body_ids = self._build_body_ids_tensor(self.envs[0], self.humanoid_handles[0], force_body_names)\n        \n        \n        if self.viewer != None:\n            for env in self.envs:\n                self._add_rectangle_borderline(env)\n\n            cam_pos = gymapi.Vec3(15.0, 0.0, 3.0)\n            cam_target = gymapi.Vec3(10.0, 0.0, 0.0)\n            self.gym.viewer_camera_look_at(self.viewer, None, cam_pos, cam_target)\n        \n        ###### Reward Definition ######\n\n        ###### Reward Definition ######\n\n        return\n    \n    def get_task_obs_size(self):\n        obs_size = 0\n        if (self._enable_task_obs):\n            obs_size = 50\n        return obs_size\n\n    def _create_envs(self, num_envs, spacing, num_per_row):\n\n        super()._create_envs(num_envs, spacing, num_per_row)\n        return\n\n    def _build_env(self, env_id, env_ptr, humanoid_asset, humanoid_asset_op):\n        super()._build_env(env_id, env_ptr, humanoid_asset, humanoid_asset_op)\n        return\n\n    def _build_body_ids_tensor(self, env_ptr, actor_handle, body_names):\n        env_ptr = self.envs[0]\n        actor_handle = self.humanoid_handles[0]\n        body_ids = []\n\n        for body_name in body_names:\n            body_id = self.gym.find_actor_rigid_body_handle(env_ptr, actor_handle, body_name)\n            assert(body_id != -1)\n            body_ids.append(body_id)\n\n        body_ids = to_torch(body_ids, device=self.device, dtype=torch.long)\n        return body_ids\n\n    def _reset_actors(self, env_ids):\n        positions = torch_rand_float(-0.2, 0.2, (len(env_ids), self.num_dof), device=self.device)\n        velocities = torch_rand_float(-0.1, 0.1, (len(env_ids), self.num_dof), device=self.device)\n        self._dof_pos[env_ids] = tensor_clamp(self._initial_dof_pos[env_ids] + positions, self.dof_limits_lower,\n                                             self.dof_limits_upper)\n        self._dof_vel[env_ids] = velocities\n\n        self._dof_pos_op[env_ids] = tensor_clamp(self._initial_dof_pos[env_ids] + positions, self.dof_limits_lower,\n                                                self.dof_limits_upper)\n        self._dof_vel_op[env_ids] = velocities\n\n        agent_env_ids = expand_env_ids(env_ids, 2)\n\n        rand_angle = torch.rand((len(env_ids),), device=self.device) * math.pi * 2\n        rand_pos = torch.ones((len(agent_env_ids), 2), device=self.device) * (\n                self.borderline_space * torch.ones((len(agent_env_ids), 2), device=self.device) - torch.rand(\n            (len(agent_env_ids), 2), device=self.device) * 2)\n        rand_pos[0::2, 0] *= torch.cos(rand_angle)\n        rand_pos[0::2, 1] *= torch.sin(rand_angle)\n        rand_pos[1::2, 0] *= torch.cos(rand_angle + math.pi)\n        rand_pos[1::2, 1] *= torch.sin(rand_angle + math.pi)\n\n        rand_floats = torch_rand_float(-1.0, 1.0, (len(agent_env_ids), 3), device=self.device)\n        rand_rotation = quat_from_angle_axis(rand_floats[:, 1] * np.pi, self.z_unit_tensor[agent_env_ids])\n        rand_rotation2 = quat_from_angle_axis(rand_floats[:, 2] * np.pi, self.z_unit_tensor[agent_env_ids])\n\n        self._humanoid_root_states[agent_env_ids] = self._initial_humanoid_root_states[agent_env_ids]\n        self._humanoid_root_states[agent_env_ids, :2] = rand_pos\n        self._humanoid_root_states[agent_env_ids[1::2], 3:7] = rand_rotation[1::2]\n        self._humanoid_root_states[agent_env_ids[0::2], 3:7] = rand_rotation2[0::2]\n        \n        return\n\n    def _reset_env_tensors(self, env_ids):\n        super()._reset_env_tensors(env_ids)\n        self.ego_to_op_damage[env_ids] = 0\n        self.op_to_ego_damage[env_ids] = 0\n        return\n\n    def pre_physics_step(self, actions):\n        super().pre_physics_step(actions)\n        # self._prev_root_pos[:] = self._humanoid_root_states[self.humanoid_indices, 0:3]\n        # self._prev_root_pos_op[:] = self._humanoid_root_states[self.humanoid_indices_op, 0:3]\n        # self._prev_body_ang_vel[:] = self._rigid_body_ang_vel[]\n        return\n\n    def post_physics_step(self):\n        super().post_physics_step()\n        self._prev_body_ang_vel[:] = self._rigid_body_ang_vel[:]\n        self._prev_body_vel[:] = self._rigid_body_vel[:]\n\n    def _compute_observations(self):\n\n        obs, obs_op = self._compute_humanoid_obs()\n        if (self._enable_task_obs):\n            task_obs, task_obs_op = self._compute_task_obs()\n            obs = torch.cat([obs, task_obs], dim=-1)\n            obs_op = torch.cat([obs_op, task_obs_op], dim=-1)\n        self.obs_buf[:self.num_envs] = obs\n        self.obs_buf[self.num_envs:] = obs_op\n        return\n\n    def _compute_task_obs(self):\n        body_pos = self._rigid_body_pos\n        body_rot = self._rigid_body_rot\n        body_vel = self._rigid_body_vel\n\n        body_pos_op = self._rigid_body_pos_op\n        body_rot_op = self._rigid_body_rot_op\n        body_vel_op = self._rigid_body_vel_op\n\n        # num_envs, 13\n        root_states = self._humanoid_root_states[self.humanoid_indices]\n        root_states_op = self._humanoid_root_states[self.humanoid_indices_op]\n\n        obs = compute_strike_observations(root_states, root_states_op, \n                                          body_pos, body_rot,\n                                          body_pos_op, body_vel_op,\n                                          borderline=self.borderline_space\n                                          )\n        obs_op = compute_strike_observations(root_states=root_states_op,\n                                             root_states_op=root_states,\n                                             body_pos=body_pos_op,\n                                             body_rot=body_rot_op,\n                                             body_pos_op=body_pos,\n                                             body_vel_op=body_vel,\n                                             borderline=self.borderline_space)\n        return obs, obs_op\n\n    def _compute_reward(self, actions):\n\n        root_states = self._humanoid_root_states[self.humanoid_indices]\n        root_states_op = self._humanoid_root_states[self.humanoid_indices_op]\n\n        body_pos = self._rigid_body_pos\n        body_vel = self._rigid_body_vel\n        prev_body_vel = self._prev_body_vel\n        \n        body_ang_vel = self._rigid_body_ang_vel\n        prev_body_ang_vel = self._prev_body_ang_vel\n        contact_force = self._contact_forces\n        \n        body_pos_op = self._rigid_body_pos_op\n        contact_force_op = self._contact_forces_op\n\n        self.rew_buf[:], force_ego_to_op, force_op_to_ego = compute_strike_reward(root_states=root_states,\n                                                root_states_op=root_states_op,\n                                                body_pos=body_pos,\n                                                body_ang_vel=body_ang_vel,\n                                                prev_body_ang_vel=prev_body_ang_vel,\n                                                body_vel=body_vel,\n                                                prev_body_vel=prev_body_vel,\n                                                body_pos_op=body_pos_op,\n                                                force_body_ids=self._force_body_ids,\n                                                strike_body_ids=self._strike_body_ids,\n                                                contact_force=contact_force,\n                                                contact_force_op=contact_force_op,\n                                                contact_body_ids=self._contact_body_ids,\n                                                borderline=self.borderline_space,\n                                                termination_heights=self._termination_heights,\n                                                dt=self.dt)\n        self.ego_to_op_damage += force_ego_to_op\n        self.op_to_ego_damage += force_op_to_ego\n        return\n\n    def _compute_reset(self):\n        self.reset_buf[:], self._terminate_buf[:],\\\n            self.extras['win'], self.extras['lose'], self.extras['draw'] = \\\n            compute_humanoid_reset(self.reset_buf, self.progress_buf,\n                                   self.ego_to_op_damage,\n                                   self.op_to_ego_damage,\n                                   self._contact_forces, \n                                   self._contact_forces_op,\n                                   self._contact_body_ids,\n                                   self._rigid_body_pos,\n                                   self._rigid_body_pos_op,\n                                   self.max_episode_length,\n                                   self._enable_early_termination,\n                                   self._termination_heights,\n                                   self.borderline_space)\n        return\n\n#####################################################################\n###=========================jit functions=========================###\n#####################################################################\n\n@torch.jit.script\ndef compute_strike_observations(root_states, root_states_op, body_pos, body_rot,\n                                body_pos_op, body_vel_op, borderline,\n                                ):\n    # type: (Tensor, Tensor, Tensor, Tensor, Tensor,Tensor,float) -> Tensor\n    root_pos = root_states[:, 0:3]\n    root_rot = root_states[:, 3:7]\n    ego_sword_pos = body_pos[:, 6, :]\n    ego_sword_rot = body_rot[:, 6, :]\n    ego_shield_pos = body_pos[:, 9, :]\n    ego_shield_rot = body_rot[:, 9, :]\n\n    root_pos_op = root_states_op[:, 0:3]\n    root_rot_op = root_states_op[:, 3:7]\n    root_vel_op = root_states_op[:, 7:10]\n    root_ang_op = root_states_op[:, 10:13]\n    op_sword_pos = body_pos_op[:, 6, :]\n    op_sword_vel = body_vel_op[:, 6, :]\n    op_torso_pos = body_pos_op[:, 1, :]\n    op_torso_vel = body_vel_op[:, 1, :]\n    op_head_pos = body_pos_op[:, 2, :]\n    op_head_vel = body_vel_op[:, 2, :]\n    op_right_upper_arm_pos = body_pos_op[:, 3, :]\n    op_right_thigh_pos = body_pos_op[:, 11, :]\n    op_left_thigh_pos = body_pos_op[:, 14, :]\n\n    ##*******************************************************##\n    relative_x_1 =  borderline - root_pos[:, 0]\n    relative_x_2 = root_pos[:, 0] + borderline\n    relative_x = torch.minimum(relative_x_1, relative_x_2)\n    relative_x = torch.unsqueeze(relative_x, -1)\n    relative_y_1 =  borderline - root_pos[:, 1]\n    relative_y_2 = root_pos[:,1] + borderline\n    relative_y = torch.minimum(relative_y_1, relative_y_2)\n    relative_y = torch.unsqueeze(relative_y, -1)\n    ##*******************************************************##\n\n    heading_rot = torch_utils.calc_heading_quat_inv(root_rot)\n    sword_rot = torch_utils.calc_heading_quat_inv(ego_sword_rot)\n    shield_rot = torch_utils.calc_heading_quat_inv(ego_shield_rot)\n\n    local_op_relative_pos = root_pos_op - root_pos\n    local_op_relative_pos[..., -1] = root_pos_op[..., -1]\n    local_op_relative_pos = quat_rotate(heading_rot, local_op_relative_pos)\n\n    local_op_vel = quat_rotate(heading_rot, root_vel_op)\n    local_op_ang_vel = quat_rotate(heading_rot, root_ang_op)\n\n    local_op_rot = quat_mul(heading_rot, root_rot_op)\n    local_op_rot_obs = torch_utils.quat_to_tan_norm(local_op_rot)\n    ##*******************************************************##\n\n    # op sword relative ego position and vel\n    local_op_relative_sword_pos = op_sword_pos - root_pos\n    local_op_relative_sword_pos = quat_rotate(heading_rot, local_op_relative_sword_pos)\n    local_op_sword_vel = quat_rotate(heading_rot, op_sword_vel)\n    \n    # op sword relative ego shield position and vel\n    local_op_sword_shield_pos = op_sword_pos - ego_shield_pos\n    local_op_sword_shield_pos = quat_rotate(shield_rot, local_op_sword_shield_pos)\n    local_op_sword_shield_vel = quat_rotate(shield_rot, op_sword_vel)\n    \n    # relative position and vel of ego sword and op up body\n    relative_sword_torso_pos = op_torso_pos - ego_sword_pos\n    relative_sword_torso_pos = quat_rotate(sword_rot, relative_sword_torso_pos)\n    relative_sword_torso_vel = quat_rotate(sword_rot, op_torso_vel)\n    relative_sword_head_pos = op_head_pos - ego_sword_pos\n    relative_sword_head_pos = quat_rotate(sword_rot, relative_sword_head_pos)\n    relative_sword_head_vel = quat_rotate(sword_rot, op_head_vel)\n    relative_sword_right_arm_pos = op_right_upper_arm_pos - ego_sword_pos\n    relative_sword_right_arm_pos = quat_rotate(sword_rot, relative_sword_right_arm_pos)\n    relative_sword_right_thigh_pos = op_right_thigh_pos - ego_sword_pos\n    relative_sword_right_thigh_pos = quat_rotate(sword_rot, relative_sword_right_thigh_pos)\n    relative_sword_left_thigh_pos = op_left_thigh_pos - ego_sword_pos\n    relative_sword_left_thigh_pos = quat_rotate(sword_rot, relative_sword_left_thigh_pos)\n\n    obs = torch.cat([relative_x, relative_y,\n                     local_op_relative_pos, local_op_rot_obs,\n                     local_op_vel, local_op_ang_vel,\n                     local_op_relative_sword_pos, local_op_sword_vel,\n                     local_op_sword_shield_pos, local_op_sword_shield_vel,\n                     relative_sword_torso_pos, relative_sword_torso_vel,\n                     relative_sword_head_pos, relative_sword_head_vel,\n                     relative_sword_right_arm_pos, relative_sword_right_thigh_pos,\n                     relative_sword_left_thigh_pos\n                     ], dim=-1)\n    return obs\n\n@torch.jit.script\ndef compute_strike_reward(root_states, root_states_op, body_pos, body_ang_vel,\n                          prev_body_ang_vel, body_vel, prev_body_vel,\n                          body_pos_op, force_body_ids, strike_body_ids,\n                          contact_force, contact_force_op, contact_body_ids,\n                          borderline, termination_heights, dt):\n    # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor,Tensor,Tensor,Tensor,float, Tensor, float) -> Tuple[Tensor, Tensor,Tensor]\n\n    op_fall_reward_w = 200.0\n    ego_fall_out_reward_w = 50.0\n    shield_to_sword_pos_reward_w = 1.0\n    damage_reward_w = 8.0\n    sword_to_op_reward_w = 0.8\n    reward_energy_w = 3.0\n    reward_strike_vel_acc_w = 3.0\n    reward_face_w = 4.0\n    reward_foot_to_op_w = 10.0\n    reward_kick_w = 2.0\n\n    num_envs = root_states.shape[0]\n    reward = torch.zeros((num_envs, 1), dtype=torch.float32)\n    root_xy_pos = root_states[:, 0:2]\n    root_pos = root_states[:, 0:3]\n    ego_sword_pos = body_pos[:, 6, 0:3]\n    ego_shield_pos = body_pos[:, 9, 0:3]\n    ego_right_foot_pos = body_pos[:, 13, 0:3]\n    op_sword_pos = body_pos_op[:, 6, 0:3]\n    op_torse_pos = body_pos_op[:, 1, 0:3]\n    op_right_thigh_pos = body_pos_op[:, 11, 0:3]\n    op_left_thigh_pos = body_pos_op[:, 14, 0:3]\n    root_pos_xy_op = root_states_op[:, 0:2]\n    root_pos_xy = root_states[:, 0:2]\n    root_pos_op = root_states_op[:, 0:3]\n    root_rot = root_states[:, 3:7]\n    root_rot_op = root_states_op[:, 3:7]\n    up = torch.zeros_like(root_pos_op)\n    up[..., -1] = 1\n    contact_buf = contact_force.clone()\n    contact_buf_op = contact_force_op.clone()\n\n    ##*****************r energy******************##\n    strike_body_vel = body_vel[:, strike_body_ids, :]\n    strike_body_vel_norm = torch.sum(torch.norm(strike_body_vel, dim=-1), dim=1)\n    strike_body_vel_norm = torch.clamp(strike_body_vel_norm, max=20)\n    distance = root_pos_xy_op - root_xy_pos\n    distance = torch.norm(distance, dim=-1)\n    zeros = torch.zeros_like(distance)\n    k_dist = torch.exp(-10 * torch.maximum(zeros, distance - 2.0))\n    r_energy = k_dist * strike_body_vel_norm\n    r_energy = r_energy / 20.\n    \n    strike_vel_dfff = body_vel[:, strike_body_ids, :] - prev_body_vel[:, strike_body_ids, :]\n    strike_vel_acc = strike_vel_dfff / dt\n    strike_vel_acc = torch.sum(torch.norm(strike_vel_acc, dim=-1), dim=1)\n    strike_vel_acc = torch.clamp(strike_vel_acc, max=1000)\n    strike_vel_acc = k_dist * strike_vel_acc / 500\n    r_strike_vel_acc = strike_vel_acc\n    ##*****************r damage******************##\n    ego_to_op_force = contact_buf_op[:, force_body_ids, :]\n\n    op_to_ego_force = contact_buf[:, force_body_ids, :]\n\n    force_ego_to_op = torch.norm(ego_to_op_force, dim=2).sum(dim=1)\n\n    force_op_to_ego = torch.norm(op_to_ego_force, dim=2).sum(dim=1)\n\n    r_damage = force_ego_to_op - force_op_to_ego * 2\n    r_damage = torch.clamp(r_damage, min= -200.)\n    r_damage /= 100\n\n    ##*****************r kick******************##\n    ego_foot_op_torse_distance = op_torse_pos - ego_right_foot_pos\n    ego_foot_op_torse_err = torch.norm(ego_foot_op_torse_distance, dim=-1)\n    succ_foot = ego_foot_op_torse_err < 0.1\n    r_foot_to_op = torch.exp(-0.5 * ego_foot_op_torse_err)\n    constant_r = torch.ones_like(r_foot_to_op)\n    r_foot_to_op = torch.where(succ_foot, constant_r, r_foot_to_op)\n    \n    foot_height = ego_right_foot_pos[..., 2]\n    succ_kick = foot_height >= 0.4\n    zeros = torch.zeros_like(succ_kick)\n    constant_r_kick = torch.ones_like(succ_kick)\n    r_kick = torch.where(succ_kick, constant_r_kick, foot_height)\n    \n    ##*****************r close******************##\n    # sword -> torso\n    pos_err_scale1 = 1.0\n    pos_err_scale2 = 2.0\n\n    sword_torse_distance = op_torse_pos - ego_sword_pos\n    sword_torse_err = torch.sum(sword_torse_distance * sword_torse_distance, dim=-1)\n\n    sword_right_thigh_distance = op_right_thigh_pos - ego_sword_pos\n    sword_right_thigh_err = torch.sum(sword_right_thigh_distance * sword_right_thigh_distance, dim=-1)\n\n    sword_left_thigh_distance = op_left_thigh_pos - ego_sword_pos\n    sword_left_thigh_err = torch.sum(sword_left_thigh_distance * sword_left_thigh_distance, dim=-1)\n\n    sword_sword_distance = op_sword_pos - ego_sword_pos\n    sword_sword_err = torch.sum(sword_sword_distance * sword_sword_distance, dim=-1)\n    \n    # zeros = torch.zeros_like(sword_torse_distance)\n    r_close = torch.exp(-pos_err_scale1 * sword_torse_err) # -> [0, 1]\n    r_close += torch.exp(-pos_err_scale1 * sword_right_thigh_err)\n    r_close += torch.exp(-pos_err_scale1 * sword_left_thigh_err)\n    r_close += torch.exp(-pos_err_scale2 * sword_sword_err)\n    ##*****************r shelid with op sword******************##\n    pos_err_scale3 = 2.0\n    ego_shield_op_sword_distance = op_sword_pos - ego_shield_pos\n    ego_shield_op_sword_err = torch.sum(ego_shield_op_sword_distance * ego_shield_op_sword_distance, dim=-1)\n    r_shield_to_sword = torch.exp(-pos_err_scale3 * ego_shield_op_sword_err)\n\n    ##*****************r face******************##\n    tar_dir = root_pos_xy_op - root_xy_pos\n    tar_dir = torch.nn.functional.normalize(tar_dir, dim=-1)\n\n    heading_rot = torch_utils.calc_heading_quat(root_rot)\n    facing_dir = torch.zeros_like(root_pos)\n    facing_dir[..., 0] = 1.0\n    facing_dir = quat_rotate(heading_rot, facing_dir)\n    facing_err = torch.sum(tar_dir * facing_dir[..., 0:2], dim=-1)\n    facing_reward = torch.clamp_min(facing_err, 0.0)\n\n    ##*****************r op fall******************##\n    masked_contact_buf_op = contact_buf_op.clone()\n    masked_contact_buf_op[:, contact_body_ids, :] = 0\n    fall_contact_op = torch.any(torch.abs(masked_contact_buf_op) > 0.1, dim=-1)\n    fall_contact_op = torch.any(fall_contact_op, dim=-1)\n\n    body_height_op = body_pos_op[..., 2]\n    fall_height_op = body_height_op < termination_heights\n    fall_height_op[:, contact_body_ids] = False\n    fall_height_op = torch.any(fall_height_op, dim=-1)\n    has_fallen_op = torch.logical_and(fall_contact_op, fall_height_op)\n\n    op_up = quat_rotate(root_rot_op, up)\n    op_rot_err = torch.sum(up * op_up, dim=-1)\n    op_rot_r = 0.6 * torch.clamp_min(1.0 - op_rot_err, 0.0) # -> [0, 1] succ = op_rot_err < 0.2\n    op_rot_r = torch.where(has_fallen_op, torch.ones_like(op_rot_r), op_rot_r)\n\n    # test, when op fall, then r_close = 0 to encourage to agents separate.\n    r_separate = torch.norm((root_pos_xy_op - root_pos_xy), dim=-1)\n    r_separate = torch.where(r_separate > 0.1, r_separate, torch.zeros_like(r_separate))\n    r_close = torch.where(has_fallen_op, r_separate, r_close)\n    r_shield_to_sword = torch.where(has_fallen_op, torch.zeros_like(r_shield_to_sword), r_shield_to_sword)\n    \n    ##*****************r penalty******************##\n    relative_x_1 =  borderline - root_xy_pos[:, 0]\n    relative_x_2 = root_xy_pos[:, 0] + borderline\n    relative_x = torch.minimum(relative_x_1, relative_x_2)\n    relative_x = relative_x < 0\n    relative_y_1 =  borderline - root_xy_pos[:, 1]\n    relative_y_2 = root_xy_pos[:,1] + borderline\n    relative_y = torch.minimum(relative_y_1, relative_y_2)\n    relative_y = relative_y < 0\n    is_out = relative_x | relative_y\n    r_penalty = is_out * 1.0\n\n    masked_contact_buf = contact_force.clone()\n    masked_contact_buf[:, contact_body_ids, :] = 0\n    fall_contact = torch.any(torch.abs(masked_contact_buf) > 0.1, dim=-1)\n    fall_contact = torch.any(fall_contact, dim=-1)\n    body_height = body_pos[..., 2]\n    fall_height = body_height < termination_heights  \n    fall_height[:, contact_body_ids] = False\n    fall_height = torch.any(fall_height, dim=-1)\n    has_fallen_ego = torch.logical_and(fall_contact, fall_height)\n    r_penalty += has_fallen_ego * 1.0\n\n    ##*****************r penalty******************##\n    reward = -r_penalty * ego_fall_out_reward_w + op_rot_r * op_fall_reward_w + \\\n        r_shield_to_sword * shield_to_sword_pos_reward_w + r_close * sword_to_op_reward_w +\\\n            r_damage * damage_reward_w + r_energy * reward_energy_w + facing_reward * reward_face_w + \\\n                r_strike_vel_acc * reward_strike_vel_acc_w + r_foot_to_op * reward_foot_to_op_w +\\\n                    r_kick * reward_kick_w\n\n    return reward, force_ego_to_op, force_op_to_ego\n\n\n@torch.jit.script\ndef compute_humanoid_reset(reset_buf, progress_buf, ego_to_op_damage, op_to_ego_damage,\n                           contact_buf, contact_buf_op, contact_body_ids,\n                           rigid_body_pos, rigid_body_pos_op, max_episode_length,\n                           enable_early_termination, termination_heights, borderline):\n    # type: (Tensor, Tensor, Tensor, Tensor,Tensor, Tensor, Tensor, Tensor, Tensor, float, bool, Tensor, float) -> Tuple[Tensor, Tensor,Tensor,Tensor,Tensor]\n\n    terminated = torch.zeros_like(reset_buf)\n\n    if (enable_early_termination):\n        masked_contact_buf = contact_buf.clone()\n        masked_contact_buf_op = contact_buf_op.clone()\n        masked_contact_buf[:, contact_body_ids, :] = 0\n        masked_contact_buf_op[:, contact_body_ids, :] = 0\n        fall_contact = torch.any(torch.abs(masked_contact_buf) > 0.1, dim=-1)\n        fall_contact = torch.any(fall_contact, dim=-1)\n        fall_contact_op = torch.any(torch.abs(masked_contact_buf_op) > 0.1, dim=-1)\n        fall_contact_op = torch.any(fall_contact_op, dim=-1)\n\n        body_height = rigid_body_pos[..., 2]\n        body_height_op = rigid_body_pos_op[..., 2]\n        fall_height = body_height < termination_heights\n        fall_height_op = body_height_op < termination_heights\n        fall_height[:, contact_body_ids] = False\n        fall_height_op[:, contact_body_ids] = False\n        fall_height = torch.any(fall_height, dim=-1)\n        fall_height_op = torch.any(fall_height_op, dim=-1)\n\n        ## out area\n        root_pos = rigid_body_pos[:, 0, 0:2]\n        root_pos_op = rigid_body_pos_op[:, 0, 0:2]\n        relative_x_1 =  borderline - root_pos[:, 0]\n        relative_x_2 = root_pos[:, 0] + borderline\n        relative_x = torch.minimum(relative_x_1, relative_x_2)\n        relative_x = relative_x < 0\n        relative_y_1 =  borderline - root_pos[:, 1]\n        relative_y_2 = root_pos[:,1] + borderline\n        relative_y = torch.minimum(relative_y_1, relative_y_2)\n        relative_y = relative_y < 0\n        is_out_ego = relative_x | relative_y\n\n        relative_x_1_op =  borderline - root_pos_op[:, 0]\n        relative_x_2_op = root_pos_op[:, 0] + borderline\n        relative_x_op = torch.minimum(relative_x_1_op, relative_x_2_op)\n        relative_x_op = relative_x_op < 0\n        relative_y_1_op =  borderline - root_pos_op[:, 1]\n        relative_y_2_op = root_pos_op[:,1] + borderline\n        relative_y_op = torch.minimum(relative_y_1_op, relative_y_2_op)\n        relative_y_op = relative_y_op < 0\n        is_out_op = relative_x_op | relative_y_op\n\n        is_out = is_out_ego | is_out_op\n        \n        has_failed = is_out\n\n        # first timestep can sometimes still have nonzero contact forces\n        # so only check after first couple of steps\n        has_failed *= (progress_buf > 1)\n\n        terminated = torch.where(has_failed, torch.ones_like(reset_buf), terminated)\n    damage_ego_more_than_op = ego_to_op_damage > op_to_ego_damage\n    damage_op_more_than_ego = op_to_ego_damage > ego_to_op_damage\n\n    reset = torch.where(progress_buf >= max_episode_length - 1, torch.ones_like(reset_buf), terminated)\n    win = torch.where(reset, damage_ego_more_than_op, torch.zeros_like(reset_buf, dtype=torch.bool))\n    lose = torch.where(reset, damage_op_more_than_ego, torch.zeros_like(reset_buf, dtype=torch.bool))\n    draw = torch.where(reset, ego_to_op_damage == op_to_ego_damage, torch.zeros_like(reset_buf, dtype=torch.bool))\n    \n    \n    return reset, terminated, win, lose, draw\n\n@torch.jit.script\ndef expand_env_ids(env_ids, n_agents):\n    # type: (Tensor, int) -> Tensor\n    device = env_ids.device\n    agent_env_ids = torch.zeros((n_agents * len(env_ids)), device=device, dtype=torch.long)\n    for idx in range(n_agents):\n        agent_env_ids[idx::n_agents] = env_ids * n_agents + idx\n    return agent_env_ids\n"
  },
  {
    "path": "timechamber/train.py",
    "content": "# train.py\n# Script to train policies in Isaac Gym\n#\n# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\nimport datetime\nfrom statistics import mode\nimport isaacgym\nimport os\nimport hydra\nimport yaml\nfrom omegaconf import DictConfig, OmegaConf\nfrom hydra.utils import to_absolute_path\nimport gym\n\nfrom timechamber.utils.reformat import omegaconf_to_dict, print_dict\nfrom timechamber.utils.utils import set_np_formatting, set_seed\nfrom timechamber.utils.rlgames_utils import RLGPUEnv, RLGPUAlgoObserver, get_rlgames_env_creator\nfrom rl_games.common import env_configurations, vecenv\nfrom rl_games.torch_runner import Runner\nfrom rl_games.algos_torch import model_builder\nfrom timechamber.ase import ase_agent\nfrom timechamber.ase import ase_models\nfrom timechamber.ase import ase_network_builder\nfrom timechamber.ase import hrl_models\nfrom timechamber.ase import hrl_network_builder\nfrom timechamber.learning import ppo_sp_agent\nfrom timechamber.learning import hrl_sp_agent\nfrom timechamber.learning import ppo_sp_player\nfrom timechamber.learning import hrl_sp_player\nfrom timechamber.learning import vectorized_models\nfrom timechamber.learning import vectorized_network_builder\nimport timechamber\n\n\n## OmegaConf & Hydra Config\n\n# Resolvers used in hydra configs (see https://omegaconf.readthedocs.io/en/2.1_branch/usage.html#resolvers)\n@hydra.main(config_name=\"config\", config_path=\"./cfg\")\ndef launch_rlg_hydra(cfg: DictConfig):\n\n    time_str = datetime.datetime.now().strftime(\"%Y-%m-%d_%H-%M-%S\")\n    run_name = f\"{cfg.wandb_name}_{time_str}\"\n\n    # ensure checkpoints can be specified as relative paths\n    if cfg.checkpoint:\n        cfg.checkpoint = to_absolute_path(cfg.checkpoint)\n\n    cfg_dict = omegaconf_to_dict(cfg)\n    print_dict(cfg_dict)\n\n    # set numpy formatting for printing only\n    set_np_formatting()\n\n    rank = int(os.getenv(\"LOCAL_RANK\", \"0\"))\n    if cfg.multi_gpu:\n        # torchrun --standalone --nnodes=1 --nproc_per_node=2 train.py\n        cfg.sim_device = f'cuda:{rank}'\n        cfg.rl_device = f'cuda:{rank}'\n\n    # sets seed. if seed is -1 will pick a random one\n    cfg.seed += rank\n    cfg.seed = set_seed(cfg.seed, torch_deterministic=cfg.torch_deterministic, rank=rank)\n\n    if cfg.wandb_activate and rank == 0:\n        # Make sure to install WandB if you actually use this.\n        import wandb\n\n        run = wandb.init(\n            project=cfg.wandb_project,\n            group=cfg.wandb_group,\n            entity=cfg.wandb_entity,\n            config=cfg_dict,\n            sync_tensorboard=True,\n            name=run_name,\n            resume=\"allow\",\n        )\n\n    def create_env_thunk(**kwargs):\n        envs = timechamber.make(\n            cfg.seed,\n            cfg.task_name,\n            cfg.task.env.numEnvs,\n            cfg.sim_device,\n            cfg.rl_device,\n            cfg.graphics_device_id,\n            cfg.device_type,\n            cfg.headless,\n            cfg.multi_gpu,\n            cfg.capture_video,\n            cfg.force_render,\n            cfg,\n            **kwargs,\n        )\n        if cfg.capture_video:\n            envs.is_vector_env = True\n            envs = gym.wrappers.RecordVideo(\n                envs,\n                f\"videos/{run_name}\",\n                step_trigger=lambda step: step % cfg.capture_video_freq == 0,\n                video_length=cfg.capture_video_len,\n            )\n        return envs\n\n    # register the rl-games adapter to use inside the runner\n    vecenv.register('RLGPU',\n                    lambda config_name, num_actors, **kwargs: RLGPUEnv(config_name, num_actors, **kwargs))\n\n    env_configurations.register('rlgpu', {\n        'vecenv_type': 'RLGPU',\n        'env_creator': create_env_thunk,\n    })\n\n    # register new AMP network builder and agent\n    def build_runner(algo_observer):\n        runner = Runner(algo_observer)\n        runner.algo_factory.register_builder('self_play_continuous', lambda **kwargs: ppo_sp_agent.SPAgent(**kwargs))\n        runner.algo_factory.register_builder('self_play_hrl', lambda **kwargs: hrl_sp_agent.HRLSPAgent(**kwargs))\n        runner.algo_factory.register_builder('ase', lambda **kwargs: ase_agent.ASEAgent(**kwargs))\n\n        runner.player_factory.register_builder('self_play_continuous',\n                                               lambda **kwargs: ppo_sp_player.SPPlayer(**kwargs))\n        runner.player_factory.register_builder('self_play_hrl',\n                                               lambda **kwargs: hrl_sp_player.HRLSPPlayer(**kwargs))\n        # runner.\n        model_builder.register_model('hrl', lambda network, **kwargs: hrl_models.ModelHRLContinuous(network))\n        model_builder.register_model('ase', lambda network, **kwargs: ase_models.ModelASEContinuous(network))\n        model_builder.register_model('vectorized_a2c',\n                                     lambda network, **kwargs: vectorized_models.ModelVectorizedA2C(network))\n        model_builder.register_network('vectorized_a2c',\n                                       lambda **kwargs: vectorized_network_builder.VectorizedA2CBuilder())\n        model_builder.register_network('ase', lambda **kwargs: ase_network_builder.ASEBuilder())\n        model_builder.register_network('hrl', lambda **kwargs: hrl_network_builder.HRLBuilder())\n        \n        return runner\n\n    rlg_config_dict = omegaconf_to_dict(cfg.train)\n\n    # convert CLI arguments into dictionory\n    # create runner and set the settings\n    runner = build_runner(RLGPUAlgoObserver())\n    runner.load(rlg_config_dict)\n    runner.reset()\n\n    # dump config dict\n    experiment_dir = os.path.join('runs', cfg.train.params.config.name)\n    os.makedirs(experiment_dir, exist_ok=True)\n    with open(os.path.join(experiment_dir, 'config.yaml'), 'w') as f:\n        f.write(OmegaConf.to_yaml(cfg))\n\n    if cfg.multi_gpu:\n        import horovod.torch as hvd\n\n        rank = hvd.rank()\n    else:\n        rank = 0\n\n    if cfg.wandb_activate and rank == 0:\n        # Make sure to install WandB if you actually use this.\n        import wandb\n\n        wandb.init(\n            project=cfg.wandb_project,\n            group=cfg.wandb_group,\n            entity=cfg.wandb_entity,\n            config=cfg_dict,\n            sync_tensorboard=True,\n            id=run_name,\n            resume=\"allow\",\n            monitor_gym=True,\n        )\n\n    runner.run({\n        'train': not cfg.test,\n        'play': cfg.test,\n        'checkpoint': cfg.checkpoint,\n        'sigma': None\n    })\n\n    if cfg.wandb_activate and rank == 0:\n        wandb.finish()\n\n\nif __name__ == \"__main__\":\n    launch_rlg_hydra()\n"
  },
  {
    "path": "timechamber/utils/config.py",
    "content": "import os\nimport sys\nimport yaml\n\nfrom isaacgym import gymapi\nfrom isaacgym import gymutil\n\nimport numpy as np\nimport random\nimport torch\n\nSIM_TIMESTEP = 1.0 / 60.0\n\ndef parse_sim_params(args, cfg):\n    # initialize sim\n    sim_params = gymapi.SimParams()\n    sim_params.dt = SIM_TIMESTEP\n    sim_params.num_client_threads = args.num_subscenes\n    if args.physics_engine == \"flex\":\n        if args.device_type != \"cpu\":\n            print(\"WARNING: Using Flex with GPU instead of PHYSX!\")\n        sim_params.flex.shape_collision_margin = 0.01\n        sim_params.flex.num_outer_iterations = 4\n        sim_params.flex.num_inner_iterations = 10\n    elif args.physics_engine == \"physx\":\n        sim_params.physx.solver_type = 1\n        sim_params.physx.num_position_iterations = 4\n        sim_params.physx.num_velocity_iterations = 0\n        sim_params.physx.num_threads = 4\n        sim_params.physx.use_gpu = args.use_gpu\n        sim_params.physx.num_subscenes = args.num_subscenes\n        sim_params.physx.max_gpu_contact_pairs = 8 * 1024 * 1024\n\n    sim_params.use_gpu_pipeline = args.use_gpu_pipeline\n    sim_params.physx.use_gpu = args.use_gpu\n\n    # if sim options are provided in cfg, parse them and update/override above:\n    if \"sim\" in cfg:\n        gymutil.parse_sim_config(cfg[\"sim\"], sim_params)\n\n    # Override num_threads if passed on the command line\n    if args.physics_engine == \"physx\" and args.num_threads > 0:\n        sim_params.physx.num_threads = args.num_threads\n\n    return sim_params"
  },
  {
    "path": "timechamber/utils/gym_util.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nfrom timechamber.utils import logger\nfrom isaacgym import gymapi\nimport numpy as np\nimport torch\nfrom isaacgym.torch_utils import *\nfrom isaacgym import gymtorch\n\ndef setup_gym_viewer(config):\n    gym = initialize_gym(config)\n    sim, viewer = configure_gym(gym, config)\n    return gym, sim, viewer\n\n\ndef initialize_gym(config):\n    gym = gymapi.acquire_gym()\n    if not gym.initialize():\n        logger.warn(\"*** Failed to initialize gym\")\n        quit()\n\n    return gym\n\n\ndef configure_gym(gym, config):\n    engine, render = config['engine'], config['render']\n\n    # physics engine settings\n    if(engine == 'FLEX'):\n        sim_engine = gymapi.SIM_FLEX\n    elif(engine == 'PHYSX'):\n        sim_engine = gymapi.SIM_PHYSX\n    else:\n        logger.warn(\"Uknown physics engine. defaulting to FLEX\")\n        sim_engine = gymapi.SIM_FLEX\n\n    # gym viewer\n    if render:\n        # create viewer\n        sim = gym.create_sim(0, 0, sim_type=sim_engine)\n        viewer = gym.create_viewer(\n            sim, int(gymapi.DEFAULT_VIEWER_WIDTH / 1.25),\n            int(gymapi.DEFAULT_VIEWER_HEIGHT / 1.25)\n        )\n\n        if viewer is None:\n            logger.warn(\"*** Failed to create viewer\")\n            quit()\n\n        # enable left mouse click or space bar for throwing projectiles\n        if config['add_projectiles']:\n            gym.subscribe_viewer_mouse_event(viewer, gymapi.MOUSE_LEFT_BUTTON, \"shoot\")\n            gym.subscribe_viewer_keyboard_event(viewer, gymapi.KEY_SPACE, \"shoot\")\n\n    else:\n        sim = gym.create_sim(0, -1)\n        viewer = None\n\n    # simulation params\n    scene_config = config['env']['scene']\n    sim_params = gymapi.SimParams()\n    sim_params.solver_type = scene_config['SolverType']\n    sim_params.num_outer_iterations = scene_config['NumIterations']\n    sim_params.num_inner_iterations = scene_config['NumInnerIterations']\n    sim_params.relaxation = scene_config.get('Relaxation', 0.75)\n    sim_params.warm_start = scene_config.get('WarmStart', 0.25)\n    sim_params.geometric_stiffness = scene_config.get('GeometricStiffness', 1.0)\n    sim_params.shape_collision_margin = 0.01\n\n    sim_params.gravity = gymapi.Vec3(0.0, -9.8, 0.0)\n    gym.set_sim_params(sim, sim_params)\n\n    return sim, viewer\n\n\ndef parse_states_from_reference_states(reference_states, progress):\n    # parse reference states from DeepMimicState\n    global_quats_ref = torch.tensor(\n        reference_states._global_rotation[(progress,)].numpy(),\n        dtype=torch.double\n    ).cuda()\n    ts_ref = torch.tensor(\n        reference_states._translation[(progress,)].numpy(),\n        dtype=torch.double\n    ).cuda()\n    vels_ref = torch.tensor(\n        reference_states._velocity[(progress,)].numpy(),\n        dtype=torch.double\n    ).cuda()\n    avels_ref = torch.tensor(\n        reference_states._angular_velocity[(progress,)].numpy(),\n        dtype=torch.double\n    ).cuda()\n    return global_quats_ref, ts_ref, vels_ref, avels_ref\n\n\ndef parse_states_from_reference_states_with_motion_id(precomputed_state,\n                                                      progress, motion_id):\n    assert len(progress) == len(motion_id)\n    # get the global id\n    global_id = precomputed_state['motion_offset'][motion_id] + progress\n    global_id = np.minimum(global_id,\n                           precomputed_state['global_quats_ref'].shape[0] - 1)\n\n    # parse reference states from DeepMimicState\n    global_quats_ref = precomputed_state['global_quats_ref'][global_id]\n    ts_ref = precomputed_state['ts_ref'][global_id]\n    vels_ref = precomputed_state['vels_ref'][global_id]\n    avels_ref = precomputed_state['avels_ref'][global_id]\n    return global_quats_ref, ts_ref, vels_ref, avels_ref\n\n\ndef parse_dof_state_with_motion_id(precomputed_state, dof_state,\n                                   progress, motion_id):\n    assert len(progress) == len(motion_id)\n    # get the global id\n    global_id = precomputed_state['motion_offset'][motion_id] + progress\n    # NOTE: it should never reach the dof_state.shape, cause the episode is\n    # terminated 2 steps before\n    global_id = np.minimum(global_id, dof_state.shape[0] - 1)\n\n    # parse reference states from DeepMimicState\n    return dof_state[global_id]\n\n\ndef get_flatten_ids(precomputed_state):\n    motion_offsets = precomputed_state['motion_offset']\n    init_state_id, init_motion_id, global_id = [], [], []\n    for i_motion in range(len(motion_offsets) - 1):\n        i_length = motion_offsets[i_motion + 1] - motion_offsets[i_motion]\n        init_state_id.extend(range(i_length))\n        init_motion_id.extend([i_motion] * i_length)\n        if len(global_id) == 0:\n            global_id.extend(range(0, i_length))\n        else:\n            global_id.extend(range(global_id[-1] + 1,\n                                   global_id[-1] + i_length + 1))\n    return np.array(init_state_id), np.array(init_motion_id), \\\n        np.array(global_id)\n\n\ndef parse_states_from_reference_states_with_global_id(precomputed_state,\n                                                      global_id):\n    # get the global id\n    global_id = global_id % precomputed_state['global_quats_ref'].shape[0]\n\n    # parse reference states from DeepMimicState\n    global_quats_ref = precomputed_state['global_quats_ref'][global_id]\n    ts_ref = precomputed_state['ts_ref'][global_id]\n    vels_ref = precomputed_state['vels_ref'][global_id]\n    avels_ref = precomputed_state['avels_ref'][global_id]\n    return global_quats_ref, ts_ref, vels_ref, avels_ref\n\n\ndef get_robot_states_from_torch_tensor(config, ts, global_quats, vels, avels,\n                                       init_rot, progress, motion_length=-1,\n                                       actions=None, relative_rot=None,\n                                       motion_id=None, num_motion=None,\n                                       motion_onehot_matrix=None):\n    info = {}\n    # the observation with quaternion-based representation\n    torso_height = ts[..., 0, 1].cpu().numpy()\n    gttrny, gqny, vny, avny, info['root_yaw_inv'] = \\\n        quaternion_math.compute_observation_return_info(global_quats, ts,\n                                                        vels, avels)\n    joint_obs = np.concatenate([gttrny.cpu().numpy(), gqny.cpu().numpy(),\n                                vny.cpu().numpy(), avny.cpu().numpy()], axis=-1)\n    joint_obs = joint_obs.reshape(joint_obs.shape[0], -1)\n    num_envs = joint_obs.shape[0]\n    obs = np.concatenate([torso_height[:, np.newaxis], joint_obs], -1)\n\n    # the previous action\n    if config['env_action_ob']:\n        obs = np.concatenate([obs, actions], axis=-1)\n\n    # the orientation\n    if config['env_orientation_ob']:\n        if relative_rot is not None:\n            obs = np.concatenate([obs, relative_rot], axis=-1)\n        else:\n            curr_rot = global_quats[np.arange(num_envs)][:, 0]\n            curr_rot = curr_rot.reshape(num_envs, -1, 4)\n            relative_rot = quaternion_math.compute_orientation_drift(\n                init_rot, curr_rot\n            ).cpu().numpy()\n            obs = np.concatenate([obs, relative_rot], axis=-1)\n\n    if config['env_frame_ob']:\n        if type(motion_length) == np.ndarray:\n            motion_length = motion_length.astype(np.float)\n            progress_ob = np.expand_dims(progress.astype(np.float) /\n                                         motion_length, axis=-1)\n        else:\n            progress_ob = np.expand_dims(progress.astype(np.float) /\n                                         float(motion_length), axis=-1)\n        obs = np.concatenate([obs, progress_ob], axis=-1)\n\n    if config['env_motion_ob'] and not config['env_motion_ob_onehot']:\n        motion_id_ob = np.expand_dims(motion_id.astype(np.float) /\n                                      float(num_motion), axis=-1)\n        obs = np.concatenate([obs, motion_id_ob], axis=-1)\n    elif config['env_motion_ob'] and config['env_motion_ob_onehot']:\n        motion_id_ob = motion_onehot_matrix[motion_id]\n        obs = np.concatenate([obs, motion_id_ob], axis=-1)\n\n    return obs, info\n\n\ndef get_xyzoffset(start_ts, end_ts, root_yaw_inv):\n    xyoffset = (end_ts - start_ts)[:, [0], :].reshape(1, -1, 1, 3)\n    ryinv = root_yaw_inv.reshape(1, -1, 1, 4)\n\n    calibrated_xyz_offset = quaternion_math.quat_apply(ryinv, xyoffset)[0, :, 0, :]\n    return calibrated_xyz_offset\n"
  },
  {
    "path": "timechamber/utils/logger.py",
    "content": "# -----------------------------------------------------------------------------\n#   @brief:\n#       The logger here will be called all across the project. It is inspired\n#   by Yuxin Wu (ppwwyyxx@gmail.com)\n#\n#   @author:\n#       Tingwu Wang, 2017, Feb, 20th\n# -----------------------------------------------------------------------------\n\nimport logging\nimport sys\nimport os\nimport datetime\nfrom termcolor import colored\n\n__all__ = ['set_file_handler']  # the actual worker is the '_logger'\n\n\nclass _MyFormatter(logging.Formatter):\n    '''\n        @brief:\n            a class to make sure the format could be used\n    '''\n\n    def format(self, record):\n        date = colored('[%(asctime)s @%(filename)s:%(lineno)d]', 'green')\n        msg = '%(message)s'\n\n        if record.levelno == logging.WARNING:\n            fmt = date + ' ' + \\\n                colored('WRN', 'red', attrs=[]) + ' ' + msg\n        elif record.levelno == logging.ERROR or \\\n                record.levelno == logging.CRITICAL:\n            fmt = date + ' ' + \\\n                colored('ERR', 'red', attrs=['underline']) + ' ' + msg\n        else:\n            fmt = date + ' ' + msg\n\n        if hasattr(self, '_style'):\n            # Python3 compatibilty\n            self._style._fmt = fmt\n        self._fmt = fmt\n\n        return super(self.__class__, self).format(record)\n\n\n_logger = logging.getLogger('joint_embedding')\n_logger.propagate = False\n_logger.setLevel(logging.INFO)\n\n# set the console output handler\ncon_handler = logging.StreamHandler(sys.stdout)\ncon_handler.setFormatter(_MyFormatter(datefmt='%m%d %H:%M:%S'))\n_logger.addHandler(con_handler)\n\n\nclass GLOBAL_PATH(object):\n\n    def __init__(self, path=None):\n        if path is None:\n            path = os.getcwd()\n        self.path = path\n\n    def _set_path(self, path):\n        self.path = path\n\n    def _get_path(self):\n        return self.path\n\n\nPATH = GLOBAL_PATH()\n\n\ndef set_file_handler(path=None, prefix='', time_str=''):\n    # set the file output handler\n    if time_str == '':\n        file_name = prefix + \\\n            datetime.datetime.now().strftime(\"%A_%d_%B_%Y_%I:%M%p\") + '.log'\n    else:\n        file_name = prefix + time_str + '.log'\n\n    if path is None:\n        mod = sys.modules['__main__']\n        path = os.path.join(os.path.abspath(mod.__file__), '..', '..', 'log')\n    else:\n        path = os.path.join(path, 'log')\n    path = os.path.abspath(path)\n\n    path = os.path.join(path, file_name)\n    if not os.path.exists(path):\n        os.makedirs(path)\n\n    PATH._set_path(path)\n    path = os.path.join(path, file_name)\n    from tensorboard_logger import configure\n    configure(path)\n\n    file_handler = logging.FileHandler(\n        filename=os.path.join(path, 'logger'), encoding='utf-8', mode='w')\n    file_handler.setFormatter(_MyFormatter(datefmt='%m%d %H:%M:%S'))\n    _logger.addHandler(file_handler)\n\n    _logger.info('Log file set to {}'.format(path))\n    return path\n\n\ndef _get_path():\n    return PATH._get_path()\n\n\n_LOGGING_METHOD = ['info', 'warning', 'error', 'critical',\n                   'warn', 'exception', 'debug']\n\n# export logger functions\nfor func in _LOGGING_METHOD:\n    locals()[func] = getattr(_logger, func)\n"
  },
  {
    "path": "timechamber/utils/motion_lib.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nimport numpy as np\nimport os\nimport yaml\n\nfrom timechamber.tasks.ase_humanoid_base.poselib.poselib.skeleton.skeleton3d import SkeletonMotion\nfrom timechamber.tasks.ase_humanoid_base.poselib.poselib.core.rotation3d import *\nfrom isaacgym.torch_utils import *\n\nfrom utils import torch_utils\n\nimport torch\n\nUSE_CACHE = True\nprint(\"MOVING MOTION DATA TO GPU, USING CACHE:\", USE_CACHE)\n\nif not USE_CACHE:\n    old_numpy = torch.Tensor.numpy\n    class Patch:\n        def numpy(self):\n            if self.is_cuda:\n                return self.to(\"cpu\").numpy()\n            else:\n                return old_numpy(self)\n\n    torch.Tensor.numpy = Patch.numpy\n\nclass DeviceCache:\n    def __init__(self, obj, device):\n        self.obj = obj\n        self.device = device\n\n        keys = dir(obj)\n        num_added = 0\n        for k in keys:\n            try:\n                out = getattr(obj, k)\n            except:\n                print(\"Error for key=\", k)\n                continue\n\n            if isinstance(out, torch.Tensor):\n                if out.is_floating_point():\n                    out = out.to(self.device, dtype=torch.float32)\n                else:\n                    out.to(self.device)\n                setattr(self, k, out)  \n                num_added += 1\n            elif isinstance(out, np.ndarray):\n                out = torch.tensor(out)\n                if out.is_floating_point():\n                    out = out.to(self.device, dtype=torch.float32)\n                else:\n                    out.to(self.device)\n                setattr(self, k, out)\n                num_added += 1\n        \n        print(\"Total added\", num_added)\n\n    def __getattr__(self, string):\n        out = getattr(self.obj, string)\n        return out\n\n\nclass MotionLib():\n    def __init__(self, motion_file, dof_body_ids, dof_offsets,\n                 key_body_ids, device):\n        self._dof_body_ids = dof_body_ids\n        self._dof_offsets = dof_offsets\n        self._num_dof = dof_offsets[-1]\n        self._key_body_ids = torch.tensor(key_body_ids, device=device)\n        self._device = device\n        self._load_motions(motion_file)\n\n        motions = self._motions\n        self.gts = torch.cat([m.global_translation for m in motions], dim=0).float()\n        self.grs = torch.cat([m.global_rotation for m in motions], dim=0).float()\n        self.lrs = torch.cat([m.local_rotation for m in motions], dim=0).float()\n        self.grvs = torch.cat([m.global_root_velocity for m in motions], dim=0).float()\n        self.gravs = torch.cat([m.global_root_angular_velocity for m in motions], dim=0).float()\n        self.dvs = torch.cat([m.dof_vels for m in motions], dim=0).float()\n\n        lengths = self._motion_num_frames\n        lengths_shifted = lengths.roll(1)\n        lengths_shifted[0] = 0\n        self.length_starts = lengths_shifted.cumsum(0)\n\n        self.motion_ids = torch.arange(len(self._motions), dtype=torch.long, device=self._device)\n\n        return\n\n    def num_motions(self):\n        return len(self._motions)\n\n    def get_total_length(self):\n        return sum(self._motion_lengths)\n\n    def get_motion(self, motion_id):\n        return self._motions[motion_id]\n\n    def sample_motions(self, n):\n        motion_ids = torch.multinomial(self._motion_weights, num_samples=n, replacement=True)\n\n        # m = self.num_motions()\n        # motion_ids = np.random.choice(m, size=n, replace=True, p=self._motion_weights)\n        # motion_ids = torch.tensor(motion_ids, device=self._device, dtype=torch.long)\n        return motion_ids\n\n    def sample_time(self, motion_ids, truncate_time=None):\n        n = len(motion_ids)\n        phase = torch.rand(motion_ids.shape, device=self._device)\n        \n        motion_len = self._motion_lengths[motion_ids]\n        if (truncate_time is not None):\n            assert(truncate_time >= 0.0)\n            motion_len -= truncate_time\n\n        motion_time = phase * motion_len\n        return motion_time\n\n    def get_motion_length(self, motion_ids):\n        return self._motion_lengths[motion_ids]\n\n    def get_motion_state(self, motion_ids, motion_times):\n        n = len(motion_ids)\n        num_bodies = self._get_num_bodies()\n        num_key_bodies = self._key_body_ids.shape[0]\n\n        motion_len = self._motion_lengths[motion_ids]\n        num_frames = self._motion_num_frames[motion_ids]\n        dt = self._motion_dt[motion_ids]\n\n        frame_idx0, frame_idx1, blend = self._calc_frame_blend(motion_times, motion_len, num_frames, dt)\n\n        f0l = frame_idx0 + self.length_starts[motion_ids]\n        f1l = frame_idx1 + self.length_starts[motion_ids]\n\n        root_pos0 = self.gts[f0l, 0]\n        root_pos1 = self.gts[f1l, 0]\n\n        root_rot0 = self.grs[f0l, 0]\n        root_rot1 = self.grs[f1l, 0]\n\n        local_rot0 = self.lrs[f0l]\n        local_rot1 = self.lrs[f1l]\n\n        root_vel = self.grvs[f0l]\n\n        root_ang_vel = self.gravs[f0l]\n        \n        key_pos0 = self.gts[f0l.unsqueeze(-1), self._key_body_ids.unsqueeze(0)]\n        key_pos1 = self.gts[f1l.unsqueeze(-1), self._key_body_ids.unsqueeze(0)]\n\n        dof_vel = self.dvs[f0l]\n\n        vals = [root_pos0, root_pos1, local_rot0, local_rot1, root_vel, root_ang_vel, key_pos0, key_pos1]\n        for v in vals:\n            assert v.dtype != torch.float64\n\n\n        blend = blend.unsqueeze(-1)\n\n        root_pos = (1.0 - blend) * root_pos0 + blend * root_pos1\n\n        root_rot = torch_utils.slerp(root_rot0, root_rot1, blend)\n\n        blend_exp = blend.unsqueeze(-1)\n        key_pos = (1.0 - blend_exp) * key_pos0 + blend_exp * key_pos1\n        \n        local_rot = torch_utils.slerp(local_rot0, local_rot1, torch.unsqueeze(blend, axis=-1))\n        dof_pos = self._local_rotation_to_dof(local_rot)\n\n        return root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, key_pos\n    \n    def _load_motions(self, motion_file):\n        self._motions = []\n        self._motion_lengths = []\n        self._motion_weights = []\n        self._motion_fps = []\n        self._motion_dt = []\n        self._motion_num_frames = []\n        self._motion_files = []\n\n        total_len = 0.0\n\n        motion_files, motion_weights = self._fetch_motion_files(motion_file)\n        num_motion_files = len(motion_files)\n        for f in range(num_motion_files):\n            curr_file = motion_files[f]\n            print(\"Loading {:d}/{:d} motion files: {:s}\".format(f + 1, num_motion_files, curr_file))\n            curr_motion = SkeletonMotion.from_file(curr_file)\n\n            motion_fps = curr_motion.fps\n            curr_dt = 1.0 / motion_fps\n\n            num_frames = curr_motion.tensor.shape[0]\n            curr_len = 1.0 / motion_fps * (num_frames - 1)\n\n            self._motion_fps.append(motion_fps)\n            self._motion_dt.append(curr_dt)\n            self._motion_num_frames.append(num_frames)\n \n            curr_dof_vels = self._compute_motion_dof_vels(curr_motion)\n            curr_motion.dof_vels = curr_dof_vels\n\n            # Moving motion tensors to the GPU\n            if USE_CACHE:\n                curr_motion = DeviceCache(curr_motion, self._device)                \n            else:\n                curr_motion.tensor = curr_motion.tensor.to(self._device)\n                curr_motion._skeleton_tree._parent_indices = curr_motion._skeleton_tree._parent_indices.to(self._device)\n                curr_motion._skeleton_tree._local_translation = curr_motion._skeleton_tree._local_translation.to(self._device)\n                curr_motion._rotation = curr_motion._rotation.to(self._device)\n\n            self._motions.append(curr_motion)\n            self._motion_lengths.append(curr_len)\n            \n            curr_weight = motion_weights[f]\n            self._motion_weights.append(curr_weight)\n            self._motion_files.append(curr_file)\n\n        self._motion_lengths = torch.tensor(self._motion_lengths, device=self._device, dtype=torch.float32)\n\n        self._motion_weights = torch.tensor(self._motion_weights, dtype=torch.float32, device=self._device)\n        self._motion_weights /= self._motion_weights.sum()\n\n        self._motion_fps = torch.tensor(self._motion_fps, device=self._device, dtype=torch.float32)\n        self._motion_dt = torch.tensor(self._motion_dt, device=self._device, dtype=torch.float32)\n        self._motion_num_frames = torch.tensor(self._motion_num_frames, device=self._device)\n\n\n        num_motions = self.num_motions()\n        total_len = self.get_total_length()\n\n        print(\"Loaded {:d} motions with a total length of {:.3f}s.\".format(num_motions, total_len))\n\n        return\n\n    def _fetch_motion_files(self, motion_file):\n        ext = os.path.splitext(motion_file)[1]\n        if (ext == \".yaml\"):\n            dir_name = os.path.dirname(motion_file)\n            motion_files = []\n            motion_weights = []\n\n            with open(os.path.join(os.getcwd(), motion_file), 'r') as f:\n                motion_config = yaml.load(f, Loader=yaml.SafeLoader)\n\n            motion_list = motion_config['motions']\n            for motion_entry in motion_list:\n                curr_file = motion_entry['file']\n                curr_weight = motion_entry['weight']\n                assert(curr_weight >= 0)\n\n                curr_file = os.path.join(dir_name, curr_file)\n                motion_weights.append(curr_weight)\n                motion_files.append(curr_file)\n        else:\n            motion_files = [motion_file]\n            motion_weights = [1.0]\n\n        return motion_files, motion_weights\n\n    def _calc_frame_blend(self, time, len, num_frames, dt):\n\n        phase = time / len\n        phase = torch.clip(phase, 0.0, 1.0)\n\n        frame_idx0 = (phase * (num_frames - 1)).long()\n        frame_idx1 = torch.min(frame_idx0 + 1, num_frames - 1)\n        blend = (time - frame_idx0 * dt) / dt\n\n        return frame_idx0, frame_idx1, blend\n\n    def _get_num_bodies(self):\n        motion = self.get_motion(0)\n        num_bodies = motion.num_joints\n        return num_bodies\n\n    def _compute_motion_dof_vels(self, motion):\n        num_frames = motion.tensor.shape[0]\n        dt = 1.0 / motion.fps\n        dof_vels = []\n\n        for f in range(num_frames - 1):\n            local_rot0 = motion.local_rotation[f]\n            local_rot1 = motion.local_rotation[f + 1]\n            frame_dof_vel = self._local_rotation_to_dof_vel(local_rot0, local_rot1, dt)\n            frame_dof_vel = frame_dof_vel\n            dof_vels.append(frame_dof_vel)\n        \n        dof_vels.append(dof_vels[-1])\n        dof_vels = torch.stack(dof_vels, dim=0)\n\n        return dof_vels\n    \n    def _local_rotation_to_dof(self, local_rot):\n        body_ids = self._dof_body_ids\n        dof_offsets = self._dof_offsets\n\n        n = local_rot.shape[0]\n        dof_pos = torch.zeros((n, self._num_dof), dtype=torch.float, device=self._device)\n\n        for j in range(len(body_ids)):\n            body_id = body_ids[j]\n            joint_offset = dof_offsets[j]\n            joint_size = dof_offsets[j + 1] - joint_offset\n\n            if (joint_size == 3):\n                joint_q = local_rot[:, body_id]\n                joint_exp_map = torch_utils.quat_to_exp_map(joint_q)\n                dof_pos[:, joint_offset:(joint_offset + joint_size)] = joint_exp_map\n            elif (joint_size == 1):\n                joint_q = local_rot[:, body_id]\n                joint_theta, joint_axis = torch_utils.quat_to_angle_axis(joint_q)\n                joint_theta = joint_theta * joint_axis[..., 1] # assume joint is always along y axis\n\n                joint_theta = normalize_angle(joint_theta)\n                dof_pos[:, joint_offset] = joint_theta\n\n            else:\n                print(\"Unsupported joint type\")\n                assert(False)\n\n        return dof_pos\n\n    def _local_rotation_to_dof_vel(self, local_rot0, local_rot1, dt):\n        body_ids = self._dof_body_ids\n        dof_offsets = self._dof_offsets\n\n        dof_vel = torch.zeros([self._num_dof], device=self._device)\n\n        diff_quat_data = quat_mul_norm(quat_inverse(local_rot0), local_rot1)\n        diff_angle, diff_axis = quat_angle_axis(diff_quat_data)\n        local_vel = diff_axis * diff_angle.unsqueeze(-1) / dt\n        local_vel = local_vel\n\n        for j in range(len(body_ids)):\n            body_id = body_ids[j]\n            joint_offset = dof_offsets[j]\n            joint_size = dof_offsets[j + 1] - joint_offset\n\n            if (joint_size == 3):\n                joint_vel = local_vel[body_id]\n                dof_vel[joint_offset:(joint_offset + joint_size)] = joint_vel\n\n            elif (joint_size == 1):\n                assert(joint_size == 1)\n                joint_vel = local_vel[body_id]\n                dof_vel[joint_offset] = joint_vel[1] # assume joint is always along y axis\n\n            else:\n                print(\"Unsupported joint type\")\n                assert(False)\n\n        return dof_vel"
  },
  {
    "path": "timechamber/utils/reformat.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nfrom omegaconf import DictConfig, OmegaConf\nfrom typing import Dict\n\ndef omegaconf_to_dict(d: DictConfig)->Dict:\n    \"\"\"Converts an omegaconf DictConfig to a python Dict, respecting variable interpolation.\"\"\"\n    ret = {}\n    for k, v in d.items():\n        if isinstance(v, DictConfig):\n            ret[k] = omegaconf_to_dict(v)\n        else:\n            ret[k] = v\n    return ret\n\ndef print_dict(val, nesting: int = -4, start: bool = True):\n    \"\"\"Outputs a nested dictionory.\"\"\"\n    if type(val) == dict:\n        if not start:\n            print('')\n        nesting += 4\n        for k in val:\n            print(nesting * ' ', end='')\n            print(k, end=': ')\n            print_dict(val[k], nesting, start=False)\n    else:\n        print(val)\n\n# EOF\n"
  },
  {
    "path": "timechamber/utils/rlgames_utils.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nfrom rl_games.common import env_configurations, vecenv\nfrom rl_games.common.algo_observer import AlgoObserver\nfrom rl_games.algos_torch import torch_ext\nfrom timechamber.utils.utils import set_seed\nimport torch\nimport numpy as np\nfrom typing import Callable\nfrom isaacgym import gymapi\nfrom isaacgym import gymutil\nfrom omegaconf import DictConfig\nfrom timechamber.tasks import isaacgym_task_map\nfrom timechamber.utils.vec_task_wrappers import VecTaskPythonWrapper\nfrom timechamber.utils.config import parse_sim_params\n\nSIM_TIMESTEP = 1.0 / 60.0\n\ndef get_rlgames_env_creator(\n        # used to create the vec task\n        seed: int,\n        cfg: DictConfig,\n        task_config: dict,\n        task_name: str,\n        sim_device: str,\n        rl_device: str,\n        graphics_device_id: int,\n        headless: bool,\n        device_type: str = \"cuda\",\n        # Used to handle multi-gpu case\n        multi_gpu: bool = False,\n        post_create_hook: Callable = None,\n        virtual_screen_capture: bool = False,\n        force_render: bool = False,\n):\n    \"\"\"Parses the configuration parameters for the environment task and creates a VecTask\n\n    Args:\n        task_config: environment configuration.\n        task_name: Name of the task, used to evaluate based on the imported name (eg 'Trifinger')\n        sim_device: The type of env device, eg 'cuda:0'\n        rl_device: Device that RL will be done on, eg 'cuda:0'\n        graphics_device_id: Graphics device ID.\n        headless: Whether to run in headless mode.\n        multi_gpu: Whether to use multi gpu\n        post_create_hook: Hooks to be called after environment creation.\n            [Needed to setup WandB only for one of the RL Games instances when doing multiple GPUs]\n        virtual_screen_capture: Set to True to allow the users get captured screen in RGB array via `env.render(mode='rgb_array')`. \n        force_render: Set to True to always force rendering in the steps (if the `control_freq_inv` is greater than 1 we suggest stting this arg to True)\n    Returns:\n        A VecTaskPython object.\n    \"\"\"\n    def create_rlgpu_env():\n        \"\"\"\n        Creates the task from configurations and wraps it using RL-games wrappers if required.\n        \"\"\"\n\n        # create native task and pass custom config\n        if task_name == \"MA_Humanoid_Strike\":\n            sim_params = parse_sim_params(cfg, task_config)\n            if cfg.physics_engine == \"physx\":\n                physics_engine = gymapi.SIM_PHYSX\n            elif cfg.physics_engine == \"flex\":\n                physics_engine = gymapi.SIM_FLEX\n            task = isaacgym_task_map[task_name](\n                cfg=task_config,\n                sim_params=sim_params,\n                physics_engine=physics_engine,\n                device_type=device_type,\n                device_id=graphics_device_id,\n                headless=headless\n            )\n            env = VecTaskPythonWrapper(task, rl_device,\n                                       task_config.get(\"clip_observations\", np.inf),\n                                       task_config.get(\"clip_actions\", 1.0),\n                                       AMP=True)\n        else:\n            task = isaacgym_task_map[task_name](\n                cfg=task_config,\n                rl_device=rl_device,\n                sim_device=sim_device,\n                graphics_device_id=graphics_device_id,\n                headless=headless,\n                virtual_screen_capture=virtual_screen_capture,\n                force_render=force_render,\n            )\n            env = VecTaskPythonWrapper(task, rl_device, task_config.get(\"clip_observations\", np.inf), task_config.get(\"clip_actions\", 1.0))\n        \n        if post_create_hook is not None:\n            post_create_hook()\n\n        return env\n    return create_rlgpu_env\n\n\nclass RLGPUAlgoObserver(AlgoObserver):\n    \"\"\"Allows us to log stats from the env along with the algorithm running stats. \"\"\"\n\n    def __init__(self):\n        pass\n\n    def after_init(self, algo):\n        self.algo = algo\n        self.mean_scores = torch_ext.AverageMeter(1, self.algo.games_to_track).to(self.algo.ppo_device)\n        self.ep_infos = []\n        self.direct_info = {}\n        self.writer = self.algo.writer\n\n    def process_infos(self, infos, done_indices):\n        assert isinstance(infos, dict), \"RLGPUAlgoObserver expects dict info\"\n        if isinstance(infos, dict):\n            if 'episode' in infos:\n                self.ep_infos.append(infos['episode'])\n\n            if len(infos) > 0 and isinstance(infos, dict):  # allow direct logging from env\n                self.direct_info = {}\n                for k, v in infos.items():\n                    # only log scalars\n                    if isinstance(v, float) or isinstance(v, int) or (isinstance(v, torch.Tensor) and len(v.shape) == 0):\n                        self.direct_info[k] = v\n\n    def after_clear_stats(self):\n        self.mean_scores.clear()\n\n    def after_print_stats(self, frame, epoch_num, total_time):\n        if self.ep_infos:\n            for key in self.ep_infos[0]:\n                    infotensor = torch.tensor([], device=self.algo.device)\n                    for ep_info in self.ep_infos:\n                        # handle scalar and zero dimensional tensor infos\n                        if not isinstance(ep_info[key], torch.Tensor):\n                            ep_info[key] = torch.Tensor([ep_info[key]])\n                        if len(ep_info[key].shape) == 0:\n                            ep_info[key] = ep_info[key].unsqueeze(0)\n                        infotensor = torch.cat((infotensor, ep_info[key].to(self.algo.device)))\n                    value = torch.mean(infotensor)\n                    self.writer.add_scalar('Episode/' + key, value, epoch_num)\n            self.ep_infos.clear()\n        \n        for k, v in self.direct_info.items():\n            self.writer.add_scalar(f'{k}/frame', v, frame)\n            self.writer.add_scalar(f'{k}/iter', v, epoch_num)\n            self.writer.add_scalar(f'{k}/time', v, total_time)\n\n        if self.mean_scores.current_size > 0:\n            mean_scores = self.mean_scores.get_mean()\n            self.writer.add_scalar('scores/mean', mean_scores, frame)\n            self.writer.add_scalar('scores/iter', mean_scores, epoch_num)\n            self.writer.add_scalar('scores/time', mean_scores, total_time)\n\n\nclass RLGPUEnv(vecenv.IVecEnv):\n    def __init__(self, config_name, num_actors, **kwargs):\n        self.env = env_configurations.configurations[config_name]['env_creator'](**kwargs)\n        self.use_global_obs = (self.env.num_states > 0)\n\n        self.full_state = {}\n        self.full_state[\"obs\"] = self.reset()\n        if self.use_global_obs:\n            self.full_state[\"states\"] = self.env.get_state()\n        return\n\n    def step(self, action):\n        next_obs, reward, is_done, info = self.env.step(action)\n\n        # todo: improve, return only dictinary\n        self.full_state[\"obs\"] = next_obs\n        if self.use_global_obs:\n            self.full_state[\"states\"] = self.env.get_state()\n        return self.full_state, reward, is_done, info\n\n    def reset(self, env_ids=None):\n        self.full_state[\"obs\"] = self.env.reset(env_ids)\n        if self.use_global_obs:\n            self.full_state[\"states\"] = self.env.get_state()\n        return self.full_state\n\n    def get_number_of_agents(self):\n        return self.env.get_number_of_agents()\n\n    def get_env_info(self):\n        info = {}\n        info['action_space'] = self.env.action_space\n        info['observation_space'] = self.env.observation_space\n        info['amp_observation_space'] = self.env.amp_observation_space\n\n        if self.use_global_obs:\n            info['state_space'] = self.env.state_space\n            print(info['action_space'], info['observation_space'], info['state_space'])\n        else:\n            print(info['action_space'], info['observation_space'])\n\n        return info\n"
  },
  {
    "path": "timechamber/utils/torch_jit_utils.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nimport torch\nimport numpy as np\nfrom isaacgym.torch_utils import *\n\n\n@torch.jit.script\ndef compute_heading_and_up(\n    torso_rotation, inv_start_rot, to_target, vec0, vec1, up_idx\n):\n    # type: (Tensor, Tensor, Tensor, Tensor, Tensor, int) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]\n    num_envs = torso_rotation.shape[0]\n    target_dirs = normalize(to_target)\n\n    torso_quat = quat_mul(torso_rotation, inv_start_rot)\n    up_vec = get_basis_vector(torso_quat, vec1).view(num_envs, 3)\n    heading_vec = get_basis_vector(torso_quat, vec0).view(num_envs, 3)\n    up_proj = up_vec[:, up_idx]\n    heading_proj = torch.bmm(heading_vec.view(\n        num_envs, 1, 3), target_dirs.view(num_envs, 3, 1)).view(num_envs)\n\n    return torso_quat, up_proj, heading_proj, up_vec, heading_vec\n\n\n@torch.jit.script\ndef compute_rot(torso_quat, velocity, ang_velocity, targets, torso_positions):\n    vel_loc = quat_rotate_inverse(torso_quat, velocity)\n    angvel_loc = quat_rotate_inverse(torso_quat, ang_velocity)\n\n    roll, pitch, yaw = get_euler_xyz(torso_quat)\n\n    walk_target_angle = torch.atan2(targets[:, 2] - torso_positions[:, 2],\n                                    targets[:, 0] - torso_positions[:, 0])\n    angle_to_target = walk_target_angle - yaw\n\n    return vel_loc, angvel_loc, roll, pitch, yaw, angle_to_target\n\n\n@torch.jit.script\ndef quat_axis(q, axis=0):\n    # type: (Tensor, int) -> Tensor\n    basis_vec = torch.zeros(q.shape[0], 3, device=q.device)\n    basis_vec[:, axis] = 1\n    return quat_rotate(q, basis_vec)\n\n\n\"\"\"\nNormalization and Denormalization of Tensors\n\"\"\"\n\n\n@torch.jit.script\ndef scale_transform(x: torch.Tensor, lower: torch.Tensor, upper: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Normalizes a given input tensor to a range of [-1, 1].\n\n    @note It uses pytorch broadcasting functionality to deal with batched input.\n\n    Args:\n        x: Input tensor of shape (N, dims).\n        lower: The minimum value of the tensor. Shape (dims,)\n        upper: The maximum value of the tensor. Shape (dims,)\n\n    Returns:\n        Normalized transform of the tensor. Shape (N, dims)\n    \"\"\"\n    # default value of center\n    offset = (lower + upper) * 0.5\n    # return normalized tensor\n    return 2 * (x - offset) / (upper - lower)\n\n\n@torch.jit.script\ndef unscale_transform(x: torch.Tensor, lower: torch.Tensor, upper: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Denormalizes a given input tensor from range of [-1, 1] to (lower, upper).\n\n    @note It uses pytorch broadcasting functionality to deal with batched input.\n\n    Args:\n        x: Input tensor of shape (N, dims).\n        lower: The minimum value of the tensor. Shape (dims,)\n        upper: The maximum value of the tensor. Shape (dims,)\n\n    Returns:\n        Denormalized transform of the tensor. Shape (N, dims)\n    \"\"\"\n    # default value of center\n    offset = (lower + upper) * 0.5\n    # return normalized tensor\n    return x * (upper - lower) * 0.5 + offset\n\n@torch.jit.script\ndef saturate(x: torch.Tensor, lower: torch.Tensor, upper: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Clamps a given input tensor to (lower, upper).\n\n    @note It uses pytorch broadcasting functionality to deal with batched input.\n\n    Args:\n        x: Input tensor of shape (N, dims).\n        lower: The minimum value of the tensor. Shape (dims,)\n        upper: The maximum value of the tensor. Shape (dims,)\n\n    Returns:\n        Clamped transform of the tensor. Shape (N, dims)\n    \"\"\"\n    return torch.max(torch.min(x, upper), lower)\n\n\"\"\"\nRotation conversions\n\"\"\"\n\n@torch.jit.script\ndef quat_diff_rad(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Get the difference in radians between two quaternions.\n\n    Args:\n        a: first quaternion, shape (N, 4)\n        b: second quaternion, shape (N, 4)\n    Returns:\n        Difference in radians, shape (N,)\n    \"\"\"\n    b_conj = quat_conjugate(b)\n    mul = quat_mul(a, b_conj)\n    # 2 * torch.acos(torch.abs(mul[:, -1]))\n    return 2.0 * torch.asin(\n        torch.clamp(\n            torch.norm(\n                mul[:, 0:3],\n                p=2, dim=-1), max=1.0)\n    )\n\n\n@torch.jit.script\ndef local_to_world_space(pos_offset_local: torch.Tensor, pose_global: torch.Tensor):\n    \"\"\" Convert a point from the local frame to the global frame\n    Args:\n        pos_offset_local: Point in local frame. Shape: [N, 3]\n        pose_global: The spatial pose of this point. Shape: [N, 7]\n    Returns:\n        Position in the global frame. Shape: [N, 3]\n    \"\"\"\n    quat_pos_local = torch.cat(\n        [pos_offset_local, torch.zeros(pos_offset_local.shape[0], 1, dtype=torch.float32, device=pos_offset_local.device)],\n        dim=-1\n    )\n    quat_global = pose_global[:, 3:7]\n    quat_global_conj = quat_conjugate(quat_global)\n    pos_offset_global = quat_mul(quat_global, quat_mul(quat_pos_local, quat_global_conj))[:, 0:3]\n\n    result_pos_gloal = pos_offset_global + pose_global[:, 0:3]\n\n    return result_pos_gloal\n\n# NB: do not make this function jit, since it is passed around as an argument.\ndef normalise_quat_in_pose(pose):\n    \"\"\"Takes a pose and normalises the quaternion portion of it.\n\n    Args:\n        pose: shape N, 7\n    Returns:\n        Pose with normalised quat. Shape N, 7\n    \"\"\"\n    pos = pose[:, 0:3]\n    quat = pose[:, 3:7]\n    quat /= torch.norm(quat, dim=-1, p=2).reshape(-1, 1)\n    return torch.cat([pos, quat], dim=-1)\n\n@torch.jit.script\ndef my_quat_rotate(q, v):\n    shape = q.shape\n    q_w = q[:, -1]\n    q_vec = q[:, :3]\n    a = v * (2.0 * q_w ** 2 - 1.0).unsqueeze(-1)\n    b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0\n    c = q_vec * \\\n        torch.bmm(q_vec.view(shape[0], 1, 3), v.view(\n            shape[0], 3, 1)).squeeze(-1) * 2.0\n    return a + b + c\n\n@torch.jit.script\ndef quat_to_angle_axis(q):\n    # type: (Tensor) -> Tuple[Tensor, Tensor]\n    # computes axis-angle representation from quaternion q\n    # q must be normalized\n    min_theta = 1e-5\n    qx, qy, qz, qw = 0, 1, 2, 3\n\n    sin_theta = torch.sqrt(1 - q[..., qw] * q[..., qw])\n    angle = 2 * torch.acos(q[..., qw])\n    angle = normalize_angle(angle)\n    sin_theta_expand = sin_theta.unsqueeze(-1)\n    axis = q[..., qx:qw] / sin_theta_expand\n\n    mask = sin_theta > min_theta\n    default_axis = torch.zeros_like(axis)\n    default_axis[..., -1] = 1\n\n    angle = torch.where(mask, angle, torch.zeros_like(angle))\n    mask_expand = mask.unsqueeze(-1)\n    axis = torch.where(mask_expand, axis, default_axis)\n    return angle, axis\n\n@torch.jit.script\ndef angle_axis_to_exp_map(angle, axis):\n    # type: (Tensor, Tensor) -> Tensor\n    # compute exponential map from axis-angle\n    angle_expand = angle.unsqueeze(-1)\n    exp_map = angle_expand * axis\n    return exp_map\n\n@torch.jit.script\ndef quat_to_exp_map(q):\n    # type: (Tensor) -> Tensor\n    # compute exponential map from quaternion\n    # q must be normalized\n    angle, axis = quat_to_angle_axis(q)\n    exp_map = angle_axis_to_exp_map(angle, axis)\n    return exp_map\n\n@torch.jit.script\ndef quat_to_tan_norm(q):\n    # type: (Tensor) -> Tensor\n    # represents a rotation using the tangent and normal vectors\n    ref_tan = torch.zeros_like(q[..., 0:3])\n    ref_tan[..., 0] = 1\n    tan = my_quat_rotate(q, ref_tan)\n    \n    ref_norm = torch.zeros_like(q[..., 0:3])\n    ref_norm[..., -1] = 1\n    norm = my_quat_rotate(q, ref_norm)\n    \n    norm_tan = torch.cat([tan, norm], dim=len(tan.shape) - 1)\n    return norm_tan\n\n@torch.jit.script\ndef euler_xyz_to_exp_map(roll, pitch, yaw):\n    # type: (Tensor, Tensor, Tensor) -> Tensor\n    q = quat_from_euler_xyz(roll, pitch, yaw)\n    exp_map = quat_to_exp_map(q)\n    return exp_map\n\n@torch.jit.script\ndef exp_map_to_angle_axis(exp_map):\n    min_theta = 1e-5\n\n    angle = torch.norm(exp_map, dim=-1)\n    angle_exp = torch.unsqueeze(angle, dim=-1)\n    axis = exp_map / angle_exp\n    angle = normalize_angle(angle)\n\n    default_axis = torch.zeros_like(exp_map)\n    default_axis[..., -1] = 1\n\n    mask = angle > min_theta\n    angle = torch.where(mask, angle, torch.zeros_like(angle))\n    mask_expand = mask.unsqueeze(-1)\n    axis = torch.where(mask_expand, axis, default_axis)\n\n    return angle, axis\n\n@torch.jit.script\ndef exp_map_to_quat(exp_map):\n    angle, axis = exp_map_to_angle_axis(exp_map)\n    q = quat_from_angle_axis(angle, axis)\n    return q\n\n@torch.jit.script\ndef slerp(q0, q1, t):\n    # type: (Tensor, Tensor, Tensor) -> Tensor\n    qx, qy, qz, qw = 0, 1, 2, 3\n\n    cos_half_theta = q0[..., qw] * q1[..., qw] \\\n                   + q0[..., qx] * q1[..., qx] \\\n                   + q0[..., qy] * q1[..., qy] \\\n                   + q0[..., qz] * q1[..., qz]\n    \n    neg_mask = cos_half_theta < 0\n    q1 = q1.clone()\n    q1[neg_mask] = -q1[neg_mask]\n    cos_half_theta = torch.abs(cos_half_theta)\n    cos_half_theta = torch.unsqueeze(cos_half_theta, dim=-1)\n\n    half_theta = torch.acos(cos_half_theta);\n    sin_half_theta = torch.sqrt(1.0 - cos_half_theta * cos_half_theta);\n\n    ratioA = torch.sin((1 - t) * half_theta) / sin_half_theta;\n    ratioB = torch.sin(t * half_theta) / sin_half_theta; \n    \n    new_q_x = ratioA * q0[..., qx:qx+1] + ratioB * q1[..., qx:qx+1]\n    new_q_y = ratioA * q0[..., qy:qy+1] + ratioB * q1[..., qy:qy+1]\n    new_q_z = ratioA * q0[..., qz:qz+1] + ratioB * q1[..., qz:qz+1]\n    new_q_w = ratioA * q0[..., qw:qw+1] + ratioB * q1[..., qw:qw+1]\n\n    cat_dim = len(new_q_w.shape) - 1\n    new_q = torch.cat([new_q_x, new_q_y, new_q_z, new_q_w], dim=cat_dim)\n\n    new_q = torch.where(torch.abs(sin_half_theta) < 0.001, 0.5 * q0 + 0.5 * q1, new_q)\n    new_q = torch.where(torch.abs(cos_half_theta) >= 1, q0, new_q)\n\n    return new_q\n\n@torch.jit.script\ndef calc_heading(q):\n    # type: (Tensor) -> Tensor\n    # calculate heading direction from quaternion\n    # the heading is the direction on the xy plane\n    # q must be normalized\n    ref_dir = torch.zeros_like(q[..., 0:3])\n    ref_dir[..., 0] = 1\n    rot_dir = my_quat_rotate(q, ref_dir)\n\n    heading = torch.atan2(rot_dir[..., 1], rot_dir[..., 0])\n    return heading\n\n@torch.jit.script\ndef calc_heading_quat(q):\n    # type: (Tensor) -> Tensor\n    # calculate heading rotation from quaternion\n    # the heading is the direction on the xy plane\n    # q must be normalized\n    heading = calc_heading(q)\n    axis = torch.zeros_like(q[..., 0:3])\n    axis[..., 2] = 1\n\n    heading_q = quat_from_angle_axis(heading, axis)\n    return heading_q\n\n@torch.jit.script\ndef calc_heading_quat_inv(q):\n    # type: (Tensor) -> Tensor\n    # calculate heading rotation from quaternion\n    # the heading is the direction on the xy plane\n    # q must be normalized\n    heading = calc_heading(q)\n    axis = torch.zeros_like(q[..., 0:3])\n    axis[..., 2] = 1\n\n    heading_q = quat_from_angle_axis(-heading, axis)\n    return heading_q\n\n\n# EOF\n"
  },
  {
    "path": "timechamber/utils/torch_utils.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nimport torch\nimport numpy as np\n\nfrom isaacgym.torch_utils import *\n\n@torch.jit.script\ndef quat_to_angle_axis(q):\n    # type: (Tensor) -> Tuple[Tensor, Tensor]\n    # computes axis-angle representation from quaternion q\n    # q must be normalized\n    min_theta = 1e-5\n    qx, qy, qz, qw = 0, 1, 2, 3\n\n    sin_theta = torch.sqrt(1 - q[..., qw] * q[..., qw])\n    angle = 2 * torch.acos(q[..., qw])\n    angle = normalize_angle(angle)\n    sin_theta_expand = sin_theta.unsqueeze(-1)\n    axis = q[..., qx:qw] / sin_theta_expand\n\n    mask = torch.abs(sin_theta) > min_theta\n    default_axis = torch.zeros_like(axis)\n    default_axis[..., -1] = 1\n\n    angle = torch.where(mask, angle, torch.zeros_like(angle))\n    mask_expand = mask.unsqueeze(-1)\n    axis = torch.where(mask_expand, axis, default_axis)\n    return angle, axis\n\n@torch.jit.script\ndef angle_axis_to_exp_map(angle, axis):\n    # type: (Tensor, Tensor) -> Tensor\n    # compute exponential map from axis-angle\n    angle_expand = angle.unsqueeze(-1)\n    exp_map = angle_expand * axis\n    return exp_map\n\n@torch.jit.script\ndef quat_to_exp_map(q):\n    # type: (Tensor) -> Tensor\n    # compute exponential map from quaternion\n    # q must be normalized\n    angle, axis = quat_to_angle_axis(q)\n    exp_map = angle_axis_to_exp_map(angle, axis)\n    return exp_map\n\n@torch.jit.script\ndef quat_to_tan_norm(q):\n    # type: (Tensor) -> Tensor\n    # represents a rotation using the tangent and normal vectors\n    ref_tan = torch.zeros_like(q[..., 0:3])\n    ref_tan[..., 0] = 1\n    tan = quat_rotate(q, ref_tan)\n    \n    ref_norm = torch.zeros_like(q[..., 0:3])\n    ref_norm[..., -1] = 1\n    norm = quat_rotate(q, ref_norm)\n    \n    norm_tan = torch.cat([tan, norm], dim=len(tan.shape) - 1)\n    return norm_tan\n\n@torch.jit.script\ndef euler_xyz_to_exp_map(roll, pitch, yaw):\n    # type: (Tensor, Tensor, Tensor) -> Tensor\n    q = quat_from_euler_xyz(roll, pitch, yaw)\n    exp_map = quat_to_exp_map(q)\n    return exp_map\n\n@torch.jit.script\ndef exp_map_to_angle_axis(exp_map):\n    min_theta = 1e-5\n\n    angle = torch.norm(exp_map, dim=-1)\n    angle_exp = torch.unsqueeze(angle, dim=-1)\n    axis = exp_map / angle_exp\n    angle = normalize_angle(angle)\n\n    default_axis = torch.zeros_like(exp_map)\n    default_axis[..., -1] = 1\n\n    mask = torch.abs(angle) > min_theta\n    angle = torch.where(mask, angle, torch.zeros_like(angle))\n    mask_expand = mask.unsqueeze(-1)\n    axis = torch.where(mask_expand, axis, default_axis)\n\n    return angle, axis\n\n@torch.jit.script\ndef exp_map_to_quat(exp_map):\n    angle, axis = exp_map_to_angle_axis(exp_map)\n    q = quat_from_angle_axis(angle, axis)\n    return q\n\n@torch.jit.script\ndef slerp(q0, q1, t):\n    # type: (Tensor, Tensor, Tensor) -> Tensor\n    cos_half_theta = torch.sum(q0 * q1, dim=-1)\n\n    neg_mask = cos_half_theta < 0\n    q1 = q1.clone()\n    q1[neg_mask] = -q1[neg_mask]\n    cos_half_theta = torch.abs(cos_half_theta)\n    cos_half_theta = torch.unsqueeze(cos_half_theta, dim=-1)\n\n    half_theta = torch.acos(cos_half_theta);\n    sin_half_theta = torch.sqrt(1.0 - cos_half_theta * cos_half_theta);\n\n    ratioA = torch.sin((1 - t) * half_theta) / sin_half_theta;\n    ratioB = torch.sin(t * half_theta) / sin_half_theta; \n    \n    new_q = ratioA * q0 + ratioB * q1\n\n    new_q = torch.where(torch.abs(sin_half_theta) < 0.001, 0.5 * q0 + 0.5 * q1, new_q)\n    new_q = torch.where(torch.abs(cos_half_theta) >= 1, q0, new_q)\n\n    return new_q\n\n@torch.jit.script\ndef calc_heading(q):\n    # type: (Tensor) -> Tensor\n    # calculate heading direction from quaternion\n    # the heading is the direction on the xy plane\n    # q must be normalized\n    ref_dir = torch.zeros_like(q[..., 0:3])\n    ref_dir[..., 0] = 1\n    rot_dir = quat_rotate(q, ref_dir)\n\n    heading = torch.atan2(rot_dir[..., 1], rot_dir[..., 0])\n    return heading\n\n@torch.jit.script\ndef calc_heading_quat(q):\n    # type: (Tensor) -> Tensor\n    # calculate heading rotation from quaternion\n    # the heading is the direction on the xy plane\n    # q must be normalized\n    heading = calc_heading(q)\n    axis = torch.zeros_like(q[..., 0:3])\n    axis[..., 2] = 1\n\n    heading_q = quat_from_angle_axis(heading, axis)\n    return heading_q\n\n@torch.jit.script\ndef calc_heading_quat_inv(q):\n    # type: (Tensor) -> Tensor\n    # calculate heading rotation from quaternion\n    # the heading is the direction on the xy plane\n    # q must be normalized\n    heading = calc_heading(q)\n    axis = torch.zeros_like(q[..., 0:3])\n    axis[..., 2] = 1\n\n    heading_q = quat_from_angle_axis(-heading, axis)\n    return heading_q"
  },
  {
    "path": "timechamber/utils/utils.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n# python\n\nimport numpy as np\nimport torch\nimport random\nimport os\nfrom collections import OrderedDict\nimport time\nfrom isaacgym import gymapi\nfrom isaacgym import gymutil\n\ndef set_np_formatting():\n    \"\"\" formats numpy print \"\"\"\n    np.set_printoptions(edgeitems=30, infstr='inf',\n                        linewidth=4000, nanstr='nan', precision=2,\n                        suppress=False, threshold=10000, formatter=None)\n\n\ndef set_seed(seed, torch_deterministic=False, rank=0):\n    \"\"\" set seed across modules \"\"\"\n    if seed == -1 and torch_deterministic:\n        seed = 42 + rank\n    elif seed == -1:\n        seed = np.random.randint(0, 10000)\n    else:\n        seed = seed + rank\n\n    print(\"Setting seed: {}\".format(seed))\n\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n\n    if torch_deterministic:\n        # refer to https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility\n        os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'\n        torch.backends.cudnn.benchmark = False\n        torch.backends.cudnn.deterministic = True\n        torch.use_deterministic_algorithms(True)\n    else:\n        torch.backends.cudnn.benchmark = True\n        torch.backends.cudnn.deterministic = False\n\n    return seed\n\ndef load_check(checkpoint, normalize_input: bool, normalize_value: bool):\n    extras = OrderedDict()\n    if normalize_value and 'value_mean_std.running_mean' not in checkpoint['model'].keys():\n        extras['value_mean_std.running_mean'] = checkpoint['reward_mean_std']['running_mean']\n        extras['value_mean_std.running_var'] = checkpoint['reward_mean_std']['running_var']\n        extras['value_mean_std.count'] = checkpoint['reward_mean_std']['count']\n\n    if normalize_input and 'running_mean_std.running_mean' not in checkpoint['model'].keys():\n        extras['running_mean_std.running_mean'] = checkpoint['running_mean_std']['running_mean']\n        extras['running_mean_std.running_var'] = checkpoint['running_mean_std']['running_var']\n        extras['running_mean_std.count'] = checkpoint['running_mean_std']['count']\n    \n    extras.update(checkpoint['model'])\n    checkpoint['model'] = extras\n    return checkpoint\n\ndef safe_filesystem_op(func, *args, **kwargs):\n    \"\"\"\n    This is to prevent spurious crashes related to saving checkpoints or restoring from checkpoints in a Network\n    Filesystem environment (i.e. NGC cloud or SLURM)\n    \"\"\"\n    num_attempts = 5\n    for attempt in range(num_attempts):\n        try:\n            return func(*args, **kwargs)\n        except Exception as exc:\n            print(f'Exception {exc} when trying to execute {func} with args:{args} and kwargs:{kwargs}...')\n            wait_sec = 2 ** attempt\n            print(f'Waiting {wait_sec} before trying again...')\n            time.sleep(wait_sec)\n\n    raise RuntimeError(f'Could not execute {func}, give up after {num_attempts} attempts...')\n\ndef safe_load(filename, device=None):\n    if device is not None:\n        return safe_filesystem_op(torch.load, filename, map_location=device)\n    else:\n        return safe_filesystem_op(torch.load, filename)\n\ndef load_checkpoint(filename, device=None):\n    print(\"=> loading checkpoint '{}'\".format(filename))\n    state = safe_load(filename, device=device)\n    return state\n\ndef print_actor_info(gym, env, actor_handle):\n\n    name = gym.get_actor_name(env, actor_handle)\n\n    body_names = gym.get_actor_rigid_body_names(env, actor_handle)\n    body_dict = gym.get_actor_rigid_body_dict(env, actor_handle)\n\n    joint_names = gym.get_actor_joint_names(env, actor_handle)\n    joint_dict = gym.get_actor_joint_dict(env, actor_handle)\n\n    dof_names = gym.get_actor_dof_names(env, actor_handle)\n    dof_dict = gym.get_actor_dof_dict(env, actor_handle)\n\n    print()\n    print(\"===== Actor: %s =======================================\" % name)\n\n    print(\"\\nBodies\")\n    print(body_names)\n    print(body_dict)\n\n    print(\"\\nJoints\")\n    print(joint_names)\n    print(joint_dict)\n\n    print(\"\\n Degrees Of Freedom (DOFs)\")\n    print(dof_names)\n    print(dof_dict)\n    print()\n\n    # Get body state information\n    body_states = gym.get_actor_rigid_body_states(\n        env, actor_handle, gymapi.STATE_ALL)\n\n    # Print some state slices\n    print(\"Poses from Body State:\")\n    print(body_states['pose'])          # print just the poses\n\n    print(\"\\nVelocities from Body State:\")\n    print(body_states['vel'])          # print just the velocities\n    print()\n\n    # iterate through bodies and print name and position\n    body_positions = body_states['pose']['p']\n    for i in range(len(body_names)):\n        print(\"Body '%s' has position\" % body_names[i], body_positions[i])\n\n    print(\"\\nDOF states:\")\n\n    # get DOF states\n    dof_states = gym.get_actor_dof_states(env, actor_handle, gymapi.STATE_ALL)\n\n    # print some state slices\n    # Print all states for each degree of freedom\n    print(dof_states)\n    print()\n\n    # iterate through DOFs and print name and position\n    dof_positions = dof_states['pos']\n    for i in range(len(dof_names)):\n        print(\"DOF '%s' has position\" % dof_names[i], dof_positions[i])\n\ndef print_asset_info(asset, name, gym):\n    print(\"======== Asset info %s: ========\" % (name))\n    num_bodies = gym.get_asset_rigid_body_count(asset)\n    num_joints = gym.get_asset_joint_count(asset)\n    num_dofs = gym.get_asset_dof_count(asset)\n    print(\"Got %d bodies, %d joints, and %d DOFs\" %\n          (num_bodies, num_joints, num_dofs))\n\n    # Iterate through bodies\n    print(\"Bodies:\")\n    for i in range(num_bodies):\n        name = gym.get_asset_rigid_body_name(asset, i)\n        print(\" %2d: '%s'\" % (i, name))\n\n    # Iterate through joints\n    print(\"Joints:\")\n    for i in range(num_joints):\n        name = gym.get_asset_joint_name(asset, i)\n        type = gym.get_asset_joint_type(asset, i)\n        type_name = gym.get_joint_type_string(type)\n        print(\" %2d: '%s' (%s)\" % (i, name, type_name))\n\n    # iterate through degrees of freedom (DOFs)\n    print(\"DOFs:\")\n    for i in range(num_dofs):\n        name = gym.get_asset_dof_name(asset, i)\n        type = gym.get_asset_dof_type(asset, i)\n        type_name = gym.get_dof_type_string(type)\n        print(\" %2d: '%s' (%s)\" % (i, name, type_name))\n\n# EOF\n"
  },
  {
    "path": "timechamber/utils/vec_task.py",
    "content": "# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\nfrom gym import spaces\n\nfrom isaacgym import gymtorch\nfrom isaacgym.torch_utils import to_torch\nimport torch\nimport numpy as np\n\n\n# VecEnv Wrapper for RL training\nclass VecTask():\n    def __init__(self, task, rl_device, clip_observations=5.0, clip_actions=1.0):\n        self.task = task\n\n        self.num_environments = task.num_envs\n        self.num_agents = 1  # used for multi-agent environments\n        self.num_observations = task.num_obs\n        self.num_states = task.num_states\n        self.num_actions = task.num_actions\n\n        self.obs_space = spaces.Box(np.ones(self.num_obs) * -np.Inf, np.ones(self.num_obs) * np.Inf)\n        self.state_space = spaces.Box(np.ones(self.num_states) * -np.Inf, np.ones(self.num_states) * np.Inf)\n        self.act_space = spaces.Box(np.ones(self.num_actions) * -1., np.ones(self.num_actions) * 1.)\n\n        self.clip_obs = clip_observations\n        self.clip_actions = clip_actions\n        self.rl_device = rl_device\n\n        print(\"RL device: \", rl_device)\n\n    def step(self, actions):\n        raise NotImplementedError\n\n    def reset(self):\n        raise NotImplementedError\n\n    def get_number_of_agents(self):\n        return self.num_agents\n\n    @property\n    def observation_space(self):\n        return self.obs_space\n\n    @property\n    def action_space(self):\n        return self.act_space\n\n    @property\n    def num_envs(self):\n        return self.num_environments\n\n    @property\n    def num_acts(self):\n        return self.num_actions\n\n    @property\n    def num_obs(self):\n        return self.num_observations\n\n\n# C++ CPU Class\nclass VecTaskCPU(VecTask):\n    def __init__(self, task, rl_device, sync_frame_time=False, clip_observations=5.0, clip_actions=1.0):\n        super().__init__(task, rl_device, clip_observations=clip_observations, clip_actions=clip_actions)\n        self.sync_frame_time = sync_frame_time\n\n    def step(self, actions):\n        actions = actions.cpu().numpy()\n        self.task.render(self.sync_frame_time)\n\n        obs, rewards, resets, extras = self.task.step(np.clip(actions, -self.clip_actions, self.clip_actions))\n\n        return (to_torch(np.clip(obs, -self.clip_obs, self.clip_obs), dtype=torch.float, device=self.rl_device),\n                to_torch(rewards, dtype=torch.float, device=self.rl_device),\n                to_torch(resets, dtype=torch.uint8, device=self.rl_device), [])\n\n    def reset(self):\n        actions = 0.01 * (1 - 2 * np.random.rand(self.num_envs, self.num_actions)).astype('f')\n\n        # step the simulator\n        obs, rewards, resets, extras = self.task.step(actions)\n\n        return to_torch(np.clip(obs, -self.clip_obs, self.clip_obs), dtype=torch.float, device=self.rl_device)\n\n\n# C++ GPU Class\nclass VecTaskGPU(VecTask):\n    def __init__(self, task, rl_device, clip_observations=5.0, clip_actions=1.0):\n        super().__init__(task, rl_device, clip_observations=clip_observations, clip_actions=clip_actions)\n\n        self.obs_tensor = gymtorch.wrap_tensor(self.task.obs_tensor, counts=(self.task.num_envs, self.task.num_obs))\n        self.rewards_tensor = gymtorch.wrap_tensor(self.task.rewards_tensor, counts=(self.task.num_envs,))\n        self.resets_tensor = gymtorch.wrap_tensor(self.task.resets_tensor, counts=(self.task.num_envs,))\n\n    def step(self, actions):\n        self.task.render(False)\n        actions_clipped = torch.clamp(actions, -self.clip_actions, self.clip_actions)\n        actions_tensor = gymtorch.unwrap_tensor(actions_clipped)\n\n        self.task.step(actions_tensor)\n\n        return torch.clamp(self.obs_tensor, -self.clip_obs, self.clip_obs), self.rewards_tensor, self.resets_tensor, []\n\n    def reset(self):\n        actions = 0.01 * (1 - 2 * torch.rand([self.task.num_envs, self.task.num_actions], dtype=torch.float32, device=self.rl_device))\n        actions_tensor = gymtorch.unwrap_tensor(actions)\n\n        # step the simulator\n        self.task.step(actions_tensor)\n\n        return torch.clamp(self.obs_tensor, -self.clip_obs, self.clip_obs)\n\n\n# Python CPU/GPU Class\nclass VecTaskPython(VecTask):\n\n    def get_state(self):\n        return torch.clamp(self.task.states_buf, -self.clip_obs, self.clip_obs).to(self.rl_device)\n\n    def step(self, actions):\n        actions_tensor = torch.clamp(actions, -self.clip_actions, self.clip_actions)\n\n        self.task.step(actions_tensor)\n\n        return torch.clamp(self.task.obs_buf, -self.clip_obs, self.clip_obs).to(self.rl_device), self.task.rew_buf.to(self.rl_device), self.task.reset_buf.to(self.rl_device), self.task.extras\n\n    def reset(self):\n        actions = 0.01 * (1 - 2 * torch.rand([self.task.num_envs, self.task.num_actions], dtype=torch.float32, device=self.rl_device))\n\n        # step the simulator\n        self.task.step(actions)\n\n        return torch.clamp(self.task.obs_buf, -self.clip_obs, self.clip_obs).to(self.rl_device)\n"
  },
  {
    "path": "timechamber/utils/vec_task_wrappers.py",
    "content": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n#\n# 3. Neither the name of the copyright holder nor the names of its\n#    contributors may be used to endorse or promote products derived from\n#    this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nfrom gym import spaces\nimport numpy as np\nimport torch\nfrom timechamber.utils.vec_task import VecTaskCPU, VecTaskGPU, VecTaskPython\n\nclass VecTaskCPUWrapper(VecTaskCPU):\n    def __init__(self, task, rl_device, sync_frame_time=False, clip_observations=5.0, clip_actions=1.0):\n        super().__init__(task, rl_device, sync_frame_time, clip_observations, clip_actions)\n        return\n\nclass VecTaskGPUWrapper(VecTaskGPU):\n    def __init__(self, task, rl_device, clip_observations=5.0, clip_actions=1.0):\n        super().__init__(task, rl_device, clip_observations, clip_actions)\n        return\n\n\nclass VecTaskPythonWrapper(VecTaskPython):\n    def __init__(self, task, rl_device, clip_observations=5.0, clip_actions=1.0, AMP=False):\n        super().__init__(task, rl_device, clip_observations, clip_actions)\n        if AMP:\n            self._amp_obs_space = spaces.Box(np.ones(task.get_num_amp_obs()) * -np.Inf, np.ones(task.get_num_amp_obs()) * np.Inf)\n        else:\n            self._amp_obs_space = None\n        return\n\n    def reset(self, env_ids=None):\n        self.task.reset(env_ids)\n        return torch.clamp(self.task.obs_buf, -self.clip_obs, self.clip_obs).to(self.rl_device)\n\n    @property\n    def amp_observation_space(self):\n        return self._amp_obs_space\n\n    def fetch_amp_obs_demo(self, num_samples):\n        return self.task.fetch_amp_obs_demo(num_samples)"
  }
]