Full Code of inspirai/TimeChamber for AI

main af3f3571c99a cached
201 files
120.5 MB
199.9k tokens
972 symbols
1 requests
Download .txt
Showing preview only (856K chars total). Download the full file or copy to clipboard to get everything.
Repository: inspirai/TimeChamber
Branch: main
Commit: af3f3571c99a
Files: 201
Total size: 120.5 MB

Directory structure:
gitextract_rvpupy7y/

├── .gitattributes
├── .gitignore
├── LICENSE
├── LISENCE/
│   └── isaacgymenvs/
│       └── LICENSE
├── README.md
├── assets/
│   └── mjcf/
│       └── nv_ant.xml
├── docs/
│   └── environments.md
├── setup.py
└── timechamber/
    ├── __init__.py
    ├── ase/
    │   ├── ase_agent.py
    │   ├── ase_models.py
    │   ├── ase_network_builder.py
    │   ├── ase_players.py
    │   ├── hrl_agent.py
    │   ├── hrl_models.py
    │   ├── hrl_network_builder.py
    │   ├── hrl_players.py
    │   └── utils/
    │       ├── amp_agent.py
    │       ├── amp_datasets.py
    │       ├── amp_models.py
    │       ├── amp_network_builder.py
    │       ├── amp_players.py
    │       ├── common_agent.py
    │       ├── common_player.py
    │       └── replay_buffer.py
    ├── cfg/
    │   ├── config.yaml
    │   ├── task/
    │   │   ├── MA_Ant_Battle.yaml
    │   │   ├── MA_Ant_Sumo.yaml
    │   │   └── MA_Humanoid_Strike.yaml
    │   └── train/
    │       ├── MA_Ant_BattlePPO.yaml
    │       ├── MA_Ant_SumoPPO.yaml
    │       ├── MA_Humanoid_StrikeHRL.yaml
    │       └── base/
    │           └── ase_humanoid_hrl.yaml
    ├── learning/
    │   ├── common_agent.py
    │   ├── common_player.py
    │   ├── hrl_sp_agent.py
    │   ├── hrl_sp_player.py
    │   ├── pfsp_player_pool.py
    │   ├── ppo_sp_agent.py
    │   ├── ppo_sp_player.py
    │   ├── replay_buffer.py
    │   ├── vectorized_models.py
    │   └── vectorized_network_builder.py
    ├── models/
    │   ├── Humanoid_Strike/
    │   │   ├── policy.pth
    │   │   └── policy_op.pth
    │   ├── ant_battle_2agents/
    │   │   └── policy.pth
    │   ├── ant_battle_3agents/
    │   │   └── policy.pth
    │   └── ant_sumo/
    │       └── policy.pth
    ├── tasks/
    │   ├── __init__.py
    │   ├── ase_humanoid_base/
    │   │   ├── base_task.py
    │   │   ├── humanoid.py
    │   │   ├── humanoid_amp.py
    │   │   ├── humanoid_amp_task.py
    │   │   └── poselib/
    │   │       ├── README.md
    │   │       ├── data/
    │   │       │   ├── 01_01_cmu.fbx
    │   │       │   ├── 07_01_cmu.fbx
    │   │       │   ├── 08_02_cmu.fbx
    │   │       │   ├── 09_11_cmu.fbx
    │   │       │   ├── 49_08_cmu.fbx
    │   │       │   ├── 55_01_cmu.fbx
    │   │       │   ├── amp_humanoid_tpose.npy
    │   │       │   ├── cmu_tpose.npy
    │   │       │   ├── configs/
    │   │       │   │   ├── retarget_cmu_to_amp.json
    │   │       │   │   └── retarget_sfu_to_amp.json
    │   │       │   └── sfu_tpose.npy
    │   │       ├── fbx_importer.py
    │   │       ├── generate_amp_humanoid_tpose.py
    │   │       ├── mjcf_importer.py
    │   │       ├── poselib/
    │   │       │   ├── __init__.py
    │   │       │   ├── core/
    │   │       │   │   ├── __init__.py
    │   │       │   │   ├── backend/
    │   │       │   │   │   ├── __init__.py
    │   │       │   │   │   ├── abstract.py
    │   │       │   │   │   └── logger.py
    │   │       │   │   ├── rotation3d.py
    │   │       │   │   ├── tensor_utils.py
    │   │       │   │   └── tests/
    │   │       │   │       ├── __init__.py
    │   │       │   │       └── test_rotation.py
    │   │       │   ├── skeleton/
    │   │       │   │   ├── __init__.py
    │   │       │   │   ├── backend/
    │   │       │   │   │   ├── __init__.py
    │   │       │   │   │   └── fbx/
    │   │       │   │   │       ├── __init__.py
    │   │       │   │   │       ├── fbx_backend.py
    │   │       │   │   │       └── fbx_read_wrapper.py
    │   │       │   │   └── skeleton3d.py
    │   │       │   └── visualization/
    │   │       │       ├── __init__.py
    │   │       │       ├── common.py
    │   │       │       ├── core.py
    │   │       │       ├── plt_plotter.py
    │   │       │       ├── simple_plotter_tasks.py
    │   │       │       ├── skeleton_plotter_tasks.py
    │   │       │       └── tests/
    │   │       │           ├── __init__.py
    │   │       │           └── test_plotter.py
    │   │       └── retarget_motion.py
    │   ├── base/
    │   │   ├── __init__.py
    │   │   ├── ma_vec_task.py
    │   │   └── vec_task.py
    │   ├── data/
    │   │   ├── assets/
    │   │   │   └── mjcf/
    │   │   │       └── amp_humanoid_sword_shield.xml
    │   │   ├── models/
    │   │   │   └── llc_reallusion_sword_shield.pth
    │   │   └── motions/
    │   │       └── reallusion_sword_shield/
    │   │           ├── README.txt
    │   │           ├── RL_Avatar_Atk_2xCombo01_Motion.npy
    │   │           ├── RL_Avatar_Atk_2xCombo02_Motion.npy
    │   │           ├── RL_Avatar_Atk_2xCombo03_Motion.npy
    │   │           ├── RL_Avatar_Atk_2xCombo04_Motion.npy
    │   │           ├── RL_Avatar_Atk_2xCombo05_Motion.npy
    │   │           ├── RL_Avatar_Atk_3xCombo01_Motion.npy
    │   │           ├── RL_Avatar_Atk_3xCombo02_Motion.npy
    │   │           ├── RL_Avatar_Atk_3xCombo03_Motion.npy
    │   │           ├── RL_Avatar_Atk_3xCombo04_Motion.npy
    │   │           ├── RL_Avatar_Atk_3xCombo05_Motion.npy
    │   │           ├── RL_Avatar_Atk_3xCombo06_Motion.npy
    │   │           ├── RL_Avatar_Atk_3xCombo07_Motion.npy
    │   │           ├── RL_Avatar_Atk_4xCombo01_Motion.npy
    │   │           ├── RL_Avatar_Atk_4xCombo02_Motion.npy
    │   │           ├── RL_Avatar_Atk_4xCombo03_Motion.npy
    │   │           ├── RL_Avatar_Atk_Jump_Motion.npy
    │   │           ├── RL_Avatar_Atk_Kick_Motion.npy
    │   │           ├── RL_Avatar_Atk_ShieldCharge_Motion.npy
    │   │           ├── RL_Avatar_Atk_ShieldSwipe01_Motion.npy
    │   │           ├── RL_Avatar_Atk_ShieldSwipe02_Motion.npy
    │   │           ├── RL_Avatar_Atk_SlashDown_Motion.npy
    │   │           ├── RL_Avatar_Atk_SlashLeft_Motion.npy
    │   │           ├── RL_Avatar_Atk_SlashRight_Motion.npy
    │   │           ├── RL_Avatar_Atk_SlashUp_Motion.npy
    │   │           ├── RL_Avatar_Atk_Spin_Motion.npy
    │   │           ├── RL_Avatar_Atk_Stab_Motion.npy
    │   │           ├── RL_Avatar_Counter_Atk01_Motion.npy
    │   │           ├── RL_Avatar_Counter_Atk02_Motion.npy
    │   │           ├── RL_Avatar_Counter_Atk03_Motion.npy
    │   │           ├── RL_Avatar_Counter_Atk04_Motion.npy
    │   │           ├── RL_Avatar_Counter_Atk05_Motion.npy
    │   │           ├── RL_Avatar_Dodge_Backward_Motion.npy
    │   │           ├── RL_Avatar_Dodgle_Left_Motion.npy
    │   │           ├── RL_Avatar_Dodgle_Right_Motion.npy
    │   │           ├── RL_Avatar_Fall_Backward_Motion.npy
    │   │           ├── RL_Avatar_Fall_Left_Motion.npy
    │   │           ├── RL_Avatar_Fall_Right_Motion.npy
    │   │           ├── RL_Avatar_Fall_SpinLeft_Motion.npy
    │   │           ├── RL_Avatar_Fall_SpinRight_Motion.npy
    │   │           ├── RL_Avatar_Idle_Alert(0)_Motion.npy
    │   │           ├── RL_Avatar_Idle_Alert_Motion.npy
    │   │           ├── RL_Avatar_Idle_Battle(0)_Motion.npy
    │   │           ├── RL_Avatar_Idle_Battle_Motion.npy
    │   │           ├── RL_Avatar_Idle_Ready(0)_Motion.npy
    │   │           ├── RL_Avatar_Idle_Ready_Motion.npy
    │   │           ├── RL_Avatar_Kill_2xCombo01_Motion.npy
    │   │           ├── RL_Avatar_Kill_2xCombo02_Motion.npy
    │   │           ├── RL_Avatar_Kill_3xCombo01_Motion.npy
    │   │           ├── RL_Avatar_Kill_3xCombo02_Motion.npy
    │   │           ├── RL_Avatar_Kill_4xCombo01_Motion.npy
    │   │           ├── RL_Avatar_RunBackward_Motion.npy
    │   │           ├── RL_Avatar_RunForward_Motion.npy
    │   │           ├── RL_Avatar_RunLeft_Motion.npy
    │   │           ├── RL_Avatar_RunRight_Motion.npy
    │   │           ├── RL_Avatar_Shield_BlockBackward_Motion.npy
    │   │           ├── RL_Avatar_Shield_BlockCrouch_Motion.npy
    │   │           ├── RL_Avatar_Shield_BlockDown_Motion.npy
    │   │           ├── RL_Avatar_Shield_BlockLeft_Motion.npy
    │   │           ├── RL_Avatar_Shield_BlockRight_Motion.npy
    │   │           ├── RL_Avatar_Shield_BlockUp_Motion.npy
    │   │           ├── RL_Avatar_Standoff_Circle_Motion.npy
    │   │           ├── RL_Avatar_Standoff_Feint_Motion.npy
    │   │           ├── RL_Avatar_Standoff_Swing_Motion.npy
    │   │           ├── RL_Avatar_Sword_ParryBackward01_Motion.npy
    │   │           ├── RL_Avatar_Sword_ParryBackward02_Motion.npy
    │   │           ├── RL_Avatar_Sword_ParryBackward03_Motion.npy
    │   │           ├── RL_Avatar_Sword_ParryBackward04_Motion.npy
    │   │           ├── RL_Avatar_Sword_ParryCrouch_Motion.npy
    │   │           ├── RL_Avatar_Sword_ParryDown_Motion.npy
    │   │           ├── RL_Avatar_Sword_ParryLeft_Motion.npy
    │   │           ├── RL_Avatar_Sword_ParryRight_Motion.npy
    │   │           ├── RL_Avatar_Sword_ParryUp_Motion.npy
    │   │           ├── RL_Avatar_Taunt_PoundChest_Motion.npy
    │   │           ├── RL_Avatar_Taunt_Roar_Motion.npy
    │   │           ├── RL_Avatar_Taunt_ShieldKnock_Motion.npy
    │   │           ├── RL_Avatar_TurnLeft180_Motion.npy
    │   │           ├── RL_Avatar_TurnLeft90_Motion.npy
    │   │           ├── RL_Avatar_TurnRight180_Motion.npy
    │   │           ├── RL_Avatar_TurnRight90_Motion.npy
    │   │           ├── RL_Avatar_WalkBackward01_Motion.npy
    │   │           ├── RL_Avatar_WalkBackward02_Motion.npy
    │   │           ├── RL_Avatar_WalkForward01_Motion.npy
    │   │           ├── RL_Avatar_WalkForward02_Motion.npy
    │   │           ├── RL_Avatar_WalkLeft01_Motion.npy
    │   │           ├── RL_Avatar_WalkLeft02_Motion.npy
    │   │           ├── RL_Avatar_WalkRight01_Motion.npy
    │   │           ├── RL_Avatar_WalkRight02_Motion.npy
    │   │           └── dataset_reallusion_sword_shield.yaml
    │   ├── ma_ant_battle.py
    │   ├── ma_ant_sumo.py
    │   └── ma_humanoid_strike.py
    ├── train.py
    └── utils/
        ├── config.py
        ├── gym_util.py
        ├── logger.py
        ├── motion_lib.py
        ├── reformat.py
        ├── rlgames_utils.py
        ├── torch_jit_utils.py
        ├── torch_utils.py
        ├── utils.py
        ├── vec_task.py
        └── vec_task_wrappers.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitattributes
================================================


================================================
FILE: .gitignore
================================================
videos
/timechamber/logs
*train_dir*
*ige_logs*
*.egg-info
/.vs
/.vscode
/_package
/shaders
._tmptext.txt
__pycache__/
/timechamber/tasks/__pycache__
/timechamber/utils/__pycache__
/timechamber/tasks/base/__pycache__
/tools/format/.lastrun
*.pyc
_doxygen
/rlisaacgymenvsgpu/logs
/timechamber/benchmarks/results
/timechamber/simpletests/results
*.pxd2
/tests/logs
/timechamber/balance_bot.xml
/timechamber/quadcopter.xml
/timechamber/ingenuity.xml
logs*
nn/
runs/
.idea
outputs/
*.hydra*
/timechamber/wandb
/test
.gitlab



================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2022 MIT Inspir.ai

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

================================================
FILE: LISENCE/isaacgymenvs/LICENSE
================================================
# Copyright (c) 2018-2022, NVIDIA Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

================================================
FILE: README.md
================================================
# TimeChamber: A Massively Parallel Large Scale Self-Play Framework

****

**TimeChamber** is a large scale self-play framework running on parallel simulation.
Running self-play algorithms always need lots of hardware resources, especially on 3D physically simulated
environments.
We provide a self-play framework that can achieve fast training and evaluation with **ONLY ONE GPU**.
TimeChamber is developed with the following key features:

- **Parallel Simulation**: TimeChamber is built within [Isaac Gym](https://developer.nvidia.com/isaac-gym). Isaac Gym is
  a fast GPU-based simulation platform. It supports running thousands of environments in parallel on a single GPU.For
  example, on one NVIDIA Laptop RTX 3070Ti GPU, TimeChamber can reach **80,000+
  mean FPS** by running 4,096 environments in parallel.
- **Parallel Evaluation**: TimeChamber can fast calculate dozens of policies' ELO
  rating(represent their combat power). It also supports multi-player ELO calculations
  by [multi-elo](https://github.com/djcunningham0/multielo). Inspired by Vectorization techniques
  for [fast population-based training](https://github.com/instadeepai/fastpbrl), we leverage the
  vectorized models to evaluate different policy in parallel.
- **Prioritized Fictitious Self-Play Benchmark**: We implement a classic PPO self-play algorithm on top
  of [rl_games](https://github.com/Denys88/rl_games), with a prioritized player pool to avoid cycles and improve the
  diversity of training policy.

<div align=center>
<img src="assets/images/algorithm.jpg" align="center" width="600"/>
</div> 

- **Competitive Multi-Agent Tasks**: Inspired by [OpenAI RoboSumo](https://github.com/openai/robosumo) and [ASE](https://github.com/nv-tlabs/ASE), we introduce three
  competitive multi-agent tasks(e.g.,Ant Sumo,Ant
  Battle and Humanoid Strike) as examples.
  The efficiency of our self-play framework has been tested on these tasks. After days of training,our agent can
  discover some interesting
  physical skills like pulling, jumping,etc. **Welcome to contribute your own environments!**


## Installation

****
Download and follow the installation instructions of Isaac Gym: https://developer.nvidia.com/isaac-gym  
Ensure that Isaac Gym works on your system by running one of the examples from the `python/examples`
directory, like `joint_monkey.py`. If you have any trouble running the samples, please follow troubleshooting steps
described in the [Isaac Gym Preview Release 3/4 installation instructions](https://developer.nvidia.com/isaac-gym).  
Then install this repo:

```bash
pip install -e .
```

## Quick Start

****

### Tasks

Source code for tasks can be found in  `timechamber/tasks`,The detailed settings of state/action/reward are
in [here](./docs/environments.md).
More interesting tasks will come soon.

#### Humanoid Strike

Humanoid Strike is a 3D environment with two simulated humanoid physics characters. Each character is equipped with a sword and shield with 37 degrees-of-freedom.
The game will be restarted if one agent goes outside the arena. We measure how much the player damaged the opponent and how much the player was damaged by the opponent in the terminated step to determine the winner.

<div align=center>
<img src="assets/images/humanoid_strike.gif" align="center" width="600"/>
</div> 



#### Ant Sumo

Ant Sumo is a 3D environment with simulated physics that allows pairs of ant agents to compete against each other.
To win, the agent has to push the opponent out of the ring. Every agent has 100 hp . Each step, If the agent's body
touches the ground, its hp will be reduced by 1.The agent whose hp becomes 0 will be eliminated.
<div align=center>
<img src="assets/images/ant_sumo.gif" align="center" width="600"/>
</div> 

#### Ant Battle

Ant Battle is an expanded environment of Ant Sumo. It supports more than two agents competing against with
each other. The battle ring radius will shrink, the agent going out of the ring will be eliminated.
<div align=center>
<img src="assets/images/ant_battle.gif" align="center" width="600"/>
</div>  

### Self-Play Training

To train your policy for tasks, for example:

```bash
# run self-play training for Humanoid Strike task
python train.py task=MA_Humanoid_Strike headless=True
```

```bash
# run self-play training for Ant Sumo task
python train.py task=MA_Ant_Sumo train=MA_Ant_SumoPPO headless=True
```

```bash
# run self-play training for Ant Battle task
python train.py task=MA_Ant_Battle train=MA_Ant_BattlePPO headless=True
```

Key arguments to the training script
follow [IsaacGymEnvs Configuration and command line arguments](https://github.com/NVIDIA-Omniverse/IsaacGymEnvs/blob/main/README.md#configuration-and-command-line-arguments)
.
Other training arguments follow [rl_games config parameters](https://github.com/Denys88/rl_games#config-parameters),
you can change them in `timechamber/tasks/train/*.yaml`. There are some specific arguments for self-play training:

- `num_agents`: Set the number of agents for Ant Battle environment, it should be larger than 1.
- `op_checkpoint`: Set to path to the checkpoint to load initial opponent agent policy.
  If it's empty, opponent agent will use random policy.
- `update_win_rate`: Win_rate threshold to add the current policy to opponent's player pool.
- `player_pool_length`: The max size of player pool, following FIFO rules.
- `games_to_check`: Warm up for training, the player pool won't be updated until the current policy plays such number of
  games.
- `max_update_steps`: If current policy update iterations exceed that number, the current policy will be added to
  opponent player_pool.

### Policies Evaluation

To evaluate your policies, for example:

```bash
# run testing for Ant Sumo policy
python train.py task=MA_Ant_Sumo train=MA_Ant_SumoPPO test=True num_envs=4 minibatch_size=32 headless=False checkpoint='models/ant_sumo/policy.pth'
```

```bash
# run testing for Humanoid Strike policy
python train.py task=MA_Humanoid_Strike train=MA_Humanoid_StrikeHRL test=True num_envs=4 minibatch_size=32 headless=False checkpoint='models/Humanoid_Strike/policy.pth' op_checkpoint='models/Humanoid_Strike/policy_op.pth'
```

You can set the opponent agent policy using `op_checkpoint`. If it's empty, the opponent agent will use the same policy
as `checkpoint`.  
We use vectorized models to accelerate the evaluation of policies. Put policies into checkpoint dir, let them compete
with each
other in parallel:

```bash
# run testing for Ant Sumo policy
python train.py task=MA_Ant_Sumo train=MA_Ant_SumoPPO test=True headless=True checkpoint='models/ant_sumo' player_pool_type=vectorized
```

There are some specific arguments for self-play evaluation, you can change them in `timechamber/tasks/train/*.yaml`:

- `games_num`: Total episode number of evaluation.
- `record_elo`: Set `True` to record the ELO rating of your policies, after evaluation, you can check the `elo.jpg` in
  your checkpoint dir.

<div align=center>
  <img src="assets/images/elo.jpg" align="center" width="400"/>
</div>

- `init_elo`: Initial ELO rating of each policy.

### Building Your Own Task

You can build your own task
follow [IsaacGymEnvs](https://github.com/NVIDIA-Omniverse/IsaacGymEnvs/blob/main/README.md#creating-an-environment)
, make sure the obs shape is correct and`info` contains `win`,`lose`and`draw`:

```python
import isaacgym
import timechamber
import torch

envs = timechamber.make(
    seed=0,
    task="MA_Ant_Sumo",
    num_envs=2,
    sim_device="cuda:0",
    rl_device="cuda:0",
)
# the obs shape should be (num_agents*num_envs,num_obs).
# the obs of training agent is (:num_envs,num_obs)
print("Observation space is", envs.observation_space)
print("Action space is", envs.action_space)
obs = envs.reset()
for _ in range(20):
    obs, reward, done, info = envs.step(
        torch.rand((2 * 2,) + envs.action_space.shape, device="cuda:0")
    )
# info:
# {'win': tensor([Bool, Bool])
# 'lose': tensor([Bool, Bool])
# 'draw': tensor([Bool, Bool])}

```

## Citing

If you use timechamber in your research please use the following citation:

````
@misc{InspirAI,
  author = {Huang Ziming, Ziyi Liu, Wu Yutong, Flood Sung},
  title = {TimeChamber: A Massively Parallel Large Scale Self-Play Framework},
  year = {2022},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/inspirai/TimeChamber}},
}

================================================
FILE: assets/mjcf/nv_ant.xml
================================================
<mujoco model="ant">
  <custom>
    <numeric data="0.0 0.0 0.55 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -1.0 0.0 -1.0 0.0 1.0" name="init_qpos"/>
  </custom>

  <default>
    <joint armature="0.01" damping="0.1" limited="true"/>
    <geom condim="3" density="5.0" friction="1.5 0.1 0.1" margin="0.01" rgba="0.97 0.38 0.06 1"/>
  </default>

  <compiler inertiafromgeom="true" angle="degree"/>

  <option timestep="0.016" iterations="50" tolerance="1e-10" solver="Newton" jacobian="dense" cone="pyramidal"/>

  <size nconmax="50" njmax="200" nstack="10000"/>
  <visual>
      <map force="0.1" zfar="30"/>
      <rgba haze="0.15 0.25 0.35 1"/>
      <quality shadowsize="2048"/>
      <global offwidth="800" offheight="800"/>
  </visual>

  <asset>
      <texture type="skybox" builtin="gradient" rgb1="0.3 0.5 0.7" rgb2="0 0 0" width="512" height="512"/> 
      <texture name="texplane" type="2d" builtin="checker" rgb1=".2 .3 .4" rgb2=".1 0.15 0.2" width="512" height="512" mark="cross" markrgb=".8 .8 .8"/>
      <texture name="texgeom" type="cube" builtin="flat" mark="cross" width="127" height="1278" 
          rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" markrgb="1 1 1" random="0.01"/>  

      <material name="matplane" reflectance="0.3" texture="texplane" texrepeat="1 1" texuniform="true"/>
      <material name="matgeom" texture="texgeom" texuniform="true" rgba="0.8 0.6 .4 1"/>
  </asset>

  <worldbody>
    <geom name="floor" pos="0 0 0" size="0 0 .25" type="plane" material="matplane" condim="3"/>

    <light directional="false" diffuse=".2 .2 .2" specular="0 0 0" pos="0 0 5" dir="0 0 -1" castshadow="false"/>
    <light mode="targetbodycom" target="torso" directional="false" diffuse=".8 .8 .8" specular="0.3 0.3 0.3" pos="0 0 4.0" dir="0 0 -1"/>

    <body name="torso" pos="0 0 0.75">
      <freejoint name="root"/>
      <geom name="torso_geom" pos="0 0 0" size="0.25" type="sphere"/>
      <geom fromto="0.0 0.0 0.0 0.2 0.2 0.0" name="aux_1_geom" size="0.08" type="capsule" rgba=".999 .2 .1 1"/>
      <geom fromto="0.0 0.0 0.0 -0.2 0.2 0.0" name="aux_2_geom" size="0.08" type="capsule"/>
      <geom fromto="0.0 0.0 0.0 -0.2 -0.2 0.0" name="aux_3_geom" size="0.08" type="capsule"/>
      <geom fromto="0.0 0.0 0.0 0.2 -0.2 0.0" name="aux_4_geom" size="0.08" type="capsule" rgba=".999 .2 .02 1"/>

      <body name="front_left_leg" pos="0.2 0.2 0">
        <joint axis="0 0 1" name="hip_1" pos="0.0 0.0 0.0" range="-40 40" type="hinge"/>
        <geom fromto="0.0 0.0 0.0 0.2 0.2 0.0" name="left_leg_geom" size="0.08" type="capsule" rgba=".999 .2 .1 1"/>
        <body pos="0.2 0.2 0" name="front_left_foot">
          <joint axis="-1 1 0" name="ankle_1" pos="0.0 0.0 0.0" range="30 100" type="hinge"/>
          <geom fromto="0.0 0.0 0.0 0.4 0.4 0.0" name="left_ankle_geom" size="0.08" type="capsule" rgba=".999 .2 .1 1"/>
        </body>
      </body>
      <body name="front_right_leg" pos="-0.2 0.2 0">
        <joint axis="0 0 1" name="hip_2" pos="0.0 0.0 0.0" range="-40 40" type="hinge"/>
        <geom fromto="0.0 0.0 0.0 -0.2 0.2 0.0" name="right_leg_geom" size="0.08" type="capsule"/>
        <body pos="-0.2 0.2 0" name="front_right_foot">
          <joint axis="1 1 0" name="ankle_2" pos="0.0 0.0 0.0" range="-100 -30" type="hinge"/>
          <geom fromto="0.0 0.0 0.0 -0.4 0.4 0.0" name="right_ankle_geom" size="0.08" type="capsule"/>
        </body>
      </body>
      <body name="left_back_leg" pos="-0.2 -0.2 0">
        <joint axis="0 0 1" name="hip_3" pos="0.0 0.0 0.0" range="-40 40" type="hinge"/>
        <geom fromto="0.0 0.0 0.0 -0.2 -0.2 0.0" name="back_leg_geom" size="0.08" type="capsule"/>
        <body pos="-0.2 -0.2 0" name="left_back_foot">
          <joint axis="-1 1 0" name="ankle_3" pos="0.0 0.0 0.0" range="-100 -30" type="hinge"/>
          <geom fromto="0.0 0.0 0.0 -0.4 -0.4 0.0" name="third_ankle_geom" size="0.08" type="capsule"/>
        </body>
      </body>
      <body name="right_back_leg" pos="0.2 -0.2 0">
        <joint axis="0 0 1" name="hip_4" pos="0.0 0.0 0.0" range="-40 40" type="hinge"/>
        <geom fromto="0.0 0.0 0.0 0.2 -0.2 0.0" name="rightback_leg_geom" size="0.08" type="capsule" rgba=".999 .2 .1 1"/>
        <body pos="0.2 -0.2 0" name="right_back_foot">
          <joint axis="1 1 0" name="ankle_4" pos="0.0 0.0 0.0" range="30 100" type="hinge"/>
          <geom fromto="0.0 0.0 0.0 0.4 -0.4 0.0" name="fourth_ankle_geom" size="0.08" type="capsule" rgba=".999 .2 .1 1"/>
        </body>
      </body>
    </body>
  </worldbody>

  <actuator>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_4" gear="15"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_4" gear="15"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_1" gear="15"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_1" gear="15"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_2" gear="15"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_2" gear="15"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_3" gear="15"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_3" gear="15"/>
  </actuator>
</mujoco>


================================================
FILE: docs/environments.md
================================================
## Environments

We provide a detailed description of the environment here.

### Humanoid Strike

Humanoid Strike is a 3D environment with two simulated humanoid physics characters. Each character is equipped with a sword and shield with 37 degrees-of-freedom.
The game will be restarted if one agent goes outside the arena or the game reaches the maximum episode steps. We measure how much the player damaged the opponent and how much the player was damaged by the opponent in the terminated step to determine the winner.

#### <span id="obs1-1">Low-Level Observation Space</span>

|  Index  |          Description           |
|:-------:|:------------------------------:|
|  0   |           Height of the root from the ground.            |
|  1 - 48  |         Position of the body in the character’s local coordinate frame.         |
|  49 - 150  |      Rotation of the body in the character’s local coordinate frame.      |
| 151 - 201 |      Linear velocity of the root in the character’s local coordinate frame.       |
| 202 - 252 |      angular velocity of the root in the character’s local coordinate frame.          |


#### <span id="obs1-2">High-Level Observation Space</span>

|  Index  |          Description           |
|:-------:|:------------------------------:|
|  0 - 1  |    relative distance from the borderline            |
|  2 - 4  |    relative distance from the opponent          |
|  5 - 10  |      Rotation of the opponent's root in the character’s local coordinate frame.      |
| 11 - 13 |      Linear velocity of the opponent'root in the character’s local coordinate frame.       |
| 14 - 16 |      angular velocity of the opponent'root in the character’s local coordinate frame.         |
| 17 - 19 |      relative distance between ego agent and opponent's sword         |
| 20 - 22 |      Linear velocity of the opponent' sword in the character’s local coordinate frame.          |
| 23 - 25 |      relative distance between ego agent' shield and opponent's sword        |
| 26 - 28 | relative velocity between ego agent' shield and opponent's sword |
|   29 - 31    |   relative distance between ego agent' sword and opponent's torse    |
|   32 - 34    | relative velocity between ego agent' sword and opponent's torse  |
|   35 - 37    |   relative distance between ego agent' sword and opponent's head    |
|   38 - 40    | relative velocity between ego agent' sword and opponent's head  |
|   41 - 43    |   relative distance between ego agent' sword and opponent's right arm    |
|   44 - 46    | relative distance between ego agent' sword and opponent's right thigh  |
|   47 - 49    | relative distance between ego agent' sword and opponent's left thigh  |


#### <span id="action1-1">Low-Level Action Space</span>

| Index |    Description    |
|:-----:|:-----------------:|
| 0 - 30 | target rotations  of each character’s joints |

#### <span id="action1-2">High-Level Action Space</span>

| Index |    Description    |
|:-----:|:-----------------:|
| 0 - 63 | latent skill variables |

#### <span id="r1">Rewards</span>

The weights of reward components are as follows:

```python
op_fall_reward_w = 200.0
ego_fall_out_reward_w = 50.0
shield_to_sword_pos_reward_w = 1.0
damage_reward_w = 8.0
sword_to_op_reward_w = 0.8
reward_energy_w = 3.0
reward_strike_vel_acc_w = 3.0
reward_face_w = 4.0
reward_foot_to_op_w = 10.0
reward_kick_w = 2.0
```


### Ant Sumo

Ant Sumo is a 3D environment with simulated physics that allows pairs of ant agents to compete against each other.
To win, the agent has to push the opponent out of the ring. Every agent has 100 hp . Each step, If the agent's body
touches the ground, its hp will be reduced by 1.The agent whose hp becomes 0 will be eliminated.

#### <span id="obs2">Observation Space</span>

|  Index  |          Description           |
|:-------:|:------------------------------:|
|  0 - 2  |           self pose            |
|  3 - 6  |         self rotation          |
|  7 - 9  |      self linear velocity      |
| 10 - 12 |      self angle velocity       |
| 13 - 20 |          self dof pos          |
| 21 - 28 |       self dof velocity        |
| 29 - 31 |         opponent pose          |
| 32 - 35 |       opponent rotation        |
| 36 - 37 | self-opponent pose vector(x,y) |
|   38    |   is self body touch ground    |
|   39    | is opponent body touch ground  |

#### <span id="action2">Action Space</span>

| Index |    Description    |
|:-----:|:-----------------:|
| 0 - 7 | self dof position |

#### <span id="r2">Rewards</span>

The reward consists of two parts:sparse reward and dense reward.

```python
win_reward = 2000
lose_penalty = -2000
draw_penalty = -1000
dense_reward_scale = 1.
dof_at_limit_cost = torch.sum(obs_buf[:, 13:21] > 0.99, dim=-1) * joints_at_limit_cost_scale
push_reward = -push_scale * torch.exp(-torch.linalg.norm(obs_buf_op[:, :2], dim=-1))
action_cost_penalty = torch.sum(torch.square(torques), dim=1) * action_cost_scale
not_move_penalty = -10 * torch.exp(-torch.sum(torch.abs(torques), dim=1))
dense_reward = move_reward + dof_at_limit_cost + push_reward + action_cost_penalty + not_move_penalty
total_reward = win_reward + lose_penalty + draw_penalty + dense_reward * dense_reward_scale
```

### Ant Battle

Ant Battle is an expanded environment of Ant Sumo. It supports more than two agents competing against with
each other. The battle ring radius will shrink, the agent going out of the ring will be eliminated.

#### <span id="obs3">Observation Space</span>

|  Index  |              Description               |
|:-------:|:--------------------------------------:|
|  0 - 2  |               self pose                |
|  3 - 6  |             self rotation              |
|  7 - 9  |          self linear velocity          |
| 10 - 12 |          self angle velocity           |
| 13 - 20 |              self dof pos              |
| 21 - 28 |           self dof velocity            |
|   29    |    border radius-self dis to centre    |
|   30    |             border radius              |
|   31    |       is self body touch ground        |
| 32 - 34 |            opponent_1 pose             |
| 35 - 38 |          opponent_1 rotation           |
| 39 - 40 |    self-opponent_1 pose vector(x,y)    |
| 41 - 48 |          opponent_1 dof pose           |
| 49 - 56 |        opponent_1 dof velocity         |
|   57    | border radius-opponent_1 dis to centre |
|   58    |    is opponent_1 body touch ground     |
|   ...   |                  ...                   |

#### <span id="action3">Action Space</span>

| Index |    Description    |
|:-----:|:-----------------:|
| 0 - 7 | self dof position |

#### <span id="r3">Rewards</span>

The reward consists of two parts:sparse reward and dense reward.

```python
win_reward_scale = 2000
reward_per_rank = 2 * win_reward_scale / (num_agents - 1)
sparse_reward = sparse_reward * (win_reward_scale - (nxt_rank[:, 0] - 1) * reward_per_rank)
stay_in_center_reward = stay_in_center_reward_scale * torch.exp(-torch.linalg.norm(obs[0, :, :2], dim=-1))
dof_at_limit_cost = torch.sum(obs[0, :, 13:21] > 0.99, dim=-1) * joints_at_limit_cost_scale
action_cost_penalty = torch.sum(torch.square(torques), dim=1) * action_cost_scale
not_move_penalty = torch.exp(-torch.sum(torch.abs(torques), dim=1))
dense_reward = dof_at_limit_cost + action_cost_penalty + not_move_penalty + stay_in_center_reward
total_reward = sparse_reward + dense_reward * dense_reward_scale
```

================================================
FILE: setup.py
================================================
"""Installation script for the 'timechamber' python package."""

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

from setuptools import setup, find_packages

import os

root_dir = os.path.dirname(os.path.realpath(__file__))

# Minimum dependencies required prior to installation
INSTALL_REQUIRES = [
    # RL
    "gym==0.24",
    "torch",
    "omegaconf",
    "termcolor",
    "dill",
    "hydra-core>=1.1",
    "rl-games==1.5.2",
    "pyvirtualdisplay",
    "multielo @ git+https://github.com/djcunningham0/multielo.git@440f7922b90ff87009f8283d6491eb0f704e6624",
    "matplotlib==3.5.2",
    "pytest==7.1.2",
]

# Installation operation
setup(
    name="timechamber",
    author="ZeldaHuang, Ziyi Liu",
    version="0.0.1",
    description="A Massively Parallel Large Scale Self-Play Framework",
    keywords=["robotics", "rl"],
    include_package_data=True,
    python_requires=">=3.6.*",
    install_requires=INSTALL_REQUIRES,
    packages=find_packages("."),
    classifiers=["Natural Language :: English", "Programming Language :: Python :: 3.7, 3.8"],
    zip_safe=False,
)

# EOF


================================================
FILE: timechamber/__init__.py
================================================
import hydra
from hydra import compose, initialize
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from timechamber.utils.reformat import omegaconf_to_dict


OmegaConf.register_new_resolver('eq', lambda x, y: x.lower()==y.lower())
OmegaConf.register_new_resolver('contains', lambda x, y: x.lower() in y.lower())
OmegaConf.register_new_resolver('if', lambda pred, a, b: a if pred else b)
OmegaConf.register_new_resolver('resolve_default', lambda default, arg: default if arg=='' else arg)


def make(
    seed: int, 
    task: str, 
    num_envs: int, 
    sim_device: str,
    rl_device: str,
    graphics_device_id: int = -1,
    device_type: str = "cuda",
    headless: bool = False,
    multi_gpu: bool = False,
    virtual_screen_capture: bool = False,
    force_render: bool = True,
    cfg: DictConfig = None
):
    from timechamber.utils.rlgames_utils import get_rlgames_env_creator
    # create hydra config if no config passed in
    if cfg is None:
        # reset current hydra config if already parsed (but not passed in here)
        if HydraConfig.initialized():
            task = HydraConfig.get().runtime.choices['task']
            hydra.core.global_hydra.GlobalHydra.instance().clear()

        with initialize(config_path="./cfg"):
            cfg = compose(config_name="config", overrides=[f"task={task}"])
            task_dict = omegaconf_to_dict(cfg.task)
            task_dict['env']['numEnvs'] = num_envs
    # reuse existing config
    else:
        task_dict = omegaconf_to_dict(cfg.task)
    task_dict['seed'] = cfg.seed
    task_dict['rl_device'] = rl_device
    if cfg.motion_file:
        task_dict['env']['motion_file'] = cfg.motion_file
    
    create_rlgpu_env = get_rlgames_env_creator(
        seed=seed,
        cfg=cfg,
        task_config=task_dict,
        task_name=task_dict["name"],
        sim_device=sim_device,
        rl_device=rl_device,
        graphics_device_id=graphics_device_id,
        headless=headless,
        device_type=device_type,
        multi_gpu=multi_gpu,
        virtual_screen_capture=virtual_screen_capture,
        force_render=force_render,
    )
    return create_rlgpu_env()


================================================
FILE: timechamber/ase/ase_agent.py
================================================
# Copyright (c) 2018-2022, NVIDIA Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


import torch
import torch.nn as nn
from isaacgym.torch_utils import *
from rl_games.algos_torch import torch_ext
from rl_games.common import a2c_common
from rl_games.algos_torch.running_mean_std import RunningMeanStd

from timechamber.ase import ase_network_builder
from timechamber.ase.utils import amp_agent 

class ASEAgent(amp_agent.AMPAgent):
    def __init__(self, base_name, config):
        super().__init__(base_name, config)
        return

    def init_tensors(self):
        super().init_tensors()
        
        batch_shape = self.experience_buffer.obs_base_shape
        self.experience_buffer.tensor_dict['ase_latents'] = torch.zeros(batch_shape + (self._latent_dim,),
                                                                dtype=torch.float32, device=self.ppo_device)
        
        self._ase_latents = torch.zeros((batch_shape[-1], self._latent_dim), dtype=torch.float32,
                                         device=self.ppo_device)
        
        self.tensor_list += ['ase_latents']

        self._latent_reset_steps = torch.zeros(batch_shape[-1], dtype=torch.int32, device=self.ppo_device)
        num_envs = self.vec_env.env.task.num_envs
        env_ids = to_torch(np.arange(num_envs), dtype=torch.long, device=self.ppo_device)
        self._reset_latent_step_count(env_ids)

        return
    
    def play_steps(self):
        self.set_eval()

        epinfos = []
        done_indices = []
        update_list = self.update_list

        for n in range(self.horizon_length):
            self.obs = self.env_reset(done_indices)
            self.experience_buffer.update_data('obses', n, self.obs['obs'])

            self._update_latents()

            if self.use_action_masks:
                masks = self.vec_env.get_action_masks()
                res_dict = self.get_masked_action_values(self.obs, self._ase_latents, masks)
            else:
                res_dict = self.get_action_values(self.obs, self._ase_latents, self._rand_action_probs)

            for k in update_list:
                self.experience_buffer.update_data(k, n, res_dict[k]) 

            if self.has_central_value:
                self.experience_buffer.update_data('states', n, self.obs['states'])

            self.obs, rewards, self.dones, infos = self.env_step(res_dict['actions'])
            shaped_rewards = self.rewards_shaper(rewards)
            self.experience_buffer.update_data('rewards', n, shaped_rewards)
            self.experience_buffer.update_data('next_obses', n, self.obs['obs'])
            self.experience_buffer.update_data('dones', n, self.dones)
            self.experience_buffer.update_data('amp_obs', n, infos['amp_obs'])
            self.experience_buffer.update_data('ase_latents', n, self._ase_latents)
            self.experience_buffer.update_data('rand_action_mask', n, res_dict['rand_action_mask'])

            terminated = infos['terminate'].float()
            terminated = terminated.unsqueeze(-1)
            next_vals = self._eval_critic(self.obs, self._ase_latents)
            next_vals *= (1.0 - terminated)
            self.experience_buffer.update_data('next_values', n, next_vals)

            self.current_rewards += rewards
            self.current_lengths += 1
            all_done_indices = self.dones.nonzero(as_tuple=False)
            done_indices = all_done_indices[::self.num_agents]

            self.game_rewards.update(self.current_rewards[done_indices])
            self.game_lengths.update(self.current_lengths[done_indices])
            self.algo_observer.process_infos(infos, done_indices)

            not_dones = 1.0 - self.dones.float()

            self.current_rewards = self.current_rewards * not_dones.unsqueeze(1)
            self.current_lengths = self.current_lengths * not_dones
        
            if (self.vec_env.env.task.viewer):
                self._amp_debug(infos, self._ase_latents)

            done_indices = done_indices[:, 0]

        mb_fdones = self.experience_buffer.tensor_dict['dones'].float()
        mb_values = self.experience_buffer.tensor_dict['values']
        mb_next_values = self.experience_buffer.tensor_dict['next_values']
        
        mb_rewards = self.experience_buffer.tensor_dict['rewards']
        mb_amp_obs = self.experience_buffer.tensor_dict['amp_obs']
        mb_ase_latents = self.experience_buffer.tensor_dict['ase_latents']
        amp_rewards = self._calc_amp_rewards(mb_amp_obs, mb_ase_latents)
        mb_rewards = self._combine_rewards(mb_rewards, amp_rewards)
        
        mb_advs = self.discount_values(mb_fdones, mb_values, mb_rewards, mb_next_values)
        mb_returns = mb_advs + mb_values

        batch_dict = self.experience_buffer.get_transformed_list(a2c_common.swap_and_flatten01, self.tensor_list)
        batch_dict['returns'] = a2c_common.swap_and_flatten01(mb_returns)
        batch_dict['played_frames'] = self.batch_size

        for k, v in amp_rewards.items():
            batch_dict[k] = a2c_common.swap_and_flatten01(v)

        return batch_dict

    def get_action_values(self, obs_dict, ase_latents, rand_action_probs):
        processed_obs = self._preproc_obs(obs_dict['obs'])

        self.model.eval()
        input_dict = {
            'is_train': False,
            'prev_actions': None, 
            'obs' : processed_obs,
            'rnn_states' : self.rnn_states,
            'ase_latents': ase_latents
        }

        with torch.no_grad():
            res_dict = self.model(input_dict)
            if self.has_central_value:
                states = obs_dict['states']
                input_dict = {
                    'is_train': False,
                    'states' : states,
                }
                value = self.get_central_value(input_dict)
                res_dict['values'] = value

        if self.normalize_value:
            res_dict['values'] = self.value_mean_std(res_dict['values'], True)
        
        rand_action_mask = torch.bernoulli(rand_action_probs)
        det_action_mask = rand_action_mask == 0.0
        res_dict['actions'][det_action_mask] = res_dict['mus'][det_action_mask]
        res_dict['rand_action_mask'] = rand_action_mask

        return res_dict

    def prepare_dataset(self, batch_dict):
        super().prepare_dataset(batch_dict)
        
        ase_latents = batch_dict['ase_latents']
        self.dataset.values_dict['ase_latents'] = ase_latents
        
        return

    def calc_gradients(self, input_dict):
        self.set_train()

        value_preds_batch = input_dict['old_values']
        old_action_log_probs_batch = input_dict['old_logp_actions']
        advantage = input_dict['advantages']
        old_mu_batch = input_dict['mu']
        old_sigma_batch = input_dict['sigma']
        return_batch = input_dict['returns']
        actions_batch = input_dict['actions']
        obs_batch = input_dict['obs']
        obs_batch = self._preproc_obs(obs_batch)

        amp_obs = input_dict['amp_obs'][0:self._amp_minibatch_size]
        amp_obs = self._preproc_amp_obs(amp_obs)
        if (self._enable_enc_grad_penalty()):
            amp_obs.requires_grad_(True)

        amp_obs_replay = input_dict['amp_obs_replay'][0:self._amp_minibatch_size]
        amp_obs_replay = self._preproc_amp_obs(amp_obs_replay)

        amp_obs_demo = input_dict['amp_obs_demo'][0:self._amp_minibatch_size]
        amp_obs_demo = self._preproc_amp_obs(amp_obs_demo)
        amp_obs_demo.requires_grad_(True)

        ase_latents = input_dict['ase_latents']
        
        rand_action_mask = input_dict['rand_action_mask']
        rand_action_sum = torch.sum(rand_action_mask)

        lr = self.last_lr
        kl = 1.0
        lr_mul = 1.0
        curr_e_clip = lr_mul * self.e_clip

        batch_dict = {
            'is_train': True,
            'prev_actions': actions_batch, 
            'obs' : obs_batch,
            'amp_obs' : amp_obs,
            'amp_obs_replay' : amp_obs_replay,
            'amp_obs_demo' : amp_obs_demo,
            'ase_latents': ase_latents
        }

        rnn_masks = None
        if self.is_rnn:
            rnn_masks = input_dict['rnn_masks']
            batch_dict['rnn_states'] = input_dict['rnn_states']
            batch_dict['seq_length'] = self.seq_len
            
        rnn_masks = None
        if self.is_rnn:
            rnn_masks = input_dict['rnn_masks']
            batch_dict['rnn_states'] = input_dict['rnn_states']
            batch_dict['seq_length'] = self.seq_len

        with torch.cuda.amp.autocast(enabled=self.mixed_precision):
            res_dict = self.model(batch_dict)
            action_log_probs = res_dict['prev_neglogp']
            values = res_dict['values']
            entropy = res_dict['entropy']
            mu = res_dict['mus']
            sigma = res_dict['sigmas']
            disc_agent_logit = res_dict['disc_agent_logit']
            disc_agent_replay_logit = res_dict['disc_agent_replay_logit']
            disc_demo_logit = res_dict['disc_demo_logit']
            enc_pred = res_dict['enc_pred']

            a_info = self._actor_loss(old_action_log_probs_batch, action_log_probs, advantage, curr_e_clip)
            a_loss = a_info['actor_loss']
            a_clipped = a_info['actor_clipped'].float()

            c_info = self._critic_loss(value_preds_batch, values, curr_e_clip, return_batch, self.clip_value)
            c_loss = c_info['critic_loss']

            b_loss = self.bound_loss(mu)

            c_loss = torch.mean(c_loss)
            a_loss = torch.sum(rand_action_mask * a_loss) / rand_action_sum
            entropy = torch.sum(rand_action_mask * entropy) / rand_action_sum
            b_loss = torch.sum(rand_action_mask * b_loss) / rand_action_sum
            a_clip_frac = torch.sum(rand_action_mask * a_clipped) / rand_action_sum
            
            disc_agent_cat_logit = torch.cat([disc_agent_logit, disc_agent_replay_logit], dim=0)
            disc_info = self._disc_loss(disc_agent_cat_logit, disc_demo_logit, amp_obs_demo)
            disc_loss = disc_info['disc_loss']
            
            enc_latents = batch_dict['ase_latents'][0:self._amp_minibatch_size]
            enc_loss_mask = rand_action_mask[0:self._amp_minibatch_size]
            enc_info = self._enc_loss(enc_pred, enc_latents, batch_dict['amp_obs'], enc_loss_mask)
            enc_loss = enc_info['enc_loss']

            loss = a_loss + self.critic_coef * c_loss - self.entropy_coef * entropy + self.bounds_loss_coef * b_loss \
                 + self._disc_coef * disc_loss + self._enc_coef * enc_loss
            
            if (self._enable_amp_diversity_bonus()):
                diversity_loss = self._diversity_loss(batch_dict['obs'], mu, batch_dict['ase_latents'])
                diversity_loss = torch.sum(rand_action_mask * diversity_loss) / rand_action_sum
                loss += self._amp_diversity_bonus * diversity_loss
                a_info['amp_diversity_loss'] = diversity_loss
                
            a_info['actor_loss'] = a_loss
            a_info['actor_clip_frac'] = a_clip_frac
            c_info['critic_loss'] = c_loss

            if self.multi_gpu:
                self.optimizer.zero_grad()
            else:
                for param in self.model.parameters():
                    param.grad = None

        self.scaler.scale(loss).backward()
        #TODO: Refactor this ugliest code of the year
        if self.truncate_grads:
            if self.multi_gpu:
                self.optimizer.synchronize()
                self.scaler.unscale_(self.optimizer)
                nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)
                with self.optimizer.skip_synchronize():
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
            else:
                self.scaler.unscale_(self.optimizer)
                nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)
                self.scaler.step(self.optimizer)
                self.scaler.update()    
        else:
            self.scaler.step(self.optimizer)
            self.scaler.update()

        with torch.no_grad():
            reduce_kl = not self.is_rnn
            kl_dist = torch_ext.policy_kl(mu.detach(), sigma.detach(), old_mu_batch, old_sigma_batch, reduce_kl)
            if self.is_rnn:
                kl_dist = (kl_dist * rnn_masks).sum() / rnn_masks.numel()  #/ sum_mask
        
        self.train_result = {
            'entropy': entropy,
            'kl': kl_dist,
            'last_lr': self.last_lr, 
            'lr_mul': lr_mul, 
            'b_loss': b_loss
        }
        self.train_result.update(a_info)
        self.train_result.update(c_info)
        self.train_result.update(disc_info)
        self.train_result.update(enc_info)

        return
    
    def env_reset(self, env_ids=None):
        obs = super().env_reset(env_ids)
        
        if (env_ids is None):
            num_envs = self.vec_env.env.task.num_envs
            env_ids = to_torch(np.arange(num_envs), dtype=torch.long, device=self.ppo_device)

        if (len(env_ids) > 0):
            self._reset_latents(env_ids)
            self._reset_latent_step_count(env_ids)

        return obs

    def _reset_latent_step_count(self, env_ids):
        self._latent_reset_steps[env_ids] = torch.randint_like(self._latent_reset_steps[env_ids], low=self._latent_steps_min, 
                                                         high=self._latent_steps_max)
        return

    def _load_config_params(self, config):
        super()._load_config_params(config)
        
        self._latent_dim = config['latent_dim']
        self._latent_steps_min = config.get('latent_steps_min', np.inf)
        self._latent_steps_max = config.get('latent_steps_max', np.inf)
        self._latent_dim = config['latent_dim']
        self._amp_diversity_bonus = config['amp_diversity_bonus']
        self._amp_diversity_tar = config['amp_diversity_tar']
        
        self._enc_coef = config['enc_coef']
        self._enc_weight_decay = config['enc_weight_decay']
        self._enc_reward_scale = config['enc_reward_scale']
        self._enc_grad_penalty = config['enc_grad_penalty']

        self._enc_reward_w = config['enc_reward_w']

        return
    
    def _build_net_config(self):
        config = super()._build_net_config()
        config['ase_latent_shape'] = (self._latent_dim,)
        return config

    def _reset_latents(self, env_ids):
        n = len(env_ids)
        z = self._sample_latents(n)
        self._ase_latents[env_ids] = z

        if (self.vec_env.env.task.viewer):
            self._change_char_color(env_ids)

        return

    def _sample_latents(self, n):
        z = self.model.a2c_network.sample_latents(n)
        return z

    def _update_latents(self):
        new_latent_envs = self._latent_reset_steps <= self.vec_env.env.task.progress_buf

        need_update = torch.any(new_latent_envs)
        if (need_update):
            new_latent_env_ids = new_latent_envs.nonzero(as_tuple=False).flatten()
            self._reset_latents(new_latent_env_ids)
            self._latent_reset_steps[new_latent_env_ids] += torch.randint_like(self._latent_reset_steps[new_latent_env_ids],
                                                                               low=self._latent_steps_min, 
                                                                               high=self._latent_steps_max)
            if (self.vec_env.env.task.viewer):
                self._change_char_color(new_latent_env_ids)

        return

    def _eval_actor(self, obs, ase_latents):
        output = self.model.eval_actor(obs=obs, ase_latents=ase_latents)
        return output

    def _eval_critic(self, obs_dict, ase_latents):
        self.model.eval()
        obs = obs_dict['obs']
        processed_obs = self._preproc_obs(obs)
        value = self.model.eval_critic(processed_obs, ase_latents)

        if self.normalize_value:
            value = self.value_mean_std(value, True)
        return value

    def _calc_amp_rewards(self, amp_obs, ase_latents):
        disc_r = self._calc_disc_rewards(amp_obs)
        enc_r = self._calc_enc_rewards(amp_obs, ase_latents)
        output = {
            'disc_rewards': disc_r,
            'enc_rewards': enc_r
        }
        return output

    def _calc_enc_rewards(self, amp_obs, ase_latents):
        with torch.no_grad():
            enc_pred = self._eval_enc(amp_obs)
            err = self._calc_enc_error(enc_pred, ase_latents)
            enc_r = torch.clamp_min(-err, 0.0)
            enc_r *= self._enc_reward_scale

        return enc_r

    def _enc_loss(self, enc_pred, ase_latent, enc_obs, loss_mask):
        enc_err = self._calc_enc_error(enc_pred, ase_latent)
        #mask_sum = torch.sum(loss_mask)
        #enc_err = enc_err.squeeze(-1)
        #enc_loss = torch.sum(loss_mask * enc_err) / mask_sum
        enc_loss = torch.mean(enc_err)

        # weight decay
        if (self._enc_weight_decay != 0):
            enc_weights = self.model.a2c_network.get_enc_weights()
            enc_weights = torch.cat(enc_weights, dim=-1)
            enc_weight_decay = torch.sum(torch.square(enc_weights))
            enc_loss += self._enc_weight_decay * enc_weight_decay
            
        enc_info = {
            'enc_loss': enc_loss
        }

        if (self._enable_enc_grad_penalty()):
            enc_obs_grad = torch.autograd.grad(enc_err, enc_obs, grad_outputs=torch.ones_like(enc_err),
                                               create_graph=True, retain_graph=True, only_inputs=True)
            enc_obs_grad = enc_obs_grad[0]
            enc_obs_grad = torch.sum(torch.square(enc_obs_grad), dim=-1)
            #enc_grad_penalty = torch.sum(loss_mask * enc_obs_grad) / mask_sum
            enc_grad_penalty = torch.mean(enc_obs_grad)

            enc_loss += self._enc_grad_penalty * enc_grad_penalty

            enc_info['enc_grad_penalty'] = enc_grad_penalty.detach()

        return enc_info

    def _diversity_loss(self, obs, action_params, ase_latents):
        assert(self.model.a2c_network.is_continuous)

        n = obs.shape[0]
        assert(n == action_params.shape[0])

        new_z = self._sample_latents(n)
        mu, sigma = self._eval_actor(obs=obs, ase_latents=new_z)

        clipped_action_params = torch.clamp(action_params, -1.0, 1.0)
        clipped_mu = torch.clamp(mu, -1.0, 1.0)

        a_diff = clipped_action_params - clipped_mu
        a_diff = torch.mean(torch.square(a_diff), dim=-1)

        z_diff = new_z * ase_latents
        z_diff = torch.sum(z_diff, dim=-1)
        z_diff = 0.5 - 0.5 * z_diff

        diversity_bonus = a_diff / (z_diff + 1e-5)
        diversity_loss = torch.square(self._amp_diversity_tar - diversity_bonus)

        return diversity_loss

    def _calc_enc_error(self, enc_pred, ase_latent):
        err = enc_pred * ase_latent
        err = -torch.sum(err, dim=-1, keepdim=True)
        return err

    def _enable_enc_grad_penalty(self):
        return self._enc_grad_penalty != 0

    def _enable_amp_diversity_bonus(self):
        return self._amp_diversity_bonus != 0

    def _eval_enc(self, amp_obs):
        proc_amp_obs = self._preproc_amp_obs(amp_obs)
        return self.model.a2c_network.eval_enc(proc_amp_obs)

    def _combine_rewards(self, task_rewards, amp_rewards):
        disc_r = amp_rewards['disc_rewards']
        enc_r = amp_rewards['enc_rewards']
        combined_rewards = self._task_reward_w * task_rewards \
                         + self._disc_reward_w * disc_r \
                         + self._enc_reward_w * enc_r
        return combined_rewards

    def _record_train_batch_info(self, batch_dict, train_info):
        super()._record_train_batch_info(batch_dict, train_info)
        train_info['enc_rewards'] = batch_dict['enc_rewards']
        return

    def _log_train_info(self, train_info, frame):
        super()._log_train_info(train_info, frame)
        
        self.writer.add_scalar('losses/enc_loss', torch_ext.mean_list(train_info['enc_loss']).item(), frame)
         
        if (self._enable_amp_diversity_bonus()):
            self.writer.add_scalar('losses/amp_diversity_loss', torch_ext.mean_list(train_info['amp_diversity_loss']).item(), frame)
        
        enc_reward_std, enc_reward_mean = torch.std_mean(train_info['enc_rewards'])
        self.writer.add_scalar('info/enc_reward_mean', enc_reward_mean.item(), frame)
        self.writer.add_scalar('info/enc_reward_std', enc_reward_std.item(), frame)

        if (self._enable_enc_grad_penalty()):
            self.writer.add_scalar('info/enc_grad_penalty', torch_ext.mean_list(train_info['enc_grad_penalty']).item(), frame)

        return

    def _change_char_color(self, env_ids):
        base_col = np.array([0.4, 0.4, 0.4])
        range_col = np.array([0.0706, 0.149, 0.2863])
        range_sum = np.linalg.norm(range_col)

        rand_col = np.random.uniform(0.0, 1.0, size=3)
        rand_col = range_sum * rand_col / np.linalg.norm(rand_col)
        rand_col += base_col
        self.vec_env.env.task.set_char_color(rand_col, env_ids)
        return

    def _amp_debug(self, info, ase_latents):
        with torch.no_grad():
            amp_obs = info['amp_obs']
            amp_obs = amp_obs
            ase_latents = ase_latents
            disc_pred = self._eval_disc(amp_obs)
            amp_rewards = self._calc_amp_rewards(amp_obs, ase_latents)
            disc_reward = amp_rewards['disc_rewards']
            enc_reward = amp_rewards['enc_rewards']

            disc_pred = disc_pred.detach().cpu().numpy()[0, 0]
            disc_reward = disc_reward.cpu().numpy()[0, 0]
            enc_reward = enc_reward.cpu().numpy()[0, 0]
            print("disc_pred: ", disc_pred, disc_reward, enc_reward)
        return

================================================
FILE: timechamber/ase/ase_models.py
================================================
# Copyright (c) 2018-2022, NVIDIA Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from timechamber.ase.utils import amp_models

class ModelASEContinuous(amp_models.ModelAMPContinuous):
    def __init__(self, network):
        super().__init__(network)
        return

    def build(self, config):
        net = self.network_builder.build('ase', **config)
        for name, _ in net.named_parameters():
            print(name)
        # print(f"ASE config: {config}")
        obs_shape = config['input_shape']
        normalize_value = config.get('normalize_value', False)
        normalize_input = config.get('normalize_input', False)
        value_size = config.get('value_size', 1)
        return ModelASEContinuous.Network(net,obs_shape=obs_shape, normalize_value=normalize_value,
                                          normalize_input=normalize_input, value_size=value_size)


    class Network(amp_models.ModelAMPContinuous.Network):
        def __init__(self, a2c_network, obs_shape, normalize_value, normalize_input, value_size):
            super().__init__(a2c_network,
                             obs_shape=obs_shape, 
                             normalize_value=normalize_value,
                             normalize_input=normalize_input, 
                             value_size=value_size)
            return

        def forward(self, input_dict):
            is_train = input_dict.get('is_train', True)
            result = super().forward(input_dict)

            if (is_train):
                amp_obs = input_dict['amp_obs']
                enc_pred = self.a2c_network.eval_enc(amp_obs)
                result["enc_pred"] = enc_pred

            return result

        def eval_actor(self, obs, ase_latents, use_hidden_latents=False):
            processed_obs = self.norm_obs(obs)
            mu, sigma = self.a2c_network.eval_actor(obs=processed_obs, ase_latents=ase_latents)
            return mu, sigma

        def eval_critic(self, obs, ase_latents, use_hidden_latents=False):
            processed_obs = self.norm_obs(obs)
            value = self.a2c_network.eval_critic(processed_obs, ase_latents, use_hidden_latents)
            return value

================================================
FILE: timechamber/ase/ase_network_builder.py
================================================
# Copyright (c) 2018-2022, NVIDIA Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from rl_games.algos_torch import torch_ext
from rl_games.algos_torch import layers
from rl_games.algos_torch import network_builder

import torch
import torch.nn as nn
import numpy as np
import enum

from timechamber.ase.utils import amp_network_builder

ENC_LOGIT_INIT_SCALE = 0.1

class LatentType(enum.Enum):
    uniform = 0
    sphere = 1

class ASEBuilder(amp_network_builder.AMPBuilder):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        return

    class Network(amp_network_builder.AMPBuilder.Network):
        def __init__(self, params, **kwargs):
            actions_num = kwargs.get('actions_num')
            input_shape = kwargs.get('input_shape')
            self.value_size = kwargs.get('value_size', 1)
            self.num_seqs = num_seqs = kwargs.get('num_seqs', 1)
            amp_input_shape = kwargs.get('amp_input_shape')
            self._ase_latent_shape = kwargs.get('ase_latent_shape')

            network_builder.NetworkBuilder.BaseNetwork.__init__(self)
            
            self.load(params)

            actor_out_size, critic_out_size = self._build_actor_critic_net(input_shape, self._ase_latent_shape)

            self.value = torch.nn.Linear(critic_out_size, self.value_size)
            self.value_act = self.activations_factory.create(self.value_activation)
            
            if self.is_discrete:
                self.logits = torch.nn.Linear(actor_out_size, actions_num)
            '''
                for multidiscrete actions num is a tuple
            '''
            if self.is_multi_discrete:
                self.logits = torch.nn.ModuleList([torch.nn.Linear(actor_out_size, num) for num in actions_num])
            if self.is_continuous:
                self.mu = torch.nn.Linear(actor_out_size, actions_num)
                self.mu_act = self.activations_factory.create(self.space_config['mu_activation']) 
                mu_init = self.init_factory.create(**self.space_config['mu_init'])
                self.sigma_act = self.activations_factory.create(self.space_config['sigma_activation']) 

                sigma_init = self.init_factory.create(**self.space_config['sigma_init'])

                if (not self.space_config['learn_sigma']):
                    self.sigma = nn.Parameter(torch.zeros(actions_num, requires_grad=False, dtype=torch.float32), requires_grad=False)
                elif self.space_config['fixed_sigma']:
                    self.sigma = nn.Parameter(torch.zeros(actions_num, requires_grad=True, dtype=torch.float32), requires_grad=True)
                else:
                    self.sigma = torch.nn.Linear(actor_out_size, actions_num)

            mlp_init = self.init_factory.create(**self.initializer)
            if self.has_cnn:
                cnn_init = self.init_factory.create(**self.cnn['initializer'])

            for m in self.modules():         
                if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
                    cnn_init(m.weight)
                    if getattr(m, "bias", None) is not None:
                        torch.nn.init.zeros_(m.bias)
                if isinstance(m, nn.Linear):
                    mlp_init(m.weight)
                    if getattr(m, "bias", None) is not None:
                        torch.nn.init.zeros_(m.bias)    

            self.actor_mlp.init_params()
            self.critic_mlp.init_params()

            if self.is_continuous:
                mu_init(self.mu.weight)
                if self.space_config['fixed_sigma']:
                    sigma_init(self.sigma)
                else:
                    sigma_init(self.sigma.weight)

            self._build_disc(amp_input_shape)
            self._build_enc(amp_input_shape)

            return

        def load(self, params):
            super().load(params)

            self._enc_units = params['enc']['units']
            self._enc_activation = params['enc']['activation']
            self._enc_initializer = params['enc']['initializer']
            self._enc_separate = params['enc']['separate']

            return

        def forward(self, obs_dict):
            obs = obs_dict['obs']
            ase_latents = obs_dict['ase_latents']
            states = obs_dict.get('rnn_states', None)
            use_hidden_latents = obs_dict.get('use_hidden_latents', False)

            actor_outputs = self.eval_actor(obs, ase_latents, use_hidden_latents)
            value = self.eval_critic(obs, ase_latents, use_hidden_latents)

            output = actor_outputs + (value, states)

            return output

        def eval_critic(self, obs, ase_latents, use_hidden_latents=False):
            c_out = self.critic_cnn(obs)
            c_out = c_out.contiguous().view(c_out.size(0), -1)
            
            c_out = self.critic_mlp(c_out, ase_latents, use_hidden_latents)
            value = self.value_act(self.value(c_out))
            return value

        def eval_actor(self, obs, ase_latents, use_hidden_latents=False):
            a_out = self.actor_cnn(obs)
            a_out = a_out.contiguous().view(a_out.size(0), -1)
            a_out = self.actor_mlp(a_out, ase_latents, use_hidden_latents)
                     
            if self.is_discrete:
                logits = self.logits(a_out)
                return logits

            if self.is_multi_discrete:
                logits = [logit(a_out) for logit in self.logits]
                return logits

            if self.is_continuous:
                mu = self.mu_act(self.mu(a_out))
                if self.space_config['fixed_sigma']:
                    sigma = mu * 0.0 + self.sigma_act(self.sigma)
                else:
                    sigma = self.sigma_act(self.sigma(a_out))

                return mu, sigma
            return

        def get_enc_weights(self):
            weights = []
            for m in self._enc_mlp.modules():
                if isinstance(m, nn.Linear):
                    weights.append(torch.flatten(m.weight))

            weights.append(torch.flatten(self._enc.weight))
            return weights

        def _build_actor_critic_net(self, input_shape, ase_latent_shape):
            style_units = [512, 256]
            style_dim = ase_latent_shape[-1]

            self.actor_cnn = nn.Sequential()
            self.critic_cnn = nn.Sequential()
            
            act_fn = self.activations_factory.create(self.activation)
            initializer = self.init_factory.create(**self.initializer)

            self.actor_mlp = AMPStyleCatNet1(obs_size=input_shape[-1],
                                             ase_latent_size=ase_latent_shape[-1],
                                             units=self.units,
                                             activation=act_fn,
                                             style_units=style_units,
                                             style_dim=style_dim,
                                             initializer=initializer)

            if self.separate:
                self.critic_mlp = AMPMLPNet(obs_size=input_shape[-1],
                                            ase_latent_size=ase_latent_shape[-1],
                                            units=self.units,
                                            activation=act_fn,
                                            initializer=initializer)

            actor_out_size = self.actor_mlp.get_out_size()
            critic_out_size = self.critic_mlp.get_out_size()

            return actor_out_size, critic_out_size

        def _build_enc(self, input_shape):
            if (self._enc_separate):
                self._enc_mlp = nn.Sequential()
                mlp_args = {
                    'input_size' : input_shape[0], 
                    'units' : self._enc_units, 
                    'activation' : self._enc_activation, 
                    'dense_func' : torch.nn.Linear
                }
                self._enc_mlp = self._build_mlp(**mlp_args)

                mlp_init = self.init_factory.create(**self._enc_initializer)
                for m in self._enc_mlp.modules():
                    if isinstance(m, nn.Linear):
                        mlp_init(m.weight)
                        if getattr(m, "bias", None) is not None:
                            torch.nn.init.zeros_(m.bias)
            else:
                self._enc_mlp = self._disc_mlp

            mlp_out_layer = list(self._enc_mlp.modules())[-2]
            mlp_out_size = mlp_out_layer.out_features
            self._enc = torch.nn.Linear(mlp_out_size, self._ase_latent_shape[-1])
            
            torch.nn.init.uniform_(self._enc.weight, -ENC_LOGIT_INIT_SCALE, ENC_LOGIT_INIT_SCALE)
            torch.nn.init.zeros_(self._enc.bias) 
            
            return

        def eval_enc(self, amp_obs):
            enc_mlp_out = self._enc_mlp(amp_obs)
            enc_output = self._enc(enc_mlp_out)
            enc_output = torch.nn.functional.normalize(enc_output, dim=-1)

            return enc_output

        def sample_latents(self, n):
            device = next(self._enc.parameters()).device
            z = torch.normal(torch.zeros([n, self._ase_latent_shape[-1]], device=device))
            z = torch.nn.functional.normalize(z, dim=-1)
            return z

    def build(self, name, **kwargs):
        net = ASEBuilder.Network(self.params, **kwargs)
        return net


class AMPMLPNet(torch.nn.Module):
    def __init__(self, obs_size, ase_latent_size, units, activation, initializer):
        super().__init__()

        input_size = obs_size + ase_latent_size
        print('build amp mlp net:', input_size)
        
        self._units = units
        self._initializer = initializer
        self._mlp = []

        in_size = input_size
        for i in range(len(units)):
            unit = units[i]
            curr_dense = torch.nn.Linear(in_size, unit)
            self._mlp.append(curr_dense)
            self._mlp.append(activation)
            in_size = unit

        self._mlp = nn.Sequential(*self._mlp)
        self.init_params()
        return

    def forward(self, obs, latent, skip_style):
        inputs = [obs, latent]
        input = torch.cat(inputs, dim=-1)
        output = self._mlp(input)
        return output

    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                self._initializer(m.weight)
                if getattr(m, "bias", None) is not None:
                    torch.nn.init.zeros_(m.bias)
        return

    def get_out_size(self):
        out_size = self._units[-1]
        return out_size

class AMPStyleCatNet1(torch.nn.Module):
    def __init__(self, obs_size, ase_latent_size, units, activation,
                 style_units, style_dim, initializer):
        super().__init__()

        print('build amp style cat net:', obs_size, ase_latent_size)
            
        self._activation = activation
        self._initializer = initializer
        self._dense_layers = []
        self._units = units
        self._style_dim = style_dim
        self._style_activation = torch.tanh

        self._style_mlp = self._build_style_mlp(style_units, ase_latent_size)
        self._style_dense = torch.nn.Linear(style_units[-1], style_dim)

        in_size = obs_size + style_dim
        for i in range(len(units)):
            unit = units[i]
            out_size = unit
            curr_dense = torch.nn.Linear(in_size, out_size)
            self._dense_layers.append(curr_dense)
            
            in_size = out_size

        self._dense_layers = nn.ModuleList(self._dense_layers)

        self.init_params()

        return

    def forward(self, obs, latent, skip_style):
        if (skip_style):
            style = latent
        else:
            style = self.eval_style(latent)

        h = torch.cat([obs, style], dim=-1)

        for i in range(len(self._dense_layers)):
            curr_dense = self._dense_layers[i]
            h = curr_dense(h)
            h = self._activation(h)

        return h

    def eval_style(self, latent):
        style_h = self._style_mlp(latent)
        style = self._style_dense(style_h)
        style = self._style_activation(style)
        return style

    def init_params(self):
        scale_init_range = 1.0

        for m in self.modules():
            if isinstance(m, nn.Linear):
                self._initializer(m.weight)
                if getattr(m, "bias", None) is not None:
                    torch.nn.init.zeros_(m.bias)

        nn.init.uniform_(self._style_dense.weight, -scale_init_range, scale_init_range)
        return

    def get_out_size(self):
        out_size = self._units[-1]
        return out_size

    def _build_style_mlp(self, style_units, input_size):
        in_size = input_size
        layers = []
        for unit in style_units:
            layers.append(torch.nn.Linear(in_size, unit))
            layers.append(self._activation)
            in_size = unit

        enc_mlp = nn.Sequential(*layers)
        return enc_mlp

================================================
FILE: timechamber/ase/ase_players.py
================================================
# Copyright (c) 2018-2022, NVIDIA Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from pytest import param
import torch

from isaacgym.torch_utils import *
from rl_games.algos_torch import players

from timechamber.ase.utils import amp_players
import timechamber.ase.ase_network_builder as ase_network_builder

class ASEPlayer(amp_players.AMPPlayerContinuous):
    def __init__(self, params):
        config = params['config']
        self._latent_dim = config['latent_dim']
        self._latent_steps_min = config.get('latent_steps_min', np.inf)
        self._latent_steps_max = config.get('latent_steps_max', np.inf)

        self._enc_reward_scale = config['enc_reward_scale']

        super().__init__(params)
        
        if (hasattr(self, 'env')) and self.env is not None:
            batch_size = self.env.task.num_envs
        else:
            batch_size = self.env_info['num_envs']
        self._ase_latents = torch.zeros((batch_size, self._latent_dim), dtype=torch.float32,
                                         device=self.device)

        return

    def run(self):
        self._reset_latent_step_count()
        super().run()
        return

    def get_action(self, obs_dict, is_determenistic=False):
        self._update_latents()

        obs = obs_dict['obs']
        if len(obs.size()) == len(self.obs_shape):
            obs = obs.unsqueeze(0)
        obs = self._preproc_obs(obs)
        ase_latents = self._ase_latents

        input_dict = {
            'is_train': False,
            'prev_actions': None, 
            'obs' : obs,
            'rnn_states' : self.states,
            'ase_latents': ase_latents
        }
        with torch.no_grad():
            res_dict = self.model(input_dict)
        mu = res_dict['mus']
        action = res_dict['actions']
        self.states = res_dict['rnn_states']
        if is_determenistic:
            current_action = mu
        else:
            current_action = action
        current_action = torch.squeeze(current_action.detach())
        return  players.rescale_actions(self.actions_low, self.actions_high, torch.clamp(current_action, -1.0, 1.0))

    def env_reset(self, env_ids=None):
        obs = super().env_reset(env_ids)
        self._reset_latents(env_ids)
        return obs
    
    def _build_net_config(self):
        config = super()._build_net_config()
        config['ase_latent_shape'] = (self._latent_dim,)
        return config
    
    def _reset_latents(self, done_env_ids=None):
        if (done_env_ids is None):
            num_envs = self.env.task.num_envs
            done_env_ids = to_torch(np.arange(num_envs), dtype=torch.long, device=self.device)

        rand_vals = self.model.a2c_network.sample_latents(len(done_env_ids))
        self._ase_latents[done_env_ids] = rand_vals
        self._change_char_color(done_env_ids)

        return

    def _update_latents(self):
        if (self._latent_step_count <= 0):
            self._reset_latents()
            self._reset_latent_step_count()

            if (self.env.task.viewer):
                print("Sampling new amp latents------------------------------")
                num_envs = self.env.task.num_envs
                env_ids = to_torch(np.arange(num_envs), dtype=torch.long, device=self.device)
                self._change_char_color(env_ids)
        else:
            self._latent_step_count -= 1
        return
    
    def _reset_latent_step_count(self):
        self._latent_step_count = np.random.randint(self._latent_steps_min, self._latent_steps_max)
        return

    def _calc_amp_rewards(self, amp_obs, ase_latents):
        disc_r = self._calc_disc_rewards(amp_obs)
        enc_r = self._calc_enc_rewards(amp_obs, ase_latents)
        output = {
            'disc_rewards': disc_r,
            'enc_rewards': enc_r
        }
        return output
    
    def _calc_enc_rewards(self, amp_obs, ase_latents):
        with torch.no_grad():
            enc_pred = self._eval_enc(amp_obs)
            err = self._calc_enc_error(enc_pred, ase_latents)
            enc_r = torch.clamp_min(-err, 0.0)
            enc_r *= self._enc_reward_scale

        return enc_r
    
    def _calc_enc_error(self, enc_pred, ase_latent):
        err = enc_pred * ase_latent
        err = -torch.sum(err, dim=-1, keepdim=True)
        return err
    
    def _eval_enc(self, amp_obs):
        proc_amp_obs = self._preproc_amp_obs(amp_obs)
        return self.model.a2c_network.eval_enc(proc_amp_obs)

    def _amp_debug(self, info):
        with torch.no_grad():
            amp_obs = info['amp_obs']
            amp_obs = amp_obs
            ase_latents = self._ase_latents
            disc_pred = self._eval_disc(amp_obs)
            amp_rewards = self._calc_amp_rewards(amp_obs, ase_latents)
            disc_reward = amp_rewards['disc_rewards']
            enc_reward = amp_rewards['enc_rewards']

            disc_pred = disc_pred.detach().cpu().numpy()[0, 0]
            disc_reward = disc_reward.cpu().numpy()[0, 0]
            enc_reward = enc_reward.cpu().numpy()[0, 0]
            print("disc_pred: ", disc_pred, disc_reward, enc_reward)
        return

    def _change_char_color(self, env_ids):
        base_col = np.array([0.4, 0.4, 0.4])
        range_col = np.array([0.0706, 0.149, 0.2863])
        range_sum = np.linalg.norm(range_col)

        rand_col = np.random.uniform(0.0, 1.0, size=3)
        rand_col = range_sum * rand_col / np.linalg.norm(rand_col)
        rand_col += base_col
        self.env.task.set_char_color(rand_col, env_ids)
        return

================================================
FILE: timechamber/ase/hrl_agent.py
================================================
# Copyright (c) 2018-2022, NVIDIA Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import copy
from datetime import datetime
from distutils.command.config import config
from gym import spaces
import numpy as np
import os
import time
import yaml

from rl_games.algos_torch import torch_ext
from rl_games.algos_torch import central_value
from rl_games.algos_torch.running_mean_std import RunningMeanStd
from rl_games.common import a2c_common
from rl_games.common import datasets
from rl_games.common import schedulers
from rl_games.common import vecenv

import torch
from torch import optim

import timechamber.ase.utils.common_agent as common_agent 
import timechamber.ase.ase_agent as ase_agent
import timechamber.ase.ase_models as ase_models
import timechamber.ase.ase_network_builder as ase_network_builder

from tensorboardX import SummaryWriter

class HRLAgent(common_agent.CommonAgent):
    def __init__(self, base_name, params):
        config = params['config']
        with open(os.path.join(os.getcwd(), config['llc_config']), 'r') as f:
            llc_config = yaml.load(f, Loader=yaml.SafeLoader)
            llc_config_params = llc_config['params']
            self._latent_dim = llc_config_params['config']['latent_dim']

        super().__init__(base_name, params)

        self._task_size = self.vec_env.env.task.get_task_obs_size()

        self._llc_steps = config['llc_steps']
        llc_checkpoint = config['llc_checkpoint']
        assert(llc_checkpoint != "")
        self._build_llc(llc_config_params, llc_checkpoint)

        return

    def env_step(self, actions):
        actions = self.preprocess_actions(actions)
        obs = self.obs['obs']

        rewards = 0.0
        disc_rewards = 0.0
        done_count = 0.0
        terminate_count = 0.0
        for t in range(self._llc_steps):
            llc_actions = self._compute_llc_action(obs, actions)
            obs_dict, curr_rewards, curr_dones, infos = self.vec_env.step(llc_actions)

            # TODO
            obs = obs_dict['obs']
            
            rewards += curr_rewards
            done_count += curr_dones
            terminate_count += infos['terminate']

            amp_obs = infos['amp_obs']
            curr_disc_reward = self._calc_disc_reward(amp_obs)
            disc_rewards += curr_disc_reward

        rewards /= self._llc_steps
        disc_rewards /= self._llc_steps

        dones = torch.zeros_like(done_count)
        dones[done_count > 0] = 1.0
        terminate = torch.zeros_like(terminate_count)
        terminate[terminate_count > 0] = 1.0
        infos['terminate'] = terminate
        infos['disc_rewards'] = disc_rewards

        if self.is_tensor_obses:
            if self.value_size == 1:
                rewards = rewards.unsqueeze(1)
            return self.obs_to_tensors(obs), rewards.to(self.ppo_device), dones.to(self.ppo_device), infos
        else:
            if self.value_size == 1:
                rewards = np.expand_dims(rewards, axis=1)
            return self.obs_to_tensors(obs), torch.from_numpy(rewards).to(self.ppo_device).float(), torch.from_numpy(dones).to(self.ppo_device), infos

    def cast_obs(self, obs):
        obs = super().cast_obs(obs)
        self._llc_agent.is_tensor_obses = self.is_tensor_obses
        return obs

    def preprocess_actions(self, actions):
        clamped_actions = torch.clamp(actions, -1.0, 1.0)
        if not self.is_tensor_obses:
            clamped_actions = clamped_actions.cpu().numpy()
        return clamped_actions

    def play_steps(self):
        self.set_eval()
        
        epinfos = []
        done_indices = torch.tensor([], device=self.device, dtype=torch.long)
        update_list = self.update_list

        for n in range(self.horizon_length):
            self.obs = self.env_reset(done_indices)
            self.experience_buffer.update_data('obses', n, self.obs['obs'])

            if self.use_action_masks:
                masks = self.vec_env.get_action_masks()
                res_dict = self.get_masked_action_values(self.obs, masks)
            else:
                res_dict = self.get_action_values(self.obs)

            for k in update_list:
                self.experience_buffer.update_data(k, n, res_dict[k]) 

            if self.has_central_value:
                self.experience_buffer.update_data('states', n, self.obs['states'])

            self.obs, rewards, self.dones, infos = self.env_step(res_dict['actions'])
            shaped_rewards = self.rewards_shaper(rewards)
            self.experience_buffer.update_data('rewards', n, shaped_rewards)
            self.experience_buffer.update_data('next_obses', n, self.obs['obs'])
            self.experience_buffer.update_data('dones', n, self.dones)
            
            self.experience_buffer.update_data('disc_rewards', n, infos['disc_rewards'])

            terminated = infos['terminate'].float()
            terminated = terminated.unsqueeze(-1)
            next_vals = self._eval_critic(self.obs)
            next_vals *= (1.0 - terminated)
            self.experience_buffer.update_data('next_values', n, next_vals)

            self.current_rewards += rewards
            self.current_lengths += 1
            all_done_indices = self.dones.nonzero(as_tuple=False)
            done_indices = all_done_indices[::self.num_agents]
  
            self.game_rewards.update(self.current_rewards[done_indices])
            self.game_lengths.update(self.current_lengths[done_indices])
            self.algo_observer.process_infos(infos, done_indices)

            not_dones = 1.0 - self.dones.float()

            self.current_rewards = self.current_rewards * not_dones.unsqueeze(1)
            self.current_lengths = self.current_lengths * not_dones

            done_indices = done_indices[:, 0]

        mb_fdones = self.experience_buffer.tensor_dict['dones'].float()
        mb_values = self.experience_buffer.tensor_dict['values']
        mb_next_values = self.experience_buffer.tensor_dict['next_values']

        mb_rewards = self.experience_buffer.tensor_dict['rewards']
        mb_disc_rewards = self.experience_buffer.tensor_dict['disc_rewards']
        mb_rewards = self._combine_rewards(mb_rewards, mb_disc_rewards)

        mb_advs = self.discount_values(mb_fdones, mb_values, mb_rewards, mb_next_values)
        mb_returns = mb_advs + mb_values

        batch_dict = self.experience_buffer.get_transformed_list(a2c_common.swap_and_flatten01, self.tensor_list)
        batch_dict['returns'] = a2c_common.swap_and_flatten01(mb_returns)
        batch_dict['played_frames'] = self.batch_size

        return batch_dict
    
    def _load_config_params(self, config):
        super()._load_config_params(config)
        
        self._task_reward_w = config['task_reward_w']
        self._disc_reward_w = config['disc_reward_w']
        return

    def _get_mean_rewards(self):
        rewards = super()._get_mean_rewards()
        rewards *= self._llc_steps
        return rewards

    def _setup_action_space(self):
        super()._setup_action_space()
        self.actions_num = self._latent_dim
        return

    def init_tensors(self):
        super().init_tensors()

        del self.experience_buffer.tensor_dict['actions']
        del self.experience_buffer.tensor_dict['mus']
        del self.experience_buffer.tensor_dict['sigmas']

        batch_shape = self.experience_buffer.obs_base_shape
        self.experience_buffer.tensor_dict['actions'] = torch.zeros(batch_shape + (self._latent_dim,),
                                                                dtype=torch.float32, device=self.ppo_device)
        self.experience_buffer.tensor_dict['mus'] = torch.zeros(batch_shape + (self._latent_dim,),
                                                                dtype=torch.float32, device=self.ppo_device)
        self.experience_buffer.tensor_dict['sigmas'] = torch.zeros(batch_shape + (self._latent_dim,),
                                                                dtype=torch.float32, device=self.ppo_device)
        
        self.experience_buffer.tensor_dict['disc_rewards'] = torch.zeros_like(self.experience_buffer.tensor_dict['rewards'])
        self.tensor_list += ['disc_rewards']

        return

    def _build_llc(self, config_params, checkpoint_file):
        llc_agent_config = self._build_llc_agent_config(config_params)
        self._llc_agent = ase_agent.ASEAgent('llc', llc_agent_config)
        self._llc_agent.restore(checkpoint_file)
        print("Loaded LLC checkpoint from {:s}".format(checkpoint_file))
        self._llc_agent.set_eval()
        return

    def _build_llc_agent_config(self, config_params, network=None):
        llc_env_info = copy.deepcopy(self.env_info)
        obs_space = llc_env_info['observation_space']
        obs_size = obs_space.shape[0]
        obs_size -= self._task_size
        llc_env_info['observation_space'] = spaces.Box(obs_space.low[:obs_size], obs_space.high[:obs_size])

        params = config_params
        params['config']['network'] = network
        params['config']['num_actors'] = self.num_actors
        params['config']['features'] = {'observer' : self.algo_observer}
        params['config']['env_info'] = llc_env_info
        params['config']['device'] = self.device

        return params

    def _compute_llc_action(self, obs, actions):
        llc_obs = self._extract_llc_obs(obs)
        processed_obs = self._llc_agent._preproc_obs(llc_obs)

        z = torch.nn.functional.normalize(actions, dim=-1)
        mu, _ = self._llc_agent.model.eval_actor(obs=processed_obs, ase_latents=z)
        llc_action = mu
        llc_action = self._llc_agent.preprocess_actions(llc_action)

        return llc_action

    def _extract_llc_obs(self, obs):
        obs_size = obs.shape[-1]
        llc_obs = obs[..., :obs_size - self._task_size]
        return llc_obs

    def _calc_disc_reward(self, amp_obs):
        disc_reward = self._llc_agent._calc_disc_rewards(amp_obs)
        return disc_reward

    def _combine_rewards(self, task_rewards, disc_rewards): 
        combined_rewards = self._task_reward_w * task_rewards + \
                         + self._disc_reward_w * disc_rewards
        
        #combined_rewards = task_rewards * disc_rewards
        return combined_rewards

    def _record_train_batch_info(self, batch_dict, train_info):
        super()._record_train_batch_info(batch_dict, train_info)
        train_info['disc_rewards'] = batch_dict['disc_rewards']
        return

    def _log_train_info(self, train_info, frame):
        super()._log_train_info(train_info, frame)

        disc_reward_std, disc_reward_mean = torch.std_mean(train_info['disc_rewards'])
        self.writer.add_scalar('info/disc_reward_mean', disc_reward_mean.item(), frame)
        self.writer.add_scalar('info/disc_reward_std', disc_reward_std.item(), frame)
        return

================================================
FILE: timechamber/ase/hrl_models.py
================================================
# Copyright (c) 2018-2022, NVIDIA Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import torch.nn as nn
from rl_games.algos_torch.models import ModelA2CContinuousLogStd

class ModelHRLContinuous(ModelA2CContinuousLogStd):
    def __init__(self, network):
        super().__init__(network)
        return

    def build(self, config):
        net = self.network_builder.build('amp', **config)
        for name, _ in net.named_parameters():
            print(name)
        # print(f"ASE config: {config}")
        obs_shape = config['input_shape']
        normalize_value = config.get('normalize_value', False)
        normalize_input = config.get('normalize_input', False)
        value_size = config.get('value_size', 1)
        return ModelHRLContinuous.Network(net, obs_shape=obs_shape, normalize_value=normalize_value,
                                          normalize_input=normalize_input, value_size=value_size)

    class Network(ModelA2CContinuousLogStd.Network):
        def __init__(self, a2c_network, obs_shape, normalize_value, normalize_input, value_size):
            super().__init__(a2c_network,
                             obs_shape=obs_shape,
                             normalize_value=normalize_value,
                             normalize_input=normalize_input, 
                             value_size=value_size)
            return

        def eval_critic(self, obs):
            processed_obs = self.norm_obs(obs)
            value = self.a2c_network.eval_critic(processed_obs)
            values = self.unnorm_value(value)
            return values

================================================
FILE: timechamber/ase/hrl_network_builder.py
================================================
# Copyright (c) 2018-2022, NVIDIA Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from rl_games.algos_torch import network_builder

import torch
import torch.nn as nn

from timechamber.ase import ase_network_builder

class HRLBuilder(network_builder.A2CBuilder):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        return

    class Network(network_builder.A2CBuilder.Network):
        def __init__(self, params, **kwargs):
            super().__init__(params, **kwargs)

            if self.is_continuous:
                if (not self.space_config['learn_sigma']):
                    actions_num = kwargs.get('actions_num')
                    sigma_init = self.init_factory.create(**self.space_config['sigma_init'])
                    self.sigma = nn.Parameter(torch.zeros(actions_num, requires_grad=False, dtype=torch.float32), requires_grad=False)
                    sigma_init(self.sigma)

            return
        
        def forward(self, obs_dict):
            mu, sigma, value, states = super().forward(obs_dict)
            norm_mu = torch.tanh(mu)
            return norm_mu, sigma, value, states

        def eval_critic(self, obs):
            c_out = self.critic_cnn(obs)
            c_out = c_out.contiguous().view(c_out.size(0), -1)
            c_out = self.critic_mlp(c_out)              
            value = self.value_act(self.value(c_out))
            return value

    def build(self, name, **kwargs):
        net = HRLBuilder.Network(self.params, **kwargs)
        return net

================================================
FILE: timechamber/ase/hrl_players.py
================================================
# Copyright (c) 2018-2022, NVIDIA Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import copy
from gym import spaces
import numpy as np
import os
import torch 
import yaml
import time

from rl_games.algos_torch import players
from rl_games.algos_torch import torch_ext
from rl_games.algos_torch.running_mean_std import RunningMeanStd
from rl_games.common.player import BasePlayer

import timechamber.ase.utils.common_player as common_player
import timechamber.ase.ase_models as ase_models
import timechamber.ase.ase_network_builder as ase_network_builder
import timechamber.ase.ase_players as ase_players

class HRLPlayer(common_player.CommonPlayer):
    def __init__(self, params):
        config = params['config']
        with open(os.path.join(os.getcwd(), config['llc_config']), 'r') as f:
            llc_config = yaml.load(f, Loader=yaml.SafeLoader)
            llc_config_params = llc_config['params']
            self._latent_dim = llc_config_params['config']['latent_dim']

        super().__init__(params)

        self._task_size = self.env.task.get_task_obs_size()
        
        self._llc_steps = config['llc_steps']
        llc_checkpoint = config['llc_checkpoint']
        assert(llc_checkpoint != "")
        self._build_llc(llc_config_params, llc_checkpoint)

        return

    def get_action(self, obs_dict, is_determenistic = False):
        obs = obs_dict['obs']

        if len(obs.size()) == len(self.obs_shape):
            obs = obs.unsqueeze(0)
        proc_obs = self._preproc_obs(obs)
        input_dict = {
            'is_train': False,
            'prev_actions': None, 
            'obs' : proc_obs,
            'rnn_states' : self.states
        }
        with torch.no_grad():
            res_dict = self.model(input_dict)
        mu = res_dict['mus']
        action = res_dict['actions']
        self.states = res_dict['rnn_states']
        if is_determenistic:
            current_action = mu
        else:
            current_action = action
        current_action = torch.squeeze(current_action.detach())
        clamped_actions = torch.clamp(current_action, -1.0, 1.0)
        
        return clamped_actions

    def run(self):
        n_games = self.games_num
        render = self.render_env
        n_game_life = self.n_game_life
        is_determenistic = self.is_determenistic
        sum_rewards = 0
        sum_steps = 0
        sum_game_res = 0
        n_games = n_games * n_game_life
        games_played = 0
        has_masks = False
        has_masks_func = getattr(self.env, "has_action_mask", None) is not None

        op_agent = getattr(self.env, "create_agent", None)
        if op_agent:
            agent_inited = True

        if has_masks_func:
            has_masks = self.env.has_action_mask()

        need_init_rnn = self.is_rnn
        for _ in range(n_games):
            if games_played >= n_games:
                break

            obs_dict = self.env_reset()
            batch_size = 1
            if len(obs_dict['obs'].size()) > len(self.obs_shape):
                batch_size = obs_dict['obs'].size()[0]
            self.batch_size = batch_size

            if need_init_rnn:
                self.init_rnn()
                need_init_rnn = False

            cr = torch.zeros(batch_size, dtype=torch.float32)
            steps = torch.zeros(batch_size, dtype=torch.float32)

            print_game_res = False

            done_indices = []

            for n in range(self.max_steps):
                obs_dict = self.env_reset(done_indices)

                if has_masks:
                    masks = self.env.get_action_mask()
                    action = self.get_masked_action(obs_dict, masks, is_determenistic)
                else:
                    action = self.get_action(obs_dict, is_determenistic)
                obs_dict, r, done, info = self.env_step(self.env, obs_dict, action)
                cr += r
                steps += 1
  
                self._post_step(info)

                if render:
                    self.env.render(mode = 'human')
                    time.sleep(self.render_sleep)

                all_done_indices = done.nonzero(as_tuple=False)
                done_indices = all_done_indices[::self.num_agents]
                done_count = len(done_indices)
                games_played += done_count

                if done_count > 0:
                    if self.is_rnn:
                        for s in self.states:
                            s[:,all_done_indices,:] = s[:,all_done_indices,:] * 0.0

                    cur_rewards = cr[done_indices].sum().item()
                    cur_steps = steps[done_indices].sum().item()

                    cr = cr * (1.0 - done.float())
                    steps = steps * (1.0 - done.float())
                    sum_rewards += cur_rewards
                    sum_steps += cur_steps

                    game_res = 0.0
                    if isinstance(info, dict):
                        if 'battle_won' in info:
                            print_game_res = True
                            game_res = info.get('battle_won', 0.5)
                        if 'scores' in info:
                            print_game_res = True
                            game_res = info.get('scores', 0.5)
                    if self.print_stats:
                        if print_game_res:
                            print('reward:', cur_rewards/done_count, 'steps:', cur_steps/done_count, 'w:', game_res)
                        else:
                            print('reward:', cur_rewards/done_count, 'steps:', cur_steps/done_count)

                    sum_game_res += game_res
                    if batch_size//self.num_agents == 1 or games_played >= n_games:
                        break
        
                done_indices = done_indices[:, 0]

        print(sum_rewards)
        if print_game_res:
            print('av reward:', sum_rewards / games_played * n_game_life, 'av steps:', sum_steps / games_played * n_game_life, 'winrate:', sum_game_res / games_played * n_game_life)
        else:
            print('av reward:', sum_rewards / games_played * n_game_life, 'av steps:', sum_steps / games_played * n_game_life)

        return

    def env_step(self, env, obs_dict, action):
        if not self.is_tensor_obses:
            actions = actions.cpu().numpy()

        obs = obs_dict['obs']
        rewards = 0.0
        done_count = 0.0
        disc_rewards = 0.0
        for t in range(self._llc_steps):
            llc_actions = self._compute_llc_action(obs, action)
            obs, curr_rewards, curr_dones, infos = env.step(llc_actions)

            rewards += curr_rewards
            done_count += curr_dones

            amp_obs = infos['amp_obs']
            curr_disc_reward = self._calc_disc_reward(amp_obs)
            curr_disc_reward = curr_disc_reward[0, 0].cpu().numpy()
            disc_rewards += curr_disc_reward

        rewards /= self._llc_steps
        dones = torch.zeros_like(done_count)
        dones[done_count > 0] = 1.0

        disc_rewards /= self._llc_steps

        if isinstance(obs, dict):
            obs = obs['obs']
        if obs.dtype == np.float64:
            obs = np.float32(obs)
        if self.value_size > 1:
            rewards = rewards[0]
        if self.is_tensor_obses:
            return obs, rewards.cpu(), dones.cpu(), infos
        else:
            if np.isscalar(dones):
                rewards = np.expand_dims(np.asarray(rewards), 0)
                dones = np.expand_dims(np.asarray(dones), 0)
            return torch.from_numpy(obs).to(self.device), torch.from_numpy(rewards), torch.from_numpy(dones), infos

    def _build_llc(self, config_params, checkpoint_file):
        llc_agent_config = self._build_llc_agent_config(config_params)

        self._llc_agent = ase_players.ASEPlayer(llc_agent_config)
        self._llc_agent.restore(checkpoint_file)
        print("Loaded LLC checkpoint from {:s}".format(checkpoint_file))
        return

    def _build_llc_agent_config(self, config_params, network=None):
        llc_env_info = copy.deepcopy(self.env_info)
        obs_space = llc_env_info['observation_space']
        obs_size = obs_space.shape[0]
        obs_size -= self._task_size
        llc_env_info['observation_space'] = spaces.Box(obs_space.low[:obs_size], obs_space.high[:obs_size])
        llc_env_info['amp_observation_space'] = self.env.amp_observation_space.shape
        llc_env_info['num_envs'] = self.env.task.num_envs

        params = config_params
        params['config']['network'] = network
        params['config']['env_info'] = llc_env_info

        return params

    def _setup_action_space(self):
        super()._setup_action_space()
        self.actions_num = self._latent_dim
        return

    def _compute_llc_action(self, obs, actions):
        llc_obs = self._extract_llc_obs(obs)
        processed_obs = self._llc_agent._preproc_obs(llc_obs)

        z = torch.nn.functional.normalize(actions, dim=-1)
        mu, _ = self._llc_agent.model.eval_actor(obs=processed_obs, ase_latents=z)
        llc_action = players.rescale_actions(self.actions_low, self.actions_high, torch.clamp(mu, -1.0, 1.0))

        return llc_action

    def _extract_llc_obs(self, obs):
        obs_size = obs.shape[-1]
        llc_obs = obs[..., :obs_size - self._task_size]
        return llc_obs

    def _calc_disc_reward(self, amp_obs):
        disc_reward = self._llc_agent._calc_disc_rewards(amp_obs)
        return disc_reward


================================================
FILE: timechamber/ase/utils/amp_agent.py
================================================
# Copyright (c) 2018-2022, NVIDIA Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from rl_games.algos_torch.running_mean_std import RunningMeanStd
from rl_games.algos_torch import torch_ext
from rl_games.common import a2c_common
from rl_games.common import schedulers
from rl_games.common import vecenv

from isaacgym.torch_utils import *

import time
from datetime import datetime
import numpy as np
from torch import optim
import torch 
from torch import nn

import timechamber.ase.utils.replay_buffer as replay_buffer
import timechamber.ase.utils.common_agent as common_agent 

from tensorboardX import SummaryWriter

class AMPAgent(common_agent.CommonAgent):
    def __init__(self, base_name, params):
        super().__init__(base_name, params)

        if self._normalize_amp_input:
            self._amp_input_mean_std = RunningMeanStd(self._amp_observation_space.shape).to(self.ppo_device)

        return

    def init_tensors(self):
        super().init_tensors()
        self._build_amp_buffers()
        return
    
    def set_eval(self):
        super().set_eval()
        if self._normalize_amp_input:
            self._amp_input_mean_std.eval()
        return

    def set_train(self):
        super().set_train()
        if self._normalize_amp_input:
            self._amp_input_mean_std.train()
        return

    def get_stats_weights(self):
        state = super().get_stats_weights()
        if self._normalize_amp_input:
            state['amp_input_mean_std'] = self._amp_input_mean_std.state_dict()
        
        return state

    def set_stats_weights(self, weights):
        super().set_stats_weights(weights)
        if self._normalize_amp_input:
            self._amp_input_mean_std.load_state_dict(weights['amp_input_mean_std'])
        
        return

    def play_steps(self):
        self.set_eval()

        epinfos = []
        done_indices = []
        update_list = self.update_list

        for n in range(self.horizon_length):

            self.obs = self.env_reset(done_indices)
            self.experience_buffer.update_data('obses', n, self.obs['obs'])

            if self.use_action_masks:
                masks = self.vec_env.get_action_masks()
                res_dict = self.get_masked_action_values(self.obs, masks)
            else:
                res_dict = self.get_action_values(self.obs, self._rand_action_probs)

            for k in update_list:
                self.experience_buffer.update_data(k, n, res_dict[k]) 

            if self.has_central_value:
                self.experience_buffer.update_data('states', n, self.obs['states'])

            self.obs, rewards, self.dones, infos = self.env_step(res_dict['actions'])
            shaped_rewards = self.rewards_shaper(rewards)
            self.experience_buffer.update_data('rewards', n, shaped_rewards)
            self.experience_buffer.update_data('next_obses', n, self.obs['obs'])
            self.experience_buffer.update_data('dones', n, self.dones)
            self.experience_buffer.update_data('amp_obs', n, infos['amp_obs'])
            self.experience_buffer.update_data('rand_action_mask', n, res_dict['rand_action_mask'])

            terminated = infos['terminate'].float()
            terminated = terminated.unsqueeze(-1)
            next_vals = self._eval_critic(self.obs)
            next_vals *= (1.0 - terminated)
            self.experience_buffer.update_data('next_values', n, next_vals)

            self.current_rewards += rewards
            self.current_lengths += 1
            all_done_indices = self.dones.nonzero(as_tuple=False)
            done_indices = all_done_indices[::self.num_agents]
  
            self.game_rewards.update(self.current_rewards[done_indices])
            self.game_lengths.update(self.current_lengths[done_indices])
            self.algo_observer.process_infos(infos, done_indices)

            not_dones = 1.0 - self.dones.float()

            self.current_rewards = self.current_rewards * not_dones.unsqueeze(1)
            self.current_lengths = self.current_lengths * not_dones
            
            if (self.vec_env.env.task.viewer):
                self._amp_debug(infos)
                
            done_indices = done_indices[:, 0]

        mb_fdones = self.experience_buffer.tensor_dict['dones'].float()
        mb_values = self.experience_buffer.tensor_dict['values']
        mb_next_values = self.experience_buffer.tensor_dict['next_values']

        mb_rewards = self.experience_buffer.tensor_dict['rewards']
        mb_amp_obs = self.experience_buffer.tensor_dict['amp_obs']
        amp_rewards = self._calc_amp_rewards(mb_amp_obs)
        mb_rewards = self._combine_rewards(mb_rewards, amp_rewards)

        mb_advs = self.discount_values(mb_fdones, mb_values, mb_rewards, mb_next_values)
        mb_returns = mb_advs + mb_values

        batch_dict = self.experience_buffer.get_transformed_list(a2c_common.swap_and_flatten01, self.tensor_list)
        batch_dict['returns'] = a2c_common.swap_and_flatten01(mb_returns)
        batch_dict['played_frames'] = self.batch_size

        for k, v in amp_rewards.items():
            batch_dict[k] = a2c_common.swap_and_flatten01(v)

        return batch_dict
    
    def get_action_values(self, obs_dict, rand_action_probs):
        processed_obs = self._preproc_obs(obs_dict['obs'])

        self.model.eval()
        input_dict = {
            'is_train': False,
            'prev_actions': None, 
            'obs' : processed_obs,
            'rnn_states' : self.rnn_states
        }

        with torch.no_grad():
            res_dict = self.model(input_dict)
            if self.has_central_value:
                states = obs_dict['states']
                input_dict = {
                    'is_train': False,
                    'states' : states,
                }
                value = self.get_central_value(input_dict)
                res_dict['values'] = value

        if self.normalize_value:
            res_dict['values'] = self.value_mean_std(res_dict['values'], True)
        
        rand_action_mask = torch.bernoulli(rand_action_probs)
        det_action_mask = rand_action_mask == 0.0
        res_dict['actions'][det_action_mask] = res_dict['mus'][det_action_mask]
        res_dict['rand_action_mask'] = rand_action_mask

        return res_dict

    def prepare_dataset(self, batch_dict):
        super().prepare_dataset(batch_dict)
        self.dataset.values_dict['amp_obs'] = batch_dict['amp_obs']
        self.dataset.values_dict['amp_obs_demo'] = batch_dict['amp_obs_demo']
        self.dataset.values_dict['amp_obs_replay'] = batch_dict['amp_obs_replay']
        
        rand_action_mask = batch_dict['rand_action_mask']
        self.dataset.values_dict['rand_action_mask'] = rand_action_mask
        return

    def train_epoch(self):
        play_time_start = time.time()

        with torch.no_grad():
            if self.is_rnn:
                batch_dict = self.play_steps_rnn()
            else:
                batch_dict = self.play_steps() 

        play_time_end = time.time()
        update_time_start = time.time()
        rnn_masks = batch_dict.get('rnn_masks', None)
        
        self._update_amp_demos()
        num_obs_samples = batch_dict['amp_obs'].shape[0]
        amp_obs_demo = self._amp_obs_demo_buffer.sample(num_obs_samples)['amp_obs']
        batch_dict['amp_obs_demo'] = amp_obs_demo

        if (self._amp_replay_buffer.get_total_count() == 0):
            batch_dict['amp_obs_replay'] = batch_dict['amp_obs']
        else:
            batch_dict['amp_obs_replay'] = self._amp_replay_buffer.sample(num_obs_samples)['amp_obs']

        self.set_train()

        self.curr_frames = batch_dict.pop('played_frames')
        self.prepare_dataset(batch_dict)
        self.algo_observer.after_steps()

        if self.has_central_value:
            self.train_central_value()

        train_info = None

        if self.is_rnn:
            frames_mask_ratio = rnn_masks.sum().item() / (rnn_masks.nelement())
            print(frames_mask_ratio)

        for _ in range(0, self.mini_epochs_num):
            ep_kls = []
            for i in range(len(self.dataset)):
                curr_train_info = self.train_actor_critic(self.dataset[i])
                
                if self.schedule_type == 'legacy':  
                    if self.multi_gpu:
                        curr_train_info['kl'] = self.hvd.average_value(curr_train_info['kl'], 'ep_kls')
                    self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, curr_train_info['kl'].item())
                    self.update_lr(self.last_lr)

                if (train_info is None):
                    train_info = dict()
                    for k, v in curr_train_info.items():
                        train_info[k] = [v]
                else:
                    for k, v in curr_train_info.items():
                        train_info[k].append(v)
            
            av_kls = torch_ext.mean_list(train_info['kl'])

            if self.schedule_type == 'standard':
                if self.multi_gpu:
                    av_kls = self.hvd.average_value(av_kls, 'ep_kls')
                self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item())
                self.update_lr(self.last_lr)

        if self.schedule_type == 'standard_epoch':
            if self.multi_gpu:
                av_kls = self.hvd.average_value(torch_ext.mean_list(kls), 'ep_kls')
            self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item())
            self.update_lr(self.last_lr)

        update_time_end = time.time()
        play_time = play_time_end - play_time_start
        update_time = update_time_end - update_time_start
        total_time = update_time_end - play_time_start

        self._store_replay_amp_obs(batch_dict['amp_obs'])

        train_info['play_time'] = play_time
        train_info['update_time'] = update_time
        train_info['total_time'] = total_time
        self._record_train_batch_info(batch_dict, train_info)

        return train_info

    def calc_gradients(self, input_dict):
        self.set_train()

        value_preds_batch = input_dict['old_values']
        old_action_log_probs_batch = input_dict['old_logp_actions']
        advantage = input_dict['advantages']
        old_mu_batch = input_dict['mu']
        old_sigma_batch = input_dict['sigma']
        return_batch = input_dict['returns']
        actions_batch = input_dict['actions']
        obs_batch = input_dict['obs']
        obs_batch = self._preproc_obs(obs_batch)

        amp_obs = input_dict['amp_obs'][0:self._amp_minibatch_size]
        amp_obs = self._preproc_amp_obs(amp_obs)
        amp_obs_replay = input_dict['amp_obs_replay'][0:self._amp_minibatch_size]
        amp_obs_replay = self._preproc_amp_obs(amp_obs_replay)

        amp_obs_demo = input_dict['amp_obs_demo'][0:self._amp_minibatch_size]
        amp_obs_demo = self._preproc_amp_obs(amp_obs_demo)
        amp_obs_demo.requires_grad_(True)
        
        rand_action_mask = input_dict['rand_action_mask']
        rand_action_sum = torch.sum(rand_action_mask)

        lr = self.last_lr
        kl = 1.0
        lr_mul = 1.0
        curr_e_clip = lr_mul * self.e_clip

        batch_dict = {
            'is_train': True,
            'prev_actions': actions_batch, 
            'obs' : obs_batch,
            'amp_obs' : amp_obs,
            'amp_obs_replay' : amp_obs_replay,
            'amp_obs_demo' : amp_obs_demo
        }

        rnn_masks = None
        if self.is_rnn:
            rnn_masks = input_dict['rnn_masks']
            batch_dict['rnn_states'] = input_dict['rnn_states']
            batch_dict['seq_length'] = self.seq_len

        with torch.cuda.amp.autocast(enabled=self.mixed_precision):
            res_dict = self.model(batch_dict)
            action_log_probs = res_dict['prev_neglogp']
            values = res_dict['values']
            entropy = res_dict['entropy']
            mu = res_dict['mus']
            sigma = res_dict['sigmas']
            disc_agent_logit = res_dict['disc_agent_logit']
            disc_agent_replay_logit = res_dict['disc_agent_replay_logit']
            disc_demo_logit = res_dict['disc_demo_logit']

            a_info = self._actor_loss(old_action_log_probs_batch, action_log_probs, advantage, curr_e_clip)
            a_loss = a_info['actor_loss']
            a_clipped = a_info['actor_clipped'].float()

            c_info = self._critic_loss(value_preds_batch, values, curr_e_clip, return_batch, self.clip_value)
            c_loss = c_info['critic_loss']

            b_loss = self.bound_loss(mu)
            
            c_loss = torch.mean(c_loss)
            a_loss = torch.sum(rand_action_mask * a_loss) / rand_action_sum
            entropy = torch.sum(rand_action_mask * entropy) / rand_action_sum
            b_loss = torch.sum(rand_action_mask * b_loss) / rand_action_sum
            a_clip_frac = torch.sum(rand_action_mask * a_clipped) / rand_action_sum

            disc_agent_cat_logit = torch.cat([disc_agent_logit, disc_agent_replay_logit], dim=0)
            disc_info = self._disc_loss(disc_agent_cat_logit, disc_demo_logit, amp_obs_demo)
            disc_loss = disc_info['disc_loss']

            loss = a_loss + self.critic_coef * c_loss - self.entropy_coef * entropy + self.bounds_loss_coef * b_loss \
                 + self._disc_coef * disc_loss
            
            a_info['actor_loss'] = a_loss
            a_info['actor_clip_frac'] = a_clip_frac
            c_info['critic_loss'] = c_loss

            if self.multi_gpu:
                self.optimizer.zero_grad()
            else:
                for param in self.model.parameters():
                    param.grad = None

        self.scaler.scale(loss).backward()
        #TODO: Refactor this ugliest code of the year
        if self.truncate_grads:
            if self.multi_gpu:
                self.optimizer.synchronize()
                self.scaler.unscale_(self.optimizer)
                nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)
                with self.optimizer.skip_synchronize():
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
            else:
                self.scaler.unscale_(self.optimizer)
                nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)
                self.scaler.step(self.optimizer)
                self.scaler.update()    
        else:
            self.scaler.step(self.optimizer)
            self.scaler.update()

        with torch.no_grad():
            reduce_kl = not self.is_rnn
            kl_dist = torch_ext.policy_kl(mu.detach(), sigma.detach(), old_mu_batch, old_sigma_batch, reduce_kl)
            if self.is_rnn:
                kl_dist = (kl_dist * rnn_masks).sum() / rnn_masks.numel()  #/ sum_mask
                    
        self.train_result = {
            'entropy': entropy,
            'kl': kl_dist,
            'last_lr': self.last_lr, 
            'lr_mul': lr_mul, 
            'b_loss': b_loss
        }
        self.train_result.update(a_info)
        self.train_result.update(c_info)
        self.train_result.update(disc_info)

        return

    def _load_config_params(self, config):
        super()._load_config_params(config)
        
        # when eps greedy is enabled, rollouts will be generated using a mixture of
        # a deterministic and stochastic actions. The deterministic actions help to
        # produce smoother, less noisy, motions that can be used to train a better
        # discriminator. If the discriminator is only trained with jittery motions
        # from noisy actions, it can learn to phone in on the jitteriness to
        # differential between real and fake samples.
        self._enable_eps_greedy = bool(config['enable_eps_greedy'])

        self._task_reward_w = config['task_reward_w']
        self._disc_reward_w = config['disc_reward_w']

        self._amp_observation_space = self.env_info['amp_observation_space']
        self._amp_batch_size = int(config['amp_batch_size'])
        self._amp_minibatch_size = int(config['amp_minibatch_size'])
        assert(self._amp_minibatch_size <= self.minibatch_size)

        self._disc_coef = config['disc_coef']
        self._disc_logit_reg = config['disc_logit_reg']
        self._disc_grad_penalty = config['disc_grad_penalty']
        self._disc_weight_decay = config['disc_weight_decay']
        self._disc_reward_scale = config['disc_reward_scale']
        self._normalize_amp_input = config.get('normalize_amp_input', True)
        return

    def _build_net_config(self):
        config = super()._build_net_config()
        config['amp_input_shape'] = self._amp_observation_space.shape
        return config
    
    def _build_rand_action_probs(self):
        num_envs = self.vec_env.env.task.num_envs
        env_ids = to_torch(np.arange(num_envs), dtype=torch.float32, device=self.ppo_device)

        self._rand_action_probs = 1.0 - torch.exp(10 * (env_ids / (num_envs - 1.0) - 1.0))
        self._rand_action_probs[0] = 1.0
        self._rand_action_probs[-1] = 0.0
        
        if not self._enable_eps_greedy:
            self._rand_action_probs[:] = 1.0

        return

    def _init_train(self):
        super()._init_train()
        self._init_amp_demo_buf()
        return

    def _disc_loss(self, disc_agent_logit, disc_demo_logit, obs_demo):
        # prediction loss
        disc_loss_agent = self._disc_loss_neg(disc_agent_logit)
        disc_loss_demo = self._disc_loss_pos(disc_demo_logit)
        disc_loss = 0.5 * (disc_loss_agent + disc_loss_demo)

        # logit reg
        logit_weights = self.model.a2c_network.get_disc_logit_weights()
        disc_logit_loss = torch.sum(torch.square(logit_weights))
        disc_loss += self._disc_logit_reg * disc_logit_loss

        # grad penalty
        disc_demo_grad = torch.autograd.grad(disc_demo_logit, obs_demo, grad_outputs=torch.ones_like(disc_demo_logit),
                                             create_graph=True, retain_graph=True, only_inputs=True)
        disc_demo_grad = disc_demo_grad[0]
        disc_demo_grad = torch.sum(torch.square(disc_demo_grad), dim=-1)
        disc_grad_penalty = torch.mean(disc_demo_grad)
        disc_loss += self._disc_grad_penalty * disc_grad_penalty

        # weight decay
        if (self._disc_weight_decay != 0):
            disc_weights = self.model.a2c_network.get_disc_weights()
            disc_weights = torch.cat(disc_weights, dim=-1)
            disc_weight_decay = torch.sum(torch.square(disc_weights))
            disc_loss += self._disc_weight_decay * disc_weight_decay

        disc_agent_acc, disc_demo_acc = self._compute_disc_acc(disc_agent_logit, disc_demo_logit)

        disc_info = {
            'disc_loss': disc_loss,
            'disc_grad_penalty': disc_grad_penalty.detach(),
            'disc_logit_loss': disc_logit_loss.detach(),
            'disc_agent_acc': disc_agent_acc.detach(),
            'disc_demo_acc': disc_demo_acc.detach(),
            'disc_agent_logit': disc_agent_logit.detach(),
            'disc_demo_logit': disc_demo_logit.detach()
        }
        return disc_info

    def _disc_loss_neg(self, disc_logits):
        bce = torch.nn.BCEWithLogitsLoss()
        loss = bce(disc_logits, torch.zeros_like(disc_logits))
        return loss
    
    def _disc_loss_pos(self, disc_logits):
        bce = torch.nn.BCEWithLogitsLoss()
        loss = bce(disc_logits, torch.ones_like(disc_logits))
        return loss

    def _compute_disc_acc(self, disc_agent_logit, disc_demo_logit):
        agent_acc = disc_agent_logit < 0
        agent_acc = torch.mean(agent_acc.float())
        demo_acc = disc_demo_logit > 0
        demo_acc = torch.mean(demo_acc.float())
        return agent_acc, demo_acc

    def _fetch_amp_obs_demo(self, num_samples):
        amp_obs_demo = self.vec_env.env.fetch_amp_obs_demo(num_samples)
        return amp_obs_demo

    def _build_amp_buffers(self):
        batch_shape = self.experience_buffer.obs_base_shape
        self.experience_buffer.tensor_dict['amp_obs'] = torch.zeros(batch_shape + self._amp_observation_space.shape,
                                                                    device=self.ppo_device)
        self.experience_buffer.tensor_dict['rand_action_mask'] = torch.zeros(batch_shape, dtype=torch.float32, device=self.ppo_device)
        
        amp_obs_demo_buffer_size = int(self.config['amp_obs_demo_buffer_size'])
        self._amp_obs_demo_buffer = replay_buffer.ReplayBuffer(amp_obs_demo_buffer_size, self.ppo_device)

        self._amp_replay_keep_prob = self.config['amp_replay_keep_prob']
        replay_buffer_size = int(self.config['amp_replay_buffer_size'])
        self._amp_replay_buffer = replay_buffer.ReplayBuffer(replay_buffer_size, self.ppo_device)
        
        self._build_rand_action_probs()
        
        self.tensor_list += ['amp_obs', 'rand_action_mask']
        return

    def _init_amp_demo_buf(self):
        buffer_size = self._amp_obs_demo_buffer.get_buffer_size()
        num_batches = int(np.ceil(buffer_size / self._amp_batch_size))

        for i in range(num_batches):
            curr_samples = self._fetch_amp_obs_demo(self._amp_batch_size)
            self._amp_obs_demo_buffer.store({'amp_obs': curr_samples})

        return
    
    def _update_amp_demos(self):
        new_amp_obs_demo = self._fetch_amp_obs_demo(self._amp_batch_size)
        self._amp_obs_demo_buffer.store({'amp_obs': new_amp_obs_demo})
        return

    def _preproc_amp_obs(self, amp_obs):
        if self._normalize_amp_input:
            amp_obs = self._amp_input_mean_std(amp_obs)
        return amp_obs

    def _combine_rewards(self, task_rewards, amp_rewards):
        disc_r = amp_rewards['disc_rewards']
        
        combined_rewards = self._task_reward_w * task_rewards + \
                         + self._disc_reward_w * disc_r
        return combined_rewards

    def _eval_disc(self, amp_obs):
        proc_amp_obs = self._preproc_amp_obs(amp_obs)
        return self.model.a2c_network.eval_disc(proc_amp_obs)
    
    def _calc_advs(self, batch_dict):
        returns = batch_dict['returns']
        values = batch_dict['values']
        rand_action_mask = batch_dict['rand_action_mask']

        advantages = returns - values
        advantages = torch.sum(advantages, axis=1)
        if self.normalize_advantage:
            advantages = torch_ext.normalization_with_masks(advantages, rand_action_mask)

        return advantages

    def _calc_amp_rewards(self, amp_obs):
        disc_r = self._calc_disc_rewards(amp_obs)
        output = {
            'disc_rewards': disc_r
        }
        return output

    def _calc_disc_rewards(self, amp_obs):
        with torch.no_grad():
            disc_logits = self._eval_disc(amp_obs)
            prob = 1 / (1 + torch.exp(-disc_logits)) 
            disc_r = -torch.log(torch.maximum(1 - prob, torch.tensor(0.0001, device=self.ppo_device)))
            disc_r *= self._disc_reward_scale

        return disc_r

    def _store_replay_amp_obs(self, amp_obs):
        buf_size = self._amp_replay_buffer.get_buffer_size()
        buf_total_count = self._amp_replay_buffer.get_total_count()
        if (buf_total_count > buf_size):
            keep_probs = to_torch(np.array([self._amp_replay_keep_prob] * amp_obs.shape[0]), device=self.ppo_device)
            keep_mask = torch.bernoulli(keep_probs) == 1.0
            amp_obs = amp_obs[keep_mask]

        if (amp_obs.shape[0] > buf_size):
            rand_idx = torch.randperm(amp_obs.shape[0])
            rand_idx = rand_idx[:buf_size]
            amp_obs = amp_obs[rand_idx]

        self._amp_replay_buffer.store({'amp_obs': amp_obs})
        return

    
    def _record_train_batch_info(self, batch_dict, train_info):
        super()._record_train_batch_info(batch_dict, train_info)
        train_info['disc_rewards'] = batch_dict['disc_rewards']
        return

    def _log_train_info(self, train_info, frame):
        super()._log_train_info(train_info, frame)

        self.writer.add_scalar('losses/disc_loss', torch_ext.mean_list(train_info['disc_loss']).item(), frame)

        self.writer.add_scalar('info/disc_agent_acc', torch_ext.mean_list(train_info['disc_agent_acc']).item(), frame)
        self.writer.add_scalar('info/disc_demo_acc', torch_ext.mean_list(train_info['disc_demo_acc']).item(), frame)
        self.writer.add_scalar('info/disc_agent_logit', torch_ext.mean_list(train_info['disc_agent_logit']).item(), frame)
        self.writer.add_scalar('info/disc_demo_logit', torch_ext.mean_list(train_info['disc_demo_logit']).item(), frame)
        self.writer.add_scalar('info/disc_grad_penalty', torch_ext.mean_list(train_info['disc_grad_penalty']).item(), frame)
        self.writer.add_scalar('info/disc_logit_loss', torch_ext.mean_list(train_info['disc_logit_loss']).item(), frame)

        disc_reward_std, disc_reward_mean = torch.std_mean(train_info['disc_rewards'])
        self.writer.add_scalar('info/disc_reward_mean', disc_reward_mean.item(), frame)
        self.writer.add_scalar('info/disc_reward_std', disc_reward_std.item(), frame)
        return

    def _amp_debug(self, info):
        with torch.no_grad():
            amp_obs = info['amp_obs']
            amp_obs = amp_obs[0:1]
            disc_pred = self._eval_disc(amp_obs)
            amp_rewards = self._calc_amp_rewards(amp_obs)
            disc_reward = amp_rewards['disc_rewards']

            disc_pred = disc_pred.detach().cpu().numpy()[0, 0]
            disc_reward = disc_reward.cpu().numpy()[0, 0]
            print("disc_pred: ", disc_pred, disc_reward)
        return

================================================
FILE: timechamber/ase/utils/amp_datasets.py
================================================
# Copyright (c) 2018-2022, NVIDIA Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import torch
from rl_games.common import datasets

class AMPDataset(datasets.PPODataset):
    def __init__(self, batch_size, minibatch_size, is_discrete, is_rnn, device, seq_len):
        super().__init__(batch_size, minibatch_size, is_discrete, is_rnn, device, seq_len)
        self._idx_buf = torch.randperm(batch_size)
        return
    
    def update_mu_sigma(self, mu, sigma):	  
        raise NotImplementedError()
        return

    def _get_item(self, idx):
        start = idx * self.minibatch_size
        end = (idx + 1) * self.minibatch_size
        sample_idx = self._idx_buf[start:end]

        input_dict = {}
        for k,v in self.values_dict.items():
            if k not in self.special_names and v is not None:
                input_dict[k] = v[sample_idx]
                
        if (end >= self.batch_size):
            self._shuffle_idx_buf()

        return input_dict

    def _shuffle_idx_buf(self):
        self._idx_buf[:] = torch.randperm(self.batch_size)
        return

================================================
FILE: timechamber/ase/utils/amp_models.py
================================================
# Copyright (c) 2018-2022, NVIDIA Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import torch.nn as nn
from rl_games.algos_torch.models import ModelA2CContinuousLogStd


class ModelAMPContinuous(ModelA2CContinuousLogStd):
    def __init__(self, network):
        super().__init__(network)
        return

    def build(self, config):
        net = self.network_builder.build('amp', **config)
        for name, _ in net.named_parameters():
            print(name)
        # print(f"AMP config: {config}")
        obs_shape = config['input_shape']
        normalize_value = config.get('normalize_value', False)
        normalize_input = config.get('normalize_input', False)
        value_size = config.get('value_size', 1)

        return ModelAMPContinuous.Network(net, obs_shape=obs_shape, normalize_value=normalize_value,
                                          normalize_input=normalize_input, value_size=value_size)

    class Network(ModelA2CContinuousLogStd.Network):
        def __init__(self, a2c_network, obs_shape, normalize_value, normalize_input, value_size):
            super().__init__(a2c_network, obs_shape=obs_shape, 
                             normalize_value=normalize_value,
                             normalize_input=normalize_input, 
                             value_size=value_size)
            return

        def forward(self, input_dict):
            is_train = input_dict.get('is_train', True)
            result = super().forward(input_dict)

            if (is_train):
                amp_obs = input_dict['amp_obs']
                disc_agent_logit = self.a2c_network.eval_disc(amp_obs)
                result["disc_agent_logit"] = disc_agent_logit

                amp_obs_replay = input_dict['amp_obs_replay']
                disc_agent_replay_logit = self.a2c_network.eval_disc(amp_obs_replay)
                result["disc_agent_replay_logit"] = disc_agent_replay_logit

                amp_demo_obs = input_dict['amp_obs_demo']
                disc_demo_logit = self.a2c_network.eval_disc(amp_demo_obs)
                result["disc_demo_logit"] = disc_demo_logit

            return result
    
        def eval_actor(self, obs):
            processed_obs = self.norm_obs(obs)
            mu, sigma = self.a2c_network.eval_actor(obs=processed_obs)
            return mu, sigma

        def eval_critic(self, obs):
            processed_obs = self.norm_obs(obs)
            value = self.a2c_network.eval_critic(processed_obs)
            return value

================================================
FILE: timechamber/ase/utils/amp_network_builder.py
================================================
# Copyright (c) 2018-2022, NVIDIA Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from rl_games.algos_torch import torch_ext
from rl_games.algos_torch import layers
from rl_games.algos_torch import network_builder

import torch
import torch.nn as nn
import numpy as np

DISC_LOGIT_INIT_SCALE = 1.0

class AMPBuilder(network_builder.A2CBuilder):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        return

    class Network(network_builder.A2CBuilder.Network):
        def __init__(self, params, **kwargs):
            super().__init__(params, **kwargs)

            if self.is_continuous:
                if (not self.space_config['learn_sigma']):
                    actions_num = kwargs.get('actions_num')
                    sigma_init = self.init_factory.create(**self.space_config['sigma_init'])
                    self.sigma = nn.Parameter(torch.zeros(actions_num, requires_grad=False, dtype=torch.float32), requires_grad=False)
                    sigma_init(self.sigma)

            amp_input_shape = kwargs.get('amp_input_shape')
            self._build_disc(amp_input_shape)

            return

        def load(self, params):
            super().load(params)

            self._disc_units = params['disc']['units']
            self._disc_activation = params['disc']['activation']
            self._disc_initializer = params['disc']['initializer']
            return

        def forward(self, obs_dict):
            obs = obs_dict['obs']
            states = obs_dict.get('rnn_states', None)

            actor_outputs = self.eval_actor(obs)
            value = self.eval_critic(obs)

            output = actor_outputs + (value, states)

            return output

        def eval_actor(self, obs):
            a_out = self.actor_cnn(obs)
            a_out = a_out.contiguous().view(a_out.size(0), -1)
            a_out = self.actor_mlp(a_out)
                     
            if self.is_discrete:
                logits = self.logits(a_out)
                return logits

            if self.is_multi_discrete:
                logits = [logit(a_out) for logit in self.logits]
                return logits

            if self.is_continuous:
                mu = self.mu_act(self.mu(a_out))
                if self.space_config['fixed_sigma']:
                    sigma = mu * 0.0 + self.sigma_act(self.sigma)
                else:
                    sigma = self.sigma_act(self.sigma(a_out))

                return mu, sigma
            return

        def eval_critic(self, obs):
            c_out = self.critic_cnn(obs)
            c_out = c_out.contiguous().view(c_out.size(0), -1)
            c_out = self.critic_mlp(c_out)              
            value = self.value_act(self.value(c_out))
            return value

        def eval_disc(self, amp_obs):
            disc_mlp_out = self._disc_mlp(amp_obs)
            disc_logits = self._disc_logits(disc_mlp_out)
            return disc_logits

        def get_disc_logit_weights(self):
            return torch.flatten(self._disc_logits.weight)

        def get_disc_weights(self):
            weights = []
            for m in self._disc_mlp.modules():
                if isinstance(m, nn.Linear):
                    weights.append(torch.flatten(m.weight))

            weights.append(torch.flatten(self._disc_logits.weight))
            return weights

        def _build_disc(self, input_shape):
            self._disc_mlp = nn.Sequential()

            mlp_args = {
                'input_size' : input_shape[0], 
                'units' : self._disc_units, 
                'activation' : self._disc_activation, 
                'dense_func' : torch.nn.Linear
            }
            self._disc_mlp = self._build_mlp(**mlp_args)

            mlp_out_size = self._disc_units[-1]
            self._disc_logits = torch.nn.Linear(mlp_out_size, 1)

            mlp_init = self.init_factory.create(**self._disc_initializer)
            for m in self._disc_mlp.modules():
                if isinstance(m, nn.Linear):
                    mlp_init(m.weight)
                    if getattr(m, "bias", None) is not None:
                        torch.nn.init.zeros_(m.bias) 

            torch.nn.init.uniform_(self._disc_logits.weight, -DISC_LOGIT_INIT_SCALE, DISC_LOGIT_INIT_SCALE)
            torch.nn.init.zeros_(self._disc_logits.bias) 

            return

    def build(self, name, **kwargs):
        net = AMPBuilder.Network(self.params, **kwargs)
        return net

================================================
FILE: timechamber/ase/utils/amp_players.py
================================================
# Copyright (c) 2018-2022, NVIDIA Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import torch 

from rl_games.algos_torch import torch_ext
from rl_games.algos_torch.running_mean_std import RunningMeanStd

import timechamber.ase.utils.common_player as common_player

class AMPPlayerContinuous(common_player.CommonPlayer):
    def __init__(self, params):
        config = params['config']
        self._normalize_amp_input = config.get('normalize_amp_input', True)
        self._disc_reward_scale = config['disc_reward_scale']
        
        super().__init__(params)
        return

    def restore(self, fn):
        if (fn != 'Base'):
            super().restore(fn)
            if self._normalize_amp_input:
                checkpoint = torch_ext.load_checkpoint(fn)
                self._amp_input_mean_std.load_state_dict(checkpoint['amp_input_mean_std'])
        return
    
    def _build_net(self, config):
        super()._build_net(config)
        
        if self._normalize_amp_input:
            self._amp_input_mean_std = RunningMeanStd(config['amp_input_shape']).to(self.device)
            self._amp_input_mean_std.eval()  
        
        return

    def _post_step(self, info):
        super()._post_step(info)
        if (self.env.task.viewer):
            self._amp_debug(info)
        return

    def _build_net_config(self):
        config = super()._build_net_config()
        if (hasattr(self, 'env')) and self.env is not None:
            config['amp_input_shape'] = self.env.amp_observation_space.shape
        else:
            config['amp_input_shape'] = self.env_info['amp_observation_space']
        return config

    def _amp_debug(self, info):
        with torch.no_grad():
            amp_obs = info['amp_obs']
            amp_obs = amp_obs[0:1]
            disc_pred = self._eval_disc(amp_obs)
            amp_rewards = self._calc_amp_rewards(amp_obs)
            disc_reward = amp_rewards['disc_rewards']

            disc_pred = disc_pred.detach().cpu().numpy()[0, 0]
            disc_reward = disc_reward.cpu().numpy()[0, 0]
            print("disc_pred: ", disc_pred, disc_reward)

        return

    def _preproc_amp_obs(self, amp_obs):
        if self._normalize_amp_input:
            amp_obs = self._amp_input_mean_std(amp_obs)
        return amp_obs

    def _eval_disc(self, amp_obs):
        proc_amp_obs = self._preproc_amp_obs(amp_obs)
        return self.model.a2c_network.eval_disc(proc_amp_obs)

    def _calc_amp_rewards(self, amp_obs):
        disc_r = self._calc_disc_rewards(amp_obs)
        output = {
            'disc_rewards': disc_r
        }
        return output

    def _calc_disc_rewards(self, amp_obs):
        with torch.no_grad():
            disc_logits = self._eval_disc(amp_obs)
            prob = 1 / (1 + torch.exp(-disc_logits)) 
            disc_r = -torch.log(torch.maximum(1 - prob, torch.tensor(0.0001, device=self.device)))
            disc_r *= self._disc_reward_scale
        return disc_r


================================================
FILE: timechamber/ase/utils/common_agent.py
================================================
# Copyright (c) 2018-2022, NVIDIA Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import copy
from datetime import datetime
from gym import spaces
import numpy as np
import os
import time
import yaml

from rl_games.algos_torch import a2c_continuous
from rl_games.algos_torch import torch_ext
from rl_games.algos_torch import central_value
from rl_games.algos_torch.running_mean_std import RunningMeanStd
from rl_games.common import a2c_common
from rl_games.common import datasets
from rl_games.common import schedulers
from rl_games.common import vecenv

import torch
from torch import optim

import timechamber.ase.utils.amp_datasets as amp_datasets
from timechamber.utils.utils import load_check, load_checkpoint

from tensorboardX import SummaryWriter

class CommonAgent(a2c_continuous.A2CAgent):
    def __init__(self, base_name, params):
        a2c_common.A2CBase.__init__(self, base_name, params)
        self.config = config = params['config']
        self._load_config_params(config)

        self.is_discrete = False
        self._setup_action_space()
        self.bounds_loss_coef = config.get('bounds_loss_coef', None)
        self.clip_actions = config.get('clip_actions', True)
        self._save_intermediate = config.get('save_intermediate', False)

        net_config = self._build_net_config()
        self.model = self.network.build(net_config)
        self.model.to(self.ppo_device)
        self.states = None

        self.init_rnn_from_model(self.model)
        self.last_lr = float(self.last_lr)

        self.optimizer = optim.Adam(self.model.parameters(), float(self.last_lr), eps=1e-08, weight_decay=self.weight_decay)

        if self.normalize_input:
            obs_shape = torch_ext.shape_whc_to_cwh(self.obs_shape)
            self.running_mean_std = RunningMeanStd(obs_shape).to(self.ppo_device)
        if self.normalize_value:
            self.value_mean_std = self.central_value_net.model.value_mean_std if self.has_central_value else self.model.value_mean_std

        if self.has_central_value:
            cv_config = {
                'state_shape' : torch_ext.shape_whc_to_cwh(self.state_shape), 
                'value_size' : self.value_size,
                'ppo_device' : self.ppo_device, 
                'num_agents' : self.num_agents, 
                'horizon_length' : self.horizon_length, 
                'num_actors' : self.num_actors, 
                'num_actions' : self.actions_num, 
                'seq_len' : self.seq_len, 
                'model' : self.central_value_config['network'],
                'config' : self.central_value_config, 
                'writter' : self.writer,
                'multi_gpu' : self.multi_gpu
            }
            self.central_value_net = central_value.CentralValueTrain(**cv_config).to(self.ppo_device)

        self.use_experimental_cv = self.config.get('use_experimental_cv', True)
        self.dataset = amp_datasets.AMPDataset(self.batch_size, self.minibatch_size, self.is_discrete, self.is_rnn, self.ppo_device, self.seq_len)
        self.algo_observer.after_init(self)

        return

    def init_tensors(self):
        super().init_tensors()
        self.experience_buffer.tensor_dict['next_obses'] = torch.zeros_like(self.experience_buffer.tensor_dict['obses'])
        self.experience_buffer.tensor_dict['next_values'] = torch.zeros_like(self.experience_buffer.tensor_dict['values'])

        self.tensor_list += ['next_obses']
        return

    def train(self):
        self.init_tensors()
        self.last_mean_rewards = -100500
        start_time = time.time()
        total_time = 0
        rep_count = 0
        self.frame = 0
        self.obs = self.env_reset()
        self.curr_frames = self.batch_size_envs

        model_output_file = os.path.join(self.nn_dir, self.config['name'])

        if self.multi_gpu:
            self.hvd.setup_algo(self)

        self._init_train()

        while True:
            epoch_num = self.update_epoch()
            train_info = self.train_epoch()

            sum_time = train_info['total_time']
            total_time += sum_time
            frame = self.frame
            if self.multi_gpu:
                self.hvd.sync_stats(self)

            if self.rank == 0:
                scaled_time = sum_time
                scaled_play_time = train_info['play_time']
                curr_frames = self.curr_frames
                self.frame += curr_frames
                if self.print_stats:
                    fps_step = curr_frames / scaled_play_time
                    fps_total = curr_frames / scaled_time
                    print(f'fps step: {fps_step:.1f} fps total: {fps_total:.1f}')

                self.writer.add_scalar('performance/total_fps', curr_frames / scaled_time, frame)
                self.writer.add_scalar('performance/step_fps', curr_frames / scaled_play_time, frame)
                self.writer.add_scalar('info/epochs', epoch_num, frame)
                self._log_train_info(train_info, frame)

                self.algo_observer.after_print_stats(frame, epoch_num, total_time)
                
                if self.game_rewards.current_size > 0:
                    mean_rewards = self._get_mean_rewards()
                    mean_lengths = self.game_lengths.get_mean()

                    for i in range(self.value_size):
                        self.writer.add_scalar('rewards{0}/frame'.format(i), mean_rewards[i], frame)
                        self.writer.add_scalar('rewards{0}/iter'.format(i), mean_rewards[i], epoch_num)
                        self.writer.add_scalar('rewards{0}/time'.format(i), mean_rewards[i], total_time)

                    self.writer.add_scalar('episode_lengths/frame', mean_lengths, frame)
                    self.writer.add_scalar('episode_lengths/iter', mean_lengths, epoch_num)

                    if self.has_self_play_config:
                        self.self_play_manager.update(self)

                if self.save_freq > 0:
                    if (epoch_num % self.save_freq == 0):
                        self.save(model_output_file)

                        if (self._save_intermediate):
                            int_model_output_file = model_output_file + '_' + str(epoch_num).zfill(8)
                            self.save(int_model_output_file)

                if epoch_num > self.max_epochs:
                    self.save(model_output_file)
                    print('MAX EPOCHS NUM!')
                    return self.last_mean_rewards, epoch_num

                update_time = 0
        return

    def set_full_state_weights(self, weights):
        self.set_weights(weights)
        self.epoch_num = weights['epoch']
        if self.has_central_value:
            self.central_value_net.load_state_dict(weights['assymetric_vf_nets'])
        self.optimizer.load_state_dict(weights['optimizer'])
        self.frame = weights.get('frame', 0)
        self.last_mean_rewards = weights.get('last_mean_rewards', -100500)

        if self.vec_env is not None:
            env_state = weights.get('env_state', None)
            self.vec_env.set_env_state(env_state)

        return

    def restore(self, fn):
        checkpoint = load_checkpoint(fn, device=self.device)
        checkpoint = load_check(checkpoint=checkpoint,
                                normalize_input=self.normalize_input,
                                normalize_value=self.normalize_value)
        self.set_full_state_weights(checkpoint)

    def train_epoch(self):
        play_time_start = time.time()
        with torch.no_grad():
            if self.is_rnn:
                batch_dict = self.play_steps_rnn()
            else:
                batch_dict = self.play_steps() 

        play_time_end = time.time()
        update_time_start = time.time()
        rnn_masks = batch_dict.get('rnn_masks', None)

        self.set_train()

        self.curr_frames = batch_dict.pop('played_frames')
        self.prepare_dataset(batch_dict)
        self.algo_observer.after_steps()

        if self.has_central_value:
            self.train_central_value()

        train_info = None

        if self.is_rnn:
            frames_mask_ratio = rnn_masks.sum().item() / (rnn_masks.nelement())
            print(frames_mask_ratio)

        for _ in range(0, self.mini_epochs_num):
            ep_kls = []
            for i in range(len(self.dataset)):
                curr_train_info = self.train_actor_critic(self.dataset[i])
                
                if self.schedule_type == 'legacy':  
                    if self.multi_gpu:
                        curr_train_info['kl'] = self.hvd.average_value(curr_train_info['kl'], 'ep_kls')
                    self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, curr_train_info['kl'].item())
                    self.update_lr(self.last_lr)

                if (train_info is None):
                    train_info = dict()
                    for k, v in curr_train_info.items():
                        train_info[k] = [v]
                else:
                    for k, v in curr_train_info.items():
                        train_info[k].append(v)
            
            av_kls = torch_ext.mean_list(train_info['kl'])

            if self.schedule_type == 'standard':
                if self.multi_gpu:
                    av_kls = self.hvd.average_value(av_kls, 'ep_kls')
                self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item())
                self.update_lr(self.last_lr)

        if self.schedule_type == 'standard_epoch':
            if self.multi_gpu:
                av_kls = self.hvd.average_value(torch_ext.mean_list(kls), 'ep_kls')
            self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item())
            self.update_lr(self.last_lr)

        update_time_end = time.time()
        play_time = play_time_end - play_time_start
        update_time = update_time_end - update_time_start
        total_time = update_time_end - play_time_start

        train_info['step_time'] = batch_dict['step_time']
        train_info['play_time'] = play_time
        train_info['update_time'] = update_time
        train_info['total_time'] = total_time
        self._record_train_batch_info(batch_dict, train_info)

        return train_info

    def play_steps(self):
        self.set_eval()

        epinfos = []
        done_indices = []
        update_list = self.update_list

        for n in range(self.horizon_length):
            self.obs = self.env_reset(done_indices)
            self.experience_buffer.update_data('obses', n, self.obs['obs'])

            if self.use_action_masks:
                masks = self.vec_env.get_action_masks()
                res_dict = self.get_masked_action_values(self.obs, masks)
            else:
                res_dict = self.get_action_values(self.obs)

            for k in update_list:
                self.experience_buffer.update_data(k, n, res_dict[k])

            if self.has_central_value:
                self.experience_buffer.update_data('states', n, self.obs['states'])

            self.obs, rewards, self.dones, infos = self.env_step(res_dict['actions'])
            shaped_rewards = self.rewards_shaper(rewards)
            self.experience_buffer.update_data('rewards', n, shaped_rewards)
            self.experience_buffer.update_data('next_obses', n, self.obs['obs'])
            self.experience_buffer.update_data('dones', n, self.dones)

            terminated = infos['terminate'].float()
            terminated = terminated.unsqueeze(-1)
            next_vals = self._eval_critic(self.obs)
            next_vals *= (1.0 - terminated)
            self.experience_buffer.update_data('next_values', n, next_vals)

            self.current_rewards += rewards
            self.current_lengths += 1
            all_done_indices = self.dones.nonzero(as_tuple=False)
            done_indices = all_done_indices[::self.num_agents]
  
            self.game_rewards.update(self.current_rewards[done_indices])
            self.game_lengths.update(self.current_lengths[done_indices])
            self.algo_observer.process_infos(infos, done_indices)

            not_dones = 1.0 - self.dones.float()

            self.current_rewards = self.current_rewards * not_dones.unsqueeze(1)
            self.current_lengths = self.current_lengths * not_dones

            done_indices = done_indices[:, 0]

        mb_fdones = self.experience_buffer.tensor_dict['dones'].float()
        mb_values = self.experience_buffer.tensor_dict['values']
        mb_next_values = self.experience_buffer.tensor_dict['next_values']
        mb_rewards = self.experience_buffer.tensor_dict['rewards']
        
        mb_advs = self.discount_values(mb_fdones, mb_values, mb_rewards, mb_next_values)
        mb_returns = mb_advs + mb_values

        batch_dict = self.experience_buffer.get_transformed_list(a2c_common.swap_and_flatten01, self.tensor_list)
        batch_dict['returns'] = a2c_common.swap_and_flatten01(mb_returns)
        batch_dict['played_frames'] = self.batch_size

        return batch_dict

    def prepare_dataset(self, batch_dict):
        obses = batch_dict['obses']
        returns = batch_dict['returns']
        dones = batch_dict['dones']
        values = batch_dict['values']
        actions = batch_dict['actions']
        neglogpacs = batch_dict['neglogpacs']
        mus = batch_dict['mus']
        sigmas = batch_dict['sigmas']
        rnn_states = batch_dict.get('rnn_states', None)
        rnn_masks = batch_dict.get('rnn_masks', None)
        
        advantages = self._calc_advs(batch_dict)

        if self.normalize_value:
            self.value_mean_std.train()
            values = self.value_mean_std(values)
            returns = self.value_mean_std(returns)
            self.value_mean_std.eval()

        dataset_dict = {}
        dataset_dict['old_values'] = values
        dataset_dict['old_logp_actions'] = neglogpacs
        dataset_dict['advantages'] = advantages
        dataset_dict['returns'] = returns
        dataset_dict['actions'] = actions
        dataset_dict['obs'] = obses
        dataset_dict['rnn_states'] = rnn_states
        dataset_dict['rnn_masks'] = rnn_masks
        dataset_dict['mu'] = mus
        dataset_dict['sigma'] = sigmas

        self.dataset.update_values_dict(dataset_dict)

        if self.has_central_value:
            dataset_dict = {}
            dataset_dict['old_values'] = values
            dataset_dict['advantages'] = advantages
            dataset_dict['returns'] = returns
            dataset_dict['actions'] = actions
            dataset_dict['obs'] = batch_dict['states']
            dataset_dict['rnn_masks'] = rnn_masks
            self.central_value_net.update_dataset(dataset_dict)

        return

    def calc_gradients(self, input_dict):
        self.set_train()

        value_preds_batch = input_dict['old_values']
        old_action_log_probs_batch = input_dict['old_logp_actions']
        advantage = input_dict['advantages']
        old_mu_batch = input_dict['mu']
        old_sigma_batch = input_dict['sigma']
        return_batch = input_dict['returns']
        actions_batch = input_dict['actions']
        obs_batch = input_dict['obs']
        obs_batch = self._preproc_obs(obs_batch)

        lr = self.last_lr
        kl = 1.0
        lr_mul = 1.0
        curr_e_clip = lr_mul * self.e_clip

        batch_dict = {
            'is_train': True,
            'prev_actions': actions_batch, 
            'obs' : obs_batch
        }

        rnn_masks = None
        if self.is_rnn:
            rnn_masks = input_dict['rnn_masks']
            batch_dict['rnn_states'] = input_dict['rnn_states']
            batch_dict['seq_length'] = self.seq_len

        with torch.cuda.amp.autocast(enabled=self.mixed_precision):
            res_dict = self.model(batch_dict)
            action_log_probs = res_dict['prev_neglogp']
            values = res_dict['values']
            entropy = res_dict['entropy']
            mu = res_dict['mus']
            sigma = res_dict['sigmas']

            a_info = self._actor_loss(old_action_log_probs_batch, action_log_probs, advantage, curr_e_clip)
            a_loss = a_info['actor_loss']

            c_info = self._critic_loss(value_preds_batch, values, curr_e_clip, return_batch, self.clip_value)
            c_loss = c_info['critic_loss']

            b_loss = self.bound_loss(mu)
            
            a_loss = torch.mean(a_loss)
            c_loss = torch.mean(c_loss)
            b_loss = torch.mean(b_loss)
            entropy = torch.mean(entropy)

            loss = a_loss + self.critic_coef * c_loss - self.entropy_coef * entropy + self.bounds_loss_coef * b_loss
            
            a_clip_frac = torch.mean(a_info['actor_clipped'].float())
            
            a_info['actor_loss'] = a_loss
            a_info['actor_clip_frac'] = a_clip_frac

            if self.multi_gpu:
                self.optimizer.zero_grad()
            else:
                for param in self.model.parameters():
                    param.grad = None

        self.scaler.scale(loss).backward()
        self.scaler.step(self.optimizer)
        self.scaler.update()

        with torch.no_grad():
            reduce_kl = not self.is_rnn
            kl_dist = torch_ext.policy_kl(mu.detach(), sigma.detach(), old_mu_batch, old_sigma_batch, reduce_kl)
                    
        self.train_result = {
            'entropy': entropy,
            'kl': kl_dist,
            'last_lr': self.last_lr, 
            'lr_mul': lr_mul, 
            'b_loss': b_loss
        }
        self.train_result.update(a_info)
        self.train_result.update(c_info)

        return

    def discount_values(self, mb_fdones, mb_values, mb_rewards, mb_next_values):
        lastgaelam = 0
        mb_advs = torch.zeros_like(mb_rewards)

        for t in reversed(range(self.horizon_length)):
            not_done = 1.0 - mb_fdones[t]
            not_done = not_done.unsqueeze(1)

            delta = mb_rewards[t] + self.gamma * mb_next_values[t] - mb_values[t]
            lastgaelam = delta + self.gamma * self.tau * not_done * lastgaelam
            mb_advs[t] = lastgaelam

        return mb_advs

    def env_reset(self, env_ids=None):
        obs = self.vec_env.reset(env_ids)
        obs = self.obs_to_tensors(obs)
        return obs

    def bound_loss(self, mu):
        if self.bounds_loss_coef is not None:
            soft_bound = 1.0
            mu_loss_high = torch.clamp_min(mu - soft_bound, 0.0)**2
            mu_loss_low = torch.clamp_max(mu + soft_bound, 0.0)**2
            b_loss = (mu_loss_low + mu_loss_high).sum(axis=-1)
        else:
            b_loss = 0
        return b_loss

    def _get_mean_rewards(self):
        return self.game_rewards.get_mean()

    def _load_config_params(self, config):
        self.last_lr = config['learning_rate']
        return

    def _build_net_config(self):
        obs_shape = torch_ext.shape_whc_to_cwh(self.obs_shape)
        config = {
            'actions_num' : self.actions_num,
            'input_shape' : obs_shape,
            'num_seqs' : self.num_actors * self.num_agents,
            'value_size': self.env_info.get('value_size', 1),
            'normalize_value' : self.normalize_value,
            'normalize_input': self.normalize_input,
        }
        return config

    def _setup_action_space(self):
        action_space = self.env_info['action_space']
        self.actions_num = action_space.shape[0]

        # todo introduce device instead of cuda()
        self.actions_low = torch.from_numpy(action_space.low.copy()).float().to(self.ppo_device)
        self.actions_high = torch.from_numpy(action_space.high.copy()).float().to(self.ppo_device)
        return

    def _init_train(self):
        return

    def _eval_critic(self, obs_dict):
        self.model.eval()
        obs = obs_dict['obs']
        processed_obs = self._preproc_obs(obs)
        value = self.model.eval_critic(processed_obs)

        return value

    def _actor_loss(self, old_action_log_probs_batch, action_log_probs, advantage, curr_e_clip):
        ratio = torch.exp(old_action_log_probs_batch - action_log_probs)
        surr1 = advantage * ratio
        surr2 = advantage * torch.clamp(ratio, 1.0 - curr_e_clip,
                                    1.0 + curr_e_clip)
        a_loss = torch.max(-surr1, -surr2)

        clipped = torch.abs(ratio - 1.0) > curr_e_clip
        clipped = clipped.detach()
        
        info = {
            'actor_loss': a_loss,
            'actor_clipped': clipped.detach()
        }
        return info

    def _critic_loss(self, value_preds_batch, values, curr_e_clip, return_batch, clip_value):
        if clip_value:
            value_pred_clipped = value_preds_batch + \
                    (values - value_preds_batch).clamp(-curr_e_clip, curr_e_clip)
            value_losses = (values - return_batch)**2
            value_losses_clipped = (value_pred_clipped - return_batch)**2
            c_loss = torch.max(value_losses, value_losses_clipped)
        else:
            c_loss = (return_batch - values)**2

        info = {
            'critic_loss': c_loss
        }
        return info
    
    def _calc_advs(self, batch_dict):
        returns = batch_dict['returns']
        values = batch_dict['values']

        advantages = returns - values
        advantages = torch.sum(advantages, axis=1)

        if self.normalize_advantage:
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        return advantages

    def _record_train_batch_info(self, batch_dict, train_info):
        return

    def _log_train_info(self, train_info, frame):
        self.writer.add_scalar('performance/update_time', train_info['update_time'], frame)
        self.writer.add_scalar('performance/play_time', train_info['play_time'], frame)
        self.writer.add_scalar('losses/a_loss', torch_ext.mean_list(train_info['actor_loss']).item(), frame)
        self.writer.add_scalar('losses/c_loss', torch_ext.mean_list(train_info['critic_loss']).item(), frame)
        
        self.writer.add_scalar('losses/bounds_loss', torch_ext.mean_list(train_info['b_loss']).item(), frame)
        self.writer.add_scalar('losses/entropy', torch_ext.mean_list(train_info['entropy']).item(), frame)
        self.writer.add_scalar('info/last_lr', train_info['last_lr'][-1] * train_info['lr_mul'][-1], frame)
        self.writer.add_scalar('info/lr_mul', train_info['lr_mul'][-1], frame)
        self.writer.add_scalar('info/e_clip', self.e_clip * train_info['lr_mul'][-1], frame)
        self.writer.add_scalar('info/clip_frac', torch_ext.mean_list(train_info['actor_clip_frac']).item(), frame)
        self.writer.add_scalar('info/kl', torch_ext.mean_list(train_info['kl']).item(), frame)
        return


================================================
FILE: timechamber/ase/utils/common_player.py
================================================
# Copyright (c) 2018-2022, NVIDIA Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import torch 

from rl_games.algos_torch import players
from rl_games.algos_torch import torch_ext
from rl_games.algos_torch.running_mean_std import RunningMeanStd
from rl_games.common.player import BasePlayer
from timechamber.utils.utils import load_check, load_checkpoint

import numpy as np

class CommonPlayer(players.PpoPlayerContinuous):
    def __init__(self, params):
        config = params['config']
        BasePlayer.__init__(self, params)
        self.network = config['network']
        
        self._setup_action_space()
        self.mask = [False]

        self.normalize_input = self.config['normalize_input']
        self.normalize_value = self.config.get('normalize_value', False)

        net_config = self._build_net_config()
        self._build_net(net_config)   
        
        return

    def run(self):
        n_games = self.games_num
        render = self.render_env
        n_game_life = self.n_game_life
        is_determenistic = self.is_determenistic
        sum_rewards = 0
        sum_steps = 0
        sum_game_res = 0
        n_games = n_games * n_game_life
        games_played = 0
        has_masks = False
        has_masks_func = getattr(self.env, "has_action_mask", None) is not None

        op_agent = getattr(self.env, "create_agent", None)
        if op_agent:
            agent_inited = True

        if has_masks_func:
            has_masks = self.env.has_action_mask()

        need_init_rnn = self.is_rnn
        for _ in range(n_games):
            if games_played >= n_games:
                break

            obs_dict = self.env_reset()
            batch_size = 1
            batch_size = self.get_batch_size(obs_dict['obs'], batch_size)

            if need_init_rnn:
                self.init_rnn()
                need_init_rnn = False

            cr = torch.zeros(batch_size, dtype=torch.float32, device=self.device)
            steps = torch.zeros(batch_size, dtype=torch.float32, device=self.device)

            print_game_res = False

            done_indices = []

            for n in range(self.max_steps):
                # obs_dict = self.env_reset(done_indices)

                if has_masks:
                    masks = self.env.get_action_mask()
                    action = self.get_masked_action(obs_dict, masks, is_determenistic)
                else:
                    action = self.get_action(obs_dict, is_determenistic)
                obs_dict, r, done, info =  self.env_step(self.env, action)
                obs_dict = {'obs': obs_dict}
                # print('obs_dict shape: ', obs_dict.shape)
                cr += r
                steps += 1
  
                self._post_step(info)

                if render:
                    self.env.render(mode = 'human')
                    time.sleep(self.render_sleep)

                all_done_indices = done.nonzero(as_tuple=False)
                done_indices = all_done_indices[::self.num_agents]
                done_count = len(done_indices)
                games_played += done_count

                if done_count > 0:
                    if self.is_rnn:
                        for s in self.states:
                            s[:,all_done_indices,:] = s[:,all_done_indices,:] * 0.0

                    cur_rewards = cr[done_indices].sum().item()
                    cur_steps = steps[done_indices].sum().item()

                    cr = cr * (1.0 - done.float())
                    steps = steps * (1.0 - done.float())
                    sum_rewards += cur_rewards
                    sum_steps += cur_steps

                    game_res = 0.0
                    if isinstance(info, dict):
                        if 'battle_won' in info:
                            print_game_res = True
                            game_res = info.get('battle_won', 0.5)
                        if 'scores' in info:
                            print_game_res = True
                            game_res = info.get('scores', 0.5)
                    if self.print_stats:
                        if print_game_res:
                            print('reward:', cur_rewards/done_count, 'steps:', cur_steps/done_count, 'w:', game_res)
                        else:
                            print('reward:', cur_rewards/done_count, 'steps:', cur_steps/done_count)

                    sum_game_res += game_res
                    if batch_size//self.num_agents == 1 or games_played >= n_games:
                        break
                
                done_indices = done_indices[:, 0]

        print(sum_rewards)
        if print_game_res:
            print('av reward:', sum_rewards / games_played * n_game_life, 'av steps:', sum_steps / games_played * n_game_life, 'winrate:', sum_game_res / games_played * n_game_life)
        else:
            print('av reward:', sum_rewards / games_played * n_game_life, 'av steps:', sum_steps / games_played * n_game_life)

        return

    def get_action(self, obs_dict, is_determenistic = False):
        output = super().get_action(obs_dict['obs'], is_determenistic)
        return output

    def env_step(self, env, actions):
        if not self.is_tensor_obses:
            actions = actions.cpu().numpy()
        obs, rewards, dones, infos = env.step(actions)

        if hasattr(obs, 'dtype') and obs.dtype == np.float64:
            obs = np.float32(obs)
        if self.value_size > 1:
            rewards = rewards[0]
        if self.is_tensor_obses:
            return obs, rewards.to(self.device), dones.to(self.device), infos
        else:
            if np.isscalar(dones):
                rewards = np.expand_dims(np.asarray(rewards), 0)
                dones = np.expand_dims(np.asarray(dones), 0)
            return self.obs_to_torch(obs), torch.from_numpy(rewards), torch.from_numpy(dones), infos

    def _build_net(self, config):
        self.model = self.network.build(config)
        self.model.to(self.device)
        self.model.eval()
        self.is_rnn = self.model.is_rnn()
        if self.normalize_input:
            obs_shape = torch_ext.shape_whc_to_cwh(self.obs_shape)
            self.running_mean_std = RunningMeanStd(obs_shape).to(self.device)
            self.running_mean_std.eval() 
        return

    def env_reset(self, env_ids=None):
        obs = self.env.reset(env_ids)
        return self.obs_to_torch(obs)

    def _post_step(self, info):
        return

    def _build_net_config(self):
        obs_shape = torch_ext.shape_whc_to_cwh(self.obs_shape)
        config = {
            'actions_num' : self.actions_num,
            'input_shape' : obs_shape,
            'num_seqs' : self.num_agents,
            'normalize_input': self.normalize_input,
            'normalize_value' : self.normalize_value,
        }
        return config

    def restore(self, fn):
        checkpoint = load_checkpoint(fn, device=self.device)
        checkpoint = load_check(checkpoint=checkpoint,
                                normalize_input=self.normalize_input,
                                normalize_value=self.normalize_value)
        self.model.load_state_dict(checkpoint['model'])

        if self.normalize_input and 'running_mean_std' in checkpoint:
            self.model.running_mean_std.load_state_dict(checkpoint['running_mean_std'])

    def _setup_action_space(self):
        self.actions_num = self.action_space.shape[0] 
        self.actions_low = torch.from_numpy(self.action_space.low.copy()).float().to(self.device)
        self.actions_high = torch.from_numpy(self.action_space.high.copy()).float().to(self.device)
        return

================================================
FILE: timechamber/ase/utils/replay_buffer.py
================================================
# Copyright (c) 2018-2022, NVIDIA Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import torch

class ReplayBuffer():
    def __init__(self, buffer_size, device):
        self._head = 0
        self._total_count = 0
        self._buffer_size = buffer_size
        self._device = device
        self._data_buf = None
        self._sample_idx = torch.randperm(buffer_size)
        self._sample_head = 0

        return

    def reset(self):
        self._head = 0
        self._total_count = 0
        self._reset_sample_idx()
        return

    def get_buffer_size(self):
        return self._buffer_size

    def get_total_count(self):
        return self._total_count

    def store(self, data_dict):
        if (self._data_buf is None):
            self._init_data_buf(data_dict)

        n = next(iter(data_dict.values())).shape[0]
        buffer_size = self.get_buffer_size()
        assert(n <= buffer_size)

        for key, curr_buf in self._data_buf.items():
            curr_n = data_dict[key].shape[0]
            assert(n == curr_n)

            store_n = min(curr_n, buffer_size - self._head)
            curr_buf[self._head:(self._head + store_n)] = data_dict[key][:store_n]    
        
            remainder = n - store_n
            if (remainder > 0):
                curr_buf[0:remainder] = data_dict[key][store_n:]  

        self._head = (self._head + n) % buffer_size
        self._total_count += n

        return

    def sample(self, n):
        total_count = self.get_total_count()
        buffer_size = self.get_buffer_size()

        idx = torch.arange(self._sample_head, self._sample_head + n)
        idx = idx % buffer_size
        rand_idx = self._sample_idx[idx]
        if (total_count < buffer_size):
            rand_idx = rand_idx % self._head

        samples = dict()
        for k, v in self._data_buf.items():
            samples[k] = v[rand_idx]

        self._sample_head += n
        if (self._sample_head >= buffer_size):
            self._reset_sample_idx()

        return samples

    def _reset_sample_idx(self):
        buffer_size = self.get_buffer_size()
        self._sample_idx[:] = torch.randperm(buffer_size)
        self._sample_head = 0
        return

    def _init_data_buf(self, data_dict):
        buffer_size = self.get_buffer_size()
        self._data_buf = dict()

        for k, v in data_dict.items():
            v_shape = v.shape[1:]
            self._data_buf[k] = torch.zeros((buffer_size,) + v_shape, device=self._device)

        return

================================================
FILE: timechamber/cfg/config.yaml
================================================
# Task name - used to pick the class to load
task_name: ${task.name}
# experiment name. defaults to name of training config
experiment: ''

# if set to positive integer, overrides the default number of environments
num_envs: ''

# seed - set to -1 to choose random seed
seed: 42
# set to True for deterministic performance
torch_deterministic: False

# set the maximum number of learning iterations to train for. overrides default per-environment setting
max_iterations: ''

# set minibatch_size
minibatch_size: 32768

## Device config
#  'physx' or 'flex'
physics_engine: 'physx'
# whether to use cpu or gpu pipeline
pipeline: 'gpu'
use_gpu: True
use_gpu_pipeline: True
# device for running physics simulation
sim_device: 'cuda:0'
# device to run RL
rl_device: 'cuda:0'
graphics_device_id: 0
device_type: cuda

## PhysX arguments
num_threads: 4 # Number of worker threads per scene used by PhysX - for CPU PhysX only.
solver_type: 1 # 0: pgs, 1: tgs
num_subscenes: 4 # Splits the simulation into N physics scenes and runs each one in a separate thread

# RLGames Arguments
# test - if set, run policy in inference mode (requires setting checkpoint to load)
test: False
# used to set checkpoint path
checkpoint: ''
op_checkpoint: ''
player_pool_type: ''
num_agents: 2

# HRL Arguments
motion_file: 'tasks/data/motions/reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy'

# set to True to use multi-gpu horovod training
multi_gpu: False

wandb_activate: False
wandb_group: ''
wandb_name: ${train.params.config.name}
wandb_entity: ''
wandb_project: 'timechamber'
capture_video: False
capture_video_freq: 1464
capture_video_len: 100
force_render: True

# disables rendering
headless: True

# set default task and default training config based on task
defaults:
  - task: MA_Humanoid_Strike
  - train: ${task}HRL
  - hydra/job_logging: disabled

# set the directory where the output files get saved
hydra:
  output_subdir: null
  run:
    dir: .



================================================
FILE: timechamber/cfg/task/MA_Ant_Battle.yaml
================================================
# used to create the object
name: MA_Ant_Battle

physics_engine: ${..physics_engine}

# if given, will override the device setting in gym.
env:
  #  numEnvs: ${...num_envs}
  numEnvs: ${resolve_default:4096,${...num_envs}}
  numAgents: ${...num_agents}
  # rgb color of Ant body
  color: [ [ 0.97, 0.38, 0.06 ],[ 0.24, 0.38, 0.06 ],[ 0.56, 0.85, 0.25 ],[ 0.56, 0.85, 0.25 ],[ 0.14, 0.97, 0.24 ],[ 0.63, 0.2, 0.87 ] ]
  envSpacing: 6
  borderlineSpace: 3
  episodeLength: 1000
  enableDebugVis: False
  controlFrequencyInv: 1
  clipActions: 1.0
  clipObservations: 5.0
  actionScale: 0.5
  control:
    # PD Drive parameters:
    stiffness: 85.0  # [N*m/rad]
    damping: 2.0     # [N*m*s/rad]
    actionScale: 0.5
    controlFrequencyInv: 1 # 60 Hz

  # reward parameters
  headingWeight: 0.5
  upWeight: 0.1

  # cost parameters
  terminationHeight: 0.31
  dofVelocityScale: 0.2
  jointsAtLimitCost: -0.1

  plane:
    staticFriction: 1.0
    dynamicFriction: 1.0
    restitution: 0.0

  asset:
    assetFileName: "mjcf/nv_ant.xml"

  # set to True if you use camera sensors in the environment
  enableCameraSensors: False

sim:
  dt: 0.0166 # 1/60 s
  substeps: 2
  up_axis: "z"
  use_gpu_pipeline: ${eq:${...pipeline},"gpu"}
  gravity: [ 0.0, 0.0, -9.81 ]
  physx:
    num_threads: ${....num_threads}
    solver_type: ${....solver_type}
    use_gpu: ${contains:"cuda",${....sim_device}} # set to False to run on CPU
    num_position_iterations: 4
    num_velocity_iterations: 0
    contact_offset: 0.02
    rest_offset: 0.0
    bounce_threshold_velocity: 0.2
    max_depenetration_velocity: 10.0
    default_buffer_size_multiplier: 5.0
    max_gpu_contact_pairs: 8388608 # 8*1024*1024
    num_subscenes: ${....num_subscenes}
    contact_collection: 0 # 0: CC_NEVER (don't collect contact info), 1: CC_LAST_SUBSTEP (collect only contacts on last substep), 2: CC_ALL_SUBSTEPS (default - all contacts)

task:
  randomize: False
  randomization_params:
    # specify which attributes to randomize for each actor type and property
    frequency: 600   # Define how many environment steps between generating new randomizations
    observations:
      range: [ 0, .002 ] # range for the white noise
      operation: "additive"
      distribution: "gaussian"
    actions:
      range: [ 0., .02 ]
      operation: "additive"
      distribution: "gaussian"
    actor_params:
      ant:
        color: True
        rigid_body_properties:
          mass:
            range: [ 0.5, 1.5 ]
            operation: "scaling"
            distribution: "uniform"
            setup_only: True # Property will only be randomized once before simulation is started. See Domain Randomization Documentation for more info.
        dof_properties:
          damping:
            range: [ 0.5, 1.5 ]
            operation: "scaling"
            distribution: "uniform"
          stiffness:
            range: [ 0.5, 1.5 ]
            operation: "scaling"
            distribution: "uniform"
          lower:
            range: [ 0, 0.01 ]
            operation: "additive"
            distribution: "gaussian"
          upper:
            range: [ 0, 0.01 ]
            operation: "additive"
            distribution: "gaussian"


================================================
FILE: timechamber/cfg/task/MA_Ant_Sumo.yaml
================================================
# used to create the object
name: MA_Ant_Sumo

physics_engine: ${..physics_engine}

# if given, will override the device setting in gym.
env:
#  numEnvs: ${...num_envs}
  numEnvs: ${resolve_default:4096,${...num_envs}}
  numAgents: ${...num_agents}
  envSpacing: 6
  borderlineSpace: 3
  episodeLength: 1000
  enableDebugVis: False
  controlFrequencyInv: 1
  clipActions: 1.0
  clipObservations: 5.0
  actionScale: 0.5
  control:
    # PD Drive parameters:
    stiffness: 85.0  # [N*m/rad]
    damping: 2.0     # [N*m*s/rad]
    actionScale: 0.5
    controlFrequencyInv: 1 # 60 Hz

  # reward parameters
  headingWeight: 0.5
  upWeight: 0.1

  # cost parameters
  terminationHeight: 0.31
  dofVelocityScale: 0.2
  jointsAtLimitCost: -0.1

  plane:
    staticFriction: 1.0
    dynamicFriction: 1.0
    restitution: 0.0

  asset:
    assetFileName: "mjcf/nv_ant.xml"

# set to True if you use camera sensors in the environment
enableCameraSensors: False

sim:
  dt: 0.0166 # 1/60 s
  substeps: 2
  up_axis: "z"
  use_gpu_pipeline: ${eq:${...pipeline},"gpu"}
  gravity: [0.0, 0.0, -9.81]
  physx:
    num_threads: ${....num_threads}
    solver_type: ${....solver_type}
    use_gpu: ${contains:"cuda",${....sim_device}} # set to False to run on CPU
    num_position_iterations: 4
    num_velocity_iterations: 0
    contact_offset: 0.02
    rest_offset: 0.0
    bounce_threshold_velocity: 0.2
    max_depenetration_velocity: 10.0
    default_buffer_size_multiplier: 5.0
    max_gpu_contact_pairs: 8388608 # 8*1024*1024
    num_subscenes: ${....num_subscenes}
    contact_collection: 0 # 0: CC_NEVER (don't collect contact info), 1: CC_LAST_SUBSTEP (collect only contacts on last substep), 2: CC_ALL_SUBSTEPS (default - all contacts)

task:
  randomize: False
  randomization_params:
    # specify which attributes to randomize for each actor type and property
    frequency: 600   # Define how many environment steps between generating new randomizations
    observations:
      range: [0, .002] # range for the white noise
      operation: "additive"
      distribution: "gaussian"
    actions:
      range: [0., .02]
      operation: "additive"
      distribution: "gaussian"
    actor_params:
      ant:
        color: True
        rigid_body_properties:
          mass:
            range: [0.5, 1.5]
            operation: "scaling"
            distribution: "uniform"
            setup_only: True # Property will only be randomized once before simulation is started. See Domain Randomization Documentation for more info.
        dof_properties:
          damping:
            range: [0.5, 1.5]
            operation: "scaling"
            distribution: "uniform"
          stiffness:
            range: [0.5, 1.5]
            operation: "scaling"
            distribution: "uniform"
          lower:
            range: [0, 0.01]
            operation: "additive"
            distribution: "gaussian"
          upper:
            range: [0, 0.01]
            operation: "additive"
            distribution: "gaussian"


================================================
FILE: timechamber/cfg/task/MA_Humanoid_Strike.yaml
================================================
name: MA_Humanoid_Strike

physics_engine: ${..physics_engine}

# if given, will override the device setting in gym. 
env: 
  numEnvs: ${resolve_default:4096,${...num_envs}}
  envSpacing: 6
  episodeLength: 1500
  borderlineSpace: 3.0
  numAgents: 2
  isFlagrun: False
  enableDebugVis: False
  
  pdControl: True
  powerScale: 1.0
  controlFrequencyInv: 2 # 30 Hz
  stateInit: "Default"
  hybridInitProb: 0.5
  numAMPObsSteps: 10
  
  localRootObs: True
  keyBodies: ["right_hand", "left_hand", "right_foot", "left_foot", "sword", "shield"]
  contactBodies: ["right_foot", "left_foot"]
  # forceBod
Download .txt
gitextract_rvpupy7y/

├── .gitattributes
├── .gitignore
├── LICENSE
├── LISENCE/
│   └── isaacgymenvs/
│       └── LICENSE
├── README.md
├── assets/
│   └── mjcf/
│       └── nv_ant.xml
├── docs/
│   └── environments.md
├── setup.py
└── timechamber/
    ├── __init__.py
    ├── ase/
    │   ├── ase_agent.py
    │   ├── ase_models.py
    │   ├── ase_network_builder.py
    │   ├── ase_players.py
    │   ├── hrl_agent.py
    │   ├── hrl_models.py
    │   ├── hrl_network_builder.py
    │   ├── hrl_players.py
    │   └── utils/
    │       ├── amp_agent.py
    │       ├── amp_datasets.py
    │       ├── amp_models.py
    │       ├── amp_network_builder.py
    │       ├── amp_players.py
    │       ├── common_agent.py
    │       ├── common_player.py
    │       └── replay_buffer.py
    ├── cfg/
    │   ├── config.yaml
    │   ├── task/
    │   │   ├── MA_Ant_Battle.yaml
    │   │   ├── MA_Ant_Sumo.yaml
    │   │   └── MA_Humanoid_Strike.yaml
    │   └── train/
    │       ├── MA_Ant_BattlePPO.yaml
    │       ├── MA_Ant_SumoPPO.yaml
    │       ├── MA_Humanoid_StrikeHRL.yaml
    │       └── base/
    │           └── ase_humanoid_hrl.yaml
    ├── learning/
    │   ├── common_agent.py
    │   ├── common_player.py
    │   ├── hrl_sp_agent.py
    │   ├── hrl_sp_player.py
    │   ├── pfsp_player_pool.py
    │   ├── ppo_sp_agent.py
    │   ├── ppo_sp_player.py
    │   ├── replay_buffer.py
    │   ├── vectorized_models.py
    │   └── vectorized_network_builder.py
    ├── models/
    │   ├── Humanoid_Strike/
    │   │   ├── policy.pth
    │   │   └── policy_op.pth
    │   ├── ant_battle_2agents/
    │   │   └── policy.pth
    │   ├── ant_battle_3agents/
    │   │   └── policy.pth
    │   └── ant_sumo/
    │       └── policy.pth
    ├── tasks/
    │   ├── __init__.py
    │   ├── ase_humanoid_base/
    │   │   ├── base_task.py
    │   │   ├── humanoid.py
    │   │   ├── humanoid_amp.py
    │   │   ├── humanoid_amp_task.py
    │   │   └── poselib/
    │   │       ├── README.md
    │   │       ├── data/
    │   │       │   ├── 01_01_cmu.fbx
    │   │       │   ├── 07_01_cmu.fbx
    │   │       │   ├── 08_02_cmu.fbx
    │   │       │   ├── 09_11_cmu.fbx
    │   │       │   ├── 49_08_cmu.fbx
    │   │       │   ├── 55_01_cmu.fbx
    │   │       │   ├── amp_humanoid_tpose.npy
    │   │       │   ├── cmu_tpose.npy
    │   │       │   ├── configs/
    │   │       │   │   ├── retarget_cmu_to_amp.json
    │   │       │   │   └── retarget_sfu_to_amp.json
    │   │       │   └── sfu_tpose.npy
    │   │       ├── fbx_importer.py
    │   │       ├── generate_amp_humanoid_tpose.py
    │   │       ├── mjcf_importer.py
    │   │       ├── poselib/
    │   │       │   ├── __init__.py
    │   │       │   ├── core/
    │   │       │   │   ├── __init__.py
    │   │       │   │   ├── backend/
    │   │       │   │   │   ├── __init__.py
    │   │       │   │   │   ├── abstract.py
    │   │       │   │   │   └── logger.py
    │   │       │   │   ├── rotation3d.py
    │   │       │   │   ├── tensor_utils.py
    │   │       │   │   └── tests/
    │   │       │   │       ├── __init__.py
    │   │       │   │       └── test_rotation.py
    │   │       │   ├── skeleton/
    │   │       │   │   ├── __init__.py
    │   │       │   │   ├── backend/
    │   │       │   │   │   ├── __init__.py
    │   │       │   │   │   └── fbx/
    │   │       │   │   │       ├── __init__.py
    │   │       │   │   │       ├── fbx_backend.py
    │   │       │   │   │       └── fbx_read_wrapper.py
    │   │       │   │   └── skeleton3d.py
    │   │       │   └── visualization/
    │   │       │       ├── __init__.py
    │   │       │       ├── common.py
    │   │       │       ├── core.py
    │   │       │       ├── plt_plotter.py
    │   │       │       ├── simple_plotter_tasks.py
    │   │       │       ├── skeleton_plotter_tasks.py
    │   │       │       └── tests/
    │   │       │           ├── __init__.py
    │   │       │           └── test_plotter.py
    │   │       └── retarget_motion.py
    │   ├── base/
    │   │   ├── __init__.py
    │   │   ├── ma_vec_task.py
    │   │   └── vec_task.py
    │   ├── data/
    │   │   ├── assets/
    │   │   │   └── mjcf/
    │   │   │       └── amp_humanoid_sword_shield.xml
    │   │   ├── models/
    │   │   │   └── llc_reallusion_sword_shield.pth
    │   │   └── motions/
    │   │       └── reallusion_sword_shield/
    │   │           ├── README.txt
    │   │           ├── RL_Avatar_Atk_2xCombo01_Motion.npy
    │   │           ├── RL_Avatar_Atk_2xCombo02_Motion.npy
    │   │           ├── RL_Avatar_Atk_2xCombo03_Motion.npy
    │   │           ├── RL_Avatar_Atk_2xCombo04_Motion.npy
    │   │           ├── RL_Avatar_Atk_2xCombo05_Motion.npy
    │   │           ├── RL_Avatar_Atk_3xCombo01_Motion.npy
    │   │           ├── RL_Avatar_Atk_3xCombo02_Motion.npy
    │   │           ├── RL_Avatar_Atk_3xCombo03_Motion.npy
    │   │           ├── RL_Avatar_Atk_3xCombo04_Motion.npy
    │   │           ├── RL_Avatar_Atk_3xCombo05_Motion.npy
    │   │           ├── RL_Avatar_Atk_3xCombo06_Motion.npy
    │   │           ├── RL_Avatar_Atk_3xCombo07_Motion.npy
    │   │           ├── RL_Avatar_Atk_4xCombo01_Motion.npy
    │   │           ├── RL_Avatar_Atk_4xCombo02_Motion.npy
    │   │           ├── RL_Avatar_Atk_4xCombo03_Motion.npy
    │   │           ├── RL_Avatar_Atk_Jump_Motion.npy
    │   │           ├── RL_Avatar_Atk_Kick_Motion.npy
    │   │           ├── RL_Avatar_Atk_ShieldCharge_Motion.npy
    │   │           ├── RL_Avatar_Atk_ShieldSwipe01_Motion.npy
    │   │           ├── RL_Avatar_Atk_ShieldSwipe02_Motion.npy
    │   │           ├── RL_Avatar_Atk_SlashDown_Motion.npy
    │   │           ├── RL_Avatar_Atk_SlashLeft_Motion.npy
    │   │           ├── RL_Avatar_Atk_SlashRight_Motion.npy
    │   │           ├── RL_Avatar_Atk_SlashUp_Motion.npy
    │   │           ├── RL_Avatar_Atk_Spin_Motion.npy
    │   │           ├── RL_Avatar_Atk_Stab_Motion.npy
    │   │           ├── RL_Avatar_Counter_Atk01_Motion.npy
    │   │           ├── RL_Avatar_Counter_Atk02_Motion.npy
    │   │           ├── RL_Avatar_Counter_Atk03_Motion.npy
    │   │           ├── RL_Avatar_Counter_Atk04_Motion.npy
    │   │           ├── RL_Avatar_Counter_Atk05_Motion.npy
    │   │           ├── RL_Avatar_Dodge_Backward_Motion.npy
    │   │           ├── RL_Avatar_Dodgle_Left_Motion.npy
    │   │           ├── RL_Avatar_Dodgle_Right_Motion.npy
    │   │           ├── RL_Avatar_Fall_Backward_Motion.npy
    │   │           ├── RL_Avatar_Fall_Left_Motion.npy
    │   │           ├── RL_Avatar_Fall_Right_Motion.npy
    │   │           ├── RL_Avatar_Fall_SpinLeft_Motion.npy
    │   │           ├── RL_Avatar_Fall_SpinRight_Motion.npy
    │   │           ├── RL_Avatar_Idle_Alert(0)_Motion.npy
    │   │           ├── RL_Avatar_Idle_Alert_Motion.npy
    │   │           ├── RL_Avatar_Idle_Battle(0)_Motion.npy
    │   │           ├── RL_Avatar_Idle_Battle_Motion.npy
    │   │           ├── RL_Avatar_Idle_Ready(0)_Motion.npy
    │   │           ├── RL_Avatar_Idle_Ready_Motion.npy
    │   │           ├── RL_Avatar_Kill_2xCombo01_Motion.npy
    │   │           ├── RL_Avatar_Kill_2xCombo02_Motion.npy
    │   │           ├── RL_Avatar_Kill_3xCombo01_Motion.npy
    │   │           ├── RL_Avatar_Kill_3xCombo02_Motion.npy
    │   │           ├── RL_Avatar_Kill_4xCombo01_Motion.npy
    │   │           ├── RL_Avatar_RunBackward_Motion.npy
    │   │           ├── RL_Avatar_RunForward_Motion.npy
    │   │           ├── RL_Avatar_RunLeft_Motion.npy
    │   │           ├── RL_Avatar_RunRight_Motion.npy
    │   │           ├── RL_Avatar_Shield_BlockBackward_Motion.npy
    │   │           ├── RL_Avatar_Shield_BlockCrouch_Motion.npy
    │   │           ├── RL_Avatar_Shield_BlockDown_Motion.npy
    │   │           ├── RL_Avatar_Shield_BlockLeft_Motion.npy
    │   │           ├── RL_Avatar_Shield_BlockRight_Motion.npy
    │   │           ├── RL_Avatar_Shield_BlockUp_Motion.npy
    │   │           ├── RL_Avatar_Standoff_Circle_Motion.npy
    │   │           ├── RL_Avatar_Standoff_Feint_Motion.npy
    │   │           ├── RL_Avatar_Standoff_Swing_Motion.npy
    │   │           ├── RL_Avatar_Sword_ParryBackward01_Motion.npy
    │   │           ├── RL_Avatar_Sword_ParryBackward02_Motion.npy
    │   │           ├── RL_Avatar_Sword_ParryBackward03_Motion.npy
    │   │           ├── RL_Avatar_Sword_ParryBackward04_Motion.npy
    │   │           ├── RL_Avatar_Sword_ParryCrouch_Motion.npy
    │   │           ├── RL_Avatar_Sword_ParryDown_Motion.npy
    │   │           ├── RL_Avatar_Sword_ParryLeft_Motion.npy
    │   │           ├── RL_Avatar_Sword_ParryRight_Motion.npy
    │   │           ├── RL_Avatar_Sword_ParryUp_Motion.npy
    │   │           ├── RL_Avatar_Taunt_PoundChest_Motion.npy
    │   │           ├── RL_Avatar_Taunt_Roar_Motion.npy
    │   │           ├── RL_Avatar_Taunt_ShieldKnock_Motion.npy
    │   │           ├── RL_Avatar_TurnLeft180_Motion.npy
    │   │           ├── RL_Avatar_TurnLeft90_Motion.npy
    │   │           ├── RL_Avatar_TurnRight180_Motion.npy
    │   │           ├── RL_Avatar_TurnRight90_Motion.npy
    │   │           ├── RL_Avatar_WalkBackward01_Motion.npy
    │   │           ├── RL_Avatar_WalkBackward02_Motion.npy
    │   │           ├── RL_Avatar_WalkForward01_Motion.npy
    │   │           ├── RL_Avatar_WalkForward02_Motion.npy
    │   │           ├── RL_Avatar_WalkLeft01_Motion.npy
    │   │           ├── RL_Avatar_WalkLeft02_Motion.npy
    │   │           ├── RL_Avatar_WalkRight01_Motion.npy
    │   │           ├── RL_Avatar_WalkRight02_Motion.npy
    │   │           └── dataset_reallusion_sword_shield.yaml
    │   ├── ma_ant_battle.py
    │   ├── ma_ant_sumo.py
    │   └── ma_humanoid_strike.py
    ├── train.py
    └── utils/
        ├── config.py
        ├── gym_util.py
        ├── logger.py
        ├── motion_lib.py
        ├── reformat.py
        ├── rlgames_utils.py
        ├── torch_jit_utils.py
        ├── torch_utils.py
        ├── utils.py
        ├── vec_task.py
        └── vec_task_wrappers.py
Download .txt
SYMBOL INDEX (972 symbols across 60 files)

FILE: timechamber/__init__.py
  function make (line 14) | def make(

FILE: timechamber/ase/ase_agent.py
  class ASEAgent (line 40) | class ASEAgent(amp_agent.AMPAgent):
    method __init__ (line 41) | def __init__(self, base_name, config):
    method init_tensors (line 45) | def init_tensors(self):
    method play_steps (line 64) | def play_steps(self):
    method get_action_values (line 145) | def get_action_values(self, obs_dict, ase_latents, rand_action_probs):
    method prepare_dataset (line 178) | def prepare_dataset(self, batch_dict):
    method calc_gradients (line 186) | def calc_gradients(self, input_dict):
    method env_reset (line 337) | def env_reset(self, env_ids=None):
    method _reset_latent_step_count (line 350) | def _reset_latent_step_count(self, env_ids):
    method _load_config_params (line 355) | def _load_config_params(self, config):
    method _build_net_config (line 374) | def _build_net_config(self):
    method _reset_latents (line 379) | def _reset_latents(self, env_ids):
    method _sample_latents (line 389) | def _sample_latents(self, n):
    method _update_latents (line 393) | def _update_latents(self):
    method _eval_actor (line 408) | def _eval_actor(self, obs, ase_latents):
    method _eval_critic (line 412) | def _eval_critic(self, obs_dict, ase_latents):
    method _calc_amp_rewards (line 422) | def _calc_amp_rewards(self, amp_obs, ase_latents):
    method _calc_enc_rewards (line 431) | def _calc_enc_rewards(self, amp_obs, ase_latents):
    method _enc_loss (line 440) | def _enc_loss(self, enc_pred, ase_latent, enc_obs, loss_mask):
    method _diversity_loss (line 472) | def _diversity_loss(self, obs, action_params, ase_latents):
    method _calc_enc_error (line 496) | def _calc_enc_error(self, enc_pred, ase_latent):
    method _enable_enc_grad_penalty (line 501) | def _enable_enc_grad_penalty(self):
    method _enable_amp_diversity_bonus (line 504) | def _enable_amp_diversity_bonus(self):
    method _eval_enc (line 507) | def _eval_enc(self, amp_obs):
    method _combine_rewards (line 511) | def _combine_rewards(self, task_rewards, amp_rewards):
    method _record_train_batch_info (line 519) | def _record_train_batch_info(self, batch_dict, train_info):
    method _log_train_info (line 524) | def _log_train_info(self, train_info, frame):
    method _change_char_color (line 541) | def _change_char_color(self, env_ids):
    method _amp_debug (line 552) | def _amp_debug(self, info, ase_latents):

FILE: timechamber/ase/ase_models.py
  class ModelASEContinuous (line 31) | class ModelASEContinuous(amp_models.ModelAMPContinuous):
    method __init__ (line 32) | def __init__(self, network):
    method build (line 36) | def build(self, config):
    class Network (line 49) | class Network(amp_models.ModelAMPContinuous.Network):
      method __init__ (line 50) | def __init__(self, a2c_network, obs_shape, normalize_value, normaliz...
      method forward (line 58) | def forward(self, input_dict):
      method eval_actor (line 69) | def eval_actor(self, obs, ase_latents, use_hidden_latents=False):
      method eval_critic (line 74) | def eval_critic(self, obs, ase_latents, use_hidden_latents=False):

FILE: timechamber/ase/ase_network_builder.py
  class LatentType (line 42) | class LatentType(enum.Enum):
  class ASEBuilder (line 46) | class ASEBuilder(amp_network_builder.AMPBuilder):
    method __init__ (line 47) | def __init__(self, **kwargs):
    class Network (line 51) | class Network(amp_network_builder.AMPBuilder.Network):
      method __init__ (line 52) | def __init__(self, params, **kwargs):
      method load (line 120) | def load(self, params):
      method forward (line 130) | def forward(self, obs_dict):
      method eval_critic (line 143) | def eval_critic(self, obs, ase_latents, use_hidden_latents=False):
      method eval_actor (line 151) | def eval_actor(self, obs, ase_latents, use_hidden_latents=False):
      method get_enc_weights (line 174) | def get_enc_weights(self):
      method _build_actor_critic_net (line 183) | def _build_actor_critic_net(self, input_shape, ase_latent_shape):
      method _build_enc (line 213) | def _build_enc(self, input_shape):
      method eval_enc (line 242) | def eval_enc(self, amp_obs):
      method sample_latents (line 249) | def sample_latents(self, n):
    method build (line 255) | def build(self, name, **kwargs):
  class AMPMLPNet (line 260) | class AMPMLPNet(torch.nn.Module):
    method __init__ (line 261) | def __init__(self, obs_size, ase_latent_size, units, activation, initi...
    method forward (line 283) | def forward(self, obs, latent, skip_style):
    method init_params (line 289) | def init_params(self):
    method get_out_size (line 297) | def get_out_size(self):
  class AMPStyleCatNet1 (line 301) | class AMPStyleCatNet1(torch.nn.Module):
    method __init__ (line 302) | def __init__(self, obs_size, ase_latent_size, units, activation,
    method forward (line 333) | def forward(self, obs, latent, skip_style):
    method eval_style (line 348) | def eval_style(self, latent):
    method init_params (line 354) | def init_params(self):
    method get_out_size (line 366) | def get_out_size(self):
    method _build_style_mlp (line 370) | def _build_style_mlp(self, style_units, input_size):

FILE: timechamber/ase/ase_players.py
  class ASEPlayer (line 38) | class ASEPlayer(amp_players.AMPPlayerContinuous):
    method __init__ (line 39) | def __init__(self, params):
    method run (line 58) | def run(self):
    method get_action (line 63) | def get_action(self, obs_dict, is_determenistic=False):
    method env_reset (line 91) | def env_reset(self, env_ids=None):
    method _build_net_config (line 96) | def _build_net_config(self):
    method _reset_latents (line 101) | def _reset_latents(self, done_env_ids=None):
    method _update_latents (line 112) | def _update_latents(self):
    method _reset_latent_step_count (line 126) | def _reset_latent_step_count(self):
    method _calc_amp_rewards (line 130) | def _calc_amp_rewards(self, amp_obs, ase_latents):
    method _calc_enc_rewards (line 139) | def _calc_enc_rewards(self, amp_obs, ase_latents):
    method _calc_enc_error (line 148) | def _calc_enc_error(self, enc_pred, ase_latent):
    method _eval_enc (line 153) | def _eval_enc(self, amp_obs):
    method _amp_debug (line 157) | def _amp_debug(self, info):
    method _change_char_color (line 173) | def _change_char_color(self, env_ids):

FILE: timechamber/ase/hrl_agent.py
  class HRLAgent (line 56) | class HRLAgent(common_agent.CommonAgent):
    method __init__ (line 57) | def __init__(self, base_name, params):
    method env_step (line 75) | def env_step(self, actions):
    method cast_obs (line 117) | def cast_obs(self, obs):
    method preprocess_actions (line 122) | def preprocess_actions(self, actions):
    method play_steps (line 128) | def play_steps(self):
    method _load_config_params (line 198) | def _load_config_params(self, config):
    method _get_mean_rewards (line 205) | def _get_mean_rewards(self):
    method _setup_action_space (line 210) | def _setup_action_space(self):
    method init_tensors (line 215) | def init_tensors(self):
    method _build_llc (line 235) | def _build_llc(self, config_params, checkpoint_file):
    method _build_llc_agent_config (line 243) | def _build_llc_agent_config(self, config_params, network=None):
    method _compute_llc_action (line 259) | def _compute_llc_action(self, obs, actions):
    method _extract_llc_obs (line 270) | def _extract_llc_obs(self, obs):
    method _calc_disc_reward (line 275) | def _calc_disc_reward(self, amp_obs):
    method _combine_rewards (line 279) | def _combine_rewards(self, task_rewards, disc_rewards):
    method _record_train_batch_info (line 286) | def _record_train_batch_info(self, batch_dict, train_info):
    method _log_train_info (line 291) | def _log_train_info(self, train_info, frame):

FILE: timechamber/ase/hrl_models.py
  class ModelHRLContinuous (line 32) | class ModelHRLContinuous(ModelA2CContinuousLogStd):
    method __init__ (line 33) | def __init__(self, network):
    method build (line 37) | def build(self, config):
    class Network (line 49) | class Network(ModelA2CContinuousLogStd.Network):
      method __init__ (line 50) | def __init__(self, a2c_network, obs_shape, normalize_value, normaliz...
      method eval_critic (line 58) | def eval_critic(self, obs):

FILE: timechamber/ase/hrl_network_builder.py
  class HRLBuilder (line 36) | class HRLBuilder(network_builder.A2CBuilder):
    method __init__ (line 37) | def __init__(self, **kwargs):
    class Network (line 41) | class Network(network_builder.A2CBuilder.Network):
      method __init__ (line 42) | def __init__(self, params, **kwargs):
      method forward (line 54) | def forward(self, obs_dict):
      method eval_critic (line 59) | def eval_critic(self, obs):
    method build (line 66) | def build(self, name, **kwargs):

FILE: timechamber/ase/hrl_players.py
  class HRLPlayer (line 47) | class HRLPlayer(common_player.CommonPlayer):
    method __init__ (line 48) | def __init__(self, params):
    method get_action (line 66) | def get_action(self, obs_dict, is_determenistic = False):
    method run (line 92) | def run(self):
    method env_step (line 198) | def env_step(self, env, obs_dict, action):
    method _build_llc (line 238) | def _build_llc(self, config_params, checkpoint_file):
    method _build_llc_agent_config (line 246) | def _build_llc_agent_config(self, config_params, network=None):
    method _setup_action_space (line 261) | def _setup_action_space(self):
    method _compute_llc_action (line 266) | def _compute_llc_action(self, obs, actions):
    method _extract_llc_obs (line 276) | def _extract_llc_obs(self, obs):
    method _calc_disc_reward (line 281) | def _calc_disc_reward(self, amp_obs):

FILE: timechamber/ase/utils/amp_agent.py
  class AMPAgent (line 49) | class AMPAgent(common_agent.CommonAgent):
    method __init__ (line 50) | def __init__(self, base_name, params):
    method init_tensors (line 58) | def init_tensors(self):
    method set_eval (line 63) | def set_eval(self):
    method set_train (line 69) | def set_train(self):
    method get_stats_weights (line 75) | def get_stats_weights(self):
    method set_stats_weights (line 82) | def set_stats_weights(self, weights):
    method play_steps (line 89) | def play_steps(self):
    method get_action_values (line 167) | def get_action_values(self, obs_dict, rand_action_probs):
    method prepare_dataset (line 199) | def prepare_dataset(self, batch_dict):
    method train_epoch (line 209) | def train_epoch(self):
    method calc_gradients (line 294) | def calc_gradients(self, input_dict):
    method _load_config_params (line 420) | def _load_config_params(self, config):
    method _build_net_config (line 447) | def _build_net_config(self):
    method _build_rand_action_probs (line 452) | def _build_rand_action_probs(self):
    method _init_train (line 465) | def _init_train(self):
    method _disc_loss (line 470) | def _disc_loss(self, disc_agent_logit, disc_demo_logit, obs_demo):
    method _disc_loss_neg (line 509) | def _disc_loss_neg(self, disc_logits):
    method _disc_loss_pos (line 514) | def _disc_loss_pos(self, disc_logits):
    method _compute_disc_acc (line 519) | def _compute_disc_acc(self, disc_agent_logit, disc_demo_logit):
    method _fetch_amp_obs_demo (line 526) | def _fetch_amp_obs_demo(self, num_samples):
    method _build_amp_buffers (line 530) | def _build_amp_buffers(self):
    method _init_amp_demo_buf (line 548) | def _init_amp_demo_buf(self):
    method _update_amp_demos (line 558) | def _update_amp_demos(self):
    method _preproc_amp_obs (line 563) | def _preproc_amp_obs(self, amp_obs):
    method _combine_rewards (line 568) | def _combine_rewards(self, task_rewards, amp_rewards):
    method _eval_disc (line 575) | def _eval_disc(self, amp_obs):
    method _calc_advs (line 579) | def _calc_advs(self, batch_dict):
    method _calc_amp_rewards (line 591) | def _calc_amp_rewards(self, amp_obs):
    method _calc_disc_rewards (line 598) | def _calc_disc_rewards(self, amp_obs):
    method _store_replay_amp_obs (line 607) | def _store_replay_amp_obs(self, amp_obs):
    method _record_train_batch_info (line 624) | def _record_train_batch_info(self, batch_dict, train_info):
    method _log_train_info (line 629) | def _log_train_info(self, train_info, frame):
    method _amp_debug (line 646) | def _amp_debug(self, info):

FILE: timechamber/ase/utils/amp_datasets.py
  class AMPDataset (line 32) | class AMPDataset(datasets.PPODataset):
    method __init__ (line 33) | def __init__(self, batch_size, minibatch_size, is_discrete, is_rnn, de...
    method update_mu_sigma (line 38) | def update_mu_sigma(self, mu, sigma):
    method _get_item (line 42) | def _get_item(self, idx):
    method _shuffle_idx_buf (line 57) | def _shuffle_idx_buf(self):

FILE: timechamber/ase/utils/amp_models.py
  class ModelAMPContinuous (line 33) | class ModelAMPContinuous(ModelA2CContinuousLogStd):
    method __init__ (line 34) | def __init__(self, network):
    method build (line 38) | def build(self, config):
    class Network (line 51) | class Network(ModelA2CContinuousLogStd.Network):
      method __init__ (line 52) | def __init__(self, a2c_network, obs_shape, normalize_value, normaliz...
      method forward (line 59) | def forward(self, input_dict):
      method eval_actor (line 78) | def eval_actor(self, obs):
      method eval_critic (line 83) | def eval_critic(self, obs):

FILE: timechamber/ase/utils/amp_network_builder.py
  class AMPBuilder (line 39) | class AMPBuilder(network_builder.A2CBuilder):
    method __init__ (line 40) | def __init__(self, **kwargs):
    class Network (line 44) | class Network(network_builder.A2CBuilder.Network):
      method __init__ (line 45) | def __init__(self, params, **kwargs):
      method load (line 60) | def load(self, params):
      method forward (line 68) | def forward(self, obs_dict):
      method eval_actor (line 79) | def eval_actor(self, obs):
      method eval_critic (line 102) | def eval_critic(self, obs):
      method eval_disc (line 109) | def eval_disc(self, amp_obs):
      method get_disc_logit_weights (line 114) | def get_disc_logit_weights(self):
      method get_disc_weights (line 117) | def get_disc_weights(self):
      method _build_disc (line 126) | def _build_disc(self, input_shape):
    method build (line 152) | def build(self, name, **kwargs):

FILE: timechamber/ase/utils/amp_players.py
  class AMPPlayerContinuous (line 36) | class AMPPlayerContinuous(common_player.CommonPlayer):
    method __init__ (line 37) | def __init__(self, params):
    method restore (line 45) | def restore(self, fn):
    method _build_net (line 53) | def _build_net(self, config):
    method _post_step (line 62) | def _post_step(self, info):
    method _build_net_config (line 68) | def _build_net_config(self):
    method _amp_debug (line 76) | def _amp_debug(self, info):
    method _preproc_amp_obs (line 90) | def _preproc_amp_obs(self, amp_obs):
    method _eval_disc (line 95) | def _eval_disc(self, amp_obs):
    method _calc_amp_rewards (line 99) | def _calc_amp_rewards(self, amp_obs):
    method _calc_disc_rewards (line 106) | def _calc_disc_rewards(self, amp_obs):

FILE: timechamber/ase/utils/common_agent.py
  class CommonAgent (line 54) | class CommonAgent(a2c_continuous.A2CAgent):
    method __init__ (line 55) | def __init__(self, base_name, params):
    method init_tensors (line 105) | def init_tensors(self):
    method train (line 113) | def train(self):
    method set_full_state_weights (line 188) | def set_full_state_weights(self, weights):
    method restore (line 203) | def restore(self, fn):
    method train_epoch (line 210) | def train_epoch(self):
    method play_steps (line 283) | def play_steps(self):
    method prepare_dataset (line 348) | def prepare_dataset(self, batch_dict):
    method calc_gradients (line 394) | def calc_gradients(self, input_dict):
    method discount_values (line 478) | def discount_values(self, mb_fdones, mb_values, mb_rewards, mb_next_va...
    method env_reset (line 492) | def env_reset(self, env_ids=None):
    method bound_loss (line 497) | def bound_loss(self, mu):
    method _get_mean_rewards (line 507) | def _get_mean_rewards(self):
    method _load_config_params (line 510) | def _load_config_params(self, config):
    method _build_net_config (line 514) | def _build_net_config(self):
    method _setup_action_space (line 526) | def _setup_action_space(self):
    method _init_train (line 535) | def _init_train(self):
    method _eval_critic (line 538) | def _eval_critic(self, obs_dict):
    method _actor_loss (line 546) | def _actor_loss(self, old_action_log_probs_batch, action_log_probs, ad...
    method _critic_loss (line 562) | def _critic_loss(self, value_preds_batch, values, curr_e_clip, return_...
    method _calc_advs (line 577) | def _calc_advs(self, batch_dict):
    method _record_train_batch_info (line 589) | def _record_train_batch_info(self, batch_dict, train_info):
    method _log_train_info (line 592) | def _log_train_info(self, train_info, frame):

FILE: timechamber/ase/utils/common_player.py
  class CommonPlayer (line 39) | class CommonPlayer(players.PpoPlayerContinuous):
    method __init__ (line 40) | def __init__(self, params):
    method run (line 56) | def run(self):
    method get_action (line 162) | def get_action(self, obs_dict, is_determenistic = False):
    method env_step (line 166) | def env_step(self, env, actions):
    method _build_net (line 183) | def _build_net(self, config):
    method env_reset (line 194) | def env_reset(self, env_ids=None):
    method _post_step (line 198) | def _post_step(self, info):
    method _build_net_config (line 201) | def _build_net_config(self):
    method restore (line 212) | def restore(self, fn):
    method _setup_action_space (line 222) | def _setup_action_space(self):

FILE: timechamber/ase/utils/replay_buffer.py
  class ReplayBuffer (line 31) | class ReplayBuffer():
    method __init__ (line 32) | def __init__(self, buffer_size, device):
    method reset (line 43) | def reset(self):
    method get_buffer_size (line 49) | def get_buffer_size(self):
    method get_total_count (line 52) | def get_total_count(self):
    method store (line 55) | def store(self, data_dict):
    method sample (line 79) | def sample(self, n):
    method _reset_sample_idx (line 99) | def _reset_sample_idx(self):
    method _init_data_buf (line 105) | def _init_data_buf(self, data_dict):

FILE: timechamber/learning/common_agent.py
  class CommonAgent (line 26) | class CommonAgent(a2c_continuous.A2CAgent):
    method __init__ (line 28) | def __init__(self, base_name, params):
    method init_tensors (line 76) | def init_tensors(self):
    method train (line 84) | def train(self):
    method train_epoch (line 155) | def train_epoch(self):
    method play_steps (line 228) | def play_steps(self):
    method calc_gradients (line 290) | def calc_gradients(self, input_dict):
    method discount_values (line 384) | def discount_values(self, mb_fdones, mb_values, mb_rewards, mb_next_va...
    method bound_loss (line 398) | def bound_loss(self, mu):
    method _load_config_params (line 408) | def _load_config_params(self, config):
    method _build_net_config (line 412) | def _build_net_config(self):
    method _setup_action_space (line 424) | def _setup_action_space(self):
    method _init_train (line 433) | def _init_train(self):
    method _env_reset_done (line 436) | def _env_reset_done(self):
    method _eval_critic (line 440) | def _eval_critic(self, obs_dict):
    method _actor_loss (line 453) | def _actor_loss(self, old_action_log_probs_batch, action_log_probs, ad...
    method _critic_loss (line 474) | def _critic_loss(self, value_preds_batch, values, curr_e_clip, return_...
    method _record_train_batch_info (line 489) | def _record_train_batch_info(self, batch_dict, train_info):
    method _log_train_info (line 492) | def _log_train_info(self, train_info, frame):

FILE: timechamber/learning/common_player.py
  class CommonPlayer (line 11) | class CommonPlayer(players.PpoPlayerContinuous):
    method __init__ (line 13) | def __init__(self, params):
    method run (line 28) | def run(self):
    method obs_to_torch (line 128) | def obs_to_torch(self, obs):
    method get_action (line 135) | def get_action(self, obs_dict, is_determenistic = False):
    method _build_net (line 139) | def _build_net(self, config):
    method _env_reset_done (line 147) | def _env_reset_done(self):
    method _post_step (line 151) | def _post_step(self, info):
    method _build_net_config (line 154) | def _build_net_config(self):
    method _setup_action_space (line 166) | def _setup_action_space(self):

FILE: timechamber/learning/hrl_sp_agent.py
  class HRLSPAgent (line 21) | class HRLSPAgent(hrl_agent.HRLAgent):
    method __init__ (line 22) | def __init__(self, base_name, params):
    method _build_player_pool (line 56) | def _build_player_pool(self, params):
    method play_steps (line 67) | def play_steps(self):
    method env_step (line 153) | def env_step(self, ego_actions, op_actions):
    method env_reset (line 223) | def env_reset(self, env_ids=None):
    method train (line 231) | def train(self):
    method update_metric (line 344) | def update_metric(self):
    method get_action_values (line 360) | def get_action_values(self, obs, is_op=False):
    method restore (line 390) | def restore(self, fn):
    method resample_op (line 397) | def resample_op(self, resample_indices):
    method resample_batch (line 408) | def resample_batch(self):
    method restore_op (line 420) | def restore_op(self, fn):
    method check_update_opponent (line 428) | def check_update_opponent(self, win_rate):
    method create_model (line 438) | def create_model(self):
    method update_player_pool (line 443) | def update_player_pool(self, model, player_idx):

FILE: timechamber/learning/hrl_sp_player.py
  class HRLSPPlayer (line 19) | class HRLSPPlayer(hrl_players.HRLPlayer):
    method __init__ (line 20) | def __init__(self, params):
    method restore (line 52) | def restore(self, load_dir):
    method restore_op (line 80) | def restore_op(self, load_dir):
    method _alloc_env_indices (line 104) | def _alloc_env_indices(self):
    method _build_player_pool (line 120) | def _build_player_pool(self, params, player_num):
    method _update_rating (line 138) | def _update_rating(self, info, env_indices):
    method run (line 159) | def run(self):
    method _plot_elo_curve (line 253) | def _plot_elo_curve(self):
    method get_action (line 277) | def get_action(self, obs, is_determenistic=False, is_op=False):
    method _norm_policy_timestep (line 308) | def _norm_policy_timestep(self):
    method env_reset (line 321) | def env_reset(self, env, env_ids=None):
    method env_step (line 328) | def env_step(self, env, obs_dict, ego_actions, op_actions):
    method create_model (line 396) | def create_model(self):
    method load_model (line 401) | def load_model(self, fn):

FILE: timechamber/learning/pfsp_player_pool.py
  function player_inference_thread (line 12) | def player_inference_thread(model, input_dict, res_dict, env_indices, pr...
  function player_inference_process (line 22) | def player_inference_process(pipe, queue, barrier):
  class SinglePlayer (line 54) | class SinglePlayer:
    method __init__ (line 55) | def __init__(self, player_idx, model, device, obs_batch_len=0, rating=...
    method __call__ (line 71) | def __call__(self, input_dict):
    method reset_envs (line 74) | def reset_envs(self):
    method remove_envs (line 77) | def remove_envs(self, env_indices):
    method add_envs (line 80) | def add_envs(self, env_indices):
    method clear_envs (line 83) | def clear_envs(self):
    method update_metric (line 86) | def update_metric(self, wins, loses, draws):
    method clear_metric (line 97) | def clear_metric(self):
    method win_rate (line 103) | def win_rate(self):
    method games_num (line 110) | def games_num(self):
  class PFSPPlayerPool (line 114) | class PFSPPlayerPool:
    method __init__ (line 115) | def __init__(self, max_length, device):
    method add_player (line 127) | def add_player(self, player):
    method sample_player (line 135) | def sample_player(self, weight='linear'):
    method update_player_metric (line 141) | def update_player_metric(self, infos):
    method clear_player_metric (line 145) | def clear_player_metric(self):
    method inference (line 149) | def inference(self, input_dict, res_dict, processed_obs):
  class PFSPPlayerVectorizedPool (line 159) | class PFSPPlayerVectorizedPool(PFSPPlayerPool):
    method __init__ (line 160) | def __init__(self, max_length, device, vector_model_config, params):
    method inference (line 175) | def inference(self, input_dict, res_dict, processed_obs):
    method add_player (line 186) | def add_player(self, player):
  class PFSPPlayerThreadPool (line 192) | class PFSPPlayerThreadPool(PFSPPlayerPool):
    method __init__ (line 193) | def __init__(self, max_length, device):
    method inference (line 197) | def inference(self, input_dict, res_dict, processed_obs):
  class PFSPPlayerProcessPool (line 205) | class PFSPPlayerProcessPool(PFSPPlayerPool):
    method __init__ (line 206) | def __init__(self, max_length, device):
    method _init_inference_processes (line 216) | def _init_inference_processes(self):
    method add_player (line 230) | def add_player(self, player):
    method inference (line 247) | def inference(self, input_dict, res_dict, processed_obs):
    method __del__ (line 258) | def __del__(self):

FILE: timechamber/learning/ppo_sp_agent.py
  class SPAgent (line 22) | class SPAgent(a2c_continuous.A2CAgent):
    method __init__ (line 23) | def __init__(self, base_name, params):
    method _build_player_pool (line 57) | def _build_player_pool(self, params):
    method play_steps (line 74) | def play_steps(self):
    method env_step (line 145) | def env_step(self, actions):
    method env_reset (line 160) | def env_reset(self, env_ids=None):
    method train (line 167) | def train(self):
    method update_metric (line 266) | def update_metric(self):
    method get_action_values (line 282) | def get_action_values(self, obs, is_op=False):
    method resample_op (line 312) | def resample_op(self, resample_indices):
    method resample_batch (line 323) | def resample_batch(self):
    method restore_op (line 335) | def restore_op(self, fn):
    method check_update_opponent (line 341) | def check_update_opponent(self, win_rate):
    method create_model (line 351) | def create_model(self):
    method update_player_pool (line 356) | def update_player_pool(self, model, player_idx):

FILE: timechamber/learning/ppo_sp_player.py
  function rescale_actions (line 18) | def rescale_actions(low, high, action):
  class SPPlayer (line 25) | class SPPlayer(BasePlayer):
    method __init__ (line 26) | def __init__(self, params):
    method restore (line 61) | def restore(self, load_dir):
    method restore_op (line 89) | def restore_op(self, load_dir):
    method _alloc_env_indices (line 113) | def _alloc_env_indices(self):
    method _build_player_pool (line 129) | def _build_player_pool(self, params, player_num):
    method _update_rating (line 147) | def _update_rating(self, info, env_indices):
    method run (line 170) | def run(self):
    method _plot_elo_curve (line 265) | def _plot_elo_curve(self):
    method get_action (line 292) | def get_action(self, obs, is_determenistic=False, is_op=False):
    method _norm_policy_timestep (line 328) | def _norm_policy_timestep(self):
    method env_reset (line 341) | def env_reset(self, env, done_indices=None):
    method env_step (line 348) | def env_step(self, env, actions):
    method create_model (line 365) | def create_model(self):
    method load_model (line 370) | def load_model(self, fn):

FILE: timechamber/learning/replay_buffer.py
  class ReplayBuffer (line 6) | class ReplayBuffer():
    method __init__ (line 7) | def __init__(self, buffer_size, device):
    method reset (line 18) | def reset(self):
    method get_buffer_size (line 24) | def get_buffer_size(self):
    method get_total_count (line 27) | def get_total_count(self):
    method store (line 30) | def store(self, data_dict):
    method sample (line 54) | def sample(self, n):
    method _reset_sample_idx (line 74) | def _reset_sample_idx(self):
    method _init_data_buf (line 80) | def _init_data_buf(self, data_dict):

FILE: timechamber/learning/vectorized_models.py
  class VectorizedRunningMeanStd (line 8) | class VectorizedRunningMeanStd(RunningMeanStd):
    method __init__ (line 9) | def __init__(self, insize, population_size, epsilon=1e-05, per_channel...
    method _update_mean_var_count_from_moments (line 34) | def _update_mean_var_count_from_moments(self, mean, var, count, batch_...
    method forward (line 45) | def forward(self, input, unnorm=False, mask=None):
  class ModelVectorizedA2C (line 86) | class ModelVectorizedA2C(ModelA2CContinuousLogStd):
    method __init__ (line 87) | def __init__(self, network):
    method build (line 91) | def build(self, config):
    class Network (line 105) | class Network(ModelA2CContinuousLogStd.Network):
      method __init__ (line 106) | def __init__(self, a2c_network, population_size, obs_shape, normaliz...
      method update (line 118) | def update(self, population_idx, network):

FILE: timechamber/learning/vectorized_network_builder.py
  class VectorizedLinearLayer (line 7) | class VectorizedLinearLayer(torch.nn.Module):
    method __init__ (line 10) | def __init__(
    method forward (line 43) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class VectorizedA2CBuilder (line 50) | class VectorizedA2CBuilder(network_builder.A2CBuilder):
    method __init__ (line 51) | def __init__(self, **kwargs):
    class Network (line 55) | class Network(network_builder.A2CBuilder.Network):
      method __init__ (line 56) | def __init__(self, params, **kwargs):
      method _build_vectorized_mlp (line 72) | def _build_vectorized_mlp(self,
      method _build_mlp (line 87) | def _build_mlp(self,
      method forward (line 97) | def forward(self, obs_dict):  # implement continues situation
      method load (line 109) | def load(self, params):
    method build (line 112) | def build(self, name, **kwargs):

FILE: timechamber/tasks/ase_humanoid_base/base_task.py
  class BaseTask (line 22) | class BaseTask():
    method __init__ (line 24) | def __init__(self, cfg, enable_camera_sensors=False):
    method set_sim_params_up_axis (line 110) | def set_sim_params_up_axis(self, sim_params, axis):
    method create_sim (line 119) | def create_sim(self, compute_device, graphics_device, physics_engine, ...
    method step (line 127) | def step(self, actions):
    method get_states (line 147) | def get_states(self):
    method render (line 150) | def render(self, sync_frame_time=False):
    method get_actor_params_info (line 174) | def get_actor_params_info(self, dr_params, env):
    method apply_randomizations (line 212) | def apply_randomizations(self, dr_params):
    method pre_physics_step (line 408) | def pre_physics_step(self, actions):
    method _physics_step (line 411) | def _physics_step(self):
    method post_physics_step (line 417) | def post_physics_step(self):
  function get_attr_val_from_sample (line 421) | def get_attr_val_from_sample(sample, offset, prop, attr):

FILE: timechamber/tasks/ase_humanoid_base/humanoid.py
  class Humanoid (line 41) | class Humanoid(BaseTask):
    method __init__ (line 42) | def __init__(self, cfg, sim_params, physics_engine, device_type, devic...
    method get_obs_size (line 157) | def get_obs_size(self):
    method get_action_size (line 160) | def get_action_size(self):
    method get_num_actors_per_env (line 163) | def get_num_actors_per_env(self):
    method _add_circle_borderline (line 167) | def _add_circle_borderline(self, env):
    method _add_rectangle_borderline (line 180) | def _add_rectangle_borderline(self, env):
    method allocate_buffers (line 232) | def allocate_buffers(self):
    method create_sim (line 254) | def create_sim(self):
    method reset (line 262) | def reset(self, env_ids=None):
    method set_char_color (line 268) | def set_char_color(self, col, env_ids):
    method _reset_envs (line 279) | def _reset_envs(self, env_ids):
    method _reset_env_tensors (line 287) | def _reset_env_tensors(self, env_ids):
    method _create_ground_plane (line 304) | def _create_ground_plane(self):
    method _setup_character_props (line 313) | def _setup_character_props(self, key_bodies):
    method _build_termination_heights (line 337) | def _build_termination_heights(self):
    method _create_envs (line 355) | def _create_envs(self, num_envs, spacing, num_per_row):
    method _build_env (line 442) | def _build_env(self, env_id, env_ptr, humanoid_asset, humanoid_asset_op):
    method _build_pd_action_offset_scale (line 488) | def _build_pd_action_offset_scale(self):
    method _get_humanoid_collision_filter (line 534) | def _get_humanoid_collision_filter(self):
    method _compute_reward (line 537) | def _compute_reward(self, actions):
    method _compute_reset (line 541) | def _compute_reset(self):
    method _refresh_sim_tensors (line 548) | def _refresh_sim_tensors(self):
    method _compute_observations (line 558) | def _compute_observations(self):
    method _compute_humanoid_obs (line 566) | def _compute_humanoid_obs(self):
    method _reset_actors (line 585) | def _reset_actors(self, env_ids):
    method pre_physics_step (line 594) | def pre_physics_step(self, actions):
    method post_physics_step (line 612) | def post_physics_step(self):
    method render (line 628) | def render(self, sync_frame_time=False):
    method _build_key_body_ids_tensor (line 633) | def _build_key_body_ids_tensor(self, key_body_names):
    method _build_contact_body_ids_tensor (line 646) | def _build_contact_body_ids_tensor(self, contact_body_names):
    method _action_to_pd_targets (line 659) | def _action_to_pd_targets(self, action):
    method _update_debug_viz (line 663) | def _update_debug_viz(self):
  function dof_to_obs (line 672) | def dof_to_obs(pose, dof_obs_size, dof_offsets):
  function compute_humanoid_observations (line 704) | def compute_humanoid_observations(root_pos, root_rot, root_vel, root_ang...
  function compute_humanoid_observations_max (line 741) | def compute_humanoid_observations_max(body_pos, body_rot, body_vel, body...
  function expand_env_ids (line 789) | def expand_env_ids(env_ids, n_agents):
  function compute_humanoid_reward (line 798) | def compute_humanoid_reward(obs_buf):
  function compute_humanoid_reset (line 804) | def compute_humanoid_reset(reset_buf, progress_buf, contact_buf, contact...

FILE: timechamber/tasks/ase_humanoid_base/humanoid_amp.py
  class HumanoidAMP (line 43) | class HumanoidAMP(Humanoid):
    class StateInit (line 44) | class StateInit(Enum):
    method __init__ (line 50) | def __init__(self, cfg, sim_params, physics_engine, device_type, devic...
    method post_physics_step (line 78) | def post_physics_step(self):
    method get_num_amp_obs (line 89) | def get_num_amp_obs(self):
    method fetch_amp_obs_demo (line 92) | def fetch_amp_obs_demo(self, num_samples):
    method build_amp_obs_demo (line 107) | def build_amp_obs_demo(self, motion_ids, motion_times0):
    method _build_amp_obs_demo_buf (line 125) | def _build_amp_obs_demo_buf(self, num_samples):
    method _setup_character_props (line 129) | def _setup_character_props(self, key_bodies):
    method _load_motion (line 145) | def _load_motion(self, motion_file):
    method _reset_envs (line 154) | def _reset_envs(self, env_ids):
    method _reset_actors (line 162) | def _reset_actors(self, env_ids):
    method _reset_default (line 174) | def _reset_default(self, env_ids):
    method _reset_ref_state_init (line 182) | def _reset_ref_state_init(self, env_ids):
    method _reset_hybrid_state_init (line 210) | def _reset_hybrid_state_init(self, env_ids):
    method _init_amp_obs (line 225) | def _init_amp_obs(self, env_ids):
    method _init_amp_obs_default (line 237) | def _init_amp_obs_default(self, env_ids):
    method _init_amp_obs_ref (line 242) | def _init_amp_obs_ref(self, env_ids, motion_ids, motion_times):
    method _set_env_state (line 260) | def _set_env_state(self, env_ids, root_pos, root_rot, dof_pos, root_ve...
    method _update_hist_amp_obs (line 270) | def _update_hist_amp_obs(self, env_ids=None):
    method _compute_amp_observations (line 277) | def _compute_amp_observations(self, env_ids=None):
  function build_amp_observations (line 303) | def build_amp_observations(root_pos, root_rot, root_vel, root_ang_vel, d...

FILE: timechamber/tasks/ase_humanoid_base/humanoid_amp_task.py
  class HumanoidAMPTask (line 33) | class HumanoidAMPTask(humanoid_amp.HumanoidAMP):
    method __init__ (line 34) | def __init__(self, cfg, sim_params, physics_engine, device_type, devic...
    method get_obs_size (line 45) | def get_obs_size(self):
    method get_task_obs_size (line 52) | def get_task_obs_size(self):
    method pre_physics_step (line 55) | def pre_physics_step(self, actions):
    method render (line 60) | def render(self, sync_frame_time=False):
    method _update_task (line 67) | def _update_task(self):
    method _reset_envs (line 70) | def _reset_envs(self, env_ids):
    method _reset_task (line 75) | def _reset_task(self, env_ids):
    method _compute_observations (line 78) | def _compute_observations(self):
    method _compute_task_obs (line 102) | def _compute_task_obs(self, env_ids=None):
    method _compute_reward (line 105) | def _compute_reward(self, actions):
    method _draw_task (line 108) | def _draw_task(self):

FILE: timechamber/tasks/ase_humanoid_base/poselib/poselib/core/backend/abstract.py
  function register (line 40) | def register(name):
  function _get_cls (line 50) | def _get_cls(name):
  class NumpyEncoder (line 55) | class NumpyEncoder(json.JSONEncoder):
    method default (line 58) | def default(self, obj):
  function json_numpy_obj_hook (line 83) | def json_numpy_obj_hook(dct):
  class Serializable (line 90) | class Serializable:
    method from_dict (line 97) | def from_dict(cls, dict_repr, *args, **kwargs):
    method to_dict (line 108) | def to_dict(self):
    method from_file (line 116) | def from_file(cls, path, *args, **kwargs):
    method to_file (line 136) | def to_file(self, path: str) -> None:

FILE: timechamber/tasks/ase_humanoid_base/poselib/poselib/core/rotation3d.py
  function quat_mul (line 37) | def quat_mul(a, b):
  function quat_pos (line 53) | def quat_pos(x):
  function quat_abs (line 64) | def quat_abs(x):
  function quat_unit (line 73) | def quat_unit(x):
  function quat_conjugate (line 82) | def quat_conjugate(x):
  function quat_real (line 90) | def quat_real(x):
  function quat_imaginary (line 98) | def quat_imaginary(x):
  function quat_norm_check (line 106) | def quat_norm_check(x):
  function quat_normalize (line 117) | def quat_normalize(q):
  function quat_from_xyz (line 126) | def quat_from_xyz(xyz):
  function quat_identity (line 136) | def quat_identity(shape: List[int]):
  function quat_from_angle_axis (line 147) | def quat_from_angle_axis(angle, axis, degree: bool = False):
  function quat_from_rotation_matrix (line 171) | def quat_from_rotation_matrix(m):
  function quat_mul_norm (line 221) | def quat_mul_norm(x, y):
  function quat_rotate (line 230) | def quat_rotate(rot, vec):
  function quat_inverse (line 239) | def quat_inverse(x):
  function quat_identity_like (line 247) | def quat_identity_like(x):
  function quat_angle_axis (line 255) | def quat_angle_axis(x):
  function quat_yaw_rotation (line 268) | def quat_yaw_rotation(x, z_up: bool = True):
  function transform_from_rotation_translation (line 289) | def transform_from_rotation_translation(
  function transform_identity (line 305) | def transform_identity(shape: List[int]):
  function transform_rotation (line 316) | def transform_rotation(x):
  function transform_translation (line 322) | def transform_translation(x):
  function transform_inverse (line 328) | def transform_inverse(x):
  function transform_identity_like (line 339) | def transform_identity_like(x):
  function transform_mul (line 347) | def transform_mul(x, y):
  function transform_apply (line 360) | def transform_apply(rot, vec):
  function rot_matrix_det (line 369) | def rot_matrix_det(x):
  function rot_matrix_integrity_check (line 384) | def rot_matrix_integrity_check(x):
  function rot_matrix_from_quaternion (line 399) | def rot_matrix_from_quaternion(q):
  function euclidean_to_rotation_matrix (line 427) | def euclidean_to_rotation_matrix(x):
  function euclidean_integrity_check (line 435) | def euclidean_integrity_check(x):
  function euclidean_translation (line 442) | def euclidean_translation(x):
  function euclidean_inverse (line 450) | def euclidean_inverse(x):
  function euclidean_to_transform (line 462) | def euclidean_to_transform(transformation_matrix):

FILE: timechamber/tasks/ase_humanoid_base/poselib/poselib/core/tensor_utils.py
  class TensorUtils (line 13) | class TensorUtils(Serializable):
    method from_dict (line 15) | def from_dict(cls, dict_repr, *args, **kwargs):
    method to_dict (line 25) | def to_dict(self):
  function tensor_to_dict (line 32) | def tensor_to_dict(x):

FILE: timechamber/tasks/ase_humanoid_base/poselib/poselib/skeleton/backend/fbx/fbx_backend.py
  function fbx_to_npy (line 48) | def fbx_to_npy(file_name_in, root_joint_name, fps):
  function _get_frame_count (line 152) | def _get_frame_count(fbx_scene):
  function _get_animation_curve (line 177) | def _get_animation_curve(joint, fbx_scene):
  function _get_skeleton (line 233) | def _get_skeleton(root_joint):
  function _recursive_to_list (line 258) | def _recursive_to_list(array):
  function parse_fbx (line 273) | def parse_fbx(file_name_in, root_joint_name, fps):

FILE: timechamber/tasks/ase_humanoid_base/poselib/poselib/skeleton/backend/fbx/fbx_read_wrapper.py
  function fbx_to_array (line 25) | def fbx_to_array(fbx_file_path, root_joint, fps):

FILE: timechamber/tasks/ase_humanoid_base/poselib/poselib/skeleton/skeleton3d.py
  class SkeletonTree (line 42) | class SkeletonTree(Serializable):
    method __init__ (line 99) | def __init__(self, node_names, parent_indices, local_translation):
    method __len__ (line 116) | def __len__(self):
    method __iter__ (line 120) | def __iter__(self):
    method __getitem__ (line 124) | def __getitem__(self, item):
    method __repr__ (line 128) | def __repr__(self):
    method _indent (line 138) | def _indent(self, s):
    method node_names (line 142) | def node_names(self):
    method parent_indices (line 146) | def parent_indices(self):
    method local_translation (line 150) | def local_translation(self):
    method num_joints (line 154) | def num_joints(self):
    method from_dict (line 159) | def from_dict(cls, dict_repr, *args, **kwargs):
    method to_dict (line 166) | def to_dict(self):
    method from_mjcf (line 176) | def from_mjcf(cls, path: str) -> "SkeletonTree":
    method parent_of (line 222) | def parent_of(self, node_name):
    method index (line 231) | def index(self, node_name):
    method drop_nodes_by_names (line 240) | def drop_nodes_by_names(
    method keep_nodes_by_names (line 283) | def keep_nodes_by_names(
  class SkeletonState (line 290) | class SkeletonState(Serializable):
    method __init__ (line 360) | def __init__(self, tensor_backend, skeleton_tree, is_local):
    method __len__ (line 365) | def __len__(self):
    method rotation (line 369) | def rotation(self):
    method _local_rotation (line 377) | def _local_rotation(self):
    method _global_rotation (line 384) | def _global_rotation(self):
    method is_local (line 391) | def is_local(self):
    method invariant_property (line 399) | def invariant_property(self):
    method num_joints (line 403) | def num_joints(self):
    method skeleton_tree (line 411) | def skeleton_tree(self):
    method root_translation (line 419) | def root_translation(self):
    method global_transformation (line 431) | def global_transformation(self):
    method global_rotation (line 455) | def global_rotation(self):
    method global_translation (line 468) | def global_translation(self):
    method global_translation_xy (line 475) | def global_translation_xy(self):
    method global_translation_xz (line 482) | def global_translation_xz(self):
    method local_rotation (line 490) | def local_rotation(self):
    method local_transformation (line 513) | def local_transformation(self):
    method local_translation (line 523) | def local_translation(self):
    method root_translation_xy (line 542) | def root_translation_xy(self):
    method global_root_rotation (line 549) | def global_root_rotation(self):
    method global_root_yaw_rotation (line 556) | def global_root_yaw_rotation(self):
    method local_translation_to_root (line 564) | def local_translation_to_root(self):
    method local_rotation_to_root (line 573) | def local_rotation_to_root(self):
    method compute_forward_vector (line 580) | def compute_forward_vector(
    method _to_state_vector (line 620) | def _to_state_vector(rot, rt):
    method from_dict (line 630) | def from_dict(
    method to_dict (line 641) | def to_dict(self) -> OrderedDict:
    method from_rotation_and_root_translation (line 652) | def from_rotation_and_root_translation(cls, skeleton_tree, r, t, is_lo...
    method zero_pose (line 675) | def zero_pose(cls, skeleton_tree):
    method local_repr (line 690) | def local_repr(self):
    method global_repr (line 706) | def global_repr(self):
    method _get_pairwise_average_translation (line 722) | def _get_pairwise_average_translation(self):
    method _transfer_to (line 734) | def _transfer_to(self, new_skeleton_tree: SkeletonTree):
    method drop_nodes_by_names (line 743) | def drop_nodes_by_names(
    method keep_nodes_by_names (line 766) | def keep_nodes_by_names(
    method _remapped_to (line 785) | def _remapped_to(
    method retarget_to (line 814) | def retarget_to(
    method retarget_to_by_tpose (line 978) | def retarget_to_by_tpose(
  class SkeletonMotion (line 1026) | class SkeletonMotion(SkeletonState):
    method __init__ (line 1027) | def __init__(self, tensor_backend, skeleton_tree, is_local, fps, *args...
    method clone (line 1031) | def clone(self):
    method invariant_property (line 1037) | def invariant_property(self):
    method global_velocity (line 1045) | def global_velocity(self):
    method global_angular_velocity (line 1053) | def global_angular_velocity(self):
    method fps (line 1061) | def fps(self):
    method time_delta (line 1066) | def time_delta(self):
    method global_root_velocity (line 1071) | def global_root_velocity(self):
    method global_root_angular_velocity (line 1076) | def global_root_angular_velocity(self):
    method from_state_vector_and_velocity (line 1081) | def from_state_vector_and_velocity(
    method from_skeleton_state (line 1118) | def from_skeleton_state(
    method _to_state_vector (line 1151) | def _to_state_vector(rot, rt, vel, avel):
    method from_dict (line 1160) | def from_dict(
    method to_dict (line 1178) | def to_dict(self) -> OrderedDict:
    method from_fbx (line 1192) | def from_fbx(
    method _compute_velocity (line 1251) | def _compute_velocity(p, time_delta, guassian_filter=True):
    method _compute_angular_velocity (line 1261) | def _compute_angular_velocity(r, time_delta: float, guassian_filter=Tr...
    method crop (line 1276) | def crop(self, start: int, end: int, fps: Optional[int] = None):
    method retarget_to (line 1311) | def retarget_to(
    method retarget_to_by_tpose (line 1373) | def retarget_to_by_tpose(

FILE: timechamber/tasks/ase_humanoid_base/poselib/poselib/visualization/common.py
  function plot_skeleton_state (line 36) | def plot_skeleton_state(skeleton_state, task_name=""):
  function plot_skeleton_states (line 51) | def plot_skeleton_states(skeleton_state, skip_n=1, task_name=""):
  function plot_skeleton_motion (line 72) | def plot_skeleton_motion(skeleton_motion, skip_n=1, task_name=""):
  function plot_skeleton_motion_interactive_base (line 94) | def plot_skeleton_motion_interactive_base(skeleton_motion, task_name=""):
  function plot_skeleton_motion_interactive (line 189) | def plot_skeleton_motion_interactive(skeleton_motion, task_name=""):
  function plot_skeleton_motion_interactive_multiple (line 202) | def plot_skeleton_motion_interactive_multiple(*callables, sync=True):

FILE: timechamber/tasks/ase_humanoid_base/poselib/poselib/visualization/core.py
  class BasePlotterTask (line 36) | class BasePlotterTask(object):
    method __init__ (line 40) | def __init__(self, task_name: str, task_type: str) -> None:
    method task_name (line 45) | def task_name(self):
    method task_type (line 49) | def task_type(self):
    method get_scoped_name (line 52) | def get_scoped_name(self, name):
    method __iter__ (line 55) | def __iter__(self):
  class BasePlotterTasks (line 61) | class BasePlotterTasks(object):
    method __init__ (line 62) | def __init__(self, tasks) -> None:
    method __iter__ (line 65) | def __iter__(self):
  class BasePlotter (line 70) | class BasePlotter(object):
    method __init__ (line 77) | def __init__(self, task: BasePlotterTask) -> None:
    method task_primitives (line 82) | def task_primitives(self):
    method create (line 85) | def create(self, task: BasePlotterTask) -> None:
    method update (line 91) | def update(self) -> None:
    method _update_impl (line 95) | def _update_impl(self, task_list: List[BasePlotterTask]) -> None:
    method _create_impl (line 98) | def _create_impl(self, task_list: List[BasePlotterTask]) -> None:

FILE: timechamber/tasks/ase_humanoid_base/poselib/poselib/visualization/plt_plotter.py
  class Matplotlib2DPlotter (line 44) | class Matplotlib2DPlotter(BasePlotter):
    method __init__ (line 53) | def __init__(self, task: "BasePlotterTask") -> None:
    method ax (line 73) | def ax(self):
    method fig (line 77) | def fig(self):
    method show (line 80) | def show(self):
    method _min (line 83) | def _min(self, x, y):
    method _max (line 90) | def _max(self, x, y):
    method _init_lim (line 97) | def _init_lim(self):
    method _update_lim (line 103) | def _update_lim(self, xs, ys):
    method _set_lim (line 109) | def _set_lim(self):
    method _lines_extract_xy_impl (line 121) | def _lines_extract_xy_impl(index, lines_task):
    method _trail_extract_xy_impl (line 125) | def _trail_extract_xy_impl(index, trail_task):
    method _lines_create_impl (line 128) | def _lines_create_impl(self, lines_task):
    method _lines_update_impl (line 140) | def _lines_update_impl(self, lines_task):
    method _dots_create_impl (line 149) | def _dots_create_impl(self, dots_task):
    method _dots_update_impl (line 161) | def _dots_update_impl(self, dots_task):
    method _trail_create_impl (line 167) | def _trail_create_impl(self, trail_task):
    method _trail_update_impl (line 180) | def _trail_update_impl(self, trail_task):
    method _create_impl (line 189) | def _create_impl(self, task_list):
    method _update_impl (line 194) | def _update_impl(self, task_list):
    method _set_aspect_equal_2d (line 199) | def _set_aspect_equal_2d(self, zero_centered=True):
    method _draw (line 221) | def _draw(self):
  class Matplotlib3DPlotter (line 229) | class Matplotlib3DPlotter(BasePlotter):
    method __init__ (line 238) | def __init__(self, task: "BasePlotterTask") -> None:
    method ax (line 257) | def ax(self):
    method fig (line 261) | def fig(self):
    method show (line 264) | def show(self):
    method _min (line 267) | def _min(self, x, y):
    method _max (line 274) | def _max(self, x, y):
    method _init_lim (line 281) | def _init_lim(self):
    method _update_lim (line 289) | def _update_lim(self, xs, ys, zs):
    method _set_lim (line 297) | def _set_lim(self):
    method _lines_extract_xyz_impl (line 312) | def _lines_extract_xyz_impl(index, lines_task):
    method _trail_extract_xyz_impl (line 316) | def _trail_extract_xyz_impl(index, trail_task):
    method _lines_create_impl (line 323) | def _lines_create_impl(self, lines_task):
    method _lines_update_impl (line 335) | def _lines_update_impl(self, lines_task):
    method _dots_create_impl (line 345) | def _dots_create_impl(self, dots_task):
    method _dots_update_impl (line 358) | def _dots_update_impl(self, dots_task):
    method _trail_create_impl (line 365) | def _trail_create_impl(self, trail_task):
    method _trail_update_impl (line 378) | def _trail_update_impl(self, trail_task):
    method _create_impl (line 388) | def _create_impl(self, task_list):
    method _update_impl (line 393) | def _update_impl(self, task_list):
    method _set_aspect_equal_3d (line 398) | def _set_aspect_equal_3d(self):
    method _draw (line 419) | def _draw(self):

FILE: timechamber/tasks/ase_humanoid_base/poselib/poselib/visualization/simple_plotter_tasks.py
  class DrawXDLines (line 38) | class DrawXDLines(BasePlotterTask):
    method __init__ (line 45) | def __init__(
    method influence_lim (line 62) | def influence_lim(self) -> bool:
    method raw_data (line 66) | def raw_data(self):
    method color (line 70) | def color(self):
    method line_width (line 74) | def line_width(self):
    method alpha (line 78) | def alpha(self):
    method dim (line 82) | def dim(self):
    method name (line 86) | def name(self):
    method update (line 89) | def update(self, lines):
    method __getitem__ (line 94) | def __getitem__(self, index):
    method __len__ (line 97) | def __len__(self):
    method __iter__ (line 100) | def __iter__(self):
  class DrawXDDots (line 104) | class DrawXDDots(BasePlotterTask):
    method __init__ (line 111) | def __init__(
    method update (line 127) | def update(self, dots):
    method __getitem__ (line 132) | def __getitem__(self, index):
    method __len__ (line 135) | def __len__(self):
    method __iter__ (line 138) | def __iter__(self):
    method influence_lim (line 142) | def influence_lim(self) -> bool:
    method raw_data (line 146) | def raw_data(self):
    method color (line 150) | def color(self):
    method marker_size (line 154) | def marker_size(self):
    method alpha (line 158) | def alpha(self):
    method dim (line 162) | def dim(self):
    method name (line 166) | def name(self):
  class DrawXDTrail (line 170) | class DrawXDTrail(DrawXDDots):
    method line_width (line 172) | def line_width(self):
    method name (line 176) | def name(self):
  class Draw2DLines (line 180) | class Draw2DLines(DrawXDLines):
    method dim (line 182) | def dim(self):
  class Draw3DLines (line 186) | class Draw3DLines(DrawXDLines):
    method dim (line 188) | def dim(self):
  class Draw2DDots (line 192) | class Draw2DDots(DrawXDDots):
    method dim (line 194) | def dim(self):
  class Draw3DDots (line 198) | class Draw3DDots(DrawXDDots):
    method dim (line 200) | def dim(self):
  class Draw2DTrail (line 204) | class Draw2DTrail(DrawXDTrail):
    method dim (line 206) | def dim(self):
  class Draw3DTrail (line 210) | class Draw3DTrail(DrawXDTrail):
    method dim (line 212) | def dim(self):

FILE: timechamber/tasks/ase_humanoid_base/poselib/poselib/visualization/skeleton_plotter_tasks.py
  class Draw3DSkeletonState (line 40) | class Draw3DSkeletonState(BasePlotterTask):
    method __init__ (line 44) | def __init__(
    method name (line 62) | def name(self):
    method update (line 65) | def update(self, skeleton_state) -> None:
    method _get_lines_and_dots (line 69) | def _get_lines_and_dots(skeleton_state):
    method _update (line 86) | def _update(self, lines, dots) -> None:
    method __iter__ (line 90) | def __iter__(self):
  class Draw3DSkeletonMotion (line 95) | class Draw3DSkeletonMotion(BasePlotterTask):
    method __init__ (line 96) | def __init__(
    method name (line 155) | def name(self):
    method update (line 158) | def update(self, frame_index=None, reset_trail=False, skeleton_motion=...
    method _get_vel_and_avel (line 182) | def _get_vel_and_avel(skeleton_motion):
    method _update (line 193) | def _update(self, vel_lines, avel_lines) -> None:
    method __iter__ (line 197) | def __iter__(self):
  class Draw3DSkeletonMotions (line 204) | class Draw3DSkeletonMotions(BasePlotterTask):
    method __init__ (line 205) | def __init__(self, skeleton_motion_tasks) -> None:
    method name (line 209) | def name(self):
    method update (line 212) | def update(self, frame_index) -> None:
    method __iter__ (line 215) | def __iter__(self):

FILE: timechamber/tasks/ase_humanoid_base/poselib/retarget_motion.py
  function project_joints (line 52) | def project_joints(motion):
  function main (line 206) | def main():

FILE: timechamber/tasks/base/ma_vec_task.py
  class MA_VecTask (line 50) | class MA_VecTask(Env):
    method __init__ (line 52) | def __init__(self, config, rl_device, sim_device, graphics_device_id, ...
    method set_viewer (line 103) | def set_viewer(self):
    method allocate_buffers (line 132) | def allocate_buffers(self):
    method set_sim_params_up_axis (line 157) | def set_sim_params_up_axis(self, sim_params: gymapi.SimParams, axis: s...
    method create_sim (line 174) | def create_sim(self, compute_device: int, graphics_device: int, physic...
    method get_state (line 192) | def get_state(self):
    method pre_physics_step (line 197) | def pre_physics_step(self, actions: torch.Tensor):
    method post_physics_step (line 205) | def post_physics_step(self):
    method step (line 208) | def step(self, actions: torch.Tensor) -> Tuple[Dict[str, torch.Tensor]...
    method zero_actions (line 250) | def zero_actions(self) -> torch.Tensor:
    method reset (line 261) | def reset(self, env_ids=None) -> torch.Tensor:
    method _reset_envs (line 275) | def _reset_envs(self, env_ids):
    method reset_done (line 282) | def reset_done(self):
    method render (line 298) | def render(self):
    method __parse_sim_params (line 328) | def __parse_sim_params(self, physics_engine: str, config_sim: Dict[str...
    method get_actor_params_info (line 382) | def get_actor_params_info(self, dr_params: Dict[str, Any], env):
    method apply_randomizations (line 424) | def apply_randomizations(self, dr_params):

FILE: timechamber/tasks/base/vec_task.py
  function _create_sim_once (line 50) | def _create_sim_once(gym, *args, **kwargs):
  class Env (line 59) | class Env(ABC):
    method __init__ (line 60) | def __init__(self, config: Dict[str, Any], rl_device: str, sim_device:...
    method allocate_buffers (line 110) | def allocate_buffers(self):
    method step (line 114) | def step(self, actions: torch.Tensor) -> Tuple[Dict[str, torch.Tensor]...
    method reset (line 125) | def reset(self)-> Dict[str, torch.Tensor]:
    method reset_idx (line 132) | def reset_idx(self, env_ids: torch.Tensor):
    method observation_space (line 139) | def observation_space(self) -> gym.Space:
    method action_space (line 144) | def action_space(self) -> gym.Space:
    method num_envs (line 149) | def num_envs(self) -> int:
    method num_acts (line 154) | def num_acts(self) -> int:
    method num_obs (line 159) | def num_obs(self) -> int:
  class VecTask (line 164) | class VecTask(Env):
    method __init__ (line 168) | def __init__(self, config, rl_device, sim_device, graphics_device_id, ...
    method set_viewer (line 224) | def set_viewer(self):
    method allocate_buffers (line 253) | def allocate_buffers(self):
    method create_sim (line 278) | def create_sim(self, compute_device: int, graphics_device: int, physic...
    method get_state (line 296) | def get_state(self):
    method pre_physics_step (line 301) | def pre_physics_step(self, actions: torch.Tensor):
    method post_physics_step (line 309) | def post_physics_step(self):
    method step (line 312) | def step(self, actions: torch.Tensor) -> Tuple[Dict[str, torch.Tensor]...
    method zero_actions (line 360) | def zero_actions(self) -> torch.Tensor:
    method reset_idx (line 370) | def reset_idx(self, env_idx):
    method reset (line 376) | def reset(self):
    method reset_done (line 390) | def reset_done(self):
    method render (line 407) | def render(self, mode="rgb_array"):
    method __parse_sim_params (line 441) | def __parse_sim_params(self, physics_engine: str, config_sim: Dict[str...
    method get_actor_params_info (line 495) | def get_actor_params_info(self, dr_params: Dict[str, Any], env):
    method apply_randomizations (line 537) | def apply_randomizations(self, dr_params):

FILE: timechamber/tasks/ma_ant_battle.py
  class MA_Ant_Battle (line 12) | class MA_Ant_Battle(MA_VecTask):
    method __init__ (line 14) | def __init__(self, cfg, sim_device, rl_device, graphics_device_id, hea...
    method allocate_buffers (line 102) | def allocate_buffers(self):
    method create_sim (line 122) | def create_sim(self):
    method _add_circle_borderline (line 142) | def _add_circle_borderline(self, env, radius):
    method _create_ground_plane (line 147) | def _create_ground_plane(self):
    method _create_envs (line 154) | def _create_envs(self, num_envs, spacing, num_per_row):
    method compute_reward (line 244) | def compute_reward(self, actions):
    method compute_observations (line 268) | def compute_observations(self):
    method reset_idx (line 287) | def reset_idx(self, env_ids):
    method pre_physics_step (line 328) | def pre_physics_step(self, actions):
    method post_physics_step (line 349) | def post_physics_step(self):
    method get_number_of_agents (line 368) | def get_number_of_agents(self):
    method zero_actions (line 372) | def zero_actions(self) -> torch.Tensor:
    method clear_count (line 383) | def clear_count(self):
  function expand_env_ids (line 394) | def expand_env_ids(env_ids, n_agents):
  function compute_ant_reward (line 405) | def compute_ant_reward(
  function compute_ant_observations (line 465) | def compute_ant_observations(
  function randomize_rotation (line 501) | def randomize_rotation(rand0, rand1, x_unit_tensor, y_unit_tensor):

FILE: timechamber/tasks/ma_ant_sumo.py
  class MA_Ant_Sumo (line 19) | class MA_Ant_Sumo(MA_VecTask):
    method __init__ (line 21) | def __init__(self, cfg, sim_device, rl_device, graphics_device_id, hea...
    method allocate_buffers (line 113) | def allocate_buffers(self):
    method create_sim (line 130) | def create_sim(self):
    method _add_circle_borderline (line 142) | def _add_circle_borderline(self, env):
    method _create_ground_plane (line 155) | def _create_ground_plane(self):
    method _create_envs (line 162) | def _create_envs(self, num_envs, spacing, num_per_row):
    method compute_reward (line 271) | def compute_reward(self, actions):
    method compute_observations (line 298) | def compute_observations(self):
    method reset_idx (line 326) | def reset_idx(self, env_ids):
    method pre_physics_step (line 374) | def pre_physics_step(self, actions):
    method post_physics_step (line 388) | def post_physics_step(self):
    method get_number_of_agents (line 396) | def get_number_of_agents(self):
    method zero_actions (line 400) | def zero_actions(self) -> torch.Tensor:
    method clear_count (line 411) | def clear_count(self):
  function expand_env_ids (line 423) | def expand_env_ids(env_ids, n_agents):
  function compute_move_reward (line 433) | def compute_move_reward(
  function compute_ant_reward (line 449) | def compute_ant_reward(
  function compute_ant_observations (line 508) | def compute_ant_observations(
  function randomize_rotation (line 529) | def randomize_rotation(rand0, rand1, x_unit_tensor, y_unit_tensor):

FILE: timechamber/tasks/ma_humanoid_strike.py
  class HumanoidStrike (line 41) | class HumanoidStrike(humanoid_amp_task.HumanoidAMPTask):
    method __init__ (line 42) | def __init__(self, cfg, sim_params, physics_engine, device_type, devic...
    method get_task_obs_size (line 80) | def get_task_obs_size(self):
    method _create_envs (line 86) | def _create_envs(self, num_envs, spacing, num_per_row):
    method _build_env (line 91) | def _build_env(self, env_id, env_ptr, humanoid_asset, humanoid_asset_op):
    method _build_body_ids_tensor (line 95) | def _build_body_ids_tensor(self, env_ptr, actor_handle, body_names):
    method _reset_actors (line 108) | def _reset_actors(self, env_ids):
    method _reset_env_tensors (line 141) | def _reset_env_tensors(self, env_ids):
    method pre_physics_step (line 147) | def pre_physics_step(self, actions):
    method post_physics_step (line 154) | def post_physics_step(self):
    method _compute_observations (line 159) | def _compute_observations(self):
    method _compute_task_obs (line 170) | def _compute_task_obs(self):
    method _compute_reward (line 197) | def _compute_reward(self, actions):
    method _compute_reset (line 233) | def _compute_reset(self):
  function compute_strike_observations (line 255) | def compute_strike_observations(root_states, root_states_op, body_pos, b...
  function compute_strike_reward (line 343) | def compute_strike_reward(root_states, root_states_op, body_pos, body_an...
  function compute_humanoid_reset (line 522) | def compute_humanoid_reset(reset_buf, progress_buf, ego_to_op_damage, op...
  function expand_env_ids (line 593) | def expand_env_ids(env_ids, n_agents):

FILE: timechamber/train.py
  function launch_rlg_hydra (line 65) | def launch_rlg_hydra(cfg: DictConfig):

FILE: timechamber/utils/config.py
  function parse_sim_params (line 14) | def parse_sim_params(args, cfg):

FILE: timechamber/utils/gym_util.py
  function setup_gym_viewer (line 36) | def setup_gym_viewer(config):
  function initialize_gym (line 42) | def initialize_gym(config):
  function configure_gym (line 51) | def configure_gym(gym, config):
  function parse_states_from_reference_states (line 102) | def parse_states_from_reference_states(reference_states, progress):
  function parse_states_from_reference_states_with_motion_id (line 123) | def parse_states_from_reference_states_with_motion_id(precomputed_state,
  function parse_dof_state_with_motion_id (line 139) | def parse_dof_state_with_motion_id(precomputed_state, dof_state,
  function get_flatten_ids (line 152) | def get_flatten_ids(precomputed_state):
  function parse_states_from_reference_states_with_global_id (line 168) | def parse_states_from_reference_states_with_global_id(precomputed_state,
  function get_robot_states_from_torch_tensor (line 181) | def get_robot_states_from_torch_tensor(config, ts, global_quats, vels, a...
  function get_xyzoffset (line 235) | def get_xyzoffset(start_ts, end_ts, root_yaw_inv):

FILE: timechamber/utils/logger.py
  class _MyFormatter (line 19) | class _MyFormatter(logging.Formatter):
    method format (line 25) | def format(self, record):
  class GLOBAL_PATH (line 57) | class GLOBAL_PATH(object):
    method __init__ (line 59) | def __init__(self, path=None):
    method _set_path (line 64) | def _set_path(self, path):
    method _get_path (line 67) | def _get_path(self):
  function set_file_handler (line 74) | def set_file_handler(path=None, prefix='', time_str=''):
  function _get_path (line 107) | def _get_path():

FILE: timechamber/utils/motion_lib.py
  class Patch (line 46) | class Patch:
    method numpy (line 47) | def numpy(self):
  class DeviceCache (line 55) | class DeviceCache:
    method __init__ (line 56) | def __init__(self, obj, device):
    method __getattr__ (line 87) | def __getattr__(self, string):
  class MotionLib (line 92) | class MotionLib():
    method __init__ (line 93) | def __init__(self, motion_file, dof_body_ids, dof_offsets,
    method num_motions (line 119) | def num_motions(self):
    method get_total_length (line 122) | def get_total_length(self):
    method get_motion (line 125) | def get_motion(self, motion_id):
    method sample_motions (line 128) | def sample_motions(self, n):
    method sample_time (line 136) | def sample_time(self, motion_ids, truncate_time=None):
    method get_motion_length (line 148) | def get_motion_length(self, motion_ids):
    method get_motion_state (line 151) | def get_motion_state(self, motion_ids, motion_times):
    method _load_motions (line 202) | def _load_motions(self, motion_file):
    method _fetch_motion_files (line 266) | def _fetch_motion_files(self, motion_file):
    method _calc_frame_blend (line 291) | def _calc_frame_blend(self, time, len, num_frames, dt):
    method _get_num_bodies (line 302) | def _get_num_bodies(self):
    method _compute_motion_dof_vels (line 307) | def _compute_motion_dof_vels(self, motion):
    method _local_rotation_to_dof (line 324) | def _local_rotation_to_dof(self, local_rot):
    method _local_rotation_to_dof_vel (line 354) | def _local_rotation_to_dof_vel(self, local_rot0, local_rot1, dt):

FILE: timechamber/utils/reformat.py
  function omegaconf_to_dict (line 32) | def omegaconf_to_dict(d: DictConfig)->Dict:
  function print_dict (line 42) | def print_dict(val, nesting: int = -4, start: bool = True):

FILE: timechamber/utils/rlgames_utils.py
  function get_rlgames_env_creator (line 45) | def get_rlgames_env_creator(
  class RLGPUAlgoObserver (line 122) | class RLGPUAlgoObserver(AlgoObserver):
    method __init__ (line 125) | def __init__(self):
    method after_init (line 128) | def after_init(self, algo):
    method process_infos (line 135) | def process_infos(self, infos, done_indices):
    method after_clear_stats (line 148) | def after_clear_stats(self):
    method after_print_stats (line 151) | def after_print_stats(self, frame, epoch_num, total_time):
  class RLGPUEnv (line 178) | class RLGPUEnv(vecenv.IVecEnv):
    method __init__ (line 179) | def __init__(self, config_name, num_actors, **kwargs):
    method step (line 189) | def step(self, action):
    method reset (line 198) | def reset(self, env_ids=None):
    method get_number_of_agents (line 204) | def get_number_of_agents(self):
    method get_env_info (line 207) | def get_env_info(self):

FILE: timechamber/utils/torch_jit_utils.py
  function compute_heading_and_up (line 35) | def compute_heading_and_up(
  function compute_rot (line 53) | def compute_rot(torso_quat, velocity, ang_velocity, targets, torso_posit...
  function quat_axis (line 67) | def quat_axis(q, axis=0):
  function scale_transform (line 80) | def scale_transform(x: torch.Tensor, lower: torch.Tensor, upper: torch.T...
  function unscale_transform (line 101) | def unscale_transform(x: torch.Tensor, lower: torch.Tensor, upper: torch...
  function saturate (line 121) | def saturate(x: torch.Tensor, lower: torch.Tensor, upper: torch.Tensor) ...
  function quat_diff_rad (line 142) | def quat_diff_rad(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
  function local_to_world_space (line 164) | def local_to_world_space(pos_offset_local: torch.Tensor, pose_global: to...
  function normalise_quat_in_pose (line 185) | def normalise_quat_in_pose(pose):
  function my_quat_rotate (line 199) | def my_quat_rotate(q, v):
  function quat_to_angle_axis (line 211) | def quat_to_angle_axis(q):
  function angle_axis_to_exp_map (line 234) | def angle_axis_to_exp_map(angle, axis):
  function quat_to_exp_map (line 242) | def quat_to_exp_map(q):
  function quat_to_tan_norm (line 251) | def quat_to_tan_norm(q):
  function euler_xyz_to_exp_map (line 266) | def euler_xyz_to_exp_map(roll, pitch, yaw):
  function exp_map_to_angle_axis (line 273) | def exp_map_to_angle_axis(exp_map):
  function exp_map_to_quat (line 292) | def exp_map_to_quat(exp_map):
  function slerp (line 298) | def slerp(q0, q1, t):
  function calc_heading (line 333) | def calc_heading(q):
  function calc_heading_quat (line 346) | def calc_heading_quat(q):
  function calc_heading_quat_inv (line 359) | def calc_heading_quat_inv(q):

FILE: timechamber/utils/torch_utils.py
  function quat_to_angle_axis (line 35) | def quat_to_angle_axis(q):
  function angle_axis_to_exp_map (line 58) | def angle_axis_to_exp_map(angle, axis):
  function quat_to_exp_map (line 66) | def quat_to_exp_map(q):
  function quat_to_tan_norm (line 75) | def quat_to_tan_norm(q):
  function euler_xyz_to_exp_map (line 90) | def euler_xyz_to_exp_map(roll, pitch, yaw):
  function exp_map_to_angle_axis (line 97) | def exp_map_to_angle_axis(exp_map):
  function exp_map_to_quat (line 116) | def exp_map_to_quat(exp_map):
  function slerp (line 122) | def slerp(q0, q1, t):
  function calc_heading (line 146) | def calc_heading(q):
  function calc_heading_quat (line 159) | def calc_heading_quat(q):
  function calc_heading_quat_inv (line 172) | def calc_heading_quat_inv(q):

FILE: timechamber/utils/utils.py
  function set_np_formatting (line 40) | def set_np_formatting():
  function set_seed (line 47) | def set_seed(seed, torch_deterministic=False, rank=0):
  function load_check (line 77) | def load_check(checkpoint, normalize_input: bool, normalize_value: bool):
  function safe_filesystem_op (line 93) | def safe_filesystem_op(func, *args, **kwargs):
  function safe_load (line 110) | def safe_load(filename, device=None):
  function load_checkpoint (line 116) | def load_checkpoint(filename, device=None):
  function print_actor_info (line 121) | def print_actor_info(gym, env, actor_handle):
  function print_asset_info (line 182) | def print_asset_info(asset, name, gym):

FILE: timechamber/utils/vec_task.py
  class VecTask (line 17) | class VecTask():
    method __init__ (line 18) | def __init__(self, task, rl_device, clip_observations=5.0, clip_action...
    method step (line 37) | def step(self, actions):
    method reset (line 40) | def reset(self):
    method get_number_of_agents (line 43) | def get_number_of_agents(self):
    method observation_space (line 47) | def observation_space(self):
    method action_space (line 51) | def action_space(self):
    method num_envs (line 55) | def num_envs(self):
    method num_acts (line 59) | def num_acts(self):
    method num_obs (line 63) | def num_obs(self):
  class VecTaskCPU (line 68) | class VecTaskCPU(VecTask):
    method __init__ (line 69) | def __init__(self, task, rl_device, sync_frame_time=False, clip_observ...
    method step (line 73) | def step(self, actions):
    method reset (line 83) | def reset(self):
  class VecTaskGPU (line 93) | class VecTaskGPU(VecTask):
    method __init__ (line 94) | def __init__(self, task, rl_device, clip_observations=5.0, clip_action...
    method step (line 101) | def step(self, actions):
    method reset (line 110) | def reset(self):
  class VecTaskPython (line 121) | class VecTaskPython(VecTask):
    method get_state (line 123) | def get_state(self):
    method step (line 126) | def step(self, actions):
    method reset (line 133) | def reset(self):

FILE: timechamber/utils/vec_task_wrappers.py
  class VecTaskCPUWrapper (line 34) | class VecTaskCPUWrapper(VecTaskCPU):
    method __init__ (line 35) | def __init__(self, task, rl_device, sync_frame_time=False, clip_observ...
  class VecTaskGPUWrapper (line 39) | class VecTaskGPUWrapper(VecTaskGPU):
    method __init__ (line 40) | def __init__(self, task, rl_device, clip_observations=5.0, clip_action...
  class VecTaskPythonWrapper (line 45) | class VecTaskPythonWrapper(VecTaskPython):
    method __init__ (line 46) | def __init__(self, task, rl_device, clip_observations=5.0, clip_action...
    method reset (line 54) | def reset(self, env_ids=None):
    method amp_observation_space (line 59) | def amp_observation_space(self):
    method fetch_amp_obs_demo (line 62) | def fetch_amp_obs_demo(self, num_samples):
Condensed preview — 201 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (861K chars).
[
  {
    "path": ".gitattributes",
    "chars": 0,
    "preview": ""
  },
  {
    "path": ".gitignore",
    "chars": 521,
    "preview": "videos\n/timechamber/logs\n*train_dir*\n*ige_logs*\n*.egg-info\n/.vs\n/.vscode\n/_package\n/shaders\n._tmptext.txt\n__pycache__/\n/"
  },
  {
    "path": "LICENSE",
    "chars": 1069,
    "preview": "MIT License\n\nCopyright (c) 2022 MIT Inspir.ai\n\nPermission is hereby granted, free of charge, to any person obtaining a c"
  },
  {
    "path": "LISENCE/isaacgymenvs/LICENSE",
    "chars": 1557,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "README.md",
    "chars": 8411,
    "preview": "# TimeChamber: A Massively Parallel Large Scale Self-Play Framework\n\n****\n\n**TimeChamber** is a large scale self-play fr"
  },
  {
    "path": "assets/mjcf/nv_ant.xml",
    "chars": 5160,
    "preview": "<mujoco model=\"ant\">\n  <custom>\n    <numeric data=\"0.0 0.0 0.55 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -1.0 0.0 -1.0 0.0 1.0\" name="
  },
  {
    "path": "docs/environments.md",
    "chars": 7443,
    "preview": "## Environments\n\nWe provide a detailed description of the environment here.\n\n### Humanoid Strike\n\nHumanoid Strike is a 3"
  },
  {
    "path": "setup.py",
    "chars": 1150,
    "preview": "\"\"\"Installation script for the 'timechamber' python package.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ "
  },
  {
    "path": "timechamber/__init__.py",
    "chars": 2194,
    "preview": "import hydra\nfrom hydra import compose, initialize\nfrom hydra.core.hydra_config import HydraConfig\nfrom omegaconf import"
  },
  {
    "path": "timechamber/ase/ase_agent.py",
    "chars": 23370,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/ase/ase_models.py",
    "chars": 3653,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/ase/ase_network_builder.py",
    "chars": 14573,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/ase/ase_players.py",
    "chars": 7018,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/ase/hrl_agent.py",
    "chars": 12346,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/ase/hrl_models.py",
    "chars": 3056,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/ase/hrl_network_builder.py",
    "chars": 3005,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/ase/hrl_players.py",
    "chars": 10959,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/ase/utils/amp_agent.py",
    "chars": 27339,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/ase/utils/amp_datasets.py",
    "chars": 2563,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/ase/utils/amp_models.py",
    "chars": 3969,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/ase/utils/amp_network_builder.py",
    "chars": 5950,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/ase/utils/amp_players.py",
    "chars": 4451,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/ase/utils/common_agent.py",
    "chars": 24428,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/ase/utils/common_player.py",
    "chars": 9123,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/ase/utils/replay_buffer.py",
    "chars": 3986,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/cfg/config.yaml",
    "chars": 1948,
    "preview": "# Task name - used to pick the class to load\ntask_name: ${task.name}\n# experiment name. defaults to name of training con"
  },
  {
    "path": "timechamber/cfg/task/MA_Ant_Battle.yaml",
    "chars": 3205,
    "preview": "# used to create the object\nname: MA_Ant_Battle\n\nphysics_engine: ${..physics_engine}\n\n# if given, will override the devi"
  },
  {
    "path": "timechamber/cfg/task/MA_Ant_Sumo.yaml",
    "chars": 3017,
    "preview": "# used to create the object\nname: MA_Ant_Sumo\n\nphysics_engine: ${..physics_engine}\n\n# if given, will override the device"
  },
  {
    "path": "timechamber/cfg/task/MA_Humanoid_Strike.yaml",
    "chars": 1500,
    "preview": "name: MA_Humanoid_Strike\n\nphysics_engine: ${..physics_engine}\n\n# if given, will override the device setting in gym. \nenv"
  },
  {
    "path": "timechamber/cfg/train/MA_Ant_BattlePPO.yaml",
    "chars": 1961,
    "preview": "params:\n  seed: ${...seed}\n\n  algo:\n    name: self_play_continuous\n\n  model:\n    name: continuous_a2c_logstd\n\n  network:"
  },
  {
    "path": "timechamber/cfg/train/MA_Ant_SumoPPO.yaml",
    "chars": 1989,
    "preview": "params:\n  seed: ${...seed}\n\n  algo:\n    name: self_play_continuous\n\n  model:\n    name: continuous_a2c_logstd\n\n  network:"
  },
  {
    "path": "timechamber/cfg/train/MA_Humanoid_StrikeHRL.yaml",
    "chars": 2118,
    "preview": "params:\n  seed: ${...seed}\n\n  algo:\n    name: self_play_hrl\n\n  model:\n    name: hrl\n\n  network:\n    name: hrl\n    separa"
  },
  {
    "path": "timechamber/cfg/train/base/ase_humanoid_hrl.yaml",
    "chars": 2158,
    "preview": "params:\n  seed: -1\n\n  algo:\n    name: ase\n\n  model:\n    name: ase\n\n  network:\n    name: ase\n    separate: True\n\n    spac"
  },
  {
    "path": "timechamber/learning/common_agent.py",
    "chars": 20332,
    "preview": "# License: see [LICENSE, LICENSES/isaacgymenvs/LICENSE]\nimport copy\nfrom datetime import datetime\nfrom gym import spaces"
  },
  {
    "path": "timechamber/learning/common_player.py",
    "chars": 6068,
    "preview": "# License: see [LICENSE, LICENSES/isaacgymenvs/LICENSE]\n\nimport torch \n\nfrom rl_games.algos_torch import players\nfrom rl"
  },
  {
    "path": "timechamber/learning/hrl_sp_agent.py",
    "chars": 20364,
    "preview": "import copy\nfrom collections import OrderedDict\nfrom datetime import datetime\nfrom gym import spaces\nimport numpy as np\n"
  },
  {
    "path": "timechamber/learning/hrl_sp_player.py",
    "chars": 18469,
    "preview": "# License: see [LICENSE, LICENSES/isaacgymenvs/LICENSE]\nimport os\nimport time\nimport torch\nimport numpy as np\nfrom rl_ga"
  },
  {
    "path": "timechamber/learning/pfsp_player_pool.py",
    "chars": 9976,
    "preview": "import collections\n\nimport random\nimport torch\nimport torch.multiprocessing as mp\nimport dill\n# import time\nfrom rl_game"
  },
  {
    "path": "timechamber/learning/ppo_sp_agent.py",
    "chars": 17388,
    "preview": "# License: see [LICENSE, LICENSES/isaacgymenvs/LICENSE]\n\nimport copy\nfrom datetime import datetime\nfrom gym import space"
  },
  {
    "path": "timechamber/learning/ppo_sp_player.py",
    "chars": 17614,
    "preview": "# License: see [LICENSE, LICENSES/isaacgymenvs/LICENSE]\nimport os\nimport time\nimport torch\nimport numpy as np\nfrom rl_ga"
  },
  {
    "path": "timechamber/learning/replay_buffer.py",
    "chars": 2473,
    "preview": "# License: see [LICENSE, LICENSES/isaacgymenvs/LICENSE]\n\nimport torch\n\n\nclass ReplayBuffer():\n    def __init__(self, buf"
  },
  {
    "path": "timechamber/learning/vectorized_models.py",
    "chars": 6307,
    "preview": "import torch\nimport torch.nn as nn\nfrom rl_games.algos_torch.running_mean_std import RunningMeanStd, RunningMeanStdObs\nf"
  },
  {
    "path": "timechamber/learning/vectorized_network_builder.py",
    "chars": 4496,
    "preview": "import torch\nimport torch.nn as nn\nimport math\nfrom rl_games.algos_torch import network_builder\n\n\nclass VectorizedLinear"
  },
  {
    "path": "timechamber/tasks/__init__.py",
    "chars": 1859,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/base_task.py",
    "chars": 20058,
    "preview": "# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\n# NVIDIA CORPORATION and its licensors retain all intell"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/humanoid.py",
    "chars": 38712,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/humanoid_amp.py",
    "chars": 15783,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/humanoid_amp_task.py",
    "chars": 3868,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/README.md",
    "chars": 5585,
    "preview": "# poselib\n\n`poselib` is a library for loading, manipulating, and retargeting skeleton poses and motions. It is separated"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/data/configs/retarget_cmu_to_amp.json",
    "chars": 890,
    "preview": "{\n    \"source_motion\": \"data/01_01_cmu.npy\",\n    \"target_motion_path\": \"data/01_01_cmu_amp.npy\",\n    \"source_tpose\": \"da"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/data/configs/retarget_sfu_to_amp.json",
    "chars": 897,
    "preview": "{\n    \"source_motion\": \"data/0005_Jogging001.npy\",\n    \"target_motion_path\": \"data/0005_Jogging001_amp.npy\",\n    \"source"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/fbx_importer.py",
    "chars": 2119,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/generate_amp_humanoid_tpose.py",
    "chars": 2816,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/mjcf_importer.py",
    "chars": 2003,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/__init__.py",
    "chars": 1603,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/core/__init__.py",
    "chars": 521,
    "preview": "# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.\n# NVIDIA CORPORATION and its licensors retain all intell"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/core/backend/__init__.py",
    "chars": 488,
    "preview": "# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.\n# NVIDIA CORPORATION and its licensors retain all intell"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/core/backend/abstract.py",
    "chars": 5159,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/core/backend/logger.py",
    "chars": 1930,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/core/rotation3d.py",
    "chars": 13466,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/core/tensor_utils.py",
    "chars": 1420,
    "preview": "# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.\n# NVIDIA CORPORATION and its licensors retain all intell"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/core/tests/__init__.py",
    "chars": 1557,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/core/tests/test_rotation.py",
    "chars": 3256,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/skeleton/__init__.py",
    "chars": 1557,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/skeleton/backend/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/skeleton/backend/fbx/__init__.py",
    "chars": 423,
    "preview": "# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.\n# NVIDIA CORPORATION and its licensors retain all intell"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/skeleton/backend/fbx/fbx_backend.py",
    "chars": 10372,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/skeleton/backend/fbx/fbx_read_wrapper.py",
    "chars": 1253,
    "preview": "# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.\n# NVIDIA CORPORATION and its licensors retain all intell"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/skeleton/skeleton3d.py",
    "chars": 57916,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/visualization/__init__.py",
    "chars": 423,
    "preview": "# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.\n# NVIDIA CORPORATION and its licensors retain all intell"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/visualization/common.py",
    "chars": 8107,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/visualization/core.py",
    "chars": 3700,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/visualization/plt_plotter.py",
    "chars": 14522,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/visualization/simple_plotter_tasks.py",
    "chars": 5246,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/visualization/skeleton_plotter_tasks.py",
    "chars": 7974,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/visualization/tests/__init__.py",
    "chars": 423,
    "preview": "# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.\n# NVIDIA CORPORATION and its licensors retain all intell"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/poselib/visualization/tests/test_plotter.py",
    "chars": 1011,
    "preview": "# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.\n# NVIDIA CORPORATION and its licensors retain all intell"
  },
  {
    "path": "timechamber/tasks/ase_humanoid_base/poselib/retarget_motion.py",
    "chars": 15546,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/tasks/base/__init__.py",
    "chars": 1558,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/tasks/base/ma_vec_task.py",
    "chars": 29077,
    "preview": "# Copyright (c) 2018-2021, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/tasks/base/vec_task.py",
    "chars": 33440,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/tasks/data/assets/mjcf/amp_humanoid_sword_shield.xml",
    "chars": 14103,
    "preview": "<mujoco model=\"humanoid\">\n\n  <statistic extent=\"2\" center=\"0 0 1\"/>\n\n  <option timestep=\"0.00555\"/>\n\n  <default>\n    <mo"
  },
  {
    "path": "timechamber/tasks/data/motions/reallusion_sword_shield/README.txt",
    "chars": 290,
    "preview": "This motion data is provided courtesy of Reallusion,\nstrictly for noncommercial use. The original motion data\nis availab"
  },
  {
    "path": "timechamber/tasks/data/motions/reallusion_sword_shield/dataset_reallusion_sword_shield.yaml",
    "chars": 6121,
    "preview": "motions:\n  - file: \"RL_Avatar_Atk_2xCombo01_Motion.npy\"\n    weight: 0.00724638\n  - file: \"RL_Avatar_Atk_2xCombo02_Motion"
  },
  {
    "path": "timechamber/tasks/ma_ant_battle.py",
    "chars": 25391,
    "preview": "from typing import Tuple\nimport os\n\nimport torch\nfrom isaacgym import gymtorch\nfrom isaacgym.gymtorch import *\n\nfrom tim"
  },
  {
    "path": "timechamber/tasks/ma_ant_sumo.py",
    "chars": 24855,
    "preview": "from typing import Tuple\nimport numpy as np\nimport os\nimport math\nimport torch\nimport random\n\nfrom isaacgym import gymto"
  },
  {
    "path": "timechamber/tasks/ma_humanoid_strike.py",
    "chars": 28787,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/train.py",
    "chars": 8062,
    "preview": "# train.py\n# Script to train policies in Isaac Gym\n#\n# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved"
  },
  {
    "path": "timechamber/utils/config.py",
    "chars": 1503,
    "preview": "import os\nimport sys\nimport yaml\n\nfrom isaacgym import gymapi\nfrom isaacgym import gymutil\n\nimport numpy as np\nimport ra"
  },
  {
    "path": "timechamber/utils/gym_util.py",
    "chars": 9881,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/utils/logger.py",
    "chars": 3186,
    "preview": "# -----------------------------------------------------------------------------\n#   @brief:\n#       The logger here will"
  },
  {
    "path": "timechamber/utils/motion_lib.py",
    "chars": 14347,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/utils/reformat.py",
    "chars": 2314,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/utils/rlgames_utils.py",
    "chars": 9465,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/utils/torch_jit_utils.py",
    "chars": 12280,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/utils/torch_utils.py",
    "chars": 6110,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/utils/utils.py",
    "chars": 7901,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  },
  {
    "path": "timechamber/utils/vec_task.py",
    "chars": 5253,
    "preview": "# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\n# NVIDIA CORPORATION and its licensors retain all intell"
  },
  {
    "path": "timechamber/utils/vec_task_wrappers.py",
    "chars": 2934,
    "preview": "# Copyright (c) 2018-2022, NVIDIA Corporation\n# All rights reserved.\n#\n# Redistribution and use in source and binary for"
  }
]

// ... and 102 more files (download for full content)

About this extraction

This page contains the full source code of the inspirai/TimeChamber GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 201 files (120.5 MB), approximately 199.9k tokens, and a symbol index with 972 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!