Full Code of google-research/dads for AI

master abc37f532c26 cached
32 files
225.1 KB
58.3k tokens
153 symbols
1 requests
Download .txt
Showing preview only (236K chars total). Download the full file or copy to clipboard to get everything.
Repository: google-research/dads
Branch: master
Commit: abc37f532c26
Files: 32
Total size: 225.1 KB

Directory structure:
gitextract_6tre9pw4/

├── .gitignore
├── AUTHORS
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── configs/
│   ├── ant_xy_offpolicy.txt
│   ├── ant_xy_onpolicy.txt
│   ├── dkitty_randomized_xy_offpolicy.txt
│   ├── humanoid_offpolicy.txt
│   ├── humanoid_onpolicy.txt
│   └── template_config.txt
├── env.yml
├── envs/
│   ├── assets/
│   │   ├── ant.xml
│   │   ├── ant_footsensor.xml
│   │   ├── half_cheetah.xml
│   │   ├── humanoid.xml
│   │   └── point.xml
│   ├── dclaw.py
│   ├── dkitty_redesign.py
│   ├── gym_mujoco/
│   │   ├── ant.py
│   │   ├── half_cheetah.py
│   │   ├── humanoid.py
│   │   └── point_mass.py
│   ├── hand_block.py
│   ├── skill_wrapper.py
│   └── video_wrapper.py
├── lib/
│   ├── py_tf_policy.py
│   └── py_uniform_replay_buffer.py
└── unsupervised_skill_learning/
    ├── dads_agent.py
    ├── dads_off.py
    ├── skill_discriminator.py
    └── skill_dynamics.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Generated files
*.egg-info/
.idea*
*__pycache__*
.ipynb_checkpoints*
*.pyc
*.DS_Store
*.mp4
*.json
output/
saved_models/
env_test.py
dkitty_eval.sh
experiments/
dads_token.txt


================================================
FILE: AUTHORS
================================================
# This is the list of authors for copyright purposes.
Google LLC
Archit Sharma
Shixiang Gu
Sergey Levine
Vikash Kumar
Karol Hausman

================================================
FILE: CONTRIBUTING.md
================================================
# How to Contribute

We'd love to accept your patches and contributions to this project. There are
just a few small guidelines you need to follow.

## Contributor License Agreement

Contributions to this project must be accompanied by a Contributor License
Agreement. You (or your employer) retain the copyright to your contribution;
this simply gives us permission to use and redistribute your contributions as
part of the project. Head over to <https://cla.developers.google.com/> to see
your current agreements on file or to sign a new one.

You generally only need to submit a CLA once, so if you've already submitted one
(even if it was for a different project), you probably don't need to do it
again.

## Code reviews

All submissions, including submissions by project members, require review. We
use GitHub pull requests for this purpose. Consult
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
information on using pull requests.

## Community Guidelines

This project follows
[Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).

================================================
FILE: LICENSE
================================================

                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.

================================================
FILE: README.md
================================================
# Dynamics-Aware Discovery of Skills (DADS)
This repository is the open-source implementation of Dynamics-Aware Unsupervised Discovery of Skills ([project page][website], [arXiv][paper]). We propose an skill-discovery method which can learn skills for different agents without any rewards, while simultaneously learning dynamics model for the skills which can be leveraged for model-based control on the downstream task. This work was published in International Conference of Learning Representations ([ICLR][iclr]), 2020.

We have also included an improved off-policy version of DADS, coined off-DADS. The details have been released in [Emergent Real-World Robotic Skills via Unsupervised Off-Policy Reinforcement Learning][rss_arxiv].

In case of problems, contact Archit Sharma.

## Table of Contents

* [Setup](#setup)
* [Usage](#usage)
* [Citation](#citation)
* [Disclaimer](#disclaimer)

## Setup

#### (1) Setup MuJoCo
Download and setup [mujoco][mujoco] in `~/.mujoco`. Set the `LD_LIBRARY_PATH` in your `~/.bashrc`:
```
LD_LIBRARY_PATH='~/.mujoco/mjpro150/bin':$LD_LIBRARY_PATH
```

#### (2) Setup environment
Clone the repository and setup up the [conda][conda] environment to run DADS code:
```
cd <path_to_dads>
conda env create -f env.yml
conda activate dads-env
```

## Usage
We give a high-level explanation of how to use the code. More details pertaining to hyperparameters can be found in the the `configs/template_config.txt`, `dads_off.py` and the Appendix A of [paper][paper].

Every training run will require an experimental logging directory and a configuration file, which can be created started from the `configs/template_config.txt`. There are two phases: (a) Training where the new skills are learnt along with their skill-dynamics models and (b) evaluation where the learnt skills are evaluated on the task associated with the environment.

For training, ensure `--run_train=1` is set in the configuration file. For on-policy optimization, set `--clear_buffer_every_iter=1` and ensure the replay buffer size is bigger than the number of steps collected in every iteration. For off-policy optimization (details yet to be released), set `--clear_buffer_every_iter=0`. Set the environment name (ensure the environment is listed in `get_environment()` in `dads_off.py`). To change the observation for skill-dynamics (for example to learn in x-y space), set `--reduced_observation` and correspondingly configure `process_observation()` in `dads_off.py`. The skill space can be configured to be discrete or continuous. The optimization parameters can be tweaked, and some basic values have been set in (more details in the [paper][paper]). 

For evaluation, ensure `--run_eval=1` and the experimental directory points to the same directory in which the training happened. Set `--num_evals` if you want to record videos of randomly sampled skills from the prior distribution. After that, the script will use the learned models to execute MPC on the latent space to optimize for the task-reward. By default, the code will call `get_environment()` to load `FLAGS.environment + '_goal'`, and will go through the list of goal-coordinates specified in the eval section of the script.

We have provided the configuration files in `configs/` to reproduce results from the experiments in the [paper][paper]. Goal evaluation is currently only setup for MuJoCo Ant environement. The goal distribution can be changed in `dads_off.py` in evaluation part of the script.

```
cd <path_to_dads>
python unsupervised_skill_learning/dads_off.py --logdir=<path_for_experiment_logs> --flagfile=configs/<config_name>.txt
```

The specified experimental log directory will contain the tensorboard files, the saved checkpoints and the skill-evaluation videos.

## Citation
To cite [Dynamics-Aware Unsupervised Discovery of Skills](paper):
```
@article{sharma2019dynamics,
  title={Dynamics-aware unsupervised discovery of skills},
  author={Sharma, Archit and Gu, Shixiang and Levine, Sergey and Kumar, Vikash and Hausman, Karol},
  journal={arXiv preprint arXiv:1907.01657},
  year={2019}
}
```
To cite off-DADS and [Emergent Real-World Robotic Skills via Unsupervised Off-Policy Reinforcement Learning][rss_arxiv]:
```
@article{sharma2020emergent,
    title={Emergent Real-World Robotic Skills via Unsupervised Off-Policy Reinforcement Learning},
    author={Sharma, Archit and Ahn, Michael and Levine, Sergey and Kumar, Vikash and Hausman, Karol and Gu, Shixiang},
    journal={arXiv preprint arXiv:2004.12974},
    year={2020}
}
```
## Disclaimer
This is not an officially supported Google product.

[website]: https://sites.google.com/corp/view/dads-skill 
[paper]: https://arxiv.org/abs/1907.01657
[iclr]: https://openreview.net/forum?id=HJgLZR4KvH
[mujoco]: http://www.mujoco.org/
[conda]: https://docs.conda.io/en/latest/miniconda.html
[rss_arxiv]: https://arxiv.org/abs/2004.12974


================================================
FILE: configs/ant_xy_offpolicy.txt
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

### TRAINING HYPERPARAMETERS -------------------
--run_train=1

# metadata flags
--save_model=dads
--save_freq=50
--record_freq=100
--vid_name=skill

# optimization hyperparmaters
--replay_buffer_capacity=10000

# (set clear_buffer_every_iter=1 for on-policy optimization)
--clear_buffer_every_iter=0
--initial_collect_steps=2000
--collect_steps=500
--num_epochs=10000

# skill dynamics optimization hyperparameters
--skill_dyn_train_steps=8
--skill_dynamics_lr=3e-4
--skill_dyn_batch_size=256

# agent hyperparameters
--agent_gamma=0.99
--agent_lr=3e-4
--agent_entropy=0.1
--agent_train_steps=64
--agent_batch_size=256

# (optional, do not change for on-policy) relabelling or off-policy corrections
--skill_dynamics_relabel_type=importance_sampling
--num_samples_for_relabelling=1
--is_clip_eps=10.

# (optional) skills can be resampled within the episodes, relative to max_env_steps
--min_steps_before_resample=2000
--resample_prob=0.02

# (optional) configure skill dynamics training samples to be only from the current policy
--train_skill_dynamics_on_policy=0

### SHARED HYPERPARAMETERS ---------------------
--environment=Ant-v1
--max_env_steps=200
--reduced_observation=2

# define the type of skills being learnt
--num_skills=2
--skill_type=cont_uniform
--random_skills=100
--num_evals=3

# (optional) policy, critic and skill dynamics
--hidden_layer_size=512

# (optional) skill dynamics hyperparameters
--graph_type=default
--num_components=4
--fix_variance=1
--normalize_data=1

# (optional) clip sampled actions
--action_clipping=1.

# (optional) debugging
--debug=0

### EVALUATION HYPERPARAMETERS -----------------
--run_eval=0

# MPC hyperparameters
--planning_horizon=1
--primitive_horizon=10
--num_candidate_sequences=50
--refine_steps=10
--mppi_gamma=10
--prior_type=normal
--smoothing_beta=0.9
--top_primitives=5


### (optional) ENVIRONMENT SPECIFIC HYPERPARAMETERS --------
# DKitty hyperparameters
--expose_last_action=1
--expose_upright=1
--robot_noise_ratio=0.0
--root_noise_ratio=0.0
--upright_threshold=0.95
--scale_root_position=1
--randomize_hfield=0.0

# DKitty/DClaw
--observation_omission_size=0

# Cube Manipulation hyperparameters
--randomized_initial_distribution=1
--horizontal_wrist_constraint=0.3
--vertical_wrist_constraint=1.0


================================================
FILE: configs/ant_xy_onpolicy.txt
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

### TRAINING HYPERPARAMETERS -------------------
--run_train=1

# metadata flags
--save_model=dads
--save_freq=50
--record_freq=100
--vid_name=skill

# optimization hyperparmaters
--replay_buffer_capacity=100000

# (set clear_buffer_iter=1 for on-policy)
--clear_buffer_every_iter=1
--initial_collect_steps=0
--collect_steps=2000
--num_epochs=10000

# skill dynamics optimization hyperparameters
--skill_dyn_train_steps=32
--skill_dynamics_lr=3e-4
--skill_dyn_batch_size=256

# agent hyperparameters
--agent_gamma=0.995
--agent_lr=3e-4
--agent_entropy=0.1
--agent_train_steps=64
--agent_batch_size=256

# (optional, do not change for on-policy) relabelling or off-policy corrections
--skill_dynamics_relabel_type=importance_sampling
--num_samples_for_relabelling=1
--is_clip_eps=1.

# (optional) skills can be resampled within the episodes, relative to max_env_steps
--min_steps_before_resample=2000
--resample_prob=0.02

# (optional) configure skill dynamics training samples to be only from the current policy
--train_skill_dynamics_on_policy=0

### SHARED HYPERPARAMETERS ---------------------
--environment=Ant-v1
--max_env_steps=200
--reduced_observation=2

# define the type of skills being learnt
--num_skills=2
--skill_type=cont_uniform
--random_skills=100
--num_evals=3

# (optional) policy, critic and skill dynamics
--hidden_layer_size=512

# (optional) skill dynamics hyperparameters
--graph_type=default
--num_components=4
--fix_variance=1
--normalize_data=1

# (optional) clip sampled actions
--action_clipping=1.

# (optional) debugging
--debug=0

### EVALUATION HYPERPARAMETERS -----------------
--run_eval=0

# MPC hyperparameters
--planning_horizon=1
--primitive_horizon=10
--num_candidate_sequences=50
--refine_steps=10
--mppi_gamma=10
--prior_type=normal
--smoothing_beta=0.9
--top_primitives=5


### (optional) ENVIRONMENT SPECIFIC HYPERPARAMETERS --------
# DKitty hyperparameters
--expose_last_action=1
--expose_upright=1
--robot_noise_ratio=0.0
--root_noise_ratio=0.0
--upright_threshold=0.95
--scale_root_position=1
--randomize_hfield=0.0

# DKitty/DClaw
--observation_omission_size=0

# Cube Manipulation hyperparameters
--randomized_initial_distribution=1
--horizontal_wrist_constraint=0.3
--vertical_wrist_constraint=1.0


================================================
FILE: configs/dkitty_randomized_xy_offpolicy.txt
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

### TRAINING HYPERPARAMETERS -------------------
--run_train=1

# metadata flags
--save_model=dads
--save_freq=50
--record_freq=100
--vid_name=skill

# optimization hyperparmaters
--replay_buffer_capacity=10000

# (set clear_buffer_iter=1 for on-policy)
--clear_buffer_every_iter=0
--initial_collect_steps=2000
--collect_steps=500
--num_epochs=1000

# skill dynamics optimization hyperparameters
--skill_dyn_train_steps=8
--skill_dynamics_lr=3e-4
--skill_dyn_batch_size=256

# agent hyperparameters
--agent_gamma=0.99
--agent_lr=3e-4
--agent_entropy=0.1
--agent_train_steps=64
--agent_batch_size=256

# (optional, do not change for on-policy) relabelling or off-policy corrections
--skill_dynamics_relabel_type=importance_sampling
--num_samples_for_relabelling=1
--is_clip_eps=10.

# (optional) skills can be resampled within the episodes, relative to max_env_steps
--min_steps_before_resample=2000
--resample_prob=0.02

# (optional) configure skill dynamics training samples to be only from the current policy
--train_skill_dynamics_on_policy=0

### SHARED HYPERPARAMETERS ---------------------
--environment=DKitty_randomized
--max_env_steps=200
--reduced_observation=2

# define the type of skills being learnt
--num_skills=2
--skill_type=cont_uniform
--random_skills=100
--num_evals=3

# (optional) policy, critic and skill dynamics
--hidden_layer_size=512

# (optional) skill dynamics hyperparameters
--graph_type=default
--num_components=4
--fix_variance=1
--normalize_data=1

# (optional) clip sampled actions
--action_clipping=1.

# (optional) debugging
--debug=0

### EVALUATION HYPERPARAMETERS -----------------
--run_eval=0

# MPC hyperparameters
--planning_horizon=1
--primitive_horizon=10
--num_candidate_sequences=50
--refine_steps=10
--mppi_gamma=10
--prior_type=normal
--smoothing_beta=0.9
--top_primitives=5


### (optional) ENVIRONMENT SPECIFIC HYPERPARAMETERS --------
# DKitty hyperparameters
--expose_last_action=1
--expose_upright=1
--robot_noise_ratio=0.0
--root_noise_ratio=0.0
--upright_threshold=0.95
--scale_root_position=1
--randomize_hfield=0.02

# DKitty/DClaw
--observation_omission_size=2

# Cube Manipulation hyperparameters
--randomized_initial_distribution=1
--horizontal_wrist_constraint=0.3
--vertical_wrist_constraint=1.0


================================================
FILE: configs/humanoid_offpolicy.txt
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

### TRAINING HYPERPARAMETERS -------------------
--run_train=1

# metadata flags
--save_model=dads
--save_freq=50
--record_freq=100
--vid_name=skill

# optimization hyperparmaters
--replay_buffer_capacity=10000

# (set clear_buffer_iter=1 for on-policy)
--clear_buffer_every_iter=0
--initial_collect_steps=5000
--collect_steps=2000
--num_epochs=100000

# skill dynamics optimization hyperparameters
--skill_dyn_train_steps=16
--skill_dynamics_lr=3e-4
--skill_dyn_batch_size=256

# agent hyperparameters
--agent_gamma=0.995
--agent_lr=3e-4
--agent_entropy=0.1
--agent_train_steps=128
--agent_batch_size=256

# (optional, do not change for on-policy) relabelling or off-policy corrections
--skill_dynamics_relabel_type=importance_sampling
--num_samples_for_relabelling=1
--is_clip_eps=1.

# (optional) skills can be resampled within the episodes, relative to max_env_steps
--min_steps_before_resample=2000
--resample_prob=0.0

# (optional) configure skill dynamics training samples to be only from the current policy
--train_skill_dynamics_on_policy=0

### SHARED HYPERPARAMETERS ---------------------
--environment=Humanoid-v1
--max_env_steps=1000
--reduced_observation=0

# define the type of skills being learnt
--num_skills=5
--skill_type=cont_uniform
--random_skills=100

# number of skill-video evaluations
--num_evals=3

# (optional) policy, critic and skill dynamics
--hidden_layer_size=1024

# (optional) skill dynamics hyperparameters
--graph_type=default
--num_components=4
--fix_variance=1
--normalize_data=1

# (optional) clip sampled actions
--action_clipping=1.

# (optional) debugging
--debug=0

### EVALUATION HYPERPARAMETERS -----------------
--run_eval=0

# MPC hyperparameters
--planning_horizon=1
--primitive_horizon=10
--num_candidate_sequences=50
--refine_steps=10
--mppi_gamma=10
--prior_type=normal
--smoothing_beta=0.9
--top_primitives=5


### (optional) ENVIRONMENT SPECIFIC HYPERPARAMETERS --------
# DKitty hyperparameters
--expose_last_action=1
--expose_upright=1
--robot_noise_ratio=0.0
--root_noise_ratio=0.0
--upright_threshold=0.95
--scale_root_position=1
--randomize_hfield=0.0

# DKitty/DClaw
--observation_omission_size=0

# Cube Manipulation hyperparameters
--randomized_initial_distribution=1
--horizontal_wrist_constraint=0.3
--vertical_wrist_constraint=1.0


================================================
FILE: configs/humanoid_onpolicy.txt
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

### TRAINING HYPERPARAMETERS -------------------
--run_train=1

# metadata flags
--save_model=dads
--save_freq=50
--record_freq=100
--vid_name=skill

# optimization hyperparmaters
--replay_buffer_capacity=100000

# (set clear_buffer_iter=1 for on-policy)
--clear_buffer_every_iter=1
--initial_collect_steps=0
--collect_steps=4000
--num_epochs=100000

# skill dynamics optimization hyperparameters
--skill_dyn_train_steps=32
--skill_dynamics_lr=3e-4
--skill_dyn_batch_size=256

# agent hyperparameters
--agent_gamma=0.995
--agent_lr=3e-4
--agent_entropy=0.1
--agent_train_steps=64
--agent_batch_size=256

# (optional, do not change for on-policy) relabelling or off-policy corrections
--skill_dynamics_relabel_type=importance_sampling
--num_samples_for_relabelling=1
--is_clip_eps=1.

# (optional) skills can be resampled within the episodes, relative to max_env_steps
--min_steps_before_resample=2000
--resample_prob=0.0

# (optional) configure skill dynamics training samples to be only from the current policy
--train_skill_dynamics_on_policy=0

### SHARED HYPERPARAMETERS ---------------------
--environment=Humanoid-v1
--max_env_steps=1000
--reduced_observation=0

# define the type of skills being learnt
--num_skills=5
--skill_type=cont_uniform
--random_skills=100

# number of skill-video evaluations
--num_evals=3

# (optional) policy, critic and skill dynamics
--hidden_layer_size=1024

# (optional) skill dynamics hyperparameters
--graph_type=default
--num_components=4
--fix_variance=1
--normalize_data=1

# (optional) clip sampled actions
--action_clipping=1.

# (optional) debugging
--debug=0

### EVALUATION HYPERPARAMETERS -----------------
--run_eval=0

# MPC hyperparameters
--planning_horizon=1
--primitive_horizon=10
--num_candidate_sequences=50
--refine_steps=10
--mppi_gamma=10
--prior_type=normal
--smoothing_beta=0.9
--top_primitives=5


### (optional) ENVIRONMENT SPECIFIC HYPERPARAMETERS --------
# DKitty hyperparameters
--expose_last_action=1
--expose_upright=1
--robot_noise_ratio=0.0
--root_noise_ratio=0.0
--upright_threshold=0.95
--scale_root_position=1
--randomize_hfield=0.0

# DKitty/DClaw
--observation_omission_size=0

# Cube Manipulation hyperparameters
--randomized_initial_distribution=1
--horizontal_wrist_constraint=0.3
--vertical_wrist_constraint=1.0


================================================
FILE: configs/template_config.txt
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

### TRAINING HYPERPARAMETERS -------------------
--run_train=0

# metadata flags
--save_model=dads
--save_freq=50
--record_freq=100
--vid_name=skill

# optimization hyperparmaters
--replay_buffer_capacity=100000

# (set clear_buffer_iter=1 for on-policy)
--clear_buffer_every_iter=0
--initial_collect_steps=2000
--collect_steps=1000
--num_epochs=100

# skill dynamics optimization hyperparameters
--skill_dyn_train_steps=16
--skill_dynamics_lr=3e-4
--skill_dyn_batch_size=256

# agent hyperparameters
--agent_gamma=0.99
--agent_lr=3e-4
--agent_entropy=0.1
--agent_train_steps=64
--agent_batch_size=256

# (optional, do not change for on-policy) relabelling or off-policy corrections
--skill_dynamics_relabel_type=importance_sampling
--num_samples_for_relabelling=1
--is_clip_eps=1.

# (optional) skills can be resampled within the episodes, relative to max_env_steps
--min_steps_before_resample=2000
--resample_prob=0.02

# (optional) configure skill dynamics training samples to be only from the current policy
--train_skill_dynamics_on_policy=0

### SHARED HYPERPARAMETERS ---------------------
--environment=<set_some_environment>
--max_env_steps=200
--reduced_observation=0

# define the type of skills being learnt
--num_skills=2
--skill_type=cont_uniform
--random_skills=100

# number of skill-video evaluations
--num_evals=3

# (optional) policy, critic and skill dynamics
--hidden_layer_size=512

# (optional) skill dynamics hyperparameters
--graph_type=default
--num_components=4
--fix_variance=1
--normalize_data=1

# (optional) clip sampled actions
--action_clipping=1.

# (optional) debugging
--debug=0

### EVALUATION HYPERPARAMETERS -----------------
--run_eval=0

# MPC hyperparameters
--planning_horizon=1
--primitive_horizon=10
--num_candidate_sequences=50
--refine_steps=10
--mppi_gamma=10
--prior_type=normal
--smoothing_beta=0.9
--top_primitives=5


### (optional) ENVIRONMENT SPECIFIC HYPERPARAMETERS --------
# DKitty hyperparameters
--expose_last_action=1
--expose_upright=1
--robot_noise_ratio=0.0
--root_noise_ratio=0.0
--upright_threshold=0.95
--scale_root_position=1
--randomize_hfield=0.0

# DKitty/DClaw
--observation_omission_size=0

# Cube Manipulation hyperparameters
--randomized_initial_distribution=1
--horizontal_wrist_constraint=0.3
--vertical_wrist_constraint=1.0


================================================
FILE: env.yml
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

name: dads-env
channels:
- defaults
- conda-forge
dependencies:
- python=3.6.8
- pip>=18.1
- conda>=4.6.7
- pip:
  - numpy<2.0,>=1.16.0
  - tensorflow-probability==0.10.0
  - tensorflow==2.2.0
  - tf-agents==0.4.0
  - tensorflow-estimator==2.2.0
  - gym==0.11.0
  - matplotlib==3.0.2
  - robel==0.1.2
  - mujoco-py==2.0.2.5
  - click
  - transforms3d


================================================
FILE: envs/assets/ant.xml
================================================
<!-- ======================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
====================================================== -->

<mujoco model="ant">
  <compiler angle="degree" coordinate="local" inertiafromgeom="true"/>
  <option integrator="RK4" timestep="0.01"/>
  <custom>
    <numeric data="0.0 0.0 0.55 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -1.0 0.0 -1.0 0.0 1.0" name="init_qpos"/>
  </custom>
  <default>
    <joint armature="1" damping="1" limited="true"/>
    <geom conaffinity="0" condim="3" density="5.0" friction="1 0.5 0.5" margin="0.01" rgba="0.8 0.6 0.4 1"/>
  </default>
  <asset>
    <texture builtin="gradient" height="100" rgb1="1 1 1" rgb2="0 0 0" type="skybox" width="100"/>
    <texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127"/>
    <texture builtin="checker" height="100" name="texplane" rgb1="0 0 0" rgb2="0.8 0.8 0.8" type="2d" width="100"/>
    <material name="MatPlane" reflectance="0.5" shininess="1" specular="1" texrepeat="60 60" texture="texplane"/>
    <material name="geom" texture="texgeom" texuniform="true"/>
  </asset>
  <worldbody>
    <light cutoff="100" diffuse="1 1 1" dir="-0 0 -1.3" directional="true" exponent="1" pos="0 0 1.3" specular=".1 .1 .1"/>
    <geom conaffinity="1" condim="3" material="MatPlane" name="floor" pos="0 0 0" rgba="0.8 0.9 0.8 1" size="40 40 40" type="plane"/>
    <body name="torso" pos="0 0 0.75">
      <camera name="track" mode="trackcom" pos="0 -3 0.3" xyaxes="1 0 0 0 0 1"/>
      <geom name="torso_geom" pos="0 0 0" size="0.25" type="sphere"/>
      <joint armature="0" damping="0" limited="false" margin="0.01" name="root" pos="0 0 0" type="free"/>
      <body name="front_left_leg" pos="0 0 0">
        <geom fromto="0.0 0.0 0.0 0.2 0.2 0.0" name="aux_1_geom" size="0.08" type="capsule"/>
        <body name="aux_1" pos="0.2 0.2 0">
          <joint axis="0 0 1" name="hip_1" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
          <geom fromto="0.0 0.0 0.0 0.2 0.2 0.0" name="left_leg_geom" size="0.08" type="capsule"/>
          <body pos="0.2 0.2 0">
            <joint axis="-1 1 0" name="ankle_1" pos="0.0 0.0 0.0" range="30 70" type="hinge"/>
            <geom fromto="0.0 0.0 0.0 0.4 0.4 0.0" name="left_ankle_geom" size="0.08" type="capsule"/>
          </body>
        </body>
      </body>
      <body name="front_right_leg" pos="0 0 0">
        <geom fromto="0.0 0.0 0.0 -0.2 0.2 0.0" name="aux_2_geom" size="0.08" type="capsule"/>
        <body name="aux_2" pos="-0.2 0.2 0">
          <joint axis="0 0 1" name="hip_2" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
          <geom fromto="0.0 0.0 0.0 -0.2 0.2 0.0" name="right_leg_geom" size="0.08" type="capsule"/>
          <body pos="-0.2 0.2 0">
            <joint axis="1 1 0" name="ankle_2" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
            <geom fromto="0.0 0.0 0.0 -0.4 0.4 0.0" name="right_ankle_geom" size="0.08" type="capsule"/>
          </body>
        </body>
      </body>
      <body name="back_leg" pos="0 0 0">
        <geom fromto="0.0 0.0 0.0 -0.2 -0.2 0.0" name="aux_3_geom" size="0.08" type="capsule"/>
        <body name="aux_3" pos="-0.2 -0.2 0">
          <joint axis="0 0 1" name="hip_3" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
          <geom fromto="0.0 0.0 0.0 -0.2 -0.2 0.0" name="back_leg_geom" size="0.08" type="capsule"/>
          <body pos="-0.2 -0.2 0">
            <joint axis="-1 1 0" name="ankle_3" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
            <geom fromto="0.0 0.0 0.0 -0.4 -0.4 0.0" name="third_ankle_geom" size="0.08" type="capsule"/>
          </body>
        </body>
      </body>
      <body name="right_back_leg" pos="0 0 0">
        <geom fromto="0.0 0.0 0.0 0.2 -0.2 0.0" name="aux_4_geom" size="0.08" type="capsule"/>
        <body name="aux_4" pos="0.2 -0.2 0">
          <joint axis="0 0 1" name="hip_4" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
          <geom fromto="0.0 0.0 0.0 0.2 -0.2 0.0" name="rightback_leg_geom" size="0.08" type="capsule"/>
          <body pos="0.2 -0.2 0">
            <joint axis="1 1 0" name="ankle_4" pos="0.0 0.0 0.0" range="30 70" type="hinge"/>
            <geom fromto="0.0 0.0 0.0 0.4 -0.4 0.0" name="fourth_ankle_geom" size="0.08" type="capsule"/>
          </body>
        </body>
      </body>
    </body>
  </worldbody>
  <actuator>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_4" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_4" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_1" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_1" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_2" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_2" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_3" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_3" gear="150"/>
  </actuator>
</mujoco>


================================================
FILE: envs/assets/ant_footsensor.xml
================================================
<!-- ======================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
====================================================== -->

<mujoco model="ant">
  <compiler angle="degree" coordinate="local" inertiafromgeom="true"/>
  <option integrator="RK4" timestep="0.01"/>
  <custom>
    <numeric data="0.0 0.0 0.55 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -1.0 0.0 -1.0 0.0 1.0" name="init_qpos"/>
  </custom>
  <default>
    <joint armature="1" damping="1" limited="true"/>
    <geom conaffinity="0" condim="3" density="5.0" friction="1 0.5 0.5" margin="0.01" rgba="0.8 0.6 0.4 1"/>
  </default>
  <asset>
    <texture builtin="gradient" height="100" rgb1="1 1 1" rgb2="0 0 0" type="skybox" width="100"/>
    <texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127"/>
    <texture builtin="checker" height="100" name="texplane" rgb1="0 0 0" rgb2="0.8 0.8 0.8" type="2d" width="100"/>
    <material name="MatPlane" reflectance="0.5" shininess="1" specular="1" texrepeat="60 60" texture="texplane"/>
    <material name="geom" texture="texgeom" texuniform="true"/>
  </asset>
  <worldbody>
    <light cutoff="100" diffuse="1 1 1" dir="-0 0 -1.3" directional="true" exponent="1" pos="0 0 1.3" specular=".1 .1 .1"/>
    <geom conaffinity="1" condim="3" material="MatPlane" name="floor" pos="0 0 0" rgba="0.8 0.9 0.8 1" size="40 40 40" type="plane"/>
    <body name="torso" pos="0 0 0.75">
      <camera name="track" mode="trackcom" pos="0 -3 0.3" xyaxes="1 0 0 0 0 1"/>
      <geom name="torso_geom" pos="0 0 0" size="0.25" type="sphere"/>
      <joint armature="0" damping="0" limited="false" margin="0.01" name="root" pos="0 0 0" type="free"/>
      <body name="front_left_leg" pos="0 0 0">
        <geom fromto="0.0 0.0 0.0 0.2 0.2 0.0" name="aux_1_geom" size="0.08" type="capsule"/>
        <body name="aux_1" pos="0.2 0.2 0">
          <joint axis="0 0 1" name="hip_1" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
          <geom fromto="0.0 0.0 0.0 0.2 0.2 0.0" name="left_leg_geom" size="0.08" type="capsule"/>
          <body pos="0.2 0.2 0">
            <joint axis="-1 1 0" name="ankle_1" pos="0.0 0.0 0.0" range="30 70" type="hinge"/>
            <geom fromto="0.0 0.0 0.0 0.4 0.4 0.0" name="left_ankle_geom" size="0.08" type="capsule"/>
            <site name='front_left_leg' pos="0.4 0.4 0.0" type='sphere' size='.1' rgba='1 1 0 .5'/>
          </body>
        </body>
      </body>
      <body name="front_right_leg" pos="0 0 0">
        <geom fromto="0.0 0.0 0.0 -0.2 0.2 0.0" name="aux_2_geom" size="0.08" type="capsule"/>
        <body name="aux_2" pos="-0.2 0.2 0">
          <joint axis="0 0 1" name="hip_2" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
          <geom fromto="0.0 0.0 0.0 -0.2 0.2 0.0" name="right_leg_geom" size="0.08" type="capsule"/>
          <body pos="-0.2 0.2 0">
            <joint axis="1 1 0" name="ankle_2" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
            <geom fromto="0.0 0.0 0.0 -0.4 0.4 0.0" name="right_ankle_geom" size="0.08" type="capsule"/>
            <site name='front_right_leg' pos="-0.4 0.4 0.0" type='sphere' size='.1' rgba='1 1 0 .5'/>
          </body>
        </body>
      </body>
      <body name="back_leg" pos="0 0 0">
        <geom fromto="0.0 0.0 0.0 -0.2 -0.2 0.0" name="aux_3_geom" size="0.08" type="capsule"/>
        <body name="aux_3" pos="-0.2 -0.2 0">
          <joint axis="0 0 1" name="hip_3" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
          <geom fromto="0.0 0.0 0.0 -0.2 -0.2 0.0" name="back_leg_geom" size="0.08" type="capsule"/>
          <body pos="-0.2 -0.2 0">
            <joint axis="-1 1 0" name="ankle_3" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
            <geom fromto="0.0 0.0 0.0 -0.4 -0.4 0.0" name="third_ankle_geom" size="0.08" type="capsule"/>
            <site name='left_back_leg' pos="-0.4 -0.4 0.0" type='sphere' size='.1' rgba='1 1 0 .5'/>
          </body>
        </body>
      </body>
      <body name="right_back_leg" pos="0 0 0">
        <geom fromto="0.0 0.0 0.0 0.2 -0.2 0.0" name="aux_4_geom" size="0.08" type="capsule"/>
        <body name="aux_4" pos="0.2 -0.2 0">
          <joint axis="0 0 1" name="hip_4" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
          <geom fromto="0.0 0.0 0.0 0.2 -0.2 0.0" name="rightback_leg_geom" size="0.08" type="capsule"/>
          <body pos="0.2 -0.2 0">
            <joint axis="1 1 0" name="ankle_4" pos="0.0 0.0 0.0" range="30 70" type="hinge"/>
            <geom fromto="0.0 0.0 0.0 0.4 -0.4 0.0" name="fourth_ankle_geom" size="0.08" type="capsule"/>
            <site name='right_back_leg' pos="0.4 -0.4 0.0" type='sphere' size='.1' rgba='1 1 0 .5'/>
          </body>
        </body>
      </body>
    </body>
  </worldbody>
  <actuator>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_4" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_4" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_1" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_1" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_2" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_2" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_3" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_3" gear="150"/>
  </actuator>

  <sensor>
    <touch name='front_left_leg' site='front_left_leg'/>
    <touch name='front_right_leg' site='front_right_leg'/>
    <touch name='left_back_leg' site='left_back_leg'/>
    <touch name='right_back_leg' site='right_back_leg'/>
  </sensor>
</mujoco>


================================================
FILE: envs/assets/half_cheetah.xml
================================================
<!-- ======================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
====================================================== -->

<!-- Cheetah Model

    The state space is populated with joints in the order that they are
    defined in this file. The actuators also operate on joints.

    State-Space (name/joint/parameter):
        - rootx     slider      position (m)
        - rootz     slider      position (m)
        - rooty     hinge       angle (rad)
        - bthigh    hinge       angle (rad)
        - bshin     hinge       angle (rad)
        - bfoot     hinge       angle (rad)
        - fthigh    hinge       angle (rad)
        - fshin     hinge       angle (rad)
        - ffoot     hinge       angle (rad)
        - rootx     slider      velocity (m/s)
        - rootz     slider      velocity (m/s)
        - rooty     hinge       angular velocity (rad/s)
        - bthigh    hinge       angular velocity (rad/s)
        - bshin     hinge       angular velocity (rad/s)
        - bfoot     hinge       angular velocity (rad/s)
        - fthigh    hinge       angular velocity (rad/s)
        - fshin     hinge       angular velocity (rad/s)
        - ffoot     hinge       angular velocity (rad/s)

    Actuators (name/actuator/parameter):
        - bthigh    hinge       torque (N m)
        - bshin     hinge       torque (N m)
        - bfoot     hinge       torque (N m)
        - fthigh    hinge       torque (N m)
        - fshin     hinge       torque (N m)
        - ffoot     hinge       torque (N m)

-->
<mujoco model="cheetah">
  <compiler angle="radian" coordinate="local" inertiafromgeom="true" settotalmass="14"/>
  <default>
    <joint armature=".1" damping=".01" limited="true" solimplimit="0 .8 .03" solreflimit=".02 1" stiffness="8"/>
    <geom conaffinity="0" condim="3" contype="1" friction=".4 .1 .1" rgba="0.8 0.6 .4 1" solimp="0.0 0.8 0.01" solref="0.02 1"/>
    <motor ctrllimited="true" ctrlrange="-1 1"/>
  </default>
  <size nstack="300000" nuser_geom="1"/>
  <option gravity="0 0 -9.81" timestep="0.01"/>
  <asset>
    <texture builtin="gradient" height="100" rgb1="1 1 1" rgb2="0 0 0" type="skybox" width="100"/>
    <texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127"/>
    <texture builtin="checker" height="100" name="texplane" rgb1="0 0 0" rgb2="0.8 0.8 0.8" type="2d" width="100"/>
    <material name="MatPlane" reflectance="0.5" shininess="1" specular="1" texrepeat="60 60" texture="texplane"/>
    <material name="geom" texture="texgeom" texuniform="true"/>
  </asset>
  <worldbody>
    <light cutoff="100" diffuse="1 1 1" dir="-0 0 -1.3" directional="true" exponent="1" pos="0 0 1.3" specular=".1 .1 .1"/>
    <geom conaffinity="1" condim="3" material="MatPlane" name="floor" pos="0 0 0" rgba="0.8 0.9 0.8 1" size="40 40 40" type="plane"/>
    <body name="torso" pos="0 0 .7">
      <camera name="track" mode="trackcom" pos="0 -3 0.3" xyaxes="1 0 0 0 0 1"/>
      <joint armature="0" axis="1 0 0" damping="0" limited="false" name="rootx" pos="0 0 0" stiffness="0" type="slide"/>
      <joint armature="0" axis="0 0 1" damping="0" limited="false" name="rootz" pos="0 0 0" stiffness="0" type="slide"/>
      <joint armature="0" axis="0 1 0" damping="0" limited="false" name="rooty" pos="0 0 0" stiffness="0" type="hinge"/>
      <geom fromto="-.5 0 0 .5 0 0" name="torso" size="0.046" type="capsule"/>
      <geom axisangle="0 1 0 .87" name="head" pos=".6 0 .1" size="0.046 .15" type="capsule"/>
      <!-- <site name='tip'  pos='.15 0 .11'/>-->
      <body name="bthigh" pos="-.5 0 0">
        <joint axis="0 1 0" damping="6" name="bthigh" pos="0 0 0" range="-.52 1.05" stiffness="240" type="hinge"/>
        <geom axisangle="0 1 0 -3.8" name="bthigh" pos=".1 0 -.13" size="0.046 .145" type="capsule"/>
        <body name="bshin" pos=".16 0 -.25">
          <joint axis="0 1 0" damping="4.5" name="bshin" pos="0 0 0" range="-.785 .785" stiffness="180" type="hinge"/>
          <geom axisangle="0 1 0 -2.03" name="bshin" pos="-.14 0 -.07" rgba="0.9 0.6 0.6 1" size="0.046 .15" type="capsule"/>
          <body name="bfoot" pos="-.28 0 -.14">
            <joint axis="0 1 0" damping="3" name="bfoot" pos="0 0 0" range="-.4 .785" stiffness="120" type="hinge"/>
            <geom axisangle="0 1 0 -.27" name="bfoot" pos=".03 0 -.097" rgba="0.9 0.6 0.6 1" size="0.046 .094" type="capsule"/>
          </body>
        </body>
      </body>
      <body name="fthigh" pos=".5 0 0">
        <joint axis="0 1 0" damping="4.5" name="fthigh" pos="0 0 0" range="-1 .7" stiffness="180" type="hinge"/>
        <geom axisangle="0 1 0 .52" name="fthigh" pos="-.07 0 -.12" size="0.046 .133" type="capsule"/>
        <body name="fshin" pos="-.14 0 -.24">
          <joint axis="0 1 0" damping="3" name="fshin" pos="0 0 0" range="-1.2 .87" stiffness="120" type="hinge"/>
          <geom axisangle="0 1 0 -.6" name="fshin" pos=".065 0 -.09" rgba="0.9 0.6 0.6 1" size="0.046 .106" type="capsule"/>
          <body name="ffoot" pos=".13 0 -.18">
            <joint axis="0 1 0" damping="1.5" name="ffoot" pos="0 0 0" range="-.5 .5" stiffness="60" type="hinge"/>
            <geom axisangle="0 1 0 -.6" name="ffoot" pos=".045 0 -.07" rgba="0.9 0.6 0.6 1" size="0.046 .07" type="capsule"/>
          </body>
        </body>
      </body>
    </body>
  </worldbody>
  <actuator>
    <motor gear="120" joint="bthigh" name="bthigh"/>
    <motor gear="90" joint="bshin" name="bshin"/>
    <motor gear="60" joint="bfoot" name="bfoot"/>
    <motor gear="120" joint="fthigh" name="fthigh"/>
    <motor gear="60" joint="fshin" name="fshin"/>
    <motor gear="30" joint="ffoot" name="ffoot"/>
  </actuator>
</mujoco>


================================================
FILE: envs/assets/humanoid.xml
================================================
<!-- ======================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
====================================================== -->

<mujoco model="humanoid">
    <compiler angle="degree" inertiafromgeom="true"/>
    <default>
        <joint armature="1" damping="1" limited="true"/>
        <geom conaffinity="1" condim="1" contype="1" margin="0.001" material="geom" rgba="0.8 0.6 .4 1"/>
        <motor ctrllimited="true" ctrlrange="-.4 .4"/>
    </default>
    <option integrator="RK4" iterations="50" solver="PGS" timestep="0.003">
        <!-- <flags solverstat="enable" energy="enable"/>-->
    </option>
    <size nkey="5" nuser_geom="1"/>
    <visual>
        <map fogend="5" fogstart="3"/>
    </visual>
    <asset>
        <texture builtin="gradient" height="100" rgb1=".4 .5 .6" rgb2="0 0 0" type="skybox" width="100"/>
        <!-- <texture builtin="gradient" height="100" rgb1="1 1 1" rgb2="0 0 0" type="skybox" width="100"/>-->
        <texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127"/>
        <texture builtin="checker" height="100" name="texplane" rgb1="0 0 0" rgb2="0.8 0.8 0.8" type="2d" width="100"/>
        <material name="MatPlane" reflectance="0.5" shininess="1" specular="1" texrepeat="60 60" texture="texplane"/>
        <material name="geom" texture="texgeom" texuniform="true"/>
    </asset>
    <worldbody>
        <light cutoff="100" diffuse="1 1 1" dir="-0 0 -1.3" directional="true" exponent="1" pos="0 0 1.3" specular=".1 .1 .1"/>
        <geom condim="3" friction="1 .1 .1" material="MatPlane" name="floor" pos="0 0 0" rgba="0.8 0.9 0.8 1" size="20 20 0.125" type="plane"/>
        <!-- <geom condim="3" material="MatPlane" name="floor" pos="0 0 0" size="10 10 0.125" type="plane"/>-->
        <body name="torso" pos="0 0 1.4">
            <camera name="track" mode="trackcom" pos="0 -4 0" xyaxes="1 0 0 0 0 1"/>
            <joint armature="0" damping="0" limited="false" name="root" pos="0 0 0" stiffness="0" type="free"/>
            <geom fromto="0 -.07 0 0 .07 0" name="torso1" size="0.07" type="capsule"/>
            <geom name="head" pos="0 0 .19" size=".09" type="sphere" user="258"/>
            <geom fromto="-.01 -.06 -.12 -.01 .06 -.12" name="uwaist" size="0.06" type="capsule"/>
            <body name="lwaist" pos="-.01 0 -0.260" quat="1.000 0 -0.002 0">
                <geom fromto="0 -.06 0 0 .06 0" name="lwaist" size="0.06" type="capsule"/>
                <joint armature="0.02" axis="0 0 1" damping="5" name="abdomen_z" pos="0 0 0.065" range="-45 45" stiffness="20" type="hinge"/>
                <joint armature="0.02" axis="0 1 0" damping="5" name="abdomen_y" pos="0 0 0.065" range="-75 30" stiffness="10" type="hinge"/>
                <body name="pelvis" pos="0 0 -0.165" quat="1.000 0 -0.002 0">
                    <joint armature="0.02" axis="1 0 0" damping="5" name="abdomen_x" pos="0 0 0.1" range="-35 35" stiffness="10" type="hinge"/>
                    <geom fromto="-.02 -.07 0 -.02 .07 0" name="butt" size="0.09" type="capsule"/>
                    <body name="right_thigh" pos="0 -0.1 -0.04">
                        <joint armature="0.01" axis="1 0 0" damping="5" name="right_hip_x" pos="0 0 0" range="-25 5" stiffness="10" type="hinge"/>
                        <joint armature="0.01" axis="0 0 1" damping="5" name="right_hip_z" pos="0 0 0" range="-60 35" stiffness="10" type="hinge"/>
                        <joint armature="0.0080" axis="0 1 0" damping="5" name="right_hip_y" pos="0 0 0" range="-110 20" stiffness="20" type="hinge"/>
                        <geom fromto="0 0 0 0 0.01 -.34" name="right_thigh1" size="0.06" type="capsule"/>
                        <body name="right_shin" pos="0 0.01 -0.403">
                            <joint armature="0.0060" axis="0 -1 0" name="right_knee" pos="0 0 .02" range="-160 -2" type="hinge"/>
                            <geom fromto="0 0 0 0 0 -.3" name="right_shin1" size="0.049" type="capsule"/>
                            <body name="right_foot" pos="0 0 -0.45">
                                <geom name="right_foot" pos="0 0 0.1" size="0.075" type="sphere" user="0"/>
                            </body>
                        </body>
                    </body>
                    <body name="left_thigh" pos="0 0.1 -0.04">
                        <joint armature="0.01" axis="-1 0 0" damping="5" name="left_hip_x" pos="0 0 0" range="-25 5" stiffness="10" type="hinge"/>
                        <joint armature="0.01" axis="0 0 -1" damping="5" name="left_hip_z" pos="0 0 0" range="-60 35" stiffness="10" type="hinge"/>
                        <joint armature="0.01" axis="0 1 0" damping="5" name="left_hip_y" pos="0 0 0" range="-120 20" stiffness="20" type="hinge"/>
                        <geom fromto="0 0 0 0 -0.01 -.34" name="left_thigh1" size="0.06" type="capsule"/>
                        <body name="left_shin" pos="0 -0.01 -0.403">
                            <joint armature="0.0060" axis="0 -1 0" name="left_knee" pos="0 0 .02" range="-160 -2" stiffness="1" type="hinge"/>
                            <geom fromto="0 0 0 0 0 -.3" name="left_shin1" size="0.049" type="capsule"/>
                            <body name="left_foot" pos="0 0 -0.45">
                                <geom name="left_foot" type="sphere" size="0.075" pos="0 0 0.1" user="0" />
                            </body>
                        </body>
                    </body>
                </body>
            </body>
            <body name="right_upper_arm" pos="0 -0.17 0.06">
                <joint armature="0.0068" axis="2 1 1" name="right_shoulder1" pos="0 0 0" range="-85 60" stiffness="1" type="hinge"/>
                <joint armature="0.0051" axis="0 -1 1" name="right_shoulder2" pos="0 0 0" range="-85 60" stiffness="1" type="hinge"/>
                <geom fromto="0 0 0 .16 -.16 -.16" name="right_uarm1" size="0.04 0.16" type="capsule"/>
                <body name="right_lower_arm" pos=".18 -.18 -.18">
                    <joint armature="0.0028" axis="0 -1 1" name="right_elbow" pos="0 0 0" range="-90 50" stiffness="0" type="hinge"/>
                    <geom fromto="0.01 0.01 0.01 .17 .17 .17" name="right_larm" size="0.031" type="capsule"/>
                    <geom name="right_hand" pos=".18 .18 .18" size="0.04" type="sphere"/>
                    <camera pos="0 0 0"/>
                </body>
            </body>
            <body name="left_upper_arm" pos="0 0.17 0.06">
                <joint armature="0.0068" axis="2 -1 1" name="left_shoulder1" pos="0 0 0" range="-60 85" stiffness="1" type="hinge"/>
                <joint armature="0.0051" axis="0 1 1" name="left_shoulder2" pos="0 0 0" range="-60 85" stiffness="1" type="hinge"/>
                <geom fromto="0 0 0 .16 .16 -.16" name="left_uarm1" size="0.04 0.16" type="capsule"/>
                <body name="left_lower_arm" pos=".18 .18 -.18">
                    <joint armature="0.0028" axis="0 -1 -1" name="left_elbow" pos="0 0 0" range="-90 50" stiffness="0" type="hinge"/>
                    <geom fromto="0.01 -0.01 0.01 .17 -.17 .17" name="left_larm" size="0.031" type="capsule"/>
                    <geom name="left_hand" pos=".18 -.18 .18" size="0.04" type="sphere"/>
                </body>
            </body>
        </body>
    </worldbody>
    <tendon>
        <fixed name="left_hipknee">
            <joint coef="-1" joint="left_hip_y"/>
            <joint coef="1" joint="left_knee"/>
        </fixed>
        <fixed name="right_hipknee">
            <joint coef="-1" joint="right_hip_y"/>
            <joint coef="1" joint="right_knee"/>
        </fixed>
    </tendon>

    <actuator>
        <motor gear="100" joint="abdomen_y" name="abdomen_y"/>
        <motor gear="100" joint="abdomen_z" name="abdomen_z"/>
        <motor gear="100" joint="abdomen_x" name="abdomen_x"/>
        <motor gear="100" joint="right_hip_x" name="right_hip_x"/>
        <motor gear="100" joint="right_hip_z" name="right_hip_z"/>
        <motor gear="300" joint="right_hip_y" name="right_hip_y"/>
        <motor gear="200" joint="right_knee" name="right_knee"/>
        <motor gear="100" joint="left_hip_x" name="left_hip_x"/>
        <motor gear="100" joint="left_hip_z" name="left_hip_z"/>
        <motor gear="300" joint="left_hip_y" name="left_hip_y"/>
        <motor gear="200" joint="left_knee" name="left_knee"/>
        <motor gear="25" joint="right_shoulder1" name="right_shoulder1"/>
        <motor gear="25" joint="right_shoulder2" name="right_shoulder2"/>
        <motor gear="25" joint="right_elbow" name="right_elbow"/>
        <motor gear="25" joint="left_shoulder1" name="left_shoulder1"/>
        <motor gear="25" joint="left_shoulder2" name="left_shoulder2"/>
        <motor gear="25" joint="left_elbow" name="left_elbow"/>
    </actuator>
</mujoco>


================================================
FILE: envs/assets/point.xml
================================================
<!-- ======================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
====================================================== -->

<mujoco>
  <compiler angle="degree" coordinate="local" inertiafromgeom="true"/>
  <option integrator="RK4" timestep="0.02"/>
  <default>
    <joint armature="0" damping="0" limited="false"/>
    <geom conaffinity="0" condim="3" density="100" friction="1 0.5 0.5" margin="0" rgba="0.8 0.6 0.4 1"/>
  </default>
  <asset>
    <texture builtin="gradient" height="100" rgb1="1 1 1" rgb2="0 0 0" type="skybox" width="100"/>
    <texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127"/>
    <texture builtin="checker" height="100" name="texplane" rgb1="0 0 0" rgb2="0.8 0.8 0.8" type="2d" width="100"/>
    <material name="MatPlane" reflectance="0.5" shininess="1" specular="1" texrepeat="30 30" texture="texplane"/>
    <material name="geom" texture="texgeom" texuniform="true"/>
  </asset>
  <worldbody>
    <light cutoff="100" diffuse="1 1 1" dir="-0 0 -1.3" directional="true" exponent="1" pos="0 0 1.3" specular=".1 .1 .1"/>
    <geom conaffinity="1" condim="3" material="MatPlane" name="floor" pos="0 0 0" rgba="0.8 0.9 0.8 1" size="40 40 40" type="plane"/>
    <body name="torso" pos="0 0 0">
      <geom name="pointbody" pos="0 0 0.5" size="0.5" type="sphere"/>
      <geom name="pointarrow" pos="0.6 0 0.5" size="0.5 0.1 0.1" type="box"/>
      <joint axis="1 0 0" name="ballx" pos="0 0 0" type="slide"/>
      <joint axis="0 1 0" name="bally" pos="0 0 0" type="slide"/>
      <joint axis="0 0 1" limited="false" name="rot" pos="0 0 0" type="hinge"/>
    </body>
  </worldbody>
  <actuator>
    <!-- Those are just dummy actuators for providing ranges -->
    <motor ctrllimited="true" ctrlrange="-1 1" joint="ballx"/>
    <motor ctrllimited="true" ctrlrange="-0.25 0.25" joint="rot"/>
  </actuator>
</mujoco>


================================================
FILE: envs/dclaw.py
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Turn tasks with DClaw robots.

This is a single rotation of an object from an initial angle to a target angle.
"""

import abc
import collections
from typing import Dict, Optional, Sequence

import numpy as np

from robel.components.robot.dynamixel_robot import DynamixelRobotState
from robel.dclaw.base_env import BaseDClawObjectEnv
from robel.simulation.randomize import SimRandomizer
from robel.utils.configurable import configurable
from robel.utils.resources import get_asset_path

# The observation keys that are concatenated as the environment observation.
DEFAULT_OBSERVATION_KEYS = (
    'object_x',
    'object_y',
    'claw_qpos',
    'last_action',
)

# Reset pose for the claw joints.
RESET_POSE = [0, -np.pi / 3, np.pi / 3] * 3

DCLAW3_ASSET_PATH = 'robel/dclaw/assets/dclaw3xh_valve3_v0.xml'


class BaseDClawTurn(BaseDClawObjectEnv, metaclass=abc.ABCMeta):
    """Shared logic for DClaw turn tasks."""

    def __init__(self,
                 asset_path: str = DCLAW3_ASSET_PATH,
                 observation_keys: Sequence[str] = DEFAULT_OBSERVATION_KEYS,
                 frame_skip: int = 40,
                 **kwargs):
        """Initializes the environment.

        Args:
            asset_path: The XML model file to load.
            observation_keys: The keys in `get_obs_dict` to concatenate as the
                observations returned by `step` and `reset`.
            frame_skip: The number of simulation steps per environment step.
            interactive: If True, allows the hardware guide motor to freely
                rotate and its current angle is used as the goal.
            success_threshold: The difference threshold (in radians) of the
                object position and the goal position within which we consider
                as a sucesss.
        """
        super().__init__(
            sim_model=get_asset_path(asset_path),
            observation_keys=observation_keys,
            frame_skip=frame_skip,
            **kwargs)

        self._desired_claw_pos = RESET_POSE

        # The following are modified (possibly every reset) by subclasses.
        self._initial_object_pos = 0
        self._initial_object_vel = 0

    def _reset(self):
        """Resets the environment."""
        self._reset_dclaw_and_object(
            claw_pos=RESET_POSE,
            object_pos=self._initial_object_pos,
            object_vel=self._initial_object_vel)

    def _step(self, action: np.ndarray):
        """Applies an action to the robot."""
        self.robot.step({
            'dclaw': action,
        })

    def get_obs_dict(self) -> Dict[str, np.ndarray]:
        """Returns the current observation of the environment.

        Returns:
            A dictionary of observation values. This should be an ordered
            dictionary if `observation_keys` isn't set.
        """
        claw_state, object_state = self.robot.get_state(
            ['dclaw', 'object'])

        obs_dict = collections.OrderedDict((
            ('claw_qpos', claw_state.qpos),
            ('claw_qvel', claw_state.qvel),
            ('object_x', np.cos(object_state.qpos)),
            ('object_y', np.sin(object_state.qpos)),
            ('object_qvel', object_state.qvel),
            ('last_action', self._get_last_action()),
        ))
        # Add hardware-specific state if present.
        if isinstance(claw_state, DynamixelRobotState):
            obs_dict['claw_current'] = claw_state.current

        return obs_dict

    def get_reward_dict(
            self,
            action: np.ndarray,
            obs_dict: Dict[str, np.ndarray],
    ) -> Dict[str, np.ndarray]:
        """Returns the reward for the given action and observation."""
        reward_dict = collections.OrderedDict(())
        return reward_dict

    def get_score_dict(
            self,
            obs_dict: Dict[str, np.ndarray],
            reward_dict: Dict[str, np.ndarray],
    ) -> Dict[str, np.ndarray]:
        """Returns a standardized measure of success for the environment."""
        return collections.OrderedDict(())

    def get_done(
            self,
            obs_dict: Dict[str, np.ndarray],
            reward_dict: Dict[str, np.ndarray],
    ) -> np.ndarray:
        """Returns whether the episode should terminate."""
        return np.zeros_like([0], dtype=bool)


@configurable(pickleable=True)
class DClawTurnRandom(BaseDClawTurn):
    """Turns the object with a random initial and random target position."""

    def _reset(self):
        # Initial position is +/- 60 degrees.
        self._initial_object_pos = self.np_random.uniform(
            low=-np.pi / 3, high=np.pi / 3)
        super()._reset()


@configurable(pickleable=True)
class DClawTurnRandomDynamics(DClawTurnRandom):
    """Turns the object with a random initial and random target position.

    The dynamics of the simulation are randomized each episode.
    """

    def __init__(self,
                 *args,
                 sim_observation_noise: Optional[float] = 0.05,
                 **kwargs):
        super().__init__(
            *args, sim_observation_noise=sim_observation_noise, **kwargs)
        self._randomizer = SimRandomizer(self)
        self._dof_indices = (
            self.robot.get_config('dclaw').qvel_indices.tolist() +
            self.robot.get_config('object').qvel_indices.tolist())

    def _reset(self):
        # Randomize joint dynamics.
        self._randomizer.randomize_dofs(
            self._dof_indices,
            damping_range=(0.005, 0.1),
            friction_loss_range=(0.001, 0.005),
        )
        self._randomizer.randomize_actuators(
            all_same=True,
            kp_range=(1, 3),
        )
        # Randomize friction on all geoms in the scene.
        self._randomizer.randomize_geoms(
            all_same=True,
            friction_slide_range=(0.8, 1.2),
            friction_spin_range=(0.003, 0.007),
            friction_roll_range=(0.00005, 0.00015),
        )
        self._randomizer.randomize_bodies(
            ['mount'],
            position_perturb_range=(-0.01, 0.01),
        )
        self._randomizer.randomize_geoms(
            ['mount'],
            color_range=(0.2, 0.9),
        )
        self._randomizer.randomize_geoms(
            parent_body_names=['valve'],
            color_range=(0.2, 0.9),
        )
        super()._reset()


================================================
FILE: envs/dkitty_redesign.py
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""DKitty redesign
"""

import abc
import collections
from typing import Dict, Optional, Sequence, Tuple, Union

import numpy as np

from robel.components.tracking import TrackerState
from robel.dkitty.base_env import BaseDKittyUprightEnv
from robel.simulation.randomize import SimRandomizer
from robel.utils.configurable import configurable
from robel.utils.math_utils import calculate_cosine
from robel.utils.resources import get_asset_path

DKITTY_ASSET_PATH = 'robel/dkitty/assets/dkitty_walk-v0.xml'

DEFAULT_OBSERVATION_KEYS = (
    'root_pos',
    'root_euler',
    'kitty_qpos',
    # 'root_vel',
    # 'root_angular_vel',
    'kitty_qvel',
    'last_action',
    'upright',
)


class BaseDKittyWalk(BaseDKittyUprightEnv, metaclass=abc.ABCMeta):
    """Shared logic for DKitty walk tasks."""

    def __init__(
            self,
            asset_path: str = DKITTY_ASSET_PATH,
            observation_keys: Sequence[str] = DEFAULT_OBSERVATION_KEYS,
            device_path: Optional[str] = None,
            torso_tracker_id: Optional[Union[str, int]] = None,
            frame_skip: int = 40,
            sticky_action_probability: float = 0.,
            upright_threshold: float = 0.9,
            upright_reward: float = 1,
            falling_reward: float = -500,
            expose_last_action: bool = True,
            expose_upright: bool = True,
            robot_noise_ratio: float = 0.05,
            **kwargs):
        """Initializes the environment.

        Args:
            asset_path: The XML model file to load.
            observation_keys: The keys in `get_obs_dict` to concatenate as the
                observations returned by `step` and `reset`.
            device_path: The device path to Dynamixel hardware.
            torso_tracker_id: The device index or serial of the tracking device
                for the D'Kitty torso.
            frame_skip: The number of simulation steps per environment step.
            sticky_action_probability: Repeat previous action with this
                probability. Default 0 (no sticky actions).
            upright_threshold: The threshold (in [0, 1]) above which the D'Kitty
                is considered to be upright. If the cosine similarity of the
                D'Kitty's z-axis with the global z-axis is below this threshold,
                the D'Kitty is considered to have fallen.
            upright_reward: The reward multiplier for uprightedness.
            falling_reward: The reward multipler for falling.
        """
        self._expose_last_action = expose_last_action
        self._expose_upright = expose_upright
        observation_keys = observation_keys[:-2]
        if self._expose_last_action:
            observation_keys += ('last_action',)
        if self._expose_upright:
            observation_keys += ('upright',)

        # robot_config = self.get_robot_config(device_path)
        # if 'sim_observation_noise' in robot_config.keys():
        #     robot_config['sim_observation_noise'] = robot_noise_ratio
 
        super().__init__(
            sim_model=get_asset_path(asset_path),
            # robot_config=robot_config,
            # tracker_config=self.get_tracker_config(
            #     torso=torso_tracker_id,
            # ),
            observation_keys=observation_keys,
            frame_skip=frame_skip,
            upright_threshold=upright_threshold,
            upright_reward=upright_reward,
            falling_reward=falling_reward,
            **kwargs)

        self._last_action = np.zeros(12)
        self._sticky_action_probability = sticky_action_probability
        self._time_step = 0

    def _reset(self):
        """Resets the environment."""
        self._reset_dkitty_standing()

        # Set the tracker locations.
        self.tracker.set_state({
            'torso': TrackerState(pos=np.zeros(3), rot=np.identity(3)),
        })

        self._time_step = 0

    def _step(self, action: np.ndarray):
        """Applies an action to the robot."""
        self._time_step += 1

        # Sticky actions
        rand = self.np_random.uniform() < self._sticky_action_probability
        action_to_apply = np.where(rand, self._last_action, action)

        # Apply action.
        self.robot.step({
            'dkitty': action_to_apply,
        })
        # Save the action to add to the observation.
        self._last_action = action

    def get_obs_dict(self) -> Dict[str, np.ndarray]:
        """Returns the current observation of the environment.

        Returns:
            A dictionary of observation values. This should be an ordered
            dictionary if `observation_keys` isn't set.
        """
        robot_state = self.robot.get_state('dkitty')
        torso_track_state = self.tracker.get_state(
            ['torso'])[0]
        obs_dict = (('root_pos', torso_track_state.pos),
                    ('root_euler', torso_track_state.rot_euler),
                    ('root_vel', torso_track_state.vel),
                    ('root_angular_vel', torso_track_state.angular_vel),
                    ('kitty_qpos', robot_state.qpos),
                    ('kitty_qvel', robot_state.qvel))

        if self._expose_last_action:
            obs_dict += (('last_action', self._last_action),)

        # Add observation terms relating to being upright.
        if self._expose_upright:
            obs_dict += (*self._get_upright_obs(torso_track_state).items(),)

        return collections.OrderedDict(obs_dict)

    def get_reward_dict(
            self,
            action: np.ndarray,
            obs_dict: Dict[str, np.ndarray],
    ) -> Dict[str, np.ndarray]:
        """Returns the reward for the given action and observation."""
        reward_dict = collections.OrderedDict(())
        return reward_dict

    def get_score_dict(
            self,
            obs_dict: Dict[str, np.ndarray],
            reward_dict: Dict[str, np.ndarray],
    ) -> Dict[str, np.ndarray]:
        """Returns a standardized measure of success for the environment."""
        return collections.OrderedDict(())

@configurable(pickleable=True)
class DKittyRandomDynamics(BaseDKittyWalk):
    """Walk straight towards a random location."""

    def __init__(self, *args, randomize_hfield=0.0, **kwargs):
        super().__init__(*args, **kwargs)
        self._randomizer = SimRandomizer(self)
        self._randomize_hfield = randomize_hfield
        self._dof_indices = (
            self.robot.get_config('dkitty').qvel_indices.tolist())

    def _reset(self):
        """Resets the environment."""
        # Randomize joint dynamics.
        self._randomizer.randomize_dofs(
            self._dof_indices,
            all_same=True,
            damping_range=(0.1, 0.2),
            friction_loss_range=(0.001, 0.005),
        )
        self._randomizer.randomize_actuators(
            all_same=True,
            kp_range=(2.8, 3.2),
        )
        # Randomize friction on all geoms in the scene.
        self._randomizer.randomize_geoms(
            all_same=True,
            friction_slide_range=(0.8, 1.2),
            friction_spin_range=(0.003, 0.007),
            friction_roll_range=(0.00005, 0.00015),
        )
        # Generate a random height field.
        self._randomizer.randomize_global(
            total_mass_range=(1.6, 2.0),
            height_field_range=(0, self._randomize_hfield),
        )
        # if self._randomize_hfield > 0.0:
        #     self.sim_scene.upload_height_field(0)
        super()._reset()


================================================
FILE: envs/gym_mujoco/ant.py
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from gym import utils
import numpy as np
from gym.envs.mujoco import mujoco_env

def q_inv(a):
  return [a[0], -a[1], -a[2], -a[3]]


def q_mult(a, b):  # multiply two quaternion
  w = a[0] * b[0] - a[1] * b[1] - a[2] * b[2] - a[3] * b[3]
  i = a[0] * b[1] + a[1] * b[0] + a[2] * b[3] - a[3] * b[2]
  j = a[0] * b[2] - a[1] * b[3] + a[2] * b[0] + a[3] * b[1]
  k = a[0] * b[3] + a[1] * b[2] - a[2] * b[1] + a[3] * b[0]
  return [w, i, j, k]

# pylint: disable=missing-docstring
class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):

  def __init__(self,
               task="forward",
               goal=None,
               expose_all_qpos=False,
               expose_body_coms=None,
               expose_body_comvels=None,
               expose_foot_sensors=False,
               use_alt_path=False,
               model_path="ant.xml"):
    self._task = task
    self._goal = goal
    self._expose_all_qpos = expose_all_qpos
    self._expose_body_coms = expose_body_coms
    self._expose_body_comvels = expose_body_comvels
    self._expose_foot_sensors = expose_foot_sensors
    self._body_com_indices = {}
    self._body_comvel_indices = {}

    # Settings from
    # https://github.com/openai/gym/blob/master/gym/envs/__init__.py

    xml_path = "envs/assets/"
    model_path = os.path.abspath(os.path.join(xml_path, model_path))
    mujoco_env.MujocoEnv.__init__(self, model_path, 5)
    utils.EzPickle.__init__(self)

  def compute_reward(self, ob, next_ob, action=None):
    xposbefore = ob[:, 0]
    yposbefore = ob[:, 1]
    xposafter = next_ob[:, 0]
    yposafter = next_ob[:, 1]

    forward_reward = (xposafter - xposbefore) / self.dt
    sideward_reward = (yposafter - yposbefore) / self.dt

    if action is not None:
      ctrl_cost = .5 * np.square(action).sum(axis=1)
      survive_reward = 1.0
    if self._task == "forward":
      reward = forward_reward - ctrl_cost + survive_reward
    elif self._task == "backward":
      reward = -forward_reward - ctrl_cost + survive_reward
    elif self._task == "left":
      reward = sideward_reward - ctrl_cost + survive_reward
    elif self._task == "right":
      reward = -sideward_reward - ctrl_cost + survive_reward
    elif self._task == "goal":
      reward = -np.linalg.norm(
          np.array([xposafter, yposafter]).T - self._goal, axis=1)

    return reward

  def step(self, a):
    xposbefore = self.get_body_com("torso")[0]
    yposbefore = self.sim.data.qpos.flat[1]
    self.do_simulation(a, self.frame_skip)
    xposafter = self.get_body_com("torso")[0]
    yposafter = self.sim.data.qpos.flat[1]

    forward_reward = (xposafter - xposbefore) / self.dt
    sideward_reward = (yposafter - yposbefore) / self.dt

    ctrl_cost = .5 * np.square(a).sum()
    survive_reward = 1.0
    if self._task == "forward":
      reward = forward_reward - ctrl_cost + survive_reward
    elif self._task == "backward":
      reward = -forward_reward - ctrl_cost + survive_reward
    elif self._task == "left":
      reward = sideward_reward - ctrl_cost + survive_reward
    elif self._task == "right":
      reward = -sideward_reward - ctrl_cost + survive_reward
    elif self._task == "goal":
      reward = -np.linalg.norm(np.array([xposafter, yposafter]) - self._goal)
    elif self._task == "motion":
      reward = np.max(np.abs(np.array([forward_reward, sideward_reward
                                      ]))) - ctrl_cost + survive_reward

    state = self.state_vector()
    notdone = np.isfinite(state).all()
    done = not notdone
    ob = self._get_obs()
    return ob, reward, done, dict(
        reward_forward=forward_reward,
        reward_sideward=sideward_reward,
        reward_ctrl=-ctrl_cost,
        reward_survive=survive_reward)

  def _get_obs(self):
    # No crfc observation
    if self._expose_all_qpos:
      obs = np.concatenate([
          self.sim.data.qpos.flat[:15],
          self.sim.data.qvel.flat[:14],
      ])
    else:
      obs = np.concatenate([
          self.sim.data.qpos.flat[2:15],
          self.sim.data.qvel.flat[:14],
      ])

    if self._expose_body_coms is not None:
      for name in self._expose_body_coms:
        com = self.get_body_com(name)
        if name not in self._body_com_indices:
          indices = range(len(obs), len(obs) + len(com))
          self._body_com_indices[name] = indices
        obs = np.concatenate([obs, com])

    if self._expose_body_comvels is not None:
      for name in self._expose_body_comvels:
        comvel = self.get_body_comvel(name)
        if name not in self._body_comvel_indices:
          indices = range(len(obs), len(obs) + len(comvel))
          self._body_comvel_indices[name] = indices
        obs = np.concatenate([obs, comvel])

    if self._expose_foot_sensors:
      obs = np.concatenate([obs, self.sim.data.sensordata])
    return obs

  def reset_model(self):
    qpos = self.init_qpos + self.np_random.uniform(
        size=self.sim.model.nq, low=-.1, high=.1)
    qvel = self.init_qvel + self.np_random.randn(self.sim.model.nv) * .1

    qpos[15:] = self.init_qpos[15:]
    qvel[14:] = 0.

    self.set_state(qpos, qvel)
    return self._get_obs()

  def viewer_setup(self):
    self.viewer.cam.distance = self.model.stat.extent * 2.5

  def get_ori(self):
    ori = [0, 1, 0, 0]
    rot = self.sim.data.qpos[3:7]  # take the quaternion
    ori = q_mult(q_mult(rot, ori), q_inv(rot))[1:3]  # project onto x-y plane
    ori = math.atan2(ori[1], ori[0])
    return ori

  @property
  def body_com_indices(self):
    return self._body_com_indices

  @property
  def body_comvel_indices(self):
    return self._body_comvel_indices


================================================
FILE: envs/gym_mujoco/half_cheetah.py
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from gym import utils
import numpy as np
from gym.envs.mujoco import mujoco_env


class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):

  def __init__(self,
               expose_all_qpos=False,
               task='default',
               target_velocity=None,
               model_path='half_cheetah.xml'):
    # Settings from
    # https://github.com/openai/gym/blob/master/gym/envs/__init__.py
    self._expose_all_qpos = expose_all_qpos
    self._task = task
    self._target_velocity = target_velocity

    xml_path = "envs/assets/"
    model_path = os.path.abspath(os.path.join(xml_path, model_path))

    mujoco_env.MujocoEnv.__init__(
        self,
        model_path,
        5)
    utils.EzPickle.__init__(self)

  def step(self, action):
    xposbefore = self.sim.data.qpos[0]
    self.do_simulation(action, self.frame_skip)
    xposafter = self.sim.data.qpos[0]
    xvelafter = self.sim.data.qvel[0]
    ob = self._get_obs()
    reward_ctrl = -0.1 * np.square(action).sum()

    if self._task == 'default':
      reward_vel = 0.
      reward_run = (xposafter - xposbefore) / self.dt
      reward = reward_ctrl + reward_run
    elif self._task == 'target_velocity':
      reward_vel = -(self._target_velocity - xvelafter)**2
      reward = reward_ctrl + reward_vel
    elif self._task == 'run_back':
      reward_vel = 0.
      reward_run = (xposbefore - xposafter) / self.dt
      reward = reward_ctrl + reward_run

    done = False
    return ob, reward, done, dict(
        reward_run=reward_run, reward_ctrl=reward_ctrl, reward_vel=reward_vel)

  def _get_obs(self):
    if self._expose_all_qpos:
      return np.concatenate(
          [self.sim.data.qpos.flat, self.sim.data.qvel.flat])
    return np.concatenate([
        self.sim.data.qpos.flat[1:],
        self.sim.data.qvel.flat,
    ])

  def reset_model(self):
    qpos = self.init_qpos + self.np_random.uniform(
        low=-.1, high=.1, size=self.sim.model.nq)
    qvel = self.init_qvel + self.np_random.randn(self.sim.model.nv) * .1
    self.set_state(qpos, qvel)
    return self._get_obs()

  def viewer_setup(self):
    self.viewer.cam.distance = self.model.stat.extent * 0.5


================================================
FILE: envs/gym_mujoco/humanoid.py
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from gym import utils
import numpy as np
from gym.envs.mujoco import mujoco_env


def mass_center(sim):
  mass = np.expand_dims(sim.model.body_mass, 1)
  xpos = sim.data.xipos
  return (np.sum(mass * xpos, 0) / np.sum(mass))[0]


# pylint: disable=missing-docstring
class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):

  def __init__(self, 
               expose_all_qpos=False,
               model_path='humanoid.xml',
               task=None,
               goal=None):

    self._task = task
    self._goal = goal
    if self._task == "follow_goals":
      self._goal_list = [
          np.array([3.0, -0.5]),
          np.array([6.0, 8.0]),
          np.array([12.0, 12.0]),
      ]
      self._goal = self._goal_list[0]
      print("Following a trajectory of goals:", self._goal_list)

    self._expose_all_qpos = expose_all_qpos
    xml_path = "envs/assets/"
    model_path = os.path.abspath(os.path.join(xml_path, model_path))
    mujoco_env.MujocoEnv.__init__(self, model_path, 5)
    utils.EzPickle.__init__(self)

  def _get_obs(self):
    data = self.sim.data
    if self._expose_all_qpos:
      return np.concatenate([
          data.qpos.flat, data.qvel.flat,
          # data.cinert.flat, data.cvel.flat,
          # data.qfrc_actuator.flat, data.cfrc_ext.flat
      ])
    return np.concatenate([
        data.qpos.flat[2:], data.qvel.flat, data.cinert.flat, data.cvel.flat,
        data.qfrc_actuator.flat, data.cfrc_ext.flat
    ])

  def compute_reward(self, ob, next_ob, action=None):
    xposbefore = ob[:, 0]
    yposbefore = ob[:, 1]
    xposafter = next_ob[:, 0]
    yposafter = next_ob[:, 1]

    forward_reward = (xposafter - xposbefore) / self.dt
    sideward_reward = (yposafter - yposbefore) / self.dt

    if action is not None:
      ctrl_cost = .5 * np.square(action).sum(axis=1)
      survive_reward = 1.0
    if self._task == "forward":
      reward = forward_reward - ctrl_cost + survive_reward
    elif self._task == "backward":
      reward = -forward_reward - ctrl_cost + survive_reward
    elif self._task == "left":
      reward = sideward_reward - ctrl_cost + survive_reward
    elif self._task == "right":
      reward = -sideward_reward - ctrl_cost + survive_reward
    elif self._task in ["goal", "follow_goals"]:
      reward = -np.linalg.norm(
          np.array([xposafter, yposafter]).T - self._goal, axis=1)
    elif self._task in ["sparse_goal"]:
      reward = (-np.linalg.norm(
          np.array([xposafter, yposafter]).T - self._goal, axis=1) >
                -0.3).astype(np.float32)
    return reward

  def step(self, a):
    pos_before = mass_center(self.sim)
    self.do_simulation(a, self.frame_skip)
    pos_after = mass_center(self.sim)
    alive_bonus = 5.0
    data = self.sim.data
    lin_vel_cost = 0.25 * (
        pos_after - pos_before) / self.sim.model.opt.timestep
    quad_ctrl_cost = 0.1 * np.square(data.ctrl).sum()
    quad_impact_cost = .5e-6 * np.square(data.cfrc_ext).sum()
    quad_impact_cost = min(quad_impact_cost, 10)
    reward = lin_vel_cost - quad_ctrl_cost - quad_impact_cost + alive_bonus

    if self._task == "follow_goals":
      xposafter = self.sim.data.qpos.flat[0]
      yposafter = self.sim.data.qpos.flat[1]
      reward = -np.linalg.norm(np.array([xposafter, yposafter]).T - self._goal)
      # update goal
      if np.abs(reward) < 0.5:
        self._goal = self._goal_list[0]
        self._goal_list = self._goal_list[1:]
        print("Goal Updated:", self._goal)

    elif self._task == "goal":
      xposafter = self.sim.data.qpos.flat[0]
      yposafter = self.sim.data.qpos.flat[1]
      reward = -np.linalg.norm(np.array([xposafter, yposafter]).T - self._goal)

    qpos = self.sim.data.qpos
    done = bool((qpos[2] < 1.0) or (qpos[2] > 2.0))
    return self._get_obs(), reward, done, dict(
        reward_linvel=lin_vel_cost,
        reward_quadctrl=-quad_ctrl_cost,
        reward_alive=alive_bonus,
        reward_impact=-quad_impact_cost)

  def reset_model(self):
    c = 0.01
    self.set_state(
        self.init_qpos + self.np_random.uniform(
            low=-c, high=c, size=self.sim.model.nq),
        self.init_qvel + self.np_random.uniform(
            low=-c,
            high=c,
            size=self.sim.model.nv,
        ))

    if self._task == "follow_goals":
      self._goal = self._goal_list[0]
      self._goal_list = self._goal_list[1:]
      print("Current goal:", self._goal)

    return self._get_obs()

  def viewer_setup(self):
    self.viewer.cam.distance = self.model.stat.extent * 2.0


================================================
FILE: envs/gym_mujoco/point_mass.py
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import os

from gym import utils
import numpy as np
from gym.envs.mujoco import mujoco_env


# pylint: disable=missing-docstring
class PointMassEnv(mujoco_env.MujocoEnv, utils.EzPickle):

  def __init__(self,
               target=None,
               wiggly_weight=0.,
               alt_xml=False,
               expose_velocity=True,
               expose_goal=True,
               use_simulator=False,
               model_path='point.xml'):
    self._sample_target = target
    if self._sample_target is not None:
      self.goal = np.array([1.0, 1.0])

    self._expose_velocity = expose_velocity
    self._expose_goal = expose_goal
    self._use_simulator = use_simulator
    self._wiggly_weight = abs(wiggly_weight)
    self._wiggle_direction = +1 if wiggly_weight > 0. else -1

    xml_path = "envs/assets/"
    model_path = os.path.abspath(os.path.join(xml_path, model_path))

    if self._use_simulator:
      mujoco_env.MujocoEnv.__init__(self, model_path, 5)
    else:
      mujoco_env.MujocoEnv.__init__(self, model_path, 1)
    utils.EzPickle.__init__(self)

  def step(self, action):
    if self._use_simulator:
      self.do_simulation(action, self.frame_skip)
    else:
      force = 0.2 * action[0]
      rot = 1.0 * action[1]
      qpos = self.sim.data.qpos.flat.copy()
      qpos[2] += rot
      ori = qpos[2]
      dx = math.cos(ori) * force
      dy = math.sin(ori) * force
      qpos[0] = np.clip(qpos[0] + dx, -2, 2)
      qpos[1] = np.clip(qpos[1] + dy, -2, 2)
      qvel = self.sim.data.qvel.flat.copy()
      self.set_state(qpos, qvel)

    ob = self._get_obs()
    if self._sample_target is not None and self.goal is not None:
      reward = -np.linalg.norm(self.sim.data.qpos.flat[:2] - self.goal)**2
    else:
      reward = 0.

    if self._wiggly_weight > 0.:
      reward = (np.exp(-((-reward)**0.5))**(1. - self._wiggly_weight)) * (
          max(self._wiggle_direction * action[1], 0)**self._wiggly_weight)
    done = False
    return ob, reward, done, None

  def _get_obs(self):
    new_obs = [self.sim.data.qpos.flat]
    if self._expose_velocity:
      new_obs += [self.sim.data.qvel.flat]
    if self._expose_goal and self.goal is not None:
      new_obs += [self.goal]
    return np.concatenate(new_obs)

  def reset_model(self):
    qpos = self.init_qpos + np.append(
        self.np_random.uniform(low=-.2, high=.2, size=2),
        self.np_random.uniform(-np.pi, np.pi, size=1))
    qvel = self.init_qvel + self.np_random.randn(self.sim.model.nv) * .01
    if self._sample_target is not None:
      self.goal = self._sample_target(qpos[:2])
    self.set_state(qpos, qvel)
    return self._get_obs()

  # only works when goal is not exposed
  def set_qpos(self, state):
    qvel = np.copy(self.sim.data.qvel.flat)
    self.set_state(state, qvel)

  def viewer_setup(self):
    self.viewer.cam.distance = self.model.stat.extent * 0.5


================================================
FILE: envs/hand_block.py
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import gym
import os
from gym import spaces
from gym.envs.robotics.hand.manipulate import ManipulateEnv
import mujoco_py

MANIPULATE_BLOCK_XML = os.path.join('hand', 'manipulate_block.xml')

class HandBlockCustomEnv(ManipulateEnv):
	def __init__(self,
				 model_path=MANIPULATE_BLOCK_XML,
				 target_position='random',
				 target_rotation='xyz',
				 reward_type='sparse',
				 horizontal_wrist_constraint=1.0,
				 vertical_wrist_constraint=1.0,
				 **kwargs):
		ManipulateEnv.__init__(self,
			 model_path=MANIPULATE_BLOCK_XML,
			 target_position=target_position,
			 target_rotation=target_rotation,
			 target_position_range=np.array([(-0.04, 0.04), (-0.06, 0.02), (0.0, 0.06)]),
			 reward_type=reward_type,
			 **kwargs)

		self._viewers = {}

		# constraining the movement of wrist (vertical movement more important than horizontal)
		self.action_space.low[0] = -horizontal_wrist_constraint
		self.action_space.high[0] = horizontal_wrist_constraint
		self.action_space.low[1] = -vertical_wrist_constraint
		self.action_space.high[1] = vertical_wrist_constraint

	def _get_viewer(self, mode):
		self.viewer = self._viewers.get(mode)
		if self.viewer is None:
			if mode == 'human':
				self.viewer = mujoco_py.MjViewer(self.sim)
			elif mode == 'rgb_array':
				self.viewer = mujoco_py.MjRenderContextOffscreen(self.sim, device_id=-1)
				self._viewer_setup()
				self._viewers[mode] = self.viewer
		return self.viewer

	def _viewer_setup(self):
		body_id = self.sim.model.body_name2id('robot0:palm')
		lookat = self.sim.data.body_xpos[body_id]
		for idx, value in enumerate(lookat):
			self.viewer.cam.lookat[idx] = value
		self.viewer.cam.distance = 0.5
		self.viewer.cam.azimuth = 55.
		self.viewer.cam.elevation = -25.

	def step(self, action):
		
		def is_on_palm():
			self.sim.forward()
			cube_middle_idx = self.sim.model.site_name2id('object:center')
			cube_middle_pos = self.sim.data.site_xpos[cube_middle_idx]
			is_on_palm = (cube_middle_pos[2] > 0.04)
			return is_on_palm

		obs, reward, done, info = super().step(action)
		done = not is_on_palm()
		return obs, reward, done, info

	def render(self, mode='human', width=500, height=500):
		self._render_callback()
		if mode == 'rgb_array':
			self._get_viewer(mode).render(width, height)
			# window size used for old mujoco-py:
			data = self._get_viewer(mode).read_pixels(width, height, depth=False)
			# original image is upside-down, so flip it
			return data[::-1, :, :]
		elif mode == 'human':
			self._get_viewer(mode).render()


================================================
FILE: envs/skill_wrapper.py
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

import gym
from gym import Wrapper

class SkillWrapper(Wrapper):

  def __init__(
      self,
      env,
      # skill type and dimension
      num_latent_skills=None,
      skill_type='discrete_uniform',
      # execute an episode with the same predefined skill, does not resample
      preset_skill=None,
      # resample skills within episode
      min_steps_before_resample=10,
      resample_prob=0.):

    super(SkillWrapper, self).__init__(env)
    self._skill_type = skill_type
    if num_latent_skills is None:
      self._num_skills = 0
    else:
      self._num_skills = num_latent_skills
    self._preset_skill = preset_skill

    # attributes for controlling skill resampling
    self._min_steps_before_resample = min_steps_before_resample
    self._resample_prob = resample_prob

    if isinstance(self.env.observation_space, gym.spaces.Dict):
      size = self.env.observation_space.spaces['observation'].shape[0] + self._num_skills
    else:
      size = self.env.observation_space.shape[0] + self._num_skills
    self.observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(size,), dtype='float32')

  def _remake_time_step(self, cur_obs):
    if isinstance(self.env.observation_space, gym.spaces.Dict):
      cur_obs = cur_obs['observation']

    if self._num_skills == 0:
      return cur_obs
    else:
      return np.concatenate([cur_obs, self.skill])

  def _set_skill(self):
    if self._num_skills:
      if self._preset_skill is not None:
        self.skill = self._preset_skill
        print('Skill:', self.skill)
      elif self._skill_type == 'discrete_uniform':
        self.skill = np.random.multinomial(
            1, [1. / self._num_skills] * self._num_skills)
      elif self._skill_type == 'gaussian':
        self.skill = np.random.multivariate_normal(
            np.zeros(self._num_skills), np.eye(self._num_skills))
      elif self._skill_type == 'cont_uniform':
        self.skill = np.random.uniform(
            low=-1.0, high=1.0, size=self._num_skills)

  def reset(self):
    cur_obs = self.env.reset()
    self._set_skill()
    self._step_count = 0
    return self._remake_time_step(cur_obs)

  def step(self, action):
    cur_obs, reward, done, info = self.env.step(action)
    self._step_count += 1
    if self._preset_skill is None and self._step_count >= self._min_steps_before_resample and np.random.random(
    ) < self._resample_prob:
      self._set_skill()
      self._step_count = 0
    return self._remake_time_step(cur_obs), reward, done, info

  def close(self):
    return self.env.close()


================================================
FILE: envs/video_wrapper.py
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

import gym
from gym import Wrapper
from gym.wrappers.monitoring import video_recorder

class VideoWrapper(Wrapper):

  def __init__(self, env, base_path, base_name=None, new_video_every_reset=False):
    super(VideoWrapper, self).__init__(env)

    self._base_path = base_path
    self._base_name = base_name

    self._new_video_every_reset = new_video_every_reset
    if self._new_video_every_reset:
      self._counter = 0
      self._recorder = None
    else:
      if self._base_name is not None:
        self._vid_name = os.path.join(self._base_path, self._base_name)
      else:
        self._vid_name = self._base_path
      self._recorder = video_recorder.VideoRecorder(self.env, path=self._vid_name + '.mp4')

  def reset(self):
    if self._new_video_every_reset:
      if self._recorder is not None:
        self._recorder.close()

      self._counter += 1
      if self._base_name is not None:
        self._vid_name = os.path.join(self._base_path, self._base_name + '_' + str(self._counter))
      else:
        self._vid_name = self._base_path + '_' + str(self._counter)

      self._recorder = video_recorder.VideoRecorder(self.env, path=self._vid_name + '.mp4')

    return self.env.reset()

  def step(self, action):
    self._recorder.capture_frame()
    return self.env.step(action)

  def close(self):
    self._recorder.encoder.proc.stdin.flush()
    self._recorder.close()
    return self.env.close()

================================================
FILE: lib/py_tf_policy.py
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Converts TensorFlow Policies into Python Policies."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import logging

import tensorflow as tf
from tf_agents.policies import py_policy
from tf_agents.policies import tf_policy
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import policy_step
from tf_agents.utils import common
from tf_agents.utils import nest_utils
from tf_agents.utils import session_utils


class PyTFPolicy(py_policy.Base, session_utils.SessionUser):
  """Exposes a Python policy as wrapper over a TF Policy."""

  # TODO(damienv): currently, the initial policy state must be batched
  # if batch_size is given. Without losing too much generality, the initial
  # policy state could be the same for every element in the batch.
  # In that case, the initial policy state could be given with no batch
  # dimension.
  # TODO(sfishman): Remove batch_size param entirely.
  def __init__(self, policy, batch_size=None, seed=None):
    """Initializes a new `PyTFPolicy`.

    Args:
      policy: A TF Policy implementing `tf_policy.Base`.
      batch_size: (deprecated)
      seed: Seed to use if policy performs random actions (optional).
    """
    if not isinstance(policy, tf_policy.Base):
      logging.warning('Policy should implement tf_policy.Base')

    if batch_size is not None:
      logging.warning('In PyTFPolicy constructor, `batch_size` is deprecated, '
                      'this parameter has no effect. This argument will be '
                      'removed on 2019-05-01')

    time_step_spec = tensor_spec.to_nest_array_spec(policy.time_step_spec)
    action_spec = tensor_spec.to_nest_array_spec(policy.action_spec)
    super(PyTFPolicy, self).__init__(
        time_step_spec, action_spec, policy_state_spec=(), info_spec=())

    self._tf_policy = policy
    self.session = None

    self._policy_state_spec = tensor_spec.to_nest_array_spec(
        self._tf_policy.policy_state_spec)

    self._batch_size = None
    self._batched = None
    self._seed = seed
    self._built = False

  def _construct(self, batch_size, graph):
    """Construct the agent graph through placeholders."""

    self._batch_size = batch_size
    self._batched = batch_size is not None

    outer_dims = [self._batch_size] if self._batched else [1]
    with graph.as_default():
      self._time_step = tensor_spec.to_nest_placeholder(
          self._tf_policy.time_step_spec, outer_dims=outer_dims)
      self._tf_initial_state = self._tf_policy.get_initial_state(
          batch_size=self._batch_size or 1)

      self._policy_state = tf.nest.map_structure(
          lambda ps: tf.compat.v1.placeholder(  # pylint: disable=g-long-lambda
              ps.dtype,
              ps.shape,
              name='policy_state'),
          self._tf_initial_state)
      self._action_step = self._tf_policy.action(
          self._time_step, self._policy_state, seed=self._seed)

      self._actions = tensor_spec.to_nest_placeholder(
          self._tf_policy.action_spec, outer_dims=outer_dims)
      self._action_distribution = self._tf_policy.distribution(
          self._time_step, policy_state=self._policy_state).action
      self._log_prob = common.log_probability(self._action_distribution,
                                              self._actions,
                                              self._tf_policy.action_spec)

  def initialize(self, batch_size, graph=None):
    if self._built:
      raise RuntimeError('PyTFPolicy can only be initialized once.')

    if not graph:
      graph = tf.compat.v1.get_default_graph()

    self._construct(batch_size, graph)
    var_list = tf.nest.flatten(self._tf_policy.variables())
    common.initialize_uninitialized_variables(self.session, var_list)
    self._built = True

  def save(self, policy_dir=None, graph=None):
    if not self._built:
      raise RuntimeError('PyTFPolicy has not been initialized yet.')

    if not graph:
      graph = tf.compat.v1.get_default_graph()

    with graph.as_default():
      global_step = tf.compat.v1.train.get_or_create_global_step()
      policy_checkpointer = common.Checkpointer(
          ckpt_dir=policy_dir, policy=self._tf_policy, global_step=global_step)
      policy_checkpointer.initialize_or_restore(self.session)
      with self.session.as_default():
        policy_checkpointer.save(global_step)

  def restore(self, policy_dir, graph=None, assert_consumed=True):
    """Restores the policy from the checkpoint.

    Args:
      policy_dir: Directory with the checkpoint.
      graph: A graph, inside which policy the is restored (optional).
      assert_consumed: If true, contents of the checkpoint will be checked
        for a match against graph variables.

    Returns:
      step: Global step associated with the restored policy checkpoint.

    Raises:
      RuntimeError: if the policy is not initialized.
      AssertionError: if the checkpoint contains variables which do not have
        matching names in the graph, and assert_consumed is set to True.

    """

    if not self._built:
      raise RuntimeError(
          'PyTFPolicy must be initialized before being restored.')
    if not graph:
      graph = tf.compat.v1.get_default_graph()

    with graph.as_default():
      global_step = tf.compat.v1.train.get_or_create_global_step()
      policy_checkpointer = common.Checkpointer(
          ckpt_dir=policy_dir, policy=self._tf_policy, global_step=global_step)
      status = policy_checkpointer.initialize_or_restore(self.session)
      with self.session.as_default():
        if assert_consumed:
          status.assert_consumed()
        status.run_restore_ops()
      return self.session.run(global_step)

  def _build_from_time_step(self, time_step):
    outer_shape = nest_utils.get_outer_array_shape(time_step,
                                                   self._time_step_spec)
    if len(outer_shape) == 1:
      self.initialize(outer_shape[0])
    elif not outer_shape:
      self.initialize(None)
    else:
      raise ValueError(
          'Cannot handle more than one outer dimension. Saw {} outer '
          'dimensions: {}'.format(len(outer_shape), outer_shape))

  def _get_initial_state(self, batch_size):
    if not self._built:
      self.initialize(batch_size)
    if batch_size != self._batch_size:
      raise ValueError(
          '`batch_size` argument is different from the batch size provided '
          'previously. Expected {}, but saw {}.'.format(self._batch_size,
                                                        batch_size))
    return self.session.run(self._tf_initial_state)

  def _action(self, time_step, policy_state):
    if not self._built:
      self._build_from_time_step(time_step)

    batch_size = None
    if time_step.step_type.shape:
      batch_size = time_step.step_type.shape[0]
    if self._batch_size != batch_size:
      raise ValueError(
          'The batch size of time_step is different from the batch size '
          'provided previously. Expected {}, but saw {}.'.format(
              self._batch_size, batch_size))

    if not self._batched:
      # Since policy_state is given in a batched form from the policy and we
      # simply have to send it back we do not need to worry about it. Only
      # update time_step.
      time_step = nest_utils.batch_nested_array(time_step)

    tf.nest.assert_same_structure(self._time_step, time_step)
    feed_dict = {self._time_step: time_step}
    if policy_state is not None:
      # Flatten policy_state to handle specs that are not hashable due to lists.
      for state_ph, state in zip(
          tf.nest.flatten(self._policy_state), tf.nest.flatten(policy_state)):
        feed_dict[state_ph] = state

    action_step = self.session.run(self._action_step, feed_dict)
    action, state, info = action_step

    if not self._batched:
      action, info = nest_utils.unbatch_nested_array([action, info])

    return policy_step.PolicyStep(action, state, info)

  def log_prob(self, time_step, action_step, policy_state=None):
    if not self._built:
      self._build_from_time_step(time_step)
    tf.nest.assert_same_structure(self._time_step, time_step)
    tf.nest.assert_same_structure(self._actions, action_step)
    feed_dict = {self._time_step: time_step, self._actions: action_step}
    if policy_state is not None:
      feed_dict[self._policy_state] = policy_state
    return self.session.run(self._log_prob, feed_dict)


================================================
FILE: lib/py_uniform_replay_buffer.py
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Uniform replay buffer in Python.

The base class provides all the functionalities of a uniform replay buffer:
  - add samples in a First In First Out way.
  - read samples uniformly.

PyHashedReplayBuffer is a flavor of the base class which
compresses the observations when the observations have some partial overlap
(e.g. when using frame stacking).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import threading

import numpy as np
import tensorflow as tf
from tf_agents.replay_buffers import replay_buffer
from tf_agents.specs import array_spec
from tf_agents.utils import nest_utils
from tf_agents.utils import numpy_storage


class PyUniformReplayBuffer(replay_buffer.ReplayBuffer):
  """A Python-based replay buffer that supports uniform sampling.

  Writing and reading to this replay buffer is thread safe.

  This replay buffer can be subclassed to change the encoding used for the
  underlying storage by overriding _encoded_data_spec, _encode, _decode, and
  _on_delete.
  """

  def __init__(self, data_spec, capacity):
    """Creates a PyUniformReplayBuffer.

    Args:
      data_spec: An ArraySpec or a list/tuple/nest of ArraySpecs describing a
        single item that can be stored in this buffer.
      capacity: The maximum number of items that can be stored in the buffer.
    """
    super(PyUniformReplayBuffer, self).__init__(data_spec, capacity)

    self._storage = numpy_storage.NumpyStorage(self._encoded_data_spec(),
                                               capacity)
    self._lock = threading.Lock()
    self._np_state = numpy_storage.NumpyState()

    # Adding elements to the replay buffer is done in a circular way.
    # Keeps track of the actual size of the replay buffer and the location
    # where to add new elements.
    self._np_state.size = np.int64(0)
    self._np_state.cur_id = np.int64(0)

    # Total number of items that went through the replay buffer.
    self._np_state.item_count = np.int64(0)

  def _encoded_data_spec(self):
    """Spec of data items after encoding using _encode."""
    return self._data_spec

  def _encode(self, item):
    """Encodes an item (before adding it to the buffer)."""
    return item

  def _decode(self, item):
    """Decodes an item."""
    return item

  def _on_delete(self, encoded_item):
    """Do any necessary cleanup."""
    pass

  @property
  def size(self):
    return self._np_state.size

  def _add_batch(self, items):
    outer_shape = nest_utils.get_outer_array_shape(items, self._data_spec)
    if outer_shape[0] != 1:
      raise NotImplementedError('PyUniformReplayBuffer only supports a batch '
                                'size of 1, but received `items` with batch '
                                'size {}.'.format(outer_shape[0]))

    item = nest_utils.unbatch_nested_array(items)
    with self._lock:
      if self._np_state.size == self._capacity:
        # If we are at capacity, we are deleting element cur_id.
        self._on_delete(self._storage.get(self._np_state.cur_id))
      self._storage.set(self._np_state.cur_id, self._encode(item))
      self._np_state.size = np.minimum(self._np_state.size + 1,
                                       self._capacity)
      self._np_state.cur_id = (self._np_state.cur_id + 1) % self._capacity
      self._np_state.item_count += 1

  def _get_next(self,
                sample_batch_size=None,
                num_steps=None,
                time_stacked=True):
    num_steps_value = num_steps if num_steps is not None else 1
    def get_single():
      """Gets a single item from the replay buffer."""
      with self._lock:
        if self._np_state.size <= 0:
          def empty_item(spec):
            return np.empty(spec.shape, dtype=spec.dtype)
          if num_steps is not None:
            item = [tf.nest.map_structure(empty_item, self.data_spec)
                    for n in range(num_steps)]
            if time_stacked:
              item = nest_utils.stack_nested_arrays(item)
          else:
            item = tf.nest.map_structure(empty_item, self.data_spec)
          return item
        idx = np.random.randint(self._np_state.size - num_steps_value + 1)
        if self._np_state.size == self._capacity:
          # If the buffer is full, add cur_id (head of circular buffer) so that
          # we sample from the range [cur_id, cur_id + size - num_steps_value].
          # We will modulo the size below.
          idx += self._np_state.cur_id

        if num_steps is not None:
          # TODO(b/120242830): Try getting data from numpy in one shot rather
          # than num_steps_value.
          item = [self._decode(self._storage.get((idx + n) % self._capacity))
                  for n in range(num_steps)]
        else:
          item = self._decode(self._storage.get(idx % self._capacity))

      if num_steps is not None and time_stacked:
        item = nest_utils.stack_nested_arrays(item)
      return item

    if sample_batch_size is None:
      return get_single()
    else:
      samples = [get_single() for _ in range(sample_batch_size)]
      return nest_utils.stack_nested_arrays(samples)

  def _as_dataset(self, sample_batch_size=None, num_steps=None,
                  num_parallel_calls=None):
    if num_parallel_calls is not None:
      raise NotImplementedError('PyUniformReplayBuffer does not support '
                                'num_parallel_calls (must be None).')

    data_spec = self._data_spec
    if sample_batch_size is not None:
      data_spec = array_spec.add_outer_dims_nest(
          data_spec, (sample_batch_size,))
    if num_steps is not None:
      data_spec = (data_spec,) * num_steps
    shapes = tuple(s.shape for s in tf.nest.flatten(data_spec))
    dtypes = tuple(s.dtype for s in tf.nest.flatten(data_spec))

    def generator_fn():
      while True:
        if sample_batch_size is not None:
          batch = [self._get_next(num_steps=num_steps, time_stacked=False)
                   for _ in range(sample_batch_size)]
          item = nest_utils.stack_nested_arrays(batch)
        else:
          item = self._get_next(num_steps=num_steps, time_stacked=False)
        yield tuple(tf.nest.flatten(item))

    def time_stack(*structures):
      time_axis = 0 if sample_batch_size is None else 1
      return tf.nest.map_structure(
          lambda *elements: tf.stack(elements, axis=time_axis), *structures)

    ds = tf.data.Dataset.from_generator(
        generator_fn, dtypes,
        shapes).map(lambda *items: tf.nest.pack_sequence_as(data_spec, items))
    if num_steps is not None:
      return ds.map(time_stack)
    else:
      return ds

  def _gather_all(self):
    data = [self._decode(self._storage.get(idx))
            for idx in range(self._capacity)]
    stacked = nest_utils.stack_nested_arrays(data)
    batched = tf.nest.map_structure(lambda t: np.expand_dims(t, 0), stacked)
    return batched

  def _clear(self):
    self._np_state.size = np.int64(0)
    self._np_state.cur_id = np.int64(0)

  def gather_all_transitions(self):
    num_steps_value = 2

    def get_single(idx):
      """Gets the idx item from the replay buffer."""
      with self._lock:
        if self._np_state.size <= idx:

          def empty_item(spec):
            return np.empty(spec.shape, dtype=spec.dtype)

          item = [
              tf.nest.map_structure(empty_item, self.data_spec)
              for n in range(num_steps_value)
          ]
          item = nest_utils.stack_nested_arrays(item)
          return item

        if self._np_state.size == self._capacity:
          # If the buffer is full, add cur_id (head of circular buffer) so that
          # we sample from the range [cur_id, cur_id + size - num_steps_value].
          # We will modulo the size below.
          idx += self._np_state.cur_id

        item = [
            self._decode(self._storage.get((idx + n) % self._capacity))
            for n in range(num_steps_value)
        ]

      item = nest_utils.stack_nested_arrays(item)
      return item

    samples = [
        get_single(idx)
        for idx in range(self._np_state.size - num_steps_value + 1)
    ]
    return nest_utils.stack_nested_arrays(samples)


================================================
FILE: unsupervised_skill_learning/dads_agent.py
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""TF-Agents Class for DADS. Builds on top of the SAC agent."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

import sys
sys.path.append(os.path.abspath('./'))

import numpy as np
import tensorflow as tf

from tf_agents.agents.sac import sac_agent

import skill_dynamics

nest = tf.nest


class DADSAgent(sac_agent.SacAgent):

  def __init__(self,
               save_directory,
               skill_dynamics_observation_size,
               observation_modify_fn=None,
               restrict_input_size=0,
               latent_size=2,
               latent_prior='cont_uniform',
               prior_samples=100,
               fc_layer_params=(256, 256),
               normalize_observations=True,
               network_type='default',
               num_mixture_components=4,
               fix_variance=True,
               skill_dynamics_learning_rate=3e-4,
               reweigh_batches=False,
               agent_graph=None,
               skill_dynamics_graph=None,
               *sac_args,
               **sac_kwargs):
    self._skill_dynamics_learning_rate = skill_dynamics_learning_rate
    self._latent_size = latent_size
    self._latent_prior = latent_prior
    self._prior_samples = prior_samples
    self._save_directory = save_directory
    self._restrict_input_size = restrict_input_size
    self._process_observation = observation_modify_fn

    if agent_graph is None:
      self._graph = tf.compat.v1.get_default_graph()
    else:
      self._graph = agent_graph

    if skill_dynamics_graph is None:
      skill_dynamics_graph = self._graph

    # instantiate the skill dynamics
    self._skill_dynamics = skill_dynamics.SkillDynamics(
        observation_size=skill_dynamics_observation_size,
        action_size=self._latent_size,
        restrict_observation=self._restrict_input_size,
        normalize_observations=normalize_observations,
        fc_layer_params=fc_layer_params,
        network_type=network_type,
        num_components=num_mixture_components,
        fix_variance=fix_variance,
        reweigh_batches=reweigh_batches,
        graph=skill_dynamics_graph)

    super(DADSAgent, self).__init__(*sac_args, **sac_kwargs)
    self._placeholders_in_place = False

  def compute_dads_reward(self, input_obs, cur_skill, target_obs):
    if self._process_observation is not None:
      input_obs, target_obs = self._process_observation(
          input_obs), self._process_observation(target_obs)

    num_reps = self._prior_samples if self._prior_samples > 0 else self._latent_size - 1
    input_obs_altz = np.concatenate([input_obs] * num_reps, axis=0)
    target_obs_altz = np.concatenate([target_obs] * num_reps, axis=0)

    # for marginalization of the denominator
    if self._latent_prior == 'discrete_uniform' and not self._prior_samples:
      alt_skill = np.concatenate(
          [np.roll(cur_skill, i, axis=1) for i in range(1, num_reps + 1)],
          axis=0)
    elif self._latent_prior == 'discrete_uniform':
      alt_skill = np.random.multinomial(
          1, [1. / self._latent_size] * self._latent_size,
          size=input_obs_altz.shape[0])
    elif self._latent_prior == 'gaussian':
      alt_skill = np.random.multivariate_normal(
          np.zeros(self._latent_size),
          np.eye(self._latent_size),
          size=input_obs_altz.shape[0])
    elif self._latent_prior == 'cont_uniform':
      alt_skill = np.random.uniform(
          low=-1.0, high=1.0, size=(input_obs_altz.shape[0], self._latent_size))

    logp = self._skill_dynamics.get_log_prob(input_obs, cur_skill, target_obs)

    # denominator may require more memory than that of a GPU, break computation
    split_group = 20 * 4000
    if input_obs_altz.shape[0] <= split_group:
      logp_altz = self._skill_dynamics.get_log_prob(input_obs_altz, alt_skill,
                                                    target_obs_altz)
    else:
      logp_altz = []
      for split_idx in range(input_obs_altz.shape[0] // split_group):
        start_split = split_idx * split_group
        end_split = (split_idx + 1) * split_group
        logp_altz.append(
            self._skill_dynamics.get_log_prob(
                input_obs_altz[start_split:end_split],
                alt_skill[start_split:end_split],
                target_obs_altz[start_split:end_split]))
      if input_obs_altz.shape[0] % split_group:
        start_split = input_obs_altz.shape[0] % split_group
        logp_altz.append(
            self._skill_dynamics.get_log_prob(input_obs_altz[-start_split:],
                                              alt_skill[-start_split:],
                                              target_obs_altz[-start_split:]))
      logp_altz = np.concatenate(logp_altz)
    logp_altz = np.array(np.array_split(logp_altz, num_reps))

    # final DADS reward
    intrinsic_reward = np.log(num_reps + 1) - np.log(1 + np.exp(
        np.clip(logp_altz - logp.reshape(1, -1), -50, 50)).sum(axis=0))

    return intrinsic_reward, {'logp': logp, 'logp_altz': logp_altz.flatten()}

  def get_experience_placeholder(self):
    self._placeholders_in_place = True
    self._placeholders = []
    for item in nest.flatten(self.collect_data_spec):
      self._placeholders += [
          tf.compat.v1.placeholder(
              item.dtype,
              shape=(None, 2) if len(item.shape) == 0 else
              (None, 2, item.shape[-1]),
              name=item.name)
      ]
    self._policy_experience_ph = nest.pack_sequence_as(self.collect_data_spec,
                                                       self._placeholders)
    return self._policy_experience_ph

  def build_agent_graph(self):
    with self._graph.as_default():
      self.get_experience_placeholder()
      self.agent_train_op = self.train(self._policy_experience_ph)
      self.summary_ops = tf.compat.v1.summary.all_v2_summary_ops()
      return self.agent_train_op

  def build_skill_dynamics_graph(self):
    self._skill_dynamics.make_placeholders()
    self._skill_dynamics.build_graph()
    self._skill_dynamics.increase_prob_op(
        learning_rate=self._skill_dynamics_learning_rate)

  def create_savers(self):
    self._skill_dynamics.create_saver(
        save_prefix=os.path.join(self._save_directory, 'dynamics'))

  def set_sessions(self, initialize_or_restore_skill_dynamics, session=None):
    if session is not None:
      self._session = session
    else:
      self._session = tf.compat.v1.Session(graph=self._graph)
    self._skill_dynamics.set_session(
        initialize_or_restore_variables=initialize_or_restore_skill_dynamics,
        session=session)

  def save_variables(self, global_step):
    self._skill_dynamics.save_variables(global_step=global_step)

  def _get_dict(self, trajectories, batch_size=-1):
    tf.nest.assert_same_structure(self.collect_data_spec, trajectories)
    if batch_size > 0:
      shuffled_batch = np.random.permutation(
          trajectories.observation.shape[0])[:batch_size]
    else:
      shuffled_batch = np.arange(trajectories.observation.shape[0])

    return_dict = {}

    for placeholder, val in zip(self._placeholders, nest.flatten(trajectories)):
      return_dict[placeholder] = val[shuffled_batch]

    return return_dict

  def train_loop(self,
                 trajectories,
                 recompute_reward=False,
                 batch_size=-1,
                 num_steps=1):
    if not self._placeholders_in_place:
      return

    if recompute_reward:
      input_obs = trajectories.observation[:, 0, :-self._latent_size]
      cur_skill = trajectories.observation[:, 0, -self._latent_size:]
      target_obs = trajectories.observation[:, 1, :-self._latent_size]
      new_reward, info = self.compute_dads_reward(input_obs, cur_skill,
                                                  target_obs)
      trajectories = trajectories._replace(
          reward=np.concatenate(
              [np.expand_dims(new_reward, axis=1), trajectories.reward[:, 1:]],
              axis=1))

    # TODO(architsh):all agent specs should be the same as env specs, shift preprocessing to actor/critic networks
    if self._restrict_input_size > 0:
      trajectories = trajectories._replace(
          observation=trajectories.observation[:, :,
                                               self._restrict_input_size:])

    for _ in range(num_steps):
      self._session.run([self.agent_train_op, self.summary_ops],
                        feed_dict=self._get_dict(
                            trajectories, batch_size=batch_size))

    if recompute_reward:
      return new_reward, info
    else:
      return None, None

  @property
  def skill_dynamics(self):
    return self._skill_dynamics


================================================
FILE: unsupervised_skill_learning/dads_off.py
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time
import pickle as pkl
import os
import io
from absl import flags, logging
import functools

import sys
sys.path.append(os.path.abspath('./'))

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

from tf_agents.agents.ddpg import critic_network
from tf_agents.agents.sac import sac_agent
from tf_agents.environments import suite_mujoco
from tf_agents.trajectories import time_step as ts
from tf_agents.environments.suite_gym import wrap_env
from tf_agents.trajectories.trajectory import from_transition, to_transition
from tf_agents.networks import actor_distribution_network
from tf_agents.networks import normal_projection_network
from tf_agents.policies import ou_noise_policy
from tf_agents.trajectories import policy_step
# from tf_agents.policies import py_tf_policy
# from tf_agents.replay_buffers import py_uniform_replay_buffer
from tf_agents.specs import array_spec
from tf_agents.specs import tensor_spec
from tf_agents.utils import common
from tf_agents.utils import nest_utils

import dads_agent

from envs import skill_wrapper
from envs import video_wrapper
from envs.gym_mujoco import ant
from envs.gym_mujoco import half_cheetah
from envs.gym_mujoco import humanoid
from envs.gym_mujoco import point_mass

from envs import dclaw
from envs import dkitty_redesign
from envs import hand_block

from lib import py_tf_policy
from lib import py_uniform_replay_buffer

FLAGS = flags.FLAGS
nest = tf.nest

# general hyperparameters
flags.DEFINE_string('logdir', '~/tmp/dads', 'Directory for saving experiment data')

# environment hyperparameters
flags.DEFINE_string('environment', 'point_mass', 'Name of the environment')
flags.DEFINE_integer('max_env_steps', 200,
                     'Maximum number of steps in one episode')
flags.DEFINE_integer('reduced_observation', 0,
                     'Predict dynamics in a reduced observation space')
flags.DEFINE_integer(
    'min_steps_before_resample', 50,
    'Minimum number of steps to execute before resampling skill')
flags.DEFINE_float('resample_prob', 0.,
                   'Creates stochasticity timesteps before resampling skill')

# need to set save_model and save_freq
flags.DEFINE_string(
    'save_model', None,
    'Name to save the model with, None implies the models are not saved.')
flags.DEFINE_integer('save_freq', 100, 'Saving frequency for checkpoints')
flags.DEFINE_string(
    'vid_name', None,
    'Base name for videos being saved, None implies videos are not recorded')
flags.DEFINE_integer('record_freq', 100,
                     'Video recording frequency within the training loop')

# final evaluation after training is done
flags.DEFINE_integer('run_eval', 0, 'Evaluate learnt skills')

# evaluation type
flags.DEFINE_integer('num_evals', 0, 'Number of skills to evaluate')
flags.DEFINE_integer('deterministic_eval', 0,
                  'Evaluate all skills, only works for discrete skills')

# training
flags.DEFINE_integer('run_train', 0, 'Train the agent')
flags.DEFINE_integer('num_epochs', 500, 'Number of training epochs')

# skill latent space
flags.DEFINE_integer('num_skills', 2, 'Number of skills to learn')
flags.DEFINE_string('skill_type', 'cont_uniform',
                    'Type of skill and the prior over it')
# network size hyperparameter
flags.DEFINE_integer(
    'hidden_layer_size', 512,
    'Hidden layer size, shared by actors, critics and dynamics')

# reward structure
flags.DEFINE_integer(
    'random_skills', 0,
    'Number of skills to sample randomly for approximating mutual information')

# optimization hyperparameters
flags.DEFINE_integer('replay_buffer_capacity', int(1e6),
                     'Capacity of the replay buffer')
flags.DEFINE_integer(
    'clear_buffer_every_iter', 0,
    'Clear replay buffer every iteration to simulate on-policy training, use larger collect steps and train-steps'
)
flags.DEFINE_integer(
    'initial_collect_steps', 2000,
    'Steps collected initially before training to populate the buffer')
flags.DEFINE_integer('collect_steps', 200, 'Steps collected per agent update')

# relabelling
flags.DEFINE_string('agent_relabel_type', None,
                    'Type of skill relabelling used for agent')
flags.DEFINE_integer(
    'train_skill_dynamics_on_policy', 0,
    'Train skill-dynamics on policy data, while agent train off-policy')
flags.DEFINE_string('skill_dynamics_relabel_type', None,
                    'Type of skill relabelling used for skill-dynamics')
flags.DEFINE_integer(
    'num_samples_for_relabelling', 100,
    'Number of samples from prior for relabelling the current skill when using policy relabelling'
)
flags.DEFINE_float(
    'is_clip_eps', 0.,
    'PPO style clipping epsilon to constrain importance sampling weights to (1-eps, 1+eps)'
)
flags.DEFINE_float(
    'action_clipping', 1.,
    'Clip actions to (-eps, eps) per dimension to avoid difficulties with tanh')
flags.DEFINE_integer('debug_skill_relabelling', 0,
                     'analysis of skill relabelling')

# skill dynamics optimization hyperparamaters
flags.DEFINE_integer('skill_dyn_train_steps', 8,
                     'Number of discriminator train steps on a batch of data')
flags.DEFINE_float('skill_dynamics_lr', 3e-4,
                   'Learning rate for increasing the log-likelihood')
flags.DEFINE_integer('skill_dyn_batch_size', 256,
                     'Batch size for discriminator updates')
# agent optimization hyperparameters
flags.DEFINE_integer('agent_batch_size', 256, 'Batch size for agent updates')
flags.DEFINE_integer('agent_train_steps', 128,
                     'Number of update steps per iteration')
flags.DEFINE_float('agent_lr', 3e-4, 'Learning rate for the agent')

# SAC hyperparameters
flags.DEFINE_float('agent_entropy', 0.1, 'Entropy regularization coefficient')
flags.DEFINE_float('agent_gamma', 0.99, 'Reward discount factor')
flags.DEFINE_string(
    'collect_policy', 'default',
    'Can use the OUNoisePolicy to collect experience for better exploration')

# skill-dynamics hyperparameters
flags.DEFINE_string(
    'graph_type', 'default',
    'process skill input separately for more representational power')
flags.DEFINE_integer('num_components', 4,
                     'Number of components for Mixture of Gaussians')
flags.DEFINE_integer('fix_variance', 1,
                     'Fix the variance of output distribution')
flags.DEFINE_integer('normalize_data', 1, 'Maintain running averages')

# debug
flags.DEFINE_integer('debug', 0, 'Creates extra summaries')

# DKitty
flags.DEFINE_integer('expose_last_action', 1, 'Add the last action to the observation')
flags.DEFINE_integer('expose_upright', 1, 'Add the upright angle to the observation')
flags.DEFINE_float('upright_threshold', 0.9, 'Threshold before which the DKitty episode is terminated')
flags.DEFINE_float('robot_noise_ratio', 0.05, 'Noise ratio for robot joints')
flags.DEFINE_float('root_noise_ratio', 0.002, 'Noise ratio for root position')
flags.DEFINE_float('scale_root_position', 1, 'Multiply the root coordinates the magnify the change')
flags.DEFINE_integer('run_on_hardware', 0, 'Flag for hardware runs')
flags.DEFINE_float('randomize_hfield', 0.0, 'Randomize terrain for better DKitty transfer')
flags.DEFINE_integer('observation_omission_size', 2, 'Dimensions to be omitted from policy input')

# Manipulation Environments
flags.DEFINE_integer('randomized_initial_distribution', 1, 'Fix the initial distribution or not')
flags.DEFINE_float('horizontal_wrist_constraint', 1.0, 'Action space constraint to restrict horizontal motion of the wrist')
flags.DEFINE_float('vertical_wrist_constraint', 1.0, 'Action space constraint to restrict vertical motion of the wrist')

# MPC hyperparameters
flags.DEFINE_integer('planning_horizon', 1, 'Number of primitives to plan in the future')
flags.DEFINE_integer('primitive_horizon', 1, 'Horizon for every primitive')
flags.DEFINE_integer('num_candidate_sequences', 50, 'Number of candidates sequence sampled from the proposal distribution')
flags.DEFINE_integer('refine_steps', 10, 'Number of optimization steps')
flags.DEFINE_float('mppi_gamma', 10.0, 'MPPI weighting hyperparameter')
flags.DEFINE_string('prior_type', 'normal', 'Uniform or Gaussian prior for candidate skill(s)')
flags.DEFINE_float('smoothing_beta', 0.9, 'Smooth candidate skill sequences used')
flags.DEFINE_integer('top_primitives', 5, 'Optimization parameter when using uniform prior (CEM style)')

# global variables for this script
observation_omit_size = 0
goal_coord = np.array([10., 10.])
sample_count = 0
iter_count = 0
episode_size_buffer = []
episode_return_buffer = []

# add a flag for state dependent std
def _normal_projection_net(action_spec, init_means_output_factor=0.1):
  return normal_projection_network.NormalProjectionNetwork(
      action_spec,
      mean_transform=None,
      state_dependent_std=True,
      init_means_output_factor=init_means_output_factor,
      std_transform=sac_agent.std_clip_transform,
      scale_distribution=True)

def get_environment(env_name='point_mass'):
  global observation_omit_size
  if env_name == 'Ant-v1':
    env = ant.AntEnv(
        expose_all_qpos=True,
        task='motion')
    observation_omit_size = 2
  elif env_name == 'Ant-v1_goal':
    observation_omit_size = 2
    return wrap_env(
        ant.AntEnv(
            task='goal',
            goal=goal_coord,
            expose_all_qpos=True),
        max_episode_steps=FLAGS.max_env_steps)
  elif env_name == 'Ant-v1_foot_sensor':
    env = ant.AntEnv(
        expose_all_qpos=True,
        model_path='ant_footsensor.xml',
        expose_foot_sensors=True)
    observation_omit_size = 2
  elif env_name == 'HalfCheetah-v1':
    env = half_cheetah.HalfCheetahEnv(expose_all_qpos=True, task='motion')
    observation_omit_size = 1
  elif env_name == 'Humanoid-v1':
    env = humanoid.HumanoidEnv(expose_all_qpos=True)
    observation_omit_size = 2
  elif env_name == 'point_mass':
    env = point_mass.PointMassEnv(expose_goal=False, expose_velocity=False)
    observation_omit_size = 2
  elif env_name == 'DClaw':
    env = dclaw.DClawTurnRandom()
    observation_omit_size = FLAGS.observation_omission_size
  elif env_name == 'DClaw_randomized':
    env = dclaw.DClawTurnRandomDynamics()
    observation_omit_size = FLAGS.observation_omission_size
  elif env_name == 'DKitty_redesign':
    env = dkitty_redesign.BaseDKittyWalk(
        expose_last_action=FLAGS.expose_last_action,
        expose_upright=FLAGS.expose_upright,
        robot_noise_ratio=FLAGS.robot_noise_ratio,
        upright_threshold=FLAGS.upright_threshold)
    observation_omit_size = FLAGS.observation_omission_size
  elif env_name == 'DKitty_randomized':
    env = dkitty_redesign.DKittyRandomDynamics(
        randomize_hfield=FLAGS.randomize_hfield,
        expose_last_action=FLAGS.expose_last_action,
        expose_upright=FLAGS.expose_upright,
        robot_noise_ratio=FLAGS.robot_noise_ratio,
        upright_threshold=FLAGS.upright_threshold)
    observation_omit_size = FLAGS.observation_omission_size
  elif env_name == 'HandBlock':
    observation_omit_size = 0
    env = hand_block.HandBlockCustomEnv(
        horizontal_wrist_constraint=FLAGS.horizontal_wrist_constraint,
        vertical_wrist_constraint=FLAGS.vertical_wrist_constraint,
        randomize_initial_position=bool(FLAGS.randomized_initial_distribution),
        randomize_initial_rotation=bool(FLAGS.randomized_initial_distribution))
  else:
    # note this is already wrapped, no need to wrap again
    env = suite_mujoco.load(env_name)
  return env

def hide_coords(time_step):
  global observation_omit_size
  if observation_omit_size > 0:
    sans_coords = time_step.observation[observation_omit_size:]
    return time_step._replace(observation=sans_coords)

  return time_step


def relabel_skill(trajectory_sample,
                  relabel_type=None,
                  cur_policy=None,
                  cur_skill_dynamics=None):
  global observation_omit_size
  if relabel_type is None or ('importance_sampling' in relabel_type and
                              FLAGS.is_clip_eps <= 1.0):
    return trajectory_sample, None

  # trajectory.to_transition, but for numpy arrays
  next_trajectory = nest.map_structure(lambda x: x[:, 1:], trajectory_sample)
  trajectory = nest.map_structure(lambda x: x[:, :-1], trajectory_sample)
  action_steps = policy_step.PolicyStep(
      action=trajectory.action, state=(), info=trajectory.policy_info)
  time_steps = ts.TimeStep(
      trajectory.step_type,
      reward=nest.map_structure(np.zeros_like, trajectory.reward),  # unknown
      discount=np.zeros_like(trajectory.discount),  # unknown
      observation=trajectory.observation)
  next_time_steps = ts.TimeStep(
      step_type=trajectory.next_step_type,
      reward=trajectory.reward,
      discount=trajectory.discount,
      observation=next_trajectory.observation)
  time_steps, action_steps, next_time_steps = nest.map_structure(
      lambda t: np.squeeze(t, axis=1),
      (time_steps, action_steps, next_time_steps))

  # just return the importance sampling weights for the given batch
  if 'importance_sampling' in relabel_type:
    old_log_probs = policy_step.get_log_probability(action_steps.info)
    is_weights = []
    for idx in range(time_steps.observation.shape[0]):
      cur_time_step = nest.map_structure(lambda x: x[idx:idx + 1], time_steps)
      cur_time_step = cur_time_step._replace(
          observation=cur_time_step.observation[:, observation_omit_size:])
      old_log_prob = old_log_probs[idx]
      cur_log_prob = cur_policy.log_prob(cur_time_step,
                                         action_steps.action[idx:idx + 1])[0]
      is_weights.append(
          np.clip(
              np.exp(cur_log_prob - old_log_prob), 1. / FLAGS.is_clip_eps,
              FLAGS.is_clip_eps))

    is_weights = np.array(is_weights)
    if relabel_type == 'normalized_importance_sampling':
      is_weights = is_weights / is_weights.mean()

    return trajectory_sample, is_weights

  new_observation = np.zeros(time_steps.observation.shape)
  for idx in range(time_steps.observation.shape[0]):
    alt_time_steps = nest.map_structure(
        lambda t: np.stack([t[idx]] * FLAGS.num_samples_for_relabelling),
        time_steps)

    # sample possible skills for relabelling from the prior
    if FLAGS.skill_type == 'cont_uniform':
      # always ensure that the original skill is one of the possible option for relabelling skills
      alt_skills = np.concatenate([
          np.random.uniform(
              low=-1.0,
              high=1.0,
              size=(FLAGS.num_samples_for_relabelling - 1, FLAGS.num_skills)),
          alt_time_steps.observation[:1, -FLAGS.num_skills:]
      ])

    # choose the skill which gives the highest log-probability to the current action
    if relabel_type == 'policy':
      cur_action = np.stack([action_steps.action[idx, :]] *
                            FLAGS.num_samples_for_relabelling)
      alt_time_steps = alt_time_steps._replace(
          observation=np.concatenate([
              alt_time_steps
              .observation[:,
                           observation_omit_size:-FLAGS.num_skills], alt_skills
          ],
                                     axis=1))
      action_log_probs = cur_policy.log_prob(alt_time_steps, cur_action)
      if FLAGS.debug_skill_relabelling:
        print('\n action_log_probs analysis----', idx,
              time_steps.observation[idx, -FLAGS.num_skills:])
        print('number of skills with higher log-probs:',
              np.sum(action_log_probs >= action_log_probs[-1]))
        print('Skills with log-probs higher than actual skill:')
        skill_dist = []
        for skill_idx in range(FLAGS.num_samples_for_relabelling):
          if action_log_probs[skill_idx] >= action_log_probs[-1]:
            print(alt_skills[skill_idx])
            skill_dist.append(
                np.linalg.norm(alt_skills[skill_idx] - alt_skills[-1]))
        print('average distance of skills with higher-log-prob:',
              np.mean(skill_dist))
      max_skill_idx = np.argmax(action_log_probs)

    # choose the skill which gets the highest log-probability under the dynamics posterior
    elif relabel_type == 'dynamics_posterior':
      cur_observations = alt_time_steps.observation[:, :-FLAGS.num_skills]
      next_observations = np.stack(
          [next_time_steps.observation[idx, :-FLAGS.num_skills]] *
          FLAGS.num_samples_for_relabelling)

      # max over posterior log probability is exactly the max over log-prob of transitin under skill-dynamics
      posterior_log_probs = cur_skill_dynamics.get_log_prob(
          process_observation(cur_observations), alt_skills,
          process_observation(next_observations))
      if FLAGS.debug_skill_relabelling:
        print('\n dynamics_log_probs analysis----', idx,
              time_steps.observation[idx, -FLAGS.num_skills:])
        print('number of skills with higher log-probs:',
              np.sum(posterior_log_probs >= posterior_log_probs[-1]))
        print('Skills with log-probs higher than actual skill:')
        skill_dist = []
        for skill_idx in range(FLAGS.num_samples_for_relabelling):
          if posterior_log_probs[skill_idx] >= posterior_log_probs[-1]:
            print(alt_skills[skill_idx])
            skill_dist.append(
                np.linalg.norm(alt_skills[skill_idx] - alt_skills[-1]))
        print('average distance of skills with higher-log-prob:',
              np.mean(skill_dist))

      max_skill_idx = np.argmax(posterior_log_probs)

    # make the new observation with the relabelled skill
    relabelled_skill = alt_skills[max_skill_idx]
    new_observation[idx] = np.concatenate(
        [time_steps.observation[idx, :-FLAGS.num_skills], relabelled_skill])

  traj_observation = np.copy(trajectory_sample.observation)
  traj_observation[:, 0] = new_observation
  new_trajectory_sample = trajectory_sample._replace(
      observation=traj_observation)

  return new_trajectory_sample, None


# hard-coding the state-space for dynamics
def process_observation(observation):

  def _shape_based_observation_processing(observation, dim_idx):
    if len(observation.shape) == 1:
      return observation[dim_idx:dim_idx + 1]
    elif len(observation.shape) == 2:
      return observation[:, dim_idx:dim_idx + 1]
    elif len(observation.shape) == 3:
      return observation[:, :, dim_idx:dim_idx + 1]

  # for consistent use
  if FLAGS.reduced_observation == 0:
    return observation

  # process observation for dynamics with reduced observation space
  if FLAGS.environment == 'HalfCheetah-v1':
    qpos_dim = 9
  elif FLAGS.environment == 'Ant-v1':
    qpos_dim = 15
  elif FLAGS.environment == 'Humanoid-v1':
    qpos_dim = 26
  elif 'DKitty' in FLAGS.environment:
    qpos_dim = 36

  # x-axis
  if FLAGS.reduced_observation in [1, 5]:
    red_obs = [_shape_based_observation_processing(observation, 0)]
  # x-y plane
  elif FLAGS.reduced_observation in [2, 6]:
    if FLAGS.environment == 'Ant-v1' or 'DKitty' in FLAGS.environment or 'DClaw' in FLAGS.environment:
      red_obs = [
          _shape_based_observation_processing(observation, 0),
          _shape_based_observation_processing(observation, 1)
      ]
    else:
      red_obs = [
          _shape_based_observation_processing(observation, 0),
          _shape_based_observation_processing(observation, qpos_dim)
      ]
  # x-y plane, x-y velocities
  elif FLAGS.reduced_observation in [4, 8]:
    if FLAGS.reduced_observation == 4 and 'DKittyPush' in FLAGS.environment:
      # position of the agent + relative position of the box
      red_obs = [
          _shape_based_observation_processing(observation, 0),
          _shape_based_observation_processing(observation, 1),
          _shape_based_observation_processing(observation, 3),
          _shape_based_observation_processing(observation, 4)
      ]
    elif FLAGS.environment in ['Ant-v1']:
      red_obs = [
          _shape_based_observation_processing(observation, 0),
          _shape_based_observation_processing(observation, 1),
          _shape_based_observation_processing(observation, qpos_dim),
          _shape_based_observation_processing(observation, qpos_dim + 1)
      ]

  # (x, y, orientation), works only for ant, point_mass
  elif FLAGS.reduced_observation == 3:
    if FLAGS.environment in ['Ant-v1', 'point_mass']:
      red_obs = [
          _shape_based_observation_processing(observation, 0),
          _shape_based_observation_processing(observation, 1),
          _shape_based_observation_processing(observation,
                                              observation.shape[1] - 1)
      ]
    # x, y, z of the center of the block
    elif FLAGS.environment in ['HandBlock']:
      red_obs = [
          _shape_based_observation_processing(observation, 
                                              observation.shape[-1] - 7),
          _shape_based_observation_processing(observation, 
                                              observation.shape[-1] - 6),
          _shape_based_observation_processing(observation,
                                              observation.shape[-1] - 5)
      ]

  if FLAGS.reduced_observation in [5, 6, 8]:
    red_obs += [
        _shape_based_observation_processing(observation,
                                            observation.shape[1] - idx)
        for idx in range(1, 5)
    ]

  if FLAGS.reduced_observation == 36 and 'DKitty' in FLAGS.environment:
    red_obs = [
        _shape_based_observation_processing(observation, idx)
        for idx in range(qpos_dim)
    ]

  # x, y, z and the rotation quaternion
  if FLAGS.reduced_observation == 7 and FLAGS.environment == 'HandBlock':
    red_obs = [
        _shape_based_observation_processing(observation, observation.shape[-1] - idx)
        for idx in range(1, 8)
    ][::-1]

  # the rotation quaternion
  if FLAGS.reduced_observation == 4 and FLAGS.environment == 'HandBlock':
    red_obs = [
        _shape_based_observation_processing(observation, observation.shape[-1] - idx)
        for idx in range(1, 5)
    ][::-1]

  if isinstance(observation, np.ndarray):
    input_obs = np.concatenate(red_obs, axis=len(observation.shape) - 1)
  elif isinstance(observation, tf.Tensor):
    input_obs = tf.concat(red_obs, axis=len(observation.shape) - 1)
  return input_obs


def collect_experience(py_env,
                       time_step,
                       collect_policy,
                       buffer_list,
                       num_steps=1):

  episode_sizes = []
  extrinsic_reward = []
  step_idx = 0
  cur_return = 0.
  for step_idx in range(num_steps):
    if time_step.is_last():
      episode_sizes.append(step_idx)
      extrinsic_reward.append(cur_return)
      cur_return = 0.

    action_step = collect_policy.action(hide_coords(time_step))

    if FLAGS.action_clipping < 1.:
      action_step = action_step._replace(
          action=np.clip(action_step.action, -FLAGS.action_clipping,
                         FLAGS.action_clipping))

    if FLAGS.skill_dynamics_relabel_type is not None and 'importance_sampling' in FLAGS.skill_dynamics_relabel_type and FLAGS.is_clip_eps > 1.0:
      cur_action_log_prob = collect_policy.log_prob(
          nest_utils.batch_nested_array(hide_coords(time_step)),
          np.expand_dims(action_step.action, 0))
      action_step = action_step._replace(
          info=policy_step.set_log_probability(action_step.info,
                                               cur_action_log_prob))

    next_time_step = py_env.step(action_step.action)
    cur_return += next_time_step.reward

    # all modification to observations and training will be done within the agent
    for buffer_ in buffer_list:
      buffer_.add_batch(
          from_transition(
              nest_utils.batch_nested_array(time_step),
              nest_utils.batch_nested_array(action_step),
              nest_utils.batch_nested_array(next_time_step)))

    time_step = next_time_step

  # carry-over calculation for the next collection cycle
  episode_sizes.append(step_idx + 1)
  extrinsic_reward.append(cur_return)
  for idx in range(1, len(episode_sizes)):
    episode_sizes[-idx] -= episode_sizes[-idx - 1]

  return time_step, {
      'episode_sizes': episode_sizes,
      'episode_return': extrinsic_reward
  }


def run_on_env(env,
               policy,
               dynamics=None,
               predict_trajectory_steps=0,
               return_data=False,
               close_environment=True):
  time_step = env.reset()
  data = []

  if not return_data:
    extrinsic_reward = []
  while not time_step.is_last():
    action_step = policy.action(hide_coords(time_step))
    if FLAGS.action_clipping < 1.:
      action_step = action_step._replace(
          action=np.clip(action_step.action, -FLAGS.action_clipping,
                         FLAGS.action_clipping))

    env_action = action_step.action
    next_time_step = env.step(env_action)

    skill_size = FLAGS.num_skills
    if skill_size > 0:
      cur_observation = time_step.observation[:-skill_size]
      cur_skill = time_step.observation[-skill_size:]
      next_observation = next_time_step.observation[:-skill_size]
    else:
      cur_observation = time_step.observation
      next_observation = next_time_step.observation

    if dynamics is not None:
      if FLAGS.reduced_observation:
        cur_observation, next_observation = process_observation(
            cur_observation), process_observation(next_observation)
      logp = dynamics.get_log_prob(
          np.expand_dims(cur_observation, 0), np.expand_dims(cur_skill, 0),
          np.expand_dims(next_observation, 0))

      cur_predicted_state = np.expand_dims(cur_observation, 0)
      skill_expanded = np.expand_dims(cur_skill, 0)
      cur_predicted_trajectory = [cur_predicted_state[0]]
      for _ in range(predict_trajectory_steps):
        next_predicted_state = dynamics.predict_state(cur_predicted_state,
                                                      skill_expanded)
        cur_predicted_trajectory.append(next_predicted_state[0])
        cur_predicted_state = next_predicted_state
    else:
      logp = ()
      cur_predicted_trajectory = []

    if return_data:
      data.append([
          cur_observation, action_step.action, logp, next_time_step.reward,
          np.array(cur_predicted_trajectory)
      ])
    else:
      extrinsic_reward.append([next_time_step.reward])

    time_step = next_time_step

  if close_environment:
    env.close()

  if return_data:
    return data
  else:
    return extrinsic_reward


def eval_loop(eval_dir,
              eval_policy,
              dynamics=None,
              vid_name=None,
              plot_name=None):
  metadata = tf.io.gfile.GFile(
      os.path.join(eval_dir, 'metadata.txt'), 'a')
  if FLAGS.num_skills == 0:
    num_evals = FLAGS.num_evals
  elif FLAGS.deterministic_eval:
    num_evals = FLAGS.num_skills
  else:
    num_evals = FLAGS.num_evals

  if plot_name is not None:
    # color_map = ['b', 'g', 'r', 'c', 'm', 'y', 'k']
    color_map = ['b', 'g', 'r', 'c', 'm', 'y']
    style_map = []
    for line_style in ['-', '--', '-.', ':']:
      style_map += [color + line_style for color in color_map]

    plt.xlim(-15, 15)
    plt.ylim(-15, 15)
    # all_trajectories = []
    # all_predicted_trajectories = []

  for idx in range(num_evals):
    if FLAGS.num_skills > 0:
      if FLAGS.deterministic_eval:
        preset_skill = np.zeros(FLAGS.num_skills, dtype=np.int64)
        preset_skill[idx] = 1
      elif FLAGS.skill_type == 'discrete_uniform':
        preset_skill = np.random.multinomial(1, [1. / FLAGS.num_skills] *
                                             FLAGS.num_skills)
      elif FLAGS.skill_type == 'gaussian':
        preset_skill = np.random.multivariate_normal(
            np.zeros(FLAGS.num_skills), np.eye(FLAGS.num_skills))
      elif FLAGS.skill_type == 'cont_uniform':
        preset_skill = np.random.uniform(
            low=-1.0, high=1.0, size=FLAGS.num_skills)
      elif FLAGS.skill_type == 'multivariate_bernoulli':
        preset_skill = np.random.binomial(1, 0.5, size=FLAGS.num_skills)
    else:
      preset_skill = None

    eval_env = get_environment(env_name=FLAGS.environment)
    eval_env = wrap_env(
        skill_wrapper.SkillWrapper(
            eval_env,
            num_latent_skills=FLAGS.num_skills,
            skill_type=FLAGS.skill_type,
            preset_skill=preset_skill,
            min_steps_before_resample=FLAGS.min_steps_before_resample,
            resample_prob=FLAGS.resample_prob),
        max_episode_steps=FLAGS.max_env_steps)

    # record videos for sampled trajectories
    if vid_name is not None:
      full_vid_name = vid_name + '_' + str(idx)
      eval_env = video_wrapper.VideoWrapper(eval_env, base_path=eval_dir, base_name=full_vid_name)

    mean_reward = 0.
    per_skill_evaluations = 1
    predict_trajectory_steps = 0
    # trajectories_per_skill = []
    # predicted_trajectories_per_skill = []
    for eval_idx in range(per_skill_evaluations):
      eval_trajectory = run_on_env(
          eval_env,
          eval_policy,
          dynamics=dynamics,
          predict_trajectory_steps=predict_trajectory_steps,
          return_data=True,
          close_environment=True if eval_idx == per_skill_evaluations -
          1 else False)

      trajectory_coordinates = np.array([
          eval_trajectory[step_idx][0][:2]
          for step_idx in range(len(eval_trajectory))
      ])

      # trajectory_states = np.array([
      #     eval_trajectory[step_idx][0]
      #     for step_idx in range(len(eval_trajectory))
      # ])
      # trajectories_per_skill.append(trajectory_states)
      if plot_name is not None:
        plt.plot(
            trajectory_coordinates[:, 0],
            trajectory_coordinates[:, 1],
            style_map[idx % len(style_map)],
            label=(str(idx) if eval_idx == 0 else None))
        # plt.plot(
        #     trajectory_coordinates[0, 0],
        #     trajectory_coordinates[0, 1],
        #     marker='o',
        #     color=style_map[idx % len(style_map)][0])
        if predict_trajectory_steps > 0:
          # predicted_states = np.array([
          #     eval_trajectory[step_idx][-1]
          #     for step_idx in range(len(eval_trajectory))
          # ])
          # predicted_trajectories_per_skill.append(predicted_states)
          for step_idx in range(len(eval_trajectory)):
            if step_idx % 20 == 0:
              plt.plot(eval_trajectory[step_idx][-1][:, 0],
                       eval_trajectory[step_idx][-1][:, 1], 'k:')

      mean_reward += np.mean([
          eval_trajectory[step_idx][-1]
          for step_idx in range(len(eval_trajectory))
      ])
      metadata.write(
          str(idx) + ' ' + str(preset_skill) + ' ' +
          str(trajectory_coordinates[-1, :]) + '\n')

    # all_predicted_trajectories.append(
    #     np.stack(predicted_trajectories_per_skill))
    # all_trajectories.append(np.stack(trajectories_per_skill))

  # all_predicted_trajectories = np.stack(all_predicted_trajectories)
  # all_trajectories = np.stack(all_trajectories)
  # print(all_trajectories.shape, all_predicted_trajectories.shape)
  # pkl.dump(
  #     all_trajectories,
  #     tf.io.gfile.GFile(
  #         os.path.join(vid_dir, 'skill_dynamics_full_obs_r100_actual_trajectories.pkl'),
  #         'wb'))
  # pkl.dump(
  #     all_predicted_trajectories,
  #     tf.io.gfile.GFile(
  #         os.path.join(vid_dir, 'skill_dynamics_full_obs_r100_predicted_trajectories.pkl'),
  #         'wb'))
  if plot_name is not None:
    full_image_name = plot_name + '.png'

    # to save images while writing to CNS
    buf = io.BytesIO()
    # plt.title('Trajectories in Continuous Skill Space')
    plt.savefig(buf, dpi=600, bbox_inches='tight')
    buf.seek(0)
    image = tf.io.gfile.GFile(os.path.join(eval_dir, full_image_name), 'w')
    image.write(buf.read(-1))

    # clear before next plot
    plt.clf()


# discrete primitives only, useful with skill-dynamics
def eval_planning(env,
                  dynamics,
                  policy,
                  latent_action_space_size,
                  episode_horizon,
                  planning_horizon=1,
                  primitive_horizon=10,
                  **kwargs):
  """env: tf-agents environment without the skill wrapper."""
  global goal_coord

  # assuming only discrete action spaces
  high_level_action_space = np.eye(latent_action_space_size)
  time_step = env.reset()

  actual_reward = 0.
  actual_coords = [np.expand_dims(time_step.observation[:2], 0)]
  predicted_coords = []

  # planning loop
  for _ in range(episode_horizon // primitive_horizon):
    running_reward = np.zeros(latent_action_space_size)
    running_cur_state = np.array([process_observation(time_step.observation)] *
                                 latent_action_space_size)
    cur_coord_predicted = [np.expand_dims(running_cur_state[:, :2], 1)]

    # simulate all high level actions for K steps
    for _ in range(planning_horizon):
      predicted_next_state = dynamics.predict_state(running_cur_state,
                                                    high_level_action_space)
      cur_coord_predicted.append(np.expand_dims(predicted_next_state[:, :2], 1))

      # update running stuff
      running_reward += env.compute_reward(running_cur_state,
                                           predicted_next_state)
      running_cur_state = predicted_next_state

    predicted_coords.append(np.concatenate(cur_coord_predicted, axis=1))

    selected_high_level_action = np.argmax(running_reward)
    for _ in range(primitive_horizon):
      # concatenated observation
      skill_concat_observation = np.concatenate([
          time_step.observation,
          high_level_action_space[selected_high_level_action]
      ],
                                                axis=0)
      next_time_step = env.step(
          np.clip(
              policy.action(
                  hide_coords(
                      time_step._replace(
                          observation=skill_concat_observation))).action,
              -FLAGS.action_clipping, FLAGS.action_clipping))
      actual_reward += next_time_step.reward

      # prepare for next iteration
      time_step = next_time_step
      actual_coords.append(np.expand_dims(time_step.observation[:2], 0))

  actual_coords = np.concatenate(actual_coords)
  return actual_reward, actual_coords, predicted_coords


def eval_mppi(
    env,
    dynamics,
    policy,
    latent_action_space_size,
    episode_horizon,
    planning_horizon=1,
    primitive_horizon=10,
    num_candidate_sequences=50,
    refine_steps=10,
    mppi_gamma=10,
    prior_type='normal',
    smoothing_beta=0.9,
    # no need to change generally
    sparsify_rewards=False,
    # only for uniform prior mode
    top_primitives=5):
  """env: tf-agents environment without the skill wrapper.

     dynamics: skill-dynamics model learnt by DADS.
     policy: skill-conditioned policy learnt by DADS.
     planning_horizon: number of latent skills to plan in the future.
     primitive_horizon: number of steps each skill is executed for.
     num_candidate_sequences: number of samples executed from the prior per
     refining step of planning.
     refine_steps: number of steps for which the plan is iterated upon before
     execution (number of optimization steps).
     mppi_gamma: MPPI parameter for reweighing rewards.
     prior_type: 'normal' implies MPPI, 'uniform' implies a CEM like algorithm
     (not tested).
     smoothing_beta: for planning_horizon > 1, the every sampled plan is
     smoothed using EMA. (0-> no smoothing, 1-> perfectly smoothed)
     sparsify_rewards: converts a dense reward problem into a sparse reward
     (avoid using).
     top_primitives: number of elites to choose, if using CEM (not tested).
  """

  step_idx = 0

  def _smooth_primitive_sequences(primitive_sequences):
    for planning_idx in range(1, primitive_sequences.shape[1]):
      primitive_sequences[:,
                          planning_idx, :] = smoothing_beta * primitive_sequences[:, planning_idx - 1, :] + (
                              1. - smoothing_beta
                          ) * primitive_sequences[:, planning_idx, :]

    return primitive_sequences

  def _get_init_primitive_parameters():
    if prior_type == 'normal':
      prior_mean = functools.partial(
          np.random.multivariate_normal,
          mean=np.zeros(latent_action_space_size),
          cov=np.diag(np.ones(latent_action_space_size)))
      prior_cov = lambda: 1.5 * np.diag(np.ones(latent_action_space_size))
      return [prior_mean(), prior_cov()]

    elif prior_type == 'uniform':
      prior_low = lambda: np.array([-1.] * latent_action_space_size)
      prior_high = lambda: np.array([1.] * latent_action_space_size)
      return [prior_low(), prior_high()]

  def _sample_primitives(params):
    if prior_type == 'normal':
      sample = np.random.multivariate_normal(*params)
    elif prior_type == 'uniform':
      sample = np.random.uniform(*params)
    return np.clip(sample, -1., 1.)

  # update new primitive means for horizon sequence
  def _update_parameters(candidates, reward, primitive_parameters):
    # a more regular mppi
    if prior_type == 'normal':
      reward = np.exp(mppi_gamma * (reward - np.max(reward)))
      reward = reward / (reward.sum() + 1e-10)
      new_means = (candidates.T * reward).T.sum(axis=0)

      for planning_idx in range(candidates.shape[1]):
        primitive_parameters[planning_idx][0] = new_means[planning_idx]

    # TODO(architsh): closer to cross-entropy/shooting method, figure out a better update
    elif prior_type == 'uniform':
      chosen_candidates = candidates[np.argsort(reward)[-top_primitives:]]
      candidates_min = np.min(chosen_candidates, axis=0)
      candidates_max = np.max(chosen_candidates, axis=0)

      for planning_idx in range(candidates.shape[1]):
        primitive_parameters[planning_idx][0] = candidates_min[planning_idx]
        primitive_parameters[planning_idx][1] = candidates_max[planning_idx]

  def _get_expected_primitive(params):
    if prior_type == 'normal':
      return params[0]
    elif prior_type == 'uniform':
      return (params[0] + params[1]) / 2

  time_step = env.reset()
  actual_coords = [np.expand_dims(time_step.observation[:2], 0)]
  actual_reward = 0.
  distance_to_goal_array = []

  primitive_parameters = []
  chosen_primitives = []
  for _ in range(planning_horizon):
    primitive_parameters.append(_get_init_primitive_parameters())

  for _ in range(episode_horizon // primitive_horizon):
    for _ in range(refine_steps):
      # generate candidates sequences for primitives
      candidate_primitive_sequences = []
      for _ in range(num_candidate_sequences):
        candidate_primitive_sequences.append([
            _sample_primitives(primitive_parameters[planning_idx])
            for planning_idx in range(planning_horizon)
        ])

      candidate_primitive_sequences = np.array(candidate_primitive_sequences)
      candidate_primitive_sequences = _smooth_primitive_sequences(
          candidate_primitive_sequences)

      running_cur_state = np.array(
          [process_observation(time_step.observation)] *
          num_candidate_sequences)
      running_reward = np.zeros(num_candidate_sequences)
      for planning_idx in range(planning_horizon):
        cur_primitives = candidate_primitive_sequences[:, planning_idx, :]
        for _ in range(primitive_horizon):
          predicted_next_state = dynamics.predict_state(running_cur_state,
                                                        cur_primitives)

          # update running stuff
          dense_reward = env.compute_reward(running_cur_state,
                                            predicted_next_state)
          # modification for sparse_reward
          if sparsify_rewards:
            sparse_reward = 5.0 * (dense_reward > -2) + 0.0 * (
                dense_reward <= -2)
            running_reward += sparse_reward
          else:
            running_reward += dense_reward

          running_cur_state = predicted_next_state

      _update_parameters(candidate_primitive_sequences, running_reward,
                         primitive_parameters)

    chosen_primitive = _get_expected_primitive(primitive_parameters[0])
    chosen_primitives.append(chosen_primitive)

    # a loop just to check what the chosen primitive is expected to do
    # running_cur_state = np.array([process_observation(time_step.observation)])
    # for _ in range(primitive_horizon):
    #   predicted_next_state = dynamics.predict_state(
    #       running_cur_state, np.expand_dims(chosen_primitive, 0))
    #   running_cur_state = predicted_next_state
    # print('Predicted next co-ordinates:', running_cur_state[0, :2])

    for _ in range(primitive_horizon):
      # concatenated observation
      skill_concat_observation = np.concatenate(
          [time_step.observation, chosen_primitive], axis=0)
      next_time_step = env.step(
          np.clip(
              policy.action(
                  hide_coords(
                      time_step._replace(
                          observation=skill_concat_observation))).action,
              -FLAGS.action_clipping, FLAGS.action_clipping))
      actual_reward += next_time_step.reward
      distance_to_goal_array.append(next_time_step.reward)
      # prepare for next iteration
      time_step = next_time_step
      actual_coords.append(np.expand_dims(time_step.observation[:2], 0))
      step_idx += 1
      # print(step_idx)
    # print('Actual next co-ordinates:', actual_coords[-1])

    primitive_parameters.pop(0)
    primitive_parameters.append(_get_init_primitive_parameters())

  actual_coords = np.concatenate(actual_coords)
  return actual_reward, actual_coords, np.array(
      chosen_primitives), distance_to_goal_array


def main(_):
  # setting up
  start_time = time.time()
  tf.compat.v1.enable_resource_variables()
  tf.compat.v1.disable_eager_execution()
  logging.set_verbosity(logging.INFO)
  global observation_omit_size, goal_coord, sample_count, iter_count, episode_size_buffer, episode_return_buffer

  root_dir = os.path.abspath(os.path.expanduser(FLAGS.logdir))
  if not tf.io.gfile.exists(root_dir):
    tf.io.gfile.makedirs(root_dir)
  log_dir = os.path.join(root_dir, FLAGS.environment)
  
  if not tf.io.gfile.exists(log_dir):
    tf.io.gfile.makedirs(log_dir)
  save_dir = os.path.join(log_dir, 'models')
  if not tf.io.gfile.exists(save_dir):
    tf.io.gfile.makedirs(save_dir)

  print('directory for recording experiment data:', log_dir)

  # in case training is paused and resumed, so can be restored
  try:
    sample_count = np.load(os.path.join(log_dir, 'sample_count.npy')).tolist()
    iter_count = np.load(os.path.join(log_dir, 'iter_count.npy')).tolist()
    episode_size_buffer = np.load(os.path.join(log_dir, 'episode_size_buffer.npy')).tolist()
    episode_return_buffer = np.load(os.path.join(log_dir, 'episode_return_buffer.npy')).tolist()
  except:
    sample_count = 0
    iter_count = 0
    episode_size_buffer = []
    episode_return_buffer = []

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      os.path.join(log_dir, 'train', 'in_graph_data'), flush_millis=10 * 1000)
  train_summary_writer.set_as_default()

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(True):
    # environment related stuff
    py_env = get_environment(env_name=FLAGS.environment)
    py_env = wrap_env(
        skill_wrapper.SkillWrapper(
            py_env,
            num_latent_skills=FLAGS.num_skills,
            skill_type=FLAGS.skill_type,
            preset_skill=None,
            min_steps_before_resample=FLAGS.min_steps_before_resample,
            resample_prob=FLAGS.resample_prob),
        max_episode_steps=FLAGS.max_env_steps)

    # all specifications required for all networks and agents
    py_action_spec = py_env.action_spec()
    tf_action_spec = tensor_spec.from_spec(
        py_action_spec)  # policy, critic action spec
    env_obs_spec = py_env.observation_spec()
    py_env_time_step_spec = ts.time_step_spec(
        env_obs_spec)  # replay buffer time_step spec
    if observation_omit_size > 0:
      agent_obs_spec = array_spec.BoundedArraySpec(
          (env_obs_spec.shape[0] - observation_omit_size,),
          env_obs_spec.dtype,
          minimum=env_obs_spec.minimum,
          maximum=env_obs_spec.maximum,
          name=env_obs_spec.name)  # policy, critic observation spec
    else:
      agent_obs_spec = env_obs_spec
    py_agent_time_step_spec = ts.time_step_spec(
        agent_obs_spec)  # policy, critic time_step spec
    tf_agent_time_step_spec = tensor_spec.from_spec(py_agent_time_step_spec)

    if not FLAGS.reduced_observation:
      skill_dynamics_observation_size = (
          py_env_time_step_spec.observation.shape[0] - FLAGS.num_skills)
    else:
      skill_dynamics_observation_size = FLAGS.reduced_observation

    # TODO(architsh): Shift co-ordinate hiding to actor_net and critic_net (good for futher image based processing as well)
    actor_net = actor_distribution_network.ActorDistributionNetwork(
        tf_agent_time_step_spec.observation,
        tf_action_spec,
        fc_layer_params=(FLAGS.hidden_layer_size,) * 2,
        continuous_projection_net=_normal_projection_net)

    critic_net = critic_network.CriticNetwork(
        (tf_agent_time_step_spec.observation, tf_action_spec),
        observation_fc_layer_params=None,
        action_fc_layer_params=None,
        joint_fc_layer_params=(FLAGS.hidden_layer_size,) * 2)

    if FLAGS.skill_dynamics_relabel_type is not None and 'importance_sampling' in FLAGS.skill_dynamics_relabel_type and FLAGS.is_clip_eps > 1.0:
      reweigh_batches_flag = True
    else:
      reweigh_batches_flag = False

    agent = dads_agent.DADSAgent(
        # DADS parameters
        save_dir,
        skill_dynamics_observation_size,
        observation_modify_fn=process_observation,
        restrict_input_size=observation_omit_size,
        latent_size=FLAGS.num_skills,
        latent_prior=FLAGS.skill_type,
        prior_samples=FLAGS.random_skills,
        fc_layer_params=(FLAGS.hidden_layer_size,) * 2,
        normalize_observations=FLAGS.normalize_data,
        network_type=FLAGS.graph_type,
        num_mixture_components=FLAGS.num_components,
        fix_variance=FLAGS.fix_variance,
        reweigh_batches=reweigh_batches_flag,
        skill_dynamics_learning_rate=FLAGS.skill_dynamics_lr,
        # SAC parameters
        time_step_spec=tf_agent_time_step_spec,
        action_spec=tf_action_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        target_update_tau=0.005,
        target_update_period=1,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=FLAGS.agent_lr),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=FLAGS.agent_lr),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=FLAGS.agent_lr),
        td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
        gamma=FLAGS.agent_gamma,
        reward_scale_factor=1. /
        (FLAGS.agent_entropy + 1e-12),
        gradient_clipping=None,
        debug_summaries=FLAGS.debug,
        train_step_counter=global_step)

    # evaluation policy
    eval_policy = py_tf_policy.PyTFPolicy(agent.policy)

    # collection policy
    if FLAGS.collect_policy == 'default':
      collect_policy = py_tf_policy.PyTFPolicy(agent.collect_policy)
    elif FLAGS.collect_policy == 'ou_noise':
      collect_policy = py_tf_policy.PyTFPolicy(
          ou_noise_policy.OUNoisePolicy(
              agent.collect_policy, ou_stddev=0.2, ou_damping=0.15))

    # relabelling policy deals with batches of data, unlike collect and eval
    relabel_policy = py_tf_policy.PyTFPolicy(agent.collect_policy)

    # constructing a replay buffer, need a python spec
    policy_step_spec = policy_step.PolicyStep(
        action=py_action_spec, state=(), info=())

    if FLAGS.skill_dynamics_relabel_type is not None and 'importance_sampling' in FLAGS.skill_dynamics_relabel_type and FLAGS.is_clip_eps > 1.0:
      policy_step_spec = policy_step_spec._replace(
          info=policy_step.set_log_probability(
              policy_step_spec.info,
              array_spec.ArraySpec(
                  shape=(), dtype=np.float32, name='action_log_prob')))

    trajectory_spec = from_transition(py_env_time_step_spec, policy_step_spec,
                                      py_env_time_step_spec)
    capacity = FLAGS.replay_buffer_capacity
    # for all the data collected
    rbuffer = py_uniform_replay_buffer.PyUniformReplayBuffer(
        capacity=capacity, data_spec=trajectory_spec)

    if FLAGS.train_skill_dynamics_on_policy:
      # for on-policy data (if something special is required)
      on_buffer = py_uniform_replay_buffer.PyUniformReplayBuffer(
          capacity=FLAGS.initial_collect_steps + FLAGS.collect_steps + 10,
          data_spec=trajectory_spec)

    # insert experience manually with relabelled rewards and skills
    agent.build_agent_graph()
    agent.build_skill_dynamics_graph()
    agent.create_savers()

    # saving this way requires the saver to be out the object
    train_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(save_dir, 'agent'),
        agent=agent,
        global_step=global_step)
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(save_dir, 'policy'),
        policy=agent.policy,
        global_step=global_step)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(save_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=rbuffer)

    setup_time = time.time() - start_time
    print('Setup time:', setup_time)

    with tf.compat.v1.Session().as_default() as sess:
      train_checkpointer.initialize_or_restore(sess)
      rb_checkpointer.initialize_or_restore(sess)
      agent.set_sessions(
          initialize_or_restore_skill_dynamics=True, session=sess)

      meta_start_time = time.time()
      if FLAGS.run_train:

        train_writer = tf.compat.v1.summary.FileWriter(
            os.path.join(log_dir, 'train'), sess.graph)
        common.initialize_uninitialized_variables(sess)
        sess.run(train_summary_writer.init())

        time_step = py_env.reset()
        episode_size_buffer.append(0)
        episode_return_buffer.append(0.)

        # maintain a buffer of episode lengths
        def _process_episodic_data(ep_buffer, cur_data):
          ep_buffer[-1] += cur_data[0]
          ep_buffer += cur_data[1:]

          # only keep the last 100 episodes
          if len(ep_buffer) > 101:
            ep_buffer = ep_buffer[-101:]

        # remove invalid transitions from the replay buffer
        def _filter_trajectories(trajectory):
          # two consecutive samples in the buffer might not have been consecutive in the episode
          valid_indices = (trajectory.step_type[:, 0] != 2)

          return nest.map_structure(lambda x: x[valid_indices], trajectory)

        if iter_count == 0:
          start_time = time.time()
          time_step, collect_info = collect_experience(
              py_env,
              time_step,
              collect_policy,
              buffer_list=[rbuffer] if not FLAGS.train_skill_dynamics_on_policy
              else [rbuffer, on_buffer],
              num_steps=FLAGS.initial_collect_steps)
          _process_episodic_data(episode_size_buffer,
                                 collect_info['episode_sizes'])
          _process_episodic_data(episode_return_buffer,
                                 collect_info['episode_return'])
          sample_count += FLAGS.initial_collect_steps
          initial_collect_time = time.time() - start_time
          print('Initial data collection time:', initial_collect_time)

        agent_end_train_time = time.time()
        while iter_count < FLAGS.num_epochs:
          print('iteration index:', iter_count)

          # model save
          if FLAGS.save_model is not None and iter_count % FLAGS.save_freq == 0:
            print('Saving stuff')
            train_checkpointer.save(global_step=iter_count)
            policy_checkpointer.save(global_step=iter_count)
            rb_checkpointer.save(global_step=iter_count)
            agent.save_variables(global_step=iter_count)

            np.save(os.path.join(log_dir, 'sample_count'), sample_count)
            np.save(os.path.join(log_dir, 'episode_size_buffer'), episode_size_buffer)
            np.save(os.path.join(log_dir, 'episode_return_buffer'), episode_return_buffer)
            np.save(os.path.join(log_dir, 'iter_count'), iter_count)

          collect_start_time = time.time()
          print('intermediate time:', collect_start_time - agent_end_train_time)
          time_step, collect_info = collect_experience(
              py_env,
              time_step,
              collect_policy,
              buffer_list=[rbuffer] if not FLAGS.train_skill_dynamics_on_policy
              else [rbuffer, on_buffer],
              num_steps=FLAGS.collect_steps)
          sample_count += FLAGS.collect_steps
          _process_episodic_data(episode_size_buffer,
                                 collect_info['episode_sizes'])
          _process_episodic_data(episode_return_buffer,
                                 collect_info['episode_return'])
          collect_end_time = time.time()
          print('Iter collection time:', collect_end_time - collect_start_time)

          # only for debugging skill relabelling
          if iter_count >= 1 and FLAGS.debug_skill_relabelling:
            trajectory_sample = rbuffer.get_next(
                sample_batch_size=5, num_steps=2)
            trajectory_sample = _filter_trajectories(trajectory_sample)
            # trajectory_sample, _ = relabel_skill(
            #     trajectory_sample,
            #     relabel_type='policy',
            #     cur_policy=relabel_policy,
            #     cur_skill_dynamics=agent.skill_dynamics)
            trajectory_sample, is_weights = relabel_skill(
                trajectory_sample,
                relabel_type='importance_sampling',
                cur_policy=relabel_policy,
                cur_skill_dynamics=agent.skill_dynamics)
            print(is_weights)

          skill_dynamics_buffer = rbuffer
          if FLAGS.train_skill_dynamics_on_policy:
            skill_dynamics_buffer = on_buffer

          # TODO(architsh): clear_buffer_every_iter needs to fix these as well
          for _ in range(1 if FLAGS.clear_buffer_every_iter else FLAGS
                         .skill_dyn_train_steps):
            if FLAGS.clear_buffer_every_iter:
              trajectory_sample = rbuffer.gather_all_transitions()
            else:
              trajectory_sample = skill_dynamics_buffer.get_next(
                  sample_batch_size=FLAGS.skill_dyn_batch_size, num_steps=2)
            trajectory_sample = _filter_trajectories(trajectory_sample)

            # is_weights is None usually, unless relabelling involves importance_sampling
            trajectory_sample, is_weights = relabel_skill(
                trajectory_sample,
                relabel_type=FLAGS.skill_dynamics_relabel_type,
                cur_policy=relabel_policy,
                cur_skill_dynamics=agent.skill_dynamics)
            input_obs = process_observation(
                trajectory_sample.observation[:, 0, :-FLAGS.num_skills])
            cur_skill = trajectory_sample.observation[:, 0, -FLAGS.num_skills:]
            target_obs = process_observation(
                trajectory_sample.observation[:, 1, :-FLAGS.num_skills])
            if FLAGS.clear_buffer_every_iter:
              agent.skill_dynamics.train(
                  input_obs,
                  cur_skill,
                  target_obs,
                  batch_size=FLAGS.skill_dyn_batch_size,
                  batch_weights=is_weights,
                  num_steps=FLAGS.skill_dyn_train_steps)
            else:
              agent.skill_dynamics.train(
                  input_obs,
                  cur_skill,
                  target_obs,
                  batch_size=-1,
                  batch_weights=is_weights,
                  num_steps=1)

          if FLAGS.train_skill_dynamics_on_policy:
            on_buffer.clear()

          skill_dynamics_end_train_time = time.time()
          print('skill_dynamics train time:',
                skill_dynamics_end_train_time - collect_end_time)

          running_dads_reward, running_logp, running_logp_altz = [], [], []

          # agent train loop analysis
          within_agent_train_time = time.time()
          sampling_time_arr, filtering_time_arr, relabelling_time_arr, train_time_arr = [], [], [], []
          for _ in range(
              1 if FLAGS.clear_buffer_every_iter else FLAGS.agent_train_steps):
            if FLAGS.clear_buffer_every_iter:
              trajectory_sample = rbuffer.gather_all_transitions()
            else:
              trajectory_sample = rbuffer.get_next(
                  sample_batch_size=FLAGS.agent_batch_size, num_steps=2)

            buffer_sampling_time = time.time()
            sampling_time_arr.append(buffer_sampling_time -
                                     within_agent_train_time)
            trajectory_sample = _filter_trajectories(trajectory_sample)

            filtering_time = time.time()
            filtering_time_arr.append(filtering_time - buffer_sampling_time)
            trajectory_sample, _ = relabel_skill(
                trajectory_sample,
                relabel_type=FLAGS.agent_relabel_type,
                cur_policy=relabel_policy,
                cur_skill_dynamics=agent.skill_dynamics)
            relabelling_time = time.time()
            relabelling_time_arr.append(relabelling_time - filtering_time)

            # need to match the assert structure
            if FLAGS.skill_dynamics_relabel_type is not None and 'importance_sampling' in FLAGS.skill_dynamics_relabel_type:
              trajectory_sample = trajectory_sample._replace(policy_info=())

            if not FLAGS.clear_buffer_every_iter:
              dads_reward, info = agent.train_loop(
                  trajectory_sample,
                  recompute_reward=True,  # turn False for normal SAC training
                  batch_size=-1,
                  num_steps=1)
            else:
              dads_reward, info = agent.train_loop(
                  trajectory_sample,
                  recompute_reward=True,  # turn False for normal SAC training
                  batch_size=FLAGS.agent_batch_size,
                  num_steps=FLAGS.agent_train_steps)

            within_agent_train_time = time.time()
            train_time_arr.append(within_agent_train_time - relabelling_time)
            if dads_reward is not None:
              running_dads_reward.append(dads_reward)
              running_logp.append(info['logp'])
              running_logp_altz.append(info['logp_altz'])

          agent_end_train_time = time.time()
          print('agent train time:',
                agent_end_train_time - skill_dynamics_end_train_time)
          print('\t sampling time:', np.sum(sampling_time_arr))
          print('\t filtering_time:', np.sum(filtering_time_arr))
          print('\t relabelling time:', np.sum(relabelling_time_arr))
          print('\t train_time:', np.sum(train_time_arr))

          if len(episode_size_buffer) > 1:
            train_writer.add_summary(
                tf.compat.v1.Summary(value=[
                    tf.compat.v1.Summary.Value(
                        tag='episode_size',
                        simple_value=np.mean(episode_size_buffer[:-1]))
                ]), sample_count)
          if len(episode_return_buffer) > 1:
            train_writer.add_summary(
                tf.compat.v1.Summary(value=[
                    tf.compat.v1.Summary.Value(
                        tag='episode_return',
                        simple_value=np.mean(episode_return_buffer[:-1]))
                ]), sample_count)
          train_writer.add_summary(
              tf.compat.v1.Summary(value=[
                  tf.compat.v1.Summary.Value(
                      tag='dads/reward',
                      simple_value=np.mean(
                          np.concatenate(running_dads_reward)))
              ]), sample_count)

          train_writer.add_summary(
      
Download .txt
gitextract_6tre9pw4/

├── .gitignore
├── AUTHORS
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── configs/
│   ├── ant_xy_offpolicy.txt
│   ├── ant_xy_onpolicy.txt
│   ├── dkitty_randomized_xy_offpolicy.txt
│   ├── humanoid_offpolicy.txt
│   ├── humanoid_onpolicy.txt
│   └── template_config.txt
├── env.yml
├── envs/
│   ├── assets/
│   │   ├── ant.xml
│   │   ├── ant_footsensor.xml
│   │   ├── half_cheetah.xml
│   │   ├── humanoid.xml
│   │   └── point.xml
│   ├── dclaw.py
│   ├── dkitty_redesign.py
│   ├── gym_mujoco/
│   │   ├── ant.py
│   │   ├── half_cheetah.py
│   │   ├── humanoid.py
│   │   └── point_mass.py
│   ├── hand_block.py
│   ├── skill_wrapper.py
│   └── video_wrapper.py
├── lib/
│   ├── py_tf_policy.py
│   └── py_uniform_replay_buffer.py
└── unsupervised_skill_learning/
    ├── dads_agent.py
    ├── dads_off.py
    ├── skill_discriminator.py
    └── skill_dynamics.py
Download .txt
SYMBOL INDEX (153 symbols across 15 files)

FILE: envs/dclaw.py
  class BaseDClawTurn (line 46) | class BaseDClawTurn(BaseDClawObjectEnv, metaclass=abc.ABCMeta):
    method __init__ (line 49) | def __init__(self,
    method _reset (line 79) | def _reset(self):
    method _step (line 86) | def _step(self, action: np.ndarray):
    method get_obs_dict (line 92) | def get_obs_dict(self) -> Dict[str, np.ndarray]:
    method get_reward_dict (line 116) | def get_reward_dict(
    method get_score_dict (line 125) | def get_score_dict(
    method get_done (line 133) | def get_done(
  class DClawTurnRandom (line 143) | class DClawTurnRandom(BaseDClawTurn):
    method _reset (line 146) | def _reset(self):
  class DClawTurnRandomDynamics (line 154) | class DClawTurnRandomDynamics(DClawTurnRandom):
    method __init__ (line 160) | def __init__(self,
    method _reset (line 171) | def _reset(self):

FILE: envs/dkitty_redesign.py
  class BaseDKittyWalk (line 45) | class BaseDKittyWalk(BaseDKittyUprightEnv, metaclass=abc.ABCMeta):
    method __init__ (line 48) | def __init__(
    method _reset (line 111) | def _reset(self):
    method _step (line 122) | def _step(self, action: np.ndarray):
    method get_obs_dict (line 137) | def get_obs_dict(self) -> Dict[str, np.ndarray]:
    method get_reward_dict (line 163) | def get_reward_dict(
    method get_score_dict (line 172) | def get_score_dict(
  class DKittyRandomDynamics (line 181) | class DKittyRandomDynamics(BaseDKittyWalk):
    method __init__ (line 184) | def __init__(self, *args, randomize_hfield=0.0, **kwargs):
    method _reset (line 191) | def _reset(self):

FILE: envs/gym_mujoco/ant.py
  function q_inv (line 25) | def q_inv(a):
  function q_mult (line 29) | def q_mult(a, b):  # multiply two quaternion
  class AntEnv (line 37) | class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
    method __init__ (line 39) | def __init__(self,
    method compute_reward (line 65) | def compute_reward(self, ob, next_ob, action=None):
    method step (line 91) | def step(self, a):
    method _get_obs (line 127) | def _get_obs(self):
    method reset_model (line 160) | def reset_model(self):
    method viewer_setup (line 171) | def viewer_setup(self):
    method get_ori (line 174) | def get_ori(self):
    method body_com_indices (line 182) | def body_com_indices(self):
    method body_comvel_indices (line 186) | def body_comvel_indices(self):

FILE: envs/gym_mujoco/half_cheetah.py
  class HalfCheetahEnv (line 26) | class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
    method __init__ (line 28) | def __init__(self,
    method step (line 48) | def step(self, action):
    method _get_obs (line 72) | def _get_obs(self):
    method reset_model (line 81) | def reset_model(self):
    method viewer_setup (line 88) | def viewer_setup(self):

FILE: envs/gym_mujoco/humanoid.py
  function mass_center (line 26) | def mass_center(sim):
  class HumanoidEnv (line 33) | class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
    method __init__ (line 35) | def __init__(self,
    method _get_obs (line 58) | def _get_obs(self):
    method compute_reward (line 71) | def compute_reward(self, ob, next_ob, action=None):
    method step (line 100) | def step(self, a):
    method reset_model (line 136) | def reset_model(self):
    method viewer_setup (line 154) | def viewer_setup(self):

FILE: envs/gym_mujoco/point_mass.py
  class PointMassEnv (line 28) | class PointMassEnv(mujoco_env.MujocoEnv, utils.EzPickle):
    method __init__ (line 30) | def __init__(self,
    method step (line 57) | def step(self, action):
    method _get_obs (line 85) | def _get_obs(self):
    method reset_model (line 93) | def reset_model(self):
    method set_qpos (line 104) | def set_qpos(self, state):
    method viewer_setup (line 108) | def viewer_setup(self):

FILE: envs/hand_block.py
  class HandBlockCustomEnv (line 24) | class HandBlockCustomEnv(ManipulateEnv):
    method __init__ (line 25) | def __init__(self,
    method _get_viewer (line 49) | def _get_viewer(self, mode):
    method _viewer_setup (line 60) | def _viewer_setup(self):
    method step (line 69) | def step(self, action):
    method render (line 82) | def render(self, mode='human', width=500, height=500):

FILE: envs/skill_wrapper.py
  class SkillWrapper (line 24) | class SkillWrapper(Wrapper):
    method __init__ (line 26) | def __init__(
    method _remake_time_step (line 56) | def _remake_time_step(self, cur_obs):
    method _set_skill (line 65) | def _set_skill(self):
    method reset (line 80) | def reset(self):
    method step (line 86) | def step(self, action):
    method close (line 95) | def close(self):

FILE: envs/video_wrapper.py
  class VideoWrapper (line 25) | class VideoWrapper(Wrapper):
    method __init__ (line 27) | def __init__(self, env, base_path, base_name=None, new_video_every_res...
    method reset (line 44) | def reset(self):
    method step (line 59) | def step(self, action):
    method close (line 63) | def close(self):

FILE: lib/py_tf_policy.py
  class PyTFPolicy (line 32) | class PyTFPolicy(py_policy.Base, session_utils.SessionUser):
    method __init__ (line 41) | def __init__(self, policy, batch_size=None, seed=None):
    method _construct (line 73) | def _construct(self, batch_size, graph):
    method initialize (line 103) | def initialize(self, batch_size, graph=None):
    method save (line 115) | def save(self, policy_dir=None, graph=None):
    method restore (line 130) | def restore(self, policy_dir, graph=None, assert_consumed=True):
    method _build_from_time_step (line 166) | def _build_from_time_step(self, time_step):
    method _get_initial_state (line 178) | def _get_initial_state(self, batch_size):
    method _action (line 188) | def _action(self, time_step, policy_state):
    method log_prob (line 223) | def log_prob(self, time_step, action_step, policy_state=None):

FILE: lib/py_uniform_replay_buffer.py
  class PyUniformReplayBuffer (line 39) | class PyUniformReplayBuffer(replay_buffer.ReplayBuffer):
    method __init__ (line 49) | def __init__(self, data_spec, capacity):
    method _encoded_data_spec (line 73) | def _encoded_data_spec(self):
    method _encode (line 77) | def _encode(self, item):
    method _decode (line 81) | def _decode(self, item):
    method _on_delete (line 85) | def _on_delete(self, encoded_item):
    method size (line 90) | def size(self):
    method _add_batch (line 93) | def _add_batch(self, items):
    method _get_next (line 111) | def _get_next(self,
    method _as_dataset (line 155) | def _as_dataset(self, sample_batch_size=None, num_steps=None,
    method _gather_all (line 193) | def _gather_all(self):
    method _clear (line 200) | def _clear(self):
    method gather_all_transitions (line 204) | def gather_all_transitions(self):

FILE: unsupervised_skill_learning/dads_agent.py
  class DADSAgent (line 36) | class DADSAgent(sac_agent.SacAgent):
    method __init__ (line 38) | def __init__(self,
    method compute_dads_reward (line 89) | def compute_dads_reward(self, input_obs, cur_skill, target_obs):
    method get_experience_placeholder (line 148) | def get_experience_placeholder(self):
    method build_agent_graph (line 163) | def build_agent_graph(self):
    method build_skill_dynamics_graph (line 170) | def build_skill_dynamics_graph(self):
    method create_savers (line 176) | def create_savers(self):
    method set_sessions (line 180) | def set_sessions(self, initialize_or_restore_skill_dynamics, session=N...
    method save_variables (line 189) | def save_variables(self, global_step):
    method _get_dict (line 192) | def _get_dict(self, trajectories, batch_size=-1):
    method train_loop (line 207) | def train_loop(self,
    method skill_dynamics (line 243) | def skill_dynamics(self):

FILE: unsupervised_skill_learning/dads_off.py
  function _normal_projection_net (line 226) | def _normal_projection_net(action_spec, init_means_output_factor=0.1):
  function get_environment (line 235) | def get_environment(env_name='point_mass'):
  function hide_coords (line 298) | def hide_coords(time_step):
  function relabel_skill (line 307) | def relabel_skill(trajectory_sample,
  function process_observation (line 444) | def process_observation(observation):
  function collect_experience (line 555) | def collect_experience(py_env,
  function run_on_env (line 611) | def run_on_env(env,
  function eval_loop (line 680) | def eval_loop(eval_dir,
  function eval_planning (line 829) | def eval_planning(env,
  function eval_mppi (line 893) | def eval_mppi(
  function main (line 1078) | def main(_):

FILE: unsupervised_skill_learning/skill_discriminator.py
  class SkillDiscriminator (line 28) | class SkillDiscriminator:
    method __init__ (line 30) | def __init__(
    method _get_distributions (line 73) | def _get_distributions(self, out):
    method _default_graph (line 121) | def _default_graph(self, timesteps):
    method _get_dict (line 133) | def _get_dict(self,
    method make_placeholders (line 161) | def make_placeholders(self):
    method set_session (line 176) | def set_session(self, session=None, initialize_or_restore_variables=Fa...
    method build_graph (line 200) | def build_graph(self,
    method increase_prob_op (line 235) | def increase_prob_op(self, learning_rate=3e-4):
    method decrease_prob_op (line 244) | def decrease_prob_op(self, learning_rate=3e-4):
    method train (line 254) | def train(self,
    method get_log_probs (line 279) | def get_log_probs(self, timesteps, skills, next_timesteps=None):
    method create_saver (line 290) | def create_saver(self, save_prefix):
    method save_variables (line 302) | def save_variables(self, global_step):
    method restore_variables (line 311) | def restore_variables(self):

FILE: unsupervised_skill_learning/skill_dynamics.py
  class SkillDynamics (line 28) | class SkillDynamics:
    method __init__ (line 30) | def __init__(
    method _get_distribution (line 77) | def _get_distribution(self, out):
    method _graph_with_separate_skill_pipe (line 129) | def _graph_with_separate_skill_pipe(self, timesteps, actions):
    method _default_graph (line 164) | def _default_graph(self, timesteps, actions):
    method _get_dict (line 176) | def _get_dict(self,
    method _get_run_dict (line 213) | def _get_run_dict(self, input_data, input_actions):
    method make_placeholders (line 223) | def make_placeholders(self):
    method set_session (line 240) | def set_session(self, session=None, initialize_or_restore_variables=Fa...
    method build_graph (line 264) | def build_graph(self,
    method increase_prob_op (line 307) | def increase_prob_op(self, learning_rate=3e-4, weights=None):
    method decrease_prob_op (line 328) | def decrease_prob_op(self, learning_rate=3e-4, weights=None):
    method create_saver (line 346) | def create_saver(self, save_prefix):
    method save_variables (line 359) | def save_variables(self, global_step):
    method restore_variables (line 368) | def restore_variables(self):
    method train (line 373) | def train(self,
    method get_log_prob (line 400) | def get_log_prob(self, timesteps, actions, next_timesteps):
    method predict_state (line 409) | def predict_state(self, timesteps, actions):
Condensed preview — 32 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (241K chars).
[
  {
    "path": ".gitignore",
    "chars": 178,
    "preview": "# Generated files\n*.egg-info/\n.idea*\n*__pycache__*\n.ipynb_checkpoints*\n*.pyc\n*.DS_Store\n*.mp4\n*.json\noutput/\nsaved_model"
  },
  {
    "path": "AUTHORS",
    "chars": 131,
    "preview": "# This is the list of authors for copyright purposes.\nGoogle LLC\nArchit Sharma\nShixiang Gu\nSergey Levine\nVikash Kumar\nKa"
  },
  {
    "path": "CONTRIBUTING.md",
    "chars": 1100,
    "preview": "# How to Contribute\n\nWe'd love to accept your patches and contributions to this project. There are\njust a few small guid"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "\n                                 Apache License\n                           Version 2.0, January 2004\n                  "
  },
  {
    "path": "README.md",
    "chars": 4887,
    "preview": "# Dynamics-Aware Discovery of Skills (DADS)\nThis repository is the open-source implementation of Dynamics-Aware Unsuperv"
  },
  {
    "path": "configs/ant_xy_offpolicy.txt",
    "chars": 2844,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "configs/ant_xy_onpolicy.txt",
    "chars": 2825,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "configs/dkitty_randomized_xy_offpolicy.txt",
    "chars": 2836,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "configs/humanoid_offpolicy.txt",
    "chars": 2872,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "configs/humanoid_onpolicy.txt",
    "chars": 2869,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "configs/template_config.txt",
    "chars": 2878,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "env.yml",
    "chars": 927,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "envs/assets/ant.xml",
    "chars": 5629,
    "preview": "<!-- ======================================================\n# Copyright 2019 Google LLC\n#\n# Licensed under the Apache Li"
  },
  {
    "path": "envs/assets/ant_footsensor.xml",
    "chars": 6285,
    "preview": "<!-- ======================================================\n# Copyright 2019 Google LLC\n#\n# Licensed under the Apache Li"
  },
  {
    "path": "envs/assets/half_cheetah.xml",
    "chars": 6311,
    "preview": "<!-- ======================================================\n# Copyright 2019 Google LLC\n#\n# Licensed under the Apache Li"
  },
  {
    "path": "envs/assets/humanoid.xml",
    "chars": 9577,
    "preview": "<!-- ======================================================\r\n# Copyright 2019 Google LLC\r\n#\r\n# Licensed under the Apache"
  },
  {
    "path": "envs/assets/point.xml",
    "chars": 2510,
    "preview": "<!-- ======================================================\n# Copyright 2019 Google LLC\n#\n# Licensed under the Apache Li"
  },
  {
    "path": "envs/dclaw.py",
    "chars": 6932,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "envs/dkitty_redesign.py",
    "chars": 8068,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "envs/gym_mujoco/ant.py",
    "chars": 6307,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "envs/gym_mujoco/half_cheetah.py",
    "chars": 2860,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "envs/gym_mujoco/humanoid.py",
    "chars": 5227,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "envs/gym_mujoco/point_mass.py",
    "chars": 3575,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "envs/hand_block.py",
    "chars": 3106,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "envs/skill_wrapper.py",
    "chars": 3259,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "envs/video_wrapper.py",
    "chars": 2120,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "lib/py_tf_policy.py",
    "chars": 9073,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "lib/py_uniform_replay_buffer.py",
    "chars": 8796,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "unsupervised_skill_learning/dads_agent.py",
    "chars": 9314,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "unsupervised_skill_learning/dads_off.py",
    "chars": 69071,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "unsupervised_skill_learning/skill_discriminator.py",
    "chars": 10897,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "unsupervised_skill_learning/skill_dynamics.py",
    "chars": 15850,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  }
]

About this extraction

This page contains the full source code of the google-research/dads GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 32 files (225.1 KB), approximately 58.3k tokens, and a symbol index with 153 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!