[
  {
    "path": ".gitignore",
    "content": "# Generated files\n*.egg-info/\n.idea*\n*__pycache__*\n.ipynb_checkpoints*\n*.pyc\n*.DS_Store\n*.mp4\n*.json\noutput/\nsaved_models/\nenv_test.py\ndkitty_eval.sh\nexperiments/\ndads_token.txt\n"
  },
  {
    "path": "AUTHORS",
    "content": "# This is the list of authors for copyright purposes.\nGoogle LLC\nArchit Sharma\nShixiang Gu\nSergey Levine\nVikash Kumar\nKarol Hausman"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# How to Contribute\n\nWe'd love to accept your patches and contributions to this project. There are\njust a few small guidelines you need to follow.\n\n## Contributor License Agreement\n\nContributions to this project must be accompanied by a Contributor License\nAgreement. You (or your employer) retain the copyright to your contribution;\nthis simply gives us permission to use and redistribute your contributions as\npart of the project. Head over to <https://cla.developers.google.com/> to see\nyour current agreements on file or to sign a new one.\n\nYou generally only need to submit a CLA once, so if you've already submitted one\n(even if it was for a different project), you probably don't need to do it\nagain.\n\n## Code reviews\n\nAll submissions, including submissions by project members, require review. We\nuse GitHub pull requests for this purpose. Consult\n[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more\ninformation on using pull requests.\n\n## Community Guidelines\n\nThis project follows\n[Google's Open Source Community Guidelines](https://opensource.google.com/conduct/)."
  },
  {
    "path": "LICENSE",
    "content": "\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License."
  },
  {
    "path": "README.md",
    "content": "# Dynamics-Aware Discovery of Skills (DADS)\nThis repository is the open-source implementation of Dynamics-Aware Unsupervised Discovery of Skills ([project page][website], [arXiv][paper]). We propose an skill-discovery method which can learn skills for different agents without any rewards, while simultaneously learning dynamics model for the skills which can be leveraged for model-based control on the downstream task. This work was published in International Conference of Learning Representations ([ICLR][iclr]), 2020.\n\nWe have also included an improved off-policy version of DADS, coined off-DADS. The details have been released in [Emergent Real-World Robotic Skills via Unsupervised Off-Policy Reinforcement Learning][rss_arxiv].\n\nIn case of problems, contact Archit Sharma.\n\n## Table of Contents\n\n* [Setup](#setup)\n* [Usage](#usage)\n* [Citation](#citation)\n* [Disclaimer](#disclaimer)\n\n## Setup\n\n#### (1) Setup MuJoCo\nDownload and setup [mujoco][mujoco] in `~/.mujoco`. Set the `LD_LIBRARY_PATH` in your `~/.bashrc`:\n```\nLD_LIBRARY_PATH='~/.mujoco/mjpro150/bin':$LD_LIBRARY_PATH\n```\n\n#### (2) Setup environment\nClone the repository and setup up the [conda][conda] environment to run DADS code:\n```\ncd <path_to_dads>\nconda env create -f env.yml\nconda activate dads-env\n```\n\n## Usage\nWe give a high-level explanation of how to use the code. More details pertaining to hyperparameters can be found in the the `configs/template_config.txt`, `dads_off.py` and the Appendix A of [paper][paper].\n\nEvery training run will require an experimental logging directory and a configuration file, which can be created started from the `configs/template_config.txt`. There are two phases: (a) Training where the new skills are learnt along with their skill-dynamics models and (b) evaluation where the learnt skills are evaluated on the task associated with the environment.\n\nFor training, ensure `--run_train=1` is set in the configuration file. For on-policy optimization, set `--clear_buffer_every_iter=1` and ensure the replay buffer size is bigger than the number of steps collected in every iteration. For off-policy optimization (details yet to be released), set `--clear_buffer_every_iter=0`. Set the environment name (ensure the environment is listed in `get_environment()` in `dads_off.py`). To change the observation for skill-dynamics (for example to learn in x-y space), set `--reduced_observation` and correspondingly configure `process_observation()` in `dads_off.py`. The skill space can be configured to be discrete or continuous. The optimization parameters can be tweaked, and some basic values have been set in (more details in the [paper][paper]). \n\nFor evaluation, ensure `--run_eval=1` and the experimental directory points to the same directory in which the training happened. Set `--num_evals` if you want to record videos of randomly sampled skills from the prior distribution. After that, the script will use the learned models to execute MPC on the latent space to optimize for the task-reward. By default, the code will call `get_environment()` to load `FLAGS.environment + '_goal'`, and will go through the list of goal-coordinates specified in the eval section of the script.\n\nWe have provided the configuration files in `configs/` to reproduce results from the experiments in the [paper][paper]. Goal evaluation is currently only setup for MuJoCo Ant environement. The goal distribution can be changed in `dads_off.py` in evaluation part of the script.\n\n```\ncd <path_to_dads>\npython unsupervised_skill_learning/dads_off.py --logdir=<path_for_experiment_logs> --flagfile=configs/<config_name>.txt\n```\n\nThe specified experimental log directory will contain the tensorboard files, the saved checkpoints and the skill-evaluation videos.\n\n## Citation\nTo cite [Dynamics-Aware Unsupervised Discovery of Skills](paper):\n```\n@article{sharma2019dynamics,\n  title={Dynamics-aware unsupervised discovery of skills},\n  author={Sharma, Archit and Gu, Shixiang and Levine, Sergey and Kumar, Vikash and Hausman, Karol},\n  journal={arXiv preprint arXiv:1907.01657},\n  year={2019}\n}\n```\nTo cite off-DADS and [Emergent Real-World Robotic Skills via Unsupervised Off-Policy Reinforcement Learning][rss_arxiv]:\n```\n@article{sharma2020emergent,\n    title={Emergent Real-World Robotic Skills via Unsupervised Off-Policy Reinforcement Learning},\n    author={Sharma, Archit and Ahn, Michael and Levine, Sergey and Kumar, Vikash and Hausman, Karol and Gu, Shixiang},\n    journal={arXiv preprint arXiv:2004.12974},\n    year={2020}\n}\n```\n## Disclaimer\nThis is not an officially supported Google product.\n\n[website]: https://sites.google.com/corp/view/dads-skill \n[paper]: https://arxiv.org/abs/1907.01657\n[iclr]: https://openreview.net/forum?id=HJgLZR4KvH\n[mujoco]: http://www.mujoco.org/\n[conda]: https://docs.conda.io/en/latest/miniconda.html\n[rss_arxiv]: https://arxiv.org/abs/2004.12974\n"
  },
  {
    "path": "configs/ant_xy_offpolicy.txt",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n### TRAINING HYPERPARAMETERS -------------------\n--run_train=1\n\n# metadata flags\n--save_model=dads\n--save_freq=50\n--record_freq=100\n--vid_name=skill\n\n# optimization hyperparmaters\n--replay_buffer_capacity=10000\n\n# (set clear_buffer_every_iter=1 for on-policy optimization)\n--clear_buffer_every_iter=0\n--initial_collect_steps=2000\n--collect_steps=500\n--num_epochs=10000\n\n# skill dynamics optimization hyperparameters\n--skill_dyn_train_steps=8\n--skill_dynamics_lr=3e-4\n--skill_dyn_batch_size=256\n\n# agent hyperparameters\n--agent_gamma=0.99\n--agent_lr=3e-4\n--agent_entropy=0.1\n--agent_train_steps=64\n--agent_batch_size=256\n\n# (optional, do not change for on-policy) relabelling or off-policy corrections\n--skill_dynamics_relabel_type=importance_sampling\n--num_samples_for_relabelling=1\n--is_clip_eps=10.\n\n# (optional) skills can be resampled within the episodes, relative to max_env_steps\n--min_steps_before_resample=2000\n--resample_prob=0.02\n\n# (optional) configure skill dynamics training samples to be only from the current policy\n--train_skill_dynamics_on_policy=0\n\n### SHARED HYPERPARAMETERS ---------------------\n--environment=Ant-v1\n--max_env_steps=200\n--reduced_observation=2\n\n# define the type of skills being learnt\n--num_skills=2\n--skill_type=cont_uniform\n--random_skills=100\n--num_evals=3\n\n# (optional) policy, critic and skill dynamics\n--hidden_layer_size=512\n\n# (optional) skill dynamics hyperparameters\n--graph_type=default\n--num_components=4\n--fix_variance=1\n--normalize_data=1\n\n# (optional) clip sampled actions\n--action_clipping=1.\n\n# (optional) debugging\n--debug=0\n\n### EVALUATION HYPERPARAMETERS -----------------\n--run_eval=0\n\n# MPC hyperparameters\n--planning_horizon=1\n--primitive_horizon=10\n--num_candidate_sequences=50\n--refine_steps=10\n--mppi_gamma=10\n--prior_type=normal\n--smoothing_beta=0.9\n--top_primitives=5\n\n\n### (optional) ENVIRONMENT SPECIFIC HYPERPARAMETERS --------\n# DKitty hyperparameters\n--expose_last_action=1\n--expose_upright=1\n--robot_noise_ratio=0.0\n--root_noise_ratio=0.0\n--upright_threshold=0.95\n--scale_root_position=1\n--randomize_hfield=0.0\n\n# DKitty/DClaw\n--observation_omission_size=0\n\n# Cube Manipulation hyperparameters\n--randomized_initial_distribution=1\n--horizontal_wrist_constraint=0.3\n--vertical_wrist_constraint=1.0\n"
  },
  {
    "path": "configs/ant_xy_onpolicy.txt",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n### TRAINING HYPERPARAMETERS -------------------\n--run_train=1\n\n# metadata flags\n--save_model=dads\n--save_freq=50\n--record_freq=100\n--vid_name=skill\n\n# optimization hyperparmaters\n--replay_buffer_capacity=100000\n\n# (set clear_buffer_iter=1 for on-policy)\n--clear_buffer_every_iter=1\n--initial_collect_steps=0\n--collect_steps=2000\n--num_epochs=10000\n\n# skill dynamics optimization hyperparameters\n--skill_dyn_train_steps=32\n--skill_dynamics_lr=3e-4\n--skill_dyn_batch_size=256\n\n# agent hyperparameters\n--agent_gamma=0.995\n--agent_lr=3e-4\n--agent_entropy=0.1\n--agent_train_steps=64\n--agent_batch_size=256\n\n# (optional, do not change for on-policy) relabelling or off-policy corrections\n--skill_dynamics_relabel_type=importance_sampling\n--num_samples_for_relabelling=1\n--is_clip_eps=1.\n\n# (optional) skills can be resampled within the episodes, relative to max_env_steps\n--min_steps_before_resample=2000\n--resample_prob=0.02\n\n# (optional) configure skill dynamics training samples to be only from the current policy\n--train_skill_dynamics_on_policy=0\n\n### SHARED HYPERPARAMETERS ---------------------\n--environment=Ant-v1\n--max_env_steps=200\n--reduced_observation=2\n\n# define the type of skills being learnt\n--num_skills=2\n--skill_type=cont_uniform\n--random_skills=100\n--num_evals=3\n\n# (optional) policy, critic and skill dynamics\n--hidden_layer_size=512\n\n# (optional) skill dynamics hyperparameters\n--graph_type=default\n--num_components=4\n--fix_variance=1\n--normalize_data=1\n\n# (optional) clip sampled actions\n--action_clipping=1.\n\n# (optional) debugging\n--debug=0\n\n### EVALUATION HYPERPARAMETERS -----------------\n--run_eval=0\n\n# MPC hyperparameters\n--planning_horizon=1\n--primitive_horizon=10\n--num_candidate_sequences=50\n--refine_steps=10\n--mppi_gamma=10\n--prior_type=normal\n--smoothing_beta=0.9\n--top_primitives=5\n\n\n### (optional) ENVIRONMENT SPECIFIC HYPERPARAMETERS --------\n# DKitty hyperparameters\n--expose_last_action=1\n--expose_upright=1\n--robot_noise_ratio=0.0\n--root_noise_ratio=0.0\n--upright_threshold=0.95\n--scale_root_position=1\n--randomize_hfield=0.0\n\n# DKitty/DClaw\n--observation_omission_size=0\n\n# Cube Manipulation hyperparameters\n--randomized_initial_distribution=1\n--horizontal_wrist_constraint=0.3\n--vertical_wrist_constraint=1.0\n"
  },
  {
    "path": "configs/dkitty_randomized_xy_offpolicy.txt",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n### TRAINING HYPERPARAMETERS -------------------\n--run_train=1\n\n# metadata flags\n--save_model=dads\n--save_freq=50\n--record_freq=100\n--vid_name=skill\n\n# optimization hyperparmaters\n--replay_buffer_capacity=10000\n\n# (set clear_buffer_iter=1 for on-policy)\n--clear_buffer_every_iter=0\n--initial_collect_steps=2000\n--collect_steps=500\n--num_epochs=1000\n\n# skill dynamics optimization hyperparameters\n--skill_dyn_train_steps=8\n--skill_dynamics_lr=3e-4\n--skill_dyn_batch_size=256\n\n# agent hyperparameters\n--agent_gamma=0.99\n--agent_lr=3e-4\n--agent_entropy=0.1\n--agent_train_steps=64\n--agent_batch_size=256\n\n# (optional, do not change for on-policy) relabelling or off-policy corrections\n--skill_dynamics_relabel_type=importance_sampling\n--num_samples_for_relabelling=1\n--is_clip_eps=10.\n\n# (optional) skills can be resampled within the episodes, relative to max_env_steps\n--min_steps_before_resample=2000\n--resample_prob=0.02\n\n# (optional) configure skill dynamics training samples to be only from the current policy\n--train_skill_dynamics_on_policy=0\n\n### SHARED HYPERPARAMETERS ---------------------\n--environment=DKitty_randomized\n--max_env_steps=200\n--reduced_observation=2\n\n# define the type of skills being learnt\n--num_skills=2\n--skill_type=cont_uniform\n--random_skills=100\n--num_evals=3\n\n# (optional) policy, critic and skill dynamics\n--hidden_layer_size=512\n\n# (optional) skill dynamics hyperparameters\n--graph_type=default\n--num_components=4\n--fix_variance=1\n--normalize_data=1\n\n# (optional) clip sampled actions\n--action_clipping=1.\n\n# (optional) debugging\n--debug=0\n\n### EVALUATION HYPERPARAMETERS -----------------\n--run_eval=0\n\n# MPC hyperparameters\n--planning_horizon=1\n--primitive_horizon=10\n--num_candidate_sequences=50\n--refine_steps=10\n--mppi_gamma=10\n--prior_type=normal\n--smoothing_beta=0.9\n--top_primitives=5\n\n\n### (optional) ENVIRONMENT SPECIFIC HYPERPARAMETERS --------\n# DKitty hyperparameters\n--expose_last_action=1\n--expose_upright=1\n--robot_noise_ratio=0.0\n--root_noise_ratio=0.0\n--upright_threshold=0.95\n--scale_root_position=1\n--randomize_hfield=0.02\n\n# DKitty/DClaw\n--observation_omission_size=2\n\n# Cube Manipulation hyperparameters\n--randomized_initial_distribution=1\n--horizontal_wrist_constraint=0.3\n--vertical_wrist_constraint=1.0\n"
  },
  {
    "path": "configs/humanoid_offpolicy.txt",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n### TRAINING HYPERPARAMETERS -------------------\n--run_train=1\n\n# metadata flags\n--save_model=dads\n--save_freq=50\n--record_freq=100\n--vid_name=skill\n\n# optimization hyperparmaters\n--replay_buffer_capacity=10000\n\n# (set clear_buffer_iter=1 for on-policy)\n--clear_buffer_every_iter=0\n--initial_collect_steps=5000\n--collect_steps=2000\n--num_epochs=100000\n\n# skill dynamics optimization hyperparameters\n--skill_dyn_train_steps=16\n--skill_dynamics_lr=3e-4\n--skill_dyn_batch_size=256\n\n# agent hyperparameters\n--agent_gamma=0.995\n--agent_lr=3e-4\n--agent_entropy=0.1\n--agent_train_steps=128\n--agent_batch_size=256\n\n# (optional, do not change for on-policy) relabelling or off-policy corrections\n--skill_dynamics_relabel_type=importance_sampling\n--num_samples_for_relabelling=1\n--is_clip_eps=1.\n\n# (optional) skills can be resampled within the episodes, relative to max_env_steps\n--min_steps_before_resample=2000\n--resample_prob=0.0\n\n# (optional) configure skill dynamics training samples to be only from the current policy\n--train_skill_dynamics_on_policy=0\n\n### SHARED HYPERPARAMETERS ---------------------\n--environment=Humanoid-v1\n--max_env_steps=1000\n--reduced_observation=0\n\n# define the type of skills being learnt\n--num_skills=5\n--skill_type=cont_uniform\n--random_skills=100\n\n# number of skill-video evaluations\n--num_evals=3\n\n# (optional) policy, critic and skill dynamics\n--hidden_layer_size=1024\n\n# (optional) skill dynamics hyperparameters\n--graph_type=default\n--num_components=4\n--fix_variance=1\n--normalize_data=1\n\n# (optional) clip sampled actions\n--action_clipping=1.\n\n# (optional) debugging\n--debug=0\n\n### EVALUATION HYPERPARAMETERS -----------------\n--run_eval=0\n\n# MPC hyperparameters\n--planning_horizon=1\n--primitive_horizon=10\n--num_candidate_sequences=50\n--refine_steps=10\n--mppi_gamma=10\n--prior_type=normal\n--smoothing_beta=0.9\n--top_primitives=5\n\n\n### (optional) ENVIRONMENT SPECIFIC HYPERPARAMETERS --------\n# DKitty hyperparameters\n--expose_last_action=1\n--expose_upright=1\n--robot_noise_ratio=0.0\n--root_noise_ratio=0.0\n--upright_threshold=0.95\n--scale_root_position=1\n--randomize_hfield=0.0\n\n# DKitty/DClaw\n--observation_omission_size=0\n\n# Cube Manipulation hyperparameters\n--randomized_initial_distribution=1\n--horizontal_wrist_constraint=0.3\n--vertical_wrist_constraint=1.0\n"
  },
  {
    "path": "configs/humanoid_onpolicy.txt",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n### TRAINING HYPERPARAMETERS -------------------\n--run_train=1\n\n# metadata flags\n--save_model=dads\n--save_freq=50\n--record_freq=100\n--vid_name=skill\n\n# optimization hyperparmaters\n--replay_buffer_capacity=100000\n\n# (set clear_buffer_iter=1 for on-policy)\n--clear_buffer_every_iter=1\n--initial_collect_steps=0\n--collect_steps=4000\n--num_epochs=100000\n\n# skill dynamics optimization hyperparameters\n--skill_dyn_train_steps=32\n--skill_dynamics_lr=3e-4\n--skill_dyn_batch_size=256\n\n# agent hyperparameters\n--agent_gamma=0.995\n--agent_lr=3e-4\n--agent_entropy=0.1\n--agent_train_steps=64\n--agent_batch_size=256\n\n# (optional, do not change for on-policy) relabelling or off-policy corrections\n--skill_dynamics_relabel_type=importance_sampling\n--num_samples_for_relabelling=1\n--is_clip_eps=1.\n\n# (optional) skills can be resampled within the episodes, relative to max_env_steps\n--min_steps_before_resample=2000\n--resample_prob=0.0\n\n# (optional) configure skill dynamics training samples to be only from the current policy\n--train_skill_dynamics_on_policy=0\n\n### SHARED HYPERPARAMETERS ---------------------\n--environment=Humanoid-v1\n--max_env_steps=1000\n--reduced_observation=0\n\n# define the type of skills being learnt\n--num_skills=5\n--skill_type=cont_uniform\n--random_skills=100\n\n# number of skill-video evaluations\n--num_evals=3\n\n# (optional) policy, critic and skill dynamics\n--hidden_layer_size=1024\n\n# (optional) skill dynamics hyperparameters\n--graph_type=default\n--num_components=4\n--fix_variance=1\n--normalize_data=1\n\n# (optional) clip sampled actions\n--action_clipping=1.\n\n# (optional) debugging\n--debug=0\n\n### EVALUATION HYPERPARAMETERS -----------------\n--run_eval=0\n\n# MPC hyperparameters\n--planning_horizon=1\n--primitive_horizon=10\n--num_candidate_sequences=50\n--refine_steps=10\n--mppi_gamma=10\n--prior_type=normal\n--smoothing_beta=0.9\n--top_primitives=5\n\n\n### (optional) ENVIRONMENT SPECIFIC HYPERPARAMETERS --------\n# DKitty hyperparameters\n--expose_last_action=1\n--expose_upright=1\n--robot_noise_ratio=0.0\n--root_noise_ratio=0.0\n--upright_threshold=0.95\n--scale_root_position=1\n--randomize_hfield=0.0\n\n# DKitty/DClaw\n--observation_omission_size=0\n\n# Cube Manipulation hyperparameters\n--randomized_initial_distribution=1\n--horizontal_wrist_constraint=0.3\n--vertical_wrist_constraint=1.0\n"
  },
  {
    "path": "configs/template_config.txt",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n### TRAINING HYPERPARAMETERS -------------------\n--run_train=0\n\n# metadata flags\n--save_model=dads\n--save_freq=50\n--record_freq=100\n--vid_name=skill\n\n# optimization hyperparmaters\n--replay_buffer_capacity=100000\n\n# (set clear_buffer_iter=1 for on-policy)\n--clear_buffer_every_iter=0\n--initial_collect_steps=2000\n--collect_steps=1000\n--num_epochs=100\n\n# skill dynamics optimization hyperparameters\n--skill_dyn_train_steps=16\n--skill_dynamics_lr=3e-4\n--skill_dyn_batch_size=256\n\n# agent hyperparameters\n--agent_gamma=0.99\n--agent_lr=3e-4\n--agent_entropy=0.1\n--agent_train_steps=64\n--agent_batch_size=256\n\n# (optional, do not change for on-policy) relabelling or off-policy corrections\n--skill_dynamics_relabel_type=importance_sampling\n--num_samples_for_relabelling=1\n--is_clip_eps=1.\n\n# (optional) skills can be resampled within the episodes, relative to max_env_steps\n--min_steps_before_resample=2000\n--resample_prob=0.02\n\n# (optional) configure skill dynamics training samples to be only from the current policy\n--train_skill_dynamics_on_policy=0\n\n### SHARED HYPERPARAMETERS ---------------------\n--environment=<set_some_environment>\n--max_env_steps=200\n--reduced_observation=0\n\n# define the type of skills being learnt\n--num_skills=2\n--skill_type=cont_uniform\n--random_skills=100\n\n# number of skill-video evaluations\n--num_evals=3\n\n# (optional) policy, critic and skill dynamics\n--hidden_layer_size=512\n\n# (optional) skill dynamics hyperparameters\n--graph_type=default\n--num_components=4\n--fix_variance=1\n--normalize_data=1\n\n# (optional) clip sampled actions\n--action_clipping=1.\n\n# (optional) debugging\n--debug=0\n\n### EVALUATION HYPERPARAMETERS -----------------\n--run_eval=0\n\n# MPC hyperparameters\n--planning_horizon=1\n--primitive_horizon=10\n--num_candidate_sequences=50\n--refine_steps=10\n--mppi_gamma=10\n--prior_type=normal\n--smoothing_beta=0.9\n--top_primitives=5\n\n\n### (optional) ENVIRONMENT SPECIFIC HYPERPARAMETERS --------\n# DKitty hyperparameters\n--expose_last_action=1\n--expose_upright=1\n--robot_noise_ratio=0.0\n--root_noise_ratio=0.0\n--upright_threshold=0.95\n--scale_root_position=1\n--randomize_hfield=0.0\n\n# DKitty/DClaw\n--observation_omission_size=0\n\n# Cube Manipulation hyperparameters\n--randomized_initial_distribution=1\n--horizontal_wrist_constraint=0.3\n--vertical_wrist_constraint=1.0\n"
  },
  {
    "path": "env.yml",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nname: dads-env\nchannels:\n- defaults\n- conda-forge\ndependencies:\n- python=3.6.8\n- pip>=18.1\n- conda>=4.6.7\n- pip:\n  - numpy<2.0,>=1.16.0\n  - tensorflow-probability==0.10.0\n  - tensorflow==2.2.0\n  - tf-agents==0.4.0\n  - tensorflow-estimator==2.2.0\n  - gym==0.11.0\n  - matplotlib==3.0.2\n  - robel==0.1.2\n  - mujoco-py==2.0.2.5\n  - click\n  - transforms3d\n"
  },
  {
    "path": "envs/assets/ant.xml",
    "content": "<!-- ======================================================\n# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n====================================================== -->\n\n<mujoco model=\"ant\">\n  <compiler angle=\"degree\" coordinate=\"local\" inertiafromgeom=\"true\"/>\n  <option integrator=\"RK4\" timestep=\"0.01\"/>\n  <custom>\n    <numeric data=\"0.0 0.0 0.55 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -1.0 0.0 -1.0 0.0 1.0\" name=\"init_qpos\"/>\n  </custom>\n  <default>\n    <joint armature=\"1\" damping=\"1\" limited=\"true\"/>\n    <geom conaffinity=\"0\" condim=\"3\" density=\"5.0\" friction=\"1 0.5 0.5\" margin=\"0.01\" rgba=\"0.8 0.6 0.4 1\"/>\n  </default>\n  <asset>\n    <texture builtin=\"gradient\" height=\"100\" rgb1=\"1 1 1\" rgb2=\"0 0 0\" type=\"skybox\" width=\"100\"/>\n    <texture builtin=\"flat\" height=\"1278\" mark=\"cross\" markrgb=\"1 1 1\" name=\"texgeom\" random=\"0.01\" rgb1=\"0.8 0.6 0.4\" rgb2=\"0.8 0.6 0.4\" type=\"cube\" width=\"127\"/>\n    <texture builtin=\"checker\" height=\"100\" name=\"texplane\" rgb1=\"0 0 0\" rgb2=\"0.8 0.8 0.8\" type=\"2d\" width=\"100\"/>\n    <material name=\"MatPlane\" reflectance=\"0.5\" shininess=\"1\" specular=\"1\" texrepeat=\"60 60\" texture=\"texplane\"/>\n    <material name=\"geom\" texture=\"texgeom\" texuniform=\"true\"/>\n  </asset>\n  <worldbody>\n    <light cutoff=\"100\" diffuse=\"1 1 1\" dir=\"-0 0 -1.3\" directional=\"true\" exponent=\"1\" pos=\"0 0 1.3\" specular=\".1 .1 .1\"/>\n    <geom conaffinity=\"1\" condim=\"3\" material=\"MatPlane\" name=\"floor\" pos=\"0 0 0\" rgba=\"0.8 0.9 0.8 1\" size=\"40 40 40\" type=\"plane\"/>\n    <body name=\"torso\" pos=\"0 0 0.75\">\n      <camera name=\"track\" mode=\"trackcom\" pos=\"0 -3 0.3\" xyaxes=\"1 0 0 0 0 1\"/>\n      <geom name=\"torso_geom\" pos=\"0 0 0\" size=\"0.25\" type=\"sphere\"/>\n      <joint armature=\"0\" damping=\"0\" limited=\"false\" margin=\"0.01\" name=\"root\" pos=\"0 0 0\" type=\"free\"/>\n      <body name=\"front_left_leg\" pos=\"0 0 0\">\n        <geom fromto=\"0.0 0.0 0.0 0.2 0.2 0.0\" name=\"aux_1_geom\" size=\"0.08\" type=\"capsule\"/>\n        <body name=\"aux_1\" pos=\"0.2 0.2 0\">\n          <joint axis=\"0 0 1\" name=\"hip_1\" pos=\"0.0 0.0 0.0\" range=\"-30 30\" type=\"hinge\"/>\n          <geom fromto=\"0.0 0.0 0.0 0.2 0.2 0.0\" name=\"left_leg_geom\" size=\"0.08\" type=\"capsule\"/>\n          <body pos=\"0.2 0.2 0\">\n            <joint axis=\"-1 1 0\" name=\"ankle_1\" pos=\"0.0 0.0 0.0\" range=\"30 70\" type=\"hinge\"/>\n            <geom fromto=\"0.0 0.0 0.0 0.4 0.4 0.0\" name=\"left_ankle_geom\" size=\"0.08\" type=\"capsule\"/>\n          </body>\n        </body>\n      </body>\n      <body name=\"front_right_leg\" pos=\"0 0 0\">\n        <geom fromto=\"0.0 0.0 0.0 -0.2 0.2 0.0\" name=\"aux_2_geom\" size=\"0.08\" type=\"capsule\"/>\n        <body name=\"aux_2\" pos=\"-0.2 0.2 0\">\n          <joint axis=\"0 0 1\" name=\"hip_2\" pos=\"0.0 0.0 0.0\" range=\"-30 30\" type=\"hinge\"/>\n          <geom fromto=\"0.0 0.0 0.0 -0.2 0.2 0.0\" name=\"right_leg_geom\" size=\"0.08\" type=\"capsule\"/>\n          <body pos=\"-0.2 0.2 0\">\n            <joint axis=\"1 1 0\" name=\"ankle_2\" pos=\"0.0 0.0 0.0\" range=\"-70 -30\" type=\"hinge\"/>\n            <geom fromto=\"0.0 0.0 0.0 -0.4 0.4 0.0\" name=\"right_ankle_geom\" size=\"0.08\" type=\"capsule\"/>\n          </body>\n        </body>\n      </body>\n      <body name=\"back_leg\" pos=\"0 0 0\">\n        <geom fromto=\"0.0 0.0 0.0 -0.2 -0.2 0.0\" name=\"aux_3_geom\" size=\"0.08\" type=\"capsule\"/>\n        <body name=\"aux_3\" pos=\"-0.2 -0.2 0\">\n          <joint axis=\"0 0 1\" name=\"hip_3\" pos=\"0.0 0.0 0.0\" range=\"-30 30\" type=\"hinge\"/>\n          <geom fromto=\"0.0 0.0 0.0 -0.2 -0.2 0.0\" name=\"back_leg_geom\" size=\"0.08\" type=\"capsule\"/>\n          <body pos=\"-0.2 -0.2 0\">\n            <joint axis=\"-1 1 0\" name=\"ankle_3\" pos=\"0.0 0.0 0.0\" range=\"-70 -30\" type=\"hinge\"/>\n            <geom fromto=\"0.0 0.0 0.0 -0.4 -0.4 0.0\" name=\"third_ankle_geom\" size=\"0.08\" type=\"capsule\"/>\n          </body>\n        </body>\n      </body>\n      <body name=\"right_back_leg\" pos=\"0 0 0\">\n        <geom fromto=\"0.0 0.0 0.0 0.2 -0.2 0.0\" name=\"aux_4_geom\" size=\"0.08\" type=\"capsule\"/>\n        <body name=\"aux_4\" pos=\"0.2 -0.2 0\">\n          <joint axis=\"0 0 1\" name=\"hip_4\" pos=\"0.0 0.0 0.0\" range=\"-30 30\" type=\"hinge\"/>\n          <geom fromto=\"0.0 0.0 0.0 0.2 -0.2 0.0\" name=\"rightback_leg_geom\" size=\"0.08\" type=\"capsule\"/>\n          <body pos=\"0.2 -0.2 0\">\n            <joint axis=\"1 1 0\" name=\"ankle_4\" pos=\"0.0 0.0 0.0\" range=\"30 70\" type=\"hinge\"/>\n            <geom fromto=\"0.0 0.0 0.0 0.4 -0.4 0.0\" name=\"fourth_ankle_geom\" size=\"0.08\" type=\"capsule\"/>\n          </body>\n        </body>\n      </body>\n    </body>\n  </worldbody>\n  <actuator>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"hip_4\" gear=\"150\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"ankle_4\" gear=\"150\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"hip_1\" gear=\"150\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"ankle_1\" gear=\"150\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"hip_2\" gear=\"150\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"ankle_2\" gear=\"150\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"hip_3\" gear=\"150\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"ankle_3\" gear=\"150\"/>\n  </actuator>\n</mujoco>\n"
  },
  {
    "path": "envs/assets/ant_footsensor.xml",
    "content": "<!-- ======================================================\n# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n====================================================== -->\n\n<mujoco model=\"ant\">\n  <compiler angle=\"degree\" coordinate=\"local\" inertiafromgeom=\"true\"/>\n  <option integrator=\"RK4\" timestep=\"0.01\"/>\n  <custom>\n    <numeric data=\"0.0 0.0 0.55 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -1.0 0.0 -1.0 0.0 1.0\" name=\"init_qpos\"/>\n  </custom>\n  <default>\n    <joint armature=\"1\" damping=\"1\" limited=\"true\"/>\n    <geom conaffinity=\"0\" condim=\"3\" density=\"5.0\" friction=\"1 0.5 0.5\" margin=\"0.01\" rgba=\"0.8 0.6 0.4 1\"/>\n  </default>\n  <asset>\n    <texture builtin=\"gradient\" height=\"100\" rgb1=\"1 1 1\" rgb2=\"0 0 0\" type=\"skybox\" width=\"100\"/>\n    <texture builtin=\"flat\" height=\"1278\" mark=\"cross\" markrgb=\"1 1 1\" name=\"texgeom\" random=\"0.01\" rgb1=\"0.8 0.6 0.4\" rgb2=\"0.8 0.6 0.4\" type=\"cube\" width=\"127\"/>\n    <texture builtin=\"checker\" height=\"100\" name=\"texplane\" rgb1=\"0 0 0\" rgb2=\"0.8 0.8 0.8\" type=\"2d\" width=\"100\"/>\n    <material name=\"MatPlane\" reflectance=\"0.5\" shininess=\"1\" specular=\"1\" texrepeat=\"60 60\" texture=\"texplane\"/>\n    <material name=\"geom\" texture=\"texgeom\" texuniform=\"true\"/>\n  </asset>\n  <worldbody>\n    <light cutoff=\"100\" diffuse=\"1 1 1\" dir=\"-0 0 -1.3\" directional=\"true\" exponent=\"1\" pos=\"0 0 1.3\" specular=\".1 .1 .1\"/>\n    <geom conaffinity=\"1\" condim=\"3\" material=\"MatPlane\" name=\"floor\" pos=\"0 0 0\" rgba=\"0.8 0.9 0.8 1\" size=\"40 40 40\" type=\"plane\"/>\n    <body name=\"torso\" pos=\"0 0 0.75\">\n      <camera name=\"track\" mode=\"trackcom\" pos=\"0 -3 0.3\" xyaxes=\"1 0 0 0 0 1\"/>\n      <geom name=\"torso_geom\" pos=\"0 0 0\" size=\"0.25\" type=\"sphere\"/>\n      <joint armature=\"0\" damping=\"0\" limited=\"false\" margin=\"0.01\" name=\"root\" pos=\"0 0 0\" type=\"free\"/>\n      <body name=\"front_left_leg\" pos=\"0 0 0\">\n        <geom fromto=\"0.0 0.0 0.0 0.2 0.2 0.0\" name=\"aux_1_geom\" size=\"0.08\" type=\"capsule\"/>\n        <body name=\"aux_1\" pos=\"0.2 0.2 0\">\n          <joint axis=\"0 0 1\" name=\"hip_1\" pos=\"0.0 0.0 0.0\" range=\"-30 30\" type=\"hinge\"/>\n          <geom fromto=\"0.0 0.0 0.0 0.2 0.2 0.0\" name=\"left_leg_geom\" size=\"0.08\" type=\"capsule\"/>\n          <body pos=\"0.2 0.2 0\">\n            <joint axis=\"-1 1 0\" name=\"ankle_1\" pos=\"0.0 0.0 0.0\" range=\"30 70\" type=\"hinge\"/>\n            <geom fromto=\"0.0 0.0 0.0 0.4 0.4 0.0\" name=\"left_ankle_geom\" size=\"0.08\" type=\"capsule\"/>\n            <site name='front_left_leg' pos=\"0.4 0.4 0.0\" type='sphere' size='.1' rgba='1 1 0 .5'/>\n          </body>\n        </body>\n      </body>\n      <body name=\"front_right_leg\" pos=\"0 0 0\">\n        <geom fromto=\"0.0 0.0 0.0 -0.2 0.2 0.0\" name=\"aux_2_geom\" size=\"0.08\" type=\"capsule\"/>\n        <body name=\"aux_2\" pos=\"-0.2 0.2 0\">\n          <joint axis=\"0 0 1\" name=\"hip_2\" pos=\"0.0 0.0 0.0\" range=\"-30 30\" type=\"hinge\"/>\n          <geom fromto=\"0.0 0.0 0.0 -0.2 0.2 0.0\" name=\"right_leg_geom\" size=\"0.08\" type=\"capsule\"/>\n          <body pos=\"-0.2 0.2 0\">\n            <joint axis=\"1 1 0\" name=\"ankle_2\" pos=\"0.0 0.0 0.0\" range=\"-70 -30\" type=\"hinge\"/>\n            <geom fromto=\"0.0 0.0 0.0 -0.4 0.4 0.0\" name=\"right_ankle_geom\" size=\"0.08\" type=\"capsule\"/>\n            <site name='front_right_leg' pos=\"-0.4 0.4 0.0\" type='sphere' size='.1' rgba='1 1 0 .5'/>\n          </body>\n        </body>\n      </body>\n      <body name=\"back_leg\" pos=\"0 0 0\">\n        <geom fromto=\"0.0 0.0 0.0 -0.2 -0.2 0.0\" name=\"aux_3_geom\" size=\"0.08\" type=\"capsule\"/>\n        <body name=\"aux_3\" pos=\"-0.2 -0.2 0\">\n          <joint axis=\"0 0 1\" name=\"hip_3\" pos=\"0.0 0.0 0.0\" range=\"-30 30\" type=\"hinge\"/>\n          <geom fromto=\"0.0 0.0 0.0 -0.2 -0.2 0.0\" name=\"back_leg_geom\" size=\"0.08\" type=\"capsule\"/>\n          <body pos=\"-0.2 -0.2 0\">\n            <joint axis=\"-1 1 0\" name=\"ankle_3\" pos=\"0.0 0.0 0.0\" range=\"-70 -30\" type=\"hinge\"/>\n            <geom fromto=\"0.0 0.0 0.0 -0.4 -0.4 0.0\" name=\"third_ankle_geom\" size=\"0.08\" type=\"capsule\"/>\n            <site name='left_back_leg' pos=\"-0.4 -0.4 0.0\" type='sphere' size='.1' rgba='1 1 0 .5'/>\n          </body>\n        </body>\n      </body>\n      <body name=\"right_back_leg\" pos=\"0 0 0\">\n        <geom fromto=\"0.0 0.0 0.0 0.2 -0.2 0.0\" name=\"aux_4_geom\" size=\"0.08\" type=\"capsule\"/>\n        <body name=\"aux_4\" pos=\"0.2 -0.2 0\">\n          <joint axis=\"0 0 1\" name=\"hip_4\" pos=\"0.0 0.0 0.0\" range=\"-30 30\" type=\"hinge\"/>\n          <geom fromto=\"0.0 0.0 0.0 0.2 -0.2 0.0\" name=\"rightback_leg_geom\" size=\"0.08\" type=\"capsule\"/>\n          <body pos=\"0.2 -0.2 0\">\n            <joint axis=\"1 1 0\" name=\"ankle_4\" pos=\"0.0 0.0 0.0\" range=\"30 70\" type=\"hinge\"/>\n            <geom fromto=\"0.0 0.0 0.0 0.4 -0.4 0.0\" name=\"fourth_ankle_geom\" size=\"0.08\" type=\"capsule\"/>\n            <site name='right_back_leg' pos=\"0.4 -0.4 0.0\" type='sphere' size='.1' rgba='1 1 0 .5'/>\n          </body>\n        </body>\n      </body>\n    </body>\n  </worldbody>\n  <actuator>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"hip_4\" gear=\"150\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"ankle_4\" gear=\"150\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"hip_1\" gear=\"150\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"ankle_1\" gear=\"150\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"hip_2\" gear=\"150\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"ankle_2\" gear=\"150\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"hip_3\" gear=\"150\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"ankle_3\" gear=\"150\"/>\n  </actuator>\n\n  <sensor>\n    <touch name='front_left_leg' site='front_left_leg'/>\n    <touch name='front_right_leg' site='front_right_leg'/>\n    <touch name='left_back_leg' site='left_back_leg'/>\n    <touch name='right_back_leg' site='right_back_leg'/>\n  </sensor>\n</mujoco>\n"
  },
  {
    "path": "envs/assets/half_cheetah.xml",
    "content": "<!-- ======================================================\n# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n====================================================== -->\n\n<!-- Cheetah Model\n\n    The state space is populated with joints in the order that they are\n    defined in this file. The actuators also operate on joints.\n\n    State-Space (name/joint/parameter):\n        - rootx     slider      position (m)\n        - rootz     slider      position (m)\n        - rooty     hinge       angle (rad)\n        - bthigh    hinge       angle (rad)\n        - bshin     hinge       angle (rad)\n        - bfoot     hinge       angle (rad)\n        - fthigh    hinge       angle (rad)\n        - fshin     hinge       angle (rad)\n        - ffoot     hinge       angle (rad)\n        - rootx     slider      velocity (m/s)\n        - rootz     slider      velocity (m/s)\n        - rooty     hinge       angular velocity (rad/s)\n        - bthigh    hinge       angular velocity (rad/s)\n        - bshin     hinge       angular velocity (rad/s)\n        - bfoot     hinge       angular velocity (rad/s)\n        - fthigh    hinge       angular velocity (rad/s)\n        - fshin     hinge       angular velocity (rad/s)\n        - ffoot     hinge       angular velocity (rad/s)\n\n    Actuators (name/actuator/parameter):\n        - bthigh    hinge       torque (N m)\n        - bshin     hinge       torque (N m)\n        - bfoot     hinge       torque (N m)\n        - fthigh    hinge       torque (N m)\n        - fshin     hinge       torque (N m)\n        - ffoot     hinge       torque (N m)\n\n-->\n<mujoco model=\"cheetah\">\n  <compiler angle=\"radian\" coordinate=\"local\" inertiafromgeom=\"true\" settotalmass=\"14\"/>\n  <default>\n    <joint armature=\".1\" damping=\".01\" limited=\"true\" solimplimit=\"0 .8 .03\" solreflimit=\".02 1\" stiffness=\"8\"/>\n    <geom conaffinity=\"0\" condim=\"3\" contype=\"1\" friction=\".4 .1 .1\" rgba=\"0.8 0.6 .4 1\" solimp=\"0.0 0.8 0.01\" solref=\"0.02 1\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1 1\"/>\n  </default>\n  <size nstack=\"300000\" nuser_geom=\"1\"/>\n  <option gravity=\"0 0 -9.81\" timestep=\"0.01\"/>\n  <asset>\n    <texture builtin=\"gradient\" height=\"100\" rgb1=\"1 1 1\" rgb2=\"0 0 0\" type=\"skybox\" width=\"100\"/>\n    <texture builtin=\"flat\" height=\"1278\" mark=\"cross\" markrgb=\"1 1 1\" name=\"texgeom\" random=\"0.01\" rgb1=\"0.8 0.6 0.4\" rgb2=\"0.8 0.6 0.4\" type=\"cube\" width=\"127\"/>\n    <texture builtin=\"checker\" height=\"100\" name=\"texplane\" rgb1=\"0 0 0\" rgb2=\"0.8 0.8 0.8\" type=\"2d\" width=\"100\"/>\n    <material name=\"MatPlane\" reflectance=\"0.5\" shininess=\"1\" specular=\"1\" texrepeat=\"60 60\" texture=\"texplane\"/>\n    <material name=\"geom\" texture=\"texgeom\" texuniform=\"true\"/>\n  </asset>\n  <worldbody>\n    <light cutoff=\"100\" diffuse=\"1 1 1\" dir=\"-0 0 -1.3\" directional=\"true\" exponent=\"1\" pos=\"0 0 1.3\" specular=\".1 .1 .1\"/>\n    <geom conaffinity=\"1\" condim=\"3\" material=\"MatPlane\" name=\"floor\" pos=\"0 0 0\" rgba=\"0.8 0.9 0.8 1\" size=\"40 40 40\" type=\"plane\"/>\n    <body name=\"torso\" pos=\"0 0 .7\">\n      <camera name=\"track\" mode=\"trackcom\" pos=\"0 -3 0.3\" xyaxes=\"1 0 0 0 0 1\"/>\n      <joint armature=\"0\" axis=\"1 0 0\" damping=\"0\" limited=\"false\" name=\"rootx\" pos=\"0 0 0\" stiffness=\"0\" type=\"slide\"/>\n      <joint armature=\"0\" axis=\"0 0 1\" damping=\"0\" limited=\"false\" name=\"rootz\" pos=\"0 0 0\" stiffness=\"0\" type=\"slide\"/>\n      <joint armature=\"0\" axis=\"0 1 0\" damping=\"0\" limited=\"false\" name=\"rooty\" pos=\"0 0 0\" stiffness=\"0\" type=\"hinge\"/>\n      <geom fromto=\"-.5 0 0 .5 0 0\" name=\"torso\" size=\"0.046\" type=\"capsule\"/>\n      <geom axisangle=\"0 1 0 .87\" name=\"head\" pos=\".6 0 .1\" size=\"0.046 .15\" type=\"capsule\"/>\n      <!-- <site name='tip'  pos='.15 0 .11'/>-->\n      <body name=\"bthigh\" pos=\"-.5 0 0\">\n        <joint axis=\"0 1 0\" damping=\"6\" name=\"bthigh\" pos=\"0 0 0\" range=\"-.52 1.05\" stiffness=\"240\" type=\"hinge\"/>\n        <geom axisangle=\"0 1 0 -3.8\" name=\"bthigh\" pos=\".1 0 -.13\" size=\"0.046 .145\" type=\"capsule\"/>\n        <body name=\"bshin\" pos=\".16 0 -.25\">\n          <joint axis=\"0 1 0\" damping=\"4.5\" name=\"bshin\" pos=\"0 0 0\" range=\"-.785 .785\" stiffness=\"180\" type=\"hinge\"/>\n          <geom axisangle=\"0 1 0 -2.03\" name=\"bshin\" pos=\"-.14 0 -.07\" rgba=\"0.9 0.6 0.6 1\" size=\"0.046 .15\" type=\"capsule\"/>\n          <body name=\"bfoot\" pos=\"-.28 0 -.14\">\n            <joint axis=\"0 1 0\" damping=\"3\" name=\"bfoot\" pos=\"0 0 0\" range=\"-.4 .785\" stiffness=\"120\" type=\"hinge\"/>\n            <geom axisangle=\"0 1 0 -.27\" name=\"bfoot\" pos=\".03 0 -.097\" rgba=\"0.9 0.6 0.6 1\" size=\"0.046 .094\" type=\"capsule\"/>\n          </body>\n        </body>\n      </body>\n      <body name=\"fthigh\" pos=\".5 0 0\">\n        <joint axis=\"0 1 0\" damping=\"4.5\" name=\"fthigh\" pos=\"0 0 0\" range=\"-1 .7\" stiffness=\"180\" type=\"hinge\"/>\n        <geom axisangle=\"0 1 0 .52\" name=\"fthigh\" pos=\"-.07 0 -.12\" size=\"0.046 .133\" type=\"capsule\"/>\n        <body name=\"fshin\" pos=\"-.14 0 -.24\">\n          <joint axis=\"0 1 0\" damping=\"3\" name=\"fshin\" pos=\"0 0 0\" range=\"-1.2 .87\" stiffness=\"120\" type=\"hinge\"/>\n          <geom axisangle=\"0 1 0 -.6\" name=\"fshin\" pos=\".065 0 -.09\" rgba=\"0.9 0.6 0.6 1\" size=\"0.046 .106\" type=\"capsule\"/>\n          <body name=\"ffoot\" pos=\".13 0 -.18\">\n            <joint axis=\"0 1 0\" damping=\"1.5\" name=\"ffoot\" pos=\"0 0 0\" range=\"-.5 .5\" stiffness=\"60\" type=\"hinge\"/>\n            <geom axisangle=\"0 1 0 -.6\" name=\"ffoot\" pos=\".045 0 -.07\" rgba=\"0.9 0.6 0.6 1\" size=\"0.046 .07\" type=\"capsule\"/>\n          </body>\n        </body>\n      </body>\n    </body>\n  </worldbody>\n  <actuator>\n    <motor gear=\"120\" joint=\"bthigh\" name=\"bthigh\"/>\n    <motor gear=\"90\" joint=\"bshin\" name=\"bshin\"/>\n    <motor gear=\"60\" joint=\"bfoot\" name=\"bfoot\"/>\n    <motor gear=\"120\" joint=\"fthigh\" name=\"fthigh\"/>\n    <motor gear=\"60\" joint=\"fshin\" name=\"fshin\"/>\n    <motor gear=\"30\" joint=\"ffoot\" name=\"ffoot\"/>\n  </actuator>\n</mujoco>\n"
  },
  {
    "path": "envs/assets/humanoid.xml",
    "content": "<!-- ======================================================\r\n# Copyright 2019 Google LLC\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#      http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n====================================================== -->\r\n\r\n<mujoco model=\"humanoid\">\r\n    <compiler angle=\"degree\" inertiafromgeom=\"true\"/>\r\n    <default>\r\n        <joint armature=\"1\" damping=\"1\" limited=\"true\"/>\r\n        <geom conaffinity=\"1\" condim=\"1\" contype=\"1\" margin=\"0.001\" material=\"geom\" rgba=\"0.8 0.6 .4 1\"/>\r\n        <motor ctrllimited=\"true\" ctrlrange=\"-.4 .4\"/>\r\n    </default>\r\n    <option integrator=\"RK4\" iterations=\"50\" solver=\"PGS\" timestep=\"0.003\">\r\n        <!-- <flags solverstat=\"enable\" energy=\"enable\"/>-->\r\n    </option>\r\n    <size nkey=\"5\" nuser_geom=\"1\"/>\r\n    <visual>\r\n        <map fogend=\"5\" fogstart=\"3\"/>\r\n    </visual>\r\n    <asset>\r\n        <texture builtin=\"gradient\" height=\"100\" rgb1=\".4 .5 .6\" rgb2=\"0 0 0\" type=\"skybox\" width=\"100\"/>\r\n        <!-- <texture builtin=\"gradient\" height=\"100\" rgb1=\"1 1 1\" rgb2=\"0 0 0\" type=\"skybox\" width=\"100\"/>-->\r\n        <texture builtin=\"flat\" height=\"1278\" mark=\"cross\" markrgb=\"1 1 1\" name=\"texgeom\" random=\"0.01\" rgb1=\"0.8 0.6 0.4\" rgb2=\"0.8 0.6 0.4\" type=\"cube\" width=\"127\"/>\r\n        <texture builtin=\"checker\" height=\"100\" name=\"texplane\" rgb1=\"0 0 0\" rgb2=\"0.8 0.8 0.8\" type=\"2d\" width=\"100\"/>\r\n        <material name=\"MatPlane\" reflectance=\"0.5\" shininess=\"1\" specular=\"1\" texrepeat=\"60 60\" texture=\"texplane\"/>\r\n        <material name=\"geom\" texture=\"texgeom\" texuniform=\"true\"/>\r\n    </asset>\r\n    <worldbody>\r\n        <light cutoff=\"100\" diffuse=\"1 1 1\" dir=\"-0 0 -1.3\" directional=\"true\" exponent=\"1\" pos=\"0 0 1.3\" specular=\".1 .1 .1\"/>\r\n        <geom condim=\"3\" friction=\"1 .1 .1\" material=\"MatPlane\" name=\"floor\" pos=\"0 0 0\" rgba=\"0.8 0.9 0.8 1\" size=\"20 20 0.125\" type=\"plane\"/>\r\n        <!-- <geom condim=\"3\" material=\"MatPlane\" name=\"floor\" pos=\"0 0 0\" size=\"10 10 0.125\" type=\"plane\"/>-->\r\n        <body name=\"torso\" pos=\"0 0 1.4\">\r\n            <camera name=\"track\" mode=\"trackcom\" pos=\"0 -4 0\" xyaxes=\"1 0 0 0 0 1\"/>\r\n            <joint armature=\"0\" damping=\"0\" limited=\"false\" name=\"root\" pos=\"0 0 0\" stiffness=\"0\" type=\"free\"/>\r\n            <geom fromto=\"0 -.07 0 0 .07 0\" name=\"torso1\" size=\"0.07\" type=\"capsule\"/>\r\n            <geom name=\"head\" pos=\"0 0 .19\" size=\".09\" type=\"sphere\" user=\"258\"/>\r\n            <geom fromto=\"-.01 -.06 -.12 -.01 .06 -.12\" name=\"uwaist\" size=\"0.06\" type=\"capsule\"/>\r\n            <body name=\"lwaist\" pos=\"-.01 0 -0.260\" quat=\"1.000 0 -0.002 0\">\r\n                <geom fromto=\"0 -.06 0 0 .06 0\" name=\"lwaist\" size=\"0.06\" type=\"capsule\"/>\r\n                <joint armature=\"0.02\" axis=\"0 0 1\" damping=\"5\" name=\"abdomen_z\" pos=\"0 0 0.065\" range=\"-45 45\" stiffness=\"20\" type=\"hinge\"/>\r\n                <joint armature=\"0.02\" axis=\"0 1 0\" damping=\"5\" name=\"abdomen_y\" pos=\"0 0 0.065\" range=\"-75 30\" stiffness=\"10\" type=\"hinge\"/>\r\n                <body name=\"pelvis\" pos=\"0 0 -0.165\" quat=\"1.000 0 -0.002 0\">\r\n                    <joint armature=\"0.02\" axis=\"1 0 0\" damping=\"5\" name=\"abdomen_x\" pos=\"0 0 0.1\" range=\"-35 35\" stiffness=\"10\" type=\"hinge\"/>\r\n                    <geom fromto=\"-.02 -.07 0 -.02 .07 0\" name=\"butt\" size=\"0.09\" type=\"capsule\"/>\r\n                    <body name=\"right_thigh\" pos=\"0 -0.1 -0.04\">\r\n                        <joint armature=\"0.01\" axis=\"1 0 0\" damping=\"5\" name=\"right_hip_x\" pos=\"0 0 0\" range=\"-25 5\" stiffness=\"10\" type=\"hinge\"/>\r\n                        <joint armature=\"0.01\" axis=\"0 0 1\" damping=\"5\" name=\"right_hip_z\" pos=\"0 0 0\" range=\"-60 35\" stiffness=\"10\" type=\"hinge\"/>\r\n                        <joint armature=\"0.0080\" axis=\"0 1 0\" damping=\"5\" name=\"right_hip_y\" pos=\"0 0 0\" range=\"-110 20\" stiffness=\"20\" type=\"hinge\"/>\r\n                        <geom fromto=\"0 0 0 0 0.01 -.34\" name=\"right_thigh1\" size=\"0.06\" type=\"capsule\"/>\r\n                        <body name=\"right_shin\" pos=\"0 0.01 -0.403\">\r\n                            <joint armature=\"0.0060\" axis=\"0 -1 0\" name=\"right_knee\" pos=\"0 0 .02\" range=\"-160 -2\" type=\"hinge\"/>\r\n                            <geom fromto=\"0 0 0 0 0 -.3\" name=\"right_shin1\" size=\"0.049\" type=\"capsule\"/>\r\n                            <body name=\"right_foot\" pos=\"0 0 -0.45\">\r\n                                <geom name=\"right_foot\" pos=\"0 0 0.1\" size=\"0.075\" type=\"sphere\" user=\"0\"/>\r\n                            </body>\r\n                        </body>\r\n                    </body>\r\n                    <body name=\"left_thigh\" pos=\"0 0.1 -0.04\">\r\n                        <joint armature=\"0.01\" axis=\"-1 0 0\" damping=\"5\" name=\"left_hip_x\" pos=\"0 0 0\" range=\"-25 5\" stiffness=\"10\" type=\"hinge\"/>\r\n                        <joint armature=\"0.01\" axis=\"0 0 -1\" damping=\"5\" name=\"left_hip_z\" pos=\"0 0 0\" range=\"-60 35\" stiffness=\"10\" type=\"hinge\"/>\r\n                        <joint armature=\"0.01\" axis=\"0 1 0\" damping=\"5\" name=\"left_hip_y\" pos=\"0 0 0\" range=\"-120 20\" stiffness=\"20\" type=\"hinge\"/>\r\n                        <geom fromto=\"0 0 0 0 -0.01 -.34\" name=\"left_thigh1\" size=\"0.06\" type=\"capsule\"/>\r\n                        <body name=\"left_shin\" pos=\"0 -0.01 -0.403\">\r\n                            <joint armature=\"0.0060\" axis=\"0 -1 0\" name=\"left_knee\" pos=\"0 0 .02\" range=\"-160 -2\" stiffness=\"1\" type=\"hinge\"/>\r\n                            <geom fromto=\"0 0 0 0 0 -.3\" name=\"left_shin1\" size=\"0.049\" type=\"capsule\"/>\r\n                            <body name=\"left_foot\" pos=\"0 0 -0.45\">\r\n                                <geom name=\"left_foot\" type=\"sphere\" size=\"0.075\" pos=\"0 0 0.1\" user=\"0\" />\r\n                            </body>\r\n                        </body>\r\n                    </body>\r\n                </body>\r\n            </body>\r\n            <body name=\"right_upper_arm\" pos=\"0 -0.17 0.06\">\r\n                <joint armature=\"0.0068\" axis=\"2 1 1\" name=\"right_shoulder1\" pos=\"0 0 0\" range=\"-85 60\" stiffness=\"1\" type=\"hinge\"/>\r\n                <joint armature=\"0.0051\" axis=\"0 -1 1\" name=\"right_shoulder2\" pos=\"0 0 0\" range=\"-85 60\" stiffness=\"1\" type=\"hinge\"/>\r\n                <geom fromto=\"0 0 0 .16 -.16 -.16\" name=\"right_uarm1\" size=\"0.04 0.16\" type=\"capsule\"/>\r\n                <body name=\"right_lower_arm\" pos=\".18 -.18 -.18\">\r\n                    <joint armature=\"0.0028\" axis=\"0 -1 1\" name=\"right_elbow\" pos=\"0 0 0\" range=\"-90 50\" stiffness=\"0\" type=\"hinge\"/>\r\n                    <geom fromto=\"0.01 0.01 0.01 .17 .17 .17\" name=\"right_larm\" size=\"0.031\" type=\"capsule\"/>\r\n                    <geom name=\"right_hand\" pos=\".18 .18 .18\" size=\"0.04\" type=\"sphere\"/>\r\n                    <camera pos=\"0 0 0\"/>\r\n                </body>\r\n            </body>\r\n            <body name=\"left_upper_arm\" pos=\"0 0.17 0.06\">\r\n                <joint armature=\"0.0068\" axis=\"2 -1 1\" name=\"left_shoulder1\" pos=\"0 0 0\" range=\"-60 85\" stiffness=\"1\" type=\"hinge\"/>\r\n                <joint armature=\"0.0051\" axis=\"0 1 1\" name=\"left_shoulder2\" pos=\"0 0 0\" range=\"-60 85\" stiffness=\"1\" type=\"hinge\"/>\r\n                <geom fromto=\"0 0 0 .16 .16 -.16\" name=\"left_uarm1\" size=\"0.04 0.16\" type=\"capsule\"/>\r\n                <body name=\"left_lower_arm\" pos=\".18 .18 -.18\">\r\n                    <joint armature=\"0.0028\" axis=\"0 -1 -1\" name=\"left_elbow\" pos=\"0 0 0\" range=\"-90 50\" stiffness=\"0\" type=\"hinge\"/>\r\n                    <geom fromto=\"0.01 -0.01 0.01 .17 -.17 .17\" name=\"left_larm\" size=\"0.031\" type=\"capsule\"/>\r\n                    <geom name=\"left_hand\" pos=\".18 -.18 .18\" size=\"0.04\" type=\"sphere\"/>\r\n                </body>\r\n            </body>\r\n        </body>\r\n    </worldbody>\r\n    <tendon>\r\n        <fixed name=\"left_hipknee\">\r\n            <joint coef=\"-1\" joint=\"left_hip_y\"/>\r\n            <joint coef=\"1\" joint=\"left_knee\"/>\r\n        </fixed>\r\n        <fixed name=\"right_hipknee\">\r\n            <joint coef=\"-1\" joint=\"right_hip_y\"/>\r\n            <joint coef=\"1\" joint=\"right_knee\"/>\r\n        </fixed>\r\n    </tendon>\r\n\r\n    <actuator>\r\n        <motor gear=\"100\" joint=\"abdomen_y\" name=\"abdomen_y\"/>\r\n        <motor gear=\"100\" joint=\"abdomen_z\" name=\"abdomen_z\"/>\r\n        <motor gear=\"100\" joint=\"abdomen_x\" name=\"abdomen_x\"/>\r\n        <motor gear=\"100\" joint=\"right_hip_x\" name=\"right_hip_x\"/>\r\n        <motor gear=\"100\" joint=\"right_hip_z\" name=\"right_hip_z\"/>\r\n        <motor gear=\"300\" joint=\"right_hip_y\" name=\"right_hip_y\"/>\r\n        <motor gear=\"200\" joint=\"right_knee\" name=\"right_knee\"/>\r\n        <motor gear=\"100\" joint=\"left_hip_x\" name=\"left_hip_x\"/>\r\n        <motor gear=\"100\" joint=\"left_hip_z\" name=\"left_hip_z\"/>\r\n        <motor gear=\"300\" joint=\"left_hip_y\" name=\"left_hip_y\"/>\r\n        <motor gear=\"200\" joint=\"left_knee\" name=\"left_knee\"/>\r\n        <motor gear=\"25\" joint=\"right_shoulder1\" name=\"right_shoulder1\"/>\r\n        <motor gear=\"25\" joint=\"right_shoulder2\" name=\"right_shoulder2\"/>\r\n        <motor gear=\"25\" joint=\"right_elbow\" name=\"right_elbow\"/>\r\n        <motor gear=\"25\" joint=\"left_shoulder1\" name=\"left_shoulder1\"/>\r\n        <motor gear=\"25\" joint=\"left_shoulder2\" name=\"left_shoulder2\"/>\r\n        <motor gear=\"25\" joint=\"left_elbow\" name=\"left_elbow\"/>\r\n    </actuator>\r\n</mujoco>\r\n"
  },
  {
    "path": "envs/assets/point.xml",
    "content": "<!-- ======================================================\n# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n====================================================== -->\n\n<mujoco>\n  <compiler angle=\"degree\" coordinate=\"local\" inertiafromgeom=\"true\"/>\n  <option integrator=\"RK4\" timestep=\"0.02\"/>\n  <default>\n    <joint armature=\"0\" damping=\"0\" limited=\"false\"/>\n    <geom conaffinity=\"0\" condim=\"3\" density=\"100\" friction=\"1 0.5 0.5\" margin=\"0\" rgba=\"0.8 0.6 0.4 1\"/>\n  </default>\n  <asset>\n    <texture builtin=\"gradient\" height=\"100\" rgb1=\"1 1 1\" rgb2=\"0 0 0\" type=\"skybox\" width=\"100\"/>\n    <texture builtin=\"flat\" height=\"1278\" mark=\"cross\" markrgb=\"1 1 1\" name=\"texgeom\" random=\"0.01\" rgb1=\"0.8 0.6 0.4\" rgb2=\"0.8 0.6 0.4\" type=\"cube\" width=\"127\"/>\n    <texture builtin=\"checker\" height=\"100\" name=\"texplane\" rgb1=\"0 0 0\" rgb2=\"0.8 0.8 0.8\" type=\"2d\" width=\"100\"/>\n    <material name=\"MatPlane\" reflectance=\"0.5\" shininess=\"1\" specular=\"1\" texrepeat=\"30 30\" texture=\"texplane\"/>\n    <material name=\"geom\" texture=\"texgeom\" texuniform=\"true\"/>\n  </asset>\n  <worldbody>\n    <light cutoff=\"100\" diffuse=\"1 1 1\" dir=\"-0 0 -1.3\" directional=\"true\" exponent=\"1\" pos=\"0 0 1.3\" specular=\".1 .1 .1\"/>\n    <geom conaffinity=\"1\" condim=\"3\" material=\"MatPlane\" name=\"floor\" pos=\"0 0 0\" rgba=\"0.8 0.9 0.8 1\" size=\"40 40 40\" type=\"plane\"/>\n    <body name=\"torso\" pos=\"0 0 0\">\n      <geom name=\"pointbody\" pos=\"0 0 0.5\" size=\"0.5\" type=\"sphere\"/>\n      <geom name=\"pointarrow\" pos=\"0.6 0 0.5\" size=\"0.5 0.1 0.1\" type=\"box\"/>\n      <joint axis=\"1 0 0\" name=\"ballx\" pos=\"0 0 0\" type=\"slide\"/>\n      <joint axis=\"0 1 0\" name=\"bally\" pos=\"0 0 0\" type=\"slide\"/>\n      <joint axis=\"0 0 1\" limited=\"false\" name=\"rot\" pos=\"0 0 0\" type=\"hinge\"/>\n    </body>\n  </worldbody>\n  <actuator>\n    <!-- Those are just dummy actuators for providing ranges -->\n    <motor ctrllimited=\"true\" ctrlrange=\"-1 1\" joint=\"ballx\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-0.25 0.25\" joint=\"rot\"/>\n  </actuator>\n</mujoco>\n"
  },
  {
    "path": "envs/dclaw.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Turn tasks with DClaw robots.\n\nThis is a single rotation of an object from an initial angle to a target angle.\n\"\"\"\n\nimport abc\nimport collections\nfrom typing import Dict, Optional, Sequence\n\nimport numpy as np\n\nfrom robel.components.robot.dynamixel_robot import DynamixelRobotState\nfrom robel.dclaw.base_env import BaseDClawObjectEnv\nfrom robel.simulation.randomize import SimRandomizer\nfrom robel.utils.configurable import configurable\nfrom robel.utils.resources import get_asset_path\n\n# The observation keys that are concatenated as the environment observation.\nDEFAULT_OBSERVATION_KEYS = (\n    'object_x',\n    'object_y',\n    'claw_qpos',\n    'last_action',\n)\n\n# Reset pose for the claw joints.\nRESET_POSE = [0, -np.pi / 3, np.pi / 3] * 3\n\nDCLAW3_ASSET_PATH = 'robel/dclaw/assets/dclaw3xh_valve3_v0.xml'\n\n\nclass BaseDClawTurn(BaseDClawObjectEnv, metaclass=abc.ABCMeta):\n    \"\"\"Shared logic for DClaw turn tasks.\"\"\"\n\n    def __init__(self,\n                 asset_path: str = DCLAW3_ASSET_PATH,\n                 observation_keys: Sequence[str] = DEFAULT_OBSERVATION_KEYS,\n                 frame_skip: int = 40,\n                 **kwargs):\n        \"\"\"Initializes the environment.\n\n        Args:\n            asset_path: The XML model file to load.\n            observation_keys: The keys in `get_obs_dict` to concatenate as the\n                observations returned by `step` and `reset`.\n            frame_skip: The number of simulation steps per environment step.\n            interactive: If True, allows the hardware guide motor to freely\n                rotate and its current angle is used as the goal.\n            success_threshold: The difference threshold (in radians) of the\n                object position and the goal position within which we consider\n                as a sucesss.\n        \"\"\"\n        super().__init__(\n            sim_model=get_asset_path(asset_path),\n            observation_keys=observation_keys,\n            frame_skip=frame_skip,\n            **kwargs)\n\n        self._desired_claw_pos = RESET_POSE\n\n        # The following are modified (possibly every reset) by subclasses.\n        self._initial_object_pos = 0\n        self._initial_object_vel = 0\n\n    def _reset(self):\n        \"\"\"Resets the environment.\"\"\"\n        self._reset_dclaw_and_object(\n            claw_pos=RESET_POSE,\n            object_pos=self._initial_object_pos,\n            object_vel=self._initial_object_vel)\n\n    def _step(self, action: np.ndarray):\n        \"\"\"Applies an action to the robot.\"\"\"\n        self.robot.step({\n            'dclaw': action,\n        })\n\n    def get_obs_dict(self) -> Dict[str, np.ndarray]:\n        \"\"\"Returns the current observation of the environment.\n\n        Returns:\n            A dictionary of observation values. This should be an ordered\n            dictionary if `observation_keys` isn't set.\n        \"\"\"\n        claw_state, object_state = self.robot.get_state(\n            ['dclaw', 'object'])\n\n        obs_dict = collections.OrderedDict((\n            ('claw_qpos', claw_state.qpos),\n            ('claw_qvel', claw_state.qvel),\n            ('object_x', np.cos(object_state.qpos)),\n            ('object_y', np.sin(object_state.qpos)),\n            ('object_qvel', object_state.qvel),\n            ('last_action', self._get_last_action()),\n        ))\n        # Add hardware-specific state if present.\n        if isinstance(claw_state, DynamixelRobotState):\n            obs_dict['claw_current'] = claw_state.current\n\n        return obs_dict\n\n    def get_reward_dict(\n            self,\n            action: np.ndarray,\n            obs_dict: Dict[str, np.ndarray],\n    ) -> Dict[str, np.ndarray]:\n        \"\"\"Returns the reward for the given action and observation.\"\"\"\n        reward_dict = collections.OrderedDict(())\n        return reward_dict\n\n    def get_score_dict(\n            self,\n            obs_dict: Dict[str, np.ndarray],\n            reward_dict: Dict[str, np.ndarray],\n    ) -> Dict[str, np.ndarray]:\n        \"\"\"Returns a standardized measure of success for the environment.\"\"\"\n        return collections.OrderedDict(())\n\n    def get_done(\n            self,\n            obs_dict: Dict[str, np.ndarray],\n            reward_dict: Dict[str, np.ndarray],\n    ) -> np.ndarray:\n        \"\"\"Returns whether the episode should terminate.\"\"\"\n        return np.zeros_like([0], dtype=bool)\n\n\n@configurable(pickleable=True)\nclass DClawTurnRandom(BaseDClawTurn):\n    \"\"\"Turns the object with a random initial and random target position.\"\"\"\n\n    def _reset(self):\n        # Initial position is +/- 60 degrees.\n        self._initial_object_pos = self.np_random.uniform(\n            low=-np.pi / 3, high=np.pi / 3)\n        super()._reset()\n\n\n@configurable(pickleable=True)\nclass DClawTurnRandomDynamics(DClawTurnRandom):\n    \"\"\"Turns the object with a random initial and random target position.\n\n    The dynamics of the simulation are randomized each episode.\n    \"\"\"\n\n    def __init__(self,\n                 *args,\n                 sim_observation_noise: Optional[float] = 0.05,\n                 **kwargs):\n        super().__init__(\n            *args, sim_observation_noise=sim_observation_noise, **kwargs)\n        self._randomizer = SimRandomizer(self)\n        self._dof_indices = (\n            self.robot.get_config('dclaw').qvel_indices.tolist() +\n            self.robot.get_config('object').qvel_indices.tolist())\n\n    def _reset(self):\n        # Randomize joint dynamics.\n        self._randomizer.randomize_dofs(\n            self._dof_indices,\n            damping_range=(0.005, 0.1),\n            friction_loss_range=(0.001, 0.005),\n        )\n        self._randomizer.randomize_actuators(\n            all_same=True,\n            kp_range=(1, 3),\n        )\n        # Randomize friction on all geoms in the scene.\n        self._randomizer.randomize_geoms(\n            all_same=True,\n            friction_slide_range=(0.8, 1.2),\n            friction_spin_range=(0.003, 0.007),\n            friction_roll_range=(0.00005, 0.00015),\n        )\n        self._randomizer.randomize_bodies(\n            ['mount'],\n            position_perturb_range=(-0.01, 0.01),\n        )\n        self._randomizer.randomize_geoms(\n            ['mount'],\n            color_range=(0.2, 0.9),\n        )\n        self._randomizer.randomize_geoms(\n            parent_body_names=['valve'],\n            color_range=(0.2, 0.9),\n        )\n        super()._reset()\n"
  },
  {
    "path": "envs/dkitty_redesign.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"DKitty redesign\n\"\"\"\n\nimport abc\nimport collections\nfrom typing import Dict, Optional, Sequence, Tuple, Union\n\nimport numpy as np\n\nfrom robel.components.tracking import TrackerState\nfrom robel.dkitty.base_env import BaseDKittyUprightEnv\nfrom robel.simulation.randomize import SimRandomizer\nfrom robel.utils.configurable import configurable\nfrom robel.utils.math_utils import calculate_cosine\nfrom robel.utils.resources import get_asset_path\n\nDKITTY_ASSET_PATH = 'robel/dkitty/assets/dkitty_walk-v0.xml'\n\nDEFAULT_OBSERVATION_KEYS = (\n    'root_pos',\n    'root_euler',\n    'kitty_qpos',\n    # 'root_vel',\n    # 'root_angular_vel',\n    'kitty_qvel',\n    'last_action',\n    'upright',\n)\n\n\nclass BaseDKittyWalk(BaseDKittyUprightEnv, metaclass=abc.ABCMeta):\n    \"\"\"Shared logic for DKitty walk tasks.\"\"\"\n\n    def __init__(\n            self,\n            asset_path: str = DKITTY_ASSET_PATH,\n            observation_keys: Sequence[str] = DEFAULT_OBSERVATION_KEYS,\n            device_path: Optional[str] = None,\n            torso_tracker_id: Optional[Union[str, int]] = None,\n            frame_skip: int = 40,\n            sticky_action_probability: float = 0.,\n            upright_threshold: float = 0.9,\n            upright_reward: float = 1,\n            falling_reward: float = -500,\n            expose_last_action: bool = True,\n            expose_upright: bool = True,\n            robot_noise_ratio: float = 0.05,\n            **kwargs):\n        \"\"\"Initializes the environment.\n\n        Args:\n            asset_path: The XML model file to load.\n            observation_keys: The keys in `get_obs_dict` to concatenate as the\n                observations returned by `step` and `reset`.\n            device_path: The device path to Dynamixel hardware.\n            torso_tracker_id: The device index or serial of the tracking device\n                for the D'Kitty torso.\n            frame_skip: The number of simulation steps per environment step.\n            sticky_action_probability: Repeat previous action with this\n                probability. Default 0 (no sticky actions).\n            upright_threshold: The threshold (in [0, 1]) above which the D'Kitty\n                is considered to be upright. If the cosine similarity of the\n                D'Kitty's z-axis with the global z-axis is below this threshold,\n                the D'Kitty is considered to have fallen.\n            upright_reward: The reward multiplier for uprightedness.\n            falling_reward: The reward multipler for falling.\n        \"\"\"\n        self._expose_last_action = expose_last_action\n        self._expose_upright = expose_upright\n        observation_keys = observation_keys[:-2]\n        if self._expose_last_action:\n            observation_keys += ('last_action',)\n        if self._expose_upright:\n            observation_keys += ('upright',)\n\n        # robot_config = self.get_robot_config(device_path)\n        # if 'sim_observation_noise' in robot_config.keys():\n        #     robot_config['sim_observation_noise'] = robot_noise_ratio\n \n        super().__init__(\n            sim_model=get_asset_path(asset_path),\n            # robot_config=robot_config,\n            # tracker_config=self.get_tracker_config(\n            #     torso=torso_tracker_id,\n            # ),\n            observation_keys=observation_keys,\n            frame_skip=frame_skip,\n            upright_threshold=upright_threshold,\n            upright_reward=upright_reward,\n            falling_reward=falling_reward,\n            **kwargs)\n\n        self._last_action = np.zeros(12)\n        self._sticky_action_probability = sticky_action_probability\n        self._time_step = 0\n\n    def _reset(self):\n        \"\"\"Resets the environment.\"\"\"\n        self._reset_dkitty_standing()\n\n        # Set the tracker locations.\n        self.tracker.set_state({\n            'torso': TrackerState(pos=np.zeros(3), rot=np.identity(3)),\n        })\n\n        self._time_step = 0\n\n    def _step(self, action: np.ndarray):\n        \"\"\"Applies an action to the robot.\"\"\"\n        self._time_step += 1\n\n        # Sticky actions\n        rand = self.np_random.uniform() < self._sticky_action_probability\n        action_to_apply = np.where(rand, self._last_action, action)\n\n        # Apply action.\n        self.robot.step({\n            'dkitty': action_to_apply,\n        })\n        # Save the action to add to the observation.\n        self._last_action = action\n\n    def get_obs_dict(self) -> Dict[str, np.ndarray]:\n        \"\"\"Returns the current observation of the environment.\n\n        Returns:\n            A dictionary of observation values. This should be an ordered\n            dictionary if `observation_keys` isn't set.\n        \"\"\"\n        robot_state = self.robot.get_state('dkitty')\n        torso_track_state = self.tracker.get_state(\n            ['torso'])[0]\n        obs_dict = (('root_pos', torso_track_state.pos),\n                    ('root_euler', torso_track_state.rot_euler),\n                    ('root_vel', torso_track_state.vel),\n                    ('root_angular_vel', torso_track_state.angular_vel),\n                    ('kitty_qpos', robot_state.qpos),\n                    ('kitty_qvel', robot_state.qvel))\n\n        if self._expose_last_action:\n            obs_dict += (('last_action', self._last_action),)\n\n        # Add observation terms relating to being upright.\n        if self._expose_upright:\n            obs_dict += (*self._get_upright_obs(torso_track_state).items(),)\n\n        return collections.OrderedDict(obs_dict)\n\n    def get_reward_dict(\n            self,\n            action: np.ndarray,\n            obs_dict: Dict[str, np.ndarray],\n    ) -> Dict[str, np.ndarray]:\n        \"\"\"Returns the reward for the given action and observation.\"\"\"\n        reward_dict = collections.OrderedDict(())\n        return reward_dict\n\n    def get_score_dict(\n            self,\n            obs_dict: Dict[str, np.ndarray],\n            reward_dict: Dict[str, np.ndarray],\n    ) -> Dict[str, np.ndarray]:\n        \"\"\"Returns a standardized measure of success for the environment.\"\"\"\n        return collections.OrderedDict(())\n\n@configurable(pickleable=True)\nclass DKittyRandomDynamics(BaseDKittyWalk):\n    \"\"\"Walk straight towards a random location.\"\"\"\n\n    def __init__(self, *args, randomize_hfield=0.0, **kwargs):\n        super().__init__(*args, **kwargs)\n        self._randomizer = SimRandomizer(self)\n        self._randomize_hfield = randomize_hfield\n        self._dof_indices = (\n            self.robot.get_config('dkitty').qvel_indices.tolist())\n\n    def _reset(self):\n        \"\"\"Resets the environment.\"\"\"\n        # Randomize joint dynamics.\n        self._randomizer.randomize_dofs(\n            self._dof_indices,\n            all_same=True,\n            damping_range=(0.1, 0.2),\n            friction_loss_range=(0.001, 0.005),\n        )\n        self._randomizer.randomize_actuators(\n            all_same=True,\n            kp_range=(2.8, 3.2),\n        )\n        # Randomize friction on all geoms in the scene.\n        self._randomizer.randomize_geoms(\n            all_same=True,\n            friction_slide_range=(0.8, 1.2),\n            friction_spin_range=(0.003, 0.007),\n            friction_roll_range=(0.00005, 0.00015),\n        )\n        # Generate a random height field.\n        self._randomizer.randomize_global(\n            total_mass_range=(1.6, 2.0),\n            height_field_range=(0, self._randomize_hfield),\n        )\n        # if self._randomize_hfield > 0.0:\n        #     self.sim_scene.upload_height_field(0)\n        super()._reset()\n"
  },
  {
    "path": "envs/gym_mujoco/ant.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\n\nfrom gym import utils\nimport numpy as np\nfrom gym.envs.mujoco import mujoco_env\n\ndef q_inv(a):\n  return [a[0], -a[1], -a[2], -a[3]]\n\n\ndef q_mult(a, b):  # multiply two quaternion\n  w = a[0] * b[0] - a[1] * b[1] - a[2] * b[2] - a[3] * b[3]\n  i = a[0] * b[1] + a[1] * b[0] + a[2] * b[3] - a[3] * b[2]\n  j = a[0] * b[2] - a[1] * b[3] + a[2] * b[0] + a[3] * b[1]\n  k = a[0] * b[3] + a[1] * b[2] - a[2] * b[1] + a[3] * b[0]\n  return [w, i, j, k]\n\n# pylint: disable=missing-docstring\nclass AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):\n\n  def __init__(self,\n               task=\"forward\",\n               goal=None,\n               expose_all_qpos=False,\n               expose_body_coms=None,\n               expose_body_comvels=None,\n               expose_foot_sensors=False,\n               use_alt_path=False,\n               model_path=\"ant.xml\"):\n    self._task = task\n    self._goal = goal\n    self._expose_all_qpos = expose_all_qpos\n    self._expose_body_coms = expose_body_coms\n    self._expose_body_comvels = expose_body_comvels\n    self._expose_foot_sensors = expose_foot_sensors\n    self._body_com_indices = {}\n    self._body_comvel_indices = {}\n\n    # Settings from\n    # https://github.com/openai/gym/blob/master/gym/envs/__init__.py\n\n    xml_path = \"envs/assets/\"\n    model_path = os.path.abspath(os.path.join(xml_path, model_path))\n    mujoco_env.MujocoEnv.__init__(self, model_path, 5)\n    utils.EzPickle.__init__(self)\n\n  def compute_reward(self, ob, next_ob, action=None):\n    xposbefore = ob[:, 0]\n    yposbefore = ob[:, 1]\n    xposafter = next_ob[:, 0]\n    yposafter = next_ob[:, 1]\n\n    forward_reward = (xposafter - xposbefore) / self.dt\n    sideward_reward = (yposafter - yposbefore) / self.dt\n\n    if action is not None:\n      ctrl_cost = .5 * np.square(action).sum(axis=1)\n      survive_reward = 1.0\n    if self._task == \"forward\":\n      reward = forward_reward - ctrl_cost + survive_reward\n    elif self._task == \"backward\":\n      reward = -forward_reward - ctrl_cost + survive_reward\n    elif self._task == \"left\":\n      reward = sideward_reward - ctrl_cost + survive_reward\n    elif self._task == \"right\":\n      reward = -sideward_reward - ctrl_cost + survive_reward\n    elif self._task == \"goal\":\n      reward = -np.linalg.norm(\n          np.array([xposafter, yposafter]).T - self._goal, axis=1)\n\n    return reward\n\n  def step(self, a):\n    xposbefore = self.get_body_com(\"torso\")[0]\n    yposbefore = self.sim.data.qpos.flat[1]\n    self.do_simulation(a, self.frame_skip)\n    xposafter = self.get_body_com(\"torso\")[0]\n    yposafter = self.sim.data.qpos.flat[1]\n\n    forward_reward = (xposafter - xposbefore) / self.dt\n    sideward_reward = (yposafter - yposbefore) / self.dt\n\n    ctrl_cost = .5 * np.square(a).sum()\n    survive_reward = 1.0\n    if self._task == \"forward\":\n      reward = forward_reward - ctrl_cost + survive_reward\n    elif self._task == \"backward\":\n      reward = -forward_reward - ctrl_cost + survive_reward\n    elif self._task == \"left\":\n      reward = sideward_reward - ctrl_cost + survive_reward\n    elif self._task == \"right\":\n      reward = -sideward_reward - ctrl_cost + survive_reward\n    elif self._task == \"goal\":\n      reward = -np.linalg.norm(np.array([xposafter, yposafter]) - self._goal)\n    elif self._task == \"motion\":\n      reward = np.max(np.abs(np.array([forward_reward, sideward_reward\n                                      ]))) - ctrl_cost + survive_reward\n\n    state = self.state_vector()\n    notdone = np.isfinite(state).all()\n    done = not notdone\n    ob = self._get_obs()\n    return ob, reward, done, dict(\n        reward_forward=forward_reward,\n        reward_sideward=sideward_reward,\n        reward_ctrl=-ctrl_cost,\n        reward_survive=survive_reward)\n\n  def _get_obs(self):\n    # No crfc observation\n    if self._expose_all_qpos:\n      obs = np.concatenate([\n          self.sim.data.qpos.flat[:15],\n          self.sim.data.qvel.flat[:14],\n      ])\n    else:\n      obs = np.concatenate([\n          self.sim.data.qpos.flat[2:15],\n          self.sim.data.qvel.flat[:14],\n      ])\n\n    if self._expose_body_coms is not None:\n      for name in self._expose_body_coms:\n        com = self.get_body_com(name)\n        if name not in self._body_com_indices:\n          indices = range(len(obs), len(obs) + len(com))\n          self._body_com_indices[name] = indices\n        obs = np.concatenate([obs, com])\n\n    if self._expose_body_comvels is not None:\n      for name in self._expose_body_comvels:\n        comvel = self.get_body_comvel(name)\n        if name not in self._body_comvel_indices:\n          indices = range(len(obs), len(obs) + len(comvel))\n          self._body_comvel_indices[name] = indices\n        obs = np.concatenate([obs, comvel])\n\n    if self._expose_foot_sensors:\n      obs = np.concatenate([obs, self.sim.data.sensordata])\n    return obs\n\n  def reset_model(self):\n    qpos = self.init_qpos + self.np_random.uniform(\n        size=self.sim.model.nq, low=-.1, high=.1)\n    qvel = self.init_qvel + self.np_random.randn(self.sim.model.nv) * .1\n\n    qpos[15:] = self.init_qpos[15:]\n    qvel[14:] = 0.\n\n    self.set_state(qpos, qvel)\n    return self._get_obs()\n\n  def viewer_setup(self):\n    self.viewer.cam.distance = self.model.stat.extent * 2.5\n\n  def get_ori(self):\n    ori = [0, 1, 0, 0]\n    rot = self.sim.data.qpos[3:7]  # take the quaternion\n    ori = q_mult(q_mult(rot, ori), q_inv(rot))[1:3]  # project onto x-y plane\n    ori = math.atan2(ori[1], ori[0])\n    return ori\n\n  @property\n  def body_com_indices(self):\n    return self._body_com_indices\n\n  @property\n  def body_comvel_indices(self):\n    return self._body_comvel_indices\n"
  },
  {
    "path": "envs/gym_mujoco/half_cheetah.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\n\nfrom gym import utils\nimport numpy as np\nfrom gym.envs.mujoco import mujoco_env\n\n\nclass HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):\n\n  def __init__(self,\n               expose_all_qpos=False,\n               task='default',\n               target_velocity=None,\n               model_path='half_cheetah.xml'):\n    # Settings from\n    # https://github.com/openai/gym/blob/master/gym/envs/__init__.py\n    self._expose_all_qpos = expose_all_qpos\n    self._task = task\n    self._target_velocity = target_velocity\n\n    xml_path = \"envs/assets/\"\n    model_path = os.path.abspath(os.path.join(xml_path, model_path))\n\n    mujoco_env.MujocoEnv.__init__(\n        self,\n        model_path,\n        5)\n    utils.EzPickle.__init__(self)\n\n  def step(self, action):\n    xposbefore = self.sim.data.qpos[0]\n    self.do_simulation(action, self.frame_skip)\n    xposafter = self.sim.data.qpos[0]\n    xvelafter = self.sim.data.qvel[0]\n    ob = self._get_obs()\n    reward_ctrl = -0.1 * np.square(action).sum()\n\n    if self._task == 'default':\n      reward_vel = 0.\n      reward_run = (xposafter - xposbefore) / self.dt\n      reward = reward_ctrl + reward_run\n    elif self._task == 'target_velocity':\n      reward_vel = -(self._target_velocity - xvelafter)**2\n      reward = reward_ctrl + reward_vel\n    elif self._task == 'run_back':\n      reward_vel = 0.\n      reward_run = (xposbefore - xposafter) / self.dt\n      reward = reward_ctrl + reward_run\n\n    done = False\n    return ob, reward, done, dict(\n        reward_run=reward_run, reward_ctrl=reward_ctrl, reward_vel=reward_vel)\n\n  def _get_obs(self):\n    if self._expose_all_qpos:\n      return np.concatenate(\n          [self.sim.data.qpos.flat, self.sim.data.qvel.flat])\n    return np.concatenate([\n        self.sim.data.qpos.flat[1:],\n        self.sim.data.qvel.flat,\n    ])\n\n  def reset_model(self):\n    qpos = self.init_qpos + self.np_random.uniform(\n        low=-.1, high=.1, size=self.sim.model.nq)\n    qvel = self.init_qvel + self.np_random.randn(self.sim.model.nv) * .1\n    self.set_state(qpos, qvel)\n    return self._get_obs()\n\n  def viewer_setup(self):\n    self.viewer.cam.distance = self.model.stat.extent * 0.5\n"
  },
  {
    "path": "envs/gym_mujoco/humanoid.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\n\nfrom gym import utils\nimport numpy as np\nfrom gym.envs.mujoco import mujoco_env\n\n\ndef mass_center(sim):\n  mass = np.expand_dims(sim.model.body_mass, 1)\n  xpos = sim.data.xipos\n  return (np.sum(mass * xpos, 0) / np.sum(mass))[0]\n\n\n# pylint: disable=missing-docstring\nclass HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):\n\n  def __init__(self, \n               expose_all_qpos=False,\n               model_path='humanoid.xml',\n               task=None,\n               goal=None):\n\n    self._task = task\n    self._goal = goal\n    if self._task == \"follow_goals\":\n      self._goal_list = [\n          np.array([3.0, -0.5]),\n          np.array([6.0, 8.0]),\n          np.array([12.0, 12.0]),\n      ]\n      self._goal = self._goal_list[0]\n      print(\"Following a trajectory of goals:\", self._goal_list)\n\n    self._expose_all_qpos = expose_all_qpos\n    xml_path = \"envs/assets/\"\n    model_path = os.path.abspath(os.path.join(xml_path, model_path))\n    mujoco_env.MujocoEnv.__init__(self, model_path, 5)\n    utils.EzPickle.__init__(self)\n\n  def _get_obs(self):\n    data = self.sim.data\n    if self._expose_all_qpos:\n      return np.concatenate([\n          data.qpos.flat, data.qvel.flat,\n          # data.cinert.flat, data.cvel.flat,\n          # data.qfrc_actuator.flat, data.cfrc_ext.flat\n      ])\n    return np.concatenate([\n        data.qpos.flat[2:], data.qvel.flat, data.cinert.flat, data.cvel.flat,\n        data.qfrc_actuator.flat, data.cfrc_ext.flat\n    ])\n\n  def compute_reward(self, ob, next_ob, action=None):\n    xposbefore = ob[:, 0]\n    yposbefore = ob[:, 1]\n    xposafter = next_ob[:, 0]\n    yposafter = next_ob[:, 1]\n\n    forward_reward = (xposafter - xposbefore) / self.dt\n    sideward_reward = (yposafter - yposbefore) / self.dt\n\n    if action is not None:\n      ctrl_cost = .5 * np.square(action).sum(axis=1)\n      survive_reward = 1.0\n    if self._task == \"forward\":\n      reward = forward_reward - ctrl_cost + survive_reward\n    elif self._task == \"backward\":\n      reward = -forward_reward - ctrl_cost + survive_reward\n    elif self._task == \"left\":\n      reward = sideward_reward - ctrl_cost + survive_reward\n    elif self._task == \"right\":\n      reward = -sideward_reward - ctrl_cost + survive_reward\n    elif self._task in [\"goal\", \"follow_goals\"]:\n      reward = -np.linalg.norm(\n          np.array([xposafter, yposafter]).T - self._goal, axis=1)\n    elif self._task in [\"sparse_goal\"]:\n      reward = (-np.linalg.norm(\n          np.array([xposafter, yposafter]).T - self._goal, axis=1) >\n                -0.3).astype(np.float32)\n    return reward\n\n  def step(self, a):\n    pos_before = mass_center(self.sim)\n    self.do_simulation(a, self.frame_skip)\n    pos_after = mass_center(self.sim)\n    alive_bonus = 5.0\n    data = self.sim.data\n    lin_vel_cost = 0.25 * (\n        pos_after - pos_before) / self.sim.model.opt.timestep\n    quad_ctrl_cost = 0.1 * np.square(data.ctrl).sum()\n    quad_impact_cost = .5e-6 * np.square(data.cfrc_ext).sum()\n    quad_impact_cost = min(quad_impact_cost, 10)\n    reward = lin_vel_cost - quad_ctrl_cost - quad_impact_cost + alive_bonus\n\n    if self._task == \"follow_goals\":\n      xposafter = self.sim.data.qpos.flat[0]\n      yposafter = self.sim.data.qpos.flat[1]\n      reward = -np.linalg.norm(np.array([xposafter, yposafter]).T - self._goal)\n      # update goal\n      if np.abs(reward) < 0.5:\n        self._goal = self._goal_list[0]\n        self._goal_list = self._goal_list[1:]\n        print(\"Goal Updated:\", self._goal)\n\n    elif self._task == \"goal\":\n      xposafter = self.sim.data.qpos.flat[0]\n      yposafter = self.sim.data.qpos.flat[1]\n      reward = -np.linalg.norm(np.array([xposafter, yposafter]).T - self._goal)\n\n    qpos = self.sim.data.qpos\n    done = bool((qpos[2] < 1.0) or (qpos[2] > 2.0))\n    return self._get_obs(), reward, done, dict(\n        reward_linvel=lin_vel_cost,\n        reward_quadctrl=-quad_ctrl_cost,\n        reward_alive=alive_bonus,\n        reward_impact=-quad_impact_cost)\n\n  def reset_model(self):\n    c = 0.01\n    self.set_state(\n        self.init_qpos + self.np_random.uniform(\n            low=-c, high=c, size=self.sim.model.nq),\n        self.init_qvel + self.np_random.uniform(\n            low=-c,\n            high=c,\n            size=self.sim.model.nv,\n        ))\n\n    if self._task == \"follow_goals\":\n      self._goal = self._goal_list[0]\n      self._goal_list = self._goal_list[1:]\n      print(\"Current goal:\", self._goal)\n\n    return self._get_obs()\n\n  def viewer_setup(self):\n    self.viewer.cam.distance = self.model.stat.extent * 2.0\n"
  },
  {
    "path": "envs/gym_mujoco/point_mass.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport math\nimport os\n\nfrom gym import utils\nimport numpy as np\nfrom gym.envs.mujoco import mujoco_env\n\n\n# pylint: disable=missing-docstring\nclass PointMassEnv(mujoco_env.MujocoEnv, utils.EzPickle):\n\n  def __init__(self,\n               target=None,\n               wiggly_weight=0.,\n               alt_xml=False,\n               expose_velocity=True,\n               expose_goal=True,\n               use_simulator=False,\n               model_path='point.xml'):\n    self._sample_target = target\n    if self._sample_target is not None:\n      self.goal = np.array([1.0, 1.0])\n\n    self._expose_velocity = expose_velocity\n    self._expose_goal = expose_goal\n    self._use_simulator = use_simulator\n    self._wiggly_weight = abs(wiggly_weight)\n    self._wiggle_direction = +1 if wiggly_weight > 0. else -1\n\n    xml_path = \"envs/assets/\"\n    model_path = os.path.abspath(os.path.join(xml_path, model_path))\n\n    if self._use_simulator:\n      mujoco_env.MujocoEnv.__init__(self, model_path, 5)\n    else:\n      mujoco_env.MujocoEnv.__init__(self, model_path, 1)\n    utils.EzPickle.__init__(self)\n\n  def step(self, action):\n    if self._use_simulator:\n      self.do_simulation(action, self.frame_skip)\n    else:\n      force = 0.2 * action[0]\n      rot = 1.0 * action[1]\n      qpos = self.sim.data.qpos.flat.copy()\n      qpos[2] += rot\n      ori = qpos[2]\n      dx = math.cos(ori) * force\n      dy = math.sin(ori) * force\n      qpos[0] = np.clip(qpos[0] + dx, -2, 2)\n      qpos[1] = np.clip(qpos[1] + dy, -2, 2)\n      qvel = self.sim.data.qvel.flat.copy()\n      self.set_state(qpos, qvel)\n\n    ob = self._get_obs()\n    if self._sample_target is not None and self.goal is not None:\n      reward = -np.linalg.norm(self.sim.data.qpos.flat[:2] - self.goal)**2\n    else:\n      reward = 0.\n\n    if self._wiggly_weight > 0.:\n      reward = (np.exp(-((-reward)**0.5))**(1. - self._wiggly_weight)) * (\n          max(self._wiggle_direction * action[1], 0)**self._wiggly_weight)\n    done = False\n    return ob, reward, done, None\n\n  def _get_obs(self):\n    new_obs = [self.sim.data.qpos.flat]\n    if self._expose_velocity:\n      new_obs += [self.sim.data.qvel.flat]\n    if self._expose_goal and self.goal is not None:\n      new_obs += [self.goal]\n    return np.concatenate(new_obs)\n\n  def reset_model(self):\n    qpos = self.init_qpos + np.append(\n        self.np_random.uniform(low=-.2, high=.2, size=2),\n        self.np_random.uniform(-np.pi, np.pi, size=1))\n    qvel = self.init_qvel + self.np_random.randn(self.sim.model.nv) * .01\n    if self._sample_target is not None:\n      self.goal = self._sample_target(qpos[:2])\n    self.set_state(qpos, qvel)\n    return self._get_obs()\n\n  # only works when goal is not exposed\n  def set_qpos(self, state):\n    qvel = np.copy(self.sim.data.qvel.flat)\n    self.set_state(state, qvel)\n\n  def viewer_setup(self):\n    self.viewer.cam.distance = self.model.stat.extent * 0.5\n"
  },
  {
    "path": "envs/hand_block.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport numpy as np\nimport gym\nimport os\nfrom gym import spaces\nfrom gym.envs.robotics.hand.manipulate import ManipulateEnv\nimport mujoco_py\n\nMANIPULATE_BLOCK_XML = os.path.join('hand', 'manipulate_block.xml')\n\nclass HandBlockCustomEnv(ManipulateEnv):\n\tdef __init__(self,\n\t\t\t\t model_path=MANIPULATE_BLOCK_XML,\n\t\t\t\t target_position='random',\n\t\t\t\t target_rotation='xyz',\n\t\t\t\t reward_type='sparse',\n\t\t\t\t horizontal_wrist_constraint=1.0,\n\t\t\t\t vertical_wrist_constraint=1.0,\n\t\t\t\t **kwargs):\n\t\tManipulateEnv.__init__(self,\n\t\t\t model_path=MANIPULATE_BLOCK_XML,\n\t\t\t target_position=target_position,\n\t\t\t target_rotation=target_rotation,\n\t\t\t target_position_range=np.array([(-0.04, 0.04), (-0.06, 0.02), (0.0, 0.06)]),\n\t\t\t reward_type=reward_type,\n\t\t\t **kwargs)\n\n\t\tself._viewers = {}\n\n\t\t# constraining the movement of wrist (vertical movement more important than horizontal)\n\t\tself.action_space.low[0] = -horizontal_wrist_constraint\n\t\tself.action_space.high[0] = horizontal_wrist_constraint\n\t\tself.action_space.low[1] = -vertical_wrist_constraint\n\t\tself.action_space.high[1] = vertical_wrist_constraint\n\n\tdef _get_viewer(self, mode):\n\t\tself.viewer = self._viewers.get(mode)\n\t\tif self.viewer is None:\n\t\t\tif mode == 'human':\n\t\t\t\tself.viewer = mujoco_py.MjViewer(self.sim)\n\t\t\telif mode == 'rgb_array':\n\t\t\t\tself.viewer = mujoco_py.MjRenderContextOffscreen(self.sim, device_id=-1)\n\t\t\t\tself._viewer_setup()\n\t\t\t\tself._viewers[mode] = self.viewer\n\t\treturn self.viewer\n\n\tdef _viewer_setup(self):\n\t\tbody_id = self.sim.model.body_name2id('robot0:palm')\n\t\tlookat = self.sim.data.body_xpos[body_id]\n\t\tfor idx, value in enumerate(lookat):\n\t\t\tself.viewer.cam.lookat[idx] = value\n\t\tself.viewer.cam.distance = 0.5\n\t\tself.viewer.cam.azimuth = 55.\n\t\tself.viewer.cam.elevation = -25.\n\n\tdef step(self, action):\n\t\t\n\t\tdef is_on_palm():\n\t\t\tself.sim.forward()\n\t\t\tcube_middle_idx = self.sim.model.site_name2id('object:center')\n\t\t\tcube_middle_pos = self.sim.data.site_xpos[cube_middle_idx]\n\t\t\tis_on_palm = (cube_middle_pos[2] > 0.04)\n\t\t\treturn is_on_palm\n\n\t\tobs, reward, done, info = super().step(action)\n\t\tdone = not is_on_palm()\n\t\treturn obs, reward, done, info\n\n\tdef render(self, mode='human', width=500, height=500):\n\t\tself._render_callback()\n\t\tif mode == 'rgb_array':\n\t\t\tself._get_viewer(mode).render(width, height)\n\t\t\t# window size used for old mujoco-py:\n\t\t\tdata = self._get_viewer(mode).read_pixels(width, height, depth=False)\n\t\t\t# original image is upside-down, so flip it\n\t\t\treturn data[::-1, :, :]\n\t\telif mode == 'human':\n\t\t\tself._get_viewer(mode).render()\n"
  },
  {
    "path": "envs/skill_wrapper.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\n\nimport gym\nfrom gym import Wrapper\n\nclass SkillWrapper(Wrapper):\n\n  def __init__(\n      self,\n      env,\n      # skill type and dimension\n      num_latent_skills=None,\n      skill_type='discrete_uniform',\n      # execute an episode with the same predefined skill, does not resample\n      preset_skill=None,\n      # resample skills within episode\n      min_steps_before_resample=10,\n      resample_prob=0.):\n\n    super(SkillWrapper, self).__init__(env)\n    self._skill_type = skill_type\n    if num_latent_skills is None:\n      self._num_skills = 0\n    else:\n      self._num_skills = num_latent_skills\n    self._preset_skill = preset_skill\n\n    # attributes for controlling skill resampling\n    self._min_steps_before_resample = min_steps_before_resample\n    self._resample_prob = resample_prob\n\n    if isinstance(self.env.observation_space, gym.spaces.Dict):\n      size = self.env.observation_space.spaces['observation'].shape[0] + self._num_skills\n    else:\n      size = self.env.observation_space.shape[0] + self._num_skills\n    self.observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(size,), dtype='float32')\n\n  def _remake_time_step(self, cur_obs):\n    if isinstance(self.env.observation_space, gym.spaces.Dict):\n      cur_obs = cur_obs['observation']\n\n    if self._num_skills == 0:\n      return cur_obs\n    else:\n      return np.concatenate([cur_obs, self.skill])\n\n  def _set_skill(self):\n    if self._num_skills:\n      if self._preset_skill is not None:\n        self.skill = self._preset_skill\n        print('Skill:', self.skill)\n      elif self._skill_type == 'discrete_uniform':\n        self.skill = np.random.multinomial(\n            1, [1. / self._num_skills] * self._num_skills)\n      elif self._skill_type == 'gaussian':\n        self.skill = np.random.multivariate_normal(\n            np.zeros(self._num_skills), np.eye(self._num_skills))\n      elif self._skill_type == 'cont_uniform':\n        self.skill = np.random.uniform(\n            low=-1.0, high=1.0, size=self._num_skills)\n\n  def reset(self):\n    cur_obs = self.env.reset()\n    self._set_skill()\n    self._step_count = 0\n    return self._remake_time_step(cur_obs)\n\n  def step(self, action):\n    cur_obs, reward, done, info = self.env.step(action)\n    self._step_count += 1\n    if self._preset_skill is None and self._step_count >= self._min_steps_before_resample and np.random.random(\n    ) < self._resample_prob:\n      self._set_skill()\n      self._step_count = 0\n    return self._remake_time_step(cur_obs), reward, done, info\n\n  def close(self):\n    return self.env.close()\n"
  },
  {
    "path": "envs/video_wrapper.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\n\nimport gym\nfrom gym import Wrapper\nfrom gym.wrappers.monitoring import video_recorder\n\nclass VideoWrapper(Wrapper):\n\n  def __init__(self, env, base_path, base_name=None, new_video_every_reset=False):\n    super(VideoWrapper, self).__init__(env)\n\n    self._base_path = base_path\n    self._base_name = base_name\n\n    self._new_video_every_reset = new_video_every_reset\n    if self._new_video_every_reset:\n      self._counter = 0\n      self._recorder = None\n    else:\n      if self._base_name is not None:\n        self._vid_name = os.path.join(self._base_path, self._base_name)\n      else:\n        self._vid_name = self._base_path\n      self._recorder = video_recorder.VideoRecorder(self.env, path=self._vid_name + '.mp4')\n\n  def reset(self):\n    if self._new_video_every_reset:\n      if self._recorder is not None:\n        self._recorder.close()\n\n      self._counter += 1\n      if self._base_name is not None:\n        self._vid_name = os.path.join(self._base_path, self._base_name + '_' + str(self._counter))\n      else:\n        self._vid_name = self._base_path + '_' + str(self._counter)\n\n      self._recorder = video_recorder.VideoRecorder(self.env, path=self._vid_name + '.mp4')\n\n    return self.env.reset()\n\n  def step(self, action):\n    self._recorder.capture_frame()\n    return self.env.step(action)\n\n  def close(self):\n    self._recorder.encoder.proc.stdin.flush()\n    self._recorder.close()\n    return self.env.close()"
  },
  {
    "path": "lib/py_tf_policy.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Converts TensorFlow Policies into Python Policies.\"\"\"\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom absl import logging\n\nimport tensorflow as tf\nfrom tf_agents.policies import py_policy\nfrom tf_agents.policies import tf_policy\nfrom tf_agents.specs import tensor_spec\nfrom tf_agents.trajectories import policy_step\nfrom tf_agents.utils import common\nfrom tf_agents.utils import nest_utils\nfrom tf_agents.utils import session_utils\n\n\nclass PyTFPolicy(py_policy.Base, session_utils.SessionUser):\n  \"\"\"Exposes a Python policy as wrapper over a TF Policy.\"\"\"\n\n  # TODO(damienv): currently, the initial policy state must be batched\n  # if batch_size is given. Without losing too much generality, the initial\n  # policy state could be the same for every element in the batch.\n  # In that case, the initial policy state could be given with no batch\n  # dimension.\n  # TODO(sfishman): Remove batch_size param entirely.\n  def __init__(self, policy, batch_size=None, seed=None):\n    \"\"\"Initializes a new `PyTFPolicy`.\n\n    Args:\n      policy: A TF Policy implementing `tf_policy.Base`.\n      batch_size: (deprecated)\n      seed: Seed to use if policy performs random actions (optional).\n    \"\"\"\n    if not isinstance(policy, tf_policy.Base):\n      logging.warning('Policy should implement tf_policy.Base')\n\n    if batch_size is not None:\n      logging.warning('In PyTFPolicy constructor, `batch_size` is deprecated, '\n                      'this parameter has no effect. This argument will be '\n                      'removed on 2019-05-01')\n\n    time_step_spec = tensor_spec.to_nest_array_spec(policy.time_step_spec)\n    action_spec = tensor_spec.to_nest_array_spec(policy.action_spec)\n    super(PyTFPolicy, self).__init__(\n        time_step_spec, action_spec, policy_state_spec=(), info_spec=())\n\n    self._tf_policy = policy\n    self.session = None\n\n    self._policy_state_spec = tensor_spec.to_nest_array_spec(\n        self._tf_policy.policy_state_spec)\n\n    self._batch_size = None\n    self._batched = None\n    self._seed = seed\n    self._built = False\n\n  def _construct(self, batch_size, graph):\n    \"\"\"Construct the agent graph through placeholders.\"\"\"\n\n    self._batch_size = batch_size\n    self._batched = batch_size is not None\n\n    outer_dims = [self._batch_size] if self._batched else [1]\n    with graph.as_default():\n      self._time_step = tensor_spec.to_nest_placeholder(\n          self._tf_policy.time_step_spec, outer_dims=outer_dims)\n      self._tf_initial_state = self._tf_policy.get_initial_state(\n          batch_size=self._batch_size or 1)\n\n      self._policy_state = tf.nest.map_structure(\n          lambda ps: tf.compat.v1.placeholder(  # pylint: disable=g-long-lambda\n              ps.dtype,\n              ps.shape,\n              name='policy_state'),\n          self._tf_initial_state)\n      self._action_step = self._tf_policy.action(\n          self._time_step, self._policy_state, seed=self._seed)\n\n      self._actions = tensor_spec.to_nest_placeholder(\n          self._tf_policy.action_spec, outer_dims=outer_dims)\n      self._action_distribution = self._tf_policy.distribution(\n          self._time_step, policy_state=self._policy_state).action\n      self._log_prob = common.log_probability(self._action_distribution,\n                                              self._actions,\n                                              self._tf_policy.action_spec)\n\n  def initialize(self, batch_size, graph=None):\n    if self._built:\n      raise RuntimeError('PyTFPolicy can only be initialized once.')\n\n    if not graph:\n      graph = tf.compat.v1.get_default_graph()\n\n    self._construct(batch_size, graph)\n    var_list = tf.nest.flatten(self._tf_policy.variables())\n    common.initialize_uninitialized_variables(self.session, var_list)\n    self._built = True\n\n  def save(self, policy_dir=None, graph=None):\n    if not self._built:\n      raise RuntimeError('PyTFPolicy has not been initialized yet.')\n\n    if not graph:\n      graph = tf.compat.v1.get_default_graph()\n\n    with graph.as_default():\n      global_step = tf.compat.v1.train.get_or_create_global_step()\n      policy_checkpointer = common.Checkpointer(\n          ckpt_dir=policy_dir, policy=self._tf_policy, global_step=global_step)\n      policy_checkpointer.initialize_or_restore(self.session)\n      with self.session.as_default():\n        policy_checkpointer.save(global_step)\n\n  def restore(self, policy_dir, graph=None, assert_consumed=True):\n    \"\"\"Restores the policy from the checkpoint.\n\n    Args:\n      policy_dir: Directory with the checkpoint.\n      graph: A graph, inside which policy the is restored (optional).\n      assert_consumed: If true, contents of the checkpoint will be checked\n        for a match against graph variables.\n\n    Returns:\n      step: Global step associated with the restored policy checkpoint.\n\n    Raises:\n      RuntimeError: if the policy is not initialized.\n      AssertionError: if the checkpoint contains variables which do not have\n        matching names in the graph, and assert_consumed is set to True.\n\n    \"\"\"\n\n    if not self._built:\n      raise RuntimeError(\n          'PyTFPolicy must be initialized before being restored.')\n    if not graph:\n      graph = tf.compat.v1.get_default_graph()\n\n    with graph.as_default():\n      global_step = tf.compat.v1.train.get_or_create_global_step()\n      policy_checkpointer = common.Checkpointer(\n          ckpt_dir=policy_dir, policy=self._tf_policy, global_step=global_step)\n      status = policy_checkpointer.initialize_or_restore(self.session)\n      with self.session.as_default():\n        if assert_consumed:\n          status.assert_consumed()\n        status.run_restore_ops()\n      return self.session.run(global_step)\n\n  def _build_from_time_step(self, time_step):\n    outer_shape = nest_utils.get_outer_array_shape(time_step,\n                                                   self._time_step_spec)\n    if len(outer_shape) == 1:\n      self.initialize(outer_shape[0])\n    elif not outer_shape:\n      self.initialize(None)\n    else:\n      raise ValueError(\n          'Cannot handle more than one outer dimension. Saw {} outer '\n          'dimensions: {}'.format(len(outer_shape), outer_shape))\n\n  def _get_initial_state(self, batch_size):\n    if not self._built:\n      self.initialize(batch_size)\n    if batch_size != self._batch_size:\n      raise ValueError(\n          '`batch_size` argument is different from the batch size provided '\n          'previously. Expected {}, but saw {}.'.format(self._batch_size,\n                                                        batch_size))\n    return self.session.run(self._tf_initial_state)\n\n  def _action(self, time_step, policy_state):\n    if not self._built:\n      self._build_from_time_step(time_step)\n\n    batch_size = None\n    if time_step.step_type.shape:\n      batch_size = time_step.step_type.shape[0]\n    if self._batch_size != batch_size:\n      raise ValueError(\n          'The batch size of time_step is different from the batch size '\n          'provided previously. Expected {}, but saw {}.'.format(\n              self._batch_size, batch_size))\n\n    if not self._batched:\n      # Since policy_state is given in a batched form from the policy and we\n      # simply have to send it back we do not need to worry about it. Only\n      # update time_step.\n      time_step = nest_utils.batch_nested_array(time_step)\n\n    tf.nest.assert_same_structure(self._time_step, time_step)\n    feed_dict = {self._time_step: time_step}\n    if policy_state is not None:\n      # Flatten policy_state to handle specs that are not hashable due to lists.\n      for state_ph, state in zip(\n          tf.nest.flatten(self._policy_state), tf.nest.flatten(policy_state)):\n        feed_dict[state_ph] = state\n\n    action_step = self.session.run(self._action_step, feed_dict)\n    action, state, info = action_step\n\n    if not self._batched:\n      action, info = nest_utils.unbatch_nested_array([action, info])\n\n    return policy_step.PolicyStep(action, state, info)\n\n  def log_prob(self, time_step, action_step, policy_state=None):\n    if not self._built:\n      self._build_from_time_step(time_step)\n    tf.nest.assert_same_structure(self._time_step, time_step)\n    tf.nest.assert_same_structure(self._actions, action_step)\n    feed_dict = {self._time_step: time_step, self._actions: action_step}\n    if policy_state is not None:\n      feed_dict[self._policy_state] = policy_state\n    return self.session.run(self._log_prob, feed_dict)\n"
  },
  {
    "path": "lib/py_uniform_replay_buffer.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Uniform replay buffer in Python.\n\nThe base class provides all the functionalities of a uniform replay buffer:\n  - add samples in a First In First Out way.\n  - read samples uniformly.\n\nPyHashedReplayBuffer is a flavor of the base class which\ncompresses the observations when the observations have some partial overlap\n(e.g. when using frame stacking).\n\"\"\"\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport threading\n\nimport numpy as np\nimport tensorflow as tf\nfrom tf_agents.replay_buffers import replay_buffer\nfrom tf_agents.specs import array_spec\nfrom tf_agents.utils import nest_utils\nfrom tf_agents.utils import numpy_storage\n\n\nclass PyUniformReplayBuffer(replay_buffer.ReplayBuffer):\n  \"\"\"A Python-based replay buffer that supports uniform sampling.\n\n  Writing and reading to this replay buffer is thread safe.\n\n  This replay buffer can be subclassed to change the encoding used for the\n  underlying storage by overriding _encoded_data_spec, _encode, _decode, and\n  _on_delete.\n  \"\"\"\n\n  def __init__(self, data_spec, capacity):\n    \"\"\"Creates a PyUniformReplayBuffer.\n\n    Args:\n      data_spec: An ArraySpec or a list/tuple/nest of ArraySpecs describing a\n        single item that can be stored in this buffer.\n      capacity: The maximum number of items that can be stored in the buffer.\n    \"\"\"\n    super(PyUniformReplayBuffer, self).__init__(data_spec, capacity)\n\n    self._storage = numpy_storage.NumpyStorage(self._encoded_data_spec(),\n                                               capacity)\n    self._lock = threading.Lock()\n    self._np_state = numpy_storage.NumpyState()\n\n    # Adding elements to the replay buffer is done in a circular way.\n    # Keeps track of the actual size of the replay buffer and the location\n    # where to add new elements.\n    self._np_state.size = np.int64(0)\n    self._np_state.cur_id = np.int64(0)\n\n    # Total number of items that went through the replay buffer.\n    self._np_state.item_count = np.int64(0)\n\n  def _encoded_data_spec(self):\n    \"\"\"Spec of data items after encoding using _encode.\"\"\"\n    return self._data_spec\n\n  def _encode(self, item):\n    \"\"\"Encodes an item (before adding it to the buffer).\"\"\"\n    return item\n\n  def _decode(self, item):\n    \"\"\"Decodes an item.\"\"\"\n    return item\n\n  def _on_delete(self, encoded_item):\n    \"\"\"Do any necessary cleanup.\"\"\"\n    pass\n\n  @property\n  def size(self):\n    return self._np_state.size\n\n  def _add_batch(self, items):\n    outer_shape = nest_utils.get_outer_array_shape(items, self._data_spec)\n    if outer_shape[0] != 1:\n      raise NotImplementedError('PyUniformReplayBuffer only supports a batch '\n                                'size of 1, but received `items` with batch '\n                                'size {}.'.format(outer_shape[0]))\n\n    item = nest_utils.unbatch_nested_array(items)\n    with self._lock:\n      if self._np_state.size == self._capacity:\n        # If we are at capacity, we are deleting element cur_id.\n        self._on_delete(self._storage.get(self._np_state.cur_id))\n      self._storage.set(self._np_state.cur_id, self._encode(item))\n      self._np_state.size = np.minimum(self._np_state.size + 1,\n                                       self._capacity)\n      self._np_state.cur_id = (self._np_state.cur_id + 1) % self._capacity\n      self._np_state.item_count += 1\n\n  def _get_next(self,\n                sample_batch_size=None,\n                num_steps=None,\n                time_stacked=True):\n    num_steps_value = num_steps if num_steps is not None else 1\n    def get_single():\n      \"\"\"Gets a single item from the replay buffer.\"\"\"\n      with self._lock:\n        if self._np_state.size <= 0:\n          def empty_item(spec):\n            return np.empty(spec.shape, dtype=spec.dtype)\n          if num_steps is not None:\n            item = [tf.nest.map_structure(empty_item, self.data_spec)\n                    for n in range(num_steps)]\n            if time_stacked:\n              item = nest_utils.stack_nested_arrays(item)\n          else:\n            item = tf.nest.map_structure(empty_item, self.data_spec)\n          return item\n        idx = np.random.randint(self._np_state.size - num_steps_value + 1)\n        if self._np_state.size == self._capacity:\n          # If the buffer is full, add cur_id (head of circular buffer) so that\n          # we sample from the range [cur_id, cur_id + size - num_steps_value].\n          # We will modulo the size below.\n          idx += self._np_state.cur_id\n\n        if num_steps is not None:\n          # TODO(b/120242830): Try getting data from numpy in one shot rather\n          # than num_steps_value.\n          item = [self._decode(self._storage.get((idx + n) % self._capacity))\n                  for n in range(num_steps)]\n        else:\n          item = self._decode(self._storage.get(idx % self._capacity))\n\n      if num_steps is not None and time_stacked:\n        item = nest_utils.stack_nested_arrays(item)\n      return item\n\n    if sample_batch_size is None:\n      return get_single()\n    else:\n      samples = [get_single() for _ in range(sample_batch_size)]\n      return nest_utils.stack_nested_arrays(samples)\n\n  def _as_dataset(self, sample_batch_size=None, num_steps=None,\n                  num_parallel_calls=None):\n    if num_parallel_calls is not None:\n      raise NotImplementedError('PyUniformReplayBuffer does not support '\n                                'num_parallel_calls (must be None).')\n\n    data_spec = self._data_spec\n    if sample_batch_size is not None:\n      data_spec = array_spec.add_outer_dims_nest(\n          data_spec, (sample_batch_size,))\n    if num_steps is not None:\n      data_spec = (data_spec,) * num_steps\n    shapes = tuple(s.shape for s in tf.nest.flatten(data_spec))\n    dtypes = tuple(s.dtype for s in tf.nest.flatten(data_spec))\n\n    def generator_fn():\n      while True:\n        if sample_batch_size is not None:\n          batch = [self._get_next(num_steps=num_steps, time_stacked=False)\n                   for _ in range(sample_batch_size)]\n          item = nest_utils.stack_nested_arrays(batch)\n        else:\n          item = self._get_next(num_steps=num_steps, time_stacked=False)\n        yield tuple(tf.nest.flatten(item))\n\n    def time_stack(*structures):\n      time_axis = 0 if sample_batch_size is None else 1\n      return tf.nest.map_structure(\n          lambda *elements: tf.stack(elements, axis=time_axis), *structures)\n\n    ds = tf.data.Dataset.from_generator(\n        generator_fn, dtypes,\n        shapes).map(lambda *items: tf.nest.pack_sequence_as(data_spec, items))\n    if num_steps is not None:\n      return ds.map(time_stack)\n    else:\n      return ds\n\n  def _gather_all(self):\n    data = [self._decode(self._storage.get(idx))\n            for idx in range(self._capacity)]\n    stacked = nest_utils.stack_nested_arrays(data)\n    batched = tf.nest.map_structure(lambda t: np.expand_dims(t, 0), stacked)\n    return batched\n\n  def _clear(self):\n    self._np_state.size = np.int64(0)\n    self._np_state.cur_id = np.int64(0)\n\n  def gather_all_transitions(self):\n    num_steps_value = 2\n\n    def get_single(idx):\n      \"\"\"Gets the idx item from the replay buffer.\"\"\"\n      with self._lock:\n        if self._np_state.size <= idx:\n\n          def empty_item(spec):\n            return np.empty(spec.shape, dtype=spec.dtype)\n\n          item = [\n              tf.nest.map_structure(empty_item, self.data_spec)\n              for n in range(num_steps_value)\n          ]\n          item = nest_utils.stack_nested_arrays(item)\n          return item\n\n        if self._np_state.size == self._capacity:\n          # If the buffer is full, add cur_id (head of circular buffer) so that\n          # we sample from the range [cur_id, cur_id + size - num_steps_value].\n          # We will modulo the size below.\n          idx += self._np_state.cur_id\n\n        item = [\n            self._decode(self._storage.get((idx + n) % self._capacity))\n            for n in range(num_steps_value)\n        ]\n\n      item = nest_utils.stack_nested_arrays(item)\n      return item\n\n    samples = [\n        get_single(idx)\n        for idx in range(self._np_state.size - num_steps_value + 1)\n    ]\n    return nest_utils.stack_nested_arrays(samples)\n"
  },
  {
    "path": "unsupervised_skill_learning/dads_agent.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"TF-Agents Class for DADS. Builds on top of the SAC agent.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\n\nimport sys\nsys.path.append(os.path.abspath('./'))\n\nimport numpy as np\nimport tensorflow as tf\n\nfrom tf_agents.agents.sac import sac_agent\n\nimport skill_dynamics\n\nnest = tf.nest\n\n\nclass DADSAgent(sac_agent.SacAgent):\n\n  def __init__(self,\n               save_directory,\n               skill_dynamics_observation_size,\n               observation_modify_fn=None,\n               restrict_input_size=0,\n               latent_size=2,\n               latent_prior='cont_uniform',\n               prior_samples=100,\n               fc_layer_params=(256, 256),\n               normalize_observations=True,\n               network_type='default',\n               num_mixture_components=4,\n               fix_variance=True,\n               skill_dynamics_learning_rate=3e-4,\n               reweigh_batches=False,\n               agent_graph=None,\n               skill_dynamics_graph=None,\n               *sac_args,\n               **sac_kwargs):\n    self._skill_dynamics_learning_rate = skill_dynamics_learning_rate\n    self._latent_size = latent_size\n    self._latent_prior = latent_prior\n    self._prior_samples = prior_samples\n    self._save_directory = save_directory\n    self._restrict_input_size = restrict_input_size\n    self._process_observation = observation_modify_fn\n\n    if agent_graph is None:\n      self._graph = tf.compat.v1.get_default_graph()\n    else:\n      self._graph = agent_graph\n\n    if skill_dynamics_graph is None:\n      skill_dynamics_graph = self._graph\n\n    # instantiate the skill dynamics\n    self._skill_dynamics = skill_dynamics.SkillDynamics(\n        observation_size=skill_dynamics_observation_size,\n        action_size=self._latent_size,\n        restrict_observation=self._restrict_input_size,\n        normalize_observations=normalize_observations,\n        fc_layer_params=fc_layer_params,\n        network_type=network_type,\n        num_components=num_mixture_components,\n        fix_variance=fix_variance,\n        reweigh_batches=reweigh_batches,\n        graph=skill_dynamics_graph)\n\n    super(DADSAgent, self).__init__(*sac_args, **sac_kwargs)\n    self._placeholders_in_place = False\n\n  def compute_dads_reward(self, input_obs, cur_skill, target_obs):\n    if self._process_observation is not None:\n      input_obs, target_obs = self._process_observation(\n          input_obs), self._process_observation(target_obs)\n\n    num_reps = self._prior_samples if self._prior_samples > 0 else self._latent_size - 1\n    input_obs_altz = np.concatenate([input_obs] * num_reps, axis=0)\n    target_obs_altz = np.concatenate([target_obs] * num_reps, axis=0)\n\n    # for marginalization of the denominator\n    if self._latent_prior == 'discrete_uniform' and not self._prior_samples:\n      alt_skill = np.concatenate(\n          [np.roll(cur_skill, i, axis=1) for i in range(1, num_reps + 1)],\n          axis=0)\n    elif self._latent_prior == 'discrete_uniform':\n      alt_skill = np.random.multinomial(\n          1, [1. / self._latent_size] * self._latent_size,\n          size=input_obs_altz.shape[0])\n    elif self._latent_prior == 'gaussian':\n      alt_skill = np.random.multivariate_normal(\n          np.zeros(self._latent_size),\n          np.eye(self._latent_size),\n          size=input_obs_altz.shape[0])\n    elif self._latent_prior == 'cont_uniform':\n      alt_skill = np.random.uniform(\n          low=-1.0, high=1.0, size=(input_obs_altz.shape[0], self._latent_size))\n\n    logp = self._skill_dynamics.get_log_prob(input_obs, cur_skill, target_obs)\n\n    # denominator may require more memory than that of a GPU, break computation\n    split_group = 20 * 4000\n    if input_obs_altz.shape[0] <= split_group:\n      logp_altz = self._skill_dynamics.get_log_prob(input_obs_altz, alt_skill,\n                                                    target_obs_altz)\n    else:\n      logp_altz = []\n      for split_idx in range(input_obs_altz.shape[0] // split_group):\n        start_split = split_idx * split_group\n        end_split = (split_idx + 1) * split_group\n        logp_altz.append(\n            self._skill_dynamics.get_log_prob(\n                input_obs_altz[start_split:end_split],\n                alt_skill[start_split:end_split],\n                target_obs_altz[start_split:end_split]))\n      if input_obs_altz.shape[0] % split_group:\n        start_split = input_obs_altz.shape[0] % split_group\n        logp_altz.append(\n            self._skill_dynamics.get_log_prob(input_obs_altz[-start_split:],\n                                              alt_skill[-start_split:],\n                                              target_obs_altz[-start_split:]))\n      logp_altz = np.concatenate(logp_altz)\n    logp_altz = np.array(np.array_split(logp_altz, num_reps))\n\n    # final DADS reward\n    intrinsic_reward = np.log(num_reps + 1) - np.log(1 + np.exp(\n        np.clip(logp_altz - logp.reshape(1, -1), -50, 50)).sum(axis=0))\n\n    return intrinsic_reward, {'logp': logp, 'logp_altz': logp_altz.flatten()}\n\n  def get_experience_placeholder(self):\n    self._placeholders_in_place = True\n    self._placeholders = []\n    for item in nest.flatten(self.collect_data_spec):\n      self._placeholders += [\n          tf.compat.v1.placeholder(\n              item.dtype,\n              shape=(None, 2) if len(item.shape) == 0 else\n              (None, 2, item.shape[-1]),\n              name=item.name)\n      ]\n    self._policy_experience_ph = nest.pack_sequence_as(self.collect_data_spec,\n                                                       self._placeholders)\n    return self._policy_experience_ph\n\n  def build_agent_graph(self):\n    with self._graph.as_default():\n      self.get_experience_placeholder()\n      self.agent_train_op = self.train(self._policy_experience_ph)\n      self.summary_ops = tf.compat.v1.summary.all_v2_summary_ops()\n      return self.agent_train_op\n\n  def build_skill_dynamics_graph(self):\n    self._skill_dynamics.make_placeholders()\n    self._skill_dynamics.build_graph()\n    self._skill_dynamics.increase_prob_op(\n        learning_rate=self._skill_dynamics_learning_rate)\n\n  def create_savers(self):\n    self._skill_dynamics.create_saver(\n        save_prefix=os.path.join(self._save_directory, 'dynamics'))\n\n  def set_sessions(self, initialize_or_restore_skill_dynamics, session=None):\n    if session is not None:\n      self._session = session\n    else:\n      self._session = tf.compat.v1.Session(graph=self._graph)\n    self._skill_dynamics.set_session(\n        initialize_or_restore_variables=initialize_or_restore_skill_dynamics,\n        session=session)\n\n  def save_variables(self, global_step):\n    self._skill_dynamics.save_variables(global_step=global_step)\n\n  def _get_dict(self, trajectories, batch_size=-1):\n    tf.nest.assert_same_structure(self.collect_data_spec, trajectories)\n    if batch_size > 0:\n      shuffled_batch = np.random.permutation(\n          trajectories.observation.shape[0])[:batch_size]\n    else:\n      shuffled_batch = np.arange(trajectories.observation.shape[0])\n\n    return_dict = {}\n\n    for placeholder, val in zip(self._placeholders, nest.flatten(trajectories)):\n      return_dict[placeholder] = val[shuffled_batch]\n\n    return return_dict\n\n  def train_loop(self,\n                 trajectories,\n                 recompute_reward=False,\n                 batch_size=-1,\n                 num_steps=1):\n    if not self._placeholders_in_place:\n      return\n\n    if recompute_reward:\n      input_obs = trajectories.observation[:, 0, :-self._latent_size]\n      cur_skill = trajectories.observation[:, 0, -self._latent_size:]\n      target_obs = trajectories.observation[:, 1, :-self._latent_size]\n      new_reward, info = self.compute_dads_reward(input_obs, cur_skill,\n                                                  target_obs)\n      trajectories = trajectories._replace(\n          reward=np.concatenate(\n              [np.expand_dims(new_reward, axis=1), trajectories.reward[:, 1:]],\n              axis=1))\n\n    # TODO(architsh):all agent specs should be the same as env specs, shift preprocessing to actor/critic networks\n    if self._restrict_input_size > 0:\n      trajectories = trajectories._replace(\n          observation=trajectories.observation[:, :,\n                                               self._restrict_input_size:])\n\n    for _ in range(num_steps):\n      self._session.run([self.agent_train_op, self.summary_ops],\n                        feed_dict=self._get_dict(\n                            trajectories, batch_size=batch_size))\n\n    if recompute_reward:\n      return new_reward, info\n    else:\n      return None, None\n\n  @property\n  def skill_dynamics(self):\n    return self._skill_dynamics\n"
  },
  {
    "path": "unsupervised_skill_learning/dads_off.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport time\nimport pickle as pkl\nimport os\nimport io\nfrom absl import flags, logging\nimport functools\n\nimport sys\nsys.path.append(os.path.abspath('./'))\n\nimport matplotlib\nmatplotlib.use('Agg')\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport tensorflow as tf\nimport tensorflow_probability as tfp\n\nfrom tf_agents.agents.ddpg import critic_network\nfrom tf_agents.agents.sac import sac_agent\nfrom tf_agents.environments import suite_mujoco\nfrom tf_agents.trajectories import time_step as ts\nfrom tf_agents.environments.suite_gym import wrap_env\nfrom tf_agents.trajectories.trajectory import from_transition, to_transition\nfrom tf_agents.networks import actor_distribution_network\nfrom tf_agents.networks import normal_projection_network\nfrom tf_agents.policies import ou_noise_policy\nfrom tf_agents.trajectories import policy_step\n# from tf_agents.policies import py_tf_policy\n# from tf_agents.replay_buffers import py_uniform_replay_buffer\nfrom tf_agents.specs import array_spec\nfrom tf_agents.specs import tensor_spec\nfrom tf_agents.utils import common\nfrom tf_agents.utils import nest_utils\n\nimport dads_agent\n\nfrom envs import skill_wrapper\nfrom envs import video_wrapper\nfrom envs.gym_mujoco import ant\nfrom envs.gym_mujoco import half_cheetah\nfrom envs.gym_mujoco import humanoid\nfrom envs.gym_mujoco import point_mass\n\nfrom envs import dclaw\nfrom envs import dkitty_redesign\nfrom envs import hand_block\n\nfrom lib import py_tf_policy\nfrom lib import py_uniform_replay_buffer\n\nFLAGS = flags.FLAGS\nnest = tf.nest\n\n# general hyperparameters\nflags.DEFINE_string('logdir', '~/tmp/dads', 'Directory for saving experiment data')\n\n# environment hyperparameters\nflags.DEFINE_string('environment', 'point_mass', 'Name of the environment')\nflags.DEFINE_integer('max_env_steps', 200,\n                     'Maximum number of steps in one episode')\nflags.DEFINE_integer('reduced_observation', 0,\n                     'Predict dynamics in a reduced observation space')\nflags.DEFINE_integer(\n    'min_steps_before_resample', 50,\n    'Minimum number of steps to execute before resampling skill')\nflags.DEFINE_float('resample_prob', 0.,\n                   'Creates stochasticity timesteps before resampling skill')\n\n# need to set save_model and save_freq\nflags.DEFINE_string(\n    'save_model', None,\n    'Name to save the model with, None implies the models are not saved.')\nflags.DEFINE_integer('save_freq', 100, 'Saving frequency for checkpoints')\nflags.DEFINE_string(\n    'vid_name', None,\n    'Base name for videos being saved, None implies videos are not recorded')\nflags.DEFINE_integer('record_freq', 100,\n                     'Video recording frequency within the training loop')\n\n# final evaluation after training is done\nflags.DEFINE_integer('run_eval', 0, 'Evaluate learnt skills')\n\n# evaluation type\nflags.DEFINE_integer('num_evals', 0, 'Number of skills to evaluate')\nflags.DEFINE_integer('deterministic_eval', 0,\n                  'Evaluate all skills, only works for discrete skills')\n\n# training\nflags.DEFINE_integer('run_train', 0, 'Train the agent')\nflags.DEFINE_integer('num_epochs', 500, 'Number of training epochs')\n\n# skill latent space\nflags.DEFINE_integer('num_skills', 2, 'Number of skills to learn')\nflags.DEFINE_string('skill_type', 'cont_uniform',\n                    'Type of skill and the prior over it')\n# network size hyperparameter\nflags.DEFINE_integer(\n    'hidden_layer_size', 512,\n    'Hidden layer size, shared by actors, critics and dynamics')\n\n# reward structure\nflags.DEFINE_integer(\n    'random_skills', 0,\n    'Number of skills to sample randomly for approximating mutual information')\n\n# optimization hyperparameters\nflags.DEFINE_integer('replay_buffer_capacity', int(1e6),\n                     'Capacity of the replay buffer')\nflags.DEFINE_integer(\n    'clear_buffer_every_iter', 0,\n    'Clear replay buffer every iteration to simulate on-policy training, use larger collect steps and train-steps'\n)\nflags.DEFINE_integer(\n    'initial_collect_steps', 2000,\n    'Steps collected initially before training to populate the buffer')\nflags.DEFINE_integer('collect_steps', 200, 'Steps collected per agent update')\n\n# relabelling\nflags.DEFINE_string('agent_relabel_type', None,\n                    'Type of skill relabelling used for agent')\nflags.DEFINE_integer(\n    'train_skill_dynamics_on_policy', 0,\n    'Train skill-dynamics on policy data, while agent train off-policy')\nflags.DEFINE_string('skill_dynamics_relabel_type', None,\n                    'Type of skill relabelling used for skill-dynamics')\nflags.DEFINE_integer(\n    'num_samples_for_relabelling', 100,\n    'Number of samples from prior for relabelling the current skill when using policy relabelling'\n)\nflags.DEFINE_float(\n    'is_clip_eps', 0.,\n    'PPO style clipping epsilon to constrain importance sampling weights to (1-eps, 1+eps)'\n)\nflags.DEFINE_float(\n    'action_clipping', 1.,\n    'Clip actions to (-eps, eps) per dimension to avoid difficulties with tanh')\nflags.DEFINE_integer('debug_skill_relabelling', 0,\n                     'analysis of skill relabelling')\n\n# skill dynamics optimization hyperparamaters\nflags.DEFINE_integer('skill_dyn_train_steps', 8,\n                     'Number of discriminator train steps on a batch of data')\nflags.DEFINE_float('skill_dynamics_lr', 3e-4,\n                   'Learning rate for increasing the log-likelihood')\nflags.DEFINE_integer('skill_dyn_batch_size', 256,\n                     'Batch size for discriminator updates')\n# agent optimization hyperparameters\nflags.DEFINE_integer('agent_batch_size', 256, 'Batch size for agent updates')\nflags.DEFINE_integer('agent_train_steps', 128,\n                     'Number of update steps per iteration')\nflags.DEFINE_float('agent_lr', 3e-4, 'Learning rate for the agent')\n\n# SAC hyperparameters\nflags.DEFINE_float('agent_entropy', 0.1, 'Entropy regularization coefficient')\nflags.DEFINE_float('agent_gamma', 0.99, 'Reward discount factor')\nflags.DEFINE_string(\n    'collect_policy', 'default',\n    'Can use the OUNoisePolicy to collect experience for better exploration')\n\n# skill-dynamics hyperparameters\nflags.DEFINE_string(\n    'graph_type', 'default',\n    'process skill input separately for more representational power')\nflags.DEFINE_integer('num_components', 4,\n                     'Number of components for Mixture of Gaussians')\nflags.DEFINE_integer('fix_variance', 1,\n                     'Fix the variance of output distribution')\nflags.DEFINE_integer('normalize_data', 1, 'Maintain running averages')\n\n# debug\nflags.DEFINE_integer('debug', 0, 'Creates extra summaries')\n\n# DKitty\nflags.DEFINE_integer('expose_last_action', 1, 'Add the last action to the observation')\nflags.DEFINE_integer('expose_upright', 1, 'Add the upright angle to the observation')\nflags.DEFINE_float('upright_threshold', 0.9, 'Threshold before which the DKitty episode is terminated')\nflags.DEFINE_float('robot_noise_ratio', 0.05, 'Noise ratio for robot joints')\nflags.DEFINE_float('root_noise_ratio', 0.002, 'Noise ratio for root position')\nflags.DEFINE_float('scale_root_position', 1, 'Multiply the root coordinates the magnify the change')\nflags.DEFINE_integer('run_on_hardware', 0, 'Flag for hardware runs')\nflags.DEFINE_float('randomize_hfield', 0.0, 'Randomize terrain for better DKitty transfer')\nflags.DEFINE_integer('observation_omission_size', 2, 'Dimensions to be omitted from policy input')\n\n# Manipulation Environments\nflags.DEFINE_integer('randomized_initial_distribution', 1, 'Fix the initial distribution or not')\nflags.DEFINE_float('horizontal_wrist_constraint', 1.0, 'Action space constraint to restrict horizontal motion of the wrist')\nflags.DEFINE_float('vertical_wrist_constraint', 1.0, 'Action space constraint to restrict vertical motion of the wrist')\n\n# MPC hyperparameters\nflags.DEFINE_integer('planning_horizon', 1, 'Number of primitives to plan in the future')\nflags.DEFINE_integer('primitive_horizon', 1, 'Horizon for every primitive')\nflags.DEFINE_integer('num_candidate_sequences', 50, 'Number of candidates sequence sampled from the proposal distribution')\nflags.DEFINE_integer('refine_steps', 10, 'Number of optimization steps')\nflags.DEFINE_float('mppi_gamma', 10.0, 'MPPI weighting hyperparameter')\nflags.DEFINE_string('prior_type', 'normal', 'Uniform or Gaussian prior for candidate skill(s)')\nflags.DEFINE_float('smoothing_beta', 0.9, 'Smooth candidate skill sequences used')\nflags.DEFINE_integer('top_primitives', 5, 'Optimization parameter when using uniform prior (CEM style)')\n\n# global variables for this script\nobservation_omit_size = 0\ngoal_coord = np.array([10., 10.])\nsample_count = 0\niter_count = 0\nepisode_size_buffer = []\nepisode_return_buffer = []\n\n# add a flag for state dependent std\ndef _normal_projection_net(action_spec, init_means_output_factor=0.1):\n  return normal_projection_network.NormalProjectionNetwork(\n      action_spec,\n      mean_transform=None,\n      state_dependent_std=True,\n      init_means_output_factor=init_means_output_factor,\n      std_transform=sac_agent.std_clip_transform,\n      scale_distribution=True)\n\ndef get_environment(env_name='point_mass'):\n  global observation_omit_size\n  if env_name == 'Ant-v1':\n    env = ant.AntEnv(\n        expose_all_qpos=True,\n        task='motion')\n    observation_omit_size = 2\n  elif env_name == 'Ant-v1_goal':\n    observation_omit_size = 2\n    return wrap_env(\n        ant.AntEnv(\n            task='goal',\n            goal=goal_coord,\n            expose_all_qpos=True),\n        max_episode_steps=FLAGS.max_env_steps)\n  elif env_name == 'Ant-v1_foot_sensor':\n    env = ant.AntEnv(\n        expose_all_qpos=True,\n        model_path='ant_footsensor.xml',\n        expose_foot_sensors=True)\n    observation_omit_size = 2\n  elif env_name == 'HalfCheetah-v1':\n    env = half_cheetah.HalfCheetahEnv(expose_all_qpos=True, task='motion')\n    observation_omit_size = 1\n  elif env_name == 'Humanoid-v1':\n    env = humanoid.HumanoidEnv(expose_all_qpos=True)\n    observation_omit_size = 2\n  elif env_name == 'point_mass':\n    env = point_mass.PointMassEnv(expose_goal=False, expose_velocity=False)\n    observation_omit_size = 2\n  elif env_name == 'DClaw':\n    env = dclaw.DClawTurnRandom()\n    observation_omit_size = FLAGS.observation_omission_size\n  elif env_name == 'DClaw_randomized':\n    env = dclaw.DClawTurnRandomDynamics()\n    observation_omit_size = FLAGS.observation_omission_size\n  elif env_name == 'DKitty_redesign':\n    env = dkitty_redesign.BaseDKittyWalk(\n        expose_last_action=FLAGS.expose_last_action,\n        expose_upright=FLAGS.expose_upright,\n        robot_noise_ratio=FLAGS.robot_noise_ratio,\n        upright_threshold=FLAGS.upright_threshold)\n    observation_omit_size = FLAGS.observation_omission_size\n  elif env_name == 'DKitty_randomized':\n    env = dkitty_redesign.DKittyRandomDynamics(\n        randomize_hfield=FLAGS.randomize_hfield,\n        expose_last_action=FLAGS.expose_last_action,\n        expose_upright=FLAGS.expose_upright,\n        robot_noise_ratio=FLAGS.robot_noise_ratio,\n        upright_threshold=FLAGS.upright_threshold)\n    observation_omit_size = FLAGS.observation_omission_size\n  elif env_name == 'HandBlock':\n    observation_omit_size = 0\n    env = hand_block.HandBlockCustomEnv(\n        horizontal_wrist_constraint=FLAGS.horizontal_wrist_constraint,\n        vertical_wrist_constraint=FLAGS.vertical_wrist_constraint,\n        randomize_initial_position=bool(FLAGS.randomized_initial_distribution),\n        randomize_initial_rotation=bool(FLAGS.randomized_initial_distribution))\n  else:\n    # note this is already wrapped, no need to wrap again\n    env = suite_mujoco.load(env_name)\n  return env\n\ndef hide_coords(time_step):\n  global observation_omit_size\n  if observation_omit_size > 0:\n    sans_coords = time_step.observation[observation_omit_size:]\n    return time_step._replace(observation=sans_coords)\n\n  return time_step\n\n\ndef relabel_skill(trajectory_sample,\n                  relabel_type=None,\n                  cur_policy=None,\n                  cur_skill_dynamics=None):\n  global observation_omit_size\n  if relabel_type is None or ('importance_sampling' in relabel_type and\n                              FLAGS.is_clip_eps <= 1.0):\n    return trajectory_sample, None\n\n  # trajectory.to_transition, but for numpy arrays\n  next_trajectory = nest.map_structure(lambda x: x[:, 1:], trajectory_sample)\n  trajectory = nest.map_structure(lambda x: x[:, :-1], trajectory_sample)\n  action_steps = policy_step.PolicyStep(\n      action=trajectory.action, state=(), info=trajectory.policy_info)\n  time_steps = ts.TimeStep(\n      trajectory.step_type,\n      reward=nest.map_structure(np.zeros_like, trajectory.reward),  # unknown\n      discount=np.zeros_like(trajectory.discount),  # unknown\n      observation=trajectory.observation)\n  next_time_steps = ts.TimeStep(\n      step_type=trajectory.next_step_type,\n      reward=trajectory.reward,\n      discount=trajectory.discount,\n      observation=next_trajectory.observation)\n  time_steps, action_steps, next_time_steps = nest.map_structure(\n      lambda t: np.squeeze(t, axis=1),\n      (time_steps, action_steps, next_time_steps))\n\n  # just return the importance sampling weights for the given batch\n  if 'importance_sampling' in relabel_type:\n    old_log_probs = policy_step.get_log_probability(action_steps.info)\n    is_weights = []\n    for idx in range(time_steps.observation.shape[0]):\n      cur_time_step = nest.map_structure(lambda x: x[idx:idx + 1], time_steps)\n      cur_time_step = cur_time_step._replace(\n          observation=cur_time_step.observation[:, observation_omit_size:])\n      old_log_prob = old_log_probs[idx]\n      cur_log_prob = cur_policy.log_prob(cur_time_step,\n                                         action_steps.action[idx:idx + 1])[0]\n      is_weights.append(\n          np.clip(\n              np.exp(cur_log_prob - old_log_prob), 1. / FLAGS.is_clip_eps,\n              FLAGS.is_clip_eps))\n\n    is_weights = np.array(is_weights)\n    if relabel_type == 'normalized_importance_sampling':\n      is_weights = is_weights / is_weights.mean()\n\n    return trajectory_sample, is_weights\n\n  new_observation = np.zeros(time_steps.observation.shape)\n  for idx in range(time_steps.observation.shape[0]):\n    alt_time_steps = nest.map_structure(\n        lambda t: np.stack([t[idx]] * FLAGS.num_samples_for_relabelling),\n        time_steps)\n\n    # sample possible skills for relabelling from the prior\n    if FLAGS.skill_type == 'cont_uniform':\n      # always ensure that the original skill is one of the possible option for relabelling skills\n      alt_skills = np.concatenate([\n          np.random.uniform(\n              low=-1.0,\n              high=1.0,\n              size=(FLAGS.num_samples_for_relabelling - 1, FLAGS.num_skills)),\n          alt_time_steps.observation[:1, -FLAGS.num_skills:]\n      ])\n\n    # choose the skill which gives the highest log-probability to the current action\n    if relabel_type == 'policy':\n      cur_action = np.stack([action_steps.action[idx, :]] *\n                            FLAGS.num_samples_for_relabelling)\n      alt_time_steps = alt_time_steps._replace(\n          observation=np.concatenate([\n              alt_time_steps\n              .observation[:,\n                           observation_omit_size:-FLAGS.num_skills], alt_skills\n          ],\n                                     axis=1))\n      action_log_probs = cur_policy.log_prob(alt_time_steps, cur_action)\n      if FLAGS.debug_skill_relabelling:\n        print('\\n action_log_probs analysis----', idx,\n              time_steps.observation[idx, -FLAGS.num_skills:])\n        print('number of skills with higher log-probs:',\n              np.sum(action_log_probs >= action_log_probs[-1]))\n        print('Skills with log-probs higher than actual skill:')\n        skill_dist = []\n        for skill_idx in range(FLAGS.num_samples_for_relabelling):\n          if action_log_probs[skill_idx] >= action_log_probs[-1]:\n            print(alt_skills[skill_idx])\n            skill_dist.append(\n                np.linalg.norm(alt_skills[skill_idx] - alt_skills[-1]))\n        print('average distance of skills with higher-log-prob:',\n              np.mean(skill_dist))\n      max_skill_idx = np.argmax(action_log_probs)\n\n    # choose the skill which gets the highest log-probability under the dynamics posterior\n    elif relabel_type == 'dynamics_posterior':\n      cur_observations = alt_time_steps.observation[:, :-FLAGS.num_skills]\n      next_observations = np.stack(\n          [next_time_steps.observation[idx, :-FLAGS.num_skills]] *\n          FLAGS.num_samples_for_relabelling)\n\n      # max over posterior log probability is exactly the max over log-prob of transitin under skill-dynamics\n      posterior_log_probs = cur_skill_dynamics.get_log_prob(\n          process_observation(cur_observations), alt_skills,\n          process_observation(next_observations))\n      if FLAGS.debug_skill_relabelling:\n        print('\\n dynamics_log_probs analysis----', idx,\n              time_steps.observation[idx, -FLAGS.num_skills:])\n        print('number of skills with higher log-probs:',\n              np.sum(posterior_log_probs >= posterior_log_probs[-1]))\n        print('Skills with log-probs higher than actual skill:')\n        skill_dist = []\n        for skill_idx in range(FLAGS.num_samples_for_relabelling):\n          if posterior_log_probs[skill_idx] >= posterior_log_probs[-1]:\n            print(alt_skills[skill_idx])\n            skill_dist.append(\n                np.linalg.norm(alt_skills[skill_idx] - alt_skills[-1]))\n        print('average distance of skills with higher-log-prob:',\n              np.mean(skill_dist))\n\n      max_skill_idx = np.argmax(posterior_log_probs)\n\n    # make the new observation with the relabelled skill\n    relabelled_skill = alt_skills[max_skill_idx]\n    new_observation[idx] = np.concatenate(\n        [time_steps.observation[idx, :-FLAGS.num_skills], relabelled_skill])\n\n  traj_observation = np.copy(trajectory_sample.observation)\n  traj_observation[:, 0] = new_observation\n  new_trajectory_sample = trajectory_sample._replace(\n      observation=traj_observation)\n\n  return new_trajectory_sample, None\n\n\n# hard-coding the state-space for dynamics\ndef process_observation(observation):\n\n  def _shape_based_observation_processing(observation, dim_idx):\n    if len(observation.shape) == 1:\n      return observation[dim_idx:dim_idx + 1]\n    elif len(observation.shape) == 2:\n      return observation[:, dim_idx:dim_idx + 1]\n    elif len(observation.shape) == 3:\n      return observation[:, :, dim_idx:dim_idx + 1]\n\n  # for consistent use\n  if FLAGS.reduced_observation == 0:\n    return observation\n\n  # process observation for dynamics with reduced observation space\n  if FLAGS.environment == 'HalfCheetah-v1':\n    qpos_dim = 9\n  elif FLAGS.environment == 'Ant-v1':\n    qpos_dim = 15\n  elif FLAGS.environment == 'Humanoid-v1':\n    qpos_dim = 26\n  elif 'DKitty' in FLAGS.environment:\n    qpos_dim = 36\n\n  # x-axis\n  if FLAGS.reduced_observation in [1, 5]:\n    red_obs = [_shape_based_observation_processing(observation, 0)]\n  # x-y plane\n  elif FLAGS.reduced_observation in [2, 6]:\n    if FLAGS.environment == 'Ant-v1' or 'DKitty' in FLAGS.environment or 'DClaw' in FLAGS.environment:\n      red_obs = [\n          _shape_based_observation_processing(observation, 0),\n          _shape_based_observation_processing(observation, 1)\n      ]\n    else:\n      red_obs = [\n          _shape_based_observation_processing(observation, 0),\n          _shape_based_observation_processing(observation, qpos_dim)\n      ]\n  # x-y plane, x-y velocities\n  elif FLAGS.reduced_observation in [4, 8]:\n    if FLAGS.reduced_observation == 4 and 'DKittyPush' in FLAGS.environment:\n      # position of the agent + relative position of the box\n      red_obs = [\n          _shape_based_observation_processing(observation, 0),\n          _shape_based_observation_processing(observation, 1),\n          _shape_based_observation_processing(observation, 3),\n          _shape_based_observation_processing(observation, 4)\n      ]\n    elif FLAGS.environment in ['Ant-v1']:\n      red_obs = [\n          _shape_based_observation_processing(observation, 0),\n          _shape_based_observation_processing(observation, 1),\n          _shape_based_observation_processing(observation, qpos_dim),\n          _shape_based_observation_processing(observation, qpos_dim + 1)\n      ]\n\n  # (x, y, orientation), works only for ant, point_mass\n  elif FLAGS.reduced_observation == 3:\n    if FLAGS.environment in ['Ant-v1', 'point_mass']:\n      red_obs = [\n          _shape_based_observation_processing(observation, 0),\n          _shape_based_observation_processing(observation, 1),\n          _shape_based_observation_processing(observation,\n                                              observation.shape[1] - 1)\n      ]\n    # x, y, z of the center of the block\n    elif FLAGS.environment in ['HandBlock']:\n      red_obs = [\n          _shape_based_observation_processing(observation, \n                                              observation.shape[-1] - 7),\n          _shape_based_observation_processing(observation, \n                                              observation.shape[-1] - 6),\n          _shape_based_observation_processing(observation,\n                                              observation.shape[-1] - 5)\n      ]\n\n  if FLAGS.reduced_observation in [5, 6, 8]:\n    red_obs += [\n        _shape_based_observation_processing(observation,\n                                            observation.shape[1] - idx)\n        for idx in range(1, 5)\n    ]\n\n  if FLAGS.reduced_observation == 36 and 'DKitty' in FLAGS.environment:\n    red_obs = [\n        _shape_based_observation_processing(observation, idx)\n        for idx in range(qpos_dim)\n    ]\n\n  # x, y, z and the rotation quaternion\n  if FLAGS.reduced_observation == 7 and FLAGS.environment == 'HandBlock':\n    red_obs = [\n        _shape_based_observation_processing(observation, observation.shape[-1] - idx)\n        for idx in range(1, 8)\n    ][::-1]\n\n  # the rotation quaternion\n  if FLAGS.reduced_observation == 4 and FLAGS.environment == 'HandBlock':\n    red_obs = [\n        _shape_based_observation_processing(observation, observation.shape[-1] - idx)\n        for idx in range(1, 5)\n    ][::-1]\n\n  if isinstance(observation, np.ndarray):\n    input_obs = np.concatenate(red_obs, axis=len(observation.shape) - 1)\n  elif isinstance(observation, tf.Tensor):\n    input_obs = tf.concat(red_obs, axis=len(observation.shape) - 1)\n  return input_obs\n\n\ndef collect_experience(py_env,\n                       time_step,\n                       collect_policy,\n                       buffer_list,\n                       num_steps=1):\n\n  episode_sizes = []\n  extrinsic_reward = []\n  step_idx = 0\n  cur_return = 0.\n  for step_idx in range(num_steps):\n    if time_step.is_last():\n      episode_sizes.append(step_idx)\n      extrinsic_reward.append(cur_return)\n      cur_return = 0.\n\n    action_step = collect_policy.action(hide_coords(time_step))\n\n    if FLAGS.action_clipping < 1.:\n      action_step = action_step._replace(\n          action=np.clip(action_step.action, -FLAGS.action_clipping,\n                         FLAGS.action_clipping))\n\n    if FLAGS.skill_dynamics_relabel_type is not None and 'importance_sampling' in FLAGS.skill_dynamics_relabel_type and FLAGS.is_clip_eps > 1.0:\n      cur_action_log_prob = collect_policy.log_prob(\n          nest_utils.batch_nested_array(hide_coords(time_step)),\n          np.expand_dims(action_step.action, 0))\n      action_step = action_step._replace(\n          info=policy_step.set_log_probability(action_step.info,\n                                               cur_action_log_prob))\n\n    next_time_step = py_env.step(action_step.action)\n    cur_return += next_time_step.reward\n\n    # all modification to observations and training will be done within the agent\n    for buffer_ in buffer_list:\n      buffer_.add_batch(\n          from_transition(\n              nest_utils.batch_nested_array(time_step),\n              nest_utils.batch_nested_array(action_step),\n              nest_utils.batch_nested_array(next_time_step)))\n\n    time_step = next_time_step\n\n  # carry-over calculation for the next collection cycle\n  episode_sizes.append(step_idx + 1)\n  extrinsic_reward.append(cur_return)\n  for idx in range(1, len(episode_sizes)):\n    episode_sizes[-idx] -= episode_sizes[-idx - 1]\n\n  return time_step, {\n      'episode_sizes': episode_sizes,\n      'episode_return': extrinsic_reward\n  }\n\n\ndef run_on_env(env,\n               policy,\n               dynamics=None,\n               predict_trajectory_steps=0,\n               return_data=False,\n               close_environment=True):\n  time_step = env.reset()\n  data = []\n\n  if not return_data:\n    extrinsic_reward = []\n  while not time_step.is_last():\n    action_step = policy.action(hide_coords(time_step))\n    if FLAGS.action_clipping < 1.:\n      action_step = action_step._replace(\n          action=np.clip(action_step.action, -FLAGS.action_clipping,\n                         FLAGS.action_clipping))\n\n    env_action = action_step.action\n    next_time_step = env.step(env_action)\n\n    skill_size = FLAGS.num_skills\n    if skill_size > 0:\n      cur_observation = time_step.observation[:-skill_size]\n      cur_skill = time_step.observation[-skill_size:]\n      next_observation = next_time_step.observation[:-skill_size]\n    else:\n      cur_observation = time_step.observation\n      next_observation = next_time_step.observation\n\n    if dynamics is not None:\n      if FLAGS.reduced_observation:\n        cur_observation, next_observation = process_observation(\n            cur_observation), process_observation(next_observation)\n      logp = dynamics.get_log_prob(\n          np.expand_dims(cur_observation, 0), np.expand_dims(cur_skill, 0),\n          np.expand_dims(next_observation, 0))\n\n      cur_predicted_state = np.expand_dims(cur_observation, 0)\n      skill_expanded = np.expand_dims(cur_skill, 0)\n      cur_predicted_trajectory = [cur_predicted_state[0]]\n      for _ in range(predict_trajectory_steps):\n        next_predicted_state = dynamics.predict_state(cur_predicted_state,\n                                                      skill_expanded)\n        cur_predicted_trajectory.append(next_predicted_state[0])\n        cur_predicted_state = next_predicted_state\n    else:\n      logp = ()\n      cur_predicted_trajectory = []\n\n    if return_data:\n      data.append([\n          cur_observation, action_step.action, logp, next_time_step.reward,\n          np.array(cur_predicted_trajectory)\n      ])\n    else:\n      extrinsic_reward.append([next_time_step.reward])\n\n    time_step = next_time_step\n\n  if close_environment:\n    env.close()\n\n  if return_data:\n    return data\n  else:\n    return extrinsic_reward\n\n\ndef eval_loop(eval_dir,\n              eval_policy,\n              dynamics=None,\n              vid_name=None,\n              plot_name=None):\n  metadata = tf.io.gfile.GFile(\n      os.path.join(eval_dir, 'metadata.txt'), 'a')\n  if FLAGS.num_skills == 0:\n    num_evals = FLAGS.num_evals\n  elif FLAGS.deterministic_eval:\n    num_evals = FLAGS.num_skills\n  else:\n    num_evals = FLAGS.num_evals\n\n  if plot_name is not None:\n    # color_map = ['b', 'g', 'r', 'c', 'm', 'y', 'k']\n    color_map = ['b', 'g', 'r', 'c', 'm', 'y']\n    style_map = []\n    for line_style in ['-', '--', '-.', ':']:\n      style_map += [color + line_style for color in color_map]\n\n    plt.xlim(-15, 15)\n    plt.ylim(-15, 15)\n    # all_trajectories = []\n    # all_predicted_trajectories = []\n\n  for idx in range(num_evals):\n    if FLAGS.num_skills > 0:\n      if FLAGS.deterministic_eval:\n        preset_skill = np.zeros(FLAGS.num_skills, dtype=np.int64)\n        preset_skill[idx] = 1\n      elif FLAGS.skill_type == 'discrete_uniform':\n        preset_skill = np.random.multinomial(1, [1. / FLAGS.num_skills] *\n                                             FLAGS.num_skills)\n      elif FLAGS.skill_type == 'gaussian':\n        preset_skill = np.random.multivariate_normal(\n            np.zeros(FLAGS.num_skills), np.eye(FLAGS.num_skills))\n      elif FLAGS.skill_type == 'cont_uniform':\n        preset_skill = np.random.uniform(\n            low=-1.0, high=1.0, size=FLAGS.num_skills)\n      elif FLAGS.skill_type == 'multivariate_bernoulli':\n        preset_skill = np.random.binomial(1, 0.5, size=FLAGS.num_skills)\n    else:\n      preset_skill = None\n\n    eval_env = get_environment(env_name=FLAGS.environment)\n    eval_env = wrap_env(\n        skill_wrapper.SkillWrapper(\n            eval_env,\n            num_latent_skills=FLAGS.num_skills,\n            skill_type=FLAGS.skill_type,\n            preset_skill=preset_skill,\n            min_steps_before_resample=FLAGS.min_steps_before_resample,\n            resample_prob=FLAGS.resample_prob),\n        max_episode_steps=FLAGS.max_env_steps)\n\n    # record videos for sampled trajectories\n    if vid_name is not None:\n      full_vid_name = vid_name + '_' + str(idx)\n      eval_env = video_wrapper.VideoWrapper(eval_env, base_path=eval_dir, base_name=full_vid_name)\n\n    mean_reward = 0.\n    per_skill_evaluations = 1\n    predict_trajectory_steps = 0\n    # trajectories_per_skill = []\n    # predicted_trajectories_per_skill = []\n    for eval_idx in range(per_skill_evaluations):\n      eval_trajectory = run_on_env(\n          eval_env,\n          eval_policy,\n          dynamics=dynamics,\n          predict_trajectory_steps=predict_trajectory_steps,\n          return_data=True,\n          close_environment=True if eval_idx == per_skill_evaluations -\n          1 else False)\n\n      trajectory_coordinates = np.array([\n          eval_trajectory[step_idx][0][:2]\n          for step_idx in range(len(eval_trajectory))\n      ])\n\n      # trajectory_states = np.array([\n      #     eval_trajectory[step_idx][0]\n      #     for step_idx in range(len(eval_trajectory))\n      # ])\n      # trajectories_per_skill.append(trajectory_states)\n      if plot_name is not None:\n        plt.plot(\n            trajectory_coordinates[:, 0],\n            trajectory_coordinates[:, 1],\n            style_map[idx % len(style_map)],\n            label=(str(idx) if eval_idx == 0 else None))\n        # plt.plot(\n        #     trajectory_coordinates[0, 0],\n        #     trajectory_coordinates[0, 1],\n        #     marker='o',\n        #     color=style_map[idx % len(style_map)][0])\n        if predict_trajectory_steps > 0:\n          # predicted_states = np.array([\n          #     eval_trajectory[step_idx][-1]\n          #     for step_idx in range(len(eval_trajectory))\n          # ])\n          # predicted_trajectories_per_skill.append(predicted_states)\n          for step_idx in range(len(eval_trajectory)):\n            if step_idx % 20 == 0:\n              plt.plot(eval_trajectory[step_idx][-1][:, 0],\n                       eval_trajectory[step_idx][-1][:, 1], 'k:')\n\n      mean_reward += np.mean([\n          eval_trajectory[step_idx][-1]\n          for step_idx in range(len(eval_trajectory))\n      ])\n      metadata.write(\n          str(idx) + ' ' + str(preset_skill) + ' ' +\n          str(trajectory_coordinates[-1, :]) + '\\n')\n\n    # all_predicted_trajectories.append(\n    #     np.stack(predicted_trajectories_per_skill))\n    # all_trajectories.append(np.stack(trajectories_per_skill))\n\n  # all_predicted_trajectories = np.stack(all_predicted_trajectories)\n  # all_trajectories = np.stack(all_trajectories)\n  # print(all_trajectories.shape, all_predicted_trajectories.shape)\n  # pkl.dump(\n  #     all_trajectories,\n  #     tf.io.gfile.GFile(\n  #         os.path.join(vid_dir, 'skill_dynamics_full_obs_r100_actual_trajectories.pkl'),\n  #         'wb'))\n  # pkl.dump(\n  #     all_predicted_trajectories,\n  #     tf.io.gfile.GFile(\n  #         os.path.join(vid_dir, 'skill_dynamics_full_obs_r100_predicted_trajectories.pkl'),\n  #         'wb'))\n  if plot_name is not None:\n    full_image_name = plot_name + '.png'\n\n    # to save images while writing to CNS\n    buf = io.BytesIO()\n    # plt.title('Trajectories in Continuous Skill Space')\n    plt.savefig(buf, dpi=600, bbox_inches='tight')\n    buf.seek(0)\n    image = tf.io.gfile.GFile(os.path.join(eval_dir, full_image_name), 'w')\n    image.write(buf.read(-1))\n\n    # clear before next plot\n    plt.clf()\n\n\n# discrete primitives only, useful with skill-dynamics\ndef eval_planning(env,\n                  dynamics,\n                  policy,\n                  latent_action_space_size,\n                  episode_horizon,\n                  planning_horizon=1,\n                  primitive_horizon=10,\n                  **kwargs):\n  \"\"\"env: tf-agents environment without the skill wrapper.\"\"\"\n  global goal_coord\n\n  # assuming only discrete action spaces\n  high_level_action_space = np.eye(latent_action_space_size)\n  time_step = env.reset()\n\n  actual_reward = 0.\n  actual_coords = [np.expand_dims(time_step.observation[:2], 0)]\n  predicted_coords = []\n\n  # planning loop\n  for _ in range(episode_horizon // primitive_horizon):\n    running_reward = np.zeros(latent_action_space_size)\n    running_cur_state = np.array([process_observation(time_step.observation)] *\n                                 latent_action_space_size)\n    cur_coord_predicted = [np.expand_dims(running_cur_state[:, :2], 1)]\n\n    # simulate all high level actions for K steps\n    for _ in range(planning_horizon):\n      predicted_next_state = dynamics.predict_state(running_cur_state,\n                                                    high_level_action_space)\n      cur_coord_predicted.append(np.expand_dims(predicted_next_state[:, :2], 1))\n\n      # update running stuff\n      running_reward += env.compute_reward(running_cur_state,\n                                           predicted_next_state)\n      running_cur_state = predicted_next_state\n\n    predicted_coords.append(np.concatenate(cur_coord_predicted, axis=1))\n\n    selected_high_level_action = np.argmax(running_reward)\n    for _ in range(primitive_horizon):\n      # concatenated observation\n      skill_concat_observation = np.concatenate([\n          time_step.observation,\n          high_level_action_space[selected_high_level_action]\n      ],\n                                                axis=0)\n      next_time_step = env.step(\n          np.clip(\n              policy.action(\n                  hide_coords(\n                      time_step._replace(\n                          observation=skill_concat_observation))).action,\n              -FLAGS.action_clipping, FLAGS.action_clipping))\n      actual_reward += next_time_step.reward\n\n      # prepare for next iteration\n      time_step = next_time_step\n      actual_coords.append(np.expand_dims(time_step.observation[:2], 0))\n\n  actual_coords = np.concatenate(actual_coords)\n  return actual_reward, actual_coords, predicted_coords\n\n\ndef eval_mppi(\n    env,\n    dynamics,\n    policy,\n    latent_action_space_size,\n    episode_horizon,\n    planning_horizon=1,\n    primitive_horizon=10,\n    num_candidate_sequences=50,\n    refine_steps=10,\n    mppi_gamma=10,\n    prior_type='normal',\n    smoothing_beta=0.9,\n    # no need to change generally\n    sparsify_rewards=False,\n    # only for uniform prior mode\n    top_primitives=5):\n  \"\"\"env: tf-agents environment without the skill wrapper.\n\n     dynamics: skill-dynamics model learnt by DADS.\n     policy: skill-conditioned policy learnt by DADS.\n     planning_horizon: number of latent skills to plan in the future.\n     primitive_horizon: number of steps each skill is executed for.\n     num_candidate_sequences: number of samples executed from the prior per\n     refining step of planning.\n     refine_steps: number of steps for which the plan is iterated upon before\n     execution (number of optimization steps).\n     mppi_gamma: MPPI parameter for reweighing rewards.\n     prior_type: 'normal' implies MPPI, 'uniform' implies a CEM like algorithm\n     (not tested).\n     smoothing_beta: for planning_horizon > 1, the every sampled plan is\n     smoothed using EMA. (0-> no smoothing, 1-> perfectly smoothed)\n     sparsify_rewards: converts a dense reward problem into a sparse reward\n     (avoid using).\n     top_primitives: number of elites to choose, if using CEM (not tested).\n  \"\"\"\n\n  step_idx = 0\n\n  def _smooth_primitive_sequences(primitive_sequences):\n    for planning_idx in range(1, primitive_sequences.shape[1]):\n      primitive_sequences[:,\n                          planning_idx, :] = smoothing_beta * primitive_sequences[:, planning_idx - 1, :] + (\n                              1. - smoothing_beta\n                          ) * primitive_sequences[:, planning_idx, :]\n\n    return primitive_sequences\n\n  def _get_init_primitive_parameters():\n    if prior_type == 'normal':\n      prior_mean = functools.partial(\n          np.random.multivariate_normal,\n          mean=np.zeros(latent_action_space_size),\n          cov=np.diag(np.ones(latent_action_space_size)))\n      prior_cov = lambda: 1.5 * np.diag(np.ones(latent_action_space_size))\n      return [prior_mean(), prior_cov()]\n\n    elif prior_type == 'uniform':\n      prior_low = lambda: np.array([-1.] * latent_action_space_size)\n      prior_high = lambda: np.array([1.] * latent_action_space_size)\n      return [prior_low(), prior_high()]\n\n  def _sample_primitives(params):\n    if prior_type == 'normal':\n      sample = np.random.multivariate_normal(*params)\n    elif prior_type == 'uniform':\n      sample = np.random.uniform(*params)\n    return np.clip(sample, -1., 1.)\n\n  # update new primitive means for horizon sequence\n  def _update_parameters(candidates, reward, primitive_parameters):\n    # a more regular mppi\n    if prior_type == 'normal':\n      reward = np.exp(mppi_gamma * (reward - np.max(reward)))\n      reward = reward / (reward.sum() + 1e-10)\n      new_means = (candidates.T * reward).T.sum(axis=0)\n\n      for planning_idx in range(candidates.shape[1]):\n        primitive_parameters[planning_idx][0] = new_means[planning_idx]\n\n    # TODO(architsh): closer to cross-entropy/shooting method, figure out a better update\n    elif prior_type == 'uniform':\n      chosen_candidates = candidates[np.argsort(reward)[-top_primitives:]]\n      candidates_min = np.min(chosen_candidates, axis=0)\n      candidates_max = np.max(chosen_candidates, axis=0)\n\n      for planning_idx in range(candidates.shape[1]):\n        primitive_parameters[planning_idx][0] = candidates_min[planning_idx]\n        primitive_parameters[planning_idx][1] = candidates_max[planning_idx]\n\n  def _get_expected_primitive(params):\n    if prior_type == 'normal':\n      return params[0]\n    elif prior_type == 'uniform':\n      return (params[0] + params[1]) / 2\n\n  time_step = env.reset()\n  actual_coords = [np.expand_dims(time_step.observation[:2], 0)]\n  actual_reward = 0.\n  distance_to_goal_array = []\n\n  primitive_parameters = []\n  chosen_primitives = []\n  for _ in range(planning_horizon):\n    primitive_parameters.append(_get_init_primitive_parameters())\n\n  for _ in range(episode_horizon // primitive_horizon):\n    for _ in range(refine_steps):\n      # generate candidates sequences for primitives\n      candidate_primitive_sequences = []\n      for _ in range(num_candidate_sequences):\n        candidate_primitive_sequences.append([\n            _sample_primitives(primitive_parameters[planning_idx])\n            for planning_idx in range(planning_horizon)\n        ])\n\n      candidate_primitive_sequences = np.array(candidate_primitive_sequences)\n      candidate_primitive_sequences = _smooth_primitive_sequences(\n          candidate_primitive_sequences)\n\n      running_cur_state = np.array(\n          [process_observation(time_step.observation)] *\n          num_candidate_sequences)\n      running_reward = np.zeros(num_candidate_sequences)\n      for planning_idx in range(planning_horizon):\n        cur_primitives = candidate_primitive_sequences[:, planning_idx, :]\n        for _ in range(primitive_horizon):\n          predicted_next_state = dynamics.predict_state(running_cur_state,\n                                                        cur_primitives)\n\n          # update running stuff\n          dense_reward = env.compute_reward(running_cur_state,\n                                            predicted_next_state)\n          # modification for sparse_reward\n          if sparsify_rewards:\n            sparse_reward = 5.0 * (dense_reward > -2) + 0.0 * (\n                dense_reward <= -2)\n            running_reward += sparse_reward\n          else:\n            running_reward += dense_reward\n\n          running_cur_state = predicted_next_state\n\n      _update_parameters(candidate_primitive_sequences, running_reward,\n                         primitive_parameters)\n\n    chosen_primitive = _get_expected_primitive(primitive_parameters[0])\n    chosen_primitives.append(chosen_primitive)\n\n    # a loop just to check what the chosen primitive is expected to do\n    # running_cur_state = np.array([process_observation(time_step.observation)])\n    # for _ in range(primitive_horizon):\n    #   predicted_next_state = dynamics.predict_state(\n    #       running_cur_state, np.expand_dims(chosen_primitive, 0))\n    #   running_cur_state = predicted_next_state\n    # print('Predicted next co-ordinates:', running_cur_state[0, :2])\n\n    for _ in range(primitive_horizon):\n      # concatenated observation\n      skill_concat_observation = np.concatenate(\n          [time_step.observation, chosen_primitive], axis=0)\n      next_time_step = env.step(\n          np.clip(\n              policy.action(\n                  hide_coords(\n                      time_step._replace(\n                          observation=skill_concat_observation))).action,\n              -FLAGS.action_clipping, FLAGS.action_clipping))\n      actual_reward += next_time_step.reward\n      distance_to_goal_array.append(next_time_step.reward)\n      # prepare for next iteration\n      time_step = next_time_step\n      actual_coords.append(np.expand_dims(time_step.observation[:2], 0))\n      step_idx += 1\n      # print(step_idx)\n    # print('Actual next co-ordinates:', actual_coords[-1])\n\n    primitive_parameters.pop(0)\n    primitive_parameters.append(_get_init_primitive_parameters())\n\n  actual_coords = np.concatenate(actual_coords)\n  return actual_reward, actual_coords, np.array(\n      chosen_primitives), distance_to_goal_array\n\n\ndef main(_):\n  # setting up\n  start_time = time.time()\n  tf.compat.v1.enable_resource_variables()\n  tf.compat.v1.disable_eager_execution()\n  logging.set_verbosity(logging.INFO)\n  global observation_omit_size, goal_coord, sample_count, iter_count, episode_size_buffer, episode_return_buffer\n\n  root_dir = os.path.abspath(os.path.expanduser(FLAGS.logdir))\n  if not tf.io.gfile.exists(root_dir):\n    tf.io.gfile.makedirs(root_dir)\n  log_dir = os.path.join(root_dir, FLAGS.environment)\n  \n  if not tf.io.gfile.exists(log_dir):\n    tf.io.gfile.makedirs(log_dir)\n  save_dir = os.path.join(log_dir, 'models')\n  if not tf.io.gfile.exists(save_dir):\n    tf.io.gfile.makedirs(save_dir)\n\n  print('directory for recording experiment data:', log_dir)\n\n  # in case training is paused and resumed, so can be restored\n  try:\n    sample_count = np.load(os.path.join(log_dir, 'sample_count.npy')).tolist()\n    iter_count = np.load(os.path.join(log_dir, 'iter_count.npy')).tolist()\n    episode_size_buffer = np.load(os.path.join(log_dir, 'episode_size_buffer.npy')).tolist()\n    episode_return_buffer = np.load(os.path.join(log_dir, 'episode_return_buffer.npy')).tolist()\n  except:\n    sample_count = 0\n    iter_count = 0\n    episode_size_buffer = []\n    episode_return_buffer = []\n\n  train_summary_writer = tf.compat.v2.summary.create_file_writer(\n      os.path.join(log_dir, 'train', 'in_graph_data'), flush_millis=10 * 1000)\n  train_summary_writer.set_as_default()\n\n  global_step = tf.compat.v1.train.get_or_create_global_step()\n  with tf.compat.v2.summary.record_if(True):\n    # environment related stuff\n    py_env = get_environment(env_name=FLAGS.environment)\n    py_env = wrap_env(\n        skill_wrapper.SkillWrapper(\n            py_env,\n            num_latent_skills=FLAGS.num_skills,\n            skill_type=FLAGS.skill_type,\n            preset_skill=None,\n            min_steps_before_resample=FLAGS.min_steps_before_resample,\n            resample_prob=FLAGS.resample_prob),\n        max_episode_steps=FLAGS.max_env_steps)\n\n    # all specifications required for all networks and agents\n    py_action_spec = py_env.action_spec()\n    tf_action_spec = tensor_spec.from_spec(\n        py_action_spec)  # policy, critic action spec\n    env_obs_spec = py_env.observation_spec()\n    py_env_time_step_spec = ts.time_step_spec(\n        env_obs_spec)  # replay buffer time_step spec\n    if observation_omit_size > 0:\n      agent_obs_spec = array_spec.BoundedArraySpec(\n          (env_obs_spec.shape[0] - observation_omit_size,),\n          env_obs_spec.dtype,\n          minimum=env_obs_spec.minimum,\n          maximum=env_obs_spec.maximum,\n          name=env_obs_spec.name)  # policy, critic observation spec\n    else:\n      agent_obs_spec = env_obs_spec\n    py_agent_time_step_spec = ts.time_step_spec(\n        agent_obs_spec)  # policy, critic time_step spec\n    tf_agent_time_step_spec = tensor_spec.from_spec(py_agent_time_step_spec)\n\n    if not FLAGS.reduced_observation:\n      skill_dynamics_observation_size = (\n          py_env_time_step_spec.observation.shape[0] - FLAGS.num_skills)\n    else:\n      skill_dynamics_observation_size = FLAGS.reduced_observation\n\n    # TODO(architsh): Shift co-ordinate hiding to actor_net and critic_net (good for futher image based processing as well)\n    actor_net = actor_distribution_network.ActorDistributionNetwork(\n        tf_agent_time_step_spec.observation,\n        tf_action_spec,\n        fc_layer_params=(FLAGS.hidden_layer_size,) * 2,\n        continuous_projection_net=_normal_projection_net)\n\n    critic_net = critic_network.CriticNetwork(\n        (tf_agent_time_step_spec.observation, tf_action_spec),\n        observation_fc_layer_params=None,\n        action_fc_layer_params=None,\n        joint_fc_layer_params=(FLAGS.hidden_layer_size,) * 2)\n\n    if FLAGS.skill_dynamics_relabel_type is not None and 'importance_sampling' in FLAGS.skill_dynamics_relabel_type and FLAGS.is_clip_eps > 1.0:\n      reweigh_batches_flag = True\n    else:\n      reweigh_batches_flag = False\n\n    agent = dads_agent.DADSAgent(\n        # DADS parameters\n        save_dir,\n        skill_dynamics_observation_size,\n        observation_modify_fn=process_observation,\n        restrict_input_size=observation_omit_size,\n        latent_size=FLAGS.num_skills,\n        latent_prior=FLAGS.skill_type,\n        prior_samples=FLAGS.random_skills,\n        fc_layer_params=(FLAGS.hidden_layer_size,) * 2,\n        normalize_observations=FLAGS.normalize_data,\n        network_type=FLAGS.graph_type,\n        num_mixture_components=FLAGS.num_components,\n        fix_variance=FLAGS.fix_variance,\n        reweigh_batches=reweigh_batches_flag,\n        skill_dynamics_learning_rate=FLAGS.skill_dynamics_lr,\n        # SAC parameters\n        time_step_spec=tf_agent_time_step_spec,\n        action_spec=tf_action_spec,\n        actor_network=actor_net,\n        critic_network=critic_net,\n        target_update_tau=0.005,\n        target_update_period=1,\n        actor_optimizer=tf.compat.v1.train.AdamOptimizer(\n            learning_rate=FLAGS.agent_lr),\n        critic_optimizer=tf.compat.v1.train.AdamOptimizer(\n            learning_rate=FLAGS.agent_lr),\n        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(\n            learning_rate=FLAGS.agent_lr),\n        td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,\n        gamma=FLAGS.agent_gamma,\n        reward_scale_factor=1. /\n        (FLAGS.agent_entropy + 1e-12),\n        gradient_clipping=None,\n        debug_summaries=FLAGS.debug,\n        train_step_counter=global_step)\n\n    # evaluation policy\n    eval_policy = py_tf_policy.PyTFPolicy(agent.policy)\n\n    # collection policy\n    if FLAGS.collect_policy == 'default':\n      collect_policy = py_tf_policy.PyTFPolicy(agent.collect_policy)\n    elif FLAGS.collect_policy == 'ou_noise':\n      collect_policy = py_tf_policy.PyTFPolicy(\n          ou_noise_policy.OUNoisePolicy(\n              agent.collect_policy, ou_stddev=0.2, ou_damping=0.15))\n\n    # relabelling policy deals with batches of data, unlike collect and eval\n    relabel_policy = py_tf_policy.PyTFPolicy(agent.collect_policy)\n\n    # constructing a replay buffer, need a python spec\n    policy_step_spec = policy_step.PolicyStep(\n        action=py_action_spec, state=(), info=())\n\n    if FLAGS.skill_dynamics_relabel_type is not None and 'importance_sampling' in FLAGS.skill_dynamics_relabel_type and FLAGS.is_clip_eps > 1.0:\n      policy_step_spec = policy_step_spec._replace(\n          info=policy_step.set_log_probability(\n              policy_step_spec.info,\n              array_spec.ArraySpec(\n                  shape=(), dtype=np.float32, name='action_log_prob')))\n\n    trajectory_spec = from_transition(py_env_time_step_spec, policy_step_spec,\n                                      py_env_time_step_spec)\n    capacity = FLAGS.replay_buffer_capacity\n    # for all the data collected\n    rbuffer = py_uniform_replay_buffer.PyUniformReplayBuffer(\n        capacity=capacity, data_spec=trajectory_spec)\n\n    if FLAGS.train_skill_dynamics_on_policy:\n      # for on-policy data (if something special is required)\n      on_buffer = py_uniform_replay_buffer.PyUniformReplayBuffer(\n          capacity=FLAGS.initial_collect_steps + FLAGS.collect_steps + 10,\n          data_spec=trajectory_spec)\n\n    # insert experience manually with relabelled rewards and skills\n    agent.build_agent_graph()\n    agent.build_skill_dynamics_graph()\n    agent.create_savers()\n\n    # saving this way requires the saver to be out the object\n    train_checkpointer = common.Checkpointer(\n        ckpt_dir=os.path.join(save_dir, 'agent'),\n        agent=agent,\n        global_step=global_step)\n    policy_checkpointer = common.Checkpointer(\n        ckpt_dir=os.path.join(save_dir, 'policy'),\n        policy=agent.policy,\n        global_step=global_step)\n    rb_checkpointer = common.Checkpointer(\n        ckpt_dir=os.path.join(save_dir, 'replay_buffer'),\n        max_to_keep=1,\n        replay_buffer=rbuffer)\n\n    setup_time = time.time() - start_time\n    print('Setup time:', setup_time)\n\n    with tf.compat.v1.Session().as_default() as sess:\n      train_checkpointer.initialize_or_restore(sess)\n      rb_checkpointer.initialize_or_restore(sess)\n      agent.set_sessions(\n          initialize_or_restore_skill_dynamics=True, session=sess)\n\n      meta_start_time = time.time()\n      if FLAGS.run_train:\n\n        train_writer = tf.compat.v1.summary.FileWriter(\n            os.path.join(log_dir, 'train'), sess.graph)\n        common.initialize_uninitialized_variables(sess)\n        sess.run(train_summary_writer.init())\n\n        time_step = py_env.reset()\n        episode_size_buffer.append(0)\n        episode_return_buffer.append(0.)\n\n        # maintain a buffer of episode lengths\n        def _process_episodic_data(ep_buffer, cur_data):\n          ep_buffer[-1] += cur_data[0]\n          ep_buffer += cur_data[1:]\n\n          # only keep the last 100 episodes\n          if len(ep_buffer) > 101:\n            ep_buffer = ep_buffer[-101:]\n\n        # remove invalid transitions from the replay buffer\n        def _filter_trajectories(trajectory):\n          # two consecutive samples in the buffer might not have been consecutive in the episode\n          valid_indices = (trajectory.step_type[:, 0] != 2)\n\n          return nest.map_structure(lambda x: x[valid_indices], trajectory)\n\n        if iter_count == 0:\n          start_time = time.time()\n          time_step, collect_info = collect_experience(\n              py_env,\n              time_step,\n              collect_policy,\n              buffer_list=[rbuffer] if not FLAGS.train_skill_dynamics_on_policy\n              else [rbuffer, on_buffer],\n              num_steps=FLAGS.initial_collect_steps)\n          _process_episodic_data(episode_size_buffer,\n                                 collect_info['episode_sizes'])\n          _process_episodic_data(episode_return_buffer,\n                                 collect_info['episode_return'])\n          sample_count += FLAGS.initial_collect_steps\n          initial_collect_time = time.time() - start_time\n          print('Initial data collection time:', initial_collect_time)\n\n        agent_end_train_time = time.time()\n        while iter_count < FLAGS.num_epochs:\n          print('iteration index:', iter_count)\n\n          # model save\n          if FLAGS.save_model is not None and iter_count % FLAGS.save_freq == 0:\n            print('Saving stuff')\n            train_checkpointer.save(global_step=iter_count)\n            policy_checkpointer.save(global_step=iter_count)\n            rb_checkpointer.save(global_step=iter_count)\n            agent.save_variables(global_step=iter_count)\n\n            np.save(os.path.join(log_dir, 'sample_count'), sample_count)\n            np.save(os.path.join(log_dir, 'episode_size_buffer'), episode_size_buffer)\n            np.save(os.path.join(log_dir, 'episode_return_buffer'), episode_return_buffer)\n            np.save(os.path.join(log_dir, 'iter_count'), iter_count)\n\n          collect_start_time = time.time()\n          print('intermediate time:', collect_start_time - agent_end_train_time)\n          time_step, collect_info = collect_experience(\n              py_env,\n              time_step,\n              collect_policy,\n              buffer_list=[rbuffer] if not FLAGS.train_skill_dynamics_on_policy\n              else [rbuffer, on_buffer],\n              num_steps=FLAGS.collect_steps)\n          sample_count += FLAGS.collect_steps\n          _process_episodic_data(episode_size_buffer,\n                                 collect_info['episode_sizes'])\n          _process_episodic_data(episode_return_buffer,\n                                 collect_info['episode_return'])\n          collect_end_time = time.time()\n          print('Iter collection time:', collect_end_time - collect_start_time)\n\n          # only for debugging skill relabelling\n          if iter_count >= 1 and FLAGS.debug_skill_relabelling:\n            trajectory_sample = rbuffer.get_next(\n                sample_batch_size=5, num_steps=2)\n            trajectory_sample = _filter_trajectories(trajectory_sample)\n            # trajectory_sample, _ = relabel_skill(\n            #     trajectory_sample,\n            #     relabel_type='policy',\n            #     cur_policy=relabel_policy,\n            #     cur_skill_dynamics=agent.skill_dynamics)\n            trajectory_sample, is_weights = relabel_skill(\n                trajectory_sample,\n                relabel_type='importance_sampling',\n                cur_policy=relabel_policy,\n                cur_skill_dynamics=agent.skill_dynamics)\n            print(is_weights)\n\n          skill_dynamics_buffer = rbuffer\n          if FLAGS.train_skill_dynamics_on_policy:\n            skill_dynamics_buffer = on_buffer\n\n          # TODO(architsh): clear_buffer_every_iter needs to fix these as well\n          for _ in range(1 if FLAGS.clear_buffer_every_iter else FLAGS\n                         .skill_dyn_train_steps):\n            if FLAGS.clear_buffer_every_iter:\n              trajectory_sample = rbuffer.gather_all_transitions()\n            else:\n              trajectory_sample = skill_dynamics_buffer.get_next(\n                  sample_batch_size=FLAGS.skill_dyn_batch_size, num_steps=2)\n            trajectory_sample = _filter_trajectories(trajectory_sample)\n\n            # is_weights is None usually, unless relabelling involves importance_sampling\n            trajectory_sample, is_weights = relabel_skill(\n                trajectory_sample,\n                relabel_type=FLAGS.skill_dynamics_relabel_type,\n                cur_policy=relabel_policy,\n                cur_skill_dynamics=agent.skill_dynamics)\n            input_obs = process_observation(\n                trajectory_sample.observation[:, 0, :-FLAGS.num_skills])\n            cur_skill = trajectory_sample.observation[:, 0, -FLAGS.num_skills:]\n            target_obs = process_observation(\n                trajectory_sample.observation[:, 1, :-FLAGS.num_skills])\n            if FLAGS.clear_buffer_every_iter:\n              agent.skill_dynamics.train(\n                  input_obs,\n                  cur_skill,\n                  target_obs,\n                  batch_size=FLAGS.skill_dyn_batch_size,\n                  batch_weights=is_weights,\n                  num_steps=FLAGS.skill_dyn_train_steps)\n            else:\n              agent.skill_dynamics.train(\n                  input_obs,\n                  cur_skill,\n                  target_obs,\n                  batch_size=-1,\n                  batch_weights=is_weights,\n                  num_steps=1)\n\n          if FLAGS.train_skill_dynamics_on_policy:\n            on_buffer.clear()\n\n          skill_dynamics_end_train_time = time.time()\n          print('skill_dynamics train time:',\n                skill_dynamics_end_train_time - collect_end_time)\n\n          running_dads_reward, running_logp, running_logp_altz = [], [], []\n\n          # agent train loop analysis\n          within_agent_train_time = time.time()\n          sampling_time_arr, filtering_time_arr, relabelling_time_arr, train_time_arr = [], [], [], []\n          for _ in range(\n              1 if FLAGS.clear_buffer_every_iter else FLAGS.agent_train_steps):\n            if FLAGS.clear_buffer_every_iter:\n              trajectory_sample = rbuffer.gather_all_transitions()\n            else:\n              trajectory_sample = rbuffer.get_next(\n                  sample_batch_size=FLAGS.agent_batch_size, num_steps=2)\n\n            buffer_sampling_time = time.time()\n            sampling_time_arr.append(buffer_sampling_time -\n                                     within_agent_train_time)\n            trajectory_sample = _filter_trajectories(trajectory_sample)\n\n            filtering_time = time.time()\n            filtering_time_arr.append(filtering_time - buffer_sampling_time)\n            trajectory_sample, _ = relabel_skill(\n                trajectory_sample,\n                relabel_type=FLAGS.agent_relabel_type,\n                cur_policy=relabel_policy,\n                cur_skill_dynamics=agent.skill_dynamics)\n            relabelling_time = time.time()\n            relabelling_time_arr.append(relabelling_time - filtering_time)\n\n            # need to match the assert structure\n            if FLAGS.skill_dynamics_relabel_type is not None and 'importance_sampling' in FLAGS.skill_dynamics_relabel_type:\n              trajectory_sample = trajectory_sample._replace(policy_info=())\n\n            if not FLAGS.clear_buffer_every_iter:\n              dads_reward, info = agent.train_loop(\n                  trajectory_sample,\n                  recompute_reward=True,  # turn False for normal SAC training\n                  batch_size=-1,\n                  num_steps=1)\n            else:\n              dads_reward, info = agent.train_loop(\n                  trajectory_sample,\n                  recompute_reward=True,  # turn False for normal SAC training\n                  batch_size=FLAGS.agent_batch_size,\n                  num_steps=FLAGS.agent_train_steps)\n\n            within_agent_train_time = time.time()\n            train_time_arr.append(within_agent_train_time - relabelling_time)\n            if dads_reward is not None:\n              running_dads_reward.append(dads_reward)\n              running_logp.append(info['logp'])\n              running_logp_altz.append(info['logp_altz'])\n\n          agent_end_train_time = time.time()\n          print('agent train time:',\n                agent_end_train_time - skill_dynamics_end_train_time)\n          print('\\t sampling time:', np.sum(sampling_time_arr))\n          print('\\t filtering_time:', np.sum(filtering_time_arr))\n          print('\\t relabelling time:', np.sum(relabelling_time_arr))\n          print('\\t train_time:', np.sum(train_time_arr))\n\n          if len(episode_size_buffer) > 1:\n            train_writer.add_summary(\n                tf.compat.v1.Summary(value=[\n                    tf.compat.v1.Summary.Value(\n                        tag='episode_size',\n                        simple_value=np.mean(episode_size_buffer[:-1]))\n                ]), sample_count)\n          if len(episode_return_buffer) > 1:\n            train_writer.add_summary(\n                tf.compat.v1.Summary(value=[\n                    tf.compat.v1.Summary.Value(\n                        tag='episode_return',\n                        simple_value=np.mean(episode_return_buffer[:-1]))\n                ]), sample_count)\n          train_writer.add_summary(\n              tf.compat.v1.Summary(value=[\n                  tf.compat.v1.Summary.Value(\n                      tag='dads/reward',\n                      simple_value=np.mean(\n                          np.concatenate(running_dads_reward)))\n              ]), sample_count)\n\n          train_writer.add_summary(\n              tf.compat.v1.Summary(value=[\n                  tf.compat.v1.Summary.Value(\n                      tag='dads/logp',\n                      simple_value=np.mean(np.concatenate(running_logp)))\n              ]), sample_count)\n          train_writer.add_summary(\n              tf.compat.v1.Summary(value=[\n                  tf.compat.v1.Summary.Value(\n                      tag='dads/logp_altz',\n                      simple_value=np.mean(np.concatenate(running_logp_altz)))\n              ]), sample_count)\n\n          if FLAGS.clear_buffer_every_iter:\n            rbuffer.clear()\n            time_step = py_env.reset()\n            episode_size_buffer = [0]\n            episode_return_buffer = [0.]\n\n          # within train loop evaluation\n          if FLAGS.record_freq is not None and iter_count % FLAGS.record_freq == 0:\n            cur_vid_dir = os.path.join(log_dir, 'videos', str(iter_count))\n            tf.io.gfile.makedirs(cur_vid_dir)\n            eval_loop(\n                cur_vid_dir,\n                eval_policy,\n                dynamics=agent.skill_dynamics,\n                vid_name=FLAGS.vid_name,\n                plot_name='traj_plot')\n\n          iter_count += 1\n\n        py_env.close()\n\n        print('Final statistics:')\n        print('\\ttotal time for %d epochs: %f' %(FLAGS.num_epochs, time.time() - meta_start_time))\n        print('\\tsteps collected during this time: %d' %(rbuffer.size))\n\n      # final evaluation, if any\n      if FLAGS.run_eval:\n        vid_dir = os.path.join(log_dir, 'videos', 'final_eval')\n        if not tf.io.gfile.exists(vid_dir):\n          tf.io.gfile.makedirs(vid_dir)\n        vid_name = FLAGS.vid_name\n\n        # generic skill evaluation\n        if FLAGS.deterministic_eval or FLAGS.num_evals > 0:\n          eval_loop(\n              vid_dir,\n              eval_policy,\n              dynamics=agent.skill_dynamics,\n              vid_name=vid_name,\n              plot_name='traj_plot')\n\n        # for planning the evaluation directory is changed to save directory\n        eval_dir = os.path.join(log_dir, 'eval')\n        goal_coord_list = [\n            np.array([10.0, 10.0]),\n            np.array([-10.0, 10.0]),\n            np.array([-10.0, -10.0]),\n            np.array([10.0, -10.0]),\n            np.array([0.0, -10.0]),\n            np.array([5.0, 10.0])\n        ]\n\n        eval_dir = os.path.join(eval_dir, 'mpc_eval')\n        if not tf.io.gfile.exists(eval_dir):\n          tf.io.gfile.makedirs(eval_dir)\n        save_label = 'goal_'\n        if 'discrete' in FLAGS.skill_type:\n          planning_fn = eval_planning\n          \n        else:\n          planning_fn = eval_mppi\n\n        color_map = ['b', 'g', 'r', 'c', 'm', 'y', 'k']\n\n        average_reward_all_goals = []\n        _, ax1 = plt.subplots(1, 1)\n        ax1.set_xlim(-20, 20)\n        ax1.set_ylim(-20, 20)\n\n        final_text = open(os.path.join(eval_dir, 'eval_data.txt'), 'w')\n\n        # goal_list = []\n        # for r in range(4, 50):\n        #   for _ in range(10):\n        #     theta = np.random.uniform(-np.pi, np.pi)\n        #     goal_x = r * np.cos(theta)\n        #     goal_y = r * np.sin(theta)\n        #     goal_list.append([r, theta, goal_x, goal_y])\n\n        # def _sample_goal():\n        #   goal_coords = np.random.uniform(0, 5, size=2)\n        #   # while np.linalg.norm(goal_coords) < np.linalg.norm([10., 10.]):\n        #   #   goal_coords = np.random.uniform(-25, 25, size=2)\n        #   return goal_coords\n\n        # goal_coord_list = [_sample_goal() for _ in range(50)]\n\n        for goal_idx, goal_coord in enumerate(goal_coord_list):\n          # for goal_idx in range(1):\n          print('Trying to reach the goal:', goal_coord)\n          # eval_plan_env = video_wrapper.VideoWrapper(\n          #     get_environment(env_name=FLAGS.environment + '_goal')\n          #     base_path=eval_dir,\n          #     base_name=save_label + '_' + str(goal_idx)))\n          # goal_coord = np.array(item[2:])\n          eval_plan_env = get_environment(env_name=FLAGS.environment + '_goal')\n          # _, (ax1, ax2) = plt.subplots(1, 2)\n          # ax1.set_xlim(-12, 12)\n          # ax1.set_ylim(-12, 12)\n          # ax2.set_xlim(-1, 1)\n          # ax2.set_ylim(-1, 1)\n          ax1.plot(goal_coord[0], goal_coord[1], marker='x', color='k')\n          reward_list = []\n\n          def _steps_to_goal(dist_array):\n            for idx in range(len(dist_array)):\n              if -dist_array[idx] < 1.5:\n                return idx\n            return -1\n\n          for _ in range(1):\n            reward, actual_coords, primitives, distance_to_goal_array = planning_fn(\n                eval_plan_env, agent.skill_dynamics, eval_policy,\n                latent_action_space_size=FLAGS.num_skills,\n                episode_horizon=FLAGS.max_env_steps,\n                planning_horizon=FLAGS.planning_horizon,\n                primitive_horizon=FLAGS.primitive_horizon,\n                num_candidate_sequences=FLAGS.num_candidate_sequences,\n                refine_steps=FLAGS.refine_steps,\n                mppi_gamma=FLAGS.mppi_gamma,\n                prior_type=FLAGS.prior_type,\n                smoothing_beta=FLAGS.smoothing_beta,\n                top_primitives=FLAGS.top_primitives\n            )\n            reward /= (FLAGS.max_env_steps * np.linalg.norm(goal_coord))\n            ax1.plot(\n                actual_coords[:, 0],\n                actual_coords[:, 1],\n                color_map[goal_idx % len(color_map)],\n                linewidth=1)\n            # ax2.plot(\n            #     primitives[:, 0],\n            #     primitives[:, 1],\n            #     marker='x',\n            #     color=color_map[try_idx % len(color_map)],\n            #     linewidth=1)\n            final_text.write(','.join([\n                str(item) for item in [\n                    goal_coord[0],\n                    goal_coord[1],\n                    reward,\n                    _steps_to_goal(distance_to_goal_array),\n                    distance_to_goal_array[-3],\n                    distance_to_goal_array[-2],\n                    distance_to_goal_array[-1],\n                ]\n            ]) + '\\n')\n            print(reward)\n            reward_list.append(reward)\n\n          eval_plan_env.close()\n          average_reward_all_goals.append(np.mean(reward_list))\n          print('Average reward:', np.mean(reward_list))\n\n        final_text.close()\n        # to save images while writing to CNS\n        buf = io.BytesIO()\n        plt.savefig(buf, dpi=600, bbox_inches='tight')\n        buf.seek(0)\n        image = tf.io.gfile.GFile(os.path.join(eval_dir, save_label + '.png'), 'w')\n        image.write(buf.read(-1))\n        plt.clf()\n\n        # for iter_idx in range(1, actual_coords.shape[0]):\n        #   _, ax1 = plt.subplots(1, 1)\n        #   ax1.set_xlim(-2, 15)\n        #   ax1.set_ylim(-2, 15)\n        #   ax1.plot(\n        #       actual_coords[:iter_idx, 0],\n        #       actual_coords[:iter_idx, 1],\n        #       linewidth=1.2)\n        #   ax1.scatter(\n        #       np.array(goal_coord_list)[:, 0],\n        #       np.array(goal_coord_list)[:, 1],\n        #       marker='x',\n        #       color='k')\n        #   buf = io.BytesIO()\n        #   plt.savefig(buf, dpi=200, bbox_inches='tight')\n        #   buf.seek(0)\n        #   image = tf.io.gfile.GFile(\n        #       os.path.join(eval_dir,\n        #                    save_label + '_' + '%04d' % (iter_idx) + '.png'),\n        #       'w')\n        #   image.write(buf.read(-1))\n        #   plt.clf()\n\n        plt.close()\n        print('Average reward for all goals:', average_reward_all_goals)\n\n\nif __name__ == '__main__':\n  tf.compat.v1.app.run(main)\n"
  },
  {
    "path": "unsupervised_skill_learning/skill_discriminator.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Skill Discriminator Prediction and Training.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nimport numpy as np\nimport tensorflow as tf\nimport tensorflow_probability as tfp\n\nfrom tf_agents.distributions import tanh_bijector_stable\n\nclass SkillDiscriminator:\n\n  def __init__(\n      self,\n      observation_size,\n      skill_size,\n      skill_type,\n      normalize_observations=False,\n      # network properties\n      fc_layer_params=(256, 256),\n      fix_variance=False,\n      input_type='diayn',\n      # probably do not need to change these\n      graph=None,\n      scope_name='skill_discriminator'):\n\n    self._observation_size = observation_size\n    self._skill_size = skill_size\n    self._skill_type = skill_type\n    self._normalize_observations = normalize_observations\n\n    # tensorflow requirements\n    if graph is not None:\n      self._graph = graph\n    else:\n      self._graph = tf.get_default_graph()\n    self._scope_name = scope_name\n\n    # discriminator network properties\n    self._fc_layer_params = fc_layer_params\n    self._fix_variance = fix_variance\n    if not self._fix_variance:\n      self._std_lower_clip = 0.3\n      self._std_upper_clip = 10.0\n    self._input_type = input_type\n\n    self._use_placeholders = False\n    self.log_probability = None\n    self.disc_max_op = None\n    self.disc_min_op = None\n    self._session = None\n\n    # saving/restoring variables\n    self._saver = None\n\n  def _get_distributions(self, out):\n    if self._skill_type in ['gaussian', 'cont_uniform']:\n      mean = tf.layers.dense(\n          out, self._skill_size, name='mean', reuse=tf.AUTO_REUSE)\n      if not self._fix_variance:\n        stddev = tf.clip_by_value(\n            tf.layers.dense(\n                out,\n                self._skill_size,\n                activation=tf.nn.softplus,\n                name='stddev',\n                reuse=tf.AUTO_REUSE), self._std_lower_clip,\n            self._std_upper_clip)\n      else:\n        stddev = tf.fill([tf.shape(out)[0], self._skill_size], 1.0)\n\n      inference_distribution = tfp.distributions.MultivariateNormalDiag(\n          loc=mean, scale_diag=stddev)\n\n      if self._skill_type == 'gaussian':\n        prior_distribution = tfp.distributions.MultivariateNormalDiag(\n            loc=[0.] * self._skill_size, scale_diag=[1.] * self._skill_size)\n      elif self._skill_type == 'cont_uniform':\n        prior_distribution = tfp.distributions.Independent(\n            tfp.distributions.Uniform(\n                low=[-1.] * self._skill_size, high=[1.] * self._skill_size),\n            reinterpreted_batch_ndims=1)\n\n        # squash posterior to the right range of [-1, 1]\n        bijectors = []\n        bijectors.append(tanh_bijector_stable.Tanh())\n        bijector_chain = tfp.bijectors.Chain(bijectors)\n        inference_distribution = tfp.distributions.TransformedDistribution(\n            distribution=inference_distribution, bijector=bijector_chain)\n\n    elif self._skill_type == 'discrete_uniform':\n      logits = tf.layers.dense(\n          out, self._skill_size, name='logits', reuse=tf.AUTO_REUSE)\n      inference_distribution = tfp.distributions.OneHotCategorical(\n          logits=logits)\n      prior_distribution = tfp.distributions.OneHotCategorical(\n          probs=[1. / self._skill_size] * self._skill_size)\n    elif self._skill_type == 'multivariate_bernoulli':\n      print('Not supported yet')\n\n    return inference_distribution, prior_distribution\n\n  # simple dynamics graph\n  def _default_graph(self, timesteps):\n    out = timesteps\n    for idx, layer_size in enumerate(self._fc_layer_params):\n      out = tf.layers.dense(\n          out,\n          layer_size,\n          activation=tf.nn.relu,\n          name='hid_' + str(idx),\n          reuse=tf.AUTO_REUSE)\n\n    return self._get_distributions(out)\n\n  def _get_dict(self,\n                input_steps,\n                target_skills,\n                input_next_steps=None,\n                batch_size=-1,\n                batch_norm=False):\n    if batch_size > 0:\n      shuffled_batch = np.random.permutation(len(input_steps))[:batch_size]\n    else:\n      shuffled_batch = np.arange(len(input_steps))\n\n    batched_input = input_steps[shuffled_batch, :]\n    batched_skills = target_skills[shuffled_batch, :]\n    if self._input_type in ['diff', 'both']:\n      batched_targets = input_next_steps[shuffled_batch, :]\n\n    return_dict = {\n        self.timesteps_pl: batched_input,\n        self.skills_pl: batched_skills,\n    }\n\n    if self._input_type in ['diff', 'both']:\n      return_dict[self.next_timesteps_pl] = batched_targets\n    if self._normalize_observations:\n      return_dict[self.is_training_pl] = batch_norm\n\n    return return_dict\n\n  def make_placeholders(self):\n    self._use_placeholders = True\n    with self._graph.as_default(), tf.variable_scope(self._scope_name):\n      self.timesteps_pl = tf.placeholder(\n          tf.float32, shape=(None, self._observation_size), name='timesteps_pl')\n      self.skills_pl = tf.placeholder(\n          tf.float32, shape=(None, self._skill_size), name='skills_pl')\n      if self._input_type in ['diff', 'both']:\n        self.next_timesteps_pl = tf.placeholder(\n            tf.float32,\n            shape=(None, self._observation_size),\n            name='next_timesteps_pl')\n      if self._normalize_observations:\n        self.is_training_pl = tf.placeholder(tf.bool, name='batch_norm_pl')\n\n  def set_session(self, session=None, initialize_or_restore_variables=False):\n    if session is None:\n      self._session = tf.Session(graph=self._graph)\n    else:\n      self._session = session\n\n    # only initialize uninitialized variables\n    if initialize_or_restore_variables:\n      if tf.gfile.Exists(self._save_prefix):\n        self.restore_variables()\n      with self._graph.as_default():\n        is_initialized = self._session.run([\n            tf.compat.v1.is_variable_initialized(v)\n            for key, v in self._variable_list.items()\n        ])\n        uninitialized_vars = []\n        for flag, v in zip(is_initialized, self._variable_list.items()):\n          if not flag:\n            uninitialized_vars.append(v[1])\n\n        if uninitialized_vars:\n          self._session.run(\n              tf.compat.v1.variables_initializer(uninitialized_vars))\n\n  def build_graph(self,\n                  timesteps=None,\n                  skills=None,\n                  next_timesteps=None,\n                  is_training=None):\n    with self._graph.as_default(), tf.variable_scope(self._scope_name):\n      if self._use_placeholders:\n        timesteps = self.timesteps_pl\n        skills = self.skills_pl\n        if self._input_type in ['diff', 'both']:\n          next_timesteps = self.next_timesteps_pl\n        if self._normalize_observations:\n          is_training = self.is_training_pl\n\n      # use deltas\n      if self._input_type == 'both':\n        next_timesteps -= timesteps\n        timesteps = tf.concat([timesteps, next_timesteps], axis=1)\n      if self._input_type == 'diff':\n        timesteps = next_timesteps - timesteps\n\n      if self._normalize_observations:\n        timesteps = tf.layers.batch_normalization(\n            timesteps,\n            training=is_training,\n            name='input_normalization',\n            reuse=tf.AUTO_REUSE)\n\n      inference_distribution, prior_distribution = self._default_graph(\n          timesteps)\n\n      self.log_probability = inference_distribution.log_prob(skills)\n      self.prior_probability = prior_distribution.log_prob(skills)\n      return self.log_probability, self.prior_probability\n\n  def increase_prob_op(self, learning_rate=3e-4):\n    with self._graph.as_default():\n      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)\n      with tf.control_dependencies(update_ops):\n        self.disc_max_op = tf.train.AdamOptimizer(\n            learning_rate=learning_rate).minimize(\n                -tf.reduce_mean(self.log_probability))\n        return self.disc_max_op\n\n  def decrease_prob_op(self, learning_rate=3e-4):\n    with self._graph.as_default():\n      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)\n      with tf.control_dependencies(update_ops):\n        self.disc_min_op = tf.train.AdamOptimizer(\n            learning_rate=learning_rate).minimize(\n                tf.reduce_mean(self.log_probability))\n        return self.disc_min_op\n\n  # only useful when training use placeholders, otherwise use ops directly\n  def train(self,\n            timesteps,\n            skills,\n            next_timesteps=None,\n            batch_size=512,\n            num_steps=1,\n            increase_probs=True):\n    if not self._use_placeholders:\n      return\n\n    if increase_probs:\n      run_op = self.disc_max_op\n    else:\n      run_op = self.disc_min_op\n\n    for _ in range(num_steps):\n      self._session.run(\n          run_op,\n          feed_dict=self._get_dict(\n              timesteps,\n              skills,\n              input_next_steps=next_timesteps,\n              batch_size=batch_size,\n              batch_norm=True))\n\n  def get_log_probs(self, timesteps, skills, next_timesteps=None):\n    if not self._use_placeholders:\n      return\n\n    return self._session.run([self.log_probability, self.prior_probability],\n                             feed_dict=self._get_dict(\n                                 timesteps,\n                                 skills,\n                                 input_next_steps=next_timesteps,\n                                 batch_norm=False))\n\n  def create_saver(self, save_prefix):\n    if self._saver is not None:\n      return self._saver\n    else:\n      with self._graph.as_default():\n        self._variable_list = {}\n        for var in tf.get_collection(\n            tf.GraphKeys.GLOBAL_VARIABLES, scope=self._scope_name):\n          self._variable_list[var.name] = var\n        self._saver = tf.train.Saver(self._variable_list, save_relative_paths=True)\n        self._save_prefix = save_prefix\n\n  def save_variables(self, global_step):\n    if not tf.gfile.Exists(self._save_prefix):\n      tf.gfile.MakeDirs(self._save_prefix)\n\n    self._saver.save(\n        self._session,\n        os.path.join(self._save_prefix, 'ckpt'),\n        global_step=global_step)\n\n  def restore_variables(self):\n    self._saver.restore(self._session,\n                        tf.train.latest_checkpoint(self._save_prefix))\n"
  },
  {
    "path": "unsupervised_skill_learning/skill_dynamics.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Dynamics Prediction and Training.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nimport numpy as np\nimport tensorflow as tf\nimport tensorflow_probability as tfp\n\n\n# TODO(architsh): Implement the dynamics with last K step input\nclass SkillDynamics:\n\n  def __init__(\n      self,\n      observation_size,\n      action_size,\n      restrict_observation=0,\n      normalize_observations=False,\n      # network properties\n      fc_layer_params=(256, 256),\n      network_type='default',\n      num_components=1,\n      fix_variance=False,\n      reweigh_batches=False,\n      graph=None,\n      scope_name='skill_dynamics'):\n\n    self._observation_size = observation_size\n    self._action_size = action_size\n    self._normalize_observations = normalize_observations\n    self._restrict_observation = restrict_observation\n    self._reweigh_batches = reweigh_batches\n\n    # tensorflow requirements\n    if graph is not None:\n      self._graph = graph\n    else:\n      self._graph = tf.compat.v1.get_default_graph()\n    self._scope_name = scope_name\n\n    # dynamics network properties\n    self._fc_layer_params = fc_layer_params\n    self._network_type = network_type\n    self._num_components = num_components\n    self._fix_variance = fix_variance\n    if not self._fix_variance:\n      self._std_lower_clip = 0.3\n      self._std_upper_clip = 10.0\n\n    self._use_placeholders = False\n    self.log_probability = None\n    self.dyn_max_op = None\n    self.dyn_min_op = None\n    self._session = None\n    self._use_modal_mean = False\n\n    # saving/restoring variables\n    self._saver = None\n\n  def _get_distribution(self, out):\n    if self._num_components > 1:\n      self.logits = tf.compat.v1.layers.dense(\n          out, self._num_components, name='logits', reuse=tf.compat.v1.AUTO_REUSE)\n      means, scale_diags = [], []\n      for component_id in range(self._num_components):\n        means.append(\n            tf.compat.v1.layers.dense(\n                out,\n                self._observation_size,\n                name='mean_' + str(component_id),\n                reuse=tf.compat.v1.AUTO_REUSE))\n        if not self._fix_variance:\n          scale_diags.append(\n              tf.clip_by_value(\n                  tf.compat.v1.layers.dense(\n                      out,\n                      self._observation_size,\n                      activation=tf.nn.softplus,\n                      name='stddev_' + str(component_id),\n                      reuse=tf.compat.v1.AUTO_REUSE), self._std_lower_clip,\n                  self._std_upper_clip))\n        else:\n          scale_diags.append(\n              tf.fill([tf.shape(out)[0], self._observation_size], 1.0))\n\n      self.means = tf.stack(means, axis=1)\n      self.scale_diags = tf.stack(scale_diags, axis=1)\n      return tfp.distributions.MixtureSameFamily(\n          mixture_distribution=tfp.distributions.Categorical(\n              logits=self.logits),\n          components_distribution=tfp.distributions.MultivariateNormalDiag(\n              loc=self.means, scale_diag=self.scale_diags))\n\n    else:\n      mean = tf.compat.v1.layers.dense(\n          out, self._observation_size, name='mean', reuse=tf.compat.v1.AUTO_REUSE)\n      if not self._fix_variance:\n        stddev = tf.clip_by_value(\n            tf.compat.v1.layers.dense(\n                out,\n                self._observation_size,\n                activation=tf.nn.softplus,\n                name='stddev',\n                reuse=tf.compat.v1.AUTO_REUSE), self._std_lower_clip,\n            self._std_upper_clip)\n      else:\n        stddev = tf.fill([tf.shape(out)[0], self._observation_size], 1.0)\n      return tfp.distributions.MultivariateNormalDiag(\n          loc=mean, scale_diag=stddev)\n\n  # dynamics graph with separate pipeline for skills and timesteps\n  def _graph_with_separate_skill_pipe(self, timesteps, actions):\n    skill_out = actions\n    with tf.compat.v1.variable_scope('action_pipe'):\n      for idx, layer_size in enumerate((self._fc_layer_params[0] // 2,)):\n        skill_out = tf.compat.v1.layers.dense(\n            skill_out,\n            layer_size,\n            activation=tf.nn.relu,\n            name='hid_' + str(idx),\n            reuse=tf.compat.v1.AUTO_REUSE)\n\n    ts_out = timesteps\n    with tf.compat.v1.variable_scope('ts_pipe'):\n      for idx, layer_size in enumerate((self._fc_layer_params[0] // 2,)):\n        ts_out = tf.compat.v1.layers.dense(\n            ts_out,\n            layer_size,\n            activation=tf.nn.relu,\n            name='hid_' + str(idx),\n            reuse=tf.compat.v1.AUTO_REUSE)\n\n    # out = tf.compat.v1.layers.flatten(tf.einsum('ai,aj->aij', ts_out, skill_out))\n    out = tf.concat([ts_out, skill_out], axis=1)\n    with tf.compat.v1.variable_scope('joint'):\n      for idx, layer_size in enumerate(self._fc_layer_param[1:]):\n        out = tf.compat.v1.layers.dense(\n            out,\n            layer_size,\n            activation=tf.nn.relu,\n            name='hid_' + str(idx),\n            reuse=tf.compat.v1.AUTO_REUSE)\n\n    return self._get_distribution(out)\n\n  # simple dynamics graph\n  def _default_graph(self, timesteps, actions):\n    out = tf.concat([timesteps, actions], axis=1)\n    for idx, layer_size in enumerate(self._fc_layer_params):\n      out = tf.compat.v1.layers.dense(\n          out,\n          layer_size,\n          activation=tf.nn.relu,\n          name='hid_' + str(idx),\n          reuse=tf.compat.v1.AUTO_REUSE)\n\n    return self._get_distribution(out)\n\n  def _get_dict(self,\n                input_data,\n                input_actions,\n                target_data,\n                batch_size=-1,\n                batch_weights=None,\n                batch_norm=False,\n                noise_targets=False,\n                noise_std=0.5):\n    if batch_size > 0:\n      shuffled_batch = np.random.permutation(len(input_data))[:batch_size]\n    else:\n      shuffled_batch = np.arange(len(input_data))\n\n    # if we are noising the input, it is better to create a new copy of the numpy arrays\n    batched_input = input_data[shuffled_batch, :]\n    batched_skills = input_actions[shuffled_batch, :]\n    batched_targets = target_data[shuffled_batch, :]\n\n    if self._reweigh_batches and batch_weights is not None:\n      example_weights = batch_weights[shuffled_batch]\n\n    if noise_targets:\n      batched_targets += np.random.randn(*batched_targets.shape) * noise_std\n\n    return_dict = {\n        self.timesteps_pl: batched_input,\n        self.actions_pl: batched_skills,\n        self.next_timesteps_pl: batched_targets\n    }\n    if self._normalize_observations:\n      return_dict[self.is_training_pl] = batch_norm\n    if self._reweigh_batches and batch_weights is not None:\n      return_dict[self.batch_weights] = example_weights\n\n    return return_dict\n\n  def _get_run_dict(self, input_data, input_actions):\n    return_dict = {\n        self.timesteps_pl: input_data,\n        self.actions_pl: input_actions\n    }\n    if self._normalize_observations:\n      return_dict[self.is_training_pl] = False\n\n    return return_dict\n\n  def make_placeholders(self):\n    self._use_placeholders = True\n    with self._graph.as_default(), tf.compat.v1.variable_scope(self._scope_name):\n      self.timesteps_pl = tf.compat.v1.placeholder(\n          tf.float32, shape=(None, self._observation_size), name='timesteps_pl')\n      self.actions_pl = tf.compat.v1.placeholder(\n          tf.float32, shape=(None, self._action_size), name='actions_pl')\n      self.next_timesteps_pl = tf.compat.v1.placeholder(\n          tf.float32,\n          shape=(None, self._observation_size),\n          name='next_timesteps_pl')\n      if self._normalize_observations:\n        self.is_training_pl = tf.compat.v1.placeholder(tf.bool, name='batch_norm_pl')\n      if self._reweigh_batches:\n        self.batch_weights = tf.compat.v1.placeholder(\n            tf.float32, shape=(None,), name='importance_sampled_weights')\n\n  def set_session(self, session=None, initialize_or_restore_variables=False):\n    if session is None:\n      self._session = tf.Session(graph=self._graph)\n    else:\n      self._session = session\n\n    # only initialize uninitialized variables\n    if initialize_or_restore_variables:\n      if tf.io.gfile.exists(self._save_prefix):\n        self.restore_variables()\n      with self._graph.as_default():\n        var_list = tf.compat.v1.global_variables(\n        ) + tf.compat.v1.local_variables()\n        is_initialized = self._session.run(\n            [tf.compat.v1.is_variable_initialized(v) for v in var_list])\n        uninitialized_vars = []\n        for flag, v in zip(is_initialized, var_list):\n          if not flag:\n            uninitialized_vars.append(v)\n\n        if uninitialized_vars:\n          self._session.run(\n              tf.compat.v1.variables_initializer(uninitialized_vars))\n\n  def build_graph(self,\n                  timesteps=None,\n                  actions=None,\n                  next_timesteps=None,\n                  is_training=None):\n    with self._graph.as_default(), tf.compat.v1.variable_scope(\n        self._scope_name, reuse=tf.compat.v1.AUTO_REUSE):\n      if self._use_placeholders:\n        timesteps = self.timesteps_pl\n        actions = self.actions_pl\n        next_timesteps = self.next_timesteps_pl\n        if self._normalize_observations:\n          is_training = self.is_training_pl\n\n      # predict deltas instead of observations\n      next_timesteps -= timesteps\n\n      if self._restrict_observation > 0:\n        timesteps = timesteps[:, self._restrict_observation:]\n\n      if self._normalize_observations:\n        timesteps = tf.compat.v1.layers.batch_normalization(\n            timesteps,\n            training=is_training,\n            name='input_normalization',\n            reuse=tf.compat.v1.AUTO_REUSE)\n        self.output_norm_layer = tf.compat.v1.layers.BatchNormalization(\n            scale=False, center=False, name='output_normalization')\n        next_timesteps = self.output_norm_layer(\n            next_timesteps, training=is_training)\n\n      if self._network_type == 'default':\n        self.base_distribution = self._default_graph(timesteps, actions)\n      elif self._network_type == 'separate':\n        self.base_distribution = self._graph_with_separate_skill_pipe(\n            timesteps, actions)\n\n      # if building multiple times, be careful about which log_prob you are optimizing\n      self.log_probability = self.base_distribution.log_prob(next_timesteps)\n      self.mean = self.base_distribution.mean()\n\n      return self.log_probability\n\n  def increase_prob_op(self, learning_rate=3e-4, weights=None):\n    with self._graph.as_default():\n      update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)\n      with tf.control_dependencies(update_ops):\n        if self._reweigh_batches:\n          self.dyn_max_op = tf.compat.v1.train.AdamOptimizer(\n              learning_rate=learning_rate,\n              name='adam_max').minimize(-tf.reduce_mean(self.log_probability *\n                                                        self.batch_weights))\n        elif weights is not None:\n          self.dyn_max_op = tf.compat.v1.train.AdamOptimizer(\n              learning_rate=learning_rate,\n              name='adam_max').minimize(-tf.reduce_mean(self.log_probability *\n                                                        weights))\n        else:\n          self.dyn_max_op = tf.compat.v1.train.AdamOptimizer(\n              learning_rate=learning_rate,\n              name='adam_max').minimize(-tf.reduce_mean(self.log_probability))\n\n        return self.dyn_max_op\n\n  def decrease_prob_op(self, learning_rate=3e-4, weights=None):\n    with self._graph.as_default():\n      update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)\n      with tf.control_dependencies(update_ops):\n        if self._reweigh_batches:\n          self.dyn_min_op = tf.compat.v1.train.AdamOptimizer(\n              learning_rate=learning_rate, name='adam_min').minimize(\n                  tf.reduce_mean(self.log_probability * self.batch_weights))\n        elif weights is not None:\n          self.dyn_min_op = tf.compat.v1.train.AdamOptimizer(\n              learning_rate=learning_rate, name='adam_min').minimize(\n                  tf.reduce_mean(self.log_probability * weights))\n        else:\n          self.dyn_min_op = tf.compat.v1.train.AdamOptimizer(\n              learning_rate=learning_rate,\n              name='adam_min').minimize(tf.reduce_mean(self.log_probability))\n        return self.dyn_min_op\n\n  def create_saver(self, save_prefix):\n    if self._saver is not None:\n      return self._saver\n    else:\n      with self._graph.as_default():\n        self._variable_list = {}\n        for var in tf.compat.v1.get_collection(\n            tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, scope=self._scope_name):\n          self._variable_list[var.name] = var\n        self._saver = tf.compat.v1.train.Saver(\n            self._variable_list, save_relative_paths=True)\n        self._save_prefix = save_prefix\n\n  def save_variables(self, global_step):\n    if not tf.io.gfile.exists(self._save_prefix):\n      tf.io.gfile.makedirs(self._save_prefix)\n\n    self._saver.save(\n        self._session,\n        os.path.join(self._save_prefix, 'ckpt'),\n        global_step=global_step)\n\n  def restore_variables(self):\n    self._saver.restore(self._session,\n                        tf.compat.v1.train.latest_checkpoint(self._save_prefix))\n\n  # all functions here-on require placeholders----------------------------------\n  def train(self,\n            timesteps,\n            actions,\n            next_timesteps,\n            batch_weights=None,\n            batch_size=512,\n            num_steps=1,\n            increase_probs=True):\n    if not self._use_placeholders:\n      return\n\n    if increase_probs:\n      run_op = self.dyn_max_op\n    else:\n      run_op = self.dyn_min_op\n\n    for _ in range(num_steps):\n      self._session.run(\n          run_op,\n          feed_dict=self._get_dict(\n              timesteps,\n              actions,\n              next_timesteps,\n              batch_weights=batch_weights,\n              batch_size=batch_size,\n              batch_norm=True))\n\n  def get_log_prob(self, timesteps, actions, next_timesteps):\n    if not self._use_placeholders:\n      return\n\n    return self._session.run(\n        self.log_probability,\n        feed_dict=self._get_dict(\n            timesteps, actions, next_timesteps, batch_norm=False))\n\n  def predict_state(self, timesteps, actions):\n    if not self._use_placeholders:\n      return\n\n    if self._use_modal_mean:\n      all_means, modal_mean_indices = self._session.run(\n          [self.means, tf.argmax(self.logits, axis=1)],\n          feed_dict=self._get_run_dict(timesteps, actions))\n      pred_state = all_means[[\n          np.arange(all_means.shape[0]), modal_mean_indices\n      ]]\n    else:\n      pred_state = self._session.run(\n          self.mean, feed_dict=self._get_run_dict(timesteps, actions))\n\n    if self._normalize_observations:\n      with self._session.as_default(), self._graph.as_default():\n        mean_correction, variance_correction = self.output_norm_layer.get_weights(\n        )\n\n      pred_state = pred_state * np.sqrt(variance_correction +\n                                        1e-3) + mean_correction\n\n    pred_state += timesteps\n    return pred_state\n"
  }
]