Repository: csmile-1006/PreferenceTransformer
Branch: main
Commit: f71647bb075c
Files: 388
Total size: 1.6 MB
Directory structure:
gitextract_cso8n2ux/
├── .gitignore
├── JaxPref/
│ ├── MR.py
│ ├── NMR.py
│ ├── PrefTransformer.py
│ ├── __init__.py
│ ├── human_label_preprocess_adroit.py
│ ├── human_label_preprocess_antmaze.py
│ ├── human_label_preprocess_mujoco.py
│ ├── human_label_preprocess_robosuite.py
│ ├── jax_utils.py
│ ├── model.py
│ ├── new_preference_reward_main.py
│ ├── replay_buffer.py
│ ├── reward_transform.py
│ ├── sampler.py
│ └── utils.py
├── LICENSE
├── README.md
├── actor.py
├── common.py
├── configs/
│ ├── adroit_config.py
│ ├── antmaze_config.py
│ ├── antmaze_finetune_config.py
│ └── mujoco_config.py
├── critic.py
├── d4rl/
│ ├── .gitignore
│ ├── LICENSE
│ ├── MANIFEST.in
│ ├── README.md
│ ├── d4rl/
│ │ ├── __init__.py
│ │ ├── carla/
│ │ │ ├── __init__.py
│ │ │ ├── carla_env.py
│ │ │ ├── data_collection_agent_lane.py
│ │ │ ├── data_collection_town.py
│ │ │ └── town_agent.py
│ │ ├── flow/
│ │ │ ├── __init__.py
│ │ │ ├── bottleneck.py
│ │ │ ├── merge.py
│ │ │ └── traffic_light_grid.py
│ │ ├── gym_bullet/
│ │ │ ├── __init__.py
│ │ │ └── gym_envs.py
│ │ ├── gym_minigrid/
│ │ │ ├── __init__.py
│ │ │ ├── envs/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── empty.py
│ │ │ │ └── fourrooms.py
│ │ │ ├── fourroom_controller.py
│ │ │ ├── minigrid.py
│ │ │ ├── register.py
│ │ │ ├── rendering.py
│ │ │ ├── roomgrid.py
│ │ │ ├── window.py
│ │ │ └── wrappers.py
│ │ ├── gym_mujoco/
│ │ │ ├── __init__.py
│ │ │ └── gym_envs.py
│ │ ├── hand_manipulation_suite/
│ │ │ ├── Adroit/
│ │ │ │ ├── .gitignore
│ │ │ │ ├── Adroit_hand.xml
│ │ │ │ ├── Adroit_hand_withOverlay.xml
│ │ │ │ ├── LICENSE
│ │ │ │ ├── README.md
│ │ │ │ └── resources/
│ │ │ │ ├── assets.xml
│ │ │ │ ├── chain.xml
│ │ │ │ ├── chain1.xml
│ │ │ │ ├── joint_position_actuation.xml
│ │ │ │ ├── meshes/
│ │ │ │ │ ├── F1.stl
│ │ │ │ │ ├── F2.stl
│ │ │ │ │ ├── F3.stl
│ │ │ │ │ ├── TH1_z.stl
│ │ │ │ │ ├── TH2_z.stl
│ │ │ │ │ ├── TH3_z.stl
│ │ │ │ │ ├── arm_base.stl
│ │ │ │ │ ├── arm_trunk.stl
│ │ │ │ │ ├── arm_trunk_asmbly.stl
│ │ │ │ │ ├── distal_ellipsoid.stl
│ │ │ │ │ ├── elbow_flex.stl
│ │ │ │ │ ├── elbow_rotate_motor.stl
│ │ │ │ │ ├── elbow_rotate_muscle.stl
│ │ │ │ │ ├── forearm_Cy_PlateAsmbly(muscle_cone).stl
│ │ │ │ │ ├── forearm_Cy_PlateAsmbly.stl
│ │ │ │ │ ├── forearm_PlateAsmbly.stl
│ │ │ │ │ ├── forearm_electric.stl
│ │ │ │ │ ├── forearm_electric_cvx.stl
│ │ │ │ │ ├── forearm_muscle.stl
│ │ │ │ │ ├── forearm_simple.stl
│ │ │ │ │ ├── forearm_simple_cvx.stl
│ │ │ │ │ ├── forearm_weight.stl
│ │ │ │ │ ├── knuckle.stl
│ │ │ │ │ ├── lfmetacarpal.stl
│ │ │ │ │ ├── palm.stl
│ │ │ │ │ ├── upper_arm.stl
│ │ │ │ │ ├── upper_arm_asmbl_shoulder.stl
│ │ │ │ │ ├── upper_arm_ass.stl
│ │ │ │ │ └── wrist.stl
│ │ │ │ └── tendon_torque_actuation.xml
│ │ │ ├── __init__.py
│ │ │ ├── assets/
│ │ │ │ ├── DAPG_Adroit.xml
│ │ │ │ ├── DAPG_assets.xml
│ │ │ │ ├── DAPG_door.xml
│ │ │ │ ├── DAPG_hammer.xml
│ │ │ │ ├── DAPG_pen.xml
│ │ │ │ └── DAPG_relocate.xml
│ │ │ ├── door_v0.py
│ │ │ ├── hammer_v0.py
│ │ │ ├── pen_v0.py
│ │ │ └── relocate_v0.py
│ │ ├── infos.py
│ │ ├── kitchen/
│ │ │ ├── __init__.py
│ │ │ ├── adept_envs/
│ │ │ │ ├── .pylintrc
│ │ │ │ ├── .style.yapf
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base_robot.py
│ │ │ │ ├── franka/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── assets/
│ │ │ │ │ │ └── franka_kitchen_jntpos_act_ab.xml
│ │ │ │ │ ├── kitchen_multitask_v0.py
│ │ │ │ │ └── robot/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── franka_config.xml
│ │ │ │ │ └── franka_robot.py
│ │ │ │ ├── mujoco_env.py
│ │ │ │ ├── robot_env.py
│ │ │ │ ├── simulation/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── module.py
│ │ │ │ │ ├── renderer.py
│ │ │ │ │ └── sim_robot.py
│ │ │ │ └── utils/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── config.py
│ │ │ │ ├── configurable.py
│ │ │ │ ├── constants.py
│ │ │ │ ├── parse_demos.py
│ │ │ │ └── quatmath.py
│ │ │ ├── adept_models/
│ │ │ │ ├── .gitignore
│ │ │ │ ├── CONTRIBUTING.public.md
│ │ │ │ ├── LICENSE
│ │ │ │ ├── README.public.md
│ │ │ │ ├── __init__.py
│ │ │ │ ├── kitchen/
│ │ │ │ │ ├── assets/
│ │ │ │ │ │ ├── backwall_asset.xml
│ │ │ │ │ │ ├── backwall_chain.xml
│ │ │ │ │ │ ├── counters_asset.xml
│ │ │ │ │ │ ├── counters_chain.xml
│ │ │ │ │ │ ├── hingecabinet_asset.xml
│ │ │ │ │ │ ├── hingecabinet_chain.xml
│ │ │ │ │ │ ├── kettle_asset.xml
│ │ │ │ │ │ ├── kettle_chain.xml
│ │ │ │ │ │ ├── microwave_asset.xml
│ │ │ │ │ │ ├── microwave_chain.xml
│ │ │ │ │ │ ├── oven_asset.xml
│ │ │ │ │ │ ├── oven_chain.xml
│ │ │ │ │ │ ├── slidecabinet_asset.xml
│ │ │ │ │ │ └── slidecabinet_chain.xml
│ │ │ │ │ ├── counters.xml
│ │ │ │ │ ├── hingecabinet.xml
│ │ │ │ │ ├── kettle.xml
│ │ │ │ │ ├── kitchen.xml
│ │ │ │ │ ├── meshes/
│ │ │ │ │ │ ├── burnerplate.stl
│ │ │ │ │ │ ├── burnerplate_mesh.stl
│ │ │ │ │ │ ├── cabinetbase.stl
│ │ │ │ │ │ ├── cabinetdrawer.stl
│ │ │ │ │ │ ├── cabinethandle.stl
│ │ │ │ │ │ ├── countertop.stl
│ │ │ │ │ │ ├── faucet.stl
│ │ │ │ │ │ ├── handle2.stl
│ │ │ │ │ │ ├── hingecabinet.stl
│ │ │ │ │ │ ├── hingedoor.stl
│ │ │ │ │ │ ├── hingehandle.stl
│ │ │ │ │ │ ├── hood.stl
│ │ │ │ │ │ ├── kettle.stl
│ │ │ │ │ │ ├── kettlehandle.stl
│ │ │ │ │ │ ├── knob.stl
│ │ │ │ │ │ ├── lightswitch.stl
│ │ │ │ │ │ ├── lightswitchbase.stl
│ │ │ │ │ │ ├── micro.stl
│ │ │ │ │ │ ├── microbutton.stl
│ │ │ │ │ │ ├── microdoor.stl
│ │ │ │ │ │ ├── microefeet.stl
│ │ │ │ │ │ ├── microfeet.stl
│ │ │ │ │ │ ├── microhandle.stl
│ │ │ │ │ │ ├── microwindow.stl
│ │ │ │ │ │ ├── oven.stl
│ │ │ │ │ │ ├── ovenhandle.stl
│ │ │ │ │ │ ├── oventop.stl
│ │ │ │ │ │ ├── ovenwindow.stl
│ │ │ │ │ │ ├── slidecabinet.stl
│ │ │ │ │ │ ├── slidedoor.stl
│ │ │ │ │ │ ├── stoverim.stl
│ │ │ │ │ │ ├── tile.stl
│ │ │ │ │ │ └── wall.stl
│ │ │ │ │ ├── microwave.xml
│ │ │ │ │ ├── oven.xml
│ │ │ │ │ └── slidecabinet.xml
│ │ │ │ └── scenes/
│ │ │ │ └── basic_scene.xml
│ │ │ ├── kitchen_envs.py
│ │ │ └── third_party/
│ │ │ └── franka/
│ │ │ ├── LICENSE
│ │ │ ├── README.md
│ │ │ ├── assets/
│ │ │ │ ├── actuator0.xml
│ │ │ │ ├── actuator1.xml
│ │ │ │ ├── assets.xml
│ │ │ │ ├── basic_scene.xml
│ │ │ │ ├── chain0.xml
│ │ │ │ ├── chain0_overlay.xml
│ │ │ │ ├── chain1.xml
│ │ │ │ └── teleop_actuator.xml
│ │ │ ├── bi-franka_panda.xml
│ │ │ ├── franka_panda.xml
│ │ │ ├── franka_panda_teleop.xml
│ │ │ └── meshes/
│ │ │ ├── collision/
│ │ │ │ ├── finger.stl
│ │ │ │ ├── hand.stl
│ │ │ │ ├── link0.stl
│ │ │ │ ├── link1.stl
│ │ │ │ ├── link2.stl
│ │ │ │ ├── link3.stl
│ │ │ │ ├── link4.stl
│ │ │ │ ├── link5.stl
│ │ │ │ ├── link6.stl
│ │ │ │ └── link7.stl
│ │ │ └── visual/
│ │ │ ├── finger.stl
│ │ │ ├── hand.stl
│ │ │ ├── link0.stl
│ │ │ ├── link1.stl
│ │ │ ├── link2.stl
│ │ │ ├── link3.stl
│ │ │ ├── link4.stl
│ │ │ ├── link5.stl
│ │ │ ├── link6.stl
│ │ │ └── link7.stl
│ │ ├── locomotion/
│ │ │ ├── __init__.py
│ │ │ ├── ant.py
│ │ │ ├── assets/
│ │ │ │ ├── ant.xml
│ │ │ │ └── point.xml
│ │ │ ├── common.py
│ │ │ ├── generate_dataset.py
│ │ │ ├── goal_reaching_env.py
│ │ │ ├── maze_env.py
│ │ │ ├── mujoco_goal_env.py
│ │ │ ├── point.py
│ │ │ ├── swimmer.py
│ │ │ └── wrappers.py
│ │ ├── offline_env.py
│ │ ├── ope.py
│ │ ├── pointmaze/
│ │ │ ├── __init__.py
│ │ │ ├── dynamic_mjc.py
│ │ │ ├── gridcraft/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── grid_env.py
│ │ │ │ ├── grid_spec.py
│ │ │ │ ├── utils.py
│ │ │ │ └── wrappers.py
│ │ │ ├── maze_model.py
│ │ │ ├── q_iteration.py
│ │ │ └── waypoint_controller.py
│ │ ├── pointmaze_bullet/
│ │ │ ├── __init__.py
│ │ │ ├── bullet_maze.py
│ │ │ └── bullet_robot.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── dataset_utils.py
│ │ ├── quatmath.py
│ │ ├── visualize_env.py
│ │ └── wrappers.py
│ ├── scripts/
│ │ ├── check_antmaze_datasets.py
│ │ ├── check_bullet.py
│ │ ├── check_envs.py
│ │ ├── check_mujoco_datasets.py
│ │ ├── generation/
│ │ │ ├── flow_idm.py
│ │ │ ├── generate_ant_maze_datasets.py
│ │ │ ├── generate_kitchen_datasets.py
│ │ │ ├── generate_maze2d_bullet_datasets.py
│ │ │ ├── generate_maze2d_datasets.py
│ │ │ ├── generate_minigrid_fourroom_data.py
│ │ │ ├── hand_dapg_combined.py
│ │ │ ├── hand_dapg_demos.py
│ │ │ ├── hand_dapg_jax.py
│ │ │ ├── hand_dapg_policies.py
│ │ │ ├── hand_dapg_random.py
│ │ │ ├── mujoco/
│ │ │ │ ├── collect_data.py
│ │ │ │ ├── convert_buffer.py
│ │ │ │ ├── fix_qpos_qvel.py
│ │ │ │ └── stitch_dataset.py
│ │ │ ├── relabel_antmaze_rewards.py
│ │ │ └── relabel_maze2d_rewards.py
│ │ ├── ope_rollout.py
│ │ ├── reference_scores/
│ │ │ ├── adroit_expert.py
│ │ │ ├── carla_lane_controller.py
│ │ │ ├── generate_ref_min_score.py
│ │ │ ├── generate_ref_min_score.sh
│ │ │ ├── maze2d_bullet_controller.py
│ │ │ ├── maze2d_controller.py
│ │ │ └── minigrid_controller.py
│ │ └── visualize_dataset.py
│ └── setup.py
├── dataset_utils.py
├── evaluation.py
├── flaxmodels/
│ ├── README.md
│ ├── flaxmodels/
│ │ ├── __init__.py
│ │ ├── gpt2/
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── gpt2.py
│ │ │ ├── gpt2_demo.ipynb
│ │ │ ├── ops.py
│ │ │ ├── third_party/
│ │ │ │ ├── __init__.py
│ │ │ │ └── huggingface_transformers/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── configuration_gpt2.py
│ │ │ │ └── utils/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── file_utils.py
│ │ │ │ ├── hf_api.py
│ │ │ │ ├── logging.py
│ │ │ │ ├── tokenization_utils.py
│ │ │ │ ├── tokenization_utils_base.py
│ │ │ │ └── versions.py
│ │ │ ├── tokenizer.py
│ │ │ └── trajectory_gpt2.py
│ │ ├── lstm/
│ │ │ ├── lstm.py
│ │ │ └── ops.py
│ │ └── utils.py
│ └── setup.py
├── human_label/
│ ├── Can_mh/
│ │ ├── indices_2_num500_q100
│ │ ├── indices_num500_q100
│ │ └── label_human
│ ├── Can_ph/
│ │ ├── indices_2_num100_q50
│ │ ├── indices_num100_q50
│ │ └── label_human
│ ├── Lift_mh/
│ │ ├── indices_2_num500_q100
│ │ ├── indices_num500_q100
│ │ └── label_human
│ ├── Lift_ph/
│ │ ├── indices_2_num100_q50
│ │ ├── indices_num100_q50
│ │ └── label_human
│ ├── README.md
│ ├── Square_mh/
│ │ ├── indices_2_num500_q100
│ │ ├── indices_num500_q100
│ │ └── label_human
│ ├── Square_ph/
│ │ ├── indices_2_num100_q50
│ │ ├── indices_num100_q50
│ │ └── label_human
│ ├── antmaze-large-diverse-v2/
│ │ ├── indices_2_num1000
│ │ ├── indices_num1000
│ │ └── label_human
│ ├── antmaze-large-play-v2/
│ │ ├── indices_2_num1000
│ │ ├── indices_num1000
│ │ └── label_human
│ ├── antmaze-medium-diverse-v2/
│ │ ├── indices_2_num1000
│ │ ├── indices_num1000
│ │ └── label_human
│ ├── antmaze-medium-play-v2/
│ │ ├── indices_2_num1000
│ │ ├── indices_num1000
│ │ └── label_human
│ ├── hammer-cloned-v1/
│ │ ├── indices_2_num100
│ │ ├── indices_num100
│ │ └── label_human
│ ├── hammer-human-v1/
│ │ ├── indices_2_num100
│ │ ├── indices_num100
│ │ └── label_human
│ ├── hopper-medium-expert-v2/
│ │ ├── indices_2_num100
│ │ ├── indices_num100
│ │ └── label_human
│ ├── hopper-medium-replay-v2/
│ │ ├── indices_2_num500
│ │ ├── indices_num500
│ │ └── label_human
│ ├── label_program.ipynb
│ ├── pen-cloned-v1/
│ │ ├── indices_2_num100
│ │ ├── indices_num100
│ │ └── label_human
│ ├── pen-human-v1/
│ │ ├── indices_2_num100
│ │ ├── indices_num100
│ │ └── label_human
│ ├── walker2d-medium-expert-v2/
│ │ ├── indices_2_num100
│ │ ├── indices_num100
│ │ └── label_human
│ └── walker2d-medium-replay-v2/
│ ├── indices_2_num500
│ ├── indices_num500
│ └── label_human
├── learner.py
├── policy.py
├── requirements.txt
├── robosuite_train_offline.py
├── train_finetune.py
├── train_offline.py
├── value_net.py
├── viskit/
│ ├── __init__.py
│ ├── core.py
│ ├── frontend.py
│ ├── logging.py
│ ├── static/
│ │ ├── css/
│ │ │ └── dropdowns-enhancement.css
│ │ └── js/
│ │ ├── dropdowns-enhancement.js
│ │ └── jquery.loadTemplate-1.5.6.js
│ ├── tabulate.py
│ └── templates/
│ └── main.html
├── visualize.py
└── wrappers/
├── __init__.py
├── common.py
├── episode_monitor.py
├── robosuite_wrapper.py
└── single_precision.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
logs/
runs*/
# jupyter notebook
notebooks/
# vscode
.vscode/
# output folder
tmp/
data/
deprecated_logs/
video/
final/
iql_final/
transformer_exp/
result/
================================================
FILE: JaxPref/MR.py
================================================
from functools import partial
from ml_collections import ConfigDict
import jax
import jax.numpy as jnp
from flax.training.train_state import TrainState
import optax
from .jax_utils import next_rng, value_and_multi_grad, mse_loss, cross_ent_loss
class MR(object):
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.rf_lr = 3e-4
config.optimizer_type = 'adam'
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
def __init__(self, config, rf):
self.config = self.get_default_config(config)
self.rf = rf
self.observation_dim = rf.observation_dim
self.action_dim = rf.action_dim
self._train_states = {}
optimizer_class = {
'adam': optax.adam,
'sgd': optax.sgd,
}[self.config.optimizer_type]
rf_params = self.rf.init(next_rng(), jnp.zeros((10, self.observation_dim)), jnp.zeros((10, self.action_dim)))
self._train_states['rf'] = TrainState.create(
params=rf_params,
tx=optimizer_class(self.config.rf_lr),
apply_fn=None,
)
model_keys = ['rf']
self._model_keys = tuple(model_keys)
self._total_steps = 0
def evaluation(self, batch):
metrics = self._eval_pref_step(
self._train_states, next_rng(), batch
)
return metrics
def get_reward(self, batch):
return self._get_reward_step(self._train_states, batch)
@partial(jax.jit, static_argnames=('self'))
def _get_reward_step(self, train_states, batch):
obs = batch['observations']
act = batch['actions']
# n_obs = batch['next_observations']
# in_obs = jnp.concatenate([obs, n_obs], axis=-1)
in_obs = obs
train_params = {key: train_states[key].params for key in self.model_keys}
rf_pred = self.rf.apply(train_params['rf'], in_obs, act)
return rf_pred
@partial(jax.jit, static_argnames=('self'))
def _eval_pref_step(self, train_states, rng, batch):
def loss_fn(train_params, rng):
obs_1 = batch['observations']
act_1 = batch['actions']
obs_2 = batch['observations_2']
act_2 = batch['actions_2']
labels = batch['labels']
B, T, obs_dim = batch['observations'].shape
B, T, act_dim = batch['actions'].shape
obs_1 = obs_1.reshape(-1, obs_dim)
obs_2 = obs_2.reshape(-1, obs_dim)
act_1 = act_1.reshape(-1, act_dim)
act_2 = act_2.reshape(-1, act_dim)
rf_pred_1 = self.rf.apply(train_params['rf'], obs_1, act_1)
rf_pred_2 = self.rf.apply(train_params['rf'], obs_2, act_2)
sum_pred_1 = jnp.mean(rf_pred_1.reshape(B, T), axis=1).reshape(-1, 1)
sum_pred_2 = jnp.mean(rf_pred_2.reshape(B, T), axis=1).reshape(-1, 1)
logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1)
loss_collection = {}
rng, split_rng = jax.random.split(rng)
""" reward function loss """
label_target = jax.lax.stop_gradient(labels)
rf_loss = cross_ent_loss(logits, label_target)
loss_collection['rf'] = rf_loss
return tuple(loss_collection[key] for key in self.model_keys), locals()
train_params = {key: train_states[key].params for key in self.model_keys}
(_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)
metrics = dict(
eval_rf_loss=aux_values['rf_loss'],
)
return metrics
def train(self, batch):
self._total_steps += 1
self._train_states, metrics = self._train_pref_step(
self._train_states, next_rng(), batch
)
return metrics
@partial(jax.jit, static_argnames=('self'))
def _train_pref_step(self, train_states, rng, batch):
def loss_fn(train_params, rng):
obs_1 = batch['observations']
act_1 = batch['actions']
obs_2 = batch['observations_2']
act_2 = batch['actions_2']
labels = batch['labels']
# n_obs_1 = batch['next_observations']
# n_obs_2 = batch['next_observations_2']
B, T, obs_dim = batch['observations'].shape
B, T, act_dim = batch['actions'].shape
obs_1 = obs_1.reshape(-1, obs_dim)
obs_2 = obs_2.reshape(-1, obs_dim)
act_1 = act_1.reshape(-1, act_dim)
act_2 = act_2.reshape(-1, act_dim)
rf_pred_1 = self.rf.apply(train_params['rf'], obs_1, act_1)
rf_pred_2 = self.rf.apply(train_params['rf'], obs_2, act_2)
sum_pred_1 = jnp.mean(rf_pred_1.reshape(B, T), axis=1).reshape(-1, 1)
sum_pred_2 = jnp.mean(rf_pred_2.reshape(B, T), axis=1).reshape(-1, 1)
logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1)
loss_collection = {}
rng, split_rng = jax.random.split(rng)
""" reward function loss """
label_target = jax.lax.stop_gradient(labels)
rf_loss = cross_ent_loss(logits, label_target)
loss_collection['rf'] = rf_loss
return tuple(loss_collection[key] for key in self.model_keys), locals()
train_params = {key: train_states[key].params for key in self.model_keys}
(_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)
new_train_states = {
key: train_states[key].apply_gradients(grads=grads[i][key])
for i, key in enumerate(self.model_keys)
}
metrics = dict(
rf_loss=aux_values['rf_loss'],
)
return new_train_states, metrics
def train_semi(self, labeled_batch, unlabeled_batch, lmd, tau):
self._total_steps += 1
self._train_states, metrics = self._train_semi_pref_step(
self._train_states, labeled_batch, unlabeled_batch, lmd, tau, next_rng()
)
return metrics
@partial(jax.jit, static_argnames=('self'))
def _train_semi_pref_step(self, train_states, labeled_batch, unlabeled_batch, lmd, tau, rng):
def compute_logits(batch):
obs_1 = batch['observations']
act_1 = batch['actions']
obs_2 = batch['observations_2']
act_2 = batch['actions_2']
labels = batch['labels']
# n_obs_1 = batch['next_observations']
# n_obs_2 = batch['next_observations_2']
B, T, obs_dim = batch['observations'].shape
B, T, act_dim = batch['actions'].shape
obs_1 = obs_1.reshape(-1, obs_dim)
obs_2 = obs_2.reshape(-1, obs_dim)
act_1 = act_1.reshape(-1, act_dim)
act_2 = act_2.reshape(-1, act_dim)
rf_pred_1 = self.rf.apply(train_params['rf'], obs_1, act_1)
rf_pred_2 = self.rf.apply(train_params['rf'], obs_2, act_2)
sum_pred_1 = jnp.mean(rf_pred_1.reshape(B,T), axis=1).reshape(-1,1)
sum_pred_2 = jnp.mean(rf_pred_2.reshape(B,T), axis=1).reshape(-1,1)
logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1)
return logits, labels
def loss_fn(train_params, lmd, tau, rng):
logits, labels = compute_logits(labeled_batch)
u_logits, _ = compute_logits(unlabeled_batch)
loss_collection = {}
rng, split_rng = jax.random.split(rng)
""" reward function loss """
label_target = jax.lax.stop_gradient(labels)
rf_loss = cross_ent_loss(logits, label_target)
u_confidence = jnp.max(jax.nn.softmax(u_logits, axis=-1), axis=-1)
pseudo_labels = jnp.argmax(u_logits, axis=-1)
pseudo_label_target = jax.lax.stop_gradient(pseudo_labels)
loss_ = optax.softmax_cross_entropy(logits=u_logits,
labels=jax.nn.one_hot(pseudo_label_target, num_classes=2))
u_rf_loss = jnp.where(u_confidence > tau, loss_, 0).mean()
u_rf_ratio = jnp.count_nonzero(u_confidence > tau) / len(u_confidence) * 100
loss_collection['rf'] = rf_loss + lmd * u_rf_loss
return tuple(loss_collection[key] for key in self.model_keys), locals()
train_params = {key: train_states[key].params for key in self.model_keys}
(_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, lmd, tau, rng)
new_train_states = {
key: train_states[key].apply_gradients(grads=grads[i][key])
for i, key in enumerate(self.model_keys)
}
metrics = dict(
rf_loss=aux_values['rf_loss'],
u_rf_loss=aux_values['u_rf_loss'],
u_rf_ratio=aux_values['u_rf_ratio']
)
return new_train_states, metrics
def train_regression(self, batch):
self._total_steps += 1
self._train_states, metrics = self._train_regression_step(
self._train_states, next_rng(), batch
)
return metrics
@partial(jax.jit, static_argnames=('self'))
def _train_regression_step(self, train_states, rng, batch):
def loss_fn(train_params, rng):
observations = batch['observations']
next_observations = batch['next_observations']
actions = batch['actions']
rewards = batch['rewards']
in_obs = jnp.concatenate([observations, next_observations], axis=-1)
loss_collection = {}
rng, split_rng = jax.random.split(rng)
""" reward function loss """
rf_pred = self.rf.apply(train_params['rf'], observations, actions)
reward_target = jax.lax.stop_gradient(rewards)
rf_loss = mse_loss(rf_pred, reward_target)
loss_collection['rf'] = rf_loss
return tuple(loss_collection[key] for key in self.model_keys), locals()
train_params = {key: train_states[key].params for key in self.model_keys}
(_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)
new_train_states = {
key: train_states[key].apply_gradients(grads=grads[i][key])
for i, key in enumerate(self.model_keys)
}
metrics = dict(
rf_loss=aux_values['rf_loss'],
average_rf=aux_values['rf_pred'].mean(),
)
return new_train_states, metrics
@property
def model_keys(self):
return self._model_keys
@property
def train_states(self):
return self._train_states
@property
def train_params(self):
return {key: self.train_states[key].params for key in self.model_keys}
@property
def total_steps(self):
return self._total_steps
================================================
FILE: JaxPref/NMR.py
================================================
from functools import partial
from ml_collections import ConfigDict
import jax
import jax.numpy as jnp
from flax.training.train_state import TrainState
import optax
from .jax_utils import next_rng, value_and_multi_grad, mse_loss, cross_ent_loss
class NMR(object):
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.lstm_lr = 1e-3
config.optimizer_type = 'adam'
config.scheduler_type = 'none'
config.vocab_size = 1
config.n_layer = 3
config.embd_dim = 256
config.n_embd = config.embd_dim
config.n_head = 1
config.n_inner = config.embd_dim // 2
config.n_positions = 1024
config.resid_pdrop = 0.1
config.attn_pdrop = 0.1
config.use_kld = False
config.lambda_kld = 0.1
config.softmax_temperature = 5
config.train_type = "sum"
config.train_diff_bool = False
config.explicit_sparse = False
config.k = 5
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
def __init__(self, config, lstm):
self.config = config
self.lstm = lstm
self.observation_dim = lstm.observation_dim
self.action_dim = lstm.action_dim
self._train_states = {}
optimizer_class = {
'adam': optax.adam,
'adamw': optax.adamw,
'sgd': optax.sgd,
}[self.config.optimizer_type]
scheduler_class = {
'none': None
}[self.config.scheduler_type]
if scheduler_class:
tx = optimizer_class(scheduler_class)
else:
tx = optimizer_class(learning_rate=self.config.lstm_lr)
lstm_params = self.lstm.init({"params": next_rng(), "dropout": next_rng()}, jnp.zeros((10, 10, self.observation_dim)), jnp.zeros((10, 10, self.action_dim)), jnp.ones((10, 10), dtype=jnp.int32))
self._train_states['lstm'] = TrainState.create(
params=lstm_params,
tx=tx,
apply_fn=None
)
model_keys = ['lstm']
self._model_keys = tuple(model_keys)
self._total_steps = 0
def evaluation(self, batch):
metrics = self._eval_pref_step(
self._train_states, next_rng(), batch
)
return metrics
def get_reward(self, batch):
return self._get_reward_step(self._train_states, batch)
@partial(jax.jit, static_argnames=('self'))
def _get_reward_step(self, train_states, batch):
obs = batch['observations']
act = batch['actions']
timestep = batch['timestep']
# n_obs = batch['next_observations']
train_params = {key: train_states[key].params for key in self.model_keys}
lstm_pred, _ = self.lstm.apply(train_params['lstm'], obs, act, timestep)
return lstm_pred, None
@partial(jax.jit, static_argnames=('self'))
def _eval_pref_step(self, train_states, rng, batch):
def loss_fn(train_params, rng):
obs_1 = batch['observations']
act_1 = batch['actions']
obs_2 = batch['observations_2']
act_2 = batch['actions_2']
timestep_1 = batch['timestep_1']
timestep_2 = batch['timestep_2']
labels = batch['labels']
B, T, _ = batch['observations'].shape
B, T, _ = batch['actions'].shape
rng, _ = jax.random.split(rng)
lstm_pred_1, _ = self.lstm.apply(train_params['lstm'], obs_1, act_1, timestep_1, training=True, attn_mask=None, rngs={"dropout": rng})
lstm_pred_2, _ = self.lstm.apply(train_params['lstm'], obs_2, act_2, timestep_2, training=True, attn_mask=None, rngs={"dropout": rng})
if self.config.train_type == "mean":
sum_pred_1 = jnp.mean(lstm_pred_1.reshape(B, T), axis=1).reshape(-1, 1)
sum_pred_2 = jnp.mean(lstm_pred_2.reshape(B, T), axis=1).reshape(-1, 1)
elif self.config.train_type == "sum":
sum_pred_1 = jnp.sum(lstm_pred_1.reshape(B, T), axis=1).reshape(-1, 1)
sum_pred_2 = jnp.sum(lstm_pred_2.reshape(B, T), axis=1).reshape(-1, 1)
elif self.config.train_type == "last":
sum_pred_1 = lstm_pred_1.reshape(B, T)[:, -1].reshape(-1, 1)
sum_pred_2 = lstm_pred_2.reshape(B, T)[:, -1].reshape(-1, 1)
logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1)
loss_collection = {}
rng, split_rng = jax.random.split(rng)
""" reward function loss """
label_target = jax.lax.stop_gradient(labels)
lstm_loss = cross_ent_loss(logits, label_target)
loss_collection['lstm'] = lstm_loss
return tuple(loss_collection[key] for key in self.model_keys), locals()
train_params = {key: train_states[key].params for key in self.model_keys}
(_, aux_values), _ = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)
metrics = dict(
eval_lstm_loss=aux_values['lstm_loss'],
)
return metrics
def train(self, batch):
self._total_steps += 1
self._train_states, metrics = self._train_pref_step(
self._train_states, next_rng(), batch
)
return metrics
@partial(jax.jit, static_argnames=('self'))
def _train_pref_step(self, train_states, rng, batch):
def loss_fn(train_params, rng):
obs_1 = batch['observations']
act_1 = batch['actions']
obs_2 = batch['observations_2']
act_2 = batch['actions_2']
timestep_1 = batch['timestep_1']
timestep_2 = batch['timestep_2']
labels = batch['labels']
B, T, _ = batch['observations'].shape
B, T, _ = batch['actions'].shape
rng, _ = jax.random.split(rng)
lstm_pred_1, _ = self.lstm.apply(train_params['lstm'], obs_1, act_1, timestep_1, training=True, attn_mask=None, rngs={"dropout": rng})
lstm_pred_2, _ = self.lstm.apply(train_params['lstm'], obs_2, act_2, timestep_2, training=True, attn_mask=None, rngs={"dropout": rng})
if self.config.train_type == "mean":
sum_pred_1 = jnp.mean(lstm_pred_1.reshape(B, T), axis=1).reshape(-1, 1)
sum_pred_2 = jnp.mean(lstm_pred_2.reshape(B, T), axis=1).reshape(-1, 1)
if self.config.train_type == "sum":
sum_pred_1 = jnp.sum(lstm_pred_1.reshape(B, T), axis=1).reshape(-1, 1)
sum_pred_2 = jnp.sum(lstm_pred_2.reshape(B, T), axis=1).reshape(-1, 1)
elif self.config.train_type == "last":
sum_pred_1 = lstm_pred_1.reshape(B, T)[:, -1].reshape(-1, 1)
sum_pred_2 = lstm_pred_2.reshape(B, T)[:, -1].reshape(-1, 1)
logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1)
loss_collection = {}
rng, split_rng = jax.random.split(rng)
""" reward function loss """
label_target = jax.lax.stop_gradient(labels)
lstm_loss = cross_ent_loss(logits, label_target)
loss_collection['lstm'] = lstm_loss
return tuple(loss_collection[key] for key in self.model_keys), locals()
train_params = {key: train_states[key].params for key in self.model_keys}
(_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)
new_train_states = {
key: train_states[key].apply_gradients(grads=grads[i][key])
for i, key in enumerate(self.model_keys)
}
metrics = dict(
lstm_loss=aux_values['lstm_loss'],
)
return new_train_states, metrics
def train_regression(self, batch):
self._total_steps += 1
self._train_states, metrics = self._train_regression_step(
self._train_states, next_rng(), batch
)
return metrics
@partial(jax.jit, static_argnames=('self'))
def _train_regression_step(self, train_states, rng, batch):
def loss_fn(train_params, rng):
observations = batch['observations']
next_observations = batch['next_observations']
actions = batch['actions']
rewards = batch['rewards']
in_obs = jnp.concatenate([observations, next_observations], axis=-1)
loss_collection = {}
rng, split_rng = jax.random.split(rng)
""" reward function loss """
rf_pred = self.rf.apply(train_params['rf'], observations, actions)
reward_target = jax.lax.stop_gradient(rewards)
rf_loss = mse_loss(rf_pred, reward_target)
loss_collection['rf'] = rf_loss
return tuple(loss_collection[key] for key in self.model_keys), locals()
train_params = {key: train_states[key].params for key in self.model_keys}
(_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)
new_train_states = {
key: train_states[key].apply_gradients(grads=grads[i][key])
for i, key in enumerate(self.model_keys)
}
metrics = dict(
rf_loss=aux_values['rf_loss'],
average_rf=aux_values['rf_pred'].mean(),
)
return new_train_states, metrics
@property
def model_keys(self):
return self._model_keys
@property
def train_states(self):
return self._train_states
@property
def train_params(self):
return {key: self.train_states[key].params for key in self.model_keys}
@property
def total_steps(self):
return self._total_steps
================================================
FILE: JaxPref/PrefTransformer.py
================================================
from functools import partial
from ml_collections import ConfigDict
import jax
import jax.numpy as jnp
import optax
import numpy as np
from flax.training.train_state import TrainState
from .jax_utils import next_rng, value_and_multi_grad, mse_loss, cross_ent_loss, kld_loss
class PrefTransformer(object):
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.trans_lr = 1e-4
config.optimizer_type = 'adamw'
config.scheduler_type = 'CosineDecay'
config.vocab_size = 1
config.n_layer = 3
config.embd_dim = 256
config.n_embd = config.embd_dim
config.n_head = 1
config.n_positions = 1024
config.resid_pdrop = 0.1
config.attn_pdrop = 0.1
config.pref_attn_embd_dim = 256
config.train_type = "mean"
# Weighted Sum option
config.use_weighted_sum = False
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
def __init__(self, config, trans):
self.config = config
self.trans = trans
self.observation_dim = trans.observation_dim
self.action_dim = trans.action_dim
self._train_states = {}
optimizer_class = {
'adam': optax.adam,
'adamw': optax.adamw,
'sgd': optax.sgd,
}[self.config.optimizer_type]
scheduler_class = {
'CosineDecay': optax.warmup_cosine_decay_schedule(
init_value=self.config.trans_lr,
peak_value=self.config.trans_lr * 10,
warmup_steps=self.config.warmup_steps,
decay_steps=self.config.total_steps,
end_value=self.config.trans_lr
),
"OnlyWarmup": optax.join_schedules(
[
optax.linear_schedule(
init_value=0.0,
end_value=self.config.trans_lr,
transition_steps=self.config.warmup_steps,
),
optax.constant_schedule(
value=self.config.trans_lr
)
],
[self.config.warmup_steps]
),
'none': None
}[self.config.scheduler_type]
if scheduler_class:
tx = optimizer_class(scheduler_class)
else:
tx = optimizer_class(learning_rate=self.config.trans_lr)
trans_params = self.trans.init(
{"params": next_rng(), "dropout": next_rng()},
jnp.zeros((10, 25, self.observation_dim)),
jnp.zeros((10, 25, self.action_dim)),
jnp.ones((10, 25), dtype=jnp.int32)
)
self._train_states['trans'] = TrainState.create(
params=trans_params,
tx=tx,
apply_fn=None
)
model_keys = ['trans']
self._model_keys = tuple(model_keys)
self._total_steps = 0
def evaluation(self, batch):
metrics = self._eval_pref_step(
self._train_states, next_rng(), batch
)
return metrics
def get_reward(self, batch):
return self._get_reward_step(self._train_states, batch)
@partial(jax.jit, static_argnames=('self'))
def _get_reward_step(self, train_states, batch):
obs = batch['observations']
act = batch['actions']
timestep = batch['timestep']
# n_obs = batch['next_observations']
attn_mask = batch['attn_mask']
train_params = {key: train_states[key].params for key in self.model_keys}
trans_pred, attn_weights = self.trans.apply(train_params['trans'], obs, act, timestep, attn_mask=attn_mask, reverse=False)
return trans_pred["value"], attn_weights[-1]
@partial(jax.jit, static_argnames=('self'))
def _eval_pref_step(self, train_states, rng, batch):
def loss_fn(train_params, rng):
obs_1 = batch['observations']
act_1 = batch['actions']
obs_2 = batch['observations_2']
act_2 = batch['actions_2']
timestep_1 = batch['timestep_1']
timestep_2 = batch['timestep_2']
labels = batch['labels']
B, T, _ = batch['observations'].shape
B, T, _ = batch['actions'].shape
rng, _ = jax.random.split(rng)
trans_pred_1, _ = self.trans.apply(train_params['trans'], obs_1, act_1, timestep_1, training=False, attn_mask=None, rngs={"dropout": rng})
trans_pred_2, _ = self.trans.apply(train_params['trans'], obs_2, act_2, timestep_2, training=False, attn_mask=None, rngs={"dropout": rng})
if self.config.use_weighted_sum:
trans_pred_1 = trans_pred_1["weighted_sum"]
trans_pred_2 = trans_pred_2["weighted_sum"]
else:
trans_pred_1 = trans_pred_1["value"]
trans_pred_2 = trans_pred_2["value"]
if self.config.train_type == "mean":
sum_pred_1 = jnp.mean(trans_pred_1.reshape(B, T), axis=1).reshape(-1, 1)
sum_pred_2 = jnp.mean(trans_pred_2.reshape(B, T), axis=1).reshape(-1, 1)
elif self.config.train_type == "sum":
sum_pred_1 = jnp.sum(trans_pred_1.reshape(B, T), axis=1).reshape(-1, 1)
sum_pred_2 = jnp.sum(trans_pred_2.reshape(B, T), axis=1).reshape(-1, 1)
elif self.config.train_type == "last":
sum_pred_1 = trans_pred_1.reshape(B, T)[:, -1].reshape(-1, 1)
sum_pred_2 = trans_pred_2.reshape(B, T)[:, -1].reshape(-1, 1)
logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1)
loss_collection = {}
rng, split_rng = jax.random.split(rng)
""" reward function loss """
label_target = jax.lax.stop_gradient(labels)
trans_loss = cross_ent_loss(logits, label_target)
cse_loss = trans_loss
loss_collection['trans'] = trans_loss
return tuple(loss_collection[key] for key in self.model_keys), locals()
train_params = {key: train_states[key].params for key in self.model_keys}
(_, aux_values), _ = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)
metrics = dict(
eval_cse_loss=aux_values['cse_loss'],
eval_trans_loss=aux_values['trans_loss'],
)
return metrics
def train(self, batch):
self._total_steps += 1
self._train_states, metrics = self._train_pref_step(
self._train_states, next_rng(), batch
)
return metrics
@partial(jax.jit, static_argnames=('self'))
def _train_pref_step(self, train_states, rng, batch):
def loss_fn(train_params, rng):
obs_1 = batch['observations']
act_1 = batch['actions']
obs_2 = batch['observations_2']
act_2 = batch['actions_2']
timestep_1 = batch['timestep_1']
timestep_2 = batch['timestep_2']
labels = batch['labels']
B, T, _ = batch['observations'].shape
B, T, _ = batch['actions'].shape
rng, _ = jax.random.split(rng)
trans_pred_1, _ = self.trans.apply(train_params['trans'], obs_1, act_1, timestep_1, training=True, attn_mask=None, rngs={"dropout": rng})
trans_pred_2, _ = self.trans.apply(train_params['trans'], obs_2, act_2, timestep_2, training=True, attn_mask=None, rngs={"dropout": rng})
if self.config.use_weighted_sum:
trans_pred_1 = trans_pred_1["weighted_sum"]
trans_pred_2 = trans_pred_2["weighted_sum"]
else:
trans_pred_1 = trans_pred_1["value"]
trans_pred_2 = trans_pred_2["value"]
if self.config.train_type == "mean":
sum_pred_1 = jnp.mean(trans_pred_1.reshape(B, T), axis=1).reshape(-1, 1)
sum_pred_2 = jnp.mean(trans_pred_2.reshape(B, T), axis=1).reshape(-1, 1)
elif self.config.train_type == "sum":
sum_pred_1 = jnp.sum(trans_pred_1.reshape(B, T), axis=1).reshape(-1, 1)
sum_pred_2 = jnp.sum(trans_pred_2.reshape(B, T), axis=1).reshape(-1, 1)
elif self.config.train_type == "last":
sum_pred_1 = trans_pred_1.reshape(B, T)[:, -1].reshape(-1, 1)
sum_pred_2 = trans_pred_2.reshape(B, T)[:, -1].reshape(-1, 1)
logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1)
loss_collection = {}
rng, split_rng = jax.random.split(rng)
""" reward function loss """
label_target = jax.lax.stop_gradient(labels)
trans_loss = cross_ent_loss(logits, label_target)
cse_loss = trans_loss
loss_collection['trans'] = trans_loss
return tuple(loss_collection[key] for key in self.model_keys), locals()
train_params = {key: train_states[key].params for key in self.model_keys}
(_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)
new_train_states = {
key: train_states[key].apply_gradients(grads=grads[i][key])
for i, key in enumerate(self.model_keys)
}
metrics = dict(
cse_loss=aux_values['cse_loss'],
trans_loss=aux_values['trans_loss'],
)
return new_train_states, metrics
def train_semi(self, labeled_batch, unlabeled_batch, lmd, tau):
self._total_steps += 1
self._train_states, metrics = self._train_semi_pref_step(
self._train_states, labeled_batch, unlabeled_batch, lmd, tau, next_rng()
)
return metrics
@partial(jax.jit, static_argnames=('self'))
def _train_semi_pref_step(self, train_states, labeled_batch, unlabeled_batch, lmd, tau, rng):
def compute_logits(train_params, batch, rng):
obs_1 = batch['observations']
act_1 = batch['actions']
obs_2 = batch['observations_2']
act_2 = batch['actions_2']
timestep_1 = batch['timestep_1']
timestep_2 = batch['timestep_2']
labels = batch['labels']
B, T, _ = batch['observations'].shape
B, T, _ = batch['actions'].shape
rng, _ = jax.random.split(rng)
trans_pred_1, _ = self.trans.apply(train_params['trans'], obs_1, act_1, timestep_1, training=True, attn_mask=None, rngs={"dropout": rng})
trans_pred_2, _ = self.trans.apply(train_params['trans'], obs_2, act_2, timestep_2, training=True, attn_mask=None, rngs={"dropout": rng})
if self.config.use_weighted_sum:
trans_pred_1 = trans_pred_1["weighted_sum"]
trans_pred_2 = trans_pred_2["weighted_sum"]
else:
trans_pred_1 = trans_pred_1["value"]
trans_pred_2 = trans_pred_2["value"]
if self.config.train_type == "mean":
sum_pred_1 = jnp.mean(trans_pred_1.reshape(B, T), axis=1).reshape(-1, 1)
sum_pred_2 = jnp.mean(trans_pred_2.reshape(B, T), axis=1).reshape(-1, 1)
elif self.config.train_type == "sum":
sum_pred_1 = jnp.sum(trans_pred_1.reshape(B, T), axis=1).reshape(-1, 1)
sum_pred_2 = jnp.sum(trans_pred_2.reshape(B, T), axis=1).reshape(-1, 1)
elif self.config.train_type == "last":
sum_pred_1 = trans_pred_1.reshape(B, T)[:, -1].reshape(-1, 1)
sum_pred_2 = trans_pred_2.reshape(B, T)[:, -1].reshape(-1, 1)
logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1)
return logits, labels
def loss_fn(train_params, lmd, tau, rng):
rng, _ = jax.random.split(rng)
logits, labels = compute_logits(train_params, labeled_batch, rng)
u_logits, _ = compute_logits(train_params, unlabeled_batch, rng)
loss_collection = {}
rng, split_rng = jax.random.split(rng)
""" reward function loss """
label_target = jax.lax.stop_gradient(labels)
trans_loss = cross_ent_loss(logits, label_target)
u_confidence = jnp.max(jax.nn.softmax(u_logits, axis=-1), axis=-1)
pseudo_labels = jnp.argmax(u_logits, axis=-1)
pseudo_label_target = jax.lax.stop_gradient(pseudo_labels)
loss_ = optax.softmax_cross_entropy(logits=u_logits, labels=jax.nn.one_hot(pseudo_label_target, num_classes=2))
u_trans_loss = jnp.sum(jnp.where(u_confidence > tau, loss_, 0)) / (jnp.count_nonzero(u_confidence > tau) + 1e-4)
u_trans_ratio = jnp.count_nonzero(u_confidence > tau) / len(u_confidence) * 100
# labeling neutral cases.
binarized_idx = jnp.where(unlabeled_batch["labels"][:, 0] != 0.5, 1., 0.)
real_label = jnp.argmax(unlabeled_batch["labels"], axis=-1)
u_trans_acc = jnp.sum(jnp.where(pseudo_label_target == real_label, 1., 0.) * binarized_idx) / jnp.sum(binarized_idx) * 100
loss_collection['trans'] = last_loss = trans_loss + lmd * u_trans_loss
return tuple(loss_collection[key] for key in self.model_keys), locals()
train_params = {key: train_states[key].params for key in self.model_keys}
(_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, lmd, tau, rng)
new_train_states = {
key: train_states[key].apply_gradients(grads=grads[i][key])
for i, key in enumerate(self.model_keys)
}
metrics = dict(
trans_loss=aux_values['trans_loss'],
u_trans_loss=aux_values['u_trans_loss'],
last_loss=aux_values['last_loss'],
u_trans_ratio=aux_values['u_trans_ratio'],
u_train_acc=aux_values['u_trans_acc']
)
return new_train_states, metrics
def train_regression(self, batch):
self._total_steps += 1
self._train_states, metrics = self._train_regression_step(
self._train_states, next_rng(), batch
)
return metrics
@partial(jax.jit, static_argnames=('self'))
def _train_regression_step(self, train_states, rng, batch):
def loss_fn(train_params, rng):
observations = batch['observations']
next_observations = batch['next_observations']
actions = batch['actions']
rewards = batch['rewards']
in_obs = jnp.concatenate([observations, next_observations], axis=-1)
loss_collection = {}
rng, split_rng = jax.random.split(rng)
""" reward function loss """
rf_pred = self.rf.apply(train_params['rf'], observations, actions)
reward_target = jax.lax.stop_gradient(rewards)
rf_loss = mse_loss(rf_pred, reward_target)
loss_collection['rf'] = rf_loss
return tuple(loss_collection[key] for key in self.model_keys), locals()
train_params = {key: train_states[key].params for key in self.model_keys}
(_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)
new_train_states = {
key: train_states[key].apply_gradients(grads=grads[i][key])
for i, key in enumerate(self.model_keys)
}
metrics = dict(
rf_loss=aux_values['rf_loss'],
average_rf=aux_values['rf_pred'].mean(),
)
return new_train_states, metrics
@property
def model_keys(self):
return self._model_keys
@property
def train_states(self):
return self._train_states
@property
def train_params(self):
return {key: self.train_states[key].params for key in self.model_keys}
@property
def total_steps(self):
return self._total_steps
================================================
FILE: JaxPref/__init__.py
================================================
================================================
FILE: JaxPref/human_label_preprocess_adroit.py
================================================
import os
import pickle
import gym
import imageio
import jax
import numpy as np
from absl import app, flags
from tqdm import tqdm, trange
import d4rl
from JaxPref.reward_transform import get_queries_from_multi
FLAGS = flags.FLAGS
flags.DEFINE_string("env_name", "antmaze-medium-diverse-v2", "Environment name.")
flags.DEFINE_string("save_dir", "./video/", "saving dir.")
flags.DEFINE_integer("num_query", 1000, "number of query.")
flags.DEFINE_integer("query_len", 100, "length of each query.")
flags.DEFINE_integer("label_type", 1, "label type.")
flags.DEFINE_integer("seed", 3407, "seed for reproducibility.")
video_size = {"medium": (500, 500), "large": (600, 450)}
def set_seed(env, seed):
np.random.seed(seed)
env.seed(seed)
env.observation_space.seed(seed)
env.action_space.seed(seed)
def qlearning_adroit_dataset(env, dataset=None, terminate_on_end=False, **kwargs):
"""
Returns datasets formatted for use by standard Q-learning algorithms,
with observations, actions, next_observations, rewards, and a terminal
flag.
Args:
env: An OfflineEnv object.
dataset: An optional dataset to pass in for processing. If None,
the dataset will default to env.get_dataset()
terminate_on_end (bool): Set done=True on the last timestep
in a trajectory. Default is False, and will discard the
last timestep in each trajectory.
**kwargs: Arguments to pass to env.get_dataset().
Returns:
A dictionary containing keys:
observations: An N x dim_obs array of observations.
actions: An N x dim_action array of actions.
next_observations: An N x dim_obs array of next observations.
rewards: An N-dim float array of rewards.
terminals: An N-dim boolean array of "done" or episode termination flags.
"""
if dataset is None:
dataset = env.get_dataset(**kwargs)
N = dataset["rewards"].shape[0]
obs_ = []
next_obs_ = []
action_ = []
reward_ = []
done_ = []
xy_ = []
done_bef_ = []
qpos_ = []
qvel_ = []
# The newer version of the dataset adds an explicit
# timeouts field. Keep old method for backwards compatability.
use_timeouts = False
if "timeouts" in dataset:
use_timeouts = True
episode_step = 0
for i in range(N - 1):
obs = dataset["observations"][i].astype(np.float32)
new_obs = dataset["observations"][i + 1].astype(np.float32)
action = dataset["actions"][i].astype(np.float32)
reward = dataset["rewards"][i].astype(np.float32)
done_bool = bool(dataset["terminals"][i]) or episode_step == env._max_episode_steps - 1
xy = dataset["infos/qpos"][i][:2].astype(np.float32)
qpos = dataset["infos/qpos"][i]
qvel = dataset["infos/qvel"][i]
if use_timeouts:
final_timestep = dataset["timeouts"][i]
next_final_timestep = dataset["timeouts"][i + 1]
else:
final_timestep = episode_step == env._max_episode_steps - 1
next_final_timestep = episode_step == env._max_episode_steps - 2
done_bef = bool(next_final_timestep)
if (not terminate_on_end) and final_timestep:
# Skip this transition and don't apply terminals on the last step of an episode
episode_step = 0
continue
if done_bool or final_timestep:
episode_step = 0
obs_.append(obs)
next_obs_.append(new_obs)
action_.append(action)
reward_.append(reward)
done_.append(done_bool)
xy_.append(xy)
done_bef_.append(done_bef)
qpos_.append(qpos)
qvel_.append(qvel)
episode_step += 1
return {
"observations": np.array(obs_),
"actions": np.array(action_),
"next_observations": np.array(next_obs_),
"rewards": np.array(reward_),
"terminals": np.array(done_),
"xys": np.array(xy_),
"dones_bef": np.array(done_bef_),
"qposes": np.array(qpos_),
"qvels": np.array(qvel_),
}
class Dataset(object):
def __init__(
self,
observations: np.ndarray,
actions: np.ndarray,
rewards: np.ndarray,
masks: np.ndarray,
dones_float: np.ndarray,
next_observations: np.ndarray,
qposes: np.ndarray,
qvels: np.ndarray,
size: int,
):
self.observations = observations
self.actions = actions
self.rewards = rewards
self.masks = masks
self.dones_float = dones_float
self.next_observations = next_observations
self.qposes = qposes
self.qvels = qvels
self.size = size
class D4RLDataset(Dataset):
def __init__(self, env: gym.Env, clip_to_eps: bool = True, eps: float = 1e-5):
dataset = qlearning_adroit_dataset(env)
if clip_to_eps:
lim = 1 - eps
dataset["actions"] = np.clip(dataset["actions"], -lim, lim)
dones_float = np.zeros_like(dataset["rewards"])
for i in range(len(dones_float) - 1):
if (
np.linalg.norm(dataset["observations"][i + 1] - dataset["next_observations"][i]) > 1e-5
or dataset["terminals"][i] == 1.0
):
dones_float[i] = 1
else:
dones_float[i] = 0
dones_float[-1] = 1
super().__init__(
dataset["observations"].astype(np.float32),
actions=dataset["actions"].astype(np.float32),
rewards=dataset["rewards"].astype(np.float32),
masks=1.0 - dataset["terminals"].astype(np.float32),
dones_float=dones_float.astype(np.float32),
next_observations=dataset["next_observations"].astype(np.float32),
qposes=dataset["qposes"].astype(np.float32),
qvels=dataset["qvels"].astype(np.float32),
size=len(dataset["observations"]),
)
def visualize_query(
gym_env, dataset, batch, query_len, num_query, width=500, height=500, save_dir="./video", verbose=False
):
save_dir = os.path.join(save_dir, gym_env.spec.id)
os.makedirs(save_dir, exist_ok=True)
for seg_idx in trange(num_query):
start_1, start_2 = (
batch["start_indices"][seg_idx],
batch["start_indices_2"][seg_idx],
)
frames = []
frames_2 = []
start_indices = range(start_1, start_1 + query_len)
start_indices_2 = range(start_2, start_2 + query_len)
gym_env.reset()
if verbose:
print(f"start pos of first one: {dataset['qposes'][start_indices[0]][:2]}")
print("=" * 50)
print(f"start pos of second one: {dataset['qposes'][start_indices_2[0]][:2]}")
camera_name = "fixed"
for t in trange(query_len, leave=False):
gym_env.set_state(dataset["qposes"][start_indices[t]], dataset["qvels"][start_indices[t]])
curr_frame = gym_env.sim.render(width=width, height=height, mode="offscreen", camera_name=camera_name)
frames.append(np.flipud(curr_frame))
gym_env.reset()
for t in trange(query_len, leave=False):
gym_env.set_state(
dataset["qposes"][start_indices_2[t]],
dataset["qvels"][start_indices_2[t]],
)
curr_frame = gym_env.sim.render(width=width, height=height, mode="offscreen", camera_name=camera_name)
frames_2.append(np.flipud(curr_frame))
video = np.concatenate((np.array(frames), np.array(frames_2)), axis=2)
writer = imageio.get_writer(os.path.join(save_dir, f"./idx{seg_idx}.mp4"), fps=30)
for frame in tqdm(video, leave=False):
writer.append_data(frame)
writer.close()
print("save query indices.")
with open(
os.path.join(save_dir, f"human_indices_numq{num_query}_len{query_len}_s{FLAGS.seed}.pkl"),
"wb",
) as f:
pickle.dump(batch["start_indices"], f)
with open(
os.path.join(
save_dir,
f"human_indices_2_numq{num_query}_len{query_len}_s{FLAGS.seed}.pkl",
),
"wb",
) as f:
pickle.dump(batch["start_indices_2"], f)
def main(_):
gym_env = gym.make(FLAGS.env_name)
width, height = 500, 500
set_seed(gym_env, FLAGS.seed)
ds = qlearning_adroit_dataset(gym_env)
batch = get_queries_from_multi(
gym_env,
ds,
data_dir="./",
num_query=FLAGS.num_query,
len_query=FLAGS.query_len,
label_type=FLAGS.label_type,
)
visualize_query(
gym_env, ds, batch, FLAGS.query_len, FLAGS.num_query, width=width, height=height, save_dir=FLAGS.save_dir
)
if __name__ == "__main__":
app.run(main)
================================================
FILE: JaxPref/human_label_preprocess_antmaze.py
================================================
import os
import pickle
import gym
import imageio
import jax
import numpy as np
from absl import app, flags
from tqdm import tqdm, trange
from PIL import Image, ImageDraw
import d4rl
from JaxPref.reward_transform import load_queries_with_indices
FLAGS = flags.FLAGS
flags.DEFINE_string("env_name", "antmaze-medium-diverse-v2", "Environment name.")
flags.DEFINE_string("save_dir", "./video/", "saving dir.")
flags.DEFINE_string("query_path", "./human_label/", "query path")
flags.DEFINE_integer("num_query", 1000, "number of query.")
flags.DEFINE_integer("query_len", 100, "length of each query.")
flags.DEFINE_integer("label_type", 1, "label type.")
flags.DEFINE_bool("slow", False, "slow option for external feedback.")
flags.DEFINE_integer("seed", 3407, "seed for reproducibility.")
video_size = {"medium": (500, 500), "large": (600, 450)}
def set_seed(env, seed):
np.random.seed(seed)
env.seed(seed)
env.observation_space.seed(seed)
env.action_space.seed(seed)
def qlearning_ant_dataset(env, dataset=None, terminate_on_end=False, **kwargs):
"""
Returns datasets formatted for use by standard Q-learning algorithms,
with observations, actions, next_observations, rewards, and a terminal
flag.
Args:
env: An OfflineEnv object.
dataset: An optional dataset to pass in for processing. If None,
the dataset will default to env.get_dataset()
terminate_on_end (bool): Set done=True on the last timestep
in a trajectory. Default is False, and will discard the
last timestep in each trajectory.
**kwargs: Arguments to pass to env.get_dataset().
Returns:
A dictionary containing keys:
observations: An N x dim_obs array of observations.
actions: An N x dim_action array of actions.
next_observations: An N x dim_obs array of next observations.
rewards: An N-dim float array of rewards.
terminals: An N-dim boolean array of "done" or episode termination flags.
"""
if dataset is None:
dataset = env.get_dataset(**kwargs)
N = dataset["rewards"].shape[0]
obs_ = []
next_obs_ = []
action_ = []
reward_ = []
done_ = []
goal_ = []
xy_ = []
done_bef_ = []
qpos_ = []
qvel_ = []
# The newer version of the dataset adds an explicit
# timeouts field. Keep old method for backwards compatability.
use_timeouts = False
if "timeouts" in dataset:
use_timeouts = True
episode_step = 0
for i in range(N - 1):
obs = dataset["observations"][i].astype(np.float32)
new_obs = dataset["observations"][i + 1].astype(np.float32)
action = dataset["actions"][i].astype(np.float32)
reward = dataset["rewards"][i].astype(np.float32)
done_bool = bool(dataset["terminals"][i]) or episode_step == env._max_episode_steps - 1
goal = dataset["infos/goal"][i].astype(np.float32)
xy = dataset["infos/qpos"][i][:2].astype(np.float32)
qpos = dataset["infos/qpos"][i]
qvel = dataset["infos/qvel"][i]
if use_timeouts:
final_timestep = dataset["timeouts"][i]
next_final_timestep = dataset["timeouts"][i + 1]
else:
final_timestep = episode_step == env._max_episode_steps - 1
next_final_timestep = episode_step == env._max_episode_steps - 2
done_bef = bool(next_final_timestep)
if (not terminate_on_end) and final_timestep:
# Skip this transition and don't apply terminals on the last step of an episode
episode_step = 0
continue
if done_bool or final_timestep:
episode_step = 0
obs_.append(obs)
next_obs_.append(new_obs)
action_.append(action)
reward_.append(reward)
done_.append(done_bool)
goal_.append(goal)
xy_.append(xy)
done_bef_.append(done_bef)
qpos_.append(qpos)
qvel_.append(qvel)
episode_step += 1
return {
"observations": np.array(obs_),
"actions": np.array(action_),
"next_observations": np.array(next_obs_),
"rewards": np.array(reward_),
"terminals": np.array(done_),
"goals": np.array(goal_),
"xys": np.array(xy_),
"dones_bef": np.array(done_bef_),
"qposes": np.array(qpos_),
"qvels": np.array(qvel_),
}
class Dataset(object):
def __init__(
self,
observations: np.ndarray,
actions: np.ndarray,
rewards: np.ndarray,
masks: np.ndarray,
dones_float: np.ndarray,
next_observations: np.ndarray,
qposes: np.ndarray,
qvels: np.ndarray,
goals: np.ndarray,
size: int,
):
self.observations = observations
self.actions = actions
self.rewards = rewards
self.masks = masks
self.dones_float = dones_float
self.next_observations = next_observations
self.qposes = qposes
self.qvels = qvels
self.goals = goals
self.size = size
class D4RLDataset(Dataset):
def __init__(self, env: gym.Env, clip_to_eps: bool = True, eps: float = 1e-5):
dataset = qlearning_ant_dataset(env)
if clip_to_eps:
lim = 1 - eps
dataset["actions"] = np.clip(dataset["actions"], -lim, lim)
dones_float = np.zeros_like(dataset["rewards"])
for i in range(len(dones_float) - 1):
if (
np.linalg.norm(dataset["observations"][i + 1] - dataset["next_observations"][i]) > 1e-5
or dataset["terminals"][i] == 1.0
):
dones_float[i] = 1
else:
dones_float[i] = 0
dones_float[-1] = 1
super().__init__(
dataset["observations"].astype(np.float32),
actions=dataset["actions"].astype(np.float32),
rewards=dataset["rewards"].astype(np.float32),
masks=1.0 - dataset["terminals"].astype(np.float32),
dones_float=dones_float.astype(np.float32),
next_observations=dataset["next_observations"].astype(np.float32),
qposes=dataset["qposes"].astype(np.float32),
qvels=dataset["qvels"].astype(np.float32),
goals=dataset["goals"].astype(np.float32),
size=len(dataset["observations"]),
)
def visualize_query(
gym_env, dataset, batch, query_len, num_query, width=500, height=500, save_dir="./video", verbose=False
):
save_dir = os.path.join(save_dir, gym_env.spec.id)
if FLAGS.slow:
save_dir = os.path.join(save_dir, "slow")
os.makedirs(save_dir, exist_ok=True)
for seg_idx in trange(num_query):
start_1, start_2 = (
batch["start_indices"][seg_idx],
batch["start_indices_2"][seg_idx],
)
frames = []
frames_2 = []
start_indices = range(start_1, start_1 + query_len)
start_indices_2 = range(start_2, start_2 + query_len)
gym_env.reset()
if verbose:
print(f"start pos of first one: {dataset['qposes'][start_indices[0]][:2]}")
print(f"goal pos of first one: {dataset['goals'][start_indices[0]]}")
print("=" * 50)
print(f"start pos of second one: {dataset['qposes'][start_indices_2[0]][:2]}")
print(f"goal pos of second one: {dataset['goals'][start_indices_2[0]]}")
# 1.0 -> 15.0 in pixel
if "medium" in gym_env.spec.id:
dist_per_pixel = 15
start_x = 95
start_y = 95
camera_name = "birdview"
else:
dist_per_pixel = 11
start_x = 80
start_y = 110
camera_name = "birdview_large"
for t in trange(query_len, leave=False):
gym_env.set_state(dataset["qposes"][start_indices[t]], dataset["qvels"][start_indices[t]])
if "diverse" in gym_env.spec.id:
goal_x, goal_y = map(lambda x: round(x), dataset["goals"][start_indices[t]])
else:
goal_x, goal_y = map(lambda x: round(x), gym_env.target_goal)
curr_frame = gym_env.physics.render(width=width, height=height, mode="offscreen", camera_name=camera_name)
curr_frame[
start_y + int(goal_y * dist_per_pixel) : start_y + int(goal_y * dist_per_pixel) + 10,
start_x + int(goal_x * dist_per_pixel) : start_x + int(goal_x * dist_per_pixel) + 10,
] = np.array((255, 0, 0)).astype(np.uint8)
if FLAGS.slow:
frame_img = Image.fromarray(curr_frame)
draw = ImageDraw.Draw(frame_img)
draw.text((width - 10, 0), f"{t + 1}", fill="black")
draw.text((0, 0), "0", fill="black")
curr_frame = np.asarray(frame_img)
for i in range(10):
frames.append(curr_frame)
gym_env.reset()
for t in trange(query_len, leave=False):
gym_env.set_state(
dataset["qposes"][start_indices_2[t]],
dataset["qvels"][start_indices_2[t]],
)
if "diverse" in gym_env.spec.id:
goal_x, goal_y = map(lambda x: round(x), dataset["goals"][start_indices_2[t]])
else:
goal_x, goal_y = map(lambda x: round(x), gym_env.target_goal)
curr_frame = gym_env.physics.render(width=width, height=height, mode="offscreen", camera_name=camera_name)
curr_frame[
start_y + int(goal_y * dist_per_pixel) : start_y + int(goal_y * dist_per_pixel) + 10,
start_x + int(goal_x * dist_per_pixel) : start_x + int(goal_x * dist_per_pixel) + 10,
] = np.array([255, 0, 0]).astype(np.uint8)
if FLAGS.slow:
frame_img = Image.fromarray(curr_frame)
draw = ImageDraw.Draw(frame_img)
draw.text((width - 10, 0), f"{t + 1}", fill="black")
draw.text((0, 0), "1", fill="black")
curr_frame = np.asarray(frame_img)
curr_frame = np.asarray(frame_img)
for i in range(10):
frames_2.append(curr_frame)
video = np.concatenate((np.array(frames), np.array(frames_2)), axis=2)
fps = 3 if FLAGS.slow else 30
writer = imageio.get_writer(os.path.join(save_dir, f"./idx{seg_idx}.mp4"), fps=30)
for frame in tqdm(video, leave=False):
writer.append_data(frame)
writer.close()
print("save query indices.")
with open(
os.path.join(save_dir, f"human_indices_numq{num_query}_len{query_len}_s{FLAGS.seed}.pkl"),
"wb",
) as f:
pickle.dump(batch["start_indices"], f)
with open(
os.path.join(
save_dir,
f"human_indices_2_numq{num_query}_len{query_len}_s{FLAGS.seed}.pkl",
),
"wb",
) as f:
pickle.dump(batch["start_indices_2"], f)
def main(_):
gym_env = gym.make(FLAGS.env_name)
if "medium" in FLAGS.env_name:
width, height = video_size["medium"]
elif "large" in FLAGS.env_name:
width, height = video_size["large"]
set_seed(gym_env, FLAGS.seed)
ds = qlearning_ant_dataset(gym_env)
base_path = os.path.join(FLAGS.query_path, FLAGS.env_name)
human_indices_2_file, human_indices_1_file, _ = sorted(os.listdir(base_path))
with open(os.path.join(base_path, human_indices_1_file), "rb") as fp: # Unpickling
human_indices = pickle.load(fp)
with open(os.path.join(base_path, human_indices_2_file), "rb") as fp: # Unpickling
human_indices_2 = pickle.load(fp)
human_labels = None
batch = load_queries_with_indices(
gym_env,
ds,
saved_indices=[human_indices, human_indices_2],
saved_labels=human_labels,
num_query=FLAGS.num_query,
len_query=FLAGS.query_len,
label_type=FLAGS.label_type,
scripted_teacher=True
)
visualize_query(
gym_env, ds, batch, FLAGS.query_len, FLAGS.num_query, width=width, height=height, save_dir=FLAGS.save_dir
)
if __name__ == "__main__":
app.run(main)
================================================
FILE: JaxPref/human_label_preprocess_mujoco.py
================================================
import os
import pickle
import gym
import imageio
import jax
import numpy as np
from absl import app, flags
from tqdm import tqdm, trange
import d4rl
from JaxPref.reward_transform import load_queries_with_indices
FLAGS = flags.FLAGS
flags.DEFINE_string("env_name", "antmaze-medium-diverse-v2", "Environment name.")
flags.DEFINE_string("save_dir", "./video/", "saving dir.")
flags.DEFINE_string("query_path", "./human_label/", "query path")
flags.DEFINE_integer("num_query", 1000, "number of query.")
flags.DEFINE_integer("query_len", 100, "length of each query.")
flags.DEFINE_integer("label_type", 1, "label type.")
flags.DEFINE_integer("seed", 3407, "seed for reproducibility.")
video_size = {"medium": (500, 500), "large": (600, 450)}
def set_seed(env, seed):
np.random.seed(seed)
env.seed(seed)
env.observation_space.seed(seed)
env.action_space.seed(seed)
def qlearning_mujoco_dataset(env, dataset=None, terminate_on_end=False, **kwargs):
"""
Returns datasets formatted for use by standard Q-learning algorithms,
with observations, actions, next_observations, rewards, and a terminal
flag.
Args:
env: An OfflineEnv object.
dataset: An optional dataset to pass in for processing. If None,
the dataset will default to env.get_dataset()
terminate_on_end (bool): Set done=True on the last timestep
in a trajectory. Default is False, and will discard the
last timestep in each trajectory.
**kwargs: Arguments to pass to env.get_dataset().
Returns:
A dictionary containing keys:
observations: An N x dim_obs array of observations.
actions: An N x dim_action array of actions.
next_observations: An N x dim_obs array of next observations.
rewards: An N-dim float array of rewards.
terminals: An N-dim boolean array of "done" or episode termination flags.
"""
if dataset is None:
dataset = env.get_dataset(**kwargs)
N = dataset["rewards"].shape[0]
obs_ = []
next_obs_ = []
action_ = []
reward_ = []
done_ = []
xy_ = []
done_bef_ = []
qpos_ = []
qvel_ = []
# The newer version of the dataset adds an explicit
# timeouts field. Keep old method for backwards compatability.
use_timeouts = False
if "timeouts" in dataset:
use_timeouts = True
episode_step = 0
for i in range(N - 1):
obs = dataset["observations"][i].astype(np.float32)
new_obs = dataset["observations"][i + 1].astype(np.float32)
action = dataset["actions"][i].astype(np.float32)
reward = dataset["rewards"][i].astype(np.float32)
done_bool = bool(dataset["terminals"][i]) or episode_step == env._max_episode_steps - 1
xy = dataset["infos/qpos"][i][:2].astype(np.float32)
qpos = dataset["infos/qpos"][i]
qvel = dataset["infos/qvel"][i]
if use_timeouts:
final_timestep = dataset["timeouts"][i]
next_final_timestep = dataset["timeouts"][i + 1]
else:
final_timestep = episode_step == env._max_episode_steps - 1
next_final_timestep = episode_step == env._max_episode_steps - 2
done_bef = bool(next_final_timestep)
if (not terminate_on_end) and final_timestep:
# Skip this transition and don't apply terminals on the last step of an episode
episode_step = 0
continue
if done_bool or final_timestep:
episode_step = 0
obs_.append(obs)
next_obs_.append(new_obs)
action_.append(action)
reward_.append(reward)
done_.append(done_bool)
xy_.append(xy)
done_bef_.append(done_bef)
qpos_.append(qpos)
qvel_.append(qvel)
episode_step += 1
return {
"observations": np.array(obs_),
"actions": np.array(action_),
"next_observations": np.array(next_obs_),
"rewards": np.array(reward_),
"terminals": np.array(done_),
"xys": np.array(xy_),
"dones_bef": np.array(done_bef_),
"qposes": np.array(qpos_),
"qvels": np.array(qvel_),
}
class Dataset(object):
def __init__(
self,
observations: np.ndarray,
actions: np.ndarray,
rewards: np.ndarray,
masks: np.ndarray,
dones_float: np.ndarray,
next_observations: np.ndarray,
qposes: np.ndarray,
qvels: np.ndarray,
size: int,
):
self.observations = observations
self.actions = actions
self.rewards = rewards
self.masks = masks
self.dones_float = dones_float
self.next_observations = next_observations
self.qposes = qposes
self.qvels = qvels
self.size = size
class D4RLDataset(Dataset):
def __init__(self, env: gym.Env, clip_to_eps: bool = True, eps: float = 1e-5):
dataset = qlearning_mujoco_dataset(env)
if clip_to_eps:
lim = 1 - eps
dataset["actions"] = np.clip(dataset["actions"], -lim, lim)
dones_float = np.zeros_like(dataset["rewards"])
for i in range(len(dones_float) - 1):
if (
np.linalg.norm(dataset["observations"][i + 1] - dataset["next_observations"][i]) > 1e-5
or dataset["terminals"][i] == 1.0
):
dones_float[i] = 1
else:
dones_float[i] = 0
dones_float[-1] = 1
super().__init__(
dataset["observations"].astype(np.float32),
actions=dataset["actions"].astype(np.float32),
rewards=dataset["rewards"].astype(np.float32),
masks=1.0 - dataset["terminals"].astype(np.float32),
dones_float=dones_float.astype(np.float32),
next_observations=dataset["next_observations"].astype(np.float32),
qposes=dataset["qposes"].astype(np.float32),
qvels=dataset["qvels"].astype(np.float32),
size=len(dataset["observations"]),
)
def visualize_query(
gym_env, dataset, batch, query_len, num_query, width=500, height=500, save_dir="./video", verbose=False
):
save_dir = os.path.join(save_dir, gym_env.spec.id)
os.makedirs(save_dir, exist_ok=True)
for seg_idx in trange(num_query):
start_1, start_2 = (
batch["start_indices"][seg_idx],
batch["start_indices_2"][seg_idx],
)
frames = []
frames_2 = []
start_indices = range(start_1, start_1 + query_len)
start_indices_2 = range(start_2, start_2 + query_len)
gym_env.reset()
if verbose:
print(f"start pos of first one: {dataset['qposes'][start_indices[0]][:2]}")
print("=" * 50)
print(f"start pos of second one: {dataset['qposes'][start_indices_2[0]][:2]}")
camera_name = "track"
for t in trange(query_len, leave=False):
gym_env.set_state(dataset["qposes"][start_indices[t]], dataset["qvels"][start_indices[t]])
curr_frame = gym_env.sim.render(width=width, height=height, mode="offscreen", camera_name=camera_name)
frames.append(np.flipud(curr_frame))
gym_env.reset()
for t in trange(query_len, leave=False):
gym_env.set_state(
dataset["qposes"][start_indices_2[t]],
dataset["qvels"][start_indices_2[t]],
)
curr_frame = gym_env.sim.render(width=width, height=height, mode="offscreen", camera_name=camera_name)
frames_2.append(np.flipud(curr_frame))
video = np.concatenate((np.array(frames), np.array(frames_2)), axis=2)
writer = imageio.get_writer(os.path.join(save_dir, f"./idx{seg_idx}.mp4"), fps=30)
for frame in tqdm(video, leave=False):
writer.append_data(frame)
writer.close()
print("save query indices.")
with open(
os.path.join(save_dir, f"human_indices_numq{num_query}_len{query_len}_s{FLAGS.seed}.pkl"),
"wb",
) as f:
pickle.dump(batch["start_indices"], f)
with open(
os.path.join(
save_dir,
f"human_indices_2_numq{num_query}_len{query_len}_s{FLAGS.seed}.pkl",
),
"wb",
) as f:
pickle.dump(batch["start_indices_2"], f)
def main(_):
gym_env = gym.make(FLAGS.env_name)
if "medium" in FLAGS.env_name:
width, height = video_size["medium"]
elif "large" in FLAGS.env_name:
width, height = video_size["large"]
set_seed(gym_env, FLAGS.seed)
ds = qlearning_mujoco_dataset(gym_env)
base_path = os.path.join(FLAGS.query_path, FLAGS.env_name)
human_indices_2_file, human_indices_1_file, _ = sorted(os.listdir(base_path))
with open(os.path.join(base_path, human_indices_1_file), "rb") as fp: # Unpickling
human_indices = pickle.load(fp)
with open(os.path.join(base_path, human_indices_2_file), "rb") as fp: # Unpickling
human_indices_2 = pickle.load(fp)
human_labels = None
batch = load_queries_with_indices(
gym_env,
ds,
saved_indices=[human_indices, human_indices_2],
saved_labels=human_labels,
num_query=FLAGS.num_query,
len_query=FLAGS.query_len,
label_type=FLAGS.label_type,
scripted_teacher=True
)
visualize_query(
gym_env, ds, batch, FLAGS.query_len, FLAGS.num_query, width=width, height=height, save_dir=FLAGS.save_dir
)
if __name__ == "__main__":
app.run(main)
================================================
FILE: JaxPref/human_label_preprocess_robosuite.py
================================================
"""
A script to visualize dataset trajectories by loading the simulation states
one by one or loading the first state and playing actions back open-loop.
The script can generate videos as well, by rendering simulation frames
during playback. The videos can also be generated using the image observations
in the dataset (this is useful for real-robot datasets) by using the
--use-obs argument.
Args:
dataset (str): path to hdf5 dataset
filter_key (str): if provided, use the subset of trajectories
in the file that correspond to this filter key
n (int): if provided, stop after n trajectories are processed
use-obs (bool): if flag is provided, visualize trajectories with dataset
image observations instead of simulator
use-actions (bool): if flag is provided, use open-loop action playback
instead of loading sim states
render (bool): if flag is provided, use on-screen rendering during playback
video_path (str): if provided, render trajectories to this video file path
video_skip (int): render frames to a video every @video_skip steps
render_image_names (str or [str]): camera name(s) / image observation(s) to
use for rendering on-screen or to video
first (bool): if flag is provided, use first frame of each episode for playback
instead of the entire episode. Useful for visualizing task initializations.
Example usage below:
# force simulation states one by one, and render agentview and wrist view cameras to video
python playback_dataset.py --dataset /path/to/dataset.hdf5 \
--render_image_names agentview robot0_eye_in_hand \
--video_path /tmp/playback_dataset.mp4
# playback the actions in the dataset, and render agentview camera during playback to video
python playback_dataset.py --dataset /path/to/dataset.hdf5 \
--use-actions --render_image_names agentview \
--video_path /tmp/playback_dataset_with_actions.mp4
# use the observations stored in the dataset to render videos of the dataset trajectories
python playback_dataset.py --dataset /path/to/dataset.hdf5 \
--use-obs --render_image_names agentview_image \
--video_path /tmp/obs_trajectory.mp4
# visualize initial states in the demonstration data
python playback_dataset.py --dataset /path/to/dataset.hdf5 \
--first --render_image_names agentview \
--video_path /tmp/dataset_task_inits.mp4
"""
import os
import json
import h5py
import pickle
import argparse
import imageio
import numpy as np
from tqdm import tqdm
from PIL import Image
import robomimic
import robomimic.utils.obs_utils as ObsUtils
import robomimic.utils.env_utils as EnvUtils
import robomimic.utils.file_utils as FileUtils
from robomimic.envs.env_base import EnvBase, EnvType
from .reward_transform import qlearning_robosuite_dataset
# Define default cameras to use for each env type
DEFAULT_CAMERAS = {
EnvType.ROBOSUITE_TYPE: ["agentview"],
EnvType.IG_MOMART_TYPE: ["rgb"],
EnvType.GYM_TYPE: ValueError("No camera names supported for gym type env!"),
}
def playback_trajectory_with_env(
env,
initial_state,
states,
actions=None,
render=False,
video_writer=None,
video_skip=5,
camera_names=None,
first=False,
):
"""
Helper function to playback a single trajectory using the simulator environment.
If @actions are not None, it will play them open-loop after loading the initial state.
Otherwise, @states are loaded one by one.
Args:
env (instance of EnvBase): environment
initial_state (dict): initial simulation state to load
states (np.array): array of simulation states to load
actions (np.array): if provided, play actions back open-loop instead of using @states
render (bool): if True, render on-screen
video_writer (imageio writer): video writer
video_skip (int): determines rate at which environment frames are written to video
camera_names (list): determines which camera(s) are used for rendering. Pass more than
one to output a video with multiple camera views concatenated horizontally.
first (bool): if True, only use the first frame of each episode.
"""
assert isinstance(env, EnvBase)
write_video = (video_writer is not None)
video_count = 0
assert not (render and write_video)
# load the initial state
env.reset()
env.reset_to(initial_state)
traj_len = states.shape[0]
action_playback = (actions is not None)
if action_playback:
assert states.shape[0] == actions.shape[0]
for i in range(traj_len):
if action_playback:
env.step(actions[i])
if i < traj_len - 1:
# check whether the actions deterministically lead to the same recorded states
state_playback = env.get_state()["states"]
if not np.all(np.equal(states[i + 1], state_playback)):
err = np.linalg.norm(states[i + 1] - state_playback)
print("warning: playback diverged by {} at step {}".format(err, i))
else:
env.reset_to({"states" : states[i]})
# on-screen render
if render:
env.render(mode="human", camera_name=camera_names[0])
# video render
if write_video:
if video_count % video_skip == 0:
video_img = []
for cam_name in camera_names:
video_img.append(env.render(mode="rgb_array", height=512, width=512, camera_name=cam_name))
video_img = np.concatenate(video_img, axis=1) # concatenate horizontally
video_writer.append_data(video_img)
video_count += 1
if first:
break
def playback_trajectory_with_obs(
traj_grp,
segs,
seg_length,
video_writer,
video_skip=5,
image_names=None,
first=False,
):
"""
This function reads all "rgb" observations in the dataset trajectory and
writes them into a video.
Args:
traj_grp (hdf5 file group): hdf5 group which corresponds to the dataset trajectory to playback
video_writer (imageio writer): video writer
video_skip (int): determines rate at which environment frames are written to video
image_names (list): determines which image observations are used for rendering. Pass more than
one to output a video with multiple image observations concatenated horizontally.
first (bool): if True, only use the first frame of each episode.
"""
assert image_names is not None, "error: must specify at least one image observation to use in @image_names"
assert len(traj_grp) == len(segs) == 2, "you should have 2 trajs with corresponding segment points."
video_count = 0
frames = [[], []]
for idx in range(2):
grp, seg = traj_grp[idx], segs[idx]
video_count = 0
for i in range(seg, seg + seg_length):
if video_count % video_skip == 0:
# concatenate image obs together
try:
im = [grp["obs/{}".format(k)][i] for k in image_names]
except:
print(f"trajectory number: {grp.name}")
print(f"length of trajectory: {len(grp['obs/agentview_image'])}")
raise
frame = np.concatenate(im, axis=1)
frames[idx].append(frame)
# video_writer.append_data(frame)
video_count += 1
if first:
break
for frame_1, frame_2 in zip(*frames):
image = np.concatenate([frame_1, frame_2], axis=1)
image = np.asarray(Image.fromarray(image).resize((512, 256), Image.HAMMING))
video_writer.append_data(image)
# for grp in traj_grp:
# traj_len = grp["actions"].shape[0]
# for i in range():
# if video_count % video_skip == 0:
# # concatenate image obs together
# im = [traj_grp["obs/{}".format(k)][i] for k in image_names]
# frame = np.concatenate(im, axis=1)
# video_writer.append_data(frame)
# video_count += 1
# if first:
# break
def playback_dataset(args):
# some arg checking
write_video = (args.video_path is not None)
assert not (args.render and write_video) # either on-screen or video but not both
dataset_path = os.path.join(args.dataset, args.env.lower(), args.dataset_type, "image.hdf5")
# Auto-fill camera rendering info if not specified
if args.render_image_names is None:
# We fill in the automatic values
env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=dataset_path)
env_type = EnvUtils.get_env_type(env_meta=env_meta)
args.render_image_names = DEFAULT_CAMERAS[env_type]
if args.render:
# on-screen rendering can only support one camera
assert len(args.render_image_names) == 1
if args.use_obs:
assert write_video, "playback with observations can only write to video"
assert not args.use_actions, "playback with observations is offline and does not support action playback"
# create environment only if not playing back with observations
if not args.use_obs:
# need to make sure ObsUtils knows which observations are images, but it doesn't matter
# for playback since observations are unused. Pass a dummy spec here.
dummy_spec = dict(
obs=dict(
low_dim=["robot0_eef_pos"],
rgb=[],
),
)
ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs=dummy_spec)
env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=dataset_path)
env = EnvUtils.create_env_from_metadata(env_meta=env_meta, render=args.render, render_offscreen=write_video)
# some operations for playback are robosuite-specific, so determine if this environment is a robosuite env
is_robosuite_env = EnvUtils.is_robosuite_env(env_meta)
f = h5py.File(dataset_path, "r")
ds = qlearning_robosuite_dataset(dataset_path)
# list of all demonstration episodes (sorted in increasing number order)
if args.filter_key is not None:
print("using filter key: {}".format(args.filter_key))
demos = [elem.decode("utf-8") for elem in np.array(f["mask/{}".format(args.filter_key)])]
else:
demos = list(f["data"].keys())
indices_path = os.path.join(args.indices_path, f"{args.env}_{args.dataset_type}")
if args.indices_path is not None:
with open(os.path.join(indices_path, f"indices_num{args.num_query}_q{args.query_len}"), "rb") as f1, open(os.path.join(indices_path, f"indices_2_num{args.num_query}_q{args.query_len}"), "rb") as g1:
indices_1 = pickle.load(f1)
indices_2 = pickle.load(g1)
trajs_1, segs_1 = ds["traj_indices"][indices_1], ds["seg_indices"][indices_1]
trajs_2, segs_2 = ds["traj_indices"][indices_2], ds["seg_indices"][indices_2]
trajs = list(zip(trajs_1, trajs_2))
segs = list(zip(segs_1, segs_2))
# inds = np.argsort([int(elem[5:]) for elem in demos])
# demos = [demos[i] for i in inds]
# maybe reduce the number of demonstrations to playback
# if args.n is not None:
# demos = demos[:args.n]
# maybe dump video
# video_writer = None
# if write_video:
video_path = os.path.join(args.video_path, args.env.lower(), args.dataset_type)
os.makedirs(video_path, exist_ok=True)
for idx, (trj_1, trj_2) in tqdm(enumerate(trajs), total=len(trajs)):
video_writer = imageio.get_writer(os.path.join(video_path, f"video_{idx}.mp4"))
ep_1_key, ep_2_key = f"demo_{trj_1}", f"demo_{trj_2}"
# print(f["data/demos_1"])
# print(f"data group 1: {f[f'data/{ep_1_key}']}")
if args.use_obs:
playback_trajectory_with_obs(
traj_grp=[f[f"data/{ep_1_key}"], f[f"data/{ep_2_key}"]],
segs=segs[idx],
seg_length=args.query_len,
video_writer=video_writer,
video_skip=args.video_skip,
image_names=args.render_image_names,
first=args.first
)
video_writer.close()
f.close()
# for ind in range(len(demos)):
# ep = demos[ind]
# print("Playing back episode: {}".format(ep))
# if args.use_obs:
# playback_trajectory_with_obs(
# traj_grp=f["data/{}".format(ep)],
# video_writer=video_writer,
# video_skip=args.video_skip,
# image_names=args.render_image_names,
# first=args.first,
# )
# continue
# # prepare initial state to reload from
# states = f["data/{}/states".format(ep)][()]
# initial_state = dict(states=states[0])
# if is_robosuite_env:
# initial_state["model"] = f["data/{}".format(ep)].attrs["model_file"]
# # supply actions if using open-loop action playback
# actions = None
# if args.use_actions:
# actions = f["data/{}/actions".format(ep)][()]
# playback_trajectory_with_env(
# env=env,
# initial_state=initial_state,
# states=states, actions=actions,
# render=args.render,
# video_writer=video_writer,
# video_skip=args.video_skip,
# camera_names=args.render_image_names,
# first=args.first,
# )
f.close()
if write_video:
video_writer.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
help="path to hdf5 dataset",
)
parser.add_argument(
"--dataset_type",
type=str,
default="ph",
help="hdf5 type of dataset."
)
parser.add_argument(
"--env",
type=str,
default="lift",
help="env name."
)
parser.add_argument(
"--filter_key",
type=str,
default=None,
help="(optional) filter key, to select a subset of trajectories in the file",
)
# number of trajectories to playback. If omitted, playback all of them.
parser.add_argument(
"--n",
type=int,
default=None,
help="(optional) stop after n trajectories are played",
)
# Use image observations instead of doing playback using the simulator env.
parser.add_argument(
"--use-obs",
action='store_true',
help="visualize trajectories with dataset image observations instead of simulator",
)
# Playback stored dataset actions open-loop instead of loading from simulation states.
parser.add_argument(
"--use-actions",
action='store_true',
help="use open-loop action playback instead of loading sim states",
)
# Whether to render playback to screen
parser.add_argument(
"--render",
action='store_true',
help="on-screen rendering",
)
# Dump a video of the dataset playback to the specified path
parser.add_argument(
"--video_path",
type=str,
default=None,
help="(optional) render trajectories to this video file path",
)
# How often to write video frames during the playback
parser.add_argument(
"--video_skip",
type=int,
default=5,
help="render frames to video every n steps",
)
# camera names to render, or image observations to use for writing to video
parser.add_argument(
"--render_image_names",
type=str,
nargs='+',
default=None,
help="(optional) camera name(s) / image observation(s) to use for rendering on-screen or to video. Default is"
"None, which corresponds to a predefined camera for each env type",
)
# Only use the first frame of each episode
parser.add_argument(
"--first",
action='store_true',
help="use first frame of each episode",
)
parser.add_argument(
"--indices_path",
type=str,
default=None,
help="path for indices file."
)
parser.add_argument(
"--query_len",
type=int,
default=50,
help="query length for making videos."
)
parser.add_argument(
"--num_query",
type=int,
default=1000,
help="number of queries in offline dataset."
)
args = parser.parse_args()
playback_dataset(args)
================================================
FILE: JaxPref/jax_utils.py
================================================
import numpy as np
import jax
import jax.numpy as jnp
import optax
class JaxRNG(object):
def __init__(self, seed):
self.rng = jax.random.PRNGKey(seed)
def __call__(self):
self.rng, next_rng = jax.random.split(self.rng)
return next_rng
def init_rng(seed):
global jax_utils_rng
jax_utils_rng = JaxRNG(seed)
def next_rng():
global jax_utils_rng
return jax_utils_rng()
def extend_and_repeat(tensor, axis, repeat):
return jnp.repeat(jnp.expand_dims(tensor, axis), repeat, axis=axis)
def mse_loss(val, target):
return jnp.mean(jnp.square(val - target))
def cross_ent_loss(logits, target):
if len(target.shape) == 1:
label = jax.nn.one_hot(target, num_classes=2)
else:
label = target
loss = jnp.mean(optax.softmax_cross_entropy(
logits=logits,
labels=label))
return loss
def kld_loss(p, q):
return jnp.mean(jnp.sum(jnp.where(p != 0, p * (jnp.log(p) - jnp.log(q)), 0), axis=-1))
def custom_softmax(array, axis=-1, temperature=1.0):
array = array / temperature
return jax.nn.softmax(array, axis=axis)
def pref_accuracy(logits, target):
predicted_class = jnp.argmax(logits, axis=1)
target_class = jnp.argmax(target, axis=1)
return jnp.mean(predicted_class == target_class)
def value_and_multi_grad(fun, n_outputs, argnums=0, has_aux=False):
def select_output(index):
def wrapped(*args, **kwargs):
if has_aux:
x, *aux = fun(*args, **kwargs)
return (x[index], *aux)
else:
x = fun(*args, **kwargs)
return x[index]
return wrapped
grad_fns = tuple(
jax.value_and_grad(select_output(i), argnums=argnums, has_aux=has_aux)
for i in range(n_outputs)
)
def multi_grad_fn(*args, **kwargs):
grads = []
values = []
for grad_fn in grad_fns:
(value, *aux), grad = grad_fn(*args, **kwargs)
values.append(value)
grads.append(grad)
return (tuple(values), *aux), tuple(grads)
return multi_grad_fn
@jax.jit
def batch_to_jax(batch):
return jax.tree_util.tree_map(jax.device_put, batch)
================================================
FILE: JaxPref/model.py
================================================
from functools import partial
from typing import Callable
import numpy as np
import jax
import jax.numpy as jnp
import flax
from flax import linen as nn
import distrax
from .jax_utils import extend_and_repeat, next_rng
def multiple_action_q_function(forward):
# Forward the q function with multiple actions on each state, to be used as a decorator
def wrapped(self, observations, actions, **kwargs):
multiple_actions = False
batch_size = observations.shape[0]
if actions.ndim == 3 and observations.ndim == 2:
multiple_actions = True
observations = extend_and_repeat(observations, 1, actions.shape[1]).reshape(-1, observations.shape[-1])
actions = actions.reshape(-1, actions.shape[-1])
q_values = forward(self, observations, actions, **kwargs)
if multiple_actions:
q_values = q_values.reshape(batch_size, -1)
return q_values
return wrapped
class FullyConnectedNetwork(nn.Module):
output_dim: int
arch: str = '256-256'
orthogonal_init: bool = False
activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
activation_final: Callable[[jnp.ndarray], jnp.ndarray] = None
@nn.compact
def __call__(self, input_tensor):
x = input_tensor
hidden_sizes = [int(h) for h in self.arch.split('-')]
for h in hidden_sizes:
if self.orthogonal_init:
x = nn.Dense(
h,
kernel_init=jax.nn.initializers.orthogonal(jnp.sqrt(2.0)),
bias_init=jax.nn.initializers.zeros
)(x)
else:
x = nn.Dense(h)(x)
x = self.activations(x)
if self.orthogonal_init:
output = nn.Dense(
self.output_dim,
kernel_init=jax.nn.initializers.orthogonal(1e-2),
bias_init=jax.nn.initializers.zeros
)(x)
else:
output = nn.Dense(
self.output_dim,
kernel_init=jax.nn.initializers.variance_scaling(
1e-2, 'fan_in', 'uniform'
),
bias_init=jax.nn.initializers.zeros
)(x)
if self.activation_final is not None:
output = self.activation_final(output)
return output
class FullyConnectedQFunction(nn.Module):
observation_dim: int
action_dim: int
arch: str = '256-256'
orthogonal_init: bool = False
activations: str = 'relu'
activation_final: str = 'none'
@nn.compact
@multiple_action_q_function
def __call__(self, observations, actions):
x = jnp.concatenate([observations, actions], axis=-1)
activations = {
'relu': nn.relu,
'leaky_relu': nn.leaky_relu,
}[self.activations]
activation_final = {
'none': None,
'tanh': nn.tanh,
}[self.activation_final]
x = FullyConnectedNetwork(output_dim=1, arch=self.arch, orthogonal_init=self.orthogonal_init, activations=activations, activation_final=activation_final)(x)
return jnp.squeeze(x, -1)
================================================
FILE: JaxPref/new_preference_reward_main.py
================================================
import os
import pickle
from collections import defaultdict
import numpy as np
import transformers
import gym
import wrappers as wrappers
import absl.app
import absl.flags
from flax.training.early_stopping import EarlyStopping
from flaxmodels.flaxmodels.lstm.lstm import LSTMRewardModel
from flaxmodels.flaxmodels.gpt2.trajectory_gpt2 import TransRewardModel
import robosuite as suite
from robosuite.wrappers import GymWrapper
import robomimic.utils.env_utils as EnvUtils
from .sampler import TrajSampler
from .jax_utils import batch_to_jax
import JaxPref.reward_transform as r_tf
from .model import FullyConnectedQFunction
from viskit.logging import logger, setup_logger
from .MR import MR
from .replay_buffer import get_d4rl_dataset, index_batch
from .NMR import NMR
from .PrefTransformer import PrefTransformer
from .utils import Timer, define_flags_with_default, set_random_seed, get_user_flags, prefix_metrics, WandBLogger, save_pickle
# Jax memory
# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.50'
FLAGS_DEF = define_flags_with_default(
env='halfcheetah-medium-v2',
model_type='MLP',
max_traj_length=1000,
seed=42,
data_seed=42,
save_model=True,
batch_size=64,
early_stop=False,
min_delta=1e-3,
patience=10,
reward_scale=1.0,
reward_bias=0.0,
clip_action=0.999,
reward_arch='256-256',
orthogonal_init=False,
activations='relu',
activation_final='none',
training=True,
n_epochs=2000,
eval_period=5,
data_dir='./human_label',
num_query=1000,
query_len=25,
skip_flag=0,
balance=False,
topk=10,
window=2,
use_human_label=False,
feedback_random=False,
feedback_uniform=False,
enable_bootstrap=False,
comment='',
robosuite=False,
robosuite_dataset_type="ph",
robosuite_dataset_path='./data',
robosuite_max_episode_steps=500,
reward=MR.get_default_config(),
transformer=PrefTransformer.get_default_config(),
lstm=NMR.get_default_config(),
logging=WandBLogger.get_default_config(),
)
def main(_):
FLAGS = absl.flags.FLAGS
variant = get_user_flags(FLAGS, FLAGS_DEF)
save_dir = FLAGS.logging.output_dir + '/' + FLAGS.env
save_dir += '/' + str(FLAGS.model_type) + '/'
FLAGS.logging.group = f"{FLAGS.env}_{FLAGS.model_type}"
assert FLAGS.comment, "You must leave your comment for logging experiment."
FLAGS.logging.group += f"_{FLAGS.comment}"
FLAGS.logging.experiment_id = FLAGS.logging.group + f"_s{FLAGS.seed}"
save_dir += f"{FLAGS.comment}" + "/"
save_dir += 's' + str(FLAGS.seed)
setup_logger(
variant=variant,
seed=FLAGS.seed,
base_log_dir=save_dir,
include_exp_prefix_sub_dir=False
)
FLAGS.logging.output_dir = save_dir
wb_logger = WandBLogger(FLAGS.logging, variant=variant)
set_random_seed(FLAGS.seed)
if FLAGS.robosuite:
dataset = r_tf.qlearning_robosuite_dataset(os.path.join(FLAGS.robosuite_dataset_path, FLAGS.env.lower(), FLAGS.robosuite_dataset_type, "low_dim.hdf5"))
env = EnvUtils.create_env_from_metadata(
env_meta=dataset['env_meta'],
render=False,
render_offscreen=False
).env
gym_env = GymWrapper(env)
gym_env._max_episode_steps = gym_env.horizon
gym_env.seed(FLAGS.seed)
gym_env.action_space.seed(FLAGS.seed)
gym_env.observation_space.seed(FLAGS.seed)
gym_env.ignore_done = False
label_type = 1
elif 'ant' in FLAGS.env:
gym_env = gym.make(FLAGS.env)
gym_env = wrappers.EpisodeMonitor(gym_env)
gym_env = wrappers.SinglePrecision(gym_env)
gym_env.seed(FLAGS.seed)
gym_env.action_space.seed(FLAGS.seed)
gym_env.observation_space.seed(FLAGS.seed)
dataset = r_tf.qlearning_ant_dataset(gym_env)
label_type = 1
else:
gym_env = gym.make(FLAGS.env)
eval_sampler = TrajSampler(gym_env.unwrapped, FLAGS.max_traj_length)
dataset = get_d4rl_dataset(eval_sampler.env)
label_type = 0
dataset['actions'] = np.clip(dataset['actions'], -FLAGS.clip_action, FLAGS.clip_action)
# use fixed seed for collecting segments.
set_random_seed(FLAGS.data_seed)
print("load saved indices.")
if 'dense' in FLAGS.env:
env = "-".join(FLAGS.env.split("-")[:-2] + [FLAGS.env.split("-")[-1]])
elif FLAGS.robosuite:
env = f"{FLAGS.env}_{FLAGS.robosuite_dataset_type}"
else:
env = FLAGS.env
base_path = os.path.join(FLAGS.data_dir, env)
if os.path.exists(base_path):
human_indices_2_file, human_indices_1_file, human_labels_file = sorted(os.listdir(base_path))
with open(os.path.join(base_path, human_indices_1_file), "rb") as fp: # Unpickling
human_indices = pickle.load(fp)
with open(os.path.join(base_path, human_indices_2_file), "rb") as fp: # Unpickling
human_indices_2 = pickle.load(fp)
with open(os.path.join(base_path, human_labels_file), "rb") as fp: # Unpickling
human_labels = pickle.load(fp)
pref_dataset = r_tf.load_queries_with_indices(
gym_env, dataset, FLAGS.num_query, FLAGS.query_len,
label_type=label_type, saved_indices=[human_indices, human_indices_2], saved_labels=human_labels,
balance=FLAGS.balance, scripted_teacher=not FLAGS.use_human_label)
true_eval = True if len(human_labels) > FLAGS.num_query else False
pref_eval_dataset = r_tf.load_queries_with_indices(
gym_env, dataset, int(FLAGS.num_query * 0.1), FLAGS.query_len,
label_type=label_type, saved_indices=[human_indices, human_indices_2], saved_labels=human_labels,
balance=FLAGS.balance, scripted_teacher=not FLAGS.use_human_label)
else:
pref_dataset = r_tf.get_queries_from_multi(
gym_env, dataset, FLAGS.num_query, FLAGS.query_len,
data_dir=base_path, label_type=label_type, balance=FLAGS.balance)
human_indices_2_file, human_indices_1_file, script_labels_file = sorted(os.listdir(base_path))
with open(os.path.join(base_path, human_indices_1_file), "rb") as fp: # Unpickling
human_indices = pickle.load(fp)
with open(os.path.join(base_path, human_indices_2_file), "rb") as fp: # Unpickling
human_indices_2 = pickle.load(fp)
with open(os.path.join(base_path, script_labels_file), "rb") as fp: # Unpickling
human_labels = pickle.load(fp)
true_eval = True if len(human_labels) > FLAGS.num_query else False
pref_eval_dataset = r_tf.load_queries_with_indices(
gym_env, dataset, int(FLAGS.num_query * 0.1), FLAGS.query_len,
label_type=label_type, saved_indices=[human_indices, human_indices_2], saved_labels=human_labels,
balance=FLAGS.balance, topk=FLAGS.topk, scripted_teacher=True, window=FLAGS.window,
feedback_random=FLAGS.feedback_random, pref_attn_n_head=FLAGS.transformer.pref_attn_n_head, true_eval=true_eval)
set_random_seed(FLAGS.seed)
observation_dim = gym_env.observation_space.shape[0]
action_dim = gym_env.action_space.shape[0]
data_size = pref_dataset["observations"].shape[0]
interval = int(data_size / FLAGS.batch_size) + 1
eval_data_size = pref_eval_dataset["observations"].shape[0]
eval_interval = int(eval_data_size / FLAGS.batch_size) + 1
early_stop = EarlyStopping(min_delta=FLAGS.min_delta, patience=FLAGS.patience)
if FLAGS.model_type == "MR":
rf = FullyConnectedQFunction(observation_dim, action_dim, FLAGS.reward_arch, FLAGS.orthogonal_init, FLAGS.activations, FLAGS.activation_final)
reward_model = MR(FLAGS.reward, rf)
elif FLAGS.model_type == "PrefTransformer":
total_epochs = FLAGS.n_epochs
config = transformers.GPT2Config(
**FLAGS.transformer
)
config.warmup_steps = int(total_epochs * 0.1 * interval)
config.total_steps = total_epochs * interval
trans = TransRewardModel(config=config, observation_dim=observation_dim, action_dim=action_dim, activation=FLAGS.activations, activation_final=FLAGS.activation_final)
reward_model = PrefTransformer(config, trans)
elif FLAGS.model_type == "NMR":
total_epochs = FLAGS.n_epochs
config = transformers.GPT2Config(
**FLAGS.lstm
)
config.warmup_steps = int(total_epochs * 0.1 * interval)
config.total_steps = total_epochs * interval
lstm = LSTMRewardModel(config=config, observation_dim=observation_dim, action_dim=action_dim, activation=FLAGS.activations, activation_final=FLAGS.activation_final)
reward_model = NMR(config, lstm)
if FLAGS.model_type == "MR":
train_loss = "reward/rf_loss"
elif FLAGS.model_type == "NMR":
train_loss = "reward/lstm_loss"
elif FLAGS.model_type == "PrefTransformer":
train_loss = "reward/trans_loss"
criteria_key = None
for epoch in range(FLAGS.n_epochs + 1):
metrics = defaultdict(list)
metrics['epoch'] = epoch
if epoch:
# train phase
shuffled_idx = np.random.permutation(pref_dataset["observations"].shape[0])
for i in range(interval):
start_pt = i * FLAGS.batch_size
end_pt = min((i + 1) * FLAGS.batch_size, pref_dataset["observations"].shape[0])
with Timer() as train_timer:
# train
batch = batch_to_jax(index_batch(pref_dataset, shuffled_idx[start_pt:end_pt]))
for key, val in prefix_metrics(reward_model.train(batch), 'reward').items():
metrics[key].append(val)
metrics['train_time'] = train_timer()
else:
# for using early stopping with train loss.
metrics[train_loss] = [float(FLAGS.query_len)]
# eval phase
if epoch % FLAGS.eval_period == 0:
for j in range(eval_interval):
eval_start_pt, eval_end_pt = j * FLAGS.batch_size, min((j + 1) * FLAGS.batch_size, pref_eval_dataset["observations"].shape[0])
# batch_eval = batch_to_jax(index_batch(pref_eval_dataset, range(eval_start_pt, eval_end_pt)))
batch_eval = batch_to_jax(index_batch(pref_eval_dataset, range(eval_start_pt, eval_end_pt)))
for key, val in prefix_metrics(reward_model.evaluation(batch_eval), 'reward').items():
metrics[key].append(val)
if not criteria_key:
if "antmaze" in FLAGS.env and not "dense" in FLAGS.env and not true_eval:
# choose train loss as criteria.
criteria_key = train_loss
else:
# choose eval loss as criteria.
criteria_key = key
criteria = np.mean(metrics[criteria_key])
has_improved, early_stop = early_stop.update(criteria)
if early_stop.should_stop and FLAGS.early_stop:
for key, val in metrics.items():
if isinstance(val, list):
metrics[key] = np.mean(val)
logger.record_dict(metrics)
logger.dump_tabular(with_prefix=False, with_timestamp=False)
wb_logger.log(metrics)
print('Met early stopping criteria, breaking...')
break
elif epoch > 0 and has_improved:
metrics["best_epoch"] = epoch
metrics[f"{key}_best"] = criteria
save_data = {"reward_model": reward_model, "variant": variant, "epoch": epoch}
save_pickle(save_data, "best_model.pkl", save_dir)
for key, val in metrics.items():
if isinstance(val, list):
metrics[key] = np.mean(val)
logger.record_dict(metrics)
logger.dump_tabular(with_prefix=False, with_timestamp=False)
wb_logger.log(metrics)
if FLAGS.save_model:
save_data = {'reward_model': reward_model, 'variant': variant, 'epoch': epoch}
save_pickle(save_data, 'model.pkl', save_dir)
if __name__ == '__main__':
absl.app.run(main)
================================================
FILE: JaxPref/replay_buffer.py
================================================
from copy import copy, deepcopy
from queue import Queue
import threading
import d4rl
import numpy as np
import jax.numpy as jnp
class ReplayBuffer(object):
def __init__(self, max_size, data=None):
self._max_size = max_size
self._next_idx = 0
self._size = 0
self._initialized = False
self._total_steps = 0
if data is not None:
if self._max_size < data['observations'].shape[0]:
self._max_size = data['observations'].shape[0]
self.add_batch(data)
def __len__(self):
return self._size
def _init_storage(self, observation_dim, action_dim):
self._observation_dim = observation_dim
self._action_dim = action_dim
self._observations = np.zeros((self._max_size, observation_dim), dtype=np.float32)
self._next_observations = np.zeros((self._max_size, observation_dim), dtype=np.float32)
self._actions = np.zeros((self._max_size, action_dim), dtype=np.float32)
self._rewards = np.zeros(self._max_size, dtype=np.float32)
self._dones = np.zeros(self._max_size, dtype=np.float32)
self._next_idx = 0
self._size = 0
self._initialized = True
def add_sample(self, observation, action, reward, next_observation, done):
if not self._initialized:
self._init_storage(observation.size, action.size)
self._observations[self._next_idx, :] = np.array(observation, dtype=np.float32)
self._next_observations[self._next_idx, :] = np.array(next_observation, dtype=np.float32)
self._actions[self._next_idx, :] = np.array(action, dtype=np.float32)
self._rewards[self._next_idx] = reward
self._dones[self._next_idx] = float(done)
if self._size < self._max_size:
self._size += 1
self._next_idx = (self._next_idx + 1) % self._max_size
self._total_steps += 1
def add_traj(self, observations, actions, rewards, next_observations, dones):
for o, a, r, no, d in zip(observations, actions, rewards, next_observations, dones):
self.add_sample(o, a, r, no, d)
def add_batch(self, batch):
self.add_traj(
batch['observations'], batch['actions'], batch['rewards'],
batch['next_observations'], batch['dones']
)
def sample(self, batch_size):
indices = np.random.randint(len(self), size=batch_size)
return self.select(indices)
def select(self, indices):
return dict(
observations=self._observations[indices, ...],
actions=self._actions[indices, ...],
rewards=self._rewards[indices, ...],
next_observations=self._next_observations[indices, ...],
dones=self._dones[indices, ...],
)
def generator(self, batch_size, n_batchs=None):
i = 0
while n_batchs is None or i < n_batchs:
yield self.sample(batch_size)
i += 1
@property
def total_steps(self):
return self._total_steps
@property
def data(self):
return dict(
observations=self._observations[:self._size, ...],
actions=self._actions[:self._size, ...],
rewards=self._rewards[:self._size, ...],
next_observations=self._next_observations[:self._size, ...],
dones=self._dones[:self._size, ...]
)
def get_d4rl_dataset(env):
dataset = d4rl.qlearning_dataset(env)
return dict(
observations=dataset['observations'],
actions=dataset['actions'],
next_observations=dataset['next_observations'],
rewards=dataset['rewards'],
dones=dataset['terminals'].astype(np.float32),
)
def index_batch(batch, indices):
indexed = {}
for key in batch.keys():
indexed[key] = batch[key][indices, ...]
return indexed
def parition_batch_train_test(batch, train_ratio):
train_indices = np.random.rand(batch['observations'].shape[0]) < train_ratio
train_batch = index_batch(batch, train_indices)
test_batch = index_batch(batch, ~train_indices)
return train_batch, test_batch
def subsample_batch(batch, size):
indices = np.random.randint(batch['observations'].shape[0], size=size)
return index_batch(batch, indices)
def concatenate_batches(batches):
concatenated = {}
for key in batches[0].keys():
concatenated[key] = np.concatenate([batch[key] for batch in batches], axis=0).astype(np.float32)
return concatenated
def split_batch(batch, batch_size):
batches = []
length = batch['observations'].shape[0]
keys = batch.keys()
for start in range(0, length, batch_size):
end = min(start + batch_size, length)
batches.append({key: batch[key][start:end, ...] for key in keys})
return batches
def split_data_by_traj(data, max_traj_length):
dones = data['dones'].astype(bool)
start = 0
splits = []
for i, done in enumerate(dones):
if i - start + 1 >= max_traj_length or done:
splits.append(index_batch(data, slice(start, i + 1)))
start = i + 1
if start < len(dones):
splits.append(index_batch(data, slice(start, None)))
return splits
================================================
FILE: JaxPref/reward_transform.py
================================================
import os
import h5py
import pickle
from tqdm import tqdm
import numpy as np
import ujson as json
import jax.numpy as jnp
def get_goal(name):
if 'large' in name:
return (32.0, 24.0)
elif 'medium' in name:
return (20.0, 20.0)
elif 'umaze' in name:
return (0.0, 8.0)
return None
def new_get_trj_idx(env, terminate_on_end=False, **kwargs):
if not hasattr(env, 'get_dataset'):
dataset = kwargs['dataset']
else:
dataset = env.get_dataset()
N = dataset['rewards'].shape[0]
# The newer version of the dataset adds an explicit
# timeouts field. Keep old method for backwards compatability.
use_timeouts = False
if 'timeouts' in dataset:
use_timeouts = True
episode_step = 0
start_idx, data_idx = 0, 0
trj_idx_list = []
for i in range(N-1):
if env.spec and 'maze' in env.spec.id:
done_bool = sum(dataset['infos/goal'][i+1] - dataset['infos/goal'][i]) > 0
else:
done_bool = bool(dataset['terminals'][i])
if use_timeouts:
final_timestep = dataset['timeouts'][i]
else:
final_timestep = (episode_step == env._max_episode_steps - 1)
if (not terminate_on_end) and final_timestep:
# Skip this transition and don't apply terminals on the last step of an episode
episode_step = 0
trj_idx_list.append([start_idx, data_idx-1])
start_idx = data_idx
continue
if done_bool or final_timestep:
episode_step = 0
trj_idx_list.append([start_idx, data_idx])
start_idx = data_idx + 1
episode_step += 1
data_idx += 1
trj_idx_list.append([start_idx, data_idx])
return trj_idx_list
def get_queries_from_multi(env, dataset, num_query, len_query, data_dir=None, balance=False, label_type=0, skip_flag=0):
os.makedirs(data_dir, exist_ok=True)
trj_idx_list = new_get_trj_idx(env, dataset=dataset) # get_nonmdp_trj_idx(env)
labeler_info = np.zeros(len(trj_idx_list) - 1)
# to-do: parallel implementation
trj_idx_list = np.array(trj_idx_list)
trj_len_list = trj_idx_list[:,1] - trj_idx_list[:,0] + 1
assert max(trj_len_list) > len_query
total_reward_seq_1, total_reward_seq_2 = np.zeros((num_query, len_query)), np.zeros((num_query, len_query))
observation_dim = dataset["observations"].shape[-1]
total_obs_seq_1, total_obs_seq_2 = np.zeros((num_query, len_query, observation_dim)), np.zeros((num_query, len_query, observation_dim))
total_next_obs_seq_1, total_next_obs_seq_2 = np.zeros((num_query, len_query, observation_dim)), np.zeros((num_query, len_query, observation_dim))
action_dim = dataset["actions"].shape[-1]
total_act_seq_1, total_act_seq_2 = np.zeros((num_query, len_query, action_dim)), np.zeros((num_query, len_query, action_dim))
total_timestep_1, total_timestep_2 = np.zeros((num_query, len_query), dtype=np.int32), np.zeros((num_query, len_query), dtype=np.int32)
start_indices_1, start_indices_2 = np.zeros(num_query), np.zeros(num_query)
time_indices_1, time_indices_2 = np.zeros(num_query), np.zeros(num_query)
indices_1_filename = os.path.join(data_dir, f"indices_num{num_query}_q{len_query}")
indices_2_filename = os.path.join(data_dir, f"indices_2_num{num_query}_q{len_query}")
label_dummy_filename = os.path.join(data_dir, f"label_dummy")
if not os.path.exists(indices_1_filename) or not os.path.exists(indices_2_filename):
for query_count in tqdm(range(num_query), desc="get queries"):
temp_count = 0
labeler = -1
while(temp_count < 2):
trj_idx = np.random.choice(np.arange(len(trj_idx_list) - 1)[np.logical_not(labeler_info)])
len_trj = trj_len_list[trj_idx]
if len_trj > len_query and (temp_count == 0 or labeler_info[trj_idx] == labeler):
labeler = labeler_info[trj_idx]
time_idx = np.random.choice(len_trj - len_query + 1)
start_idx = trj_idx_list[trj_idx][0] + time_idx
end_idx = start_idx + len_query
assert end_idx <= trj_idx_list[trj_idx][1] + 1
reward_seq = dataset['rewards'][start_idx:end_idx]
obs_seq = dataset['observations'][start_idx:end_idx]
next_obs_seq = dataset['next_observations'][start_idx:end_idx]
act_seq = dataset['actions'][start_idx:end_idx]
# timestep_seq = np.arange(time_idx + 1, time_idx + len_query + 1)
timestep_seq = np.arange(1, len_query + 1)
# skip flag 1: skip queries with equal rewards.
if skip_flag == 1 and temp_count == 1:
if np.sum(total_reward_seq_1[-1]) == np.sum(reward_seq):
continue
# skip flag 2: keep queries with equal reward until 50% of num_query.
if skip_flag == 2 and temp_count == 1 and query_count < int(0.5*num_query):
if np.sum(total_reward_seq_1[-1]) == np.sum(reward_seq):
continue
# skip flag 3: keep queries with equal reward until 20% of num_query.
if skip_flag == 3 and temp_count == 1 and query_count < int(0.2*num_query):
if np.sum(total_reward_seq_1[-1]) == np.sum(reward_seq):
continue
if temp_count == 0:
start_indices_1[query_count] = start_idx
time_indices_1[query_count] = time_idx
total_reward_seq_1[query_count] = reward_seq
total_obs_seq_1[query_count] = obs_seq
total_next_obs_seq_1[query_count] = next_obs_seq
total_act_seq_1[query_count] = act_seq
total_timestep_1[query_count] = timestep_seq
else:
start_indices_2[query_count] = start_idx
time_indices_2[query_count] = time_idx
total_reward_seq_2[query_count] = reward_seq
total_obs_seq_2[query_count] = obs_seq
total_next_obs_seq_2[query_count] = next_obs_seq
total_act_seq_2[query_count] = act_seq
total_timestep_2[query_count] = timestep_seq
temp_count += 1
seg_reward_1 = total_reward_seq_1.copy()
seg_reward_2 = total_reward_seq_2.copy()
seg_obs_1 = total_obs_seq_1.copy()
seg_obs_2 = total_obs_seq_2.copy()
seg_next_obs_1 = total_next_obs_seq_1.copy()
seg_next_obs_2 = total_next_obs_seq_2.copy()
seq_act_1 = total_act_seq_1.copy()
seq_act_2 = total_act_seq_2.copy()
seq_timestep_1 = total_timestep_1.copy()
seq_timestep_2 = total_timestep_2.copy()
if label_type == 0: # perfectly rational
sum_r_t_1 = np.sum(seg_reward_1, axis=1)
sum_r_t_2 = np.sum(seg_reward_2, axis=1)
binary_label = 1*(sum_r_t_1 < sum_r_t_2)
rational_labels = np.zeros((len(binary_label), 2))
rational_labels[np.arange(binary_label.size), binary_label] = 1.0
elif label_type == 1:
sum_r_t_1 = np.sum(seg_reward_1, axis=1)
sum_r_t_2 = np.sum(seg_reward_2, axis=1)
binary_label = 1*(sum_r_t_1 < sum_r_t_2)
rational_labels = np.zeros((len(binary_label), 2))
rational_labels[np.arange(binary_label.size), binary_label] = 1.0
margin_index = (np.abs(sum_r_t_1 - sum_r_t_2) <= 0).reshape(-1)
rational_labels[margin_index] = 0.5
start_indices_1 = np.array(start_indices_1, dtype=np.int32)
start_indices_2 = np.array(start_indices_2, dtype=np.int32)
time_indices_1 = np.array(time_indices_1, dtype=np.int32)
time_indices_2 = np.array(time_indices_2, dtype=np.int32)
batch = {}
batch['labels'] = rational_labels
batch['observations'] = seg_obs_1 # for compatibility, remove "_1"
batch['next_observations'] = seg_next_obs_1
batch['actions'] = seq_act_1
batch['observations_2'] = seg_obs_2
batch['next_observations_2'] = seg_next_obs_2
batch['actions_2'] = seq_act_2
batch['timestep_1'] = seq_timestep_1
batch['timestep_2'] = seq_timestep_2
batch['start_indices'] = start_indices_1
batch['start_indices_2'] = start_indices_2
# balancing data with zero_labels
if balance:
nonzero_condition = np.any(batch["labels"] != [0.5, 0.5], axis=1)
nonzero_idx, = np.where(nonzero_condition)
zero_idx, = np.where(np.logical_not(nonzero_condition))
selected_zero_idx = np.random.choice(zero_idx, len(nonzero_idx))
for key, val in batch.items():
batch[key] = val[np.concatenate([selected_zero_idx, nonzero_idx])]
print(f"size of batch after balancing: {len(batch['labels'])}")
with open(indices_1_filename, "wb") as fp, open(indices_2_filename, "wb") as gp, open(label_dummy_filename, "wb") as hp:
pickle.dump(batch['start_indices'], fp)
pickle.dump(batch['start_indices_2'], gp)
pickle.dump(np.ones_like(batch['labels']), hp)
else:
with open(indices_1_filename, "rb") as fp, open(indices_2_filename, "rb") as gp:
indices_1, indices_2 = pickle.load(fp), pickle.load(gp)
return load_queries_with_indices(
env, dataset, num_query, len_query,
label_type=label_type, saved_indices=[indices_1, indices_2],
saved_labels=None, balance=balance, scripted_teacher=True
)
return batch
def find_time_idx(trj_idx_list, idx):
for (start, end) in trj_idx_list:
if start <= idx <= end:
return idx - start
def load_queries_with_indices(env, dataset, num_query, len_query, label_type, saved_indices, saved_labels, balance=False, scripted_teacher=False):
trj_idx_list = new_get_trj_idx(env, dataset=dataset) # get_nonmdp_trj_idx(env)
# to-do: parallel implementation
trj_idx_list = np.array(trj_idx_list)
trj_len_list = trj_idx_list[:, 1] - trj_idx_list[:, 0] + 1
assert max(trj_len_list) > len_query
total_reward_seq_1, total_reward_seq_2 = np.zeros((num_query, len_query)), np.zeros((num_query, len_query))
observation_dim = dataset["observations"].shape[-1]
action_dim = dataset["actions"].shape[-1]
total_obs_seq_1, total_obs_seq_2 = np.zeros((num_query, len_query, observation_dim)), np.zeros((num_query, len_query, observation_dim))
total_next_obs_seq_1, total_next_obs_seq_2 = np.zeros((num_query, len_query, observation_dim)), np.zeros((num_query, len_query, observation_dim))
total_act_seq_1, total_act_seq_2 = np.zeros((num_query, len_query, action_dim)), np.zeros((num_query, len_query, action_dim))
total_timestep_1, total_timestep_2 = np.zeros((num_query, len_query), dtype=np.int32), np.zeros((num_query, len_query), dtype=np.int32)
if saved_labels is None:
query_range = np.arange(num_query)
else:
query_range = np.arange(len(saved_labels) - num_query, len(saved_labels))
for query_count, i in enumerate(tqdm(query_range, desc="get queries from saved indices")):
temp_count = 0
while(temp_count < 2):
start_idx = saved_indices[temp_count][i]
end_idx = start_idx + len_query
reward_seq = dataset['rewards'][start_idx:end_idx]
obs_seq = dataset['observations'][start_idx:end_idx]
next_obs_seq = dataset['next_observations'][start_idx:end_idx]
act_seq = dataset['actions'][start_idx:end_idx]
timestep_seq = np.arange(1, len_query + 1)
if temp_count == 0:
total_reward_seq_1[query_count] = reward_seq
total_obs_seq_1[query_count] = obs_seq
total_next_obs_seq_1[query_count] = next_obs_seq
total_act_seq_1[query_count] = act_seq
total_timestep_1[query_count] = timestep_seq
else:
total_reward_seq_2[query_count] = reward_seq
total_obs_seq_2[query_count] = obs_seq
total_next_obs_seq_2[query_count] = next_obs_seq
total_act_seq_2[query_count] = act_seq
total_timestep_2[query_count] = timestep_seq
temp_count += 1
seg_reward_1 = total_reward_seq_1.copy()
seg_reward_2 = total_reward_seq_2.copy()
seg_obs_1 = total_obs_seq_1.copy()
seg_obs_2 = total_obs_seq_2.copy()
seg_next_obs_1 = total_next_obs_seq_1.copy()
seg_next_obs_2 = total_next_obs_seq_2.copy()
seq_act_1 = total_act_seq_1.copy()
seq_act_2 = total_act_seq_2.copy()
seq_timestep_1 = total_timestep_1.copy()
seq_timestep_2 = total_timestep_2.copy()
if label_type == 0: # perfectly rational
sum_r_t_1 = np.sum(seg_reward_1, axis=1)
sum_r_t_2 = np.sum(seg_reward_2, axis=1)
binary_label = 1*(sum_r_t_1 < sum_r_t_2)
rational_labels = np.zeros((len(binary_label), 2))
rational_labels[np.arange(binary_label.size), binary_label] = 1.0
elif label_type == 1:
sum_r_t_1 = np.sum(seg_reward_1, axis=1)
sum_r_t_2 = np.sum(seg_reward_2, axis=1)
binary_label = 1*(sum_r_t_1 < sum_r_t_2)
rational_labels = np.zeros((len(binary_label), 2))
rational_labels[np.arange(binary_label.size), binary_label] = 1.0
margin_index = (np.abs(sum_r_t_1 - sum_r_t_2) <= 0).reshape(-1)
rational_labels[margin_index] = 0.5
batch = {}
if scripted_teacher:
# counter part of human label for comparing with human label.
batch['labels'] = rational_labels
else:
human_labels = np.zeros((len(saved_labels), 2))
human_labels[np.array(saved_labels)==0,0] = 1.
human_labels[np.array(saved_labels)==1,1] = 1.
human_labels[np.array(saved_labels)==-1] = 0.5
human_labels = human_labels[query_range]
batch['labels'] = human_labels
batch['script_labels'] = rational_labels
batch['observations'] = seg_obs_1 # for compatibility, remove "_1"
batch['next_observations'] = seg_next_obs_1
batch['actions'] = seq_act_1
batch['observations_2'] = seg_obs_2
batch['next_observations_2'] = seg_next_obs_2
batch['actions_2'] = seq_act_2
batch['timestep_1'] = seq_timestep_1
batch['timestep_2'] = seq_timestep_2
batch['start_indices'] = saved_indices[0]
batch['start_indices_2'] = saved_indices[1]
if balance:
nonzero_condition = np.any(batch["labels"] != [0.5, 0.5], axis=1)
nonzero_idx, = np.where(nonzero_condition)
zero_idx, = np.where(np.logical_not(nonzero_condition))
selected_zero_idx = np.random.choice(zero_idx, len(nonzero_idx))
for key, val in batch.items():
batch[key] = val[np.concatenate([selected_zero_idx, nonzero_idx])]
print(f"size of batch after balancing: {len(batch['labels'])}")
return batch
def qlearning_ant_dataset(env, dataset=None, terminate_on_end=False, **kwargs):
"""
Returns datasets formatted for use by standard Q-learning algorithms,
with observations, actions, next_observations, rewards, and a terminal
flag.
Args:
env: An OfflineEnv object.
dataset: An optional dataset to pass in for processing. If None,
the dataset will default to env.get_dataset()
terminate_on_end (bool): Set done=True on the last timestep
in a trajectory. Default is False, and will discard the
last timestep in each trajectory.
**kwargs: Arguments to pass to env.get_dataset().
Returns:
A dictionary containing keys:
observations: An N x dim_obs array of observations.
actions: An N x dim_action array of actions.
next_observations: An N x dim_obs array of next observations.
rewards: An N-dim float array of rewards.
terminals: An N-dim boolean array of "done" or episode termination flags.
"""
if dataset is None:
dataset = env.get_dataset(**kwargs)
N = dataset['rewards'].shape[0]
obs_ = []
next_obs_ = []
action_ = []
reward_ = []
done_ = []
goal_ = []
xy_ = []
done_bef_ = []
# The newer version of the dataset adds an explicit
# timeouts field. Keep old method for backwards compatability.
use_timeouts = False
if 'timeouts' in dataset:
use_timeouts = True
episode_step = 0
for i in range(N-1):
obs = dataset['observations'][i].astype(np.float32)
new_obs = dataset['observations'][i+1].astype(np.float32)
action = dataset['actions'][i].astype(np.float32)
reward = dataset['rewards'][i].astype(np.float32)
done_bool = bool(dataset['terminals'][i])
goal = dataset['infos/goal'][i].astype(np.float32)
xy = dataset['infos/qpos'][i][:2].astype(np.float32)
if use_timeouts:
final_timestep = dataset['timeouts'][i]
next_final_timestep = dataset['timeouts'][i+1]
else:
final_timestep = (episode_step == env._max_episode_steps - 1)
next_final_timestep = (episode_step == env._max_episode_steps - 2)
done_bef = bool(next_final_timestep)
if (not terminate_on_end) and final_timestep:
# Skip this transition and don't apply terminals on the last step of an episode
episode_step = 0
continue
if done_bool or final_timestep:
episode_step = 0
obs_.append(obs)
next_obs_.append(new_obs)
action_.append(action)
reward_.append(reward)
done_.append(done_bool)
goal_.append(goal)
xy_.append(xy)
done_bef_.append(done_bef)
episode_step += 1
return {
'observations': np.array(obs_),
'actions': np.array(action_),
'next_observations': np.array(next_obs_),
'rewards': np.array(reward_),
'terminals': np.array(done_),
'goals': np.array(goal_),
'xys': np.array(xy_),
'dones_bef': np.array(done_bef_)
}
def qlearning_robosuite_dataset(dataset_path, terminate_on_end=False, **kwargs):
"""
Returns datasets formatted for use by standard Q-learning algorithms,
with observations, actions, next_observations, rewards, and a terminal
flag.
Args:
env: An OfflineEnv object.
dataset: An optional dataset to pass in for processing. If None,
the dataset will default to env.get_dataset()
terminate_on_end (bool): Set done=True on the last timestep
in a trajectory. Default is False, and will discard the
last timestep in each trajectory.
**kwargs: Arguments to pass to env.get_dataset().
Returns:
A dictionary containing keys:
observations: An N x dim_obs array of observations.
actions: An N x dim_action array of actions.
next_observations: An N x dim_obs array of next observations.
rewards: An N-dim float array of rewards.
terminals: An N-dim boolean array of "done" or episode termination flags.
"""
f = h5py.File(dataset_path, 'r')
# N = dataset['rewards'].shape[0]
demos = list(f['data'].keys())
N = len(demos)
obs_ = []
next_obs_ = []
action_ = []
reward_ = []
done_ = []
traj_idx_ = []
seg_idx_ = []
# The newer version of the dataset adds an explicit
# timeouts field. Keep old method for backwards compatability.
use_timeouts = False
# if 'timeouts' in dataset:
# use_timeouts = True
episode_step = 0
obs_keys = kwargs.get("obs_key", ["object", "robot0_joint_pos", "robot0_joint_pos_cos", "robot0_joint_pos_sin", "robot0_joint_vel", "robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos", "robot0_gripper_qvel"])
for ep in tqdm(demos, desc="load robosuite demonstrations"):
ep_grp = f[f"data/{ep}"]
traj_len = ep_grp["actions"].shape[0]
for i in range(traj_len - 1):
total_obs = ep_grp["obs"]
obs = np.concatenate([total_obs[key][i].tolist() for key in obs_keys], axis=0)
new_obs = np.concatenate([total_obs[key][i + 1].tolist() for key in obs_keys], axis=0)
action = ep_grp["actions"][i]
reward = ep_grp["rewards"][i]
done_bool = bool(ep_grp["dones"][i])
obs_.append(obs)
next_obs_.append(new_obs)
action_.append(action)
reward_.append(reward)
done_.append(done_bool)
traj_idx_.append(int(ep[5:]))
seg_idx_.append(i)
return {
'observations': np.array(obs_),
'actions': np.array(action_),
'next_observations': np.array(next_obs_),
'rewards': np.array(reward_),
'terminals': np.array(done_),
'env_meta': json.loads(f["data"].attrs["env_args"]),
'traj_indices': np.array(traj_idx_),
'seg_indices': np.array(seg_idx_),
}
================================================
FILE: JaxPref/sampler.py
================================================
import numpy as np
import JaxPref.reward_transform as r_tf
class StepSampler(object):
def __init__(self, env, max_traj_length=1000, reward_trans=None, act_flag=False, act_coeff=1e-3):
self.max_traj_length = max_traj_length
self._env = env
self._traj_steps = 0
self._current_observation = self.env.reset()
self._reward_trans = reward_trans
self._act_flag = act_flag
self._act_coeff = act_coeff
def sample(self, policy, n_steps, deterministic=False, replay_buffer=None):
observations = []
actions = []
rewards = []
next_observations = []
dones = []
for _ in range(n_steps):
self._traj_steps += 1
observation = self._current_observation
action = policy(observation.reshape(1, -1), deterministic=deterministic).reshape(-1)
next_observation, reward, done, info = self.env.step(action)
observations.append(observation)
actions.append(action)
if self._reward_trans is not None:
if self._act_flag:
reward_run = reward + self._act_coeff*np.square(action).sum()
new_reward = self._reward_trans(reward_run, np.square(action).sum())
else:
new_reward = self._reward_trans(reward)
reward = new_reward
rewards.append(reward)
dones.append(done)
next_observations.append(next_observation)
if replay_buffer is not None:
replay_buffer.add_sample(
observation, action, reward, next_observation, done
)
self._current_observation = next_observation
if done or self._traj_steps >= self.max_traj_length:
self._traj_steps = 0
self._current_observation = self.env.reset()
return dict(
observations=np.array(observations, dtype=np.float32),
actions=np.array(actions, dtype=np.float32),
rewards=np.array(rewards, dtype=np.float32),
next_observations=np.array(next_observations, dtype=np.float32),
dones=np.array(dones, dtype=np.float32),
)
@property
def env(self):
return self._env
class TrajSampler(object):
def __init__(self, env, max_traj_length=1000, loco_flag=True):
self.max_traj_length = max_traj_length
self._env = env
self._loco_flag = loco_flag
if not self._loco_flag:
self.goal = r_tf.get_goal(env.unwrapped.spec.id)
def sample(self, policy, n_trajs, deterministic=False, replay_buffer=None):
trajs = []
for _ in range(n_trajs):
observations = []
actions = []
rewards = []
rewards_run = []
rewards_ctrl = []
next_observations = []
dones = []
distance = []
observation = self.env.reset()
for _ in range(self.max_traj_length):
action = policy(observation.reshape(1, -1), deterministic=deterministic).reshape(-1)
next_observation, reward, done, info = self.env.step(action)
observations.append(observation)
actions.append(action)
rewards.append(reward)
if self._loco_flag:
rewards_run.append(info['reward_run'])
rewards_ctrl.append(info['reward_ctrl'])
else:
xy = next_observation[:2]
distance.append(np.linalg.norm(xy-self.goal))
dones.append(done)
next_observations.append(next_observation)
if replay_buffer is not None:
replay_buffer.add_sample(
observation, action, reward, next_observation, done
)
observation = next_observation
if done:
break
trajs.append(dict(
observations=np.array(observations, dtype=np.float32),
actions=np.array(actions, dtype=np.float32),
rewards=np.array(rewards, dtype=np.float32),
rewards_run=np.array(rewards_run, dtype=np.float32),
rewards_ctrl=np.array(rewards_ctrl, dtype=np.float32),
next_observations=np.array(next_observations, dtype=np.float32),
dones=np.array(dones, dtype=np.float32),
distance=np.array(distance, dtype=np.float32)
))
return trajs
@property
def env(self):
return self._env
================================================
FILE: JaxPref/utils.py
================================================
import random
import pprint
import time
import uuid
import tempfile
import os
from copy import copy
from socket import gethostname
import cloudpickle as pickle
import numpy as np
import absl.flags
from absl import logging
from ml_collections import ConfigDict
from ml_collections.config_flags import config_flags
from ml_collections.config_dict import config_dict
import wandb
from .jax_utils import init_rng
class Timer(object):
def __init__(self):
self._time = None
def __enter__(self):
self._start_time = time.time()
return self
def __exit__(self, exc_type, exc_value, exc_tb):
self._time = time.time() - self._start_time
def __call__(self):
return self._time
class WandBLogger(object):
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.online = False
config.prefix = ''
config.project = 'PrefRL'
config.output_dir = './reward_model'
config.random_delay = 0.0
config.group = config_dict.placeholder(str)
config.experiment_id = config_dict.placeholder(str)
config.anonymous = config_dict.placeholder(str)
config.notes = config_dict.placeholder(str)
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
def __init__(self, config, variant):
self.config = self.get_default_config(config)
if self.config.experiment_id is None:
self.config.experiment_id = uuid.uuid4().hex
if self.config.prefix != '':
self.config.project = '{}--{}'.format(self.config.prefix, self.config.project)
if self.config.output_dir == '':
self.config.output_dir = tempfile.mkdtemp()
else:
# self.config.output_dir = os.path.join(self.config.output_dir, self.config.experiment_id)
os.makedirs(self.config.output_dir, exist_ok=True)
self._variant = copy(variant)
if 'hostname' not in self._variant:
self._variant['hostname'] = gethostname()
if self.config.random_delay > 0:
time.sleep(np.random.uniform(0, self.config.random_delay))
self.run = wandb.init(
reinit=True,
config=self._variant,
project=self.config.project,
dir=self.config.output_dir,
group=self.config.group,
name=self.config.experiment_id,
# anonymous=self.config.anonymous,
notes=self.config.notes,
settings=wandb.Settings(
start_method="thread",
_disable_stats=True,
),
mode='online' if self.config.online else 'offline',
)
def log(self, *args, **kwargs):
self.run.log(*args, **kwargs)
def save_pickle(self, obj, filename):
with open(os.path.join(self.config.output_dir, filename), 'wb') as fout:
pickle.dump(obj, fout)
@property
def experiment_id(self):
return self.config.experiment_id
@property
def variant(self):
return self.config.variant
@property
def output_dir(self):
return self.config.output_dir
def define_flags_with_default(**kwargs):
for key, val in kwargs.items():
if isinstance(val, ConfigDict):
config_flags.DEFINE_config_dict(key, val)
elif isinstance(val, bool):
# Note that True and False are instances of int.
absl.flags.DEFINE_bool(key, val, 'automatically defined flag')
elif isinstance(val, int):
absl.flags.DEFINE_integer(key, val, 'automatically defined flag')
elif isinstance(val, float):
absl.flags.DEFINE_float(key, val, 'automatically defined flag')
elif isinstance(val, str):
absl.flags.DEFINE_string(key, val, 'automatically defined flag')
else:
raise ValueError('Incorrect value type')
return kwargs
def set_random_seed(seed):
np.random.seed(seed)
random.seed(seed)
init_rng(seed)
def print_flags(flags, flags_def):
logging.info(
'Running training with hyperparameters: \n{}'.format(
pprint.pformat(
['{}: {}'.format(key, val) for key, val in get_user_flags(flags, flags_def).items()]
)
)
)
def get_user_flags(flags, flags_def):
output = {}
for key in flags_def:
val = getattr(flags, key)
if isinstance(val, ConfigDict):
output.update(flatten_config_dict(val, prefix=key))
else:
output[key] = val
return output
def flatten_config_dict(config, prefix=None):
output = {}
for key, val in config.items():
if prefix is not None:
next_prefix = '{}.{}'.format(prefix, key)
else:
next_prefix = key
if isinstance(val, ConfigDict):
output.update(flatten_config_dict(val, prefix=next_prefix))
else:
output[next_prefix] = val
return output
def save_pickle(obj, filename, output_dir):
with open(os.path.join(output_dir, filename), 'wb') as fout:
pickle.dump(obj, fout)
def prefix_metrics(metrics, prefix):
return {
'{}/{}'.format(prefix, key): value for key, value in metrics.items()
}
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2021 Ilya Kostrikov, Ashvin Nair, Sergey Levine
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# Preference Transformer: Modeling Human Preferences using Transformers for RL (ICLR 2023)
Official Jax/Flax implementation of **[Preference Transformer: Modeling Human Preferences using Transformers for RL](https://openreview.net/forum?id=Peot1SFDX0)** by [Changyeon Kim*](https://changyeon.page),1, [Jongjin Park*](https://pjj4288.github.io/),1, [Jinwoo Shin](https://alinlab.kaist.ac.kr/shin.html)1, [Honglak Lee](https://web.eecs.umich.edu/~honglak/)2,3, [Pieter Abbeel](http://people.eecs.berkeley.edu/~pabbeel/)4, [Kimin Lee](https://sites.google.com/view/kiminlee)51KAIST, 2University of Michigan 3LG AI Research 4UC Berkeley 5Google Research
**TL;DR**: We introduce a transformer-based architecture for preference-based RL considering non-Markovian rewards.
[paper](https://openreview.net/pdf?id=Peot1SFDX0)
Overview of Preference Transformer. We first construct hidden embeddings $\{\mathbf{x}_t\}$ through the causal transformer, where each represents the context information from the initial timestep to timestep $t$. The preference attention layer with a bidirectional self-attention computes the non-Markovian rewards $\{\hat{r}_t\} and their convex combinations $\{z_t \}$ from those hidden embeddings, then we aggregate $\{z_t \}$ for modeling the weighted sum of non-Markovian rewards $\sum_{t}{w_t \hat{r}_t }$.
## NOTICE
In this new version, we release the **real human preference** for various dataset in D4RL and Robosuite.
## How to run the code
### Install dependencies
```
conda create -y -n offline python=3.8
conda activate offline
pip install --upgrade pip
conda install -y -c conda-forge cudatoolkit=11.1 cudnn=8.2.1
pip install -r requirements.txt
cd d4rl
pip install -e .
cd ..
# Installs the wheel compatible with Cuda 11 and cudnn 8.
pip install "jax[cuda11_cudnn805]>=0.2.27" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install protobuf==3.20.1 gym<0.24.0 distrax==0.1.2 wandb
pip install transformers
```
## D4RL
### Run Training Reward Model
```python
# Preference Transfomer (PT)
CUDA_VISIBLE_DEVICES=0 python -m JaxPref.new_preference_reward_main --use_human_label True --comment {experiment_name} --transformer.embd_dim 256 --transformer.n_layer 1 --transformer.n_head 4 --env {D4RL env name} --logging.output_dir './logs/pref_reward' --batch_size 256 --num_query {number of query} --query_len 100 --n_epochs 10000 --skip_flag 0 --seed {seed} --model_type PrefTransformer
# Non-Markovian Reward (NMR)
CUDA_VISIBLE_DEVICES=0 python -m JaxPref.new_preference_reward_main --use_human_label True --comment {experiment_name} --env {D4RL env name} --logging.output_dir './logs/pref_reward' --batch_size 256 --num_query {number of query} --query_len 100 --n_epochs 10000 --skip_flag 0 --seed {seed} --model_type NMR
# Markovian Reward (MR)
CUDA_VISIBLE_DEVICES=0 python -m JaxPref.new_preference_reward_main --use_human_label True --comment {experiment_name} --env {D4RL env name} --logging.output_dir './logs/pref_reward' --batch_size 256 --num_query {number of query} --query_len 100 --n_epochs 10000 --skip_flag 0 --seed {seed} --model_type MR
```
### Run IQL with learned Reward Model
```python
# Preference Transfomer (PT)
CUDA_VISIBLE_DEVICES=0 python train_offline.py --seq_len {sequence length in reward prediction} --comment {experiment_name} --eval_interval {5000: mujoco / 100000: antmaze / 50000: adroit} --env_name {d4rl env name} --config {configs/(mujoco|antmaze|adroit)_config.py} --eval_episodes {100 for ant , 10 o.w.} --use_reward_model True --model_type PrefTransformer --ckpt_dir {reward_model_path} --seed {seed}
# Non-Markovian Reward (NMR)
CUDA_VISIBLE_DEVICES=0 python train_offline.py --seq_len {sequence length in reward prediction} --comment {experiment_name} --eval_interval {5000: mujoco / 100000: antmaze / 50000: adroit} --env_name {d4rl env name} --config {configs/(mujoco|antmaze|adroit)_config.py} --eval_episodes {100 for ant , 10 o.w.} --use_reward_model True --model_type NMR --ckpt_dir {reward_model_path} --seed {seed}
# Markovian Reward (MR)
CUDA_VISIBLE_DEVICES=0 python train_offline.py --comment {experiment_name} --eval_interval {5000: mujoco / 100000: antmaze / 50000: adroit} --env_name {d4rl env name} --config {configs/(mujoco|antmaze|adroit)_config.py} --eval_episodes {100 for ant , 10 o.w.} --use_reward_model True --model_type MR --ckpt_dir {reward_model_path} --seed {seed}
```
## Robosuite
### Preliminaries
You must download the robomimic (https://robomimic.github.io/) dataset.
Please refer to this website: https://robomimic.github.io/docs/datasets/robomimic_v0.1.html
### Run Training Reward Model
```bash
# Preference Transfomer (PT)
CUDA_VISIBLE_DEVICES=0 python -m JaxPref.new_preference_reward_main --use_human_label True --comment {experiment_name} --robosuite True --robosuite_dataset_type {dataset_type} --robosuite_dataset_path {path for robomimic demonstrations} --transformer.embd_dim 256 --transformer.n_layer 1 --transformer.n_head 4 --env {Robosuite env name} --logging.output_dir './logs/pref_reward' --batch_size 256 --num_query {number of query} --query_len {100|50} --n_epochs 10000 --skip_flag 0 --seed {seed} --model_type PrefTransformer
# Non-Markovian Reward (NMR)
CUDA_VISIBLE_DEVICES=0 python -m JaxPref.new_preference_reward_main --use_human_label True --comment {experiment_name} --robosuite True --robosuite_dataset_type {dataset_type} --robosuite_dataset_path {path for robomimic demonstrations} --env {Robosuite env name} --logging.output_dir './logs/pref_reward' --batch_size 256 --num_query {number of query} --query_len {100|50} --n_epochs 10000 --skip_flag 0 --seed {seed} --model_type NMR
# Markovian Reward (MR)
CUDA_VISIBLE_DEVICES=0 python -m JaxPref.new_preference_reward_main --use_human_label True --comment {experiment_name} --robosuite True --robosuite_dataset_type {dataset_type} --robosuite_dataset_path {path for robomimic demonstrations} --env {Robosuite env name} --logging.output_dir './logs/pref_reward' --batch_size 256 --num_query 100000 --query_len {100|50} --n_epochs 10000 --skip_flag 0 --seed {seed} --model_type MR
```
### Run IQL with learned Reward Model
```bash
# Preference Transfomer (PT)
CUDA_VISIBLE_DEVICES=0 python robosuite_train_offline.py --seq_len {sequence length in reward prediction} --comment {experiment_name} --eval_interval 100000 --env_name {Robosuite env name} --robosuite_dataset_type {ph|mh} --robosuite_dataset_path {path for robomimic demonstrations} --config configs/adroit_config.py --eval_episodes 10 --use_reward_model True --model_type PrefTransformer --ckpt_dir {reward_model_path} --seed {seed}
# Non-Markovian Reward (NMR)
CUDA_VISIBLE_DEVICES=0 python robosuite_train_offline.py --seq_len {sequence length in reward prediction} --comment {experiment_name} --eval_interval 100000 --env_name {Robosuite env name} --robosuite_dataset_type {ph|mh} --robosuite_dataset_path {path for robomimic demonstrations} --config configs/adroit_config.py --eval_episodes 10 --use_reward_model True --model_type NMR --ckpt_dir {reward_model_path} --seed {seed}
# Markovian Reward (MR)
CUDA_VISIBLE_DEVICES=0 python robosuite_train_offline.py --comment {experiment_name} --eval_interval 100000 --env_name {Robosuite env name} --robosuite_dataset_type {ph|mh} --robosuite_dataset_path {path for robomimic demonstrations} --config configs/adroit_config.py --eval_episodes 10 --use_reward_model True --model_type MR --ckpt_dir {reward_model_path} --seed {seed}
```
## Citation
```
@inproceedings{
kim2023preference,
title={Preference Transformer: Modeling Human Preferences using Transformers for {RL}},
author={Changyeon Kim and Jongjin Park and Jinwoo Shin and Honglak Lee and Pieter Abbeel and Kimin Lee},
booktitle={International Conference on Learning Representations},
year={2023},
url={https://openreview.net/forum?id=Peot1SFDX0}
}
```
## Acknowledgments
Our code is based on the implementation of [Flaxmodels](https://github.com/matthias-wright/flaxmodels) and [IQL](https://github.com/ikostrikov/implicit_q_learning).
================================================
FILE: actor.py
================================================
from typing import Tuple
import jax
import jax.numpy as jnp
from common import Batch, InfoDict, Model, Params, PRNGKey
def update(key: PRNGKey, actor: Model, critic: Model, value: Model,
batch: Batch, temperature: float) -> Tuple[Model, InfoDict]:
v = value(batch.observations)
q1, q2 = critic(batch.observations, batch.actions)
q = jnp.minimum(q1, q2)
exp_a = jnp.exp((q - v) * temperature)
exp_a = jnp.minimum(exp_a, 100.0)
def actor_loss_fn(actor_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
dist = actor.apply({'params': actor_params},
batch.observations,
training=True,
rngs={'dropout': key})
log_probs = dist.log_prob(batch.actions)
actor_loss = -(exp_a * log_probs).mean()
return actor_loss, {'actor_loss': actor_loss, 'adv': q - v}
new_actor, info = actor.apply_gradient(actor_loss_fn)
return new_actor, info
================================================
FILE: common.py
================================================
import collections
import os
from typing import Any, Callable, Dict, Optional, Sequence, Tuple
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
Batch = collections.namedtuple(
'Batch',
['observations', 'actions', 'rewards', 'masks', 'next_observations'])
def default_init(scale: Optional[float] = jnp.sqrt(2)):
return nn.initializers.orthogonal(scale)
PRNGKey = Any
Params = flax.core.FrozenDict[str, Any]
PRNGKey = Any
Shape = Sequence[int]
Dtype = Any # this could be a real type?
InfoDict = Dict[str, float]
class MLP(nn.Module):
hidden_dims: Sequence[int]
activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
activate_final: int = False
dropout_rate: Optional[float] = None
@nn.compact
def __call__(self, x: jnp.ndarray, training: bool = False) -> jnp.ndarray:
for i, size in enumerate(self.hidden_dims):
x = nn.Dense(size, kernel_init=default_init())(x)
if i + 1 < len(self.hidden_dims) or self.activate_final:
x = self.activations(x)
if self.dropout_rate is not None:
x = nn.Dropout(rate=self.dropout_rate)(
x, deterministic=not training)
return x
@flax.struct.dataclass
class Model:
step: int
apply_fn: nn.Module = flax.struct.field(pytree_node=False)
params: Params
tx: Optional[optax.GradientTransformation] = flax.struct.field(
pytree_node=False)
opt_state: Optional[optax.OptState] = None
@classmethod
def create(cls,
model_def: nn.Module,
inputs: Sequence[jnp.ndarray],
tx: Optional[optax.GradientTransformation] = None) -> 'Model':
variables = model_def.init(*inputs)
_, params = variables.pop('params')
if tx is not None:
opt_state = tx.init(params)
else:
opt_state = None
return cls(step=1,
apply_fn=model_def,
params=params,
tx=tx,
opt_state=opt_state)
def __call__(self, *args, **kwargs):
return self.apply_fn.apply({'params': self.params}, *args, **kwargs)
def apply(self, *args, **kwargs):
return self.apply_fn.apply(*args, **kwargs)
def apply_gradient(self, loss_fn) -> Tuple[Any, 'Model']:
grad_fn = jax.grad(loss_fn, has_aux=True)
grads, info = grad_fn(self.params)
updates, new_opt_state = self.tx.update(grads, self.opt_state,
self.params)
new_params = optax.apply_updates(self.params, updates)
return self.replace(step=self.step + 1,
params=new_params,
opt_state=new_opt_state), info
def save(self, save_path: str):
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, 'wb') as f:
f.write(flax.serialization.to_bytes(self.params))
def load(self, load_path: str) -> 'Model':
with open(load_path, 'rb') as f:
params = flax.serialization.from_bytes(self.params, f.read())
return self.replace(params=params)
================================================
FILE: configs/adroit_config.py
================================================
import ml_collections
def get_config():
config = ml_collections.ConfigDict()
config.actor_lr = 3e-4
config.value_lr = 3e-4
config.critic_lr = 3e-4
config.hidden_dims = (256, 256)
config.discount = 0.99
config.expectile = 0.7 # The actual tau for expectiles.
config.temperature = 0.5
config.dropout_rate = 0.1
config.tau = 0.005 # For soft target updates.
return config
================================================
FILE: configs/antmaze_config.py
================================================
import ml_collections
def get_config():
config = ml_collections.ConfigDict()
config.actor_lr = 3e-4
config.value_lr = 3e-4
config.critic_lr = 3e-4
config.hidden_dims = (256, 256)
config.discount = 0.99
config.expectile = 0.9 # The actual tau for expectiles.
config.temperature = 10.0
config.dropout_rate = None
config.tau = 0.005 # For soft target updates.
return config
================================================
FILE: configs/antmaze_finetune_config.py
================================================
import ml_collections
def get_config():
config = ml_collections.ConfigDict()
config.actor_lr = 3e-4
config.value_lr = 3e-4
config.critic_lr = 3e-4
config.hidden_dims = (256, 256)
config.discount = 0.99
config.expectile = 0.9 # The actual tau for expectiles.
config.temperature = 10.0
config.dropout_rate = None
config.tau = 0.005 # For soft target updates.
config.opt_decay_schedule = None # Don't decay optimizer lr
return config
================================================
FILE: configs/mujoco_config.py
================================================
import ml_collections
def get_config():
config = ml_collections.ConfigDict()
config.actor_lr = 3e-4
config.value_lr = 3e-4
config.critic_lr = 3e-4
config.hidden_dims = (256, 256)
config.discount = 0.99
config.expectile = 0.7 # The actual tau for expectiles.
config.temperature = 3.0
config.dropout_rate = None
config.tau = 0.005 # For soft target updates.
return config
================================================
FILE: critic.py
================================================
from typing import Tuple
import jax.numpy as jnp
from common import Batch, InfoDict, Model, Params
def loss(diff, expectile=0.8):
weight = jnp.where(diff > 0, expectile, (1 - expectile))
return weight * (diff**2)
def update_v(critic: Model, value: Model, batch: Batch,
expectile: float) -> Tuple[Model, InfoDict]:
actions = batch.actions
q1, q2 = critic(batch.observations, actions)
q = jnp.minimum(q1, q2)
def value_loss_fn(value_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
v = value.apply({'params': value_params}, batch.observations)
value_loss = loss(q - v, expectile).mean()
return value_loss, {
'value_loss': value_loss,
'v': v.mean(),
}
new_value, info = value.apply_gradient(value_loss_fn)
return new_value, info
def update_q(critic: Model, target_value: Model, batch: Batch,
discount: float) -> Tuple[Model, InfoDict]:
next_v = target_value(batch.next_observations)
target_q = batch.rewards + discount * batch.masks * next_v
def critic_loss_fn(critic_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
q1, q2 = critic.apply({'params': critic_params}, batch.observations,
batch.actions)
critic_loss = ((q1 - target_q)**2 + (q2 - target_q)**2).mean()
return critic_loss, {
'critic_loss': critic_loss,
'q1': q1.mean(),
'q2': q2.mean()
}
new_critic, info = critic.apply_gradient(critic_loss_fn)
return new_critic, info
================================================
FILE: d4rl/.gitignore
================================================
.idea
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
================================================
FILE: d4rl/LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: d4rl/MANIFEST.in
================================================
recursive-include * *.xml
recursive-include * *.stl
recursive-include * *.png
================================================
FILE: d4rl/README.md
================================================
# D4RL: Datasets for Deep Data-Driven Reinforcement Learning
[](https://opensource.org/licenses/Apache-2.0)
[](https://creativecommons.org/licenses/by/4.0/)
D4RL is an open-source benchmark for offline reinforcement learning. It provides standardized environments and datasets for training and benchmarking algorithms. A supplementary [whitepaper](https://arxiv.org/abs/2004.07219) and [website](https://sites.google.com/view/d4rl/home) are also available.
## Setup
D4RL can be installed by cloning the repository as follows:
```
git clone https://github.com/rail-berkeley/d4rl.git
cd d4rl
pip install -e .
```
Or, alternatively:
```
pip install git+https://github.com/rail-berkeley/d4rl@master#egg=d4rl
```
The control environments require MuJoCo as a dependency. You may need to obtain a [license](https://www.roboti.us/license.html) and follow the setup instructions for mujoco_py. This mostly involves copying the key to your MuJoCo installation folder.
The Flow and CARLA tasks also require additional installation steps:
- Instructions for installing CARLA can be found [here](https://github.com/rail-berkeley/d4rl/wiki/CARLA-Setup)
- Instructions for installing Flow can be found [here](https://flow.readthedocs.io/en/latest/flow_setup.html). Make sure to install using the SUMO simulator, and add the flow repository to your PYTHONPATH once finished.
## Using d4rl
d4rl uses the [OpenAI Gym](https://github.com/openai/gym) API. Tasks are created via the `gym.make` function. A full list of all tasks is [available here](https://github.com/rail-berkeley/d4rl/wiki/Tasks).
Each task is associated with a fixed offline dataset, which can be obtained with the `env.get_dataset()` method. This method returns a dictionary with:
- `observations`: An N by observation dimensional array of observations.
- `actions`: An N by action dimensional array of actions.
- `rewards`: An N dimensional array of rewards.
- `terminals`: An N dimensional array of episode termination flags. This is true when episodes end due to termination conditions such as falling over.
- `timeouts`: An N dimensional array of termination flags. This is true when episodes end due to reaching the maximum episode length.
- `infos`: Contains optional task-specific debugging information.
You can also load data using `d4rl.qlearning_dataset(env)`, which formats the data for use by typical Q-learning algorithms by adding a `next_observations` key.
```python
import gym
import d4rl # Import required to register environments
# Create the environment
env = gym.make('maze2d-umaze-v1')
# d4rl abides by the OpenAI gym interface
env.reset()
env.step(env.action_space.sample())
# Each task is associated with a dataset
# dataset contains observations, actions, rewards, terminals, and infos
dataset = env.get_dataset()
print(dataset['observations']) # An N x dim_observation Numpy array of observations
# Alternatively, use d4rl.qlearning_dataset which
# also adds next_observations.
dataset = d4rl.qlearning_dataset(env)
```
Datasets are automatically downloaded to the `~/.d4rl/datasets` directory when `get_dataset()` is called. If you would like to change the location of this directory, you can set the `$D4RL_DATASET_DIR` environment variable to the directory of your choosing, or pass in the dataset filepath directly into the `get_dataset` method.
### Normalizing Scores
You can use the `env.get_normalized_score(returns)` function to compute a normalized score for an episode, where `returns` is the undiscounted total sum of rewards accumulated during an episode.
The individual min and max reference scores are stored in `d4rl/infos.py` for reference.
## Algorithm Implementations
We have aggregated implementations of various offline RL algorithms in a [separate repository](https://github.com/rail-berkeley/d4rl_evaluations).
## Off-Policy Evaluations
D4RL currently has limited support for off-policy evaluation methods, on a select few locomotion tasks. We provide trained reference policies and a set of performance metrics. Additional details can be found in the [wiki](https://github.com/rail-berkeley/d4rl/wiki/Off-Policy-Evaluation).
## Recent Updates
### 2-12-2020
- Added new Gym-MuJoCo datasets (labeled v2) which fixed Hopper's performance and the qpos/qvel fields.
- Added additional wiki documentation on [generating datasets](https://github.com/rail-berkeley/d4rl/wiki/Dataset-Reproducibility-Guide).
## Acknowledgements
D4RL builds on top of several excellent domains and environments built by various researchers. We would like to thank the authors of:
- [hand_dapg](https://github.com/aravindr93/hand_dapg)
- [gym-minigrid](https://github.com/maximecb/gym-minigrid)
- [carla](https://github.com/carla-simulator/carla)
- [flow](https://github.com/flow-project/flow)
- [adept_envs](https://github.com/google-research/relay-policy-learning)
## Citation
Please use the following bibtex for citations:
```
@misc{fu2020d4rl,
title={D4RL: Datasets for Deep Data-Driven Reinforcement Learning},
author={Justin Fu and Aviral Kumar and Ofir Nachum and George Tucker and Sergey Levine},
year={2020},
eprint={2004.07219},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
```
## Licenses
Unless otherwise noted, all datasets are licensed under the [Creative Commons Attribution 4.0 License (CC BY)](https://creativecommons.org/licenses/by/4.0/), and code is licensed under the [Apache 2.0 License](https://www.apache.org/licenses/LICENSE-2.0.html).
================================================
FILE: d4rl/d4rl/__init__.py
================================================
import os
import sys
import collections
import numpy as np
import d4rl.infos
from d4rl.offline_env import set_dataset_path, get_keys
SUPPRESS_MESSAGES = bool(os.environ.get('D4RL_SUPPRESS_IMPORT_ERROR', 0))
_ERROR_MESSAGE = 'Warning: %s failed to import. Set the environment variable D4RL_SUPPRESS_IMPORT_ERROR=1 to suppress this message.'
try:
import d4rl.locomotion
import d4rl.hand_manipulation_suite
import d4rl.pointmaze
import d4rl.gym_minigrid
import d4rl.gym_mujoco
except ImportError as e:
if not SUPPRESS_MESSAGES:
print(_ERROR_MESSAGE % 'Mujoco-based envs', file=sys.stderr)
print(e, file=sys.stderr)
try:
import d4rl.flow
except ImportError as e:
if not SUPPRESS_MESSAGES:
print(_ERROR_MESSAGE % 'Flow', file=sys.stderr)
print(e, file=sys.stderr)
try:
import d4rl.kitchen
except ImportError as e:
if not SUPPRESS_MESSAGES:
print(_ERROR_MESSAGE % 'FrankaKitchen', file=sys.stderr)
print(e, file=sys.stderr)
try:
import d4rl.carla
except ImportError as e:
if not SUPPRESS_MESSAGES:
print(_ERROR_MESSAGE % 'CARLA', file=sys.stderr)
print(e, file=sys.stderr)
try:
import d4rl.gym_bullet
import d4rl.pointmaze_bullet
except ImportError as e:
if not SUPPRESS_MESSAGES:
print(_ERROR_MESSAGE % 'GymBullet', file=sys.stderr)
print(e, file=sys.stderr)
def reverse_normalized_score(env_name, score):
ref_min_score = d4rl.infos.REF_MIN_SCORE[env_name]
ref_max_score = d4rl.infos.REF_MAX_SCORE[env_name]
return (score * (ref_max_score - ref_min_score)) + ref_min_score
def get_normalized_score(env_name, score):
ref_min_score = d4rl.infos.REF_MIN_SCORE[env_name]
ref_max_score = d4rl.infos.REF_MAX_SCORE[env_name]
return (score - ref_min_score) / (ref_max_score - ref_min_score)
def qlearning_dataset(env, dataset=None, terminate_on_end=False, **kwargs):
"""
Returns datasets formatted for use by standard Q-learning algorithms,
with observations, actions, next_observations, rewards, and a terminal
flag.
Args:
env: An OfflineEnv object.
dataset: An optional dataset to pass in for processing. If None,
the dataset will default to env.get_dataset()
terminate_on_end (bool): Set done=True on the last timestep
in a trajectory. Default is False, and will discard the
last timestep in each trajectory.
**kwargs: Arguments to pass to env.get_dataset().
Returns:
A dictionary containing keys:
observations: An N x dim_obs array of observations.
actions: An N x dim_action array of actions.
next_observations: An N x dim_obs array of next observations.
rewards: An N-dim float array of rewards.
terminals: An N-dim boolean array of "done" or episode termination flags.
"""
if dataset is None:
dataset = env.get_dataset(**kwargs)
N = dataset['rewards'].shape[0]
obs_ = []
next_obs_ = []
action_ = []
reward_ = []
done_ = []
# The newer version of the dataset adds an explicit
# timeouts field. Keep old method for backwards compatability.
use_timeouts = False
if 'timeouts' in dataset:
use_timeouts = True
episode_step = 0
for i in range(N-1):
obs = dataset['observations'][i].astype(np.float32)
new_obs = dataset['observations'][i+1].astype(np.float32)
action = dataset['actions'][i].astype(np.float32)
reward = dataset['rewards'][i].astype(np.float32)
# if 'maze' in env.spec.id:
if False:
done_bool = sum(dataset['infos/goal'][i+1] - dataset['infos/goal'][i]) > 0
else:
done_bool = bool(dataset['terminals'][i])
if use_timeouts:
final_timestep = dataset['timeouts'][i]
else:
final_timestep = (episode_step == env._max_episode_steps - 1)
if (not terminate_on_end) and final_timestep:
# Skip this transition and don't apply terminals on the last step of an episode
episode_step = 0
continue
if done_bool or final_timestep:
episode_step = 0
obs_.append(obs)
next_obs_.append(new_obs)
action_.append(action)
reward_.append(reward)
done_.append(done_bool)
episode_step += 1
return {
'observations': np.array(obs_),
'actions': np.array(action_),
'next_observations': np.array(next_obs_),
'rewards': np.array(reward_),
'terminals': np.array(done_),
}
def sequence_dataset(env, dataset=None, **kwargs):
"""
Returns an iterator through trajectories.
Args:
env: An OfflineEnv object.
dataset: An optional dataset to pass in for processing. If None,
the dataset will default to env.get_dataset()
**kwargs: Arguments to pass to env.get_dataset().
Returns:
An iterator through dictionaries with keys:
observations
actions
rewards
terminals
"""
if dataset is None:
dataset = env.get_dataset(**kwargs)
N = dataset['rewards'].shape[0]
data_ = collections.defaultdict(list)
# The newer version of the dataset adds an explicit
# timeouts field. Keep old method for backwards compatability.
use_timeouts = False
if 'timeouts' in dataset:
use_timeouts = True
episode_step = 0
for i in range(N):
done_bool = bool(dataset['terminals'][i])
if use_timeouts:
final_timestep = dataset['timeouts'][i]
else:
final_timestep = (episode_step == env._max_episode_steps - 1)
for k in dataset:
data_[k].append(dataset[k][i])
if done_bool or final_timestep:
episode_step = 0
episode_data = {}
for k in data_:
episode_data[k] = np.array(data_[k])
yield episode_data
data_ = collections.defaultdict(list)
episode_step += 1
================================================
FILE: d4rl/d4rl/carla/__init__.py
================================================
from .carla_env import CarlaObsDictEnv
from .carla_env import CarlaObsEnv
from gym.envs.registration import register
register(
id='carla-lane-v0',
entry_point='d4rl.carla:CarlaObsEnv',
max_episode_steps=250,
kwargs={
'ref_min_score': -0.8503839912088142,
'ref_max_score': 1023.5784385429523,
'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_lane_follow_flat-v0.hdf5',
'reward_type': 'lane_follow',
'carla_args': dict(
vision_size=48,
vision_fov=48,
weather=False,
frame_skip=1,
steps=250,
multiagent=True,
lane=0,
lights=False,
record_dir="None",
)
}
)
register(
id='carla-lane-render-v0',
entry_point='d4rl.carla:CarlaDictEnv',
max_episode_steps=250,
kwargs={
'ref_min_score': -0.8503839912088142,
'ref_max_score': 1023.5784385429523,
'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_lane_follow-v0.hdf5',
'reward_type': 'lane_follow',
'render_images': True,
'carla_args': dict(
vision_size=48,
vision_fov=48,
weather=False,
frame_skip=1,
steps=250,
multiagent=True,
lane=0,
lights=False,
record_dir="None",
)
}
)
TOWN_STEPS = 1000
register(
id='carla-town-v0',
entry_point='d4rl.carla:CarlaObsEnv',
max_episode_steps=TOWN_STEPS,
kwargs={
'ref_min_score': -114.81579500772153, # Average random returns
'ref_max_score': 2440.1772022247314, # Average dataset returns
'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_subsamp_flat-v0.hdf5',
'reward_type': 'goal_reaching',
'carla_args': dict(
vision_size=48,
vision_fov=48,
weather=False,
frame_skip=1,
steps=TOWN_STEPS,
multiagent=True,
lane=0,
lights=False,
record_dir="None",
)
}
)
register(
id='carla-town-full-v0',
entry_point='d4rl.carla:CarlaObsEnv',
max_episode_steps=TOWN_STEPS,
kwargs={
'ref_min_score': -114.81579500772153, # Average random returns
'ref_max_score': 2440.1772022247314, # Average dataset returns
'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_flat-v0.hdf5',
'reward_type': 'goal_reaching',
'carla_args': dict(
vision_size=48,
vision_fov=48,
weather=False,
frame_skip=1,
steps=TOWN_STEPS,
multiagent=True,
lane=0,
lights=False,
record_dir="None",
)
}
)
register(
id='carla-town-render-v0',
entry_point='d4rl.carla:CarlaObsEnv',
max_episode_steps=TOWN_STEPS,
kwargs={
'ref_min_score': None,
'ref_max_score': None,
'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_flat-v0.hdf5',
'render_images': True,
'reward_type': 'goal_reaching',
'carla_args': dict(
vision_size=48,
vision_fov=48,
weather=False,
frame_skip=1,
steps=TOWN_STEPS,
multiagent=True,
lane=0,
lights=False,
record_dir="None",
)
}
)
================================================
FILE: d4rl/d4rl/carla/carla_env.py
================================================
import argparse
import datetime
import glob
import os
import random
import sys
import time
from PIL import Image
from PIL.PngImagePlugin import PngInfo
import gym
from gym import Env
import gym.spaces as spaces
#from . import proxy_env
from d4rl.offline_env import OfflineEnv
try:
sys.path.append(glob.glob('../carla/dist/carla-*%d.%d-%s.egg' % (
sys.version_info.major,
sys.version_info.minor,
'win-amd64' if os.name == 'nt' else 'linux-x86_64'))[0])
except IndexError:
pass
import carla
import math
from dotmap import DotMap
try:
import pygame
except ImportError:
raise RuntimeError('cannot import pygame, make sure pygame package is installed')
try:
import numpy as np
except ImportError:
raise RuntimeError('cannot import numpy, make sure numpy package is installed')
try:
import queue
except ImportError:
import Queue as queue
# This is CARLA agent
from agents.navigation.agent import Agent, AgentState
from agents.navigation.local_planner import LocalPlanner
from agents.navigation.global_route_planner import GlobalRoutePlanner
from agents.navigation.global_route_planner_dao import GlobalRoutePlannerDAO
from agents.tools.misc import is_within_distance_ahead, compute_magnitude_angle
def is_within_distance(target_location, current_location, orientation, max_distance, d_angle_th_up, d_angle_th_low=0):
"""
Check if a target object is within a certain distance from a reference object.
A vehicle in front would be something around 0 deg, while one behind around 180 deg.
:param target_location: location of the target object
:param current_location: location of the reference object
:param orientation: orientation of the reference object
:param max_distance: maximum allowed distance
:param d_angle_th_up: upper thereshold for angle
:param d_angle_th_low: low thereshold for angle (optional, default is 0)
:return: True if target object is within max_distance ahead of the reference object
"""
target_vector = np.array([target_location.x - current_location.x, target_location.y - current_location.y])
norm_target = np.linalg.norm(target_vector)
# If the vector is too short, we can simply stop here
if norm_target < 0.001:
return True
if norm_target > max_distance:
return False
forward_vector = np.array(
[math.cos(math.radians(orientation)), math.sin(math.radians(orientation))])
d_angle = math.degrees(math.acos(np.clip(np.dot(forward_vector, target_vector) / norm_target, -1., 1.)))
return d_angle_th_low < d_angle < d_angle_th_up
def compute_distance(location_1, location_2):
"""
Euclidean distance between 3D po-0.427844-0.427844ints
:param location_1, location_2: 3D points
"""
x = location_2.x - location_1.x
y = location_2.y - location_1.y
z = location_2.z - location_1.z
norm = np.linalg.norm([x, y, z]) + np.finfo(float).eps
return norm
class CustomGlobalRoutePlanner(GlobalRoutePlanner):
def __init__(self, dao):
super(CustomGlobalRoutePlanner, self).__init__(dao=dao)
def compute_direction_velocities(self, origin, velocity, destination):
node_list = super(CustomGlobalRoutePlanner, self)._path_search(origin=origin, destination=destination)
origin_xy = np.array([origin.x, origin.y])
velocity_xy = np.array([velocity.x, velocity.y])
first_node_xy = self._graph.nodes[node_list[0]]['vertex']
first_node_xy = np.array([first_node_xy[0], first_node_xy[1]])
target_direction_vector = first_node_xy - origin_xy
target_unit_vector = np.array(target_direction_vector) / np.linalg.norm(target_direction_vector)
vel_s = np.dot(velocity_xy, target_unit_vector)
unit_velocity = velocity_xy / (np.linalg.norm(velocity_xy) + 1e-8)
angle = np.arccos(np.clip(np.dot(unit_velocity, target_unit_vector), -1.0, 1.0))
vel_perp = np.linalg.norm(velocity_xy) * np.sin(angle)
return vel_s, vel_perp
def compute_distance(self, origin, destination):
node_list = super(CustomGlobalRoutePlanner, self)._path_search(origin=origin, destination=destination)
#print('Node list:', node_list)
first_node_xy = self._graph.nodes[node_list[1]]['vertex']
#print('Diff:', origin, first_node_xy)
#distance = 0.0
distances = []
distances.append(np.linalg.norm(np.array([origin.x, origin.y, 0.0]) - np.array(first_node_xy)))
for idx in range(len(node_list) - 1):
distances.append(super(CustomGlobalRoutePlanner, self)._distance_heuristic(node_list[idx], node_list[idx+1]))
#print('Distances:', distances)
#import pdb; pdb.set_trace()
return np.sum(distances)
class CarlaSyncMode(object):
"""
Context manager to synchronize output from different sensors. Synchronous
mode is enabled as long as we are inside this context
with CarlaSyncMode(world, sensors) as sync_mode:
while True:
data = sync_mode.tick(timeout=1.0)
"""
def __init__(self, world, *sensors, **kwargs):
self.world = world
self.sensors = sensors
self.frame = None
self.delta_seconds = 1.0 / kwargs.get('fps', 20)
self._queues = []
self._settings = None
self.start()
def start(self):
self._settings = self.world.get_settings()
self.frame = self.world.apply_settings(carla.WorldSettings(
no_rendering_mode=False,
synchronous_mode=True,
fixed_delta_seconds=self.delta_seconds))
def make_queue(register_event):
q = queue.Queue()
register_event(q.put)
self._queues.append(q)
make_queue(self.world.on_tick)
for sensor in self.sensors:
make_queue(sensor.listen)
def tick(self, timeout):
self.frame = self.world.tick()
data = [self._retrieve_data(q, timeout) for q in self._queues]
assert all(x.frame == self.frame for x in data)
return data
def __exit__(self, *args, **kwargs):
self.world.apply_settings(self._settings)
def _retrieve_data(self, sensor_queue, timeout):
while True:
data = sensor_queue.get(timeout=timeout)
if data.frame == self.frame:
return data
class Sun(object):
def __init__(self, azimuth, altitude):
self.azimuth = azimuth
self.altitude = altitude
self._t = 0.0
def tick(self, delta_seconds):
self._t += 0.008 * delta_seconds
self._t %= 2.0 * math.pi
self.azimuth += 0.25 * delta_seconds
self.azimuth %= 360.0
min_alt, max_alt = [20, 90]
self.altitude = 0.5 * (max_alt + min_alt) + 0.5 * (max_alt - min_alt) * math.cos(self._t)
def __str__(self):
return 'Sun(alt: %.2f, azm: %.2f)' % (self.altitude, self.azimuth)
class Storm(object):
def __init__(self, precipitation):
self._t = precipitation if precipitation > 0.0 else -50.0
self._increasing = True
self.clouds = 0.0
self.rain = 0.0
self.wetness = 0.0
self.puddles = 0.0
self.wind = 0.0
self.fog = 0.0
def tick(self, delta_seconds):
delta = (1.3 if self._increasing else -1.3) * delta_seconds
self._t = clamp(delta + self._t, -250.0, 100.0)
self.clouds = clamp(self._t + 40.0, 0.0, 90.0)
self.clouds = clamp(self._t + 40.0, 0.0, 60.0)
self.rain = clamp(self._t, 0.0, 80.0)
delay = -10.0 if self._increasing else 90.0
self.puddles = clamp(self._t + delay, 0.0, 85.0)
self.wetness = clamp(self._t * 5, 0.0, 100.0)
self.wind = 5.0 if self.clouds <= 20 else 90 if self.clouds >= 70 else 40
self.fog = clamp(self._t - 10, 0.0, 30.0)
if self._t == -250.0:
self._increasing = True
if self._t == 100.0:
self._increasing = False
def __str__(self):
return 'Storm(clouds=%d%%, rain=%d%%, wind=%d%%)' % (self.clouds, self.rain, self.wind)
class Weather(object):
def __init__(self, world, changing_weather_speed):
self.world = world
self.reset()
self.weather = world.get_weather()
self.changing_weather_speed = changing_weather_speed
self._sun = Sun(self.weather.sun_azimuth_angle, self.weather.sun_altitude_angle)
self._storm = Storm(self.weather.precipitation)
def reset(self):
weather_params = carla.WeatherParameters(sun_altitude_angle=90.)
self.world.set_weather(weather_params)
def tick(self):
self._sun.tick(self.changing_weather_speed)
self._storm.tick(self.changing_weather_speed)
self.weather.cloudiness = self._storm.clouds
self.weather.precipitation = self._storm.rain
self.weather.precipitation_deposits = self._storm.puddles
self.weather.wind_intensity = self._storm.wind
self.weather.fog_density = self._storm.fog
self.weather.wetness = self._storm.wetness
self.weather.sun_azimuth_angle = self._sun.azimuth
self.weather.sun_altitude_angle = self._sun.altitude
self.world.set_weather(self.weather)
def __str__(self):
return '%s %s' % (self._sun, self._storm)
def clamp(value, minimum=0.0, maximum=100.0):
return max(minimum, min(value, maximum))
## Now the actual env
class CarlaEnv(object):
"""
CARLA agent, we will wrap this in a proxy env to get a gym env
"""
def __init__(self, render=False, carla_port=2000, record=False, record_dir=None, args=None, record_vision=False, reward_type='lane_follow', **kwargs):
self.render_display = render
self.record_display = record
print('[CarlaEnv] record_vision:', record_vision)
self.record_vision = record_vision
self.record_dir = record_dir
self.reward_type = reward_type
self.vision_size = args['vision_size']
self.vision_fov = args['vision_fov']
self.changing_weather_speed = float(args['weather'])
self.frame_skip = args['frame_skip']
self.max_episode_steps = args['steps'] # DMC uses this
self.multiagent = args['multiagent']
self.start_lane = args['lane']
self.follow_traffic_lights = args['lights']
if self.record_display:
assert self.render_display
self.actor_list = []
if self.render_display:
pygame.init()
self.render_display = pygame.display.set_mode((800, 600), pygame.HWSURFACE | pygame.DOUBLEBUF)
self.font = get_font()
self.clock = pygame.time.Clock()
self.client = carla.Client('localhost', carla_port)
self.client.set_timeout(2.0)
self.world = self.client.get_world()
self.map = self.world.get_map()
# tests specific to map 4:
if self.start_lane and self.map.name != "Town04":
raise NotImplementedError
# remove old vehicles and sensors (in case they survived)
self.world.tick()
actor_list = self.world.get_actors()
for vehicle in actor_list.filter("*vehicle*"):
print("Warning: removing old vehicle")
vehicle.destroy()
for sensor in actor_list.filter("*sensor*"):
print("Warning: removing old sensor")
sensor.destroy()
self.vehicle = None
self.vehicles_list = [] # their ids
self.reset_vehicle() # creates self.vehicle
self.actor_list.append(self.vehicle)
blueprint_library = self.world.get_blueprint_library()
if self.render_display:
self.camera_display = self.world.spawn_actor(
blueprint_library.find('sensor.camera.rgb'),
carla.Transform(carla.Location(x=-5.5, z=2.8), carla.Rotation(pitch=-15)),
attach_to=self.vehicle)
self.actor_list.append(self.camera_display)
bp = blueprint_library.find('sensor.camera.rgb')
bp.set_attribute('image_size_x', str(self.vision_size))
bp.set_attribute('image_size_y', str(self.vision_size))
bp.set_attribute('fov', str(self.vision_fov))
location = carla.Location(x=1.6, z=1.7)
self.camera_vision = self.world.spawn_actor(bp, carla.Transform(location, carla.Rotation(yaw=0.0)), attach_to=self.vehicle)
self.actor_list.append(self.camera_vision)
if self.record_display or self.record_vision:
if self.record_dir is None:
self.record_dir = "carla-{}-{}x{}-fov{}".format(
self.map.name.lower(), self.vision_size, self.vision_size, self.vision_fov)
if self.frame_skip > 1:
self.record_dir += '-{}'.format(self.frame_skip)
if self.changing_weather_speed > 0.0:
self.record_dir += '-weather'
if self.multiagent:
self.record_dir += '-mutiagent'
if self.follow_traffic_lights:
self.record_dir += '-lights'
self.record_dir += '-{}k'.format(self.max_episode_steps // 1000)
now = datetime.datetime.now()
self.record_dir += now.strftime("-%Y-%m-%d-%H-%M-%S")
os.mkdir(self.record_dir)
if self.render_display:
self.sync_mode = CarlaSyncMode(self.world, self.camera_display, self.camera_vision, fps=20)
else:
self.sync_mode = CarlaSyncMode(self.world, self.camera_vision, fps=20)
# weather
self.weather = Weather(self.world, self.changing_weather_speed)
# dummy variables, to match deep mind control's APIs
low = -1.0
high = 1.0
self.action_space = spaces.Box(low=np.array((low, low)), high=np.array((high, high)))
self.observation_space = DotMap()
self.observation_space.shape = (3, self.vision_size, self.vision_size)
self.observation_space.dtype = np.dtype(np.uint8)
self.reward_range = None
self.metadata = None
# self.action_space.sample = lambda: np.random.uniform(low=low, high=high, size=self.action_space.shape[0]).astype(np.float32)
self.horizon = self.max_episode_steps
self.image_shape = (3, self.vision_size, self.vision_size)
# roaming carla agent
self.count = 0
self.world.tick()
self.reset_init()
self._proximity_threshold = 10.0
self._traffic_light_threshold = 5.0
self.actor_list = self.world.get_actors()
#for idx in range(len(self.actor_list)):
# print (idx, self.actor_list[idx])
# import ipdb; ipdb.set_trace()
self.vehicle_list = self.actor_list.filter("*vehicle*")
self.lights_list = self.actor_list.filter("*traffic_light*")
self.object_list = self.actor_list.filter("*traffic.*")
# town nav
self.route_planner_dao = GlobalRoutePlannerDAO(self.map, sampling_resolution=0.1)
self.route_planner = CustomGlobalRoutePlanner(self.route_planner_dao)
self.route_planner.setup()
self.target_location = carla.Location(x=-13.473097, y=134.311234, z=-0.010433)
# roaming carla agent
# self.agent = None
# self.count = 0
# self.world.tick()
self.reset() # creates self.agent
def reset_init(self):
self.reset_vehicle()
self.world.tick()
self.reset_other_vehicles()
self.world.tick()
#
self.count = 0
def reset(self):
#self.reset_vehicle()
#self.world.tick()
#self.reset_other_vehicles()
#self.world.tick()
#self.count = 0
# get obs:
#for _ in range(5):
# self.world.tick()
#obs, _, _, _ = self.step()
obs, _, done, _ = self.step()
# keep resetting until vehicle is not collided
total_resets = 0
while done:
self.reset_vehicle()
self.world.tick()
obs, _, done, _ = self.step()
total_resets += 1
if total_resets > 10:
break
return obs
def reset_vehicle(self):
if self.map.name == "Town04":
self.start_lane = -1 # np.random.choice([-1, -2, -3, -4]) # their positive values, not negative
start_x = 5.
vehicle_init_transform = carla.Transform(carla.Location(x=start_x, y=0, z=0.1), carla.Rotation(yaw=-90))
else:
init_transforms = self.world.get_map().get_spawn_points()
vehicle_init_transform = random.choice(init_transforms)
#print('MyInitTransform', vehicle_init_transform)
if self.vehicle is None: # then create the ego vehicle
blueprint_library = self.world.get_blueprint_library()
vehicle_blueprint = blueprint_library.find('vehicle.audi.a2')
self.vehicle = self.world.spawn_actor(vehicle_blueprint, vehicle_init_transform)
self.vehicle.set_transform(vehicle_init_transform)
self.vehicle.set_velocity(carla.Vector3D())
self.vehicle.set_angular_velocity(carla.Vector3D())
def reset_other_vehicles(self):
if not self.multiagent:
return
# clear out old vehicles
self.client.apply_batch([carla.command.DestroyActor(x) for x in self.vehicles_list])
self.world.tick()
self.vehicles_list = []
traffic_manager = self.client.get_trafficmanager()
traffic_manager.set_global_distance_to_leading_vehicle(2.0)
traffic_manager.set_synchronous_mode(True)
blueprints = self.world.get_blueprint_library().filter('vehicle.*')
blueprints = [x for x in blueprints if int(x.get_attribute('number_of_wheels')) == 4]
num_vehicles = 20
if self.map.name == "Town04":
road_id = 47
road_length = 117.
init_transforms = []
for _ in range(num_vehicles):
lane_id = random.choice([-1, -2, -3, -4])
vehicle_s = np.random.uniform(road_length) # length of road 47
init_transforms.append(self.map.get_waypoint_xodr(road_id, lane_id, vehicle_s).transform)
else:
init_transforms = self.world.get_map().get_spawn_points()
init_transforms = np.random.choice(init_transforms, num_vehicles)
#print('OtherInitTransforms:')
#for transf in init_transforms:
# print(transf)
# --------------
# Spawn vehicles
# --------------
batch = []
for transform in init_transforms:
transform.location.z += 0.1 # otherwise can collide with the road it starts on
blueprint = random.choice(blueprints)
if blueprint.has_attribute('color'):
color = random.choice(blueprint.get_attribute('color').recommended_values)
blueprint.set_attribute('color', color)
if blueprint.has_attribute('driver_id'):
driver_id = random.choice(blueprint.get_attribute('driver_id').recommended_values)
blueprint.set_attribute('driver_id', driver_id)
blueprint.set_attribute('role_name', 'autopilot')
batch.append(carla.command.SpawnActor(blueprint, transform).then(
carla.command.SetAutopilot(carla.command.FutureActor, True)))
for response in self.client.apply_batch_sync(batch, False):
self.vehicles_list.append(response.actor_id)
for response in self.client.apply_batch_sync(batch):
if response.error:
pass
else:
self.vehicles_list.append(response.actor_id)
traffic_manager.global_percentage_speed_difference(30.0)
def step(self, action=None, traffic_light_color=""):
"""
rewards = []
for _ in range(self.frame_skip): # default 1
next_obs, reward, done, info = self._simulator_step(action, traffic_light_color)
rewards.append(reward)
if done:
break
return next_obs, np.mean(rewards), done, info
"""
return self._simulator_step(action, traffic_light_color)
def _is_vehicle_hazard(self, vehicle, vehicle_list):
"""
:param vehicle_list: list of potential obstacle to check
:return: a tuple given by (bool_flag, vehicle), where
- bool_flag is True if there is a vehicle ahead blocking us
and False otherwise
- vehicle is the blocker object itself
"""
ego_vehicle_location = vehicle.get_location()
ego_vehicle_waypoint = self.map.get_waypoint(ego_vehicle_location)
for target_vehicle in vehicle_list:
# do not account for the ego vehicle
if target_vehicle.id == vehicle.id:
continue
# if the object is not in our lane it's not an obstacle
target_vehicle_waypoint = self.map.get_waypoint(target_vehicle.get_location())
if target_vehicle_waypoint.road_id != ego_vehicle_waypoint.road_id or \
target_vehicle_waypoint.lane_id != ego_vehicle_waypoint.lane_id:
continue
if is_within_distance_ahead(target_vehicle.get_transform(),
vehicle.get_transform(),
self._proximity_threshold/10.0):
return (True, -1.0, target_vehicle)
return (False, 0.0, None)
def _is_object_hazard(self, vehicle, object_list):
"""
:param vehicle_list: list of potential obstacle to check
:return: a tuple given by (bool_flag, vehicle), where
- bool_flag is True if there is a vehicle ahead blocking us
and False otherwise
- vehicle is the blocker object itself
"""
ego_vehicle_location = vehicle.get_location()
ego_vehicle_waypoint = self.map.get_waypoint(ego_vehicle_location)
for target_vehicle in object_list:
# do not account for the ego vehicle
if target_vehicle.id == vehicle.id:
continue
# if the object is not in our lane it's not an obstacle
target_vehicle_waypoint = self.map.get_waypoint(target_vehicle.get_location())
if target_vehicle_waypoint.road_id != ego_vehicle_waypoint.road_id or \
target_vehicle_waypoint.lane_id != ego_vehicle_waypoint.lane_id:
continue
if is_within_distance_ahead(target_vehicle.get_transform(),
vehicle.get_transform(),
self._proximity_threshold/40.0):
return (True, -1.0, target_vehicle)
return (False, 0.0, None)
def _is_light_red(self, vehicle):
"""
Method to check if there is a red light affecting us. This version of
the method is compatible with both European and US style traffic lights.
:param lights_list: list containing TrafficLight objects
:return: a tuple given by (bool_flag, traffic_light), where
- bool_flag is True if there is a traffic light in RED
affecting us and False otherwise
- traffic_light is the object itself or None if there is no
red traffic light affecting us
"""
ego_vehicle_location = vehicle.get_location()
ego_vehicle_waypoint = self.map.get_waypoint(ego_vehicle_location)
for traffic_light in self.lights_list:
object_location = self._get_trafficlight_trigger_location(traffic_light)
object_waypoint = self.map.get_waypoint(object_location)
if object_waypoint.road_id != ego_vehicle_waypoint.road_id:
continue
ve_dir = ego_vehicle_waypoint.transform.get_forward_vector()
wp_dir = object_waypoint.transform.get_forward_vector()
dot_ve_wp = ve_dir.x * wp_dir.x + ve_dir.y * wp_dir.y + ve_dir.z * wp_dir.z
if dot_ve_wp < 0:
continue
if is_within_distance_ahead(object_waypoint.transform,
vehicle.get_transform(),
self._traffic_light_threshold):
if traffic_light.state == carla.TrafficLightState.Red:
return (True, -0.1, traffic_light)
return (False, 0.0, None)
def _get_trafficlight_trigger_location(self, traffic_light): # pylint: disable=no-self-use
"""
Calculates the yaw of the waypoint that represents the trigger volume of the traffic light
"""
def rotate_point(point, radians):
"""
rotate a given point by a given angle
"""
rotated_x = math.cos(radians) * point.x - math.sin(radians) * point.y
rotated_y = math.sin(radians) * point.x - math.cos(radians) * point.y
return carla.Vector3D(rotated_x, rotated_y, point.z)
base_transform = traffic_light.get_transform()
base_rot = base_transform.rotation.yaw
area_loc = base_transform.transform(traffic_light.trigger_volume.location)
area_ext = traffic_light.trigger_volume.extent
point = rotate_point(carla.Vector3D(0, 0, area_ext.z), math.radians(base_rot))
point_location = area_loc + carla.Location(x=point.x, y=point.y)
return carla.Location(point_location.x, point_location.y, point_location.z)
def _get_collision_reward(self, vehicle):
vehicle_hazard, reward, vehicle_id = self._is_vehicle_hazard(vehicle, self.vehicle_list)
# Check the lane ids
loc = vehicle.get_location()
if loc is not None:
w = self.map.get_waypoint(loc)
if w is not None:
current_lane_id = w.lane_id
if current_lane_id not in [-1, 1]:
#print ('Lane: ', current_lane_id, self.start_lane)
vehicle_hazard = True
reward = -1.0
else:
vehicle_hazard = True
reward = -1.0
else:
vehicle_hazard = True
reward = -1.0
#print ('vehicle: ', loc, current_lane_id, self.start_lane)
return vehicle_hazard, reward
def _get_traffic_light_reward(self, vehicle):
traffic_light_hazard, reward, traffic_light_id = self._is_light_red(vehicle)
return traffic_light_hazard, 0.0
def _get_object_collided_reward(self, vehicle):
object_hazard, reward, object_id = self._is_object_hazard(vehicle, self.object_list)
return object_hazard, reward
def goal_reaching_reward(self, vehicle):
# Now we will write goal_reaching_rewards
vehicle_location = vehicle.get_location()
vehicle_velocity = vehicle.get_velocity()
target_location = self.target_location
# This is the distance computation
try:
dist = self.route_planner.compute_distance(vehicle_location, target_location)
vel_forward, vel_perp = self.route_planner.compute_direction_velocities(vehicle_location, vehicle_velocity, target_location)
except TypeError:
# Weird bug where the graph disappears
vel_forward = 0
vel_perp = 0
#print('[GoalReachReward] VehLoc: %s Target: %s Dist: %s VelF:%s' % (str(vehicle_location), str(target_location), str(dist), str(vel_forward)))
#base_reward = -1.0 * (dist / 100.0) + 5.0
base_reward = vel_forward
collided_done, collision_reward = self._get_collision_reward(vehicle)
traffic_light_done, traffic_light_reward = self._get_traffic_light_reward(vehicle)
object_collided_done, object_collided_reward = self._get_object_collided_reward(vehicle)
total_reward = base_reward + 100 * collision_reward # + 100 * traffic_light_reward + 100.0 * object_collided_reward
reward_dict = dict()
reward_dict['collision'] = collision_reward
reward_dict['traffic_light'] = traffic_light_reward
reward_dict['object_collision'] = object_collided_reward
reward_dict['base_reward'] = base_reward
done_dict = dict()
done_dict['collided_done'] = collided_done
done_dict['traffic_light_done'] = traffic_light_done
done_dict['object_collided_done'] = object_collided_done
return total_reward, reward_dict, done_dict
def lane_follow_reward(self, vehicle):
# assume on highway
vehicle_location = vehicle.get_location()
vehicle_waypoint = self.map.get_waypoint(vehicle_location)
vehicle_xy = np.array([vehicle_location.x, vehicle_location.y])
vehicle_s = vehicle_waypoint.s
vehicle_velocity = vehicle.get_velocity() # Vector3D
vehicle_velocity_xy = np.array([vehicle_velocity.x, vehicle_velocity.y])
# print ('Velocity: ', vehicle_velocity_xy)
speed = np.linalg.norm(vehicle_velocity_xy)
vehicle_waypoint_closest_to_road = \
self.map.get_waypoint(vehicle_location, project_to_road=True, lane_type=carla.LaneType.Driving)
road_id = vehicle_waypoint_closest_to_road.road_id
assert road_id is not None
goal_abs_lane_id = 1 # just for goal-following
lane_id_sign = int(np.sign(vehicle_waypoint_closest_to_road.lane_id))
assert lane_id_sign in [-1, 1]
goal_lane_id = goal_abs_lane_id * lane_id_sign
current_waypoint = self.map.get_waypoint(vehicle_location, project_to_road=False)
goal_waypoint = self.map.get_waypoint_xodr(road_id, goal_lane_id, vehicle_s)
# Check for valid goal waypoint
if goal_waypoint is None:
print ('goal waypoint is None...')
# try to fix, bit of a hack, with CARLA waypoint discretizations
carla_waypoint_discretization = 0.02 # meters
goal_waypoint = self.map.get_waypoint_xodr(road_id, goal_lane_id, vehicle_s - carla_waypoint_discretization)
if goal_waypoint is None:
goal_waypoint = self.map.get_waypoint_xodr(road_id, goal_lane_id, vehicle_s + carla_waypoint_discretization)
# set distance to 100 if the waypoint is off the road
if goal_waypoint is None:
print("Episode fail: goal waypoint is off the road! (frame %d)" % self.count)
done, dist, vel_s = True, 100., 0.
else:
goal_location = goal_waypoint.transform.location
goal_xy = np.array([goal_location.x, goal_location.y])
# dist = np.linalg.norm(vehicle_xy - goal_xy)
dists = []
for abs_lane_id in [1, 2, 3, 4]:
lane_id_ = abs_lane_id * lane_id_sign
wp = self.map.get_waypoint_xodr(road_id, lane_id_, vehicle_s)
if wp is not None: # lane 4 might not exist where the highway has a turnoff
loc = wp.transform.location
xy = np.array([loc.x, loc.y])
dists.append(np.linalg.norm(vehicle_xy - xy))
if dists:
dist = min(dists) # just try to get to the center of one of the lanes
else:
dist = 0.
next_goal_waypoint = goal_waypoint.next(0.1) # waypoints are ever 0.02 meters
if len(next_goal_waypoint) != 1:
print('warning: {} waypoints (not 1)'.format(len(next_goal_waypoint)))
if len(next_goal_waypoint) == 0:
print("Episode done: no more waypoints left. (frame %d)" % self.count)
done, vel_s, vel_perp = True, 0., 0.
else:
location_ahead = next_goal_waypoint[0].transform.location
highway_vector = np.array([location_ahead.x, location_ahead.y]) - goal_xy
highway_unit_vector = np.array(highway_vector) / np.linalg.norm(highway_vector)
vel_s = np.dot(vehicle_velocity_xy, highway_unit_vector)
unit_velocity = vehicle_velocity_xy / (np.linalg.norm(vehicle_velocity_xy) + 1e-8)
angle = np.arccos(np.clip(np.dot(unit_velocity, highway_unit_vector), -1.0, 1.0))
#vel_forward = np.linalg.norm(vehicle_velocity_xy) * np.cos(angle)
vel_perp = np.linalg.norm(vehicle_velocity_xy) * np.sin(angle)
#print('R:', np.clip(vel_s-5*vel_perp, -5.0, 5.0), 'vel_s:', vel_s, 'vel_perp:', vel_perp)
#import pdb; pdb.set_trace()
done = False
# not algorithm's fault, but the simulator sometimes throws the car in the air wierdly
# usually in initial few frames, which can be ignored
"""
if vehicle_velocity.z > 1. and self.count < 20:
print("Episode done: vertical velocity too high ({}), usually a simulator glitch (frame {})".format(vehicle_velocity.z, self.count))
done = True
if vehicle_location.z > 0.5 and self.count < 20:
print("Episode done: vertical velocity too high ({}), usually a simulator glitch (frame {})".format(vehicle_location.z, self.count))
done = True
"""
## Add rewards for collision and optionally traffic lights
vehicle_location = vehicle.get_location()
base_reward = np.clip(vel_s - 5*vel_perp, -5.0, 5.0)
collided_done, collision_reward = self._get_collision_reward(vehicle)
traffic_light_done, traffic_light_reward = self._get_traffic_light_reward(vehicle)
object_collided_done, object_collided_reward = self._get_object_collided_reward(vehicle)
total_reward = base_reward + 100 * collision_reward + 100 * traffic_light_reward + 100.0 * object_collided_reward
reward_dict = dict()
reward_dict['collision'] = collision_reward
reward_dict['traffic_light'] = traffic_light_reward
reward_dict['object_collision'] = object_collided_reward
reward_dict['base_reward'] = base_reward
reward_dict['base_reward_vel_s'] = vel_s
reward_dict['base_reward_vel_perp'] = vel_perp
done_dict = dict()
done_dict['collided_done'] = collided_done
done_dict['traffic_light_done'] = traffic_light_done
done_dict['object_collided_done'] = object_collided_done
done_dict['base_done'] = done
return total_reward, reward_dict, done_dict
def _simulator_step(self, action, traffic_light_color):
if action is None:
throttle, steer, brake = 0., 0., 0.
else:
steer = float(action[1])
throttle_brake = float(action[0])
if throttle_brake >= 0.0:
throttle = throttle_brake
brake = 0.0
else:
throttle = 0.0
brake = -throttle_brake
vehicle_control = carla.VehicleControl(
throttle=float(throttle),
steer=float(steer),
brake=float(brake),
hand_brake=False,
reverse=False,
manual_gear_shift=False
)
self.vehicle.apply_control(vehicle_control)
# Advance the simulation and wait for the data.
if self.render_display:
snapshot, display_image, vision_image = self.sync_mode.tick(timeout=2.0)
else:
snapshot, vision_image = self.sync_mode.tick(timeout=2.0)
# Weather evolves
self.weather.tick()
# Draw the display.
if self.render_display:
self.render_display.blit(self.font.render('Frame %d' % self.count, True, (255, 255, 255)), (8, 10))
self.render_display.blit(self.font.render('Control: %5.2f thottle, %5.2f steer, %5.2f brake' % (throttle, steer, brake), True, (255, 255, 255)), (8, 28))
self.render_display.blit(self.font.render('Traffic light: ' + traffic_light_color, True, (255, 255, 255)), (8, 46))
self.render_display.blit(self.font.render(str(self.weather), True, (255, 255, 255)), (8, 64))
pygame.display.flip()
# Format rl image
bgra = np.array(vision_image.raw_data).reshape(self.vision_size, self.vision_size, 4) # BGRA format
bgr = bgra[:, :, :3] # BGR format (84 x 84 x 3)
rgb = np.flip(bgr, axis=2) # RGB format (84 x 84 x 3)
if self.render_display and self.record_display:
image_name = os.path.join(self.record_dir, "display%08d.jpg" % self.count)
pygame.image.save(self.render_display, image_name)
# # Can animate with:
# ffmpeg -r 20 -pattern_type glob -i 'display*.jpg' carla.mp4
if self.record_vision:
image_name = os.path.join(self.record_dir, "vision%08d.png" % self.count)
print('savedimg:', image_name)
im = Image.fromarray(rgb)
# add any meta data you like into the image before we save it:
metadata = PngInfo()
metadata.add_text("throttle", str(throttle))
metadata.add_text("steer", str(steer))
metadata.add_text("brake", str(brake))
metadata.add_text("lights", traffic_light_color)
# acceleration
acceleration = self.vehicle.get_acceleration()
metadata.add_text("acceleration_x", str(acceleration.x))
metadata.add_text("acceleration_y", str(acceleration.y))
metadata.add_text("acceleration_z", str(acceleration.z))
# angular velocity
angular_velocity = self.vehicle.get_angular_velocity()
metadata.add_text("angular_velocity_x", str(angular_velocity.x))
metadata.add_text("angular_velocity_y", str(angular_velocity.y))
metadata.add_text("angular_velocity_z", str(angular_velocity.z))
# location
location = self.vehicle.get_location()
metadata.add_text("location_x", str(location.x))
metadata.add_text("location_y", str(location.y))
metadata.add_text("location_z", str(location.z))
# rotation
rotation = self.vehicle.get_transform().rotation
metadata.add_text("rotation_pitch", str(rotation.pitch))
metadata.add_text("rotation_yaw", str(rotation.yaw))
metadata.add_text("rotation_roll", str(rotation.roll))
forward_vector = rotation.get_forward_vector()
metadata.add_text("forward_vector_x", str(forward_vector.x))
metadata.add_text("forward_vector_y", str(forward_vector.y))
metadata.add_text("forward_vector_z", str(forward_vector.z))
# velocity
velocity = self.vehicle.get_velocity()
metadata.add_text("velocity_x", str(velocity.x))
metadata.add_text("velocity_y", str(velocity.y))
metadata.add_text("velocity_z", str(velocity.z))
# weather
metadata.add_text("weather_cloudiness ", str(self.weather.weather.cloudiness))
metadata.add_text("weather_precipitation", str(self.weather.weather.precipitation))
metadata.add_text("weather_precipitation_deposits", str(self.weather.weather.precipitation_deposits))
metadata.add_text("weather_wind_intensity", str(self.weather.weather.wind_intensity))
metadata.add_text("weather_fog_density", str(self.weather.weather.fog_density))
metadata.add_text("weather_wetness", str(self.weather.weather.wetness))
metadata.add_text("weather_sun_azimuth_angle", str(self.weather.weather.sun_azimuth_angle))
# settings
metadata.add_text("settings_map", self.map.name)
metadata.add_text("settings_vision_size", str(self.vision_size))
metadata.add_text("settings_vision_fov", str(self.vision_fov))
metadata.add_text("settings_changing_weather_speed", str(self.changing_weather_speed))
metadata.add_text("settings_multiagent", str(self.multiagent))
# traffic lights
metadata.add_text("traffic_lights_color", "UNLABELED")
metadata.add_text("reward", str(reward))
## Add in reward dict
for key in reward_dict:
metadata.add_text("reward_" + str(key), str(reward_dict[key]))
for key in done_dict:
metadata.add_text("done_" + str(key), str(done_dict[key]))
## Save the target location as well
metadata.add_text('target_location_x', str(self.target_location.x))
metadata.add_text('target_location_y', str(self.target_location.y))
metadata.add_text('target_location_z', str(self.target_location.z))
im.save(image_name, "PNG", pnginfo=metadata)
self.count += 1
next_obs = rgb
done = False
if done:
print("Episode success: I've reached the episode horizon ({}).".format(self.max_episode_steps))
if self.reward_type=='lane_follow':
reward, reward_dict, done_dict = self.lane_follow_reward(self.vehicle)
elif self.reward_type=='goal_reaching':
reward, reward_dict, done_dict = self.goal_reaching_reward(self.vehicle)
else:
raise ValueError('unknown reward type:', self.reward_type)
info = reward_dict
info.update(done_dict)
done = False
for key in done_dict:
done = (done or done_dict[key])
#if done:
# print('done_dict:', done_dict, 'r:', reward)
return next_obs, reward, done, info
def finish(self):
print('destroying actors.')
for actor in self.actor_list:
actor.destroy()
print('\ndestroying %d vehicles' % len(self.vehicles_list))
self.client.apply_batch([carla.command.DestroyActor(x) for x in self.vehicles_list])
time.sleep(0.5)
pygame.quit()
print('done.')
class CarlaObsDictEnv(OfflineEnv):
def __init__(self, carla_args=None, carla_port=2000, reward_type='lane_follow', render_images=False, **kwargs):
self._wrapped_env = CarlaEnv(carla_port=carla_port, args=carla_args, reward_type=reward_type, record_vision=render_images)
print('[CarlaObsDictEnv] render_images:', render_images)
self._wrapped_env = CarlaEnv(carla_port=carla_port, args=carla_args, record_vision=render_images)
self.action_space = self._wrapped_env.action_space
self.observation_space = self._wrapped_env.observation_space
self.observation_size = int(np.prod(self._wrapped_env.observation_space.shape))
self.observation_space = spaces.Dict({
'image':spaces.Box(low=np.array([0.0] * self.observation_size), high=np.array([256.0,] * self.observation_size))
})
print (self.observation_space)
super(CarlaObsDictEnv, self).__init__(**kwargs)
@property
def wrapped_env(self):
return self._wrapped_env
def reset(self, **kwargs):
self._wrapped_env.reset_init()
obs = (self._wrapped_env.reset(**kwargs))
obs_dict = dict()
# Also normalize obs
obs_dict['image'] = (obs.astype(np.float32) / 255.0).flatten()
return obs_dict
def step(self, action):
#print ('Action: ', action)
next_obs, reward, done, info = self._wrapped_env.step(action)
next_obs_dict = dict()
next_obs_dict['image'] = (next_obs.astype(np.float32) / 255.0).flatten()
# print ('Reward: ', reward)
# print ('Done dict: ', info)
return next_obs_dict, reward, done, info
def render(self, *args, **kwargs):
return self._wrapped_env.render(*args, **kwargs)
@property
def horizon(self):
return self._wrapped_env.horizon
def terminate(self):
if hasattr(self.wrapped_env, "terminate"):
self._wrapped_env.terminate()
def __getattr__(self, attr):
if attr == '_wrapped_env':
raise AttributeError()
return getattr(self._wrapped_env, attr)
def __getstate__(self):
"""
This is useful to override in case the wrapped env has some funky
__getstate__ that doesn't play well with overriding __getattr__.
The main problematic case is/was gym's EzPickle serialization scheme.
:return:
"""
return self.__dict__
def __setstate__(self, state):
self.__dict__.update(state)
def __str__(self):
return '{}({})'.format(type(self).__name__, self.wrapped_env)
class CarlaObsEnv(OfflineEnv):
def __init__(self, carla_args=None, carla_port=2000, reward_type='lane_follow', render_images=False, **kwargs):
self._wrapped_env = CarlaEnv(carla_port=carla_port, args=carla_args, reward_type=reward_type, record_vision=render_images)
self.action_space = self._wrapped_env.action_space
self.observation_space = self._wrapped_env.observation_space
self.observation_size = int(np.prod(self._wrapped_env.observation_space.shape))
self.observation_space = spaces.Box(low=np.array([0.0] * self.observation_size), high=np.array([256.0,] * self.observation_size))
#self.observation_space = spaces.Dict({
# 'image':spaces.Box(low=np.array([0.0] * self.observation_size), high=np.array([256.0,] * self.observation_size))
#})
super(CarlaObsEnv, self).__init__(**kwargs)
@property
def wrapped_env(self):
return self._wrapped_env
def reset(self, **kwargs):
self._wrapped_env.reset_init()
obs = (self._wrapped_env.reset(**kwargs))
obs_dict = dict()
# Also normalize obs
obs_dict = (obs.astype(np.float32) / 255.0).flatten()
return obs_dict
def step(self, action):
#print ('Action: ', action)
next_obs, reward, done, info = self._wrapped_env.step(action)
#next_obs_dict = dict()
#next_obs_dict['image'] = (next_obs.astype(np.float32) / 255.0).flatten()
next_obs_dict = (next_obs.astype(np.float32) / 255.0).flatten()
# print ('Reward: ', reward)
# print ('Done dict: ', info)
return next_obs_dict, reward, done, info
def render(self, *args, **kwargs):
return self._wrapped_env.render(*args, **kwargs)
@property
def horizon(self):
return self._wrapped_env.horizon
def terminate(self):
if hasattr(self.wrapped_env, "terminate"):
self._wrapped_env.terminate()
def __getattr__(self, attr):
if attr == '_wrapped_env':
raise AttributeError()
return getattr(self._wrapped_env, attr)
def __getstate__(self):
"""
This is useful to override in case the wrapped env has some funky
__getstate__ that doesn't play well with overriding __getattr__.
The main problematic case is/was gym's EzPickle serialization scheme.
:return:
"""
return self.__dict__
def __setstate__(self, state):
self.__dict__.update(state)
def __str__(self):
return '{}({})'.format(type(self).__name__, self.wrapped_env)
if __name__ == '__main__':
variant = dict()
variant['vision_size'] = 48
variant['vision_fov'] = 48
variant['weather'] = False
variant['frame_skip'] = 1
variant['steps'] = 100000
variant['multiagent'] = False
variant['lane'] = 0
variant['lights'] = False
variant['record_dir'] = None
env = CarlaEnv(args=variant)
carla_gym_env = proxy_env.ProxyEnv(env)
================================================
FILE: d4rl/d4rl/carla/data_collection_agent_lane.py
================================================
# !/usr/bin/env python
# Copyright (c) 2019 Computer Vision Center (CVC) at the Universitat Autonoma de
# Barcelona (UAB).
#
# This work is licensed under the terms of the MIT license.
# For a copy, see .
#
# Modified by Rowan McAllister on 20 April 2020
import argparse
import datetime
import glob
import os
import random
import sys
import time
from PIL import Image
from PIL.PngImagePlugin import PngInfo
try:
sys.path.append(glob.glob('../carla/dist/carla-*%d.%d-%s.egg' % (
sys.version_info.major,
sys.version_info.minor,
'win-amd64' if os.name == 'nt' else 'linux-x86_64'))[0])
except IndexError:
pass
import carla
import math
from dotmap import DotMap
try:
import pygame
except ImportError:
raise RuntimeError('cannot import pygame, make sure pygame package is installed')
try:
import numpy as np
except ImportError:
raise RuntimeError('cannot import numpy, make sure numpy package is installed')
try:
import queue
except ImportError:
import Queue as queue
from agents.navigation.agent import Agent, AgentState
from agents.navigation.local_planner import LocalPlanner
from agents.navigation.global_route_planner import GlobalRoutePlanner
from agents.tools.misc import is_within_distance_ahead, compute_magnitude_angle
from agents.navigation.global_route_planner_dao import GlobalRoutePlannerDAO
def is_within_distance(target_location, current_location, orientation, max_distance, d_angle_th_up, d_angle_th_low=0):
"""
Check if a target object is within a certain distance from a reference object.
A vehicle in front would be something around 0 deg, while one behind around 180 deg.
:param target_location: location of the target object
:param current_location: location of the reference object
:param orientation: orientation of the reference object
:param max_distance: maximum allowed distance
:param d_angle_th_up: upper thereshold for angle
:param d_angle_th_low: low thereshold for angle (optional, default is 0)
:return: True if target object is within max_distance ahead of the reference object
"""
target_vector = np.array([target_location.x - current_location.x, target_location.y - current_location.y])
norm_target = np.linalg.norm(target_vector)
# If the vector is too short, we can simply stop here
if norm_target < 0.001:
return True
if norm_target > max_distance:
return False
forward_vector = np.array(
[math.cos(math.radians(orientation)), math.sin(math.radians(orientation))])
d_angle = math.degrees(math.acos(np.clip(np.dot(forward_vector, target_vector) / norm_target, -1., 1.)))
return d_angle_th_low < d_angle < d_angle_th_up
def compute_distance(location_1, location_2):
"""
Euclidean distance between 3D points
:param location_1, location_2: 3D points
"""
x = location_2.x - location_1.x
y = location_2.y - location_1.y
z = location_2.z - location_1.z
norm = np.linalg.norm([x, y, z]) + np.finfo(float).eps
return norm
class CarlaSyncMode(object):
"""
Context manager to synchronize output from different sensors. Synchronous
mode is enabled as long as we are inside this context
with CarlaSyncMode(world, sensors) as sync_mode:
while True:
data = sync_mode.tick(timeout=1.0)
"""
def __init__(self, world, *sensors, **kwargs):
self.world = world
self.sensors = sensors
self.frame = None
self.delta_seconds = 1.0 / kwargs.get('fps', 20)
self._queues = []
self._settings = None
self.start()
def start(self):
self._settings = self.world.get_settings()
self.frame = self.world.apply_settings(carla.WorldSettings(
no_rendering_mode=False,
synchronous_mode=True,
fixed_delta_seconds=self.delta_seconds))
def make_queue(register_event):
q = queue.Queue()
register_event(q.put)
self._queues.append(q)
make_queue(self.world.on_tick)
for sensor in self.sensors:
make_queue(sensor.listen)
def tick(self, timeout):
self.frame = self.world.tick()
data = [self._retrieve_data(q, timeout) for q in self._queues]
assert all(x.frame == self.frame for x in data)
return data
def __exit__(self, *args, **kwargs):
self.world.apply_settings(self._settings)
def _retrieve_data(self, sensor_queue, timeout):
while True:
data = sensor_queue.get(timeout=timeout)
if data.frame == self.frame:
return data
def draw_image(surface, image, blend=False):
array = np.frombuffer(image.raw_data, dtype=np.dtype("uint8"))
array = np.reshape(array, (image.height, image.width, 4))
array = array[:, :, :3]
array = array[:, :, ::-1]
image_surface = pygame.surfarray.make_surface(array.swapaxes(0, 1))
if blend:
image_surface.set_alpha(100)
surface.blit(image_surface, (0, 0))
def get_font():
fonts = [x for x in pygame.font.get_fonts()]
default_font = 'ubuntumono'
font = default_font if default_font in fonts else fonts[0]
font = pygame.font.match_font(font)
return pygame.font.Font(font, 14)
def should_quit():
for event in pygame.event.get():
if event.type == pygame.QUIT:
return True
elif event.type == pygame.KEYUP:
if event.key == pygame.K_ESCAPE:
return True
return False
def clamp(value, minimum=0.0, maximum=100.0):
return max(minimum, min(value, maximum))
class Sun(object):
def __init__(self, azimuth, altitude):
self.azimuth = azimuth
self.altitude = altitude
self._t = 0.0
def tick(self, delta_seconds):
self._t += 0.008 * delta_seconds
self._t %= 2.0 * math.pi
self.azimuth += 0.25 * delta_seconds
self.azimuth %= 360.0
min_alt, max_alt = [20, 90]
self.altitude = 0.5 * (max_alt + min_alt) + 0.5 * (max_alt - min_alt) * math.cos(self._t)
def __str__(self):
return 'Sun(alt: %.2f, azm: %.2f)' % (self.altitude, self.azimuth)
class Storm(object):
def __init__(self, precipitation):
self._t = precipitation if precipitation > 0.0 else -50.0
self._increasing = True
self.clouds = 0.0
self.rain = 0.0
self.wetness = 0.0
self.puddles = 0.0
self.wind = 0.0
self.fog = 0.0
def tick(self, delta_seconds):
delta = (1.3 if self._increasing else -1.3) * delta_seconds
self._t = clamp(delta + self._t, -250.0, 100.0)
self.clouds = clamp(self._t + 40.0, 0.0, 90.0)
self.clouds = clamp(self._t + 40.0, 0.0, 60.0)
self.rain = clamp(self._t, 0.0, 80.0)
delay = -10.0 if self._increasing else 90.0
self.puddles = clamp(self._t + delay, 0.0, 85.0)
self.wetness = clamp(self._t * 5, 0.0, 100.0)
self.wind = 5.0 if self.clouds <= 20 else 90 if self.clouds >= 70 else 40
self.fog = clamp(self._t - 10, 0.0, 30.0)
if self._t == -250.0:
self._increasing = True
if self._t == 100.0:
self._increasing = False
def __str__(self):
return 'Storm(clouds=%d%%, rain=%d%%, wind=%d%%)' % (self.clouds, self.rain, self.wind)
class Weather(object):
def __init__(self, world, changing_weather_speed):
self.world = world
self.reset()
self.weather = world.get_weather()
self.changing_weather_speed = changing_weather_speed
self._sun = Sun(self.weather.sun_azimuth_angle, self.weather.sun_altitude_angle)
self._storm = Storm(self.weather.precipitation)
def reset(self):
weather_params = carla.WeatherParameters(sun_altitude_angle=90.)
self.world.set_weather(weather_params)
def tick(self):
self._sun.tick(self.changing_weather_speed)
self._storm.tick(self.changing_weather_speed)
self.weather.cloudiness = self._storm.clouds
self.weather.precipitation = self._storm.rain
self.weather.precipitation_deposits = self._storm.puddles
self.weather.wind_intensity = self._storm.wind
self.weather.fog_density = self._storm.fog
self.weather.wetness = self._storm.wetness
self.weather.sun_azimuth_angle = self._sun.azimuth
self.weather.sun_altitude_angle = self._sun.altitude
self.world.set_weather(self.weather)
def __str__(self):
return '%s %s' % (self._sun, self._storm)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--vision_size', type=int, default=84)
parser.add_argument('--vision_fov', type=int, default=90)
parser.add_argument('--weather', default=False, action='store_true')
parser.add_argument('--frame_skip', type=int, default=1),
parser.add_argument('--steps', type=int, default=100000)
parser.add_argument('--multiagent', default=False, action='store_true'),
parser.add_argument('--lane', type=int, default=0)
parser.add_argument('--lights', default=False, action='store_true')
args = parser.parse_args()
return args
class LocalPlannerModified(LocalPlanner):
def __del__(self):
pass # otherwise it deletes our vehicle object
def run_step(self):
return super().run_step(debug=False) # otherwise by default shows waypoints, that interfere with our camera
class RoamingAgent(Agent):
"""
RoamingAgent implements a basic agent that navigates scenes making random
choices when facing an intersection.
This agent respects traffic lights and other vehicles.
NOTE: need to re-create after each env reset
"""
def __init__(self, env):
"""
:param vehicle: actor to apply to local planner logic onto
"""
vehicle = env.vehicle
follow_traffic_lights = env.follow_traffic_lights
super(RoamingAgent, self).__init__(vehicle)
self._proximity_threshold = 10.0 # meters
self._state = AgentState.NAVIGATING
self._local_planner = LocalPlannerModified(self._vehicle)
self._follow_traffic_lights = follow_traffic_lights
def compute_action(self):
action, traffic_light = self.run_step()
throttle = action.throttle
brake = action.brake
steer = action.steer
#print('tbsl:', throttle, brake, steer, traffic_light)
if brake == 0.0:
return np.array([throttle, steer])
else:
return np.array([-brake, steer])
def run_step(self):
"""
Execute one step of navigation.
:return: carla.VehicleControl
"""
# is there an obstacle in front of us?
hazard_detected = False
# retrieve relevant elements for safe navigation, i.e.: traffic lights and other vehicles
actor_list = self._world.get_actors()
vehicle_list = actor_list.filter("*vehicle*")
lights_list = actor_list.filter("*traffic_light*")
# check possible obstacles
vehicle_state, vehicle = self._is_vehicle_hazard(vehicle_list)
if vehicle_state:
self._state = AgentState.BLOCKED_BY_VEHICLE
hazard_detected = True
# check for the state of the traffic lights
traffic_light_color = self._is_light_red(lights_list)
if traffic_light_color == 'RED' and self._follow_traffic_lights:
self._state = AgentState.BLOCKED_RED_LIGHT
hazard_detected = True
if hazard_detected:
control = self.emergency_stop()
else:
self._state = AgentState.NAVIGATING
# standard local planner behavior
control = self._local_planner.run_step()
#print ('Action chosen: ', control)
return control, traffic_light_color
# override case class
def _is_light_red_europe_style(self, lights_list):
"""
This method is specialized to check European style traffic lights.
Only suitable for Towns 03 -- 07.
"""
ego_vehicle_location = self._vehicle.get_location()
ego_vehicle_waypoint = self._map.get_waypoint(ego_vehicle_location)
traffic_light_color = "NONE" # default, if no traffic lights are seen
for traffic_light in lights_list:
object_waypoint = self._map.get_waypoint(traffic_light.get_location())
if object_waypoint.road_id != ego_vehicle_waypoint.road_id or \
object_waypoint.lane_id != ego_vehicle_waypoint.lane_id:
continue
if is_within_distance_ahead(traffic_light.get_transform(),
self._vehicle.get_transform(),
self._proximity_threshold):
if traffic_light.state == carla.TrafficLightState.Red:
return "RED"
elif traffic_light.state == carla.TrafficLightState.Yellow:
traffic_light_color = "YELLOW"
elif traffic_light.state == carla.TrafficLightState.Green:
if traffic_light_color is not "YELLOW": # (more severe)
traffic_light_color = "GREEN"
else:
import pdb; pdb.set_trace()
# investigate https://carla.readthedocs.io/en/latest/python_api/#carlatrafficlightstate
return traffic_light_color
# override case class
def _is_light_red_us_style(self, lights_list, debug=False):
ego_vehicle_location = self._vehicle.get_location()
ego_vehicle_waypoint = self._map.get_waypoint(ego_vehicle_location)
traffic_light_color = "NONE" # default, if no traffic lights are seen
if ego_vehicle_waypoint.is_junction:
# It is too late. Do not block the intersection! Keep going!
return "JUNCTION"
if self._local_planner.target_waypoint is not None:
if self._local_planner.target_waypoint.is_junction:
min_angle = 180.0
sel_magnitude = 0.0
sel_traffic_light = None
for traffic_light in lights_list:
loc = traffic_light.get_location()
magnitude, angle = compute_magnitude_angle(loc,
ego_vehicle_location,
self._vehicle.get_transform().rotation.yaw)
if magnitude < 60.0 and angle < min(25.0, min_angle):
sel_magnitude = magnitude
sel_traffic_light = traffic_light
min_angle = angle
if sel_traffic_light is not None:
if debug:
print('=== Magnitude = {} | Angle = {} | ID = {}'.format(
sel_magnitude, min_angle, sel_traffic_light.id))
if self._last_traffic_light is None:
self._last_traffic_light = sel_traffic_light
if self._last_traffic_light.state == carla.TrafficLightState.Red:
return "RED"
elif self._last_traffic_light.state == carla.TrafficLightState.Yellow:
traffic_light_color = "YELLOW"
elif self._last_traffic_light.state == carla.TrafficLightState.Green:
if traffic_light_color is not "YELLOW": # (more severe)
traffic_light_color = "GREEN"
else:
import pdb; pdb.set_trace()
# investigate https://carla.readthedocs.io/en/latest/python_api/#carlatrafficlightstate
else:
self._last_traffic_light = None
return traffic_light_color
if __name__ == '__main__':
# example call:
# ./PythonAPI/util/config.py --map Town01 --delta-seconds 0.05
# python PythonAPI/carla/agents/navigation/data_collection_agent.py --vision_size 256 --vision_fov 90 --steps 10000 --weather --lights
args = parse_args()
env = CarlaEnv(args)
try:
done = False
while not done:
action, traffic_light_color = env.compute_action()
next_obs, reward, done, info = env.step(action, traffic_light_color)
print ('Reward: ', reward, 'Done: ', done, 'Location: ', env.vehicle.get_location())
if done:
# env.reset_init()
# env.reset()
done = False
finally:
env.finish()
================================================
FILE: d4rl/d4rl/carla/data_collection_town.py
================================================
#!/usr/bin/env python
# Copyright (c) 2019 Computer Vision Center (CVC) at the Universitat Autonoma de
# Barcelona (UAB).
#
# This work is licensed under the terms of the MIT license.
# For a copy, see .
#
# Modified by Rowan McAllister on 20 April 2020
import argparse
import datetime
import glob
import os
import random
import sys
import time
from PIL import Image
from PIL.PngImagePlugin import PngInfo
try:
sys.path.append(glob.glob('../carla/dist/carla-*%d.%d-%s.egg' % (
sys.version_info.major,
sys.version_info.minor,
'win-amd64' if os.name == 'nt' else 'linux-x86_64'))[0])
except IndexError:
pass
import carla
import math
from dotmap import DotMap
try:
import pygame
except ImportError:
raise RuntimeError('cannot import pygame, make sure pygame package is installed')
try:
import numpy as np
except ImportError:
raise RuntimeError('cannot import numpy, make sure numpy package is installed')
try:
import queue
except ImportError:
import Queue as queue
from agents.navigation.agent import Agent, AgentState
from agents.navigation.local_planner import LocalPlanner
from agents.navigation.global_route_planner import GlobalRoutePlanner
from agents.navigation.global_route_planner_dao import GlobalRoutePlannerDAO
from agents.tools.misc import is_within_distance_ahead #, is_within_distance, compute_distance
from agents.tools.misc import is_within_distance_ahead, compute_magnitude_angle
def is_within_distance(target_location, current_location, orientation, max_distance, d_angle_th_up, d_angle_th_low=0):
"""
Check if a target object is within a certain distance from a reference object.
A vehicle in front would be something around 0 deg, while one behind around 180 deg.
:param target_location: location of the target object
:param current_location: location of the reference object
:param orientation: orientation of the reference object
:param max_distance: maximum allowed distance
:param d_angle_th_up: upper thereshold for angle
:param d_angle_th_low: low thereshold for angle (optional, default is 0)
:return: True if target object is within max_distance ahead of the reference object
"""
target_vector = np.array([target_location.x - current_location.x, target_location.y - current_location.y])
norm_target = np.linalg.norm(target_vector)
# If the vector is too short, we can simply stop here
if norm_target < 0.001:
return True
if norm_target > max_distance:
return False
forward_vector = np.array(
[math.cos(math.radians(orientation)), math.sin(math.radians(orientation))])
d_angle = math.degrees(math.acos(np.clip(np.dot(forward_vector, target_vector) / norm_target, -1., 1.)))
return d_angle_th_low < d_angle < d_angle_th_up
def compute_distance(location_1, location_2):
"""
Euclidean distance between 3D points
:param location_1, location_2: 3D points
"""
x = location_2.x - location_1.x
y = location_2.y - location_1.y
z = location_2.z - location_1.z
norm = np.linalg.norm([x, y, z]) + np.finfo(float).eps
return norm
class CustomGlobalRoutePlanner(GlobalRoutePlanner):
def __init__(self, dao):
super(CustomGlobalRoutePlanner, self).__init__(dao=dao)
"""
def compute_distance(self, origin, destination):
node_list = super(CustomGlobalRoutePlanner, self)._path_search(origin=origin, destination=destination)
distance = 0.0
for idx in range(len(node_list) - 1):
distance += (super(CustomGlobalRoutePlanner, self)._distance_heuristic(node_list[idx], node_list[idx+1]))
# print ('Distance: ', distance)
return distance
"""
def compute_direction_velocities(self, origin, velocity, destination):
node_list = super(CustomGlobalRoutePlanner, self)._path_search(origin=origin, destination=destination)
origin_xy = np.array([origin.x, origin.y])
velocity_xy = np.array([velocity.x, velocity.y])
first_node_xy = self._graph.nodes[node_list[1]]['vertex']
first_node_xy = np.array([first_node_xy[0], first_node_xy[1]])
target_direction_vector = first_node_xy - origin_xy
target_unit_vector = np.array(target_direction_vector) / np.linalg.norm(target_direction_vector)
vel_s = np.dot(velocity_xy, target_unit_vector)
unit_velocity = velocity_xy / (np.linalg.norm(velocity_xy) + 1e-8)
angle = np.arccos(np.clip(np.dot(unit_velocity, target_unit_vector), -1.0, 1.0))
vel_perp = np.linalg.norm(velocity_xy) * np.sin(angle)
return vel_s, vel_perp
def compute_distance(self, origin, destination):
node_list = super(CustomGlobalRoutePlanner, self)._path_search(origin=origin, destination=destination)
#print('Node list:', node_list)
first_node_xy = self._graph.nodes[node_list[0]]['vertex']
#print('Diff:', origin, first_node_xy)
#distance = 0.0
distances = []
distances.append(np.linalg.norm(np.array([origin.x, origin.y, 0.0]) - np.array(first_node_xy)))
for idx in range(len(node_list) - 1):
distances.append(super(CustomGlobalRoutePlanner, self)._distance_heuristic(node_list[idx], node_list[idx+1]))
#print('Distances:', distances)
#import pdb; pdb.set_trace()
return np.sum(distances)
class CarlaSyncMode(object):
"""
Context manager to synchronize output from different sensors. Synchronous
mode is enabled as long as we are inside this context
with CarlaSyncMode(world, sensors) as sync_mode:
while True:
data = sync_mode.tick(timeout=1.0)
"""
def __init__(self, world, *sensors, **kwargs):
self.world = world
self.sensors = sensors
self.frame = None
self.delta_seconds = 1.0 / kwargs.get('fps', 20)
self._queues = []
self._settings = None
self.start()
def start(self):
self._settings = self.world.get_settings()
self.frame = self.world.apply_settings(carla.WorldSettings(
no_rendering_mode=False,
synchronous_mode=True,
fixed_delta_seconds=self.delta_seconds))
def make_queue(register_event):
q = queue.Queue()
register_event(q.put)
self._queues.append(q)
make_queue(self.world.on_tick)
for sensor in self.sensors:
make_queue(sensor.listen)
def tick(self, timeout):
self.frame = self.world.tick()
data = [self._retrieve_data(q, timeout) for q in self._queues]
assert all(x.frame == self.frame for x in data)
return data
def __exit__(self, *args, **kwargs):
self.world.apply_settings(self._settings)
def _retrieve_data(self, sensor_queue, timeout):
while True:
data = sensor_queue.get(timeout=timeout)
if data.frame == self.frame:
return data
def draw_image(surface, image, blend=False):
array = np.frombuffer(image.raw_data, dtype=np.dtype("uint8"))
array = np.reshape(array, (image.height, image.width, 4))
array = array[:, :, :3]
array = array[:, :, ::-1]
image_surface = pygame.surfarray.make_surface(array.swapaxes(0, 1))
if blend:
image_surface.set_alpha(100)
surface.blit(image_surface, (0, 0))
def get_font():
fonts = [x for x in pygame.font.get_fonts()]
default_font = 'ubuntumono'
font = default_font if default_font in fonts else fonts[0]
font = pygame.font.match_font(font)
return pygame.font.Font(font, 14)
def should_quit():
for event in pygame.event.get():
if event.type == pygame.QUIT:
return True
elif event.type == pygame.KEYUP:
if event.key == pygame.K_ESCAPE:
return True
return False
def clamp(value, minimum=0.0, maximum=100.0):
return max(minimum, min(value, maximum))
class Sun(object):
def __init__(self, azimuth, altitude):
self.azimuth = azimuth
self.altitude = altitude
self._t = 0.0
def tick(self, delta_seconds):
self._t += 0.008 * delta_seconds
self._t %= 2.0 * math.pi
self.azimuth += 0.25 * delta_seconds
self.azimuth %= 360.0
min_alt, max_alt = [20, 90]
self.altitude = 0.5 * (max_alt + min_alt) + 0.5 * (max_alt - min_alt) * math.cos(self._t)
def __str__(self):
return 'Sun(alt: %.2f, azm: %.2f)' % (self.altitude, self.azimuth)
class Storm(object):
def __init__(self, precipitation):
self._t = precipitation if precipitation > 0.0 else -50.0
self._increasing = True
self.clouds = 0.0
self.rain = 0.0
self.wetness = 0.0
self.puddles = 0.0
self.wind = 0.0
self.fog = 0.0
def tick(self, delta_seconds):
delta = (1.3 if self._increasing else -1.3) * delta_seconds
self._t = clamp(delta + self._t, -250.0, 100.0)
self.clouds = clamp(self._t + 40.0, 0.0, 90.0)
self.clouds = clamp(self._t + 40.0, 0.0, 60.0)
self.rain = clamp(self._t, 0.0, 80.0)
delay = -10.0 if self._increasing else 90.0
self.puddles = clamp(self._t + delay, 0.0, 85.0)
self.wetness = clamp(self._t * 5, 0.0, 100.0)
self.wind = 5.0 if self.clouds <= 20 else 90 if self.clouds >= 70 else 40
self.fog = clamp(self._t - 10, 0.0, 30.0)
if self._t == -250.0:
self._increasing = True
if self._t == 100.0:
self._increasing = False
def __str__(self):
return 'Storm(clouds=%d%%, rain=%d%%, wind=%d%%)' % (self.clouds, self.rain, self.wind)
class Weather(object):
def __init__(self, world, changing_weather_speed):
self.world = world
self.reset()
self.weather = world.get_weather()
self.changing_weather_speed = changing_weather_speed
self._sun = Sun(self.weather.sun_azimuth_angle, self.weather.sun_altitude_angle)
self._storm = Storm(self.weather.precipitation)
def reset(self):
weather_params = carla.WeatherParameters(sun_altitude_angle=90.)
self.world.set_weather(weather_params)
def tick(self):
self._sun.tick(self.changing_weather_speed)
self._storm.tick(self.changing_weather_speed)
self.weather.cloudiness = self._storm.clouds
self.weather.precipitation = self._storm.rain
self.weather.precipitation_deposits = self._storm.puddles
self.weather.wind_intensity = self._storm.wind
self.weather.fog_density = self._storm.fog
self.weather.wetness = self._storm.wetness
self.weather.sun_azimuth_angle = self._sun.azimuth
self.weather.sun_altitude_angle = self._sun.altitude
self.world.set_weather(self.weather)
def __str__(self):
return '%s %s' % (self._sun, self._storm)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--vision_size', type=int, default=84)
parser.add_argument('--vision_fov', type=int, default=90)
parser.add_argument('--weather', default=False, action='store_true')
parser.add_argument('--frame_skip', type=int, default=1),
parser.add_argument('--steps', type=int, default=100000)
parser.add_argument('--multiagent', default=False, action='store_true'),
parser.add_argument('--lane', type=int, default=0)
parser.add_argument('--lights', default=False, action='store_true')
args = parser.parse_args()
return args
class CarlaEnv(object):
def __init__(self, args):
self.render_display = False
self.record_display = False
self.record_vision = True
self.record_dir = None #'/nfs/kun1/users/aviralkumar/carla_data/'
self.vision_size = args.vision_size
self.vision_fov = args.vision_fov
self.changing_weather_speed = float(args.weather)
self.frame_skip = args.frame_skip
self.max_episode_steps = args.steps
self.multiagent = args.multiagent
self.start_lane = args.lane
self.follow_traffic_lights = args.lights
if self.record_display:
assert self.render_display
self.actor_list = []
if self.render_display:
pygame.init()
self.render_display = pygame.display.set_mode((800, 600), pygame.HWSURFACE | pygame.DOUBLEBUF)
self.font = get_font()
self.clock = pygame.time.Clock()
self.client = carla.Client('localhost', 2000)
self.client.set_timeout(2.0)
self.world = self.client.get_world()
self.map = self.world.get_map()
## Define the route planner
self.route_planner_dao = GlobalRoutePlannerDAO(self.map, sampling_resolution=0.1)
self.route_planner = CustomGlobalRoutePlanner(self.route_planner_dao)
# tests specific to map 4:
if self.start_lane and self.map.name != "Town04":
raise NotImplementedError
# remove old vehicles and sensors (in case they survived)
self.world.tick()
actor_list = self.world.get_actors()
for vehicle in actor_list.filter("*vehicle*"):
print("Warning: removing old vehicle")
vehicle.destroy()
for sensor in actor_list.filter("*sensor*"):
print("Warning: removing old sensor")
sensor.destroy()
self.vehicle = None
self.vehicles_list = [] # their ids
self.reset_vehicle() # creates self.vehicle
self.actor_list.append(self.vehicle)
blueprint_library = self.world.get_blueprint_library()
if self.render_display:
self.camera_display = self.world.spawn_actor(
blueprint_library.find('sensor.camera.rgb'),
carla.Transform(carla.Location(x=-5.5, z=2.8), carla.Rotation(pitch=-15)),
attach_to=self.vehicle)
self.actor_list.append(self.camera_display)
bp = blueprint_library.find('sensor.camera.rgb')
bp.set_attribute('image_size_x', str(self.vision_size))
bp.set_attribute('image_size_y', str(self.vision_size))
bp.set_attribute('fov', str(self.vision_fov))
location = carla.Location(x=1.6, z=1.7)
self.camera_vision = self.world.spawn_actor(bp, carla.Transform(location, carla.Rotation(yaw=0.0)), attach_to=self.vehicle)
self.actor_list.append(self.camera_vision)
if self.record_display or self.record_vision:
if self.record_dir is None:
self.record_dir = "carla-{}-{}x{}-fov{}".format(
self.map.name.lower(), self.vision_size, self.vision_size, self.vision_fov)
if self.frame_skip > 1:
self.record_dir += '-{}'.format(self.frame_skip)
if self.changing_weather_speed > 0.0:
self.record_dir += '-weather'
if self.multiagent:
self.record_dir += '-mutiagent'
if self.follow_traffic_lights:
self.record_dir += '-lights'
self.record_dir += '-{}k'.format(self.max_episode_steps // 1000)
now = datetime.datetime.now()
self.record_dir += now.strftime("-%Y-%m-%d-%H-%M-%S")
if not os.path.exists(self.record_dir):
os.mkdir(self.record_dir)
if self.render_display:
self.sync_mode = CarlaSyncMode(self.world, self.camera_display, self.camera_vision, fps=20)
else:
self.sync_mode = CarlaSyncMode(self.world, self.camera_vision, fps=20)
# weather
self.weather = Weather(self.world, self.changing_weather_speed)
# dummy variables, to match deep mind control's APIs
low = -1.0
high = 1.0
self.action_space = DotMap()
self.action_space.low.min = lambda: low
self.action_space.high.max = lambda: high
self.action_space.shape = [2]
self.observation_space = DotMap()
self.observation_space.shape = (3, self.vision_size, self.vision_size)
self.observation_space.dtype = np.dtype(np.uint8)
self.reward_range = None
self.metadata = None
self.action_space.sample = lambda: np.random.uniform(low=low, high=high, size=self.action_space.shape[0]).astype(np.float32)
# roaming carla agent
self.agent = None
self.world.tick()
self.reset_init() # creates self.agent
## Initialize the route planner
self.route_planner.setup()
## Collision detection
self._proximity_threshold = 10.0
self._traffic_light_threshold = 5.0
self.actor_list = self.world.get_actors()
for idx in range(len(self.actor_list)):
print (idx, self.actor_list[idx])
# import ipdb; ipdb.set_trace()
self.vehicle_list = self.actor_list.filter("*vehicle*")
self.lights_list = self.actor_list.filter("*traffic_light*")
self.object_list = self.actor_list.filter("*traffic.*")
## Initialize the route planner
self.route_planner.setup()
## The map is deterministic so for reward relabelling, we can
## instantiate the environment object and then query the distance function
## in the env, which directly uses this map_graph, and we need not save it.
self._map_graph = self.route_planner._graph
## This is a dummy for the target location, we can make this an input
## to the env in RL code.
self.target_location = carla.Location(x=-13.473097, y=134.311234, z=-0.010433)
## Now reset the env once
self.reset()
def reset_init(self):
self.reset_vehicle()
self.world.tick()
self.reset_other_vehicles()
self.world.tick()
self.agent = RoamingAgent(self.vehicle, follow_traffic_lights=self.follow_traffic_lights)
self.count = 0
self.ts = int(time.time())
def reset(self):
# get obs:
obs, _, _, _ = self.step()
return obs
def reset_vehicle(self):
if self.map.name == "Town04":
start_lane = -1
start_x = 5.0
vehicle_init_transform = carla.Transform(carla.Location(x=start_x, y=0, z=0.1), carla.Rotation(yaw=-90))
else:
init_transforms = self.world.get_map().get_spawn_points()
vehicle_init_transform = random.choice(init_transforms)
# TODO(aviral): start lane not defined for town, also for the town, we may not want to have
# the lane following reward, so it should be okay.
if self.vehicle is None: # then create the ego vehicle
blueprint_library = self.world.get_blueprint_library()
vehicle_blueprint = blueprint_library.find('vehicle.audi.a2')
self.vehicle = self.world.spawn_actor(vehicle_blueprint, vehicle_init_transform)
self.vehicle.set_transform(vehicle_init_transform)
self.vehicle.set_velocity(carla.Vector3D())
self.vehicle.set_angular_velocity(carla.Vector3D())
def reset_other_vehicles(self):
if not self.multiagent:
return
# clear out old vehicles
self.client.apply_batch([carla.command.DestroyActor(x) for x in self.vehicles_list])
self.world.tick()
self.vehicles_list = []
traffic_manager = self.client.get_trafficmanager()
traffic_manager.set_global_distance_to_leading_vehicle(2.0)
traffic_manager.set_synchronous_mode(True)
blueprints = self.world.get_blueprint_library().filter('vehicle.*')
blueprints = [x for x in blueprints if int(x.get_attribute('number_of_wheels')) == 4]
num_vehicles = 20
if self.map.name == "Town04":
road_id = 47
road_length = 117.
init_transforms = []
for _ in range(num_vehicles):
lane_id = random.choice([-1, -2, -3, -4])
vehicle_s = np.random.uniform(road_length) # length of road 47
init_transforms.append(self.map.get_waypoint_xodr(road_id, lane_id, vehicle_s).transform)
else:
init_transforms = self.world.get_map().get_spawn_points()
init_transforms = np.random.choice(init_transforms, num_vehicles)
# --------------
# Spawn vehicles
# --------------
batch = []
for transform in init_transforms:
transform.location.z += 0.1 # otherwise can collide with the road it starts on
blueprint = random.choice(blueprints)
if blueprint.has_attribute('color'):
color = random.choice(blueprint.get_attribute('color').recommended_values)
blueprint.set_attribute('color', color)
if blueprint.has_attribute('driver_id'):
driver_id = random.choice(blueprint.get_attribute('driver_id').recommended_values)
blueprint.set_attribute('driver_id', driver_id)
blueprint.set_attribute('role_name', 'autopilot')
batch.append(carla.command.SpawnActor(blueprint, transform).then(
carla.command.SetAutopilot(carla.command.FutureActor, True)))
for response in self.client.apply_batch_sync(batch, False):
self.vehicles_list.append(response.actor_id)
for response in self.client.apply_batch_sync(batch):
if response.error:
pass
else:
self.vehicles_list.append(response.actor_id)
traffic_manager.global_percentage_speed_difference(30.0)
def compute_action(self):
return self.agent.run_step()
def step(self, action=None, traffic_light_color=""):
rewards = []
for _ in range(self.frame_skip): # default 1
next_obs, reward, done, info = self._simulator_step(action, traffic_light_color)
rewards.append(reward)
if done:
break
return next_obs, np.mean(rewards), done, info
def _is_vehicle_hazard(self, vehicle, vehicle_list):
"""
:param vehicle_list: list of potential obstacle to check
:return: a tuple given by (bool_flag, vehicle), where
- bool_flag is True if there is a vehicle ahead blocking us
and False otherwise
- vehicle is the blocker object itself
"""
ego_vehicle_location = vehicle.get_location()
ego_vehicle_waypoint = self.map.get_waypoint(ego_vehicle_location)
for target_vehicle in vehicle_list:
# do not account for the ego vehicle
if target_vehicle.id == vehicle.id:
continue
# if the object is not in our lane it's not an obstacle
target_vehicle_waypoint = self.map.get_waypoint(target_vehicle.get_location())
if target_vehicle_waypoint.road_id != ego_vehicle_waypoint.road_id or \
target_vehicle_waypoint.lane_id != ego_vehicle_waypoint.lane_id:
continue
if is_within_distance_ahead(target_vehicle.get_transform(),
vehicle.get_transform(),
self._proximity_threshold/10.0):
return (True, -1.0, target_vehicle)
return (False, 0.0, None)
def _is_object_hazard(self, vehicle, object_list):
"""
:param vehicle_list: list of potential obstacle to check
:return: a tuple given by (bool_flag, vehicle), where
- bool_flag is True if there is a vehicle ahead blocking us
and False otherwise
- vehicle is the blocker object itself
"""
ego_vehicle_location = vehicle.get_location()
ego_vehicle_waypoint = self.map.get_waypoint(ego_vehicle_location)
for target_vehicle in object_list:
# do not account for the ego vehicle
if target_vehicle.id == vehicle.id:
continue
# if the object is not in our lane it's not an obstacle
target_vehicle_waypoint = self.map.get_waypoint(target_vehicle.get_location())
if target_vehicle_waypoint.road_id != ego_vehicle_waypoint.road_id or \
target_vehicle_waypoint.lane_id != ego_vehicle_waypoint.lane_id:
continue
if is_within_distance_ahead(target_vehicle.get_transform(),
vehicle.get_transform(),
self._proximity_threshold/40.0):
return (True, -1.0, target_vehicle)
return (False, 0.0, None)
def _is_light_red(self, vehicle):
"""
Method to check if there is a red light affecting us. This version of
the method is compatible with both European and US style traffic lights.
:param lights_list: list containing TrafficLight objects
:return: a tuple given by (bool_flag, traffic_light), where
- bool_flag is True if there is a traffic light in RED
affecting us and False otherwise
- traffic_light is the object itself or None if there is no
red traffic light affecting us
"""
ego_vehicle_location = vehicle.get_location()
ego_vehicle_waypoint = self.map.get_waypoint(ego_vehicle_location)
for traffic_light in self.lights_list:
object_location = self._get_trafficlight_trigger_location(traffic_light)
object_waypoint = self.map.get_waypoint(object_location)
if object_waypoint.road_id != ego_vehicle_waypoint.road_id:
continue
ve_dir = ego_vehicle_waypoint.transform.get_forward_vector()
wp_dir = object_waypoint.transform.get_forward_vector()
dot_ve_wp = ve_dir.x * wp_dir.x + ve_dir.y * wp_dir.y + ve_dir.z * wp_dir.z
if dot_ve_wp < 0:
continue
if is_within_distance_ahead(object_waypoint.transform,
vehicle.get_transform(),
self._traffic_light_threshold):
if traffic_light.state == carla.TrafficLightState.Red:
return (True, -0.1, traffic_light)
return (False, 0.0, None)
def _get_trafficlight_trigger_location(self, traffic_light): # pylint: disable=no-self-use
"""
Calculates the yaw of the waypoint that represents the trigger volume of the traffic light
"""
def rotate_point(point, radians):
"""
rotate a given point by a given angle
"""
rotated_x = math.cos(radians) * point.x - math.sin(radians) * point.y
rotated_y = math.sin(radians) * point.x - math.cos(radians) * point.y
return carla.Vector3D(rotated_x, rotated_y, point.z)
base_transform = traffic_light.get_transform()
base_rot = base_transform.rotation.yaw
area_loc = base_transform.transform(traffic_light.trigger_volume.location)
area_ext = traffic_light.trigger_volume.extent
point = rotate_point(carla.Vector3D(0, 0, area_ext.z), math.radians(base_rot))
point_location = area_loc + carla.Location(x=point.x, y=point.y)
return carla.Location(point_location.x, point_location.y, point_location.z)
def _get_collision_reward(self, vehicle):
vehicle_hazard, reward, vehicle_id = self._is_vehicle_hazard(vehicle, self.vehicle_list)
return vehicle_hazard, reward
def _get_traffic_light_reward(self, vehicle):
traffic_light_hazard, reward, traffic_light_id = self._is_light_red(vehicle)
return traffic_light_hazard, 0.0
def _get_object_collided_reward(self, vehicle):
object_hazard, reward, object_id = self._is_object_hazard(vehicle, self.object_list)
return object_hazard, reward
def goal_reaching_reward(self, vehicle):
# Now we will write goal_reaching_rewards
vehicle_location = vehicle.get_location()
target_location = self.target_location
# This is the distance computation
"""
dist = self.route_planner.compute_distance(vehicle_location, target_location)
base_reward = -1.0 * dist
collided_done, collision_reward = self._get_collision_reward(vehicle)
traffic_light_done, traffic_light_reward = self._get_traffic_light_reward(vehicle)
object_collided_done, object_collided_reward = self._get_object_collided_reward(vehicle)
total_reward = base_reward + 100 * collision_reward + 100 * traffic_light_reward + 100.0 * object_collided_reward
"""
vehicle_velocity = vehicle.get_velocity()
dist = self.route_planner.compute_distance(vehicle_location, target_location)
vel_forward, vel_perp = self.route_planner.compute_direction_velocities(vehicle_location, vehicle_velocity, target_location)
#print('[GoalReachReward] VehLoc: %s Target: %s Dist: %s VelF:%s' % (str(vehicle_location), str(target_location), str(dist), str(vel_forward)))
#base_reward = -1.0 * (dist / 100.0) + 5.0
base_reward = vel_forward
collided_done, collision_reward = self._get_collision_reward(vehicle)
traffic_light_done, traffic_light_reward = self._get_traffic_light_reward(vehicle)
object_collided_done, object_collided_reward = self._get_object_collided_reward(vehicle)
total_reward = base_reward + 100 * collision_reward # + 100 * traffic_light_reward + 100.0 * object_collided_reward
reward_dict = dict()
reward_dict['collision'] = collision_reward
reward_dict['traffic_light'] = traffic_light_reward
reward_dict['object_collision'] = object_collided_reward
reward_dict['base_reward'] = base_reward
reward_dict['vel_forward'] = vel_forward
reward_dict['vel_perp'] = vel_perp
done_dict = dict()
done_dict['collided_done'] = collided_done
done_dict['traffic_light_done'] = traffic_light_done
done_dict['object_collided_done'] = object_collided_done
return total_reward, reward_dict, done_dict
def _simulator_step(self, action, traffic_light_color):
if self.render_display:
if should_quit():
return
self.clock.tick()
if action is None:
throttle, steer, brake = 0., 0., 0.
else:
throttle, steer, brake = action.throttle, action.steer, action.brake
# throttle = clamp(throttle, minimum=0.005, maximum=0.995) + np.random.uniform(low=-0.003, high=0.003)
# steer = clamp(steer, minimum=-0.995, maximum=0.995) + np.random.uniform(low=-0.003, high=0.003)
# brake = clamp(brake, minimum=0.005, maximum=0.995) + np.random.uniform(low=-0.003, high=0.003)
vehicle_control = carla.VehicleControl(
throttle=throttle, # [0,1]
steer=steer, # [-1,1]
brake=brake, # [0,1]
hand_brake=False,
reverse=False,
manual_gear_shift=False
)
self.vehicle.apply_control(vehicle_control)
# Advance the simulation and wait for the data.
if self.render_display:
snapshot, display_image, vision_image = self.sync_mode.tick(timeout=2.0)
else:
snapshot, vision_image = self.sync_mode.tick(timeout=2.0)
# Weather evolves
self.weather.tick()
# Draw the display.
if self.render_display:
draw_image(self.render_display, display_image)
self.render_display.blit(self.font.render('Frame %d' % self.count, True, (255, 255, 255)), (8, 10))
self.render_display.blit(self.font.render('Control: %5.2f thottle, %5.2f steer, %5.2f brake' % (throttle, steer, brake), True, (255, 255, 255)), (8, 28))
self.render_display.blit(self.font.render('Traffic light: ' + traffic_light_color, True, (255, 255, 255)), (8, 46))
self.render_display.blit(self.font.render(str(self.weather), True, (255, 255, 255)), (8, 64))
pygame.display.flip()
# Format rl image
bgra = np.array(vision_image.raw_data).reshape(self.vision_size, self.vision_size, 4) # BGRA format
bgr = bgra[:, :, :3] # BGR format (84 x 84 x 3)
rgb = np.flip(bgr, axis=2) # RGB format (84 x 84 x 3)
reward, reward_dict, done_dict = self.goal_reaching_reward(self.vehicle)
if self.render_display and self.record_display:
image_name = os.path.join(self.record_dir, "display%08d.jpg" % self.count)
pygame.image.save(self.render_display, image_name)
# # Can animate with:
# ffmpeg -r 20 -pattern_type glob -i 'display*.jpg' carla.mp4
if self.record_vision:
image_name = os.path.join(self.record_dir, "vision_%d_%08d.png" % (self.ts, self.count))
im = Image.fromarray(rgb)
# add any eta data you like into the image before we save it:
metadata = PngInfo()
# control
metadata.add_text("control_throttle", str(throttle))
metadata.add_text("control_steer", str(steer))
metadata.add_text("control_brake", str(brake))
metadata.add_text("control_repeat", str(self.frame_skip))
# acceleration
acceleration = self.vehicle.get_acceleration()
metadata.add_text("acceleration_x", str(acceleration.x))
metadata.add_text("acceleration_y", str(acceleration.y))
metadata.add_text("acceleration_z", str(acceleration.z))
# angular velocity
angular_velocity = self.vehicle.get_angular_velocity()
metadata.add_text("angular_velocity_x", str(angular_velocity.x))
metadata.add_text("angular_velocity_y", str(angular_velocity.y))
metadata.add_text("angular_velocity_z", str(angular_velocity.z))
# location
location = self.vehicle.get_location()
print('Location:', location)
metadata.add_text("location_x", str(location.x))
metadata.add_text("location_y", str(location.y))
metadata.add_text("location_z", str(location.z))
# rotation
rotation = self.vehicle.get_transform().rotation
metadata.add_text("rotation_pitch", str(rotation.pitch))
metadata.add_text("rotation_yaw", str(rotation.yaw))
metadata.add_text("rotation_roll", str(rotation.roll))
forward_vector = rotation.get_forward_vector()
metadata.add_text("forward_vector_x", str(forward_vector.x))
metadata.add_text("forward_vector_y", str(forward_vector.y))
metadata.add_text("forward_vector_z", str(forward_vector.z))
# velocity
velocity = self.vehicle.get_velocity()
metadata.add_text("velocity_x", str(velocity.x))
metadata.add_text("velocity_y", str(velocity.y))
metadata.add_text("velocity_z", str(velocity.z))
# weather
metadata.add_text("weather_cloudiness ", str(self.weather.weather.cloudiness))
metadata.add_text("weather_precipitation", str(self.weather.weather.precipitation))
metadata.add_text("weather_precipitation_deposits", str(self.weather.weather.precipitation_deposits))
metadata.add_text("weather_wind_intensity", str(self.weather.weather.wind_intensity))
metadata.add_text("weather_fog_density", str(self.weather.weather.fog_density))
metadata.add_text("weather_wetness", str(self.weather.weather.wetness))
metadata.add_text("weather_sun_azimuth_angle", str(self.weather.weather.sun_azimuth_angle))
# settings
metadata.add_text("settings_map", self.map.name)
metadata.add_text("settings_vision_size", str(self.vision_size))
metadata.add_text("settings_vision_fov", str(self.vision_fov))
metadata.add_text("settings_changing_weather_speed", str(self.changing_weather_speed))
metadata.add_text("settings_multiagent", str(self.multiagent))
# traffic lights
metadata.add_text("traffic_lights_color", "UNLABELED")
metadata.add_text("reward", str(reward))
## Add in reward dict
for key in reward_dict:
metadata.add_text("reward_" + str(key), str(reward_dict[key]))
for key in done_dict:
metadata.add_text("done_" + str(key), str(done_dict[key]))
## Save the target location as well
metadata.add_text('target_location_x', str(self.target_location.x))
metadata.add_text('target_location_y', str(self.target_location.y))
metadata.add_text('target_location_z', str(self.target_location.z))
im.save(image_name, "PNG", pnginfo=metadata)
# # To read these images later, you can run something like this:
# from PIL.PngImagePlugin import PngImageFile
# im = PngImageFile("vision00001234.png")
# throttle = float(im.text['throttle']) # range [0, 1]
# steer = float(im.text['steer']) # range [-1, 1]
# brake = float(im.text['brake']) # range [0, 1]
# lights = im.text['lights'] # traffic lights color, [NONE, JUNCTION, RED, YELLOW, GREEN]
self.count += 1
next_obs = rgb # 84 x 84 x 3
# # To inspect images, run:
# import pdb; pdb.set_trace()
# import matplotlib.pyplot as plt
# plt.imshow(next_obs)
# plt.show()
done = False #self.count >= self.max_episode_steps
if done:
print("Episode success: I've reached the episode horizon ({}).".format(self.max_episode_steps))
# print ('reward: ', reward)
info = reward_dict
info.update(done_dict)
done = False
for key in done_dict:
done = (done or done_dict[key])
return next_obs, reward, done, info
def finish(self):
print('destroying actors.')
for actor in self.actor_list:
actor.destroy()
print('\ndestroying %d vehicles' % len(self.vehicles_list))
self.client.apply_batch([carla.command.DestroyActor(x) for x in self.vehicles_list])
time.sleep(0.5)
pygame.quit()
print('done.')
class LocalPlannerModified(LocalPlanner):
def __del__(self):
pass # otherwise it deletes our vehicle object
def run_step(self):
return super().run_step(debug=False) # otherwise by default shows waypoints, that interfere with our camera
class RoamingAgent(Agent):
"""
RoamingAgent implements a basic agent that navigates scenes making random
choices when facing an intersection.
This agent respects traffic lights and other vehicles.
"""
def __init__(self, vehicle, follow_traffic_lights=True):
"""
:param vehicle: actor to apply to local planner logic onto
"""
super(RoamingAgent, self).__init__(vehicle)
self._proximity_threshold = 10.0 # meters
self._state = AgentState.NAVIGATING
self._local_planner = LocalPlannerModified(self._vehicle)
self._follow_traffic_lights = follow_traffic_lights
def run_step(self):
"""
Execute one step of navigation.
:return: carla.VehicleControl
"""
# is there an obstacle in front of us?
hazard_detected = False
# retrieve relevant elements for safe navigation, i.e.: traffic lights and other vehicles
actor_list = self._world.get_actors()
vehicle_list = actor_list.filter("*vehicle*")
lights_list = actor_list.filter("*traffic_light*")
# check possible obstacles
vehicle_state, vehicle = self._is_vehicle_hazard(vehicle_list)
if vehicle_state:
self._state = AgentState.BLOCKED_BY_VEHICLE
hazard_detected = True
# check for the state of the traffic lights
traffic_light_color = self._is_light_red(lights_list)
if traffic_light_color == 'RED' and self._follow_traffic_lights:
self._state = AgentState.BLOCKED_RED_LIGHT
hazard_detected = True
if hazard_detected:
control = self.emergency_stop()
else:
self._state = AgentState.NAVIGATING
# standard local planner behavior
control = self._local_planner.run_step()
return control, traffic_light_color
# override case class
def _is_light_red_europe_style(self, lights_list):
"""
This method is specialized to check European style traffic lights.
Only suitable for Towns 03 -- 07.
"""
ego_vehicle_location = self._vehicle.get_location()
ego_vehicle_waypoint = self._map.get_waypoint(ego_vehicle_location)
traffic_light_color = "NONE" # default, if no traffic lights are seen
for traffic_light in lights_list:
object_waypoint = self._map.get_waypoint(traffic_light.get_location())
if object_waypoint.road_id != ego_vehicle_waypoint.road_id or \
object_waypoint.lane_id != ego_vehicle_waypoint.lane_id:
continue
if is_within_distance_ahead(traffic_light.get_transform(),
self._vehicle.get_transform(),
self._proximity_threshold):
if traffic_light.state == carla.TrafficLightState.Red:
return "RED"
elif traffic_light.state == carla.TrafficLightState.Yellow:
traffic_light_color = "YELLOW"
elif traffic_light.state == carla.TrafficLightState.Green:
if traffic_light_color is not "YELLOW": # (more severe)
traffic_light_color = "GREEN"
else:
import pdb; pdb.set_trace()
# investigate https://carla.readthedocs.io/en/latest/python_api/#carlatrafficlightstate
return traffic_light_color
# override case class
def _is_light_red_us_style(self, lights_list, debug=False):
ego_vehicle_location = self._vehicle.get_location()
ego_vehicle_waypoint = self._map.get_waypoint(ego_vehicle_location)
traffic_light_color = "NONE" # default, if no traffic lights are seen
if ego_vehicle_waypoint.is_junction:
# It is too late. Do not block the intersection! Keep going!
return "JUNCTION"
if self._local_planner.target_waypoint is not None:
if self._local_planner.target_waypoint.is_junction:
min_angle = 180.0
sel_magnitude = 0.0
sel_traffic_light = None
for traffic_light in lights_list:
loc = traffic_light.get_location()
magnitude, angle = compute_magnitude_angle(loc,
ego_vehicle_location,
self._vehicle.get_transform().rotation.yaw)
if magnitude < 60.0 and angle < min(25.0, min_angle):
sel_magnitude = magnitude
sel_traffic_light = traffic_light
min_angle = angle
if sel_traffic_light is not None:
if debug:
print('=== Magnitude = {} | Angle = {} | ID = {}'.format(
sel_magnitude, min_angle, sel_traffic_light.id))
if self._last_traffic_light is None:
self._last_traffic_light = sel_traffic_light
if self._last_traffic_light.state == carla.TrafficLightState.Red:
return "RED"
elif self._last_traffic_light.state == carla.TrafficLightState.Yellow:
traffic_light_color = "YELLOW"
elif self._last_traffic_light.state == carla.TrafficLightState.Green:
if traffic_light_color is not "YELLOW": # (more severe)
traffic_light_color = "GREEN"
else:
import pdb; pdb.set_trace()
# investigate https://carla.readthedocs.io/en/latest/python_api/#carlatrafficlightstate
else:
self._last_traffic_light = None
return traffic_light_color
if __name__ == '__main__':
# example call:
# ./PythonAPI/util/config.py --map Town01 --delta-seconds 0.05
# python PythonAPI/carla/agents/navigation/data_collection_agent.py --vision_size 256 --vision_fov 90 --steps 10000 --weather --lights
args = parse_args()
env = CarlaEnv(args)
curr_steps = 0
try:
done = False
while not done:
curr_steps += 1
action, traffic_light_color = env.compute_action()
next_obs, reward, done, info = env.step(action, traffic_light_color)
print ('Reward: ', reward, 'Done: ', done, 'Location: ', env.vehicle.get_location())
if done:
# env.reset_init()
# env.reset()
done = False
if curr_steps % 5000 == 4999:
env.reset_init()
env.reset()
finally:
env.finish()
================================================
FILE: d4rl/d4rl/carla/town_agent.py
================================================
# A baseline town agent.
from agents.navigation.agent import Agent, AgentState
import numpy as np
from agents.navigation.local_planner import LocalPlanner
class RoamingAgent(Agent):
"""
RoamingAgent implements a basic agent that navigates scenes making random
choices when facing an intersection.
This agent respects traffic lights and other vehicles.
NOTE: need to re-create after each env reset
"""
def __init__(self, env):
"""
:param vehicle: actor to apply to local planner logic onto
"""
vehicle = env.vehicle
follow_traffic_lights = env.follow_traffic_lights
super(RoamingAgent, self).__init__(vehicle)
self._proximity_threshold = 10.0 # meters
self._state = AgentState.NAVIGATING
self._local_planner = LocalPlannerModified(self._vehicle)
self._follow_traffic_lights = follow_traffic_lights
def compute_action(self):
action, traffic_light = self.run_step()
throttle = action.throttle
brake = action.brake
steer = action.steer
#print('tbsl:', throttle, brake, steer, traffic_light)
if brake == 0.0:
return np.array([throttle, steer])
else:
return np.array([-brake, steer])
def run_step(self):
"""
Execute one step of navigation.
:return: carla.VehicleControl
"""
# is there an obstacle in front of us?
hazard_detected = False
# retrieve relevant elements for safe navigation, i.e.: traffic lights and other vehicles
actor_list = self._world.get_actors()
vehicle_list = actor_list.filter("*vehicle*")
lights_list = actor_list.filter("*traffic_light*")
# check possible obstacles
vehicle_state, vehicle = self._is_vehicle_hazard(vehicle_list)
if vehicle_state:
self._state = AgentState.BLOCKED_BY_VEHICLE
hazard_detected = True
# check for the state of the traffic lights
if hazard_detected:
control = self.emergency_stop()
else:
self._state = AgentState.NAVIGATING
# standard local planner behavior
control = self._local_planner.run_step()
throttle = control.throttle
brake = control.brake
steer = control.steer
#print('tbsl:', throttle, brake, steer, traffic_light)
if brake == 0.0:
return np.array([throttle, steer])
else:
return np.array([-brake, steer])
class LocalPlannerModified(LocalPlanner):
def __del__(self):
pass # otherwise it deletes our vehicle object
def run_step(self):
return super().run_step(debug=False) # otherwise by default shows waypoints, that interfere with our camera
class DummyTownAgent(Agent):
"""
A simple agent for the town driving task.
If the car is currently facing on a path towards the goal, drive forward.
If the car would start drivign away, apply maximum brakes.
"""
def __init__(self, env):
"""
:param vehicle: actor to apply to local planner logic onto
"""
self.env = env
super(DummyTownAgent, self).__init__(self.env.vehicle)
self._proximity_threshold = 10.0 # meters
self._state = AgentState.NAVIGATING
self._local_planner = LocalPlannerModified(self._vehicle)
def compute_action(self):
hazard_detected = False
# retrieve relevant elements for safe navigation, i.e.: traffic lights and other vehicles
actor_list = self._world.get_actors()
vehicle_list = actor_list.filter("*vehicle*")
lights_list = actor_list.filter("*traffic_light*")
# check possible obstacles
vehicle_state, vehicle = self._is_vehicle_hazard(vehicle_list)
if vehicle_state:
self._state = AgentState.BLOCKED_BY_VEHICLE
hazard_detected = True
rotation = self.env.vehicle.get_transform().rotation
forward_vector = rotation.get_forward_vector()
origin = self.env.vehicle.get_location()
destination = self.env.target_location
node_list = self.env.route_planner._path_search(origin=origin, destination=destination)
origin_xy = np.array([origin.x, origin.y])
forward_xy = np.array([forward_vector.x, forward_vector.y])
first_node_xy = self.env.route_planner._graph.nodes[node_list[0]]['vertex']
first_node_xy = np.array([first_node_xy[0], first_node_xy[1]])
target_direction_vector = first_node_xy - origin_xy
target_unit_vector = np.array(target_direction_vector) / np.linalg.norm(target_direction_vector)
vel_s = np.dot(forward_xy, target_unit_vector)
if vel_s < 0:
hazard_detected = True
if hazard_detected:
control = self.emergency_stop()
else:
self._state = AgentState.NAVIGATING
# standard local planner behavior
control = self._local_planner.run_step()
throttle = control.throttle
brake = control.brake
steer = control.steer
#print('tbsl:', throttle, brake, steer, traffic_light)
if brake == 0.0:
return np.array([throttle, steer])
else:
return np.array([-brake, steer])
================================================
FILE: d4rl/d4rl/flow/__init__.py
================================================
import gym
import os
from d4rl import offline_env
from gym.envs.registration import register
from copy import deepcopy
import flow
import flow.envs
from flow.networks.ring import RingNetwork
from flow.core.params import NetParams, VehicleParams, EnvParams, InFlows
from flow.core.params import SumoLaneChangeParams, SumoCarFollowingParams
from flow.networks.ring import ADDITIONAL_NET_PARAMS
from flow.controllers.car_following_models import IDMController
from flow.controllers.routing_controllers import ContinuousRouter
from flow.controllers import SimCarFollowingController, SimLaneChangeController
from flow.controllers import RLController
from flow.core.params import InitialConfig
from flow.core.params import TrafficLightParams
from flow.envs.ring.accel import AccelEnv
from flow.core.params import SumoParams
from flow.utils.registry import make_create_env
from flow.envs import WaveAttenuationPOEnv
from flow.envs import BayBridgeEnv, TrafficLightGridPOEnv
from d4rl.flow import traffic_light_grid
from d4rl.flow import merge
from d4rl.flow import bottleneck
def flow_register(flow_params, render=None, **kwargs):
exp_tag = flow_params["exp_tag"]
env_params = flow_params['env']
net_params = flow_params['net']
env_class = flow_params['env_name']
initial_config = flow_params.get('initial', InitialConfig())
traffic_lights = flow_params.get("tls", TrafficLightParams())
sim_params = deepcopy(flow_params['sim'])
vehicles = deepcopy(flow_params['veh'])
sim_params.render = render or sim_params.render
if isinstance(flow_params["network"], str):
print("""Passing of strings for network will be deprecated.
Please pass the Network instance instead.""")
module = __import__("flow.networks", fromlist=[flow_params["network"]])
network_class = getattr(module, flow_params["network"])
else:
network_class = flow_params["network"]
network = network_class(
name=exp_tag,
vehicles=vehicles,
net_params=net_params,
initial_config=initial_config,
traffic_lights=traffic_lights,
)
flow_env = env_class(
env_params= env_params,
sim_params= sim_params,
network= network,
simulator= flow_params['simulator']
)
env = offline_env.OfflineEnvWrapper(flow_env,
**kwargs
)
return env
def ring_env(render='drgb'):
name = "ring"
network_name = RingNetwork
env_name = WaveAttenuationPOEnv
net_params = NetParams(additional_params=ADDITIONAL_NET_PARAMS)
initial_config = InitialConfig(spacing="uniform", shuffle=False)
vehicles = VehicleParams()
vehicles.add("human",
acceleration_controller=(IDMController, {}),
routing_controller=(ContinuousRouter, {}),
num_vehicles=21)
vehicles.add(veh_id="rl",
acceleration_controller=(RLController, {}),
routing_controller=(ContinuousRouter, {}),
num_vehicles=1)
sim_params = SumoParams(sim_step=0.5, render=render, save_render=True)
HORIZON=100
env_params = EnvParams(
# length of one rollout
horizon=HORIZON,
additional_params={
# maximum acceleration of autonomous vehicles
"max_accel": 1,
# maximum deceleration of autonomous vehicles
"max_decel": 1,
# bounds on the ranges of ring road lengths the autonomous vehicle
# is trained on
"ring_length": [220, 270],
},
)
flow_params = dict(
exp_tag=name,
env_name=env_name,
network=network_name,
simulator='traci',
sim=sim_params,
env=env_params,
net=net_params,
veh=vehicles,
initial=initial_config
)
return flow_params
RING_RANDOM_SCORE = -165.22
RING_EXPERT_SCORE = 24.42
register(
id='flow-ring-v0',
entry_point='d4rl.flow:flow_register',
max_episode_steps=500,
kwargs={
'flow_params': ring_env(render=False),
'dataset_url': None,
'ref_min_score': RING_RANDOM_SCORE,
'ref_max_score': RING_EXPERT_SCORE
}
)
register(
id='flow-ring-render-v0',
entry_point='d4rl.flow:flow_register',
max_episode_steps=500,
kwargs={
'flow_params': ring_env(render='drgb'),
'dataset_url': None,
'ref_min_score': RING_RANDOM_SCORE,
'ref_max_score': RING_EXPERT_SCORE
}
)
register(
id='flow-ring-random-v0',
entry_point='d4rl.flow:flow_register',
max_episode_steps=500,
kwargs={
'flow_params': ring_env(render=False),
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-random.hdf5',
'ref_min_score': RING_RANDOM_SCORE,
'ref_max_score': RING_EXPERT_SCORE
}
)
register(
id='flow-ring-controller-v0',
entry_point='d4rl.flow:flow_register',
max_episode_steps=500,
kwargs={
'flow_params': ring_env(render=False),
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-idm.hdf5',
'ref_min_score': RING_RANDOM_SCORE,
'ref_max_score': RING_EXPERT_SCORE
}
)
MERGE_RANDOM_SCORE = 118.67993
MERGE_EXPERT_SCORE = 330.03179
register(
id='flow-merge-v0',
entry_point='d4rl.flow:flow_register',
max_episode_steps=750,
kwargs={
'flow_params': merge.gen_env(render=False),
'dataset_url': None,
'ref_min_score': MERGE_RANDOM_SCORE,
'ref_max_score': MERGE_EXPERT_SCORE
}
)
register(
id='flow-merge-render-v0',
entry_point='d4rl.flow:flow_register',
max_episode_steps=750,
kwargs={
'flow_params': merge.gen_env(render='drgb'),
'dataset_url': None,
'ref_min_score': MERGE_RANDOM_SCORE,
'ref_max_score': MERGE_EXPERT_SCORE
}
)
register(
id='flow-merge-random-v0',
entry_point='d4rl.flow:flow_register',
max_episode_steps=750,
kwargs={
'flow_params': merge.gen_env(render=False),
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-random.hdf5',
'ref_min_score': MERGE_RANDOM_SCORE,
'ref_max_score': MERGE_EXPERT_SCORE
}
)
register(
id='flow-merge-controller-v0',
entry_point='d4rl.flow:flow_register',
max_episode_steps=750,
kwargs={
'flow_params': merge.gen_env(render=False),
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-idm.hdf5',
'ref_min_score': MERGE_RANDOM_SCORE,
'ref_max_score': MERGE_EXPERT_SCORE
}
)
================================================
FILE: d4rl/d4rl/flow/bottleneck.py
================================================
import flow
import flow.envs
from flow.core.params import NetParams, VehicleParams, EnvParams, InFlows
from flow.core.params import SumoLaneChangeParams, SumoCarFollowingParams
from flow.networks.ring import ADDITIONAL_NET_PARAMS
from flow.controllers.routing_controllers import ContinuousRouter
from flow.controllers import SimCarFollowingController, SimLaneChangeController
from flow.controllers import RLController
from flow.core.params import InitialConfig
from flow.core.params import TrafficLightParams
from flow.core.params import SumoParams
from flow.envs import BottleneckDesiredVelocityEnv
from flow.networks import BottleneckNetwork
def bottleneck(render='drgb'):
# time horizon of a single rollout
HORIZON = 1500
SCALING = 1
NUM_LANES = 4 * SCALING # number of lanes in the widest highway
DISABLE_TB = True
DISABLE_RAMP_METER = True
AV_FRAC = 0.10
vehicles = VehicleParams()
vehicles.add(
veh_id="human",
routing_controller=(ContinuousRouter, {}),
car_following_params=SumoCarFollowingParams(
speed_mode=9,
),
lane_change_params=SumoLaneChangeParams(
lane_change_mode=0,
),
num_vehicles=1 * SCALING)
vehicles.add(
veh_id="rl",
acceleration_controller=(RLController, {}),
routing_controller=(ContinuousRouter, {}),
car_following_params=SumoCarFollowingParams(
speed_mode=9,
),
lane_change_params=SumoLaneChangeParams(
lane_change_mode=0,
),
num_vehicles=1 * SCALING)
controlled_segments = [("1", 1, False), ("2", 2, True), ("3", 2, True),
("4", 2, True), ("5", 1, False)]
num_observed_segments = [("1", 1), ("2", 3), ("3", 3), ("4", 3), ("5", 1)]
additional_env_params = {
"target_velocity": 40,
"disable_tb": True,
"disable_ramp_metering": True,
"controlled_segments": controlled_segments,
"symmetric": False,
"observed_segments": num_observed_segments,
"reset_inflow": False,
"lane_change_duration": 5,
"max_accel": 3,
"max_decel": 3,
"inflow_range": [1200, 2500]
}
# flow rate
flow_rate = 2500 * SCALING
# percentage of flow coming out of each lane
inflow = InFlows()
inflow.add(
veh_type="human",
edge="1",
vehs_per_hour=flow_rate * (1 - AV_FRAC),
depart_lane="random",
depart_speed=10)
inflow.add(
veh_type="rl",
edge="1",
vehs_per_hour=flow_rate * AV_FRAC,
depart_lane="random",
depart_speed=10)
traffic_lights = TrafficLightParams()
if not DISABLE_TB:
traffic_lights.add(node_id="2")
if not DISABLE_RAMP_METER:
traffic_lights.add(node_id="3")
additional_net_params = {"scaling": SCALING, "speed_limit": 23}
net_params = NetParams(
inflows=inflow,
additional_params=additional_net_params)
flow_params = dict(
# name of the experiment
exp_tag="bottleneck_0",
# name of the flow environment the experiment is running on
env_name=BottleneckDesiredVelocityEnv,
# name of the network class the experiment is running on
network=BottleneckNetwork,
# simulator that is used by the experiment
simulator='traci',
# sumo-related parameters (see flow.core.params.SumoParams)
sim=SumoParams(
sim_step=0.5,
render=render,
save_render=True,
print_warnings=False,
restart_instance=True,
),
# environment related parameters (see flow.core.params.EnvParams)
env=EnvParams(
warmup_steps=40,
sims_per_step=1,
horizon=HORIZON,
additional_params=additional_env_params,
),
# network-related parameters (see flow.core.params.NetParams and the
# network's documentation or ADDITIONAL_NET_PARAMS component)
net=NetParams(
inflows=inflow,
additional_params=additional_net_params,
),
# vehicles to be placed in the network at the start of a rollout (see
# flow.core.params.VehicleParams)
veh=vehicles,
# parameters specifying the positioning of vehicles upon initialization/
# reset (see flow.core.params.InitialConfig)
initial=InitialConfig(
spacing="uniform",
min_gap=5,
lanes_distribution=float("inf"),
edges_distribution=["2", "3", "4", "5"],
),
# traffic lights to be introduced to specific nodes (see
# flow.core.params.TrafficLightParams)
tls=traffic_lights,
)
return flow_params
================================================
FILE: d4rl/d4rl/flow/merge.py
================================================
"""Open merge example.
Trains a a small percentage of rl vehicles to dissipate shockwaves caused by
on-ramp merge to a single lane open highway network.
"""
from flow.envs import MergePOEnv
from flow.networks import MergeNetwork
from copy import deepcopy
from flow.core.params import SumoParams, EnvParams, InitialConfig, NetParams, \
InFlows, SumoCarFollowingParams
from flow.networks.merge import ADDITIONAL_NET_PARAMS
from flow.core.params import VehicleParams
from flow.controllers import SimCarFollowingController, RLController
def gen_env(render='drgb'):
# time horizon of a single rollout
HORIZON = 750
# inflow rate at the highway
FLOW_RATE = 2000
# percent of autonomous vehicles
RL_PENETRATION = 0.1
# num_rl term (see ADDITIONAL_ENV_PARAMs)
NUM_RL = 5
# We consider a highway network with an upstream merging lane producing
# shockwaves
additional_net_params = deepcopy(ADDITIONAL_NET_PARAMS)
additional_net_params["merge_lanes"] = 1
additional_net_params["highway_lanes"] = 1
additional_net_params["pre_merge_length"] = 500
# RL vehicles constitute 5% of the total number of vehicles
vehicles = VehicleParams()
vehicles.add(
veh_id="human",
acceleration_controller=(SimCarFollowingController, {}),
car_following_params=SumoCarFollowingParams(
speed_mode=9,
),
num_vehicles=5)
vehicles.add(
veh_id="rl",
acceleration_controller=(RLController, {}),
car_following_params=SumoCarFollowingParams(
speed_mode=9,
),
num_vehicles=0)
# Vehicles are introduced from both sides of merge, with RL vehicles entering
# from the highway portion as well
inflow = InFlows()
inflow.add(
veh_type="human",
edge="inflow_highway",
vehs_per_hour=(1 - RL_PENETRATION) * FLOW_RATE,
depart_lane="free",
depart_speed=10)
inflow.add(
veh_type="rl",
edge="inflow_highway",
vehs_per_hour=RL_PENETRATION * FLOW_RATE,
depart_lane="free",
depart_speed=10)
inflow.add(
veh_type="human",
edge="inflow_merge",
vehs_per_hour=100,
depart_lane="free",
depart_speed=7.5)
flow_params = dict(
# name of the experiment
exp_tag="merge_0",
# name of the flow environment the experiment is running on
env_name=MergePOEnv,
# name of the network class the experiment is running on
network=MergeNetwork,
# simulator that is used by the experiment
simulator='traci',
# sumo-related parameters (see flow.core.params.SumoParams)
sim=SumoParams(
restart_instance=True,
sim_step=0.5,
render=render,
save_render=True
),
# environment related parameters (see flow.core.params.EnvParams)
env=EnvParams(
horizon=HORIZON,
sims_per_step=2,
warmup_steps=0,
additional_params={
"max_accel": 1.5,
"max_decel": 1.5,
"target_velocity": 20,
"num_rl": NUM_RL,
},
),
# network-related parameters (see flow.core.params.NetParams and the
# network's documentation or ADDITIONAL_NET_PARAMS component)
net=NetParams(
inflows=inflow,
additional_params=additional_net_params,
),
# vehicles to be placed in the network at the start of a rollout (see
# flow.core.params.VehicleParams)
veh=vehicles,
# parameters specifying the positioning of vehicles upon initialization/
# reset (see flow.core.params.InitialConfig)
initial=InitialConfig(),
)
return flow_params
================================================
FILE: d4rl/d4rl/flow/traffic_light_grid.py
================================================
"""Traffic Light Grid example."""
from flow.envs import TrafficLightGridBenchmarkEnv
from flow.networks import TrafficLightGridNetwork
from flow.core.params import SumoParams, EnvParams, InitialConfig, NetParams, \
InFlows, SumoCarFollowingParams
from flow.core.params import VehicleParams
from flow.controllers import SimCarFollowingController, GridRouter
def gen_env(render='drgb'):
# time horizon of a single rollout
HORIZON = 400
# inflow rate of vehicles at every edge
EDGE_INFLOW = 300
# enter speed for departing vehicles
V_ENTER = 30
# number of row of bidirectional lanes
N_ROWS = 3
# number of columns of bidirectional lanes
N_COLUMNS = 3
# length of inner edges in the grid network
INNER_LENGTH = 300
# length of final edge in route
LONG_LENGTH = 100
# length of edges that vehicles start on
SHORT_LENGTH = 300
# number of vehicles originating in the left, right, top, and bottom edges
N_LEFT, N_RIGHT, N_TOP, N_BOTTOM = 1, 1, 1, 1
# we place a sufficient number of vehicles to ensure they confirm with the
# total number specified above. We also use a "right_of_way" speed mode to
# support traffic light compliance
vehicles = VehicleParams()
vehicles.add(
veh_id="human",
acceleration_controller=(SimCarFollowingController, {}),
car_following_params=SumoCarFollowingParams(
min_gap=2.5,
max_speed=V_ENTER,
decel=7.5, # avoid collisions at emergency stops
speed_mode="right_of_way",
),
routing_controller=(GridRouter, {}),
num_vehicles=(N_LEFT + N_RIGHT) * N_COLUMNS + (N_BOTTOM + N_TOP) * N_ROWS)
# inflows of vehicles are place on all outer edges (listed here)
outer_edges = []
outer_edges += ["left{}_{}".format(N_ROWS, i) for i in range(N_COLUMNS)]
outer_edges += ["right0_{}".format(i) for i in range(N_ROWS)]
outer_edges += ["bot{}_0".format(i) for i in range(N_ROWS)]
outer_edges += ["top{}_{}".format(i, N_COLUMNS) for i in range(N_ROWS)]
# equal inflows for each edge (as dictate by the EDGE_INFLOW constant)
inflow = InFlows()
for edge in outer_edges:
inflow.add(
veh_type="human",
edge=edge,
vehs_per_hour=EDGE_INFLOW,
depart_lane="free",
depart_speed=V_ENTER)
flow_params = dict(
# name of the experiment
exp_tag="grid_0",
# name of the flow environment the experiment is running on
env_name=TrafficLightGridBenchmarkEnv,
# name of the network class the experiment is running on
network=TrafficLightGridNetwork,
# simulator that is used by the experiment
simulator='traci',
# sumo-related parameters (see flow.core.params.SumoParams)
sim=SumoParams(
restart_instance=True,
sim_step=1,
render=render,
save_render=True,
),
# environment related parameters (see flow.core.params.EnvParams)
env=EnvParams(
horizon=HORIZON,
additional_params={
"target_velocity": 50,
"switch_time": 3,
"num_observed": 2,
"discrete": False,
"tl_type": "actuated"
},
),
# network-related parameters (see flow.core.params.NetParams and the
# network's documentation or ADDITIONAL_NET_PARAMS component)
net=NetParams(
inflows=inflow,
additional_params={
"speed_limit": V_ENTER + 5,
"grid_array": {
"short_length": SHORT_LENGTH,
"inner_length": INNER_LENGTH,
"long_length": LONG_LENGTH,
"row_num": N_ROWS,
"col_num": N_COLUMNS,
"cars_left": N_LEFT,
"cars_right": N_RIGHT,
"cars_top": N_TOP,
"cars_bot": N_BOTTOM,
},
"horizontal_lanes": 1,
"vertical_lanes": 1,
},
),
# vehicles to be placed in the network at the start of a rollout (see
# flow.core.params.VehicleParams)
veh=vehicles,
# parameters specifying the positioning of vehicles upon initialization/
# reset (see flow.core.params.InitialConfig)
initial=InitialConfig(
spacing='custom',
shuffle=True,
),
)
return flow_params
================================================
FILE: d4rl/d4rl/gym_bullet/__init__.py
================================================
from gym.envs.registration import register
from d4rl.gym_bullet import gym_envs
from d4rl import infos
for agent in ['hopper', 'halfcheetah', 'ant', 'walker2d']:
register(
id='bullet-%s-v0' % agent,
entry_point='d4rl.gym_bullet.gym_envs:get_%s_env' % agent,
max_episode_steps=1000,
)
for dataset in ['random', 'medium', 'expert', 'medium-expert', 'medium-replay']:
env_name = 'bullet-%s-%s-v0' % (agent, dataset)
register(
id=env_name,
entry_point='d4rl.gym_bullet.gym_envs:get_%s_env' % agent,
max_episode_steps=1000,
kwargs={
'ref_min_score': infos.REF_MIN_SCORE[env_name],
'ref_max_score': infos.REF_MAX_SCORE[env_name],
'dataset_url': infos.DATASET_URLS[env_name]
}
)
================================================
FILE: d4rl/d4rl/gym_bullet/gym_envs.py
================================================
from .. import offline_env
from pybullet_envs.gym_locomotion_envs import HopperBulletEnv, HalfCheetahBulletEnv, Walker2DBulletEnv, AntBulletEnv
from ..utils.wrappers import NormalizedBoxEnv
class OfflineAntEnv(AntBulletEnv, offline_env.OfflineEnv):
def __init__(self, **kwargs):
AntBulletEnv.__init__(self,)
offline_env.OfflineEnv.__init__(self, **kwargs)
class OfflineHopperEnv(HopperBulletEnv, offline_env.OfflineEnv):
def __init__(self, **kwargs):
HopperBulletEnv.__init__(self,)
offline_env.OfflineEnv.__init__(self, **kwargs)
class OfflineHalfCheetahEnv(HalfCheetahBulletEnv, offline_env.OfflineEnv):
def __init__(self, **kwargs):
HalfCheetahBulletEnv.__init__(self,)
offline_env.OfflineEnv.__init__(self, **kwargs)
class OfflineWalker2dEnv(Walker2DBulletEnv, offline_env.OfflineEnv):
def __init__(self, **kwargs):
Walker2DBulletEnv.__init__(self,)
offline_env.OfflineEnv.__init__(self, **kwargs)
def get_ant_env(**kwargs):
return NormalizedBoxEnv(OfflineAntEnv(**kwargs))
def get_halfcheetah_env(**kwargs):
return NormalizedBoxEnv(OfflineHalfCheetahEnv(**kwargs))
def get_hopper_env(**kwargs):
return NormalizedBoxEnv(OfflineHopperEnv(**kwargs))
def get_walker2d_env(**kwargs):
return NormalizedBoxEnv(OfflineWalker2dEnv(**kwargs))
================================================
FILE: d4rl/d4rl/gym_minigrid/__init__.py
================================================
from gym.envs.registration import register
register(
id='minigrid-fourrooms-v0',
entry_point='d4rl.gym_minigrid.envs.fourrooms:FourRoomsEnv',
max_episode_steps=50,
kwargs={
'ref_min_score': 0.01442,
'ref_max_score': 2.89685,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/minigrid/minigrid4rooms.hdf5'
}
)
register(
id='minigrid-fourrooms-random-v0',
entry_point='d4rl.gym_minigrid.envs.fourrooms:FourRoomsEnv',
max_episode_steps=50,
kwargs={
'ref_min_score': 0.01442,
'ref_max_score': 2.89685,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/minigrid/minigrid4rooms_random.hdf5'
}
)
================================================
FILE: d4rl/d4rl/gym_minigrid/envs/__init__.py
================================================
from d4rl.gym_minigrid.envs.fourrooms import *
from d4rl.gym_minigrid.envs.empty import *
================================================
FILE: d4rl/d4rl/gym_minigrid/envs/empty.py
================================================
from d4rl.gym_minigrid.minigrid import *
from d4rl.gym_minigrid.register import register
class EmptyEnv(MiniGridEnv):
"""
Empty grid environment, no obstacles, sparse reward
"""
def __init__(
self,
size=8,
agent_start_pos=(1,1),
agent_start_dir=0,
):
self.agent_start_pos = agent_start_pos
self.agent_start_dir = agent_start_dir
super().__init__(
grid_size=size,
max_steps=4*size*size,
# Set this to True for maximum speed
see_through_walls=True
)
def _gen_grid(self, width, height):
# Create an empty grid
self.grid = Grid(width, height)
# Generate the surrounding walls
self.grid.wall_rect(0, 0, width, height)
# Place a goal square in the bottom-right corner
self.put_obj(Goal(), width - 2, height - 2)
# Place the agent
if self.agent_start_pos is not None:
self.agent_pos = self.agent_start_pos
self.agent_dir = self.agent_start_dir
else:
self.place_agent()
self.mission = "get to the green goal square"
class EmptyEnv5x5(EmptyEnv):
def __init__(self):
super().__init__(size=5)
class EmptyRandomEnv5x5(EmptyEnv):
def __init__(self):
super().__init__(size=5, agent_start_pos=None)
class EmptyEnv6x6(EmptyEnv):
def __init__(self):
super().__init__(size=6)
class EmptyRandomEnv6x6(EmptyEnv):
def __init__(self):
super().__init__(size=6, agent_start_pos=None)
class EmptyEnv16x16(EmptyEnv):
def __init__(self):
super().__init__(size=16)
register(
id='MiniGrid-Empty-5x5-v0',
entry_point='gym_minigrid.envs:EmptyEnv5x5'
)
register(
id='MiniGrid-Empty-Random-5x5-v0',
entry_point='gym_minigrid.envs:EmptyRandomEnv5x5'
)
register(
id='MiniGrid-Empty-6x6-v0',
entry_point='gym_minigrid.envs:EmptyEnv6x6'
)
register(
id='MiniGrid-Empty-Random-6x6-v0',
entry_point='gym_minigrid.envs:EmptyRandomEnv6x6'
)
register(
id='MiniGrid-Empty-8x8-v0',
entry_point='gym_minigrid.envs:EmptyEnv'
)
register(
id='MiniGrid-Empty-16x16-v0',
entry_point='gym_minigrid.envs:EmptyEnv16x16'
)
================================================
FILE: d4rl/d4rl/gym_minigrid/envs/fourrooms.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from d4rl.gym_minigrid.minigrid import *
from d4rl.gym_minigrid.register import register
class FourRoomsEnv(MiniGridEnv):
"""
Classic 4 rooms gridworld environment.
Can specify agent and goal position, if not it set at random.
"""
def __init__(self, agent_pos=None, goal_pos=None, **kwargs):
self._agent_default_pos = agent_pos
if goal_pos is None:
goal_pos = (12, 12)
self._goal_default_pos = goal_pos
super().__init__(grid_size=19, max_steps=100, **kwargs)
def get_target(self):
return self._goal_default_pos
def _gen_grid(self, width, height):
# Create the grid
self.grid = Grid(width, height)
# Generate the surrounding walls
self.grid.horz_wall(0, 0)
self.grid.horz_wall(0, height - 1)
self.grid.vert_wall(0, 0)
self.grid.vert_wall(width - 1, 0)
room_w = width // 2
room_h = height // 2
# For each row of rooms
for j in range(0, 2):
# For each column
for i in range(0, 2):
xL = i * room_w
yT = j * room_h
xR = xL + room_w
yB = yT + room_h
# Bottom wall and door
if i + 1 < 2:
self.grid.vert_wall(xR, yT, room_h)
pos = (xR, self._rand_int(yT + 1, yB))
self.grid.set(*pos, None)
# Bottom wall and door
if j + 1 < 2:
self.grid.horz_wall(xL, yB, room_w)
pos = (self._rand_int(xL + 1, xR), yB)
self.grid.set(*pos, None)
# Randomize the player start position and orientation
if self._agent_default_pos is not None:
self.agent_pos = self._agent_default_pos
self.grid.set(*self._agent_default_pos, None)
self.agent_dir = self._rand_int(0, 4) # assuming random start direction
else:
self.place_agent()
if self._goal_default_pos is not None:
goal = Goal()
self.put_obj(goal, *self._goal_default_pos)
goal.init_pos, goal.cur_pos = self._goal_default_pos
else:
self.place_obj(Goal())
self.mission = 'Reach the goal'
def step(self, action):
obs, reward, done, info = MiniGridEnv.step(self, action)
return obs, reward, done, info
register(
id='MiniGrid-FourRooms-v0',
entry_point='gym_minigrid.envs:FourRoomsEnv'
)
================================================
FILE: d4rl/d4rl/gym_minigrid/fourroom_controller.py
================================================
import numpy as np
import random
from d4rl.pointmaze import q_iteration
from d4rl.pointmaze.gridcraft import grid_env
from d4rl.pointmaze.gridcraft import grid_spec
MAZE = \
"###################\\"+\
"#OOOOOOOO#OOOOOOOO#\\"+\
"#OOOOOOOO#OOOOOOOO#\\"+\
"#OOOOOOOOOOOOOOOOO#\\"+\
"#OOOOOOOO#OOOOOOOO#\\"+\
"#OOOOOOOO#OOOOOOOO#\\"+\
"#OOOOOOOO#OOOOOOOO#\\"+\
"#OOOOOOOO#OOOOOOOO#\\"+\
"#OOOOOOOO#OOOOOOOO#\\"+\
"####O#########O####\\"+\
"#OOOOOOOO#OOOOOOOO#\\"+\
"#OOOOOOOO#OOOOOOOO#\\"+\
"#OOOOOOOO#OOOOOOOO#\\"+\
"#OOOOOOOO#OOOOOOOO#\\"+\
"#OOOOOOOO#OOOOOOOO#\\"+\
"#OOOOOOOO#OOOOOOOO#\\"+\
"#OOOOOOOOOOOOOOOOO#\\"+\
"#OOOOOOOO#OOOOOOOO#\\"+\
"###################\\"
# NLUDR -> RDLU
TRANSLATE_DIRECTION = {
0: None,
1: 3,#3,
2: 1,#1,
3: 2,#2,
4: 0,#0,
}
RIGHT = 1
LEFT = 0
FORWARD = 2
class FourRoomController(object):
def __init__(self):
self.env = grid_env.GridEnv(grid_spec.spec_from_string(MAZE))
self.reset_locations = list(zip(*np.where(self.env.gs.spec == grid_spec.EMPTY)))
def sample_target(self):
return random.choice(self.reset_locations)
def set_target(self, target):
self.target = target
self.env.gs[target] = grid_spec.REWARD
self.q_values = q_iteration.q_iteration(env=self.env, num_itrs=32, discount=0.99)
self.env.gs[target] = grid_spec.EMPTY
def get_action(self, pos, orientation):
if tuple(pos) == tuple(self.target):
done = True
else:
done = False
env_pos_idx = self.env.gs.xy_to_idx(pos)
qvalues = self.q_values[env_pos_idx]
direction = TRANSLATE_DIRECTION[np.argmax(qvalues)]
#tgt_pos, _ = self.env.step_stateless(env_pos_idx, np.argmax(qvalues))
#tgt_pos = self.env.gs.idx_to_xy(tgt_pos)
#print('\tcmd_dir:', direction, np.argmax(qvalues), qvalues, tgt_pos)
#infos = {}
#infos['tgt_pos'] = tgt_pos
if orientation == direction or direction == None:
return FORWARD, done
else:
return get_turn(orientation, direction), done
#RDLU
TURN_DIRS = [
[None, RIGHT, RIGHT, LEFT], #R
[LEFT, None, RIGHT, RIGHT], #D
[RIGHT, LEFT, None, RIGHT], #L
[RIGHT, RIGHT, LEFT, None], #U
]
def get_turn(ori, tgt_ori):
return TURN_DIRS[ori][tgt_ori]
================================================
FILE: d4rl/d4rl/gym_minigrid/minigrid.py
================================================
import math
import gym
from enum import IntEnum
import numpy as np
from gym import error, spaces, utils
from gym.utils import seeding
from d4rl.gym_minigrid.rendering import *
from d4rl import offline_env
# Size in pixels of a tile in the full-scale human view
TILE_PIXELS = 32
# Map of color names to RGB values
COLORS = {
'red' : np.array([255, 0, 0]),
'green' : np.array([0, 255, 0]),
'blue' : np.array([0, 0, 255]),
'purple': np.array([112, 39, 195]),
'yellow': np.array([255, 255, 0]),
'grey' : np.array([100, 100, 100])
}
COLOR_NAMES = sorted(list(COLORS.keys()))
# Used to map colors to integers
COLOR_TO_IDX = {
'red' : 0,
'green' : 1,
'blue' : 2,
'purple': 3,
'yellow': 4,
'grey' : 5
}
IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
# Map of object type to integers
OBJECT_TO_IDX = {
'unseen' : 0,
'empty' : 1,
'wall' : 2,
'floor' : 3,
'door' : 4,
'key' : 5,
'ball' : 6,
'box' : 7,
'goal' : 8,
'lava' : 9,
'agent' : 10,
}
IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
# Map of state names to integers
STATE_TO_IDX = {
'open' : 0,
'closed': 1,
'locked': 2,
}
# Map of agent direction indices to vectors
DIR_TO_VEC = [
# Pointing right (positive X)
np.array((1, 0)),
# Down (positive Y)
np.array((0, 1)),
# Pointing left (negative X)
np.array((-1, 0)),
# Up (negative Y)
np.array((0, -1)),
]
class WorldObj:
"""
Base class for grid world objects
"""
def __init__(self, type, color):
assert type in OBJECT_TO_IDX, type
assert color in COLOR_TO_IDX, color
self.type = type
self.color = color
self.contains = None
# Initial position of the object
self.init_pos = None
# Current position of the object
self.cur_pos = None
def can_overlap(self):
"""Can the agent overlap with this?"""
return False
def can_pickup(self):
"""Can the agent pick this up?"""
return False
def can_contain(self):
"""Can this contain another object?"""
return False
def see_behind(self):
"""Can the agent see behind this object?"""
return True
def toggle(self, env, pos):
"""Method to trigger/toggle an action this object performs"""
return False
def encode(self):
"""Encode the a description of this object as a 3-tuple of integers"""
return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
@staticmethod
def decode(type_idx, color_idx, state):
"""Create an object from a 3-tuple state description"""
obj_type = IDX_TO_OBJECT[type_idx]
color = IDX_TO_COLOR[color_idx]
if obj_type == 'empty' or obj_type == 'unseen':
return None
# State, 0: open, 1: closed, 2: locked
is_open = state == 0
is_locked = state == 2
if obj_type == 'wall':
v = Wall(color)
elif obj_type == 'floor':
v = Floor(color)
elif obj_type == 'ball':
v = Ball(color)
elif obj_type == 'key':
v = Key(color)
elif obj_type == 'box':
v = Box(color)
elif obj_type == 'door':
v = Door(color, is_open, is_locked)
elif obj_type == 'goal':
v = Goal()
elif obj_type == 'lava':
v = Lava()
else:
assert False, "unknown object type in decode '%s'" % objType
return v
def render(self, r):
"""Draw this object with the given renderer"""
raise NotImplementedError
class Goal(WorldObj):
def __init__(self):
super().__init__('goal', 'green')
def can_overlap(self):
return True
def render(self, img):
fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
class Floor(WorldObj):
"""
Colored floor tile the agent can walk over
"""
def __init__(self, color='blue'):
super().__init__('floor', color)
def can_overlap(self):
return True
def render(self, r):
# Give the floor a pale color
c = COLORS[self.color]
r.setLineColor(100, 100, 100, 0)
r.setColor(*c/2)
r.drawPolygon([
(1 , TILE_PIXELS),
(TILE_PIXELS, TILE_PIXELS),
(TILE_PIXELS, 1),
(1 , 1)
])
class Lava(WorldObj):
def __init__(self):
super().__init__('lava', 'red')
def can_overlap(self):
return True
def render(self, img):
c = (255, 128, 0)
# Background color
fill_coords(img, point_in_rect(0, 1, 0, 1), c)
# Little waves
for i in range(3):
ylo = 0.3 + 0.2 * i
yhi = 0.4 + 0.2 * i
fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0,0,0))
fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0,0,0))
fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0,0,0))
fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0,0,0))
class Wall(WorldObj):
def __init__(self, color='grey'):
super().__init__('wall', color)
def see_behind(self):
return False
def render(self, img):
fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
class Door(WorldObj):
def __init__(self, color, is_open=False, is_locked=False):
super().__init__('door', color)
self.is_open = is_open
self.is_locked = is_locked
def can_overlap(self):
"""The agent can only walk over this cell when the door is open"""
return self.is_open
def see_behind(self):
return self.is_open
def toggle(self, env, pos):
# If the player has the right key to open the door
if self.is_locked:
if isinstance(env.carrying, Key) and env.carrying.color == self.color:
self.is_locked = False
self.is_open = True
return True
return False
self.is_open = not self.is_open
return True
def encode(self):
"""Encode the a description of this object as a 3-tuple of integers"""
# State, 0: open, 1: closed, 2: locked
if self.is_open:
state = 0
elif self.is_locked:
state = 2
elif not self.is_open:
state = 1
return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
def render(self, img):
c = COLORS[self.color]
if self.is_open:
fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0,0,0))
return
# Door frame and door
if self.is_locked:
fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
# Draw key slot
fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
else:
fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0,0,0))
fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0,0,0))
# Draw door handle
fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
class Key(WorldObj):
def __init__(self, color='blue'):
super(Key, self).__init__('key', color)
def can_pickup(self):
return True
def render(self, img):
c = COLORS[self.color]
# Vertical quad
fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)
# Teeth
fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)
# Ring
fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0,0,0))
class Ball(WorldObj):
def __init__(self, color='blue'):
super(Ball, self).__init__('ball', color)
def can_pickup(self):
return True
def render(self, img):
fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
class Box(WorldObj):
def __init__(self, color, contains=None):
super(Box, self).__init__('box', color)
self.contains = contains
def can_pickup(self):
return True
def render(self, img):
c = COLORS[self.color]
# Outline
fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0,0,0))
# Horizontal slit
fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
def toggle(self, env, pos):
# Replace the box by its contents
env.grid.set(*pos, self.contains)
return True
class Grid:
"""
Represent a grid and operations on it
"""
# Static cache of pre-renderer tiles
tile_cache = {}
def __init__(self, width, height):
assert width >= 3
assert height >= 3
self.width = width
self.height = height
self.grid = [None] * width * height
def __contains__(self, key):
if isinstance(key, WorldObj):
for e in self.grid:
if e is key:
return True
elif isinstance(key, tuple):
for e in self.grid:
if e is None:
continue
if (e.color, e.type) == key:
return True
if key[0] is None and key[1] == e.type:
return True
return False
def __eq__(self, other):
grid1 = self.encode()
grid2 = other.encode()
return np.array_equal(grid2, grid1)
def __ne__(self, other):
return not self == other
def copy(self):
from copy import deepcopy
return deepcopy(self)
def set(self, i, j, v):
assert i >= 0 and i < self.width
assert j >= 0 and j < self.height
self.grid[j * self.width + i] = v
def get(self, i, j):
assert i >= 0 and i < self.width
assert j >= 0 and j < self.height
return self.grid[j * self.width + i]
def horz_wall(self, x, y, length=None, obj_type=Wall):
if length is None:
length = self.width - x
for i in range(0, length):
self.set(x + i, y, obj_type())
def vert_wall(self, x, y, length=None, obj_type=Wall):
if length is None:
length = self.height - y
for j in range(0, length):
self.set(x, y + j, obj_type())
def wall_rect(self, x, y, w, h):
self.horz_wall(x, y, w)
self.horz_wall(x, y+h-1, w)
self.vert_wall(x, y, h)
self.vert_wall(x+w-1, y, h)
def rotate_left(self):
"""
Rotate the grid to the left (counter-clockwise)
"""
grid = Grid(self.height, self.width)
for i in range(self.width):
for j in range(self.height):
v = self.get(i, j)
grid.set(j, grid.height - 1 - i, v)
return grid
def slice(self, topX, topY, width, height):
"""
Get a subset of the grid
"""
grid = Grid(width, height)
for j in range(0, height):
for i in range(0, width):
x = topX + i
y = topY + j
if x >= 0 and x < self.width and \
y >= 0 and y < self.height:
v = self.get(x, y)
else:
v = Wall()
grid.set(i, j, v)
return grid
@classmethod
def render_tile(
cls,
obj,
agent_dir=None,
highlight=False,
tile_size=TILE_PIXELS,
subdivs=3
):
"""
Render a tile and cache the result
"""
# Hash map lookup key for the cache
key = (agent_dir, highlight, tile_size)
key = obj.encode() + key if obj else key
if key in cls.tile_cache:
return cls.tile_cache[key]
img = np.zeros(shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8)
# Draw the grid lines (top and left edges)
fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
if obj != None:
obj.render(img)
# Overlay the agent on top
if agent_dir is not None:
tri_fn = point_in_triangle(
(0.12, 0.19),
(0.87, 0.50),
(0.12, 0.81),
)
# Rotate the agent based on its direction
tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5*math.pi*agent_dir)
fill_coords(img, tri_fn, (255, 0, 0))
# Highlight the cell if needed
if highlight:
highlight_img(img)
# Downsample the image to perform supersampling/anti-aliasing
img = downsample(img, subdivs)
# Cache the rendered tile
cls.tile_cache[key] = img
return img
def render(
self,
tile_size,
agent_pos=None,
agent_dir=None,
highlight_mask=None
):
"""
Render this grid at a given scale
:param r: target renderer object
:param tile_size: tile size in pixels
"""
if highlight_mask is None:
highlight_mask = np.zeros(shape=(self.width, self.height), dtype=np.bool)
# Compute the total grid size
width_px = self.width * tile_size
height_px = self.height * tile_size
img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
# Render the grid
for j in range(0, self.height):
for i in range(0, self.width):
cell = self.get(i, j)
agent_here = np.array_equal(agent_pos, (i, j))
tile_img = Grid.render_tile(
cell,
agent_dir=agent_dir if agent_here else None,
highlight=highlight_mask[i, j],
tile_size=tile_size
)
ymin = j * tile_size
ymax = (j+1) * tile_size
xmin = i * tile_size
xmax = (i+1) * tile_size
img[ymin:ymax, xmin:xmax, :] = tile_img
return img
def encode(self, vis_mask=None):
"""
Produce a compact numpy encoding of the grid
"""
if vis_mask is None:
vis_mask = np.ones((self.width, self.height), dtype=bool)
array = np.zeros((self.width, self.height, 3), dtype='uint8')
for i in range(self.width):
for j in range(self.height):
if vis_mask[i, j]:
v = self.get(i, j)
if v is None:
array[i, j, 0] = OBJECT_TO_IDX['empty']
array[i, j, 1] = 0
array[i, j, 2] = 0
else:
array[i, j, :] = v.encode()
return array
@staticmethod
def decode(array):
"""
Decode an array grid encoding back into a grid
"""
width, height, channels = array.shape
assert channels == 3
vis_mask = np.ones(shape=(width, height), dtype=np.bool)
grid = Grid(width, height)
for i in range(width):
for j in range(height):
type_idx, color_idx, state = array[i, j]
v = WorldObj.decode(type_idx, color_idx, state)
grid.set(i, j, v)
vis_mask[i, j] = (type_idx != OBJECT_TO_IDX['unseen'])
return grid, vis_mask
def process_vis(grid, agent_pos):
mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
mask[agent_pos[0], agent_pos[1]] = True
for j in reversed(range(0, grid.height)):
for i in range(0, grid.width-1):
if not mask[i, j]:
continue
cell = grid.get(i, j)
if cell and not cell.see_behind():
continue
mask[i+1, j] = True
if j > 0:
mask[i+1, j-1] = True
mask[i, j-1] = True
for i in reversed(range(1, grid.width)):
if not mask[i, j]:
continue
cell = grid.get(i, j)
if cell and not cell.see_behind():
continue
mask[i-1, j] = True
if j > 0:
mask[i-1, j-1] = True
mask[i, j-1] = True
for j in range(0, grid.height):
for i in range(0, grid.width):
if not mask[i, j]:
grid.set(i, j, None)
return mask
class MiniGridEnv(offline_env.OfflineEnv):
"""
2D grid world game environment
"""
metadata = {
'render.modes': ['human', 'rgb_array'],
'video.frames_per_second' : 10
}
# Enumeration of possible actions
class Actions(IntEnum):
# Turn left, turn right, move forward
left = 0
right = 1
forward = 2
# Pick up an object
pickup = 3
# Drop an object
drop = 4
# Toggle/activate an object
toggle = 5
# Done completing task
done = 6
def __init__(
self,
grid_size=None,
width=None,
height=None,
max_steps=100,
see_through_walls=False,
seed=1337,
agent_view_size=7,
**kwargs
):
offline_env.OfflineEnv.__init__(self, **kwargs)
# Can't set both grid_size and width/height
if grid_size:
assert width == None and height == None
width = grid_size
height = grid_size
# Action enumeration for this environment
self.actions = MiniGridEnv.Actions
# Actions are discrete integer values
self.action_space = spaces.Discrete(len(self.actions))
# Number of cells (width and height) in the agent view
self.agent_view_size = agent_view_size
# Observations are dictionaries containing an
# encoding of the grid and a textual 'mission' string
self.observation_space = spaces.Box(
low=0,
high=255,
shape=(self.agent_view_size, self.agent_view_size, 3),
dtype='uint8'
)
self.observation_space = spaces.Dict({
'image': self.observation_space
})
# Range of possible rewards
self.reward_range = (0, 1)
# Window to use for human rendering mode
self.window = None
# Environment configuration
self.width = width
self.height = height
self.max_steps = max_steps
self.see_through_walls = see_through_walls
# Current position and direction of the agent
self.agent_pos = None
self.agent_dir = None
# Initialize the RNG
self.seed(seed=seed)
# Initialize the state
self.reset()
def reset(self):
# Current position and direction of the agent
self.agent_pos = None
self.agent_dir = None
# Generate a new random grid at the start of each episode
# To keep the same grid for each episode, call env.seed() with
# the same seed before calling env.reset()
self._gen_grid(self.width, self.height)
# These fields should be defined by _gen_grid
assert self.agent_pos is not None
assert self.agent_dir is not None
# Check that the agent doesn't overlap with an object
start_cell = self.grid.get(*self.agent_pos)
assert start_cell is None or start_cell.can_overlap()
# Item picked up, being carried, initially nothing
self.carrying = None
# Step count since episode start
self.step_count = 0
# Return first observation
obs = self.gen_obs()
return obs
def seed(self, seed=1337):
# Seed the random number generator
self.np_random, _ = seeding.np_random(seed)
return [seed]
@property
def steps_remaining(self):
return self.max_steps - self.step_count
def __str__(self):
"""
Produce a pretty string of the environment's grid along with the agent.
A grid cell is represented by 2-character string, the first one for
the object and the second one for the color.
"""
# Map of object types to short string
OBJECT_TO_STR = {
'wall' : 'W',
'floor' : 'F',
'door' : 'D',
'key' : 'K',
'ball' : 'A',
'box' : 'B',
'goal' : 'G',
'lava' : 'V',
}
# Short string for opened door
OPENDED_DOOR_IDS = '_'
# Map agent's direction to short string
AGENT_DIR_TO_STR = {
0: '>',
1: 'V',
2: '<',
3: '^'
}
str = ''
for j in range(self.grid.height):
for i in range(self.grid.width):
if i == self.agent_pos[0] and j == self.agent_pos[1]:
str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
continue
c = self.grid.get(i, j)
if c == None:
str += ' '
continue
if c.type == 'door':
if c.is_open:
str += '__'
elif c.is_locked:
str += 'L' + c.color[0].upper()
else:
str += 'D' + c.color[0].upper()
continue
str += OBJECT_TO_STR[c.type] + c.color[0].upper()
if j < self.grid.height - 1:
str += '\n'
return str
def _gen_grid(self, width, height):
assert False, "_gen_grid needs to be implemented by each environment"
def _reward(self):
"""
Compute the reward to be given upon success
"""
return 1 - 0.9 * (self.step_count / self.max_steps)
def _rand_int(self, low, high):
"""
Generate random integer in [low,high[
"""
return self.np_random.randint(low, high)
def _rand_float(self, low, high):
"""
Generate random float in [low,high[
"""
return self.np_random.uniform(low, high)
def _rand_bool(self):
"""
Generate random boolean value
"""
return (self.np_random.randint(0, 2) == 0)
def _rand_elem(self, iterable):
"""
Pick a random element in a list
"""
lst = list(iterable)
idx = self._rand_int(0, len(lst))
return lst[idx]
def _rand_subset(self, iterable, num_elems):
"""
Sample a random subset of distinct elements of a list
"""
lst = list(iterable)
assert num_elems <= len(lst)
out = []
while len(out) < num_elems:
elem = self._rand_elem(lst)
lst.remove(elem)
out.append(elem)
return out
def _rand_color(self):
"""
Generate a random color name (string)
"""
return self._rand_elem(COLOR_NAMES)
def _rand_pos(self, xLow, xHigh, yLow, yHigh):
"""
Generate a random (x,y) position tuple
"""
return (
self.np_random.randint(xLow, xHigh),
self.np_random.randint(yLow, yHigh)
)
def place_obj(self,
obj,
top=None,
size=None,
reject_fn=None,
max_tries=math.inf
):
"""
Place an object at an empty position in the grid
:param top: top-left position of the rectangle where to place
:param size: size of the rectangle where to place
:param reject_fn: function to filter out potential positions
"""
if top is None:
top = (0, 0)
else:
top = (max(top[0], 0), max(top[1], 0))
if size is None:
size = (self.grid.width, self.grid.height)
num_tries = 0
while True:
# This is to handle with rare cases where rejection sampling
# gets stuck in an infinite loop
if num_tries > max_tries:
raise RecursionError('rejection sampling failed in place_obj')
num_tries += 1
pos = np.array((
self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
self._rand_int(top[1], min(top[1] + size[1], self.grid.height))
))
# Don't place the object on top of another object
if self.grid.get(*pos) != None:
continue
# Don't place the object where the agent is
if np.array_equal(pos, self.agent_pos):
continue
# Check if there is a filtering criterion
if reject_fn and reject_fn(self, pos):
continue
break
self.grid.set(*pos, obj)
if obj is not None:
obj.init_pos = pos
obj.cur_pos = pos
return pos
def put_obj(self, obj, i, j):
"""
Put an object at a specific position in the grid
"""
self.grid.set(i, j, obj)
obj.init_pos = (i, j)
obj.cur_pos = (i, j)
def place_agent(
self,
top=None,
size=None,
rand_dir=True,
max_tries=math.inf
):
"""
Set the agent's starting point at an empty position in the grid
"""
self.agent_pos = None
pos = self.place_obj(None, top, size, max_tries=max_tries)
self.agent_pos = pos
if rand_dir:
self.agent_dir = self._rand_int(0, 4)
return pos
@property
def dir_vec(self):
"""
Get the direction vector for the agent, pointing in the direction
of forward movement.
"""
assert self.agent_dir >= 0 and self.agent_dir < 4
return DIR_TO_VEC[self.agent_dir]
@property
def right_vec(self):
"""
Get the vector pointing to the right of the agent.
"""
dx, dy = self.dir_vec
return np.array((-dy, dx))
@property
def front_pos(self):
"""
Get the position of the cell that is right in front of the agent
"""
return self.agent_pos + self.dir_vec
def get_view_coords(self, i, j):
"""
Translate and rotate absolute grid coordinates (i, j) into the
agent's partially observable view (sub-grid). Note that the resulting
coordinates may be negative or outside of the agent's view size.
"""
ax, ay = self.agent_pos
dx, dy = self.dir_vec
rx, ry = self.right_vec
# Compute the absolute coordinates of the top-left view corner
sz = self.agent_view_size
hs = self.agent_view_size // 2
tx = ax + (dx * (sz-1)) - (rx * hs)
ty = ay + (dy * (sz-1)) - (ry * hs)
lx = i - tx
ly = j - ty
# Project the coordinates of the object relative to the top-left
# corner onto the agent's own coordinate system
vx = (rx*lx + ry*ly)
vy = -(dx*lx + dy*ly)
return vx, vy
def get_view_exts(self):
"""
Get the extents of the square set of tiles visible to the agent
Note: the bottom extent indices are not included in the set
"""
# Facing right
if self.agent_dir == 0:
topX = self.agent_pos[0]
topY = self.agent_pos[1] - self.agent_view_size // 2
# Facing down
elif self.agent_dir == 1:
topX = self.agent_pos[0] - self.agent_view_size // 2
topY = self.agent_pos[1]
# Facing left
elif self.agent_dir == 2:
topX = self.agent_pos[0] - self.agent_view_size + 1
topY = self.agent_pos[1] - self.agent_view_size // 2
# Facing up
elif self.agent_dir == 3:
topX = self.agent_pos[0] - self.agent_view_size // 2
topY = self.agent_pos[1] - self.agent_view_size + 1
else:
assert False, "invalid agent direction"
botX = topX + self.agent_view_size
botY = topY + self.agent_view_size
return (topX, topY, botX, botY)
def relative_coords(self, x, y):
"""
Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
"""
vx, vy = self.get_view_coords(x, y)
if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
return None
return vx, vy
def in_view(self, x, y):
"""
check if a grid position is visible to the agent
"""
return self.relative_coords(x, y) is not None
def agent_sees(self, x, y):
"""
Check if a non-empty grid position is visible to the agent
"""
coordinates = self.relative_coords(x, y)
if coordinates is None:
return False
vx, vy = coordinates
obs = self.gen_obs()
obs_grid, _ = Grid.decode(obs['image'])
obs_cell = obs_grid.get(vx, vy)
world_cell = self.grid.get(x, y)
return obs_cell is not None and obs_cell.type == world_cell.type
def step(self, action):
self.step_count += 1
reward = 0
done = False
# Get the position in front of the agent
fwd_pos = self.front_pos
# Get the contents of the cell in front of the agent
fwd_cell = self.grid.get(*fwd_pos)
# Rotate left
if action == self.actions.left:
self.agent_dir -= 1
if self.agent_dir < 0:
self.agent_dir += 4
# Rotate right
elif action == self.actions.right:
self.agent_dir = (self.agent_dir + 1) % 4
# Move forward
elif action == self.actions.forward:
if fwd_cell == None or fwd_cell.can_overlap():
self.agent_pos = fwd_pos
if fwd_cell != None and fwd_cell.type == 'goal':
done = True
reward = self._reward()
if fwd_cell != None and fwd_cell.type == 'lava':
done = True
# Pick up an object
elif action == self.actions.pickup:
if fwd_cell and fwd_cell.can_pickup():
if self.carrying is None:
self.carrying = fwd_cell
self.carrying.cur_pos = np.array([-1, -1])
self.grid.set(*fwd_pos, None)
# Drop an object
elif action == self.actions.drop:
if not fwd_cell and self.carrying:
self.grid.set(*fwd_pos, self.carrying)
self.carrying.cur_pos = fwd_pos
self.carrying = None
# Toggle/activate an object
elif action == self.actions.toggle:
if fwd_cell:
fwd_cell.toggle(self, fwd_pos)
# Done action (not used by default)
elif action == self.actions.done:
pass
else:
assert False, "unknown action"
if self.step_count >= self.max_steps:
done = True
obs = self.gen_obs()
return obs, reward, done, {}
def gen_obs_grid(self):
"""
Generate the sub-grid observed by the agent.
This method also outputs a visibility mask telling us which grid
cells the agent can actually see.
"""
topX, topY, botX, botY = self.get_view_exts()
grid = self.grid.slice(topX, topY, self.agent_view_size, self.agent_view_size)
for i in range(self.agent_dir + 1):
grid = grid.rotate_left()
# Process occluders and visibility
# Note that this incurs some performance cost
if not self.see_through_walls:
vis_mask = grid.process_vis(agent_pos=(self.agent_view_size // 2 , self.agent_view_size - 1))
else:
vis_mask = np.ones(shape=(grid.width, grid.height), dtype=np.bool)
# Make it so the agent sees what it's carrying
# We do this by placing the carried object at the agent's position
# in the agent's partially observable view
agent_pos = grid.width // 2, grid.height - 1
if self.carrying:
grid.set(*agent_pos, self.carrying)
else:
grid.set(*agent_pos, None)
return grid, vis_mask
def gen_obs(self):
"""
Generate the agent's view (partially observable, low-resolution encoding)
"""
grid, vis_mask = self.gen_obs_grid()
# Encode the partially observable view into a numpy array
image = grid.encode(vis_mask)
assert hasattr(self, 'mission'), "environments must define a textual mission string"
# Observations are dictionaries containing:
# - an image (partially observable view of the environment)
# - the agent's direction/orientation (acting as a compass)
# - a textual mission string (instructions for the agent)
obs = {
'image': image,
'direction': self.agent_dir,
'mission': self.mission
}
return obs
def get_obs_render(self, obs, tile_size=TILE_PIXELS//2):
"""
Render an agent observation for visualization
"""
grid, vis_mask = Grid.decode(obs)
# Render the whole grid
img = grid.render(
tile_size,
agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
agent_dir=3,
highlight_mask=vis_mask
)
return img
def render(self, mode='human', close=False, highlight=True, tile_size=TILE_PIXELS):
"""
Render the whole-grid human view
"""
if close:
if self.window:
self.window.close()
return
if mode == 'human' and not self.window:
import d4rl.gym_minigrid.window
self.window = d4rl.gym_minigrid.window.Window('gym_minigrid')
self.window.show(block=False)
# Compute which cells are visible to the agent
_, vis_mask = self.gen_obs_grid()
# Compute the world coordinates of the bottom-left corner
# of the agent's view area
f_vec = self.dir_vec
r_vec = self.right_vec
top_left = self.agent_pos + f_vec * (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2)
# Mask of which cells to highlight
highlight_mask = np.zeros(shape=(self.width, self.height), dtype=np.bool)
# For each cell in the visibility mask
for vis_j in range(0, self.agent_view_size):
for vis_i in range(0, self.agent_view_size):
# If this cell is not visible, don't highlight it
if not vis_mask[vis_i, vis_j]:
continue
# Compute the world coordinates of this cell
abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
if abs_i < 0 or abs_i >= self.width:
continue
if abs_j < 0 or abs_j >= self.height:
continue
# Mark this cell to be highlighted
highlight_mask[abs_i, abs_j] = True
# Render the whole grid
img = self.grid.render(
tile_size,
self.agent_pos,
self.agent_dir,
highlight_mask=highlight_mask if highlight else None
)
if mode == 'human':
self.window.show_img(img)
self.window.set_caption(self.mission)
return img
================================================
FILE: d4rl/d4rl/gym_minigrid/register.py
================================================
from gym.envs.registration import register as gym_register
env_list = []
def register(
id,
entry_point,
reward_threshold=0.95
):
assert id.startswith("MiniGrid-")
assert id not in env_list
# Register the environment with OpenAI gym
gym_register(
id=id,
entry_point=entry_point,
reward_threshold=reward_threshold
)
# Add the environment to the set
env_list.append(id)
================================================
FILE: d4rl/d4rl/gym_minigrid/rendering.py
================================================
import math
import numpy as np
def downsample(img, factor):
"""
Downsample an image along both dimensions by some factor
"""
assert img.shape[0] % factor == 0
assert img.shape[1] % factor == 0
img = img.reshape([img.shape[0]//factor, factor, img.shape[1]//factor, factor, 3])
img = img.mean(axis=3)
img = img.mean(axis=1)
return img
def fill_coords(img, fn, color):
"""
Fill pixels of an image with coordinates matching a filter function
"""
for y in range(img.shape[0]):
for x in range(img.shape[1]):
yf = (y + 0.5) / img.shape[0]
xf = (x + 0.5) / img.shape[1]
if fn(xf, yf):
img[y, x] = color
return img
def rotate_fn(fin, cx, cy, theta):
def fout(x, y):
x = x - cx
y = y - cy
x2 = cx + x * math.cos(-theta) - y * math.sin(-theta)
y2 = cy + y * math.cos(-theta) + x * math.sin(-theta)
return fin(x2, y2)
return fout
def point_in_line(x0, y0, x1, y1, r):
p0 = np.array([x0, y0])
p1 = np.array([x1, y1])
dir = p1 - p0
dist = np.linalg.norm(dir)
dir = dir / dist
xmin = min(x0, x1) - r
xmax = max(x0, x1) + r
ymin = min(y0, y1) - r
ymax = max(y0, y1) + r
def fn(x, y):
# Fast, early escape test
if x < xmin or x > xmax or y < ymin or y > ymax:
return False
q = np.array([x, y])
pq = q - p0
# Closest point on line
a = np.dot(pq, dir)
a = np.clip(a, 0, dist)
p = p0 + a * dir
dist_to_line = np.linalg.norm(q - p)
return dist_to_line <= r
return fn
def point_in_circle(cx, cy, r):
def fn(x, y):
return (x-cx)*(x-cx) + (y-cy)*(y-cy) <= r * r
return fn
def point_in_rect(xmin, xmax, ymin, ymax):
def fn(x, y):
return x >= xmin and x <= xmax and y >= ymin and y <= ymax
return fn
def point_in_triangle(a, b, c):
a = np.array(a)
b = np.array(b)
c = np.array(c)
def fn(x, y):
v0 = c - a
v1 = b - a
v2 = np.array((x, y)) - a
# Compute dot products
dot00 = np.dot(v0, v0)
dot01 = np.dot(v0, v1)
dot02 = np.dot(v0, v2)
dot11 = np.dot(v1, v1)
dot12 = np.dot(v1, v2)
# Compute barycentric coordinates
inv_denom = 1 / (dot00 * dot11 - dot01 * dot01)
u = (dot11 * dot02 - dot01 * dot12) * inv_denom
v = (dot00 * dot12 - dot01 * dot02) * inv_denom
# Check if point is in triangle
return (u >= 0) and (v >= 0) and (u + v) < 1
return fn
def highlight_img(img, color=(255, 255, 255), alpha=0.30):
"""
Add highlighting to an image
"""
blend_img = img + alpha * (np.array(color, dtype=np.uint8) - img)
blend_img = blend_img.clip(0, 255).astype(np.uint8)
img[:, :, :] = blend_img
================================================
FILE: d4rl/d4rl/gym_minigrid/roomgrid.py
================================================
from d4rl.gym_minigrid.minigrid import *
def reject_next_to(env, pos):
"""
Function to filter out object positions that are right next to
the agent's starting point
"""
sx, sy = env.agent_pos
x, y = pos
d = abs(sx - x) + abs(sy - y)
return d < 2
class Room:
def __init__(
self,
top,
size
):
# Top-left corner and size (tuples)
self.top = top
self.size = size
# List of door objects and door positions
# Order of the doors is right, down, left, up
self.doors = [None] * 4
self.door_pos = [None] * 4
# List of rooms adjacent to this one
# Order of the neighbors is right, down, left, up
self.neighbors = [None] * 4
# Indicates if this room is behind a locked door
self.locked = False
# List of objects contained
self.objs = []
def rand_pos(self, env):
topX, topY = self.top
sizeX, sizeY = self.size
return env._randPos(
topX + 1, topX + sizeX - 1,
topY + 1, topY + sizeY - 1
)
def pos_inside(self, x, y):
"""
Check if a position is within the bounds of this room
"""
topX, topY = self.top
sizeX, sizeY = self.size
if x < topX or y < topY:
return False
if x >= topX + sizeX or y >= topY + sizeY:
return False
return True
class RoomGrid(MiniGridEnv):
"""
Environment with multiple rooms and random objects.
This is meant to serve as a base class for other environments.
"""
def __init__(
self,
room_size=7,
num_rows=3,
num_cols=3,
max_steps=100,
seed=0
):
assert room_size > 0
assert room_size >= 3
assert num_rows > 0
assert num_cols > 0
self.room_size = room_size
self.num_rows = num_rows
self.num_cols = num_cols
height = (room_size - 1) * num_rows + 1
width = (room_size - 1) * num_cols + 1
# By default, this environment has no mission
self.mission = ''
super().__init__(
width=width,
height=height,
max_steps=max_steps,
see_through_walls=False,
seed=seed
)
def room_from_pos(self, x, y):
"""Get the room a given position maps to"""
assert x >= 0
assert y >= 0
i = x // (self.room_size-1)
j = y // (self.room_size-1)
assert i < self.num_cols
assert j < self.num_rows
return self.room_grid[j][i]
def get_room(self, i, j):
assert i < self.num_cols
assert j < self.num_rows
return self.room_grid[j][i]
def _gen_grid(self, width, height):
# Create the grid
self.grid = Grid(width, height)
self.room_grid = []
# For each row of rooms
for j in range(0, self.num_rows):
row = []
# For each column of rooms
for i in range(0, self.num_cols):
room = Room(
(i * (self.room_size-1), j * (self.room_size-1)),
(self.room_size, self.room_size)
)
row.append(room)
# Generate the walls for this room
self.grid.wall_rect(*room.top, *room.size)
self.room_grid.append(row)
# For each row of rooms
for j in range(0, self.num_rows):
# For each column of rooms
for i in range(0, self.num_cols):
room = self.room_grid[j][i]
x_l, y_l = (room.top[0] + 1, room.top[1] + 1)
x_m, y_m = (room.top[0] + room.size[0] - 1, room.top[1] + room.size[1] - 1)
# Door positions, order is right, down, left, up
if i < self.num_cols - 1:
room.neighbors[0] = self.room_grid[j][i+1]
room.door_pos[0] = (x_m, self._rand_int(y_l, y_m))
if j < self.num_rows - 1:
room.neighbors[1] = self.room_grid[j+1][i]
room.door_pos[1] = (self._rand_int(x_l, x_m), y_m)
if i > 0:
room.neighbors[2] = self.room_grid[j][i-1]
room.door_pos[2] = room.neighbors[2].door_pos[0]
if j > 0:
room.neighbors[3] = self.room_grid[j-1][i]
room.door_pos[3] = room.neighbors[3].door_pos[1]
# The agent starts in the middle, facing right
self.agent_pos = (
(self.num_cols // 2) * (self.room_size-1) + (self.room_size // 2),
(self.num_rows // 2) * (self.room_size-1) + (self.room_size // 2)
)
self.agent_dir = 0
def place_in_room(self, i, j, obj):
"""
Add an existing object to room (i, j)
"""
room = self.get_room(i, j)
pos = self.place_obj(
obj,
room.top,
room.size,
reject_fn=reject_next_to,
max_tries=1000
)
room.objs.append(obj)
return obj, pos
def add_object(self, i, j, kind=None, color=None):
"""
Add a new object to room (i, j)
"""
if kind == None:
kind = self._rand_elem(['key', 'ball', 'box'])
if color == None:
color = self._rand_color()
# TODO: we probably want to add an Object.make helper function
assert kind in ['key', 'ball', 'box']
if kind == 'key':
obj = Key(color)
elif kind == 'ball':
obj = Ball(color)
elif kind == 'box':
obj = Box(color)
return self.place_in_room(i, j, obj)
def add_door(self, i, j, door_idx=None, color=None, locked=None):
"""
Add a door to a room, connecting it to a neighbor
"""
room = self.get_room(i, j)
if door_idx == None:
# Need to make sure that there is a neighbor along this wall
# and that there is not already a door
while True:
door_idx = self._rand_int(0, 4)
if room.neighbors[door_idx] and room.doors[door_idx] is None:
break
if color == None:
color = self._rand_color()
if locked is None:
locked = self._rand_bool()
assert room.doors[door_idx] is None, "door already exists"
room.locked = locked
door = Door(color, is_locked=locked)
pos = room.door_pos[door_idx]
self.grid.set(*pos, door)
door.cur_pos = pos
neighbor = room.neighbors[door_idx]
room.doors[door_idx] = door
neighbor.doors[(door_idx+2) % 4] = door
return door, pos
def remove_wall(self, i, j, wall_idx):
"""
Remove a wall between two rooms
"""
room = self.get_room(i, j)
assert wall_idx >= 0 and wall_idx < 4
assert room.doors[wall_idx] is None, "door exists on this wall"
assert room.neighbors[wall_idx], "invalid wall"
neighbor = room.neighbors[wall_idx]
tx, ty = room.top
w, h = room.size
# Ordering of walls is right, down, left, up
if wall_idx == 0:
for i in range(1, h - 1):
self.grid.set(tx + w - 1, ty + i, None)
elif wall_idx == 1:
for i in range(1, w - 1):
self.grid.set(tx + i, ty + h - 1, None)
elif wall_idx == 2:
for i in range(1, h - 1):
self.grid.set(tx, ty + i, None)
elif wall_idx == 3:
for i in range(1, w - 1):
self.grid.set(tx + i, ty, None)
else:
assert False, "invalid wall index"
# Mark the rooms as connected
room.doors[wall_idx] = True
neighbor.doors[(wall_idx+2) % 4] = True
def place_agent(self, i=None, j=None, rand_dir=True):
"""
Place the agent in a room
"""
if i == None:
i = self._rand_int(0, self.num_cols)
if j == None:
j = self._rand_int(0, self.num_rows)
room = self.room_grid[j][i]
# Find a position that is not right in front of an object
while True:
super().place_agent(room.top, room.size, rand_dir, max_tries=1000)
front_cell = self.grid.get(*self.front_pos)
if front_cell is None or front_cell.type is 'wall':
break
return self.agent_pos
def connect_all(self, door_colors=COLOR_NAMES, max_itrs=5000):
"""
Make sure that all rooms are reachable by the agent from its
starting position
"""
start_room = self.room_from_pos(*self.agent_pos)
added_doors = []
def find_reach():
reach = set()
stack = [start_room]
while len(stack) > 0:
room = stack.pop()
if room in reach:
continue
reach.add(room)
for i in range(0, 4):
if room.doors[i]:
stack.append(room.neighbors[i])
return reach
num_itrs = 0
while True:
# This is to handle rare situations where random sampling produces
# a level that cannot be connected, producing in an infinite loop
if num_itrs > max_itrs:
raise RecursionError('connect_all failed')
num_itrs += 1
# If all rooms are reachable, stop
reach = find_reach()
if len(reach) == self.num_rows * self.num_cols:
break
# Pick a random room and door position
i = self._rand_int(0, self.num_cols)
j = self._rand_int(0, self.num_rows)
k = self._rand_int(0, 4)
room = self.get_room(i, j)
# If there is already a door there, skip
if not room.door_pos[k] or room.doors[k]:
continue
if room.locked or room.neighbors[k].locked:
continue
color = self._rand_elem(door_colors)
door, _ = self.add_door(i, j, k, color, False)
added_doors.append(door)
return added_doors
def add_distractors(self, i=None, j=None, num_distractors=10, all_unique=True):
"""
Add random objects that can potentially distract/confuse the agent.
"""
# Collect a list of existing objects
objs = []
for row in self.room_grid:
for room in row:
for obj in room.objs:
objs.append((obj.type, obj.color))
# List of distractors added
dists = []
while len(dists) < num_distractors:
color = self._rand_elem(COLOR_NAMES)
type = self._rand_elem(['key', 'ball', 'box'])
obj = (type, color)
if all_unique and obj in objs:
continue
# Add the object to a random room if no room specified
room_i = i
room_j = j
if room_i == None:
room_i = self._rand_int(0, self.num_cols)
if room_j == None:
room_j = self._rand_int(0, self.num_rows)
dist, pos = self.add_object(room_i, room_j, *obj)
objs.append(obj)
dists.append(dist)
return dists
================================================
FILE: d4rl/d4rl/gym_minigrid/window.py
================================================
import sys
import numpy as np
# Only ask users to install matplotlib if they actually need it
try:
import matplotlib.pyplot as plt
except:
print('To display the environment in a window, please install matplotlib, eg:')
print('pip3 install --user matplotlib')
sys.exit(-1)
class Window:
"""
Window to draw a gridworld instance using Matplotlib
"""
def __init__(self, title):
self.fig = None
self.imshow_obj = None
# Create the figure and axes
self.fig, self.ax = plt.subplots()
# Show the env name in the window title
self.fig.canvas.set_window_title(title)
# Turn off x/y axis numbering/ticks
self.ax.set_xticks([], [])
self.ax.set_yticks([], [])
# Flag indicating the window was closed
self.closed = False
def close_handler(evt):
self.closed = True
self.fig.canvas.mpl_connect('close_event', close_handler)
def show_img(self, img):
"""
Show an image or update the image being shown
"""
# Show the first image of the environment
if self.imshow_obj is None:
self.imshow_obj = self.ax.imshow(img, interpolation='bilinear')
self.imshow_obj.set_data(img)
self.fig.canvas.draw()
# Let matplotlib process UI events
# This is needed for interactive mode to work properly
plt.pause(0.001)
def set_caption(self, text):
"""
Set/update the caption text below the image
"""
plt.xlabel(text)
def reg_key_handler(self, key_handler):
"""
Register a keyboard event handler
"""
# Keyboard handler
self.fig.canvas.mpl_connect('key_press_event', key_handler)
def show(self, block=True):
"""
Show the window, and start an event loop
"""
# If not blocking, trigger interactive mode
if not block:
plt.ion()
# Show the plot
# In non-interative mode, this enters the matplotlib event loop
# In interactive mode, this call does not block
plt.show()
def close(self):
"""
Close the window
"""
plt.close()
================================================
FILE: d4rl/d4rl/gym_minigrid/wrappers.py
================================================
import math
import operator
from functools import reduce
import numpy as np
import gym
from gym import error, spaces, utils
from d4rl.gym_minigrid.minigrid import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX
class ReseedWrapper(gym.core.Wrapper):
"""
Wrapper to always regenerate an environment with the same set of seeds.
This can be used to force an environment to always keep the same
configuration when reset.
"""
def __init__(self, env, seeds=[0], seed_idx=0):
self.seeds = list(seeds)
self.seed_idx = seed_idx
super().__init__(env)
def reset(self, **kwargs):
seed = self.seeds[self.seed_idx]
self.seed_idx = (self.seed_idx + 1) % len(self.seeds)
self.env.seed(seed)
return self.env.reset(**kwargs)
def step(self, action):
obs, reward, done, info = self.env.step(action)
return obs, reward, done, info
class ActionBonus(gym.core.Wrapper):
"""
Wrapper which adds an exploration bonus.
This is a reward to encourage exploration of less
visited (state,action) pairs.
"""
def __init__(self, env):
super().__init__(env)
self.counts = {}
def step(self, action):
obs, reward, done, info = self.env.step(action)
env = self.unwrapped
tup = (tuple(env.agent_pos), env.agent_dir, action)
# Get the count for this (s,a) pair
pre_count = 0
if tup in self.counts:
pre_count = self.counts[tup]
# Update the count for this (s,a) pair
new_count = pre_count + 1
self.counts[tup] = new_count
bonus = 1 / math.sqrt(new_count)
reward += bonus
return obs, reward, done, info
def reset(self, **kwargs):
return self.env.reset(**kwargs)
class StateBonus(gym.core.Wrapper):
"""
Adds an exploration bonus based on which positions
are visited on the grid.
"""
def __init__(self, env):
super().__init__(env)
self.counts = {}
def step(self, action):
obs, reward, done, info = self.env.step(action)
# Tuple based on which we index the counts
# We use the position after an update
env = self.unwrapped
tup = (tuple(env.agent_pos))
# Get the count for this key
pre_count = 0
if tup in self.counts:
pre_count = self.counts[tup]
# Update the count for this key
new_count = pre_count + 1
self.counts[tup] = new_count
bonus = 1 / math.sqrt(new_count)
reward += bonus
return obs, reward, done, info
def reset(self, **kwargs):
return self.env.reset(**kwargs)
class ImgObsWrapper(gym.core.ObservationWrapper):
"""
Use the image as the only observation output, no language/mission.
"""
def __init__(self, env):
super().__init__(env)
self.observation_space = env.observation_space.spaces['image']
def observation(self, obs):
return obs['image']
class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
"""
Wrapper to get a one-hot encoding of a partially observable
agent view as observation.
"""
def __init__(self, env, tile_size=8):
super().__init__(env)
self.tile_size = tile_size
obs_shape = env.observation_space['image'].shape
# Number of bits per cell
num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX)
self.observation_space.spaces["image"] = spaces.Box(
low=0,
high=255,
shape=(obs_shape[0], obs_shape[1], num_bits),
dtype='uint8'
)
def observation(self, obs):
img = obs['image']
out = np.zeros(self.observation_space.shape, dtype='uint8')
for i in range(img.shape[0]):
for j in range(img.shape[1]):
type = img[i, j, 0]
color = img[i, j, 1]
state = img[i, j, 2]
out[i, j, type] = 1
out[i, j, len(OBJECT_TO_IDX) + color] = 1
out[i, j, len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + state] = 1
return {
'mission': obs['mission'],
'image': out
}
class RGBImgObsWrapper(gym.core.ObservationWrapper):
"""
Wrapper to use fully observable RGB image as the only observation output,
no language/mission. This can be used to have the agent to solve the
gridworld in pixel space.
"""
def __init__(self, env, tile_size=8):
super().__init__(env)
self.tile_size = tile_size
self.observation_space.spaces['image'] = spaces.Box(
low=0,
high=255,
shape=(self.env.width*tile_size, self.env.height*tile_size, 3),
dtype='uint8'
)
def observation(self, obs):
env = self.unwrapped
rgb_img = env.render(
mode='rgb_array',
highlight=False,
tile_size=self.tile_size
)
return {
'mission': obs['mission'],
'image': rgb_img
}
class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
"""
Wrapper to use partially observable RGB image as the only observation output
This can be used to have the agent to solve the gridworld in pixel space.
"""
def __init__(self, env, tile_size=8):
super().__init__(env)
self.tile_size = tile_size
obs_shape = env.observation_space['image'].shape
self.observation_space.spaces['image'] = spaces.Box(
low=0,
high=255,
shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3),
dtype='uint8'
)
def observation(self, obs):
env = self.unwrapped
rgb_img_partial = env.get_obs_render(
obs['image'],
tile_size=self.tile_size
)
return {
'mission': obs['mission'],
'image': rgb_img_partial
}
class FullyObsWrapper(gym.core.ObservationWrapper):
"""
Fully observable gridworld using a compact grid encoding
"""
def __init__(self, env):
super().__init__(env)
self.observation_space.spaces["image"] = spaces.Box(
low=0,
high=255,
shape=(self.env.width, self.env.height, 3), # number of cells
dtype='uint8'
)
def observation(self, obs):
env = self.unwrapped
full_grid = env.grid.encode()
full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array([
OBJECT_TO_IDX['agent'],
COLOR_TO_IDX['red'],
env.agent_dir
])
return {
'mission': obs['mission'],
'image': full_grid
}
class FlatObsWrapper(gym.core.ObservationWrapper):
"""
Encode mission strings using a one-hot scheme,
and combine these with observed images into one flat array
"""
def __init__(self, env, maxStrLen=96):
super().__init__(env)
self.maxStrLen = maxStrLen
self.numCharCodes = 27
imgSpace = env.observation_space.spaces['image']
imgSize = reduce(operator.mul, imgSpace.shape, 1)
self.observation_space = spaces.Box(
low=0,
high=255,
shape=(1, imgSize + self.numCharCodes * self.maxStrLen),
dtype='uint8'
)
self.cachedStr = None
self.cachedArray = None
def observation(self, obs):
image = obs['image']
mission = obs['mission']
# Cache the last-encoded mission string
if mission != self.cachedStr:
assert len(mission) <= self.maxStrLen, 'mission string too long ({} chars)'.format(len(mission))
mission = mission.lower()
strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype='float32')
for idx, ch in enumerate(mission):
if ch >= 'a' and ch <= 'z':
chNo = ord(ch) - ord('a')
elif ch == ' ':
chNo = ord('z') - ord('a') + 1
assert chNo < self.numCharCodes, '%s : %d' % (ch, chNo)
strArray[idx, chNo] = 1
self.cachedStr = mission
self.cachedArray = strArray
obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))
return obs
class ViewSizeWrapper(gym.core.Wrapper):
"""
Wrapper to customize the agent field of view size.
This cannot be used with fully observable wrappers.
"""
def __init__(self, env, agent_view_size=7):
super().__init__(env)
# Override default view size
env.unwrapped.agent_view_size = agent_view_size
# Compute observation space with specified view size
observation_space = gym.spaces.Box(
low=0,
high=255,
shape=(agent_view_size, agent_view_size, 3),
dtype='uint8'
)
# Override the environment's observation space
self.observation_space = spaces.Dict({
'image': observation_space
})
def reset(self, **kwargs):
return self.env.reset(**kwargs)
def step(self, action):
return self.env.step(action)
================================================
FILE: d4rl/d4rl/gym_mujoco/__init__.py
================================================
from gym.envs.registration import register
from d4rl.gym_mujoco import gym_envs
from d4rl import infos
# V1 envs
for agent in ['hopper', 'halfcheetah', 'ant', 'walker2d']:
for dataset in ['random', 'medium', 'expert', 'medium-expert', 'medium-replay', 'full-replay']:
for version in ['v1', 'v2']:
env_name = '%s-%s-%s' % (agent, dataset, version)
register(
id=env_name,
entry_point='d4rl.gym_mujoco.gym_envs:get_%s_env' % agent.replace('halfcheetah', 'cheetah').replace('walker2d', 'walker'),
max_episode_steps=1000,
kwargs={
'deprecated': version != 'v2',
'ref_min_score': infos.REF_MIN_SCORE[env_name],
'ref_max_score': infos.REF_MAX_SCORE[env_name],
'dataset_url': infos.DATASET_URLS[env_name]
}
)
HOPPER_RANDOM_SCORE = -20.272305
HALFCHEETAH_RANDOM_SCORE = -280.178953
WALKER_RANDOM_SCORE = 1.629008
ANT_RANDOM_SCORE = -325.6
HOPPER_EXPERT_SCORE = 3234.3
HALFCHEETAH_EXPERT_SCORE = 12135.0
WALKER_EXPERT_SCORE = 4592.3
ANT_EXPERT_SCORE = 3879.7
# Single Policy datasets
register(
id='hopper-medium-v0',
entry_point='d4rl.gym_mujoco.gym_envs:get_hopper_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'ref_min_score': HOPPER_RANDOM_SCORE,
'ref_max_score': HOPPER_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_medium.hdf5'
}
)
register(
id='halfcheetah-medium-v0',
entry_point='d4rl.gym_mujoco.gym_envs:get_cheetah_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'ref_min_score': HALFCHEETAH_RANDOM_SCORE,
'ref_max_score': HALFCHEETAH_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_medium.hdf5'
}
)
register(
id='walker2d-medium-v0',
entry_point='d4rl.gym_mujoco.gym_envs:get_walker_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'ref_min_score': WALKER_RANDOM_SCORE,
'ref_max_score': WALKER_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_medium.hdf5'
}
)
register(
id='hopper-expert-v0',
entry_point='d4rl.gym_mujoco.gym_envs:get_hopper_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'ref_min_score': HOPPER_RANDOM_SCORE,
'ref_max_score': HOPPER_EXPERT_SCORE,
'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_expert.hdf5'
}
)
register(
id='halfcheetah-expert-v0',
entry_point='d4rl.gym_mujoco.gym_envs:get_cheetah_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'ref_min_score': HALFCHEETAH_RANDOM_SCORE,
'ref_max_score': HALFCHEETAH_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_expert.hdf5'
}
)
register(
id='walker2d-expert-v0',
entry_point='d4rl.gym_mujoco.gym_envs:get_walker_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'ref_min_score': WALKER_RANDOM_SCORE,
'ref_max_score': WALKER_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_expert.hdf5'
}
)
register(
id='hopper-random-v0',
entry_point='d4rl.gym_mujoco.gym_envs:get_hopper_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'ref_min_score': HOPPER_RANDOM_SCORE,
'ref_max_score': HOPPER_EXPERT_SCORE,
'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_random.hdf5'
}
)
register(
id='halfcheetah-random-v0',
entry_point='d4rl.gym_mujoco.gym_envs:get_cheetah_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'ref_min_score': HALFCHEETAH_RANDOM_SCORE,
'ref_max_score': HALFCHEETAH_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_random.hdf5'
}
)
register(
id='walker2d-random-v0',
entry_point='d4rl.gym_mujoco.gym_envs:get_walker_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'ref_min_score': WALKER_RANDOM_SCORE,
'ref_max_score': WALKER_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_random.hdf5'
}
)
# Mixed datasets
register(
id='hopper-medium-replay-v0',
entry_point='d4rl.gym_mujoco.gym_envs:get_hopper_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'ref_min_score': HOPPER_RANDOM_SCORE,
'ref_max_score': HOPPER_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_mixed.hdf5'
},
)
register(
id='walker2d-medium-replay-v0',
entry_point='d4rl.gym_mujoco.gym_envs:get_walker_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'ref_min_score': WALKER_RANDOM_SCORE,
'ref_max_score': WALKER_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker_mixed.hdf5'
}
)
register(
id='halfcheetah-medium-replay-v0',
entry_point='d4rl.gym_mujoco.gym_envs:get_cheetah_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'ref_min_score': HALFCHEETAH_RANDOM_SCORE,
'ref_max_score': HALFCHEETAH_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_mixed.hdf5'
}
)
# Mixtures of random/medium and experts
register(
id='walker2d-medium-expert-v0',
entry_point='d4rl.gym_mujoco.gym_envs:get_walker_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'ref_min_score': WALKER_RANDOM_SCORE,
'ref_max_score': WALKER_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_medium_expert.hdf5'
}
)
register(
id='halfcheetah-medium-expert-v0',
entry_point='d4rl.gym_mujoco.gym_envs:get_cheetah_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'ref_min_score': HALFCHEETAH_RANDOM_SCORE,
'ref_max_score': HALFCHEETAH_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_medium_expert.hdf5'
}
)
register(
id='hopper-medium-expert-v0',
entry_point='d4rl.gym_mujoco.gym_envs:get_hopper_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'ref_min_score': HOPPER_RANDOM_SCORE,
'ref_max_score': HOPPER_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_medium_expert.hdf5'
}
)
register(
id='ant-medium-expert-v0',
entry_point='d4rl.gym_mujoco.gym_envs:get_ant_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'ref_min_score': ANT_RANDOM_SCORE,
'ref_max_score': ANT_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_medium_expert.hdf5'
}
)
register(
id='ant-medium-replay-v0',
entry_point='d4rl.gym_mujoco.gym_envs:get_ant_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'ref_min_score': ANT_RANDOM_SCORE,
'ref_max_score': ANT_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_mixed.hdf5'
}
)
register(
id='ant-medium-v0',
entry_point='d4rl.gym_mujoco.gym_envs:get_ant_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'ref_min_score': ANT_RANDOM_SCORE,
'ref_max_score': ANT_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_medium.hdf5'
}
)
register(
id='ant-random-v0',
entry_point='d4rl.gym_mujoco.gym_envs:get_ant_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'ref_min_score': ANT_RANDOM_SCORE,
'ref_max_score': ANT_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_random.hdf5'
}
)
register(
id='ant-expert-v0',
entry_point='d4rl.gym_mujoco.gym_envs:get_ant_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'ref_min_score': ANT_RANDOM_SCORE,
'ref_max_score': ANT_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_expert.hdf5'
}
)
register(
id='ant-random-expert-v0',
entry_point='d4rl.gym_mujoco.gym_envs:get_ant_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'ref_min_score': ANT_RANDOM_SCORE,
'ref_max_score': ANT_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_random_expert.hdf5'
}
)
================================================
FILE: d4rl/d4rl/gym_mujoco/gym_envs.py
================================================
from .. import offline_env
from gym.envs.mujoco import HalfCheetahEnv, AntEnv, HopperEnv, Walker2dEnv
from ..utils.wrappers import NormalizedBoxEnv
class OfflineAntEnv(AntEnv, offline_env.OfflineEnv):
def __init__(self, **kwargs):
AntEnv.__init__(self,)
offline_env.OfflineEnv.__init__(self, **kwargs)
class OfflineHopperEnv(HopperEnv, offline_env.OfflineEnv):
def __init__(self, **kwargs):
HopperEnv.__init__(self,)
offline_env.OfflineEnv.__init__(self, **kwargs)
class OfflineHalfCheetahEnv(HalfCheetahEnv, offline_env.OfflineEnv):
def __init__(self, **kwargs):
HalfCheetahEnv.__init__(self,)
offline_env.OfflineEnv.__init__(self, **kwargs)
class OfflineWalker2dEnv(Walker2dEnv, offline_env.OfflineEnv):
def __init__(self, **kwargs):
Walker2dEnv.__init__(self,)
offline_env.OfflineEnv.__init__(self, **kwargs)
def get_ant_env(**kwargs):
return NormalizedBoxEnv(OfflineAntEnv(**kwargs))
def get_cheetah_env(**kwargs):
return NormalizedBoxEnv(OfflineHalfCheetahEnv(**kwargs))
def get_hopper_env(**kwargs):
return NormalizedBoxEnv(OfflineHopperEnv(**kwargs))
def get_walker_env(**kwargs):
return NormalizedBoxEnv(OfflineWalker2dEnv(**kwargs))
if __name__ == '__main__':
"""Example usage of these envs"""
pass
================================================
FILE: d4rl/d4rl/hand_manipulation_suite/Adroit/.gitignore
================================================
*.DS_Store
================================================
FILE: d4rl/d4rl/hand_manipulation_suite/Adroit/Adroit_hand.xml
================================================
================================================
FILE: d4rl/d4rl/hand_manipulation_suite/Adroit/Adroit_hand_withOverlay.xml
================================================
================================================
FILE: d4rl/d4rl/hand_manipulation_suite/Adroit/LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: d4rl/d4rl/hand_manipulation_suite/Adroit/README.md
================================================
# Adroit Manipulation Platform
Adroit manipulation platform is reconfigurable, tendon-driven, pneumatically-actuated platform designed and developed by [Vikash Kumar](https://vikashplus.github.io/) during this Ph.D. ([Thesis: Manipulators and Manipulation in high dimensional spaces](https://digital.lib.washington.edu/researchworks/handle/1773/38104)) to study dynamic dexterous manipulation. Adroit is comprised of the [Shadow Hand](https://www.shadowrobot.com/products/dexterous-hand/) skeleton (developed by [Shadow Robot company](https://www.shadowrobot.com/)) and a custom arm, and is powered by a custom actuation sysem. This custom actuation system allows Adroit to move the ShadowHand skeleton faster than a human hand (70 msec limit-to-limit movement, 30 msec overall reflex latency), generate sufficient forces (40 N at each finger tendon, 125N at each wrist tendon), and achieve high compliance on the mechanism level (6 grams of external force at the fingertip displaces the finger when the system is powered.) This combination of speed, force, and compliance is a prerequisite for dexterous manipulation, yet it has never before been achieved with a tendon-driven system, let alone a system with 24 degrees of freedom and 40 tendons.
## Mujoco Model
Adroit is a 28 degree of freedom system which consists of a 24 degrees of freedom **ShadowHand** and a 4 degree of freedom arm. This repository contains the Mujoco Models of the system developed with extreme care and great attention to the details.
## In Projects
Adroit has been used in a wide variety of project. A small list is appended below. Details of these projects can be found [here](https://vikashplus.github.io/).
[](https://vikashplus.github.io/)
## In News and Media
Adroit has found quite some attention in the world media. Details can be found [here](https://vikashplus.github.io/news.html)
[](https://vikashplus.github.io/news.html)
## Citation
If the contents of this repo helped you, please consider citing
```
@phdthesis{Kumar2016thesis,
title = {Manipulators and Manipulation in high dimensional spaces},
school = {University of Washington, Seattle},
author = {Kumar, Vikash},
year = {2016},
url = {https://digital.lib.washington.edu/researchworks/handle/1773/38104}
}
```
================================================
FILE: d4rl/d4rl/hand_manipulation_suite/Adroit/resources/assets.xml
================================================
================================================
FILE: d4rl/d4rl/hand_manipulation_suite/Adroit/resources/chain.xml
================================================
================================================
FILE: d4rl/d4rl/hand_manipulation_suite/Adroit/resources/chain1.xml
================================================
================================================
FILE: d4rl/d4rl/hand_manipulation_suite/Adroit/resources/joint_position_actuation.xml
================================================
================================================
FILE: d4rl/d4rl/hand_manipulation_suite/Adroit/resources/tendon_torque_actuation.xml
================================================
================================================
FILE: d4rl/d4rl/hand_manipulation_suite/__init__.py
================================================
from gym.envs.registration import register
from mjrl.envs.mujoco_env import MujocoEnv
from d4rl.hand_manipulation_suite.door_v0 import DoorEnvV0
from d4rl.hand_manipulation_suite.hammer_v0 import HammerEnvV0
from d4rl.hand_manipulation_suite.pen_v0 import PenEnvV0
from d4rl.hand_manipulation_suite.relocate_v0 import RelocateEnvV0
from d4rl import infos
# V1 envs
MAX_STEPS = {'hammer': 200, 'relocate': 200, 'door': 200, 'pen': 100}
LONG_HORIZONS = {'hammer': 600, 'pen': 200, 'relocate': 500, 'door': 300}
ENV_MAPPING = {'hammer': 'HammerEnvV0', 'relocate': 'RelocateEnvV0', 'door': 'DoorEnvV0', 'pen': 'PenEnvV0'}
for agent in ['hammer', 'pen', 'relocate', 'door']:
for dataset in ['human', 'expert', 'cloned']:
env_name = '%s-%s-v1' % (agent, dataset)
register(
id=env_name,
entry_point='d4rl.hand_manipulation_suite:' + ENV_MAPPING[agent],
max_episode_steps=MAX_STEPS[agent],
kwargs={
'ref_min_score': infos.REF_MIN_SCORE[env_name],
'ref_max_score': infos.REF_MAX_SCORE[env_name],
'dataset_url': infos.DATASET_URLS[env_name]
}
)
if dataset == 'human':
longhorizon_env_name = '%s-human-longhorizon-v1' % agent
register(
id=longhorizon_env_name,
entry_point='d4rl.hand_manipulation_suite:' + ENV_MAPPING[agent],
max_episode_steps=LONG_HORIZONS[agent],
kwargs={
'ref_min_score': infos.REF_MIN_SCORE[env_name],
'ref_max_score': infos.REF_MAX_SCORE[env_name],
'dataset_url': infos.DATASET_URLS[env_name]
}
)
DOOR_RANDOM_SCORE = -56.512833
DOOR_EXPERT_SCORE = 2880.5693087298737
HAMMER_RANDOM_SCORE = -274.856578
HAMMER_EXPERT_SCORE = 12794.134825156867
PEN_RANDOM_SCORE = 96.262799
PEN_EXPERT_SCORE = 3076.8331017826877
RELOCATE_RANDOM_SCORE = -6.425911
RELOCATE_EXPERT_SCORE = 4233.877797728884
# Swing the door open
register(
id='door-v0',
entry_point='d4rl.hand_manipulation_suite:DoorEnvV0',
max_episode_steps=200,
)
register(
id='door-human-v0',
entry_point='d4rl.hand_manipulation_suite:DoorEnvV0',
max_episode_steps=200,
kwargs={
'deprecated': True,
'ref_min_score': DOOR_RANDOM_SCORE,
'ref_max_score': DOOR_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_demos_clipped.hdf5'
}
)
register(
id='door-human-longhorizon-v0',
entry_point='d4rl.hand_manipulation_suite:DoorEnvV0',
max_episode_steps=300,
kwargs={
'deprecated': True,
'ref_min_score': DOOR_RANDOM_SCORE,
'ref_max_score': DOOR_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_demos_clipped.hdf5'
}
)
register(
id='door-cloned-v0',
entry_point='d4rl.hand_manipulation_suite:DoorEnvV0',
max_episode_steps=200,
kwargs={
'deprecated': True,
'ref_min_score': DOOR_RANDOM_SCORE,
'ref_max_score': DOOR_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-demos-v0-bc-combined.hdf5'
}
)
register(
id='door-expert-v0',
entry_point='d4rl.hand_manipulation_suite:DoorEnvV0',
max_episode_steps=200,
kwargs={
'deprecated': True,
'ref_min_score': DOOR_RANDOM_SCORE,
'ref_max_score': DOOR_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_expert_clipped.hdf5'
}
)
# Hammer a nail into the board
register(
id='hammer-v0',
entry_point='d4rl.hand_manipulation_suite:HammerEnvV0',
max_episode_steps=200,
)
register(
id='hammer-human-v0',
entry_point='d4rl.hand_manipulation_suite:HammerEnvV0',
max_episode_steps=200,
kwargs={
'deprecated': True,
'ref_min_score': HAMMER_RANDOM_SCORE,
'ref_max_score': HAMMER_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_demos_clipped.hdf5'
}
)
register(
id='hammer-human-longhorizon-v0',
entry_point='d4rl.hand_manipulation_suite:HammerEnvV0',
max_episode_steps=600,
kwargs={
'deprecated': True,
'ref_min_score': HAMMER_RANDOM_SCORE,
'ref_max_score': HAMMER_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_demos_clipped.hdf5'
}
)
register(
id='hammer-cloned-v0',
entry_point='d4rl.hand_manipulation_suite:HammerEnvV0',
max_episode_steps=200,
kwargs={
'deprecated': True,
'ref_min_score': HAMMER_RANDOM_SCORE,
'ref_max_score': HAMMER_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-demos-v0-bc-combined.hdf5'
}
)
register(
id='hammer-expert-v0',
entry_point='d4rl.hand_manipulation_suite:HammerEnvV0',
max_episode_steps=200,
kwargs={
'deprecated': True,
'ref_min_score': HAMMER_RANDOM_SCORE,
'ref_max_score': HAMMER_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_expert_clipped.hdf5'
}
)
# Reposition a pen in hand
register(
id='pen-v0',
entry_point='d4rl.hand_manipulation_suite:PenEnvV0',
max_episode_steps=100,
)
register(
id='pen-human-v0',
entry_point='d4rl.hand_manipulation_suite:PenEnvV0',
max_episode_steps=100,
kwargs={
'deprecated': True,
'ref_min_score': PEN_RANDOM_SCORE,
'ref_max_score': PEN_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_demos_clipped.hdf5'
}
)
register(
id='pen-human-longhorizon-v0',
entry_point='d4rl.hand_manipulation_suite:PenEnvV0',
max_episode_steps=200,
kwargs={
'deprecated': True,
'ref_min_score': PEN_RANDOM_SCORE,
'ref_max_score': PEN_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_demos_clipped.hdf5'
}
)
register(
id='pen-cloned-v0',
entry_point='d4rl.hand_manipulation_suite:PenEnvV0',
max_episode_steps=100,
kwargs={
'deprecated': True,
'ref_min_score': PEN_RANDOM_SCORE,
'ref_max_score': PEN_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-demos-v0-bc-combined.hdf5'
}
)
register(
id='pen-expert-v0',
entry_point='d4rl.hand_manipulation_suite:PenEnvV0',
max_episode_steps=100,
kwargs={
'deprecated': True,
'ref_min_score': PEN_RANDOM_SCORE,
'ref_max_score': PEN_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_expert_clipped.hdf5'
}
)
# Relcoate an object to the target
register(
id='relocate-v0',
entry_point='d4rl.hand_manipulation_suite:RelocateEnvV0',
max_episode_steps=200,
)
register(
id='relocate-human-v0',
entry_point='d4rl.hand_manipulation_suite:RelocateEnvV0',
max_episode_steps=200,
kwargs={
'deprecated': True,
'ref_min_score': RELOCATE_RANDOM_SCORE,
'ref_max_score': RELOCATE_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_demos_clipped.hdf5'
}
)
register(
id='relocate-human-longhorizon-v0',
entry_point='d4rl.hand_manipulation_suite:RelocateEnvV0',
max_episode_steps=500,
kwargs={
'deprecated': True,
'ref_min_score': RELOCATE_RANDOM_SCORE,
'ref_max_score': RELOCATE_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_demos_clipped.hdf5'
}
)
register(
id='relocate-cloned-v0',
entry_point='d4rl.hand_manipulation_suite:RelocateEnvV0',
max_episode_steps=200,
kwargs={
'deprecated': True,
'ref_min_score': RELOCATE_RANDOM_SCORE,
'ref_max_score': RELOCATE_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-demos-v0-bc-combined.hdf5'
}
)
register(
id='relocate-expert-v0',
entry_point='d4rl.hand_manipulation_suite:RelocateEnvV0',
max_episode_steps=200,
kwargs={
'deprecated': True,
'ref_min_score': RELOCATE_RANDOM_SCORE,
'ref_max_score': RELOCATE_EXPERT_SCORE,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_expert_clipped.hdf5'
}
)
================================================
FILE: d4rl/d4rl/hand_manipulation_suite/assets/DAPG_Adroit.xml
================================================
================================================
FILE: d4rl/d4rl/hand_manipulation_suite/assets/DAPG_assets.xml
================================================
================================================
FILE: d4rl/d4rl/hand_manipulation_suite/assets/DAPG_door.xml
================================================
================================================
FILE: d4rl/d4rl/hand_manipulation_suite/assets/DAPG_hammer.xml
================================================
================================================
FILE: d4rl/d4rl/hand_manipulation_suite/assets/DAPG_pen.xml
================================================
================================================
FILE: d4rl/d4rl/hand_manipulation_suite/assets/DAPG_relocate.xml
================================================
================================================
FILE: d4rl/d4rl/hand_manipulation_suite/door_v0.py
================================================
import numpy as np
from gym import utils
from gym import spaces
from mjrl.envs import mujoco_env
from mujoco_py import MjViewer
from d4rl import offline_env
import os
ADD_BONUS_REWARDS = True
class DoorEnvV0(mujoco_env.MujocoEnv, utils.EzPickle, offline_env.OfflineEnv):
def __init__(self, **kwargs):
offline_env.OfflineEnv.__init__(self, **kwargs)
self.door_hinge_did = 0
self.door_bid = 0
self.grasp_sid = 0
self.handle_sid = 0
curr_dir = os.path.dirname(os.path.abspath(__file__))
mujoco_env.MujocoEnv.__init__(self, curr_dir+'/assets/DAPG_door.xml', 5)
# Override action_space to -1, 1
self.action_space = spaces.Box(low=-1.0, high=1.0, dtype=np.float32, shape=self.action_space.shape)
# change actuator sensitivity
self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([10, 0, 0])
self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([1, 0, 0])
self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([0, -10, 0])
self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([0, -1, 0])
utils.EzPickle.__init__(self)
ob = self.reset_model()
self.act_mid = np.mean(self.model.actuator_ctrlrange, axis=1)
self.act_rng = 0.5*(self.model.actuator_ctrlrange[:,1]-self.model.actuator_ctrlrange[:,0])
self.door_hinge_did = self.model.jnt_dofadr[self.model.joint_name2id('door_hinge')]
self.grasp_sid = self.model.site_name2id('S_grasp')
self.handle_sid = self.model.site_name2id('S_handle')
self.door_bid = self.model.body_name2id('frame')
def step(self, a):
a = np.clip(a, -1.0, 1.0)
try:
a = self.act_mid + a*self.act_rng # mean center and scale
except:
a = a # only for the initialization phase
self.do_simulation(a, self.frame_skip)
ob = self.get_obs()
handle_pos = self.data.site_xpos[self.handle_sid].ravel()
palm_pos = self.data.site_xpos[self.grasp_sid].ravel()
door_pos = self.data.qpos[self.door_hinge_did]
# get to handle
reward = -0.1*np.linalg.norm(palm_pos-handle_pos)
# open door
reward += -0.1*(door_pos - 1.57)*(door_pos - 1.57)
# velocity cost
reward += -1e-5*np.sum(self.data.qvel**2)
if ADD_BONUS_REWARDS:
# Bonus
if door_pos > 0.2:
reward += 2
if door_pos > 1.0:
reward += 8
if door_pos > 1.35:
reward += 10
goal_achieved = True if door_pos >= 1.35 else False
return ob, reward, False, dict(goal_achieved=goal_achieved)
def get_obs(self):
# qpos for hand
# xpos for obj
# xpos for target
qp = self.data.qpos.ravel()
handle_pos = self.data.site_xpos[self.handle_sid].ravel()
palm_pos = self.data.site_xpos[self.grasp_sid].ravel()
door_pos = np.array([self.data.qpos[self.door_hinge_did]])
if door_pos > 1.0:
door_open = 1.0
else:
door_open = -1.0
latch_pos = qp[-1]
return np.concatenate([qp[1:-2], [latch_pos], door_pos, palm_pos, handle_pos, palm_pos-handle_pos, [door_open]])
def reset_model(self):
qp = self.init_qpos.copy()
qv = self.init_qvel.copy()
self.set_state(qp, qv)
self.model.body_pos[self.door_bid,0] = self.np_random.uniform(low=-0.3, high=-0.2)
self.model.body_pos[self.door_bid,1] = self.np_random.uniform(low=0.25, high=0.35)
self.model.body_pos[self.door_bid,2] = self.np_random.uniform(low=0.252, high=0.35)
self.sim.forward()
return self.get_obs()
def get_env_state(self):
"""
Get state of hand as well as objects and targets in the scene
"""
qp = self.data.qpos.ravel().copy()
qv = self.data.qvel.ravel().copy()
door_body_pos = self.model.body_pos[self.door_bid].ravel().copy()
return dict(qpos=qp, qvel=qv, door_body_pos=door_body_pos)
def set_env_state(self, state_dict):
"""
Set the state which includes hand as well as objects and targets in the scene
"""
qp = state_dict['qpos']
qv = state_dict['qvel']
self.set_state(qp, qv)
self.model.body_pos[self.door_bid] = state_dict['door_body_pos']
self.sim.forward()
def mj_viewer_setup(self):
self.viewer = MjViewer(self.sim)
self.viewer.cam.azimuth = 90
self.sim.forward()
self.viewer.cam.distance = 1.5
def evaluate_success(self, paths):
num_success = 0
num_paths = len(paths)
# success if door open for 25 steps
for path in paths:
if np.sum(path['env_infos']['goal_achieved']) > 25:
num_success += 1
success_percentage = num_success*100.0/num_paths
return success_percentage
================================================
FILE: d4rl/d4rl/hand_manipulation_suite/hammer_v0.py
================================================
import numpy as np
from gym import utils
from gym import spaces
from mjrl.envs import mujoco_env
from mujoco_py import MjViewer
from d4rl.utils.quatmath import quat2euler
from d4rl import offline_env
import os
ADD_BONUS_REWARDS = True
class HammerEnvV0(mujoco_env.MujocoEnv, utils.EzPickle, offline_env.OfflineEnv):
def __init__(self, **kwargs):
offline_env.OfflineEnv.__init__(self, **kwargs)
self.target_obj_sid = -1
self.S_grasp_sid = -1
self.obj_bid = -1
self.tool_sid = -1
self.goal_sid = -1
curr_dir = os.path.dirname(os.path.abspath(__file__))
mujoco_env.MujocoEnv.__init__(self, curr_dir+'/assets/DAPG_hammer.xml', 5)
# Override action_space to -1, 1
self.action_space = spaces.Box(low=-1.0, high=1.0, dtype=np.float32, shape=self.action_space.shape)
utils.EzPickle.__init__(self)
# change actuator sensitivity
self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([10, 0, 0])
self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([1, 0, 0])
self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([0, -10, 0])
self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([0, -1, 0])
self.target_obj_sid = self.sim.model.site_name2id('S_target')
self.S_grasp_sid = self.sim.model.site_name2id('S_grasp')
self.obj_bid = self.sim.model.body_name2id('Object')
self.tool_sid = self.sim.model.site_name2id('tool')
self.goal_sid = self.sim.model.site_name2id('nail_goal')
self.act_mid = np.mean(self.model.actuator_ctrlrange, axis=1)
self.act_rng = 0.5 * (self.model.actuator_ctrlrange[:, 1] - self.model.actuator_ctrlrange[:, 0])
def step(self, a):
a = np.clip(a, -1.0, 1.0)
try:
a = self.act_mid + a * self.act_rng # mean center and scale
except:
a = a # only for the initialization phase
self.do_simulation(a, self.frame_skip)
ob = self.get_obs()
obj_pos = self.data.body_xpos[self.obj_bid].ravel()
palm_pos = self.data.site_xpos[self.S_grasp_sid].ravel()
tool_pos = self.data.site_xpos[self.tool_sid].ravel()
target_pos = self.data.site_xpos[self.target_obj_sid].ravel()
goal_pos = self.data.site_xpos[self.goal_sid].ravel()
# get to hammer
reward = - 0.1 * np.linalg.norm(palm_pos - obj_pos)
# take hammer head to nail
reward -= np.linalg.norm((tool_pos - target_pos))
# make nail go inside
reward -= 10 * np.linalg.norm(target_pos - goal_pos)
# velocity penalty
reward -= 1e-2 * np.linalg.norm(self.data.qvel.ravel())
if ADD_BONUS_REWARDS:
# bonus for lifting up the hammer
if obj_pos[2] > 0.04 and tool_pos[2] > 0.04:
reward += 2
# bonus for hammering the nail
if (np.linalg.norm(target_pos - goal_pos) < 0.020):
reward += 25
if (np.linalg.norm(target_pos - goal_pos) < 0.010):
reward += 75
goal_achieved = True if np.linalg.norm(target_pos - goal_pos) < 0.010 else False
return ob, reward, False, dict(goal_achieved=goal_achieved)
def get_obs(self):
# qpos for hand
# xpos for obj
# xpos for target
qp = self.data.qpos.ravel()
qv = np.clip(self.data.qvel.ravel(), -1.0, 1.0)
obj_pos = self.data.body_xpos[self.obj_bid].ravel()
obj_rot = quat2euler(self.data.body_xquat[self.obj_bid].ravel()).ravel()
palm_pos = self.data.site_xpos[self.S_grasp_sid].ravel()
target_pos = self.data.site_xpos[self.target_obj_sid].ravel()
nail_impact = np.clip(self.sim.data.sensordata[self.sim.model.sensor_name2id('S_nail')], -1.0, 1.0)
return np.concatenate([qp[:-6], qv[-6:], palm_pos, obj_pos, obj_rot, target_pos, np.array([nail_impact])])
def reset_model(self):
self.sim.reset()
target_bid = self.model.body_name2id('nail_board')
self.model.body_pos[target_bid,2] = self.np_random.uniform(low=0.1, high=0.25)
self.sim.forward()
return self.get_obs()
def get_env_state(self):
"""
Get state of hand as well as objects and targets in the scene
"""
qpos = self.data.qpos.ravel().copy()
qvel = self.data.qvel.ravel().copy()
board_pos = self.model.body_pos[self.model.body_name2id('nail_board')].copy()
target_pos = self.data.site_xpos[self.target_obj_sid].ravel().copy()
return dict(qpos=qpos, qvel=qvel, board_pos=board_pos, target_pos=target_pos)
def set_env_state(self, state_dict):
"""
Set the state which includes hand as well as objects and targets in the scene
"""
qp = state_dict['qpos']
qv = state_dict['qvel']
board_pos = state_dict['board_pos']
self.set_state(qp, qv)
self.model.body_pos[self.model.body_name2id('nail_board')] = board_pos
self.sim.forward()
def mj_viewer_setup(self):
self.viewer = MjViewer(self.sim)
self.viewer.cam.azimuth = 45
self.viewer.cam.distance = 2.0
self.sim.forward()
def evaluate_success(self, paths):
num_success = 0
num_paths = len(paths)
# success if nail insude board for 25 steps
for path in paths:
if np.sum(path['env_infos']['goal_achieved']) > 25:
num_success += 1
success_percentage = num_success*100.0/num_paths
return success_percentage
================================================
FILE: d4rl/d4rl/hand_manipulation_suite/pen_v0.py
================================================
import numpy as np
from gym import utils
from gym import spaces
from mjrl.envs import mujoco_env
from d4rl.utils.quatmath import quat2euler, euler2quat
from d4rl import offline_env
from mujoco_py import MjViewer
import os
ADD_BONUS_REWARDS = True
class PenEnvV0(mujoco_env.MujocoEnv, utils.EzPickle, offline_env.OfflineEnv):
def __init__(self, **kwargs):
offline_env.OfflineEnv.__init__(self, **kwargs)
self.target_obj_bid = 0
self.S_grasp_sid = 0
self.eps_ball_sid = 0
self.obj_bid = 0
self.obj_t_sid = 0
self.obj_b_sid = 0
self.tar_t_sid = 0
self.tar_b_sid = 0
self.pen_length = 1.0
self.tar_length = 1.0
curr_dir = os.path.dirname(os.path.abspath(__file__))
mujoco_env.MujocoEnv.__init__(self, curr_dir+'/assets/DAPG_pen.xml', 5)
# Override action_space to -1, 1
self.action_space = spaces.Box(low=-1.0, high=1.0, dtype=np.float32, shape=self.action_space.shape)
# change actuator sensitivity
self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([10, 0, 0])
self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([1, 0, 0])
self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([0, -10, 0])
self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([0, -1, 0])
utils.EzPickle.__init__(self)
self.target_obj_bid = self.sim.model.body_name2id("target")
self.S_grasp_sid = self.sim.model.site_name2id('S_grasp')
self.obj_bid = self.sim.model.body_name2id('Object')
self.eps_ball_sid = self.sim.model.site_name2id('eps_ball')
self.obj_t_sid = self.sim.model.site_name2id('object_top')
self.obj_b_sid = self.sim.model.site_name2id('object_bottom')
self.tar_t_sid = self.sim.model.site_name2id('target_top')
self.tar_b_sid = self.sim.model.site_name2id('target_bottom')
self.pen_length = np.linalg.norm(self.data.site_xpos[self.obj_t_sid] - self.data.site_xpos[self.obj_b_sid])
self.tar_length = np.linalg.norm(self.data.site_xpos[self.tar_t_sid] - self.data.site_xpos[self.tar_b_sid])
self.act_mid = np.mean(self.model.actuator_ctrlrange, axis=1)
self.act_rng = 0.5*(self.model.actuator_ctrlrange[:,1]-self.model.actuator_ctrlrange[:,0])
def step(self, a):
a = np.clip(a, -1.0, 1.0)
try:
starting_up = False
a = self.act_mid + a*self.act_rng # mean center and scale
except:
starting_up = True
a = a # only for the initialization phase
self.do_simulation(a, self.frame_skip)
obj_pos = self.data.body_xpos[self.obj_bid].ravel()
desired_loc = self.data.site_xpos[self.eps_ball_sid].ravel()
obj_orien = (self.data.site_xpos[self.obj_t_sid] - self.data.site_xpos[self.obj_b_sid])/self.pen_length
desired_orien = (self.data.site_xpos[self.tar_t_sid] - self.data.site_xpos[self.tar_b_sid])/self.tar_length
# pos cost
dist = np.linalg.norm(obj_pos-desired_loc)
reward = -dist
# orien cost
orien_similarity = np.dot(obj_orien, desired_orien)
reward += orien_similarity
if ADD_BONUS_REWARDS:
# bonus for being close to desired orientation
if dist < 0.075 and orien_similarity > 0.9:
reward += 10
if dist < 0.075 and orien_similarity > 0.95:
reward += 50
# penalty for dropping the pen
done = False
if obj_pos[2] < 0.075:
reward -= 5
done = True if not starting_up else False
goal_achieved = True if (dist < 0.075 and orien_similarity > 0.95) else False
return self.get_obs(), reward, done, dict(goal_achieved=goal_achieved)
def get_obs(self):
qp = self.data.qpos.ravel()
obj_vel = self.data.qvel[-6:].ravel()
obj_pos = self.data.body_xpos[self.obj_bid].ravel()
desired_pos = self.data.site_xpos[self.eps_ball_sid].ravel()
obj_orien = (self.data.site_xpos[self.obj_t_sid] - self.data.site_xpos[self.obj_b_sid])/self.pen_length
desired_orien = (self.data.site_xpos[self.tar_t_sid] - self.data.site_xpos[self.tar_b_sid])/self.tar_length
return np.concatenate([qp[:-6], obj_pos, obj_vel, obj_orien, desired_orien,
obj_pos-desired_pos, obj_orien-desired_orien])
def reset_model(self):
qp = self.init_qpos.copy()
qv = self.init_qvel.copy()
self.set_state(qp, qv)
desired_orien = np.zeros(3)
desired_orien[0] = self.np_random.uniform(low=-1, high=1)
desired_orien[1] = self.np_random.uniform(low=-1, high=1)
self.model.body_quat[self.target_obj_bid] = euler2quat(desired_orien)
self.sim.forward()
return self.get_obs()
def get_env_state(self):
"""
Get state of hand as well as objects and targets in the scene
"""
qp = self.data.qpos.ravel().copy()
qv = self.data.qvel.ravel().copy()
desired_orien = self.model.body_quat[self.target_obj_bid].ravel().copy()
return dict(qpos=qp, qvel=qv, desired_orien=desired_orien)
def set_env_state(self, state_dict):
"""
Set the state which includes hand as well as objects and targets in the scene
"""
qp = state_dict['qpos']
qv = state_dict['qvel']
desired_orien = state_dict['desired_orien']
self.set_state(qp, qv)
self.model.body_quat[self.target_obj_bid] = desired_orien
self.sim.forward()
def mj_viewer_setup(self):
self.viewer = MjViewer(self.sim)
self.viewer.cam.azimuth = -45
self.sim.forward()
self.viewer.cam.distance = 1.0
def evaluate_success(self, paths):
num_success = 0
num_paths = len(paths)
# success if pen within 15 degrees of target for 20 steps
for path in paths:
if np.sum(path['env_infos']['goal_achieved']) > 20:
num_success += 1
success_percentage = num_success*100.0/num_paths
return success_percentage
================================================
FILE: d4rl/d4rl/hand_manipulation_suite/relocate_v0.py
================================================
import numpy as np
from gym import utils
from gym import spaces
from mjrl.envs import mujoco_env
from mujoco_py import MjViewer
from d4rl import offline_env
import os
ADD_BONUS_REWARDS = True
class RelocateEnvV0(mujoco_env.MujocoEnv, utils.EzPickle, offline_env.OfflineEnv):
def __init__(self, **kwargs):
offline_env.OfflineEnv.__init__(self, **kwargs)
self.target_obj_sid = 0
self.S_grasp_sid = 0
self.obj_bid = 0
curr_dir = os.path.dirname(os.path.abspath(__file__))
mujoco_env.MujocoEnv.__init__(self, curr_dir+'/assets/DAPG_relocate.xml', 5)
# Override action_space to -1, 1
self.action_space = spaces.Box(low=-1.0, high=1.0, dtype=np.float32, shape=self.action_space.shape)
# change actuator sensitivity
self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([10, 0, 0])
self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([1, 0, 0])
self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([0, -10, 0])
self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([0, -1, 0])
self.target_obj_sid = self.sim.model.site_name2id("target")
self.S_grasp_sid = self.sim.model.site_name2id('S_grasp')
self.obj_bid = self.sim.model.body_name2id('Object')
utils.EzPickle.__init__(self)
self.act_mid = np.mean(self.model.actuator_ctrlrange, axis=1)
self.act_rng = 0.5*(self.model.actuator_ctrlrange[:,1]-self.model.actuator_ctrlrange[:,0])
def step(self, a):
a = np.clip(a, -1.0, 1.0)
try:
a = self.act_mid + a*self.act_rng # mean center and scale
except:
a = a # only for the initialization phase
self.do_simulation(a, self.frame_skip)
ob = self.get_obs()
obj_pos = self.data.body_xpos[self.obj_bid].ravel()
palm_pos = self.data.site_xpos[self.S_grasp_sid].ravel()
target_pos = self.data.site_xpos[self.target_obj_sid].ravel()
reward = -0.1*np.linalg.norm(palm_pos-obj_pos) # take hand to object
if obj_pos[2] > 0.04: # if object off the table
reward += 1.0 # bonus for lifting the object
reward += -0.5*np.linalg.norm(palm_pos-target_pos) # make hand go to target
reward += -0.5*np.linalg.norm(obj_pos-target_pos) # make object go to target
if ADD_BONUS_REWARDS:
if np.linalg.norm(obj_pos-target_pos) < 0.1:
reward += 10.0 # bonus for object close to target
if np.linalg.norm(obj_pos-target_pos) < 0.05:
reward += 20.0 # bonus for object "very" close to target
goal_achieved = True if np.linalg.norm(obj_pos-target_pos) < 0.1 else False
return ob, reward, False, dict(goal_achieved=goal_achieved)
def get_obs(self):
# qpos for hand
# xpos for obj
# xpos for target
qp = self.data.qpos.ravel()
obj_pos = self.data.body_xpos[self.obj_bid].ravel()
palm_pos = self.data.site_xpos[self.S_grasp_sid].ravel()
target_pos = self.data.site_xpos[self.target_obj_sid].ravel()
return np.concatenate([qp[:-6], palm_pos-obj_pos, palm_pos-target_pos, obj_pos-target_pos])
def reset_model(self):
qp = self.init_qpos.copy()
qv = self.init_qvel.copy()
self.set_state(qp, qv)
self.model.body_pos[self.obj_bid,0] = self.np_random.uniform(low=-0.15, high=0.15)
self.model.body_pos[self.obj_bid,1] = self.np_random.uniform(low=-0.15, high=0.3)
self.model.site_pos[self.target_obj_sid, 0] = self.np_random.uniform(low=-0.2, high=0.2)
self.model.site_pos[self.target_obj_sid,1] = self.np_random.uniform(low=-0.2, high=0.2)
self.model.site_pos[self.target_obj_sid,2] = self.np_random.uniform(low=0.15, high=0.35)
self.sim.forward()
return self.get_obs()
def get_env_state(self):
"""
Get state of hand as well as objects and targets in the scene
"""
qp = self.data.qpos.ravel().copy()
qv = self.data.qvel.ravel().copy()
hand_qpos = qp[:30]
obj_pos = self.data.body_xpos[self.obj_bid].ravel()
palm_pos = self.data.site_xpos[self.S_grasp_sid].ravel()
target_pos = self.data.site_xpos[self.target_obj_sid].ravel()
return dict(hand_qpos=hand_qpos, obj_pos=obj_pos, target_pos=target_pos, palm_pos=palm_pos,
qpos=qp, qvel=qv)
def set_env_state(self, state_dict):
"""
Set the state which includes hand as well as objects and targets in the scene
"""
qp = state_dict['qpos']
qv = state_dict['qvel']
obj_pos = state_dict['obj_pos']
target_pos = state_dict['target_pos']
self.set_state(qp, qv)
self.model.body_pos[self.obj_bid] = obj_pos
self.model.site_pos[self.target_obj_sid] = target_pos
self.sim.forward()
def mj_viewer_setup(self):
self.viewer = MjViewer(self.sim)
self.viewer.cam.azimuth = 90
self.sim.forward()
self.viewer.cam.distance = 1.5
def evaluate_success(self, paths):
num_success = 0
num_paths = len(paths)
# success if object close to target for 25 steps
for path in paths:
if np.sum(path['env_infos']['goal_achieved']) > 25:
num_success += 1
success_percentage = num_success*100.0/num_paths
return success_percentage
================================================
FILE: d4rl/d4rl/infos.py
================================================
"""
This file holds all URLs and reference scores.
"""
#TODO(Justin): This is duplicated. Make all __init__ file URLs and scores point to this file.
DATASET_URLS = {
'maze2d-open-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-sparse.hdf5',
'maze2d-umaze-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-sparse-v1.hdf5',
'maze2d-medium-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-sparse-v1.hdf5',
'maze2d-large-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-sparse-v1.hdf5',
'maze2d-eval-umaze-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-umaze-sparse-v1.hdf5',
'maze2d-eval-medium-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-medium-sparse-v1.hdf5',
'maze2d-eval-large-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-large-sparse-v1.hdf5',
'maze2d-open-dense-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-dense.hdf5',
'maze2d-umaze-dense-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-dense-v1.hdf5',
'maze2d-medium-dense-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-dense-v1.hdf5',
'maze2d-large-dense-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-dense-v1.hdf5',
'maze2d-eval-umaze-dense-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-umaze-dense-v1.hdf5',
'maze2d-eval-medium-dense-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-medium-dense-v1.hdf5',
'maze2d-eval-large-dense-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-large-dense-v1.hdf5',
'minigrid-fourrooms-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/minigrid/minigrid4rooms.hdf5',
'minigrid-fourrooms-random-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/minigrid/minigrid4rooms_random.hdf5',
'pen-human-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_demos_clipped.hdf5',
'pen-cloned-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-demos-v0-bc-combined.hdf5',
'pen-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_expert_clipped.hdf5',
'hammer-human-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_demos_clipped.hdf5',
'hammer-cloned-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-demos-v0-bc-combined.hdf5',
'hammer-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_expert_clipped.hdf5',
'relocate-human-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_demos_clipped.hdf5',
'relocate-cloned-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-demos-v0-bc-combined.hdf5',
'relocate-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_expert_clipped.hdf5',
'door-human-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_demos_clipped.hdf5',
'door-cloned-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-demos-v0-bc-combined.hdf5',
'door-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_expert_clipped.hdf5',
'halfcheetah-random-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_random.hdf5',
'halfcheetah-medium-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_medium.hdf5',
'halfcheetah-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_expert.hdf5',
'halfcheetah-medium-replay-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_mixed.hdf5',
'halfcheetah-medium-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_medium_expert.hdf5',
'walker2d-random-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_random.hdf5',
'walker2d-medium-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_medium.hdf5',
'walker2d-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_expert.hdf5',
'walker2d-medium-replay-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker_mixed.hdf5',
'walker2d-medium-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_medium_expert.hdf5',
'hopper-random-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_random.hdf5',
'hopper-medium-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_medium.hdf5',
'hopper-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_expert.hdf5',
'hopper-medium-replay-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_mixed.hdf5',
'hopper-medium-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_medium_expert.hdf5',
'ant-random-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_random.hdf5',
'ant-medium-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_medium.hdf5',
'ant-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_expert.hdf5',
'ant-medium-replay-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_mixed.hdf5',
'ant-medium-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_medium_expert.hdf5',
'ant-random-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_random_expert.hdf5',
'antmaze-umaze-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse.hdf5',
'antmaze-umaze-diverse-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_True_multigoal_True_sparse.hdf5',
'antmaze-medium-play-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_False_sparse.hdf5',
'antmaze-medium-diverse-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_True_sparse.hdf5',
'antmaze-large-play-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_sparse.hdf5',
'antmaze-large-diverse-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_True_sparse.hdf5',
'antmaze-umaze-v2' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse_fixed.hdf5',
'antmaze-umaze-diverse-v2' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_u-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5',
'antmaze-medium-play-v2' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_big-maze_noisy_multistart_True_multigoal_False_sparse_fixed.hdf5',
'antmaze-medium-diverse-v2' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_big-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5',
'antmaze-large-play-v2' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_sparse_fixed.hdf5',
'antmaze-large-diverse-v2' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5',
'flow-ring-random-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-random.hdf5',
'flow-ring-controller-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-idm.hdf5',
'flow-merge-random-v0':'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-random.hdf5',
'flow-merge-controller-v0':'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-idm.hdf5',
'kitchen-complete-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/mini_kitchen_microwave_kettle_light_slider-v0.hdf5',
'kitchen-partial-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_light_slider-v0.hdf5',
'kitchen-mixed-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_bottomburner_light-v0.hdf5',
'carla-lane-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_lane_follow_flat-v0.hdf5',
'carla-town-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_subsamp_flat-v0.hdf5',
'carla-town-full-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_flat-v0.hdf5',
'bullet-halfcheetah-random-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_random.hdf5',
'bullet-halfcheetah-medium-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium.hdf5',
'bullet-halfcheetah-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_expert.hdf5',
'bullet-halfcheetah-medium-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_expert.hdf5',
'bullet-halfcheetah-medium-replay-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_replay.hdf5',
'bullet-hopper-random-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_random.hdf5',
'bullet-hopper-medium-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium.hdf5',
'bullet-hopper-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_expert.hdf5',
'bullet-hopper-medium-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_expert.hdf5',
'bullet-hopper-medium-replay-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_replay.hdf5',
'bullet-ant-random-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_random.hdf5',
'bullet-ant-medium-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium.hdf5',
'bullet-ant-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_expert.hdf5',
'bullet-ant-medium-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_expert.hdf5',
'bullet-ant-medium-replay-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_replay.hdf5',
'bullet-walker2d-random-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_random.hdf5',
'bullet-walker2d-medium-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium.hdf5',
'bullet-walker2d-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_expert.hdf5',
'bullet-walker2d-medium-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_expert.hdf5',
'bullet-walker2d-medium-replay-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_replay.hdf5',
'bullet-maze2d-open-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-open-sparse.hdf5',
'bullet-maze2d-umaze-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-umaze-sparse.hdf5',
'bullet-maze2d-medium-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-medium-sparse.hdf5',
'bullet-maze2d-large-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-large-sparse.hdf5',
}
REF_MIN_SCORE = {
'maze2d-open-v0' : 0.01 ,
'maze2d-umaze-v1' : 23.85 ,
'maze2d-medium-v1' : 13.13 ,
'maze2d-large-v1' : 6.7 ,
'maze2d-open-dense-v0' : 11.17817 ,
'maze2d-umaze-dense-v1' : 68.537689 ,
'maze2d-medium-dense-v1' : 44.264742 ,
'maze2d-large-dense-v1' : 30.569041 ,
'minigrid-fourrooms-v0' : 0.01442 ,
'minigrid-fourrooms-random-v0' : 0.01442 ,
'pen-human-v0' : 96.262799 ,
'pen-cloned-v0' : 96.262799 ,
'pen-expert-v0' : 96.262799 ,
'hammer-human-v0' : -274.856578 ,
'hammer-cloned-v0' : -274.856578 ,
'hammer-expert-v0' : -274.856578 ,
'relocate-human-v0' : -6.425911 ,
'relocate-cloned-v0' : -6.425911 ,
'relocate-expert-v0' : -6.425911 ,
'door-human-v0' : -56.512833 ,
'door-cloned-v0' : -56.512833 ,
'door-expert-v0' : -56.512833 ,
'halfcheetah-random-v0' : -280.178953 ,
'halfcheetah-medium-v0' : -280.178953 ,
'halfcheetah-expert-v0' : -280.178953 ,
'halfcheetah-medium-replay-v0' : -280.178953 ,
'halfcheetah-medium-expert-v0' : -280.178953 ,
'walker2d-random-v0' : 1.629008 ,
'walker2d-medium-v0' : 1.629008 ,
'walker2d-expert-v0' : 1.629008 ,
'walker2d-medium-replay-v0' : 1.629008 ,
'walker2d-medium-expert-v0' : 1.629008 ,
'hopper-random-v0' : -20.272305 ,
'hopper-medium-v0' : -20.272305 ,
'hopper-expert-v0' : -20.272305 ,
'hopper-medium-replay-v0' : -20.272305 ,
'hopper-medium-expert-v0' : -20.272305 ,
'ant-random-v0' : -325.6,
'ant-medium-v0' : -325.6,
'ant-expert-v0' : -325.6,
'ant-medium-replay-v0' : -325.6,
'ant-medium-expert-v0' : -325.6,
'antmaze-umaze-v0' : 0.0 ,
'antmaze-umaze-diverse-v0' : 0.0 ,
'antmaze-medium-play-v0' : 0.0 ,
'antmaze-medium-diverse-v0' : 0.0 ,
'antmaze-large-play-v0' : 0.0 ,
'antmaze-large-diverse-v0' : 0.0 ,
'antmaze-umaze-v2' : 0.0 ,
'antmaze-umaze-diverse-v2' : 0.0 ,
'antmaze-medium-play-v2' : 0.0 ,
'antmaze-medium-diverse-v2' : 0.0 ,
'antmaze-large-play-v2' : 0.0 ,
'antmaze-large-diverse-v2' : 0.0 ,
'kitchen-complete-v0' : 0.0 ,
'kitchen-partial-v0' : 0.0 ,
'kitchen-mixed-v0' : 0.0 ,
'flow-ring-random-v0' : -165.22 ,
'flow-ring-controller-v0' : -165.22 ,
'flow-merge-random-v0' : 118.67993 ,
'flow-merge-controller-v0' : 118.67993 ,
'carla-lane-v0': -0.8503839912088142,
'carla-town-v0': -114.81579500772153, # random score
'bullet-halfcheetah-random-v0': -1275.766996,
'bullet-halfcheetah-medium-v0': -1275.766996,
'bullet-halfcheetah-expert-v0': -1275.766996,
'bullet-halfcheetah-medium-expert-v0': -1275.766996,
'bullet-halfcheetah-medium-replay-v0': -1275.766996,
'bullet-hopper-random-v0': 20.058972,
'bullet-hopper-medium-v0': 20.058972,
'bullet-hopper-expert-v0': 20.058972,
'bullet-hopper-medium-expert-v0': 20.058972,
'bullet-hopper-medium-replay-v0': 20.058972,
'bullet-ant-random-v0': 373.705955,
'bullet-ant-medium-v0': 373.705955,
'bullet-ant-expert-v0': 373.705955,
'bullet-ant-medium-expert-v0': 373.705955,
'bullet-ant-medium-replay-v0': 373.705955,
'bullet-walker2d-random-v0': 16.523877,
'bullet-walker2d-medium-v0': 16.523877,
'bullet-walker2d-expert-v0': 16.523877,
'bullet-walker2d-medium-expert-v0': 16.523877,
'bullet-walker2d-medium-replay-v0': 16.523877,
'bullet-maze2d-open-v0': 8.750000,
'bullet-maze2d-umaze-v0': 32.460000,
'bullet-maze2d-medium-v0': 14.870000,
'bullet-maze2d-large-v0': 1.820000,
}
REF_MAX_SCORE = {
'maze2d-open-v0' : 20.66 ,
'maze2d-umaze-v1' : 161.86 ,
'maze2d-medium-v1' : 277.39 ,
'maze2d-large-v1' : 273.99 ,
'maze2d-open-dense-v0' : 27.166538620695782 ,
'maze2d-umaze-dense-v1' : 193.66285642381482 ,
'maze2d-medium-dense-v1' : 297.4552547777125 ,
'maze2d-large-dense-v1' : 303.4857382709002 ,
'minigrid-fourrooms-v0' : 2.89685 ,
'minigrid-fourrooms-random-v0' : 2.89685 ,
'pen-human-v0' : 3076.8331017826877 ,
'pen-cloned-v0' : 3076.8331017826877 ,
'pen-expert-v0' : 3076.8331017826877 ,
'hammer-human-v0' : 12794.134825156867 ,
'hammer-cloned-v0' : 12794.134825156867 ,
'hammer-expert-v0' : 12794.134825156867 ,
'relocate-human-v0' : 4233.877797728884 ,
'relocate-cloned-v0' : 4233.877797728884 ,
'relocate-expert-v0' : 4233.877797728884 ,
'door-human-v0' : 2880.5693087298737 ,
'door-cloned-v0' : 2880.5693087298737 ,
'door-expert-v0' : 2880.5693087298737 ,
'halfcheetah-random-v0' : 12135.0 ,
'halfcheetah-medium-v0' : 12135.0 ,
'halfcheetah-expert-v0' : 12135.0 ,
'halfcheetah-medium-replay-v0' : 12135.0 ,
'halfcheetah-medium-expert-v0' : 12135.0 ,
'walker2d-random-v0' : 4592.3 ,
'walker2d-medium-v0' : 4592.3 ,
'walker2d-expert-v0' : 4592.3 ,
'walker2d-medium-replay-v0' : 4592.3 ,
'walker2d-medium-expert-v0' : 4592.3 ,
'hopper-random-v0' : 3234.3 ,
'hopper-medium-v0' : 3234.3 ,
'hopper-expert-v0' : 3234.3 ,
'hopper-medium-replay-v0' : 3234.3 ,
'hopper-medium-expert-v0' : 3234.3 ,
'ant-random-v0' : 3879.7,
'ant-medium-v0' : 3879.7,
'ant-expert-v0' : 3879.7,
'ant-medium-replay-v0' : 3879.7,
'ant-medium-expert-v0' : 3879.7,
'antmaze-umaze-v0' : 1.0 ,
'antmaze-umaze-diverse-v0' : 1.0 ,
'antmaze-medium-play-v0' : 1.0 ,
'antmaze-medium-diverse-v0' : 1.0 ,
'antmaze-large-play-v0' : 1.0 ,
'antmaze-large-diverse-v0' : 1.0 ,
'antmaze-umaze-v2' : 1.0 ,
'antmaze-umaze-diverse-v2' : 1.0 ,
'antmaze-medium-play-v2' : 1.0 ,
'antmaze-medium-diverse-v2' : 1.0 ,
'antmaze-large-play-v2' : 1.0 ,
'antmaze-large-diverse-v2' : 1.0 ,
'kitchen-complete-v0' : 4.0 ,
'kitchen-partial-v0' : 4.0 ,
'kitchen-mixed-v0' : 4.0 ,
'flow-ring-random-v0' : 24.42 ,
'flow-ring-controller-v0' : 24.42 ,
'flow-merge-random-v0' : 330.03179 ,
'flow-merge-controller-v0' : 330.03179 ,
'carla-lane-v0': 1023.5784385429523,
'carla-town-v0': 2440.1772022247314, # avg dataset score
'bullet-halfcheetah-random-v0': 2381.6725,
'bullet-halfcheetah-medium-v0': 2381.6725,
'bullet-halfcheetah-expert-v0': 2381.6725,
'bullet-halfcheetah-medium-expert-v0': 2381.6725,
'bullet-halfcheetah-medium-replay-v0': 2381.6725,
'bullet-hopper-random-v0': 1441.8059623430963,
'bullet-hopper-medium-v0': 1441.8059623430963,
'bullet-hopper-expert-v0': 1441.8059623430963,
'bullet-hopper-medium-expert-v0': 1441.8059623430963,
'bullet-hopper-medium-replay-v0': 1441.8059623430963,
'bullet-ant-random-v0': 2650.495,
'bullet-ant-medium-v0': 2650.495,
'bullet-ant-expert-v0': 2650.495,
'bullet-ant-medium-expert-v0': 2650.495,
'bullet-ant-medium-replay-v0': 2650.495,
'bullet-walker2d-random-v0': 1623.6476303317536,
'bullet-walker2d-medium-v0': 1623.6476303317536,
'bullet-walker2d-expert-v0': 1623.6476303317536,
'bullet-walker2d-medium-expert-v0': 1623.6476303317536,
'bullet-walker2d-medium-replay-v0': 1623.6476303317536,
'bullet-maze2d-open-v0': 64.15,
'bullet-maze2d-umaze-v0': 153.99,
'bullet-maze2d-medium-v0': 238.05,
'bullet-maze2d-large-v0': 285.92,
}
#Gym-MuJoCo V1/V2 envs
for env in ['halfcheetah', 'hopper', 'walker2d', 'ant']:
for dset in ['random', 'medium', 'expert', 'medium-replay', 'full-replay', 'medium-expert']:
#v1 envs
dset_name = env+'_'+dset.replace('-', '_')+'-v1'
env_name = dset_name.replace('_', '-')
DATASET_URLS[env_name] = 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/%s.hdf5' % dset_name
REF_MIN_SCORE[env_name] = REF_MIN_SCORE[env+'-random-v0']
REF_MAX_SCORE[env_name] = REF_MAX_SCORE[env+'-random-v0']
#v2 envs
dset_name = env+'_'+dset.replace('-', '_')+'-v2'
env_name = dset_name.replace('_', '-')
DATASET_URLS[env_name] = 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/%s.hdf5' % dset_name
REF_MIN_SCORE[env_name] = REF_MIN_SCORE[env+'-random-v0']
REF_MAX_SCORE[env_name] = REF_MAX_SCORE[env+'-random-v0']
#Adroit v1 envs
for env in ['hammer', 'pen', 'relocate', 'door']:
for dset in ['human', 'expert', 'cloned']:
env_name = env+'-'+dset+'-v1'
DATASET_URLS[env_name] = 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/%s.hdf5' % env_name
REF_MIN_SCORE[env_name] = REF_MIN_SCORE[env+'-human-v0']
REF_MAX_SCORE[env_name] = REF_MAX_SCORE[env+'-human-v0']
================================================
FILE: d4rl/d4rl/kitchen/__init__.py
================================================
from .kitchen_envs import KitchenMicrowaveKettleLightSliderV0, KitchenMicrowaveKettleBottomBurnerLightV0
from gym.envs.registration import register
# Smaller dataset with only positive demonstrations.
register(
id='kitchen-complete-v0',
entry_point='d4rl.kitchen:KitchenMicrowaveKettleLightSliderV0',
max_episode_steps=280,
kwargs={
'ref_min_score': 0.0,
'ref_max_score': 4.0,
'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/mini_kitchen_microwave_kettle_light_slider-v0.hdf5'
}
)
# Whole dataset with undirected demonstrations. A subset of the demonstrations
# solve the task.
register(
id='kitchen-partial-v0',
entry_point='d4rl.kitchen:KitchenMicrowaveKettleLightSliderV0',
max_episode_steps=280,
kwargs={
'ref_min_score': 0.0,
'ref_max_score': 4.0,
'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_light_slider-v0.hdf5'
}
)
# Whole dataset with undirected demonstrations. No demonstration completely
# solves the task, but each demonstration partially solves different
# components of the task.
register(
id='kitchen-mixed-v0',
entry_point='d4rl.kitchen:KitchenMicrowaveKettleBottomBurnerLightV0',
max_episode_steps=280,
kwargs={
'ref_min_score': 0.0,
'ref_max_score': 4.0,
'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_bottomburner_light-v0.hdf5'
}
)
================================================
FILE: d4rl/d4rl/kitchen/adept_envs/.pylintrc
================================================
[MASTER]
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code.
extension-pkg-whitelist=
# Add files or directories to the blacklist. They should be base names, not
# paths.
ignore=CVS
# Add files or directories matching the regex patterns to the blacklist. The
# regex matches against base names, not paths.
ignore-patterns=
# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
#init-hook=
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
# number of processors available to use.
jobs=1
# Control the amount of potential inferred values when inferring a single
# object. This can help the performance when dealing with large functions or
# complex, nested conditions.
limit-inference-results=100
# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
# Pickle collected data for later comparisons.
persistent=yes
# Specify a configuration file.
#rcfile=
# When enabled, pylint would attempt to guess common misconfiguration and emit
# user-friendly hints instead of false-positive error messages.
suggestion-mode=yes
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.
confidence=
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once). You can also use "--disable=all" to
# disable everything first and then reenable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use "--disable=all --enable=classes
# --disable=W".
disable=relative-beyond-top-level
[REPORTS]
# Python expression which should return a note less than 10 (10 is the highest
# note). You have access to the variables errors warning, statement which
# respectively contain the number of errors / warnings messages and the total
# number of statements analyzed. This is used by the global evaluation report
# (RP0004).
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details.
#msg-template=
# Set the output format. Available formats are text, parseable, colorized, json
# and msvs (visual studio). You can also give a reporter class, e.g.
# mypackage.mymodule.MyReporterClass.
output-format=text
# Tells whether to display a full report or only the messages.
reports=no
# Activate the evaluation score.
score=yes
[REFACTORING]
# Maximum number of nested blocks for function / method body
max-nested-blocks=5
# Complete name of functions that never returns. When checking for
# inconsistent-return-statements if a never returning function is called then
# it will be considered as an explicit return statement and no message will be
# printed.
never-returning-functions=sys.exit
[LOGGING]
# Format style used to check logging format string. `old` means using %
# formatting, while `new` is for `{}` formatting.
logging-format-style=old
# Logging modules to check that the string format arguments are in logging
# function parameter format.
logging-modules=logging
[VARIABLES]
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid defining new builtins when possible.
additional-builtins=
# Tells whether unused global variables should be treated as a violation.
allow-global-unused-variables=yes
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,
_cb
# A regular expression matching the name of dummy variables (i.e. expected to
# not be used).
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
# Argument names that match this expression will be ignored. Default to name
# with leading underscore.
ignored-argument-names=_.*|^ignored_|^unused_
# Tells whether we should check for unused import in __init__ files.
init-import=no
# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
[FORMAT]
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=^\s*(# )??$
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
indent-string=' '
# Maximum number of characters on a single line.
max-line-length=80
# Maximum number of lines in a module
max-module-lines=99999
# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check=trailing-comma,
dict-separator
# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
single-line-class-stmt=no
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=no
[TYPECHECK]
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=
# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
ignore-mixin-members=yes
# Tells whether to warn about missing members when the owner of the attribute
# is inferred to be None.
ignore-none=yes
# This flag controls whether pylint should warn about no-member and similar
# checks whenever an opaque object is returned when inferring. The inference
# can return multiple potential results while evaluating a Python object, but
# some branches might not be evaluated, which results in partial inference. In
# that case, it might be useful to still emit no-member and other checks for
# the rest of the inferred objects.
ignore-on-opaque-inference=yes
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local
# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=
# Show a hint with possible names when a member name was not found. The aspect
# of finding the hint is based on edit distance.
missing-member-hint=yes
# The minimum edit distance a name should have in order to be considered a
# similar match for a missing member name.
missing-member-hint-distance=1
# The total number of similar names that should be taken in consideration when
# showing a hint for a missing member.
missing-member-max-choices=1
[SIMILARITIES]
# Ignore comments when computing similarities.
ignore-comments=yes
# Ignore docstrings when computing similarities.
ignore-docstrings=yes
# Ignore imports when computing similarities.
ignore-imports=no
# Minimum lines number of a similarity.
min-similarity-lines=4
[BASIC]
# Naming style matching correct argument names
argument-naming-style=snake_case
# Regular expression matching correct argument names. Overrides argument-
# naming-style
argument-rgx=^[a-z][a-z0-9_]*$
# Naming style matching correct attribute names
attr-naming-style=snake_case
# Regular expression matching correct attribute names. Overrides attr-naming-
# style
attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
# Bad variable names which should always be refused, separated by a comma
bad-names=
# Naming style matching correct class attribute names
class-attribute-naming-style=any
# Regular expression matching correct class attribute names. Overrides class-
# attribute-naming-style
class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
# Naming style matching correct class names
class-naming-style=PascalCase
# Regular expression matching correct class names. Overrides class-naming-style
class-rgx=^_?[A-Z][a-zA-Z0-9]*$
# Naming style matching correct constant names
const-naming-style=UPPER_CASE
# Regular expression matching correct constant names. Overrides const-naming-
# style
const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=10
# Naming style matching correct function names
function-naming-style=snake_case
# Regular expression matching correct function names. Overrides function-
# naming-style
function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$
# Good variable names which should always be accepted, separated by a comma
good-names=main,
_
# Include a hint for the correct naming format with invalid-name
include-naming-hint=no
# Naming style matching correct inline iteration names
inlinevar-naming-style=any
# Regular expression matching correct inline iteration names. Overrides
# inlinevar-naming-style
inlinevar-rgx=^[a-z][a-z0-9_]*$
# Naming style matching correct method names
method-naming-style=snake_case
# Regular expression matching correct method names. Overrides method-naming-
# style
method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$
# Naming style matching correct module names
module-naming-style=snake_case
# Regular expression matching correct module names. Overrides module-naming-
# style
module-rgx=^(_?[a-z][a-z0-9_]*)|__init__|PRESUBMIT|PRESUBMIT_unittest$
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=function:method
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=(__.*__|main)
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
property-classes=abc.abstractproperty,google3.pyglib.function_utils.cached.property
# Naming style matching correct variable names
variable-naming-style=snake_case
# Regular expression matching correct variable names. Overrides variable-
# naming-style
variable-rgx=^[a-z][a-z0-9_]*$
[SPELLING]
# Limits count of emitted suggestions for spelling mistakes.
max-spelling-suggestions=4
# Spelling dictionary name. Available dictionaries: none. To make it working
# install python-enchant package..
spelling-dict=
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to indicated private dictionary in
# --spelling-private-dict-file option instead of raising a message.
spelling-store-unknown-words=no
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=FIXME,
XXX,
TODO
[IMPORTS]
# Allow wildcard imports from modules that define __all__.
allow-wildcard-with-all=no
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
# Deprecated modules which should not be used, separated by a comma.
deprecated-modules=optparse,tkinter.tix
# Create a graph of external dependencies in the given file (report RP0402 must
# not be disabled).
ext-import-graph=
# Create a graph of every (i.e. internal and external) dependencies in the
# given file (report RP0402 must not be disabled).
import-graph=
# Create a graph of internal dependencies in the given file (report RP0402 must
# not be disabled).
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant
[CLASSES]
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
__new__,
setUp
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,
_fields,
_replace,
_source,
_make
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=cls
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "Exception".
overgeneral-exceptions=Exception
================================================
FILE: d4rl/d4rl/kitchen/adept_envs/.style.yapf
================================================
[style]
# Align closing bracket with visual indentation.
align_closing_bracket_with_visual_indent=False
# Allow dictionary keys to exist on multiple lines. For example:
#
# x = {
# ('this is the first element of a tuple',
# 'this is the second element of a tuple'):
# value,
# }
allow_multiline_dictionary_keys=False
# Allow lambdas to be formatted on more than one line.
allow_multiline_lambdas=False
# Allow splitting before a default / named assignment in an argument list.
allow_split_before_default_or_named_assigns=True
# Allow splits before the dictionary value.
allow_split_before_dict_value=True
# Let spacing indicate operator precedence. For example:
#
# a = 1 * 2 + 3 / 4
# b = 1 / 2 - 3 * 4
# c = (1 + 2) * (3 - 4)
# d = (1 - 2) / (3 + 4)
# e = 1 * 2 - 3
# f = 1 + 2 + 3 + 4
#
# will be formatted as follows to indicate precedence:
#
# a = 1*2 + 3/4
# b = 1/2 - 3*4
# c = (1+2) * (3-4)
# d = (1-2) / (3+4)
# e = 1*2 - 3
# f = 1 + 2 + 3 + 4
#
arithmetic_precedence_indication=False
# Number of blank lines surrounding top-level function and class
# definitions.
blank_lines_around_top_level_definition=2
# Insert a blank line before a class-level docstring.
blank_line_before_class_docstring=False
# Insert a blank line before a module docstring.
blank_line_before_module_docstring=False
# Insert a blank line before a 'def' or 'class' immediately nested
# within another 'def' or 'class'. For example:
#
# class Foo:
# # <------ this blank line
# def method():
# ...
blank_line_before_nested_class_or_def=True
# Do not split consecutive brackets. Only relevant when
# dedent_closing_brackets is set. For example:
#
# call_func_that_takes_a_dict(
# {
# 'key1': 'value1',
# 'key2': 'value2',
# }
# )
#
# would reformat to:
#
# call_func_that_takes_a_dict({
# 'key1': 'value1',
# 'key2': 'value2',
# })
coalesce_brackets=False
# The column limit.
column_limit=80
# The style for continuation alignment. Possible values are:
#
# - SPACE: Use spaces for continuation alignment. This is default behavior.
# - FIXED: Use fixed number (CONTINUATION_INDENT_WIDTH) of columns
# (ie: CONTINUATION_INDENT_WIDTH/INDENT_WIDTH tabs) for continuation
# alignment.
# - LESS: Slightly left if cannot vertically align continuation lines with
# indent characters.
# - VALIGN-RIGHT: Vertically align continuation lines with indent
# characters. Slightly right (one more indent character) if cannot
# vertically align continuation lines with indent characters.
#
# For options FIXED, and VALIGN-RIGHT are only available when USE_TABS is
# enabled.
continuation_align_style=SPACE
# Indent width used for line continuations.
continuation_indent_width=4
# Put closing brackets on a separate line, dedented, if the bracketed
# expression can't fit in a single line. Applies to all kinds of brackets,
# including function definitions and calls. For example:
#
# config = {
# 'key1': 'value1',
# 'key2': 'value2',
# } # <--- this bracket is dedented and on a separate line
#
# time_series = self.remote_client.query_entity_counters(
# entity='dev3246.region1',
# key='dns.query_latency_tcp',
# transform=Transformation.AVERAGE(window=timedelta(seconds=60)),
# start_ts=now()-timedelta(days=3),
# end_ts=now(),
# ) # <--- this bracket is dedented and on a separate line
dedent_closing_brackets=False
# Disable the heuristic which places each list element on a separate line
# if the list is comma-terminated.
disable_ending_comma_heuristic=False
# Place each dictionary entry onto its own line.
each_dict_entry_on_separate_line=True
# The regex for an i18n comment. The presence of this comment stops
# reformatting of that line, because the comments are required to be
# next to the string they translate.
i18n_comment=#\..*
# The i18n function call names. The presence of this function stops
# reformattting on that line, because the string it has cannot be moved
# away from the i18n comment.
i18n_function_call=N_, _
# Indent blank lines.
indent_blank_lines=False
# Indent the dictionary value if it cannot fit on the same line as the
# dictionary key. For example:
#
# config = {
# 'key1':
# 'value1',
# 'key2': value1 +
# value2,
# }
indent_dictionary_value=False
# The number of columns to use for indentation.
indent_width=4
# Join short lines into one line. E.g., single line 'if' statements.
join_multiple_lines=True
# Do not include spaces around selected binary operators. For example:
#
# 1 + 2 * 3 - 4 / 5
#
# will be formatted as follows when configured with "*,/":
#
# 1 + 2*3 - 4/5
#
no_spaces_around_selected_binary_operators=
# Use spaces around default or named assigns.
spaces_around_default_or_named_assign=False
# Use spaces around the power operator.
spaces_around_power_operator=False
# The number of spaces required before a trailing comment.
# This can be a single value (representing the number of spaces
# before each trailing comment) or list of values (representing
# alignment column values; trailing comments within a block will
# be aligned to the first column value that is greater than the maximum
# line length within the block). For example:
#
# With spaces_before_comment=5:
#
# 1 + 1 # Adding values
#
# will be formatted as:
#
# 1 + 1 # Adding values <-- 5 spaces between the end of the statement and comment
#
# With spaces_before_comment=15, 20:
#
# 1 + 1 # Adding values
# two + two # More adding
#
# longer_statement # This is a longer statement
# short # This is a shorter statement
#
# a_very_long_statement_that_extends_beyond_the_final_column # Comment
# short # This is a shorter statement
#
# will be formatted as:
#
# 1 + 1 # Adding values <-- end of line comments in block aligned to col 15
# two + two # More adding
#
# longer_statement # This is a longer statement <-- end of line comments in block aligned to col 20
# short # This is a shorter statement
#
# a_very_long_statement_that_extends_beyond_the_final_column # Comment <-- the end of line comments are aligned based on the line length
# short # This is a shorter statement
#
spaces_before_comment=2
# Insert a space between the ending comma and closing bracket of a list,
# etc.
space_between_ending_comma_and_closing_bracket=False
# Split before arguments
split_all_comma_separated_values=False
# Split before arguments if the argument list is terminated by a
# comma.
split_arguments_when_comma_terminated=False
# Set to True to prefer splitting before '&', '|' or '^' rather than
# after.
split_before_bitwise_operator=False
# Split before the closing bracket if a list or dict literal doesn't fit on
# a single line.
split_before_closing_bracket=True
# Split before a dictionary or set generator (comp_for). For example, note
# the split before the 'for':
#
# foo = {
# variable: 'Hello world, have a nice day!'
# for variable in bar if variable != 42
# }
split_before_dict_set_generator=False
# Split before the '.' if we need to split a longer expression:
#
# foo = ('This is a really long string: {}, {}, {}, {}'.format(a, b, c, d))
#
# would reformat to something like:
#
# foo = ('This is a really long string: {}, {}, {}, {}'
# .format(a, b, c, d))
split_before_dot=False
# Split after the opening paren which surrounds an expression if it doesn't
# fit on a single line.
split_before_expression_after_opening_paren=False
# If an argument / parameter list is going to be split, then split before
# the first argument.
split_before_first_argument=False
# Set to True to prefer splitting before 'and' or 'or' rather than
# after.
split_before_logical_operator=False
# Split named assignments onto individual lines.
split_before_named_assigns=True
# Set to True to split list comprehensions and generators that have
# non-trivial expressions and multiple clauses before each of these
# clauses. For example:
#
# result = [
# a_long_var + 100 for a_long_var in xrange(1000)
# if a_long_var % 10]
#
# would reformat to something like:
#
# result = [
# a_long_var + 100
# for a_long_var in xrange(1000)
# if a_long_var % 10]
split_complex_comprehension=True
# The penalty for splitting right after the opening bracket.
split_penalty_after_opening_bracket=30
# The penalty for splitting the line after a unary operator.
split_penalty_after_unary_operator=10000
# The penalty for splitting right before an if expression.
split_penalty_before_if_expr=0
# The penalty of splitting the line around the '&', '|', and '^'
# operators.
split_penalty_bitwise_operator=300
# The penalty for splitting a list comprehension or generator
# expression.
split_penalty_comprehension=2100
# The penalty for characters over the column limit.
split_penalty_excess_character=7000
# The penalty incurred by adding a line split to the unwrapped line. The
# more line splits added the higher the penalty.
split_penalty_for_added_line_split=30
# The penalty of splitting a list of "import as" names. For example:
#
# from a_very_long_or_indented_module_name_yada_yad import (long_argument_1,
# long_argument_2,
# long_argument_3)
#
# would reformat to something like:
#
# from a_very_long_or_indented_module_name_yada_yad import (
# long_argument_1, long_argument_2, long_argument_3)
split_penalty_import_names=0
# The penalty of splitting the line around the 'and' and 'or'
# operators.
split_penalty_logical_operator=300
# Use the Tab character for indentation.
use_tabs=False
================================================
FILE: d4rl/d4rl/kitchen/adept_envs/__init__.py
================================================
#!/usr/bin/python
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import d4rl.kitchen.adept_envs.franka
from d4rl.kitchen.adept_envs.utils.configurable import global_config
================================================
FILE: d4rl/d4rl/kitchen/adept_envs/base_robot.py
================================================
#!/usr/bin/python
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from collections import deque
class BaseRobot(object):
"""Base class for all robot classes."""
def __init__(self,
n_jnt,
n_obj,
pos_bounds=None,
vel_bounds=None,
calibration_path=None,
is_hardware=False,
device_name=None,
overlay=False,
calibration_mode=False,
observation_cache_maxsize=5):
"""Create a new robot.
Args:
n_jnt: The number of dofs in the robot.
n_obj: The number of dofs in the object.
pos_bounds: (n_jnt, 2)-shape matrix denoting the min and max joint
position for each joint.
vel_bounds: (n_jnt, 2)-shape matrix denoting the min and max joint
velocity for each joint.
calibration_path: File path to the calibration configuration file to
use.
is_hardware: Whether to run on hardware or not.
device_name: The device path for the robot hardware. Only required
in legacy mode.
overlay: Whether to show a simulation overlay of the hardware.
calibration_mode: Start with motors disengaged.
"""
assert n_jnt > 0
assert n_obj >= 0
self._n_jnt = n_jnt
self._n_obj = n_obj
self._n_dofs = n_jnt + n_obj
self._pos_bounds = None
if pos_bounds is not None:
pos_bounds = np.array(pos_bounds, dtype=np.float32)
assert pos_bounds.shape == (self._n_dofs, 2)
for low, high in pos_bounds:
assert low < high
self._pos_bounds = pos_bounds
self._vel_bounds = None
if vel_bounds is not None:
vel_bounds = np.array(vel_bounds, dtype=np.float32)
assert vel_bounds.shape == (self._n_dofs, 2)
for low, high in vel_bounds:
assert low < high
self._vel_bounds = vel_bounds
self._is_hardware = is_hardware
self._device_name = device_name
self._calibration_path = calibration_path
self._overlay = overlay
self._calibration_mode = calibration_mode
self._observation_cache_maxsize = observation_cache_maxsize
# Gets updated
self._observation_cache = deque([], maxlen=self._observation_cache_maxsize)
@property
def n_jnt(self):
return self._n_jnt
@property
def n_obj(self):
return self._n_obj
@property
def n_dofs(self):
return self._n_dofs
@property
def pos_bounds(self):
return self._pos_bounds
@property
def vel_bounds(self):
return self._vel_bounds
@property
def is_hardware(self):
return self._is_hardware
@property
def device_name(self):
return self._device_name
@property
def calibration_path(self):
return self._calibration_path
@property
def overlay(self):
return self._overlay
@property
def has_obj(self):
return self._n_obj > 0
@property
def calibration_mode(self):
return self._calibration_mode
@property
def observation_cache_maxsize(self):
return self._observation_cache_maxsize
@property
def observation_cache(self):
return self._observation_cache
def clip_positions(self, positions):
"""Clips the given joint positions to the position bounds.
Args:
positions: The joint positions.
Returns:
The bounded joint positions.
"""
if self.pos_bounds is None:
return positions
assert len(positions) == self.n_jnt or len(positions) == self.n_dofs
pos_bounds = self.pos_bounds[:len(positions)]
return np.clip(positions, pos_bounds[:, 0], pos_bounds[:, 1])
================================================
FILE: d4rl/d4rl/kitchen/adept_envs/franka/__init__.py
================================================
#!/usr/bin/python
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from gym.envs.registration import register
# Relax the robot
register(
id='kitchen_relax-v1',
entry_point='adept_envs.franka.kitchen_multitask_v0:KitchenTaskRelaxV1',
max_episode_steps=280,
)
================================================
FILE: d4rl/d4rl/kitchen/adept_envs/franka/assets/franka_kitchen_jntpos_act_ab.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/adept_envs/franka/kitchen_multitask_v0.py
================================================
""" Kitchen environment for long horizon manipulation """
#!/usr/bin/python
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
from d4rl.kitchen.adept_envs import robot_env
from d4rl.kitchen.adept_envs.utils.configurable import configurable
from gym import spaces
from dm_control.mujoco import engine
@configurable(pickleable=True)
class KitchenV0(robot_env.RobotEnv):
CALIBRATION_PATHS = {
'default':
os.path.join(os.path.dirname(__file__), 'robot/franka_config.xml')
}
# Converted to velocity actuation
ROBOTS = {'robot': 'd4rl.kitchen.adept_envs.franka.robot.franka_robot:Robot_VelAct'}
MODEl = os.path.join(
os.path.dirname(__file__),
'../franka/assets/franka_kitchen_jntpos_act_ab.xml')
N_DOF_ROBOT = 9
N_DOF_OBJECT = 21
def __init__(self, robot_params={}, frame_skip=40):
self.goal_concat = True
self.obs_dict = {}
self.robot_noise_ratio = 0.1 # 10% as per robot_config specs
self.goal = np.zeros((30,))
super().__init__(
self.MODEl,
robot=self.make_robot(
n_jnt=self.N_DOF_ROBOT, #root+robot_jnts
n_obj=self.N_DOF_OBJECT,
**robot_params),
frame_skip=frame_skip,
camera_settings=dict(
distance=4.5,
azimuth=-66,
elevation=-65,
),
)
self.init_qpos = self.sim.model.key_qpos[0].copy()
# For the microwave kettle slide hinge
self.init_qpos = np.array([ 1.48388023e-01, -1.76848573e+00, 1.84390296e+00, -2.47685760e+00,
2.60252026e-01, 7.12533105e-01, 1.59515394e+00, 4.79267505e-02,
3.71350919e-02, -2.66279850e-04, -5.18043486e-05, 3.12877220e-05,
-4.51199853e-05, -3.90842156e-06, -4.22629655e-05, 6.28065475e-05,
4.04984708e-05, 4.62730939e-04, -2.26906415e-04, -4.65501369e-04,
-6.44129196e-03, -1.77048263e-03, 1.08009684e-03, -2.69397440e-01,
3.50383255e-01, 1.61944683e+00, 1.00618764e+00, 4.06395120e-03,
-6.62095997e-03, -2.68278933e-04])
self.init_qvel = self.sim.model.key_qvel[0].copy()
self.act_mid = np.zeros(self.N_DOF_ROBOT)
self.act_amp = 2.0 * np.ones(self.N_DOF_ROBOT)
act_lower = -1*np.ones((self.N_DOF_ROBOT,))
act_upper = 1*np.ones((self.N_DOF_ROBOT,))
self.action_space = spaces.Box(act_lower, act_upper)
obs_upper = 8. * np.ones(self.obs_dim)
obs_lower = -obs_upper
self.observation_space = spaces.Box(obs_lower, obs_upper)
def _get_reward_n_score(self, obs_dict):
raise NotImplementedError()
def step(self, a, b=None):
a = np.clip(a, -1.0, 1.0)
if not self.initializing:
a = self.act_mid + a * self.act_amp # mean center and scale
else:
self.goal = self._get_task_goal() # update goal if init
self.robot.step(
self, a, step_duration=self.skip * self.model.opt.timestep)
# observations
obs = self._get_obs()
#rewards
reward_dict, score = self._get_reward_n_score(self.obs_dict)
# termination
done = False
# finalize step
env_info = {
'time': self.obs_dict['t'],
'obs_dict': self.obs_dict,
'rewards': reward_dict,
'score': score,
'images': np.asarray(self.render(mode='rgb_array'))
}
# self.render()
return obs, reward_dict['r_total'], done, env_info
def _get_obs(self):
t, qp, qv, obj_qp, obj_qv = self.robot.get_obs(
self, robot_noise_ratio=self.robot_noise_ratio)
self.obs_dict = {}
self.obs_dict['t'] = t
self.obs_dict['qp'] = qp
self.obs_dict['qv'] = qv
self.obs_dict['obj_qp'] = obj_qp
self.obs_dict['obj_qv'] = obj_qv
self.obs_dict['goal'] = self.goal
if self.goal_concat:
return np.concatenate([self.obs_dict['qp'], self.obs_dict['obj_qp'], self.obs_dict['goal']])
def reset_model(self):
reset_pos = self.init_qpos[:].copy()
reset_vel = self.init_qvel[:].copy()
self.robot.reset(self, reset_pos, reset_vel)
self.sim.forward()
self.goal = self._get_task_goal() #sample a new goal on reset
return self._get_obs()
def evaluate_success(self, paths):
# score
mean_score_per_rollout = np.zeros(shape=len(paths))
for idx, path in enumerate(paths):
mean_score_per_rollout[idx] = np.mean(path['env_infos']['score'])
mean_score = np.mean(mean_score_per_rollout)
# success percentage
num_success = 0
num_paths = len(paths)
for path in paths:
num_success += bool(path['env_infos']['rewards']['bonus'][-1])
success_percentage = num_success * 100.0 / num_paths
# fuse results
return np.sign(mean_score) * (
1e6 * round(success_percentage, 2) + abs(mean_score))
def close_env(self):
self.robot.close()
def set_goal(self, goal):
self.goal = goal
def _get_task_goal(self):
return self.goal
# Only include goal
@property
def goal_space(self):
len_obs = self.observation_space.low.shape[0]
env_lim = np.abs(self.observation_space.low[0])
return spaces.Box(low=-env_lim, high=env_lim, shape=(len_obs//2,))
def convert_to_active_observation(self, observation):
return observation
class KitchenTaskRelaxV1(KitchenV0):
"""Kitchen environment with proper camera and goal setup"""
def __init__(self):
super(KitchenTaskRelaxV1, self).__init__()
def _get_reward_n_score(self, obs_dict):
reward_dict = {}
reward_dict['true_reward'] = 0.
reward_dict['bonus'] = 0.
reward_dict['r_total'] = 0.
score = 0.
return reward_dict, score
def render(self, mode='human'):
if mode =='rgb_array':
camera = engine.MovableCamera(self.sim, 1920, 2560)
camera.set_pose(distance=2.2, lookat=[-0.2, .5, 2.], azimuth=70, elevation=-35)
img = camera.render()
return img
else:
super(KitchenTaskRelaxV1, self).render()
================================================
FILE: d4rl/d4rl/kitchen/adept_envs/franka/robot/__init__.py
================================================
================================================
FILE: d4rl/d4rl/kitchen/adept_envs/franka/robot/franka_config.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/adept_envs/franka/robot/franka_robot.py
================================================
#!/usr/bin/python
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os, getpass
import numpy as np
from termcolor import cprint
import time
import copy
import click
from d4rl.kitchen.adept_envs import base_robot
from d4rl.kitchen.adept_envs.utils.config import (get_config_root_node, read_config_from_node)
# obervations structure
from collections import namedtuple
observation = namedtuple('observation', ['time', 'qpos_robot', 'qvel_robot', 'qpos_object', 'qvel_object'])
franka_interface = ''
class Robot(base_robot.BaseRobot):
"""
Abstracts away the differences between the robot_simulation and robot_hardware
"""
def __init__(self, *args, **kwargs):
super(Robot, self).__init__(*args, **kwargs)
global franka_interface
# Read robot configurations
self._read_specs_from_config(robot_configs=self.calibration_path)
# Robot: Handware
if self.is_hardware:
if franka_interface is '':
raise NotImplementedError()
from handware.franka import franka
# initialize franka
self.franka_interface = franka()
franka_interface = self.franka_interface
cprint("Initializing %s Hardware (Status:%d)" % (self.robot_name, self.franka.okay(self.robot_hardware_dof)), 'white', 'on_grey')
else:
self.franka_interface = franka_interface
cprint("Reusing previours Franka session", 'white', 'on_grey')
# Robot: Simulation
else:
self.robot_name = "Franka"
cprint("Initializing %s sim" % self.robot_name, 'white', 'on_grey')
# Robot's time
self.time_start = time.time()
self.time = time.time()-self.time_start
self.time_render = -1 # time of rendering
# read specs from the calibration file
def _read_specs_from_config(self, robot_configs):
root, root_name = get_config_root_node(config_file_name=robot_configs)
self.robot_name = root_name[0]
self.robot_mode = np.zeros(self.n_dofs, dtype=int)
self.robot_mj_dof = np.zeros(self.n_dofs, dtype=int)
self.robot_hardware_dof = np.zeros(self.n_dofs, dtype=int)
self.robot_scale = np.zeros(self.n_dofs, dtype=float)
self.robot_offset = np.zeros(self.n_dofs, dtype=float)
self.robot_pos_bound = np.zeros([self.n_dofs, 2], dtype=float)
self.robot_vel_bound = np.zeros([self.n_dofs, 2], dtype=float)
self.robot_pos_noise_amp = np.zeros(self.n_dofs, dtype=float)
self.robot_vel_noise_amp = np.zeros(self.n_dofs, dtype=float)
print("Reading configurations for %s" % self.robot_name)
for i in range(self.n_dofs):
self.robot_mode[i] = read_config_from_node(root, "qpos"+str(i), "mode", int)
self.robot_mj_dof[i] = read_config_from_node(root, "qpos"+str(i), "mj_dof", int)
self.robot_hardware_dof[i] = read_config_from_node(root, "qpos"+str(i), "hardware_dof", int)
self.robot_scale[i] = read_config_from_node(root, "qpos"+str(i), "scale", float)
self.robot_offset[i] = read_config_from_node(root, "qpos"+str(i), "offset", float)
self.robot_pos_bound[i] = read_config_from_node(root, "qpos"+str(i), "pos_bound", float)
self.robot_vel_bound[i] = read_config_from_node(root, "qpos"+str(i), "vel_bound", float)
self.robot_pos_noise_amp[i] = read_config_from_node(root, "qpos"+str(i), "pos_noise_amp", float)
self.robot_vel_noise_amp[i] = read_config_from_node(root, "qpos"+str(i), "vel_noise_amp", float)
# convert to hardware space
def _de_calib(self, qp_mj, qv_mj=None):
qp_ad = (qp_mj-self.robot_offset)/self.robot_scale
if qv_mj is not None:
qv_ad = qv_mj/self.robot_scale
return qp_ad, qv_ad
else:
return qp_ad
# convert to mujoco space
def _calib(self, qp_ad, qv_ad):
qp_mj = qp_ad* self.robot_scale + self.robot_offset
qv_mj = qv_ad* self.robot_scale
return qp_mj, qv_mj
# refresh the observation cache
def _observation_cache_refresh(self, env):
for _ in range(self.observation_cache_maxsize):
self.get_obs(env, sim_mimic_hardware=False)
# get past observation
def get_obs_from_cache(self, env, index=-1):
assert (index>=0 and index=-self.observation_cache_maxsize), \
"cache index out of bound. (cache size is %2d)"%self.observation_cache_maxsize
obs = self.observation_cache[index]
if self.has_obj:
return obs.time, obs.qpos_robot, obs.qvel_robot, obs.qpos_object, obs.qvel_object
else:
return obs.time, obs.qpos_robot, obs.qvel_robot
# get observation
def get_obs(self, env, robot_noise_ratio=1, object_noise_ratio=1, sim_mimic_hardware=True):
if self.is_hardware:
raise NotImplementedError()
else:
#Gather simulated observation
qp = env.sim.data.qpos[:self.n_jnt].copy()
qv = env.sim.data.qvel[:self.n_jnt].copy()
if self.has_obj:
qp_obj = env.sim.data.qpos[-self.n_obj:].copy()
qv_obj = env.sim.data.qvel[-self.n_obj:].copy()
else:
qp_obj = None
qv_obj = None
self.time = env.sim.data.time
# Simulate observation noise
if not env.initializing:
qp += robot_noise_ratio*self.robot_pos_noise_amp[:self.n_jnt]*env.np_random.uniform(low=-1., high=1., size=self.n_jnt)
qv += robot_noise_ratio*self.robot_vel_noise_amp[:self.n_jnt]*env.np_random.uniform(low=-1., high=1., size=self.n_jnt)
if self.has_obj:
qp_obj += robot_noise_ratio*self.robot_pos_noise_amp[-self.n_obj:]*env.np_random.uniform(low=-1., high=1., size=self.n_obj)
qv_obj += robot_noise_ratio*self.robot_vel_noise_amp[-self.n_obj:]*env.np_random.uniform(low=-1., high=1., size=self.n_obj)
# cache observations
obs = observation(time=self.time, qpos_robot=qp, qvel_robot=qv, qpos_object=qp_obj, qvel_object=qv_obj)
self.observation_cache.append(obs)
if self.has_obj:
return obs.time, obs.qpos_robot, obs.qvel_robot, obs.qpos_object, obs.qvel_object
else:
return obs.time, obs.qpos_robot, obs.qvel_robot
# enforce position specs.
def ctrl_position_limits(self, ctrl_position):
ctrl_feasible_position = np.clip(ctrl_position, self.robot_pos_bound[:self.n_jnt, 0], self.robot_pos_bound[:self.n_jnt, 1])
return ctrl_feasible_position
# step the robot env
def step(self, env, ctrl_desired, step_duration, sim_override=False):
# Populate observation cache during startup
if env.initializing:
self._observation_cache_refresh(env)
# enforce velocity limits
ctrl_feasible = self.ctrl_velocity_limits(ctrl_desired, step_duration)
# enforce position limits
ctrl_feasible = self.ctrl_position_limits(ctrl_feasible)
# Send controls to the robot
if self.is_hardware and (not sim_override):
raise NotImplementedError()
else:
env.do_simulation(ctrl_feasible, int(step_duration/env.sim.model.opt.timestep)) # render is folded in here
# Update current robot state on the overlay
if self.overlay:
env.sim.data.qpos[self.n_jnt:2*self.n_jnt] = env.desired_pose.copy()
env.sim.forward()
# synchronize time
if self.is_hardware:
time_now = (time.time()-self.time_start)
time_left_in_step = step_duration - (time_now-self.time)
if(time_left_in_step>0.0001):
time.sleep(time_left_in_step)
return 1
def reset(self, env, reset_pose, reset_vel, overlay_mimic_reset_pose=True, sim_override=False):
reset_pose = self.clip_positions(reset_pose)
if self.is_hardware:
raise NotImplementedError()
else:
env.sim.reset()
env.sim.data.qpos[:self.n_jnt] = reset_pose[:self.n_jnt].copy()
env.sim.data.qvel[:self.n_jnt] = reset_vel[:self.n_jnt].copy()
if self.has_obj:
env.sim.data.qpos[-self.n_obj:] = reset_pose[-self.n_obj:].copy()
env.sim.data.qvel[-self.n_obj:] = reset_vel[-self.n_obj:].copy()
env.sim.forward()
if self.overlay:
env.sim.data.qpos[self.n_jnt:2*self.n_jnt] = env.desired_pose[:self.n_jnt].copy()
env.sim.forward()
# refresh observation cache before exit
self._observation_cache_refresh(env)
def close(self):
if self.is_hardware:
cprint("Closing Franka hardware... ", 'white', 'on_grey', end='', flush=True)
status = 0
raise NotImplementedError()
cprint("Closed (Status: {})".format(status), 'white', 'on_grey', flush=True)
else:
cprint("Closing Franka sim", 'white', 'on_grey', flush=True)
class Robot_PosAct(Robot):
# enforce velocity sepcs.
# ALERT: This depends on previous observation. This is not ideal as it breaks MDP addumptions. Be careful
def ctrl_velocity_limits(self, ctrl_position, step_duration):
last_obs = self.observation_cache[-1]
ctrl_desired_vel = (ctrl_position-last_obs.qpos_robot[:self.n_jnt])/step_duration
ctrl_feasible_vel = np.clip(ctrl_desired_vel, self.robot_vel_bound[:self.n_jnt, 0], self.robot_vel_bound[:self.n_jnt, 1])
ctrl_feasible_position = last_obs.qpos_robot[:self.n_jnt] + ctrl_feasible_vel*step_duration
return ctrl_feasible_position
class Robot_VelAct(Robot):
# enforce velocity sepcs.
# ALERT: This depends on previous observation. This is not ideal as it breaks MDP addumptions. Be careful
def ctrl_velocity_limits(self, ctrl_velocity, step_duration):
last_obs = self.observation_cache[-1]
ctrl_feasible_vel = np.clip(ctrl_velocity, self.robot_vel_bound[:self.n_jnt, 0], self.robot_vel_bound[:self.n_jnt, 1])
ctrl_feasible_position = last_obs.qpos_robot[:self.n_jnt] + ctrl_feasible_vel*step_duration
return ctrl_feasible_position
================================================
FILE: d4rl/d4rl/kitchen/adept_envs/mujoco_env.py
================================================
"""Base environment for MuJoCo-based environments."""
#!/usr/bin/python
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import os
import time
from typing import Dict, Optional
import gym
from gym import spaces
from gym.utils import seeding
import numpy as np
from d4rl.kitchen.adept_envs.simulation.sim_robot import MujocoSimRobot, RenderMode
DEFAULT_RENDER_SIZE = 480
USE_DM_CONTROL = True
class MujocoEnv(gym.Env):
"""Superclass for all MuJoCo environments."""
def __init__(self,
model_path: str,
frame_skip: int,
camera_settings: Optional[Dict] = None,
use_dm_backend: Optional[bool] = None,
):
"""Initializes a new MuJoCo environment.
Args:
model_path: The path to the MuJoCo XML file.
frame_skip: The number of simulation steps per environment step. On
hardware this influences the duration of each environment step.
camera_settings: Settings to initialize the simulation camera. This
can contain the keys `distance`, `azimuth`, and `elevation`.
use_dm_backend: A boolean to switch between mujoco-py and dm_control.
"""
self._seed()
if not os.path.isfile(model_path):
raise IOError(
'[MujocoEnv]: Model path does not exist: {}'.format(model_path))
self.frame_skip = frame_skip
self.sim_robot = MujocoSimRobot(
model_path,
use_dm_backend=use_dm_backend or USE_DM_CONTROL,
camera_settings=camera_settings)
self.sim = self.sim_robot.sim
self.model = self.sim_robot.model
self.data = self.sim_robot.data
self.metadata = {
'render.modes': ['human', 'rgb_array', 'depth_array'],
'video.frames_per_second': int(np.round(1.0 / self.dt))
}
self.mujoco_render_frames = False
self.init_qpos = self.data.qpos.ravel().copy()
self.init_qvel = self.data.qvel.ravel().copy()
observation, _reward, done, _info = self.step(np.zeros(self.model.nu))
assert not done
bounds = self.model.actuator_ctrlrange.copy()
act_upper = bounds[:, 1]
act_lower = bounds[:, 0]
# Define the action and observation spaces.
# HACK: MJRL is still using gym 0.9.x so we can't provide a dtype.
try:
self.action_space = spaces.Box(
act_lower, act_upper, dtype=np.float32)
if isinstance(observation, collections.Mapping):
self.observation_space = spaces.Dict({
k: spaces.Box(-np.inf, np.inf, shape=v.shape, dtype=np.float32) for k, v in observation.items()})
else:
self.obs_dim = np.sum([o.size for o in observation]) if type(observation) is tuple else observation.size
self.observation_space = spaces.Box(
-np.inf, np.inf, observation.shape, dtype=np.float32)
except TypeError:
# Fallback case for gym 0.9.x
self.action_space = spaces.Box(act_lower, act_upper)
assert not isinstance(observation, collections.Mapping), 'gym 0.9.x does not support dictionary observation.'
self.obs_dim = np.sum([o.size for o in observation]) if type(observation) is tuple else observation.size
self.observation_space = spaces.Box(
-np.inf, np.inf, observation.shape)
def seed(self, seed=None): # Compatibility with new gym
return self._seed(seed)
def _seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
# methods to override:
# ----------------------------
def reset_model(self):
"""Reset the robot degrees of freedom (qpos and qvel).
Implement this in each subclass.
"""
raise NotImplementedError
# -----------------------------
def reset(self): # compatibility with new gym
return self._reset()
def _reset(self):
self.sim.reset()
self.sim.forward()
ob = self.reset_model()
return ob
def set_state(self, qpos, qvel):
assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)
state = self.sim.get_state()
for i in range(self.model.nq):
state.qpos[i] = qpos[i]
for i in range(self.model.nv):
state.qvel[i] = qvel[i]
self.sim.set_state(state)
self.sim.forward()
@property
def dt(self):
return self.model.opt.timestep * self.frame_skip
def do_simulation(self, ctrl, n_frames):
for i in range(self.model.nu):
self.sim.data.ctrl[i] = ctrl[i]
for _ in range(n_frames):
self.sim.step()
# TODO(michaelahn): Remove this; render should be called separately.
if self.mujoco_render_frames is True:
self.mj_render()
def render(self,
mode='human',
width=DEFAULT_RENDER_SIZE,
height=DEFAULT_RENDER_SIZE,
camera_id=-1):
"""Renders the environment.
Args:
mode: The type of rendering to use.
- 'human': Renders to a graphical window.
- 'rgb_array': Returns the RGB image as an np.ndarray.
- 'depth_array': Returns the depth image as an np.ndarray.
width: The width of the rendered image. This only affects offscreen
rendering.
height: The height of the rendered image. This only affects
offscreen rendering.
camera_id: The ID of the camera to use. By default, this is the free
camera. If specified, only affects offscreen rendering.
"""
if mode == 'human':
self.sim_robot.renderer.render_to_window()
elif mode == 'rgb_array':
assert width and height
return self.sim_robot.renderer.render_offscreen(
width, height, mode=RenderMode.RGB, camera_id=camera_id)
elif mode == 'depth_array':
assert width and height
return self.sim_robot.renderer.render_offscreen(
width, height, mode=RenderMode.DEPTH, camera_id=camera_id)
else:
raise NotImplementedError(mode)
def close(self):
self.sim_robot.close()
def mj_render(self):
"""Backwards compatibility with MJRL."""
self.render(mode='human')
def state_vector(self):
state = self.sim.get_state()
return np.concatenate([state.qpos.flat, state.qvel.flat])
================================================
FILE: d4rl/d4rl/kitchen/adept_envs/robot_env.py
================================================
"""Base class for robotics environments."""
#!/usr/bin/python
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
import os
from typing import Dict, Optional
import numpy as np
from d4rl.kitchen.adept_envs import mujoco_env
from d4rl.kitchen.adept_envs.base_robot import BaseRobot
from d4rl.kitchen.adept_envs.utils.configurable import import_class_from_path
from d4rl.kitchen.adept_envs.utils.constants import MODELS_PATH
class RobotEnv(mujoco_env.MujocoEnv):
"""Base environment for all adept robots."""
# Mapping of robot name to fully qualified class path.
# e.g. 'robot': 'adept_envs.dclaw.robot.Robot'
# Subclasses should override this to specify the Robot classes they support.
ROBOTS = {}
# Mapping of device path to the calibration file to use. If the device path
# is not found, the 'default' key is used.
# This can be overriden by subclasses.
CALIBRATION_PATHS = {}
def __init__(self,
model_path: str,
robot: BaseRobot,
frame_skip: int,
camera_settings: Optional[Dict] = None):
"""Initializes a robotics environment.
Args:
model_path: The path to the model to run. Relative paths will be
interpreted as relative to the 'adept_models' folder.
robot: The Robot object to use.
frame_skip: The number of simulation steps per environment step. On
hardware this influences the duration of each environment step.
camera_settings: Settings to initialize the simulation camera. This
can contain the keys `distance`, `azimuth`, and `elevation`.
"""
self._robot = robot
# Initial pose for first step.
self.desired_pose = np.zeros(self.n_jnt)
if not model_path.startswith('/'):
model_path = os.path.abspath(os.path.join(MODELS_PATH, model_path))
self.remote_viz = None
try:
from adept_envs.utils.remote_viz import RemoteViz
self.remote_viz = RemoteViz(model_path)
except ImportError:
pass
self._initializing = True
super(RobotEnv, self).__init__(
model_path, frame_skip, camera_settings=camera_settings)
self._initializing = False
@property
def robot(self):
return self._robot
@property
def n_jnt(self):
return self._robot.n_jnt
@property
def n_obj(self):
return self._robot.n_obj
@property
def skip(self):
"""Alias for frame_skip. Needed for MJRL."""
return self.frame_skip
@property
def initializing(self):
return self._initializing
def close_env(self):
if self._robot is not None:
self._robot.close()
def make_robot(self,
n_jnt,
n_obj=0,
is_hardware=False,
device_name=None,
legacy=False,
**kwargs):
"""Creates a new robot for the environment.
Args:
n_jnt: The number of joints in the robot.
n_obj: The number of object joints in the robot environment.
is_hardware: Whether to run on hardware or not.
device_name: The device path for the robot hardware.
legacy: If true, runs using direct dynamixel communication rather
than DDS.
kwargs: See BaseRobot for other parameters.
Returns:
A Robot object.
"""
if not self.ROBOTS:
raise NotImplementedError('Subclasses must override ROBOTS.')
if is_hardware and not device_name:
raise ValueError('Must provide device name if running on hardware.')
robot_name = 'dds_robot' if not legacy and is_hardware else 'robot'
if robot_name not in self.ROBOTS:
raise KeyError("Unsupported robot '{}', available: {}".format(
robot_name, list(self.ROBOTS.keys())))
cls = import_class_from_path(self.ROBOTS[robot_name])
calibration_path = None
if self.CALIBRATION_PATHS:
if not device_name:
calibration_name = 'default'
elif device_name not in self.CALIBRATION_PATHS:
print('Device "{}" not in CALIBRATION_PATHS; using default.'
.format(device_name))
calibration_name = 'default'
else:
calibration_name = device_name
calibration_path = self.CALIBRATION_PATHS[calibration_name]
if not os.path.isfile(calibration_path):
raise OSError('Could not find calibration file at: {}'.format(
calibration_path))
return cls(
n_jnt,
n_obj,
is_hardware=is_hardware,
device_name=device_name,
calibration_path=calibration_path,
**kwargs)
================================================
FILE: d4rl/d4rl/kitchen/adept_envs/simulation/__init__.py
================================================
================================================
FILE: d4rl/d4rl/kitchen/adept_envs/simulation/module.py
================================================
#!/usr/bin/python
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for caching Python modules related to simulation."""
import sys
_MUJOCO_PY_MODULE = None
_DM_MUJOCO_MODULE = None
_DM_VIEWER_MODULE = None
_DM_RENDER_MODULE = None
_GLFW_MODULE = None
def get_mujoco_py():
"""Returns the mujoco_py module."""
global _MUJOCO_PY_MODULE
if _MUJOCO_PY_MODULE:
return _MUJOCO_PY_MODULE
try:
import mujoco_py
# Override the warning function.
from mujoco_py.builder import cymj
cymj.set_warning_callback(_mj_warning_fn)
except ImportError:
print(
'Failed to import mujoco_py. Ensure that mujoco_py (using MuJoCo '
'v1.50) is installed.',
file=sys.stderr)
sys.exit(1)
_MUJOCO_PY_MODULE = mujoco_py
return mujoco_py
def get_mujoco_py_mjlib():
"""Returns the mujoco_py mjlib module."""
class MjlibDelegate:
"""Wrapper that forwards mjlib calls."""
def __init__(self, lib):
self._lib = lib
def __getattr__(self, name: str):
if name.startswith('mj'):
return getattr(self._lib, '_' + name)
raise AttributeError(name)
return MjlibDelegate(get_mujoco_py().cymj)
def get_dm_mujoco():
"""Returns the DM Control mujoco module."""
global _DM_MUJOCO_MODULE
if _DM_MUJOCO_MODULE:
return _DM_MUJOCO_MODULE
try:
from dm_control import mujoco
except ImportError:
print(
'Failed to import dm_control.mujoco. Ensure that dm_control (using '
'MuJoCo v2.00) is installed.',
file=sys.stderr)
sys.exit(1)
_DM_MUJOCO_MODULE = mujoco
return mujoco
def get_dm_viewer():
"""Returns the DM Control viewer module."""
global _DM_VIEWER_MODULE
if _DM_VIEWER_MODULE:
return _DM_VIEWER_MODULE
try:
from dm_control import viewer
except ImportError:
print(
'Failed to import dm_control.viewer. Ensure that dm_control (using '
'MuJoCo v2.00) is installed.',
file=sys.stderr)
sys.exit(1)
_DM_VIEWER_MODULE = viewer
return viewer
def get_dm_render():
"""Returns the DM Control render module."""
global _DM_RENDER_MODULE
if _DM_RENDER_MODULE:
return _DM_RENDER_MODULE
try:
try:
from dm_control import _render
render = _render
except ImportError:
print('Warning: DM Control is out of date.')
from dm_control import render
except ImportError:
print(
'Failed to import dm_control.render. Ensure that dm_control (using '
'MuJoCo v2.00) is installed.',
file=sys.stderr)
sys.exit(1)
_DM_RENDER_MODULE = render
return render
def _mj_warning_fn(warn_data: bytes):
"""Warning function override for mujoco_py."""
print('WARNING: Mujoco simulation is unstable (has NaNs): {}'.format(
warn_data.decode()))
================================================
FILE: d4rl/d4rl/kitchen/adept_envs/simulation/renderer.py
================================================
#!/usr/bin/python
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for viewing Physics objects in the DM Control viewer."""
import abc
import enum
import sys
from typing import Dict, Optional
import numpy as np
from d4rl.kitchen.adept_envs.simulation import module
# Default window dimensions.
DEFAULT_WINDOW_WIDTH = 1024
DEFAULT_WINDOW_HEIGHT = 768
DEFAULT_WINDOW_TITLE = 'MuJoCo Viewer'
_MAX_RENDERBUFFER_SIZE = 2048
class RenderMode(enum.Enum):
"""Rendering modes for offscreen rendering."""
RGB = 0
DEPTH = 1
SEGMENTATION = 2
class Renderer(abc.ABC):
"""Base interface for rendering simulations."""
def __init__(self, camera_settings: Optional[Dict] = None):
self._camera_settings = camera_settings
@abc.abstractmethod
def close(self):
"""Cleans up any resources being used by the renderer."""
@abc.abstractmethod
def render_to_window(self):
"""Renders the simulation to a window."""
@abc.abstractmethod
def render_offscreen(self,
width: int,
height: int,
mode: RenderMode = RenderMode.RGB,
camera_id: int = -1) -> np.ndarray:
"""Renders the camera view as a NumPy array of pixels.
Args:
width: The viewport width (pixels).
height: The viewport height (pixels).
mode: The rendering mode.
camera_id: The ID of the camera to render from. By default, uses
the free camera.
Returns:
A NumPy array of the pixels.
"""
def _update_camera(self, camera):
"""Updates the given camera to move to the initial settings."""
if not self._camera_settings:
return
distance = self._camera_settings.get('distance')
azimuth = self._camera_settings.get('azimuth')
elevation = self._camera_settings.get('elevation')
lookat = self._camera_settings.get('lookat')
if distance is not None:
camera.distance = distance
if azimuth is not None:
camera.azimuth = azimuth
if elevation is not None:
camera.elevation = elevation
if lookat is not None:
camera.lookat[:] = lookat
class MjPyRenderer(Renderer):
"""Class for rendering mujoco_py simulations."""
def __init__(self, sim, **kwargs):
assert isinstance(sim, module.get_mujoco_py().MjSim), \
'MjPyRenderer takes a mujoco_py MjSim object.'
super().__init__(**kwargs)
self._sim = sim
self._onscreen_renderer = None
self._offscreen_renderer = None
def render_to_window(self):
"""Renders the simulation to a window."""
if not self._onscreen_renderer:
self._onscreen_renderer = module.get_mujoco_py().MjViewer(self._sim)
self._update_camera(self._onscreen_renderer.cam)
self._onscreen_renderer.render()
def render_offscreen(self,
width: int,
height: int,
mode: RenderMode = RenderMode.RGB,
camera_id: int = -1) -> np.ndarray:
"""Renders the camera view as a NumPy array of pixels.
Args:
width: The viewport width (pixels).
height: The viewport height (pixels).
mode: The rendering mode.
camera_id: The ID of the camera to render from. By default, uses
the free camera.
Returns:
A NumPy array of the pixels.
"""
if not self._offscreen_renderer:
self._offscreen_renderer = module.get_mujoco_py() \
.MjRenderContextOffscreen(self._sim)
# Update the camera configuration for the free-camera.
if camera_id == -1:
self._update_camera(self._offscreen_renderer.cam)
self._offscreen_renderer.render(width, height, camera_id)
if mode == RenderMode.RGB:
data = self._offscreen_renderer.read_pixels(
width, height, depth=False)
# Original image is upside-down, so flip it
return data[::-1, :, :]
elif mode == RenderMode.DEPTH:
data = self._offscreen_renderer.read_pixels(
width, height, depth=True)[1]
# Original image is upside-down, so flip it
return data[::-1, :]
else:
raise NotImplementedError(mode)
def close(self):
"""Cleans up any resources being used by the renderer."""
class DMRenderer(Renderer):
"""Class for rendering DM Control Physics objects."""
def __init__(self, physics, **kwargs):
assert isinstance(physics, module.get_dm_mujoco().Physics), \
'DMRenderer takes a DM Control Physics object.'
super().__init__(**kwargs)
self._physics = physics
self._window = None
# Set the camera to lookat the center of the geoms. (mujoco_py does
# this automatically.
if 'lookat' not in self._camera_settings:
self._camera_settings['lookat'] = [
np.median(self._physics.data.geom_xpos[:, i]) for i in range(3)
]
def render_to_window(self):
"""Renders the Physics object to a window.
The window continuously renders the Physics in a separate thread.
This function is a no-op if the window was already created.
"""
if not self._window:
self._window = DMRenderWindow()
self._window.load_model(self._physics)
self._update_camera(self._window.camera)
self._window.run_frame()
def render_offscreen(self,
width: int,
height: int,
mode: RenderMode = RenderMode.RGB,
camera_id: int = -1) -> np.ndarray:
"""Renders the camera view as a NumPy array of pixels.
Args:
width: The viewport width (pixels).
height: The viewport height (pixels).
mode: The rendering mode.
camera_id: The ID of the camera to render from. By default, uses
the free camera.
Returns:
A NumPy array of the pixels.
"""
mujoco = module.get_dm_mujoco()
# TODO(michaelahn): Consider caching the camera.
camera = mujoco.Camera(
physics=self._physics,
height=height,
width=width,
camera_id=camera_id)
# Update the camera configuration for the free-camera.
if camera_id == -1:
self._update_camera(
camera._render_camera, # pylint: disable=protected-access
)
image = camera.render(
depth=(mode == RenderMode.DEPTH),
segmentation=(mode == RenderMode.SEGMENTATION))
camera._scene.free() # pylint: disable=protected-access
return image
def close(self):
"""Cleans up any resources being used by the renderer."""
if self._window:
self._window.close()
self._window = None
class DMRenderWindow:
"""Class that encapsulates a graphical window."""
def __init__(self,
width: int = DEFAULT_WINDOW_WIDTH,
height: int = DEFAULT_WINDOW_HEIGHT,
title: str = DEFAULT_WINDOW_TITLE):
"""Creates a graphical render window.
Args:
width: The width of the window.
height: The height of the window.
title: The title of the window.
"""
dmv = module.get_dm_viewer()
self._viewport = dmv.renderer.Viewport(width, height)
self._window = dmv.gui.RenderWindow(width, height, title)
self._viewer = dmv.viewer.Viewer(self._viewport, self._window.mouse,
self._window.keyboard)
self._draw_surface = None
self._renderer = dmv.renderer.NullRenderer()
@property
def camera(self):
return self._viewer._camera._camera
def close(self):
self._viewer.deinitialize()
self._renderer.release()
self._draw_surface.free()
self._window.close()
def load_model(self, physics):
"""Loads the given Physics object to render."""
self._viewer.deinitialize()
self._draw_surface = module.get_dm_render().Renderer(
max_width=_MAX_RENDERBUFFER_SIZE, max_height=_MAX_RENDERBUFFER_SIZE)
self._renderer = module.get_dm_viewer().renderer.OffScreenRenderer(
physics.model, self._draw_surface)
self._viewer.initialize(physics, self._renderer, touchpad=False)
def run_frame(self):
"""Renders one frame of the simulation.
NOTE: This is extremely slow at the moment.
"""
glfw = module.get_dm_viewer().gui.glfw_gui.glfw
glfw_window = self._window._context.window
if glfw.window_should_close(glfw_window):
sys.exit(0)
self._viewport.set_size(*self._window.shape)
self._viewer.render()
pixels = self._renderer.pixels
with self._window._context.make_current() as ctx:
ctx.call(self._window._update_gui_on_render_thread, glfw_window,
pixels)
self._window._mouse.process_events()
self._window._keyboard.process_events()
================================================
FILE: d4rl/d4rl/kitchen/adept_envs/simulation/sim_robot.py
================================================
#!/usr/bin/python
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for loading MuJoCo models."""
import os
from typing import Dict, Optional
from d4rl.kitchen.adept_envs.simulation import module
from d4rl.kitchen.adept_envs.simulation.renderer import DMRenderer, MjPyRenderer, RenderMode
class MujocoSimRobot:
"""Class that encapsulates a MuJoCo simulation.
This class exposes methods that are agnostic to the simulation backend.
Two backends are supported:
1. mujoco_py - MuJoCo v1.50
2. dm_control - MuJoCo v2.00
"""
def __init__(self,
model_file: str,
use_dm_backend: bool = False,
camera_settings: Optional[Dict] = None):
"""Initializes a new simulation.
Args:
model_file: The MuJoCo XML model file to load.
use_dm_backend: If True, uses DM Control's Physics (MuJoCo v2.0) as
the backend for the simulation. Otherwise, uses mujoco_py (MuJoCo
v1.5) as the backend.
camera_settings: Settings to initialize the renderer's camera. This
can contain the keys `distance`, `azimuth`, and `elevation`.
"""
self._use_dm_backend = use_dm_backend
if not os.path.isfile(model_file):
raise ValueError(
'[MujocoSimRobot] Invalid model file path: {}'.format(
model_file))
if self._use_dm_backend:
dm_mujoco = module.get_dm_mujoco()
if model_file.endswith('.mjb'):
self.sim = dm_mujoco.Physics.from_binary_path(model_file)
else:
self.sim = dm_mujoco.Physics.from_xml_path(model_file)
self.model = self.sim.model
self._patch_mjlib_accessors(self.model, self.sim.data)
self.renderer = DMRenderer(
self.sim, camera_settings=camera_settings)
else: # Use mujoco_py
mujoco_py = module.get_mujoco_py()
self.model = mujoco_py.load_model_from_path(model_file)
self.sim = mujoco_py.MjSim(self.model)
self.renderer = MjPyRenderer(
self.sim, camera_settings=camera_settings)
self.data = self.sim.data
def close(self):
"""Cleans up any resources being used by the simulation."""
self.renderer.close()
def save_binary(self, path: str):
"""Saves the loaded model to a binary .mjb file."""
if os.path.exists(path):
raise ValueError(
'[MujocoSimRobot] Path already exists: {}'.format(path))
if not path.endswith('.mjb'):
path = path + '.mjb'
if self._use_dm_backend:
self.model.save_binary(path)
else:
with open(path, 'wb') as f:
f.write(self.model.get_mjb())
def get_mjlib(self):
"""Returns an object that exposes the low-level MuJoCo API."""
if self._use_dm_backend:
return module.get_dm_mujoco().wrapper.mjbindings.mjlib
else:
return module.get_mujoco_py_mjlib()
def _patch_mjlib_accessors(self, model, data):
"""Adds accessors to the DM Control objects to support mujoco_py API."""
assert self._use_dm_backend
mjlib = self.get_mjlib()
def name2id(type_name, name):
obj_id = mjlib.mj_name2id(model.ptr,
mjlib.mju_str2Type(type_name.encode()),
name.encode())
if obj_id < 0:
raise ValueError('No {} with name "{}" exists.'.format(
type_name, name))
return obj_id
if not hasattr(model, 'body_name2id'):
model.body_name2id = lambda name: name2id('body', name)
if not hasattr(model, 'geom_name2id'):
model.geom_name2id = lambda name: name2id('geom', name)
if not hasattr(model, 'site_name2id'):
model.site_name2id = lambda name: name2id('site', name)
if not hasattr(model, 'joint_name2id'):
model.joint_name2id = lambda name: name2id('joint', name)
if not hasattr(model, 'actuator_name2id'):
model.actuator_name2id = lambda name: name2id('actuator', name)
if not hasattr(model, 'camera_name2id'):
model.camera_name2id = lambda name: name2id('camera', name)
if not hasattr(data, 'body_xpos'):
data.body_xpos = data.xpos
if not hasattr(data, 'body_xquat'):
data.body_xquat = data.xquat
================================================
FILE: d4rl/d4rl/kitchen/adept_envs/utils/__init__.py
================================================
================================================
FILE: d4rl/d4rl/kitchen/adept_envs/utils/config.py
================================================
#!/usr/bin/python
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
try:
import cElementTree as ET
except ImportError:
try:
# Python 2.5 need to import a different module
import xml.etree.cElementTree as ET
except ImportError:
exit_err("Failed to import cElementTree from any known place")
CONFIG_XML_DATA = """
"""
# Read config from root
def read_config_from_node(root_node, parent_name, child_name, dtype=int):
# find parent
parent_node = root_node.find(parent_name)
if parent_node == None:
quit("Parent %s not found" % parent_name)
# get child data
child_data = parent_node.get(child_name)
if child_data == None:
quit("Child %s not found" % child_name)
config_val = np.array(child_data.split(), dtype=dtype)
return config_val
# get config frlom file or string
def get_config_root_node(config_file_name=None, config_file_data=None):
try:
# get root
if config_file_data is None:
config_file_content = open(config_file_name, "r")
config = ET.parse(config_file_content)
root_node = config.getroot()
else:
root_node = ET.fromstring(config_file_data)
# get root data
root_data = root_node.get('name')
root_name = np.array(root_data.split(), dtype=str)
except:
quit("ERROR: Unable to process config file %s" % config_file_name)
return root_node, root_name
# Read config from config_file
def read_config_from_xml(config_file_name, parent_name, child_name, dtype=int):
root_node, root_name = get_config_root_node(
config_file_name=config_file_name)
return read_config_from_node(root_node, parent_name, child_name, dtype)
# tests
if __name__ == '__main__':
print("Read config and parse -------------------------")
root, root_name = get_config_root_node(config_file_data=CONFIG_XML_DATA)
print("Root:name \t", root_name)
print("limit:low \t", read_config_from_node(root, "limits", "low", float))
print("limit:high \t", read_config_from_node(root, "limits", "high", float))
print("scale:joint \t", read_config_from_node(root, "scale", "joint",
float))
print("data:type \t", read_config_from_node(root, "data", "type", str))
# read straight from xml (dum the XML data as duh.xml for this test)
root, root_name = get_config_root_node(config_file_name="duh.xml")
print("Read from xml --------------------------------")
print("limit:low \t", read_config_from_xml("duh.xml", "limits", "low",
float))
print("limit:high \t",
read_config_from_xml("duh.xml", "limits", "high", float))
print("scale:joint \t",
read_config_from_xml("duh.xml", "scale", "joint", float))
print("data:type \t", read_config_from_xml("duh.xml", "data", "type", str))
================================================
FILE: d4rl/d4rl/kitchen/adept_envs/utils/configurable.py
================================================
#!/usr/bin/python
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
import os
from gym.envs.registration import registry as gym_registry
def import_class_from_path(class_path):
"""Given 'path.to.module:object', imports and returns the object."""
module_path, class_name = class_path.split(":")
module = importlib.import_module(module_path)
return getattr(module, class_name)
class ConfigCache(object):
"""Configuration class to store constructor arguments.
This is used to store parameters to pass to Gym environments at init time.
"""
def __init__(self):
self._configs = {}
self._default_config = {}
def set_default_config(self, config):
"""Sets the default configuration used for all RobotEnv envs."""
self._default_config = dict(config)
def set_config(self, cls_or_env_id, config):
"""Sets the configuration for the given environment within a context.
Args:
cls_or_env_id (Class | str): A class type or Gym environment ID to
configure.
config (dict): The configuration parameters.
"""
config_key = self._get_config_key(cls_or_env_id)
self._configs[config_key] = dict(config)
def get_config(self, cls_or_env_id):
"""Returns the configuration for the given env name.
Args:
cls_or_env_id (Class | str): A class type or Gym environment ID to
get the configuration of.
"""
config_key = self._get_config_key(cls_or_env_id)
config = dict(self._default_config)
config.update(self._configs.get(config_key, {}))
return config
def clear_config(self, cls_or_env_id):
"""Clears the configuration for the given ID."""
config_key = self._get_config_key(cls_or_env_id)
if config_key in self._configs:
del self._configs[config_key]
def _get_config_key(self, cls_or_env_id):
if inspect.isclass(cls_or_env_id):
return cls_or_env_id
env_id = cls_or_env_id
assert isinstance(env_id, str)
if env_id not in gym_registry.env_specs:
raise ValueError("Unregistered environment name {}.".format(env_id))
entry_point = gym_registry.env_specs[env_id]._entry_point
if callable(entry_point):
return entry_point
else:
return import_class_from_path(entry_point)
# Global robot config.
global_config = ConfigCache()
def configurable(config_id=None, pickleable=False, config_cache=global_config):
"""Class decorator to allow injection of constructor arguments.
This allows constructor arguments to be passed via ConfigCache.
Example usage:
@configurable()
class A:
def __init__(b=None, c=2, d='Wow'):
...
global_config.set_config(A, {'b': 10, 'c': 20})
a = A() # b=10, c=20, d='Wow'
a = A(b=30) # b=30, c=20, d='Wow'
Args:
config_id: ID of the config to use. This defaults to the class type.
pickleable: Whether this class is pickleable. If true, causes the pickle
state to include the config and constructor arguments.
config_cache: The ConfigCache to use to read config data from. Uses
the global ConfigCache by default.
"""
def cls_decorator(cls):
assert inspect.isclass(cls)
# Overwrite the class constructor to pass arguments from the config.
base_init = cls.__init__
def __init__(self, *args, **kwargs):
config = config_cache.get_config(config_id or type(self))
# Allow kwargs to override the config.
kwargs = {**config, **kwargs}
# print('Initializing {} with params: {}'.format(type(self).__name__,
# kwargs))
if pickleable:
self._pkl_env_args = args
self._pkl_env_kwargs = kwargs
base_init(self, *args, **kwargs)
cls.__init__ = __init__
# If the class is pickleable, overwrite the state methods to save
# the constructor arguments and config.
if pickleable:
# Use same pickle keys as gym.utils.ezpickle for backwards compat.
PKL_ARGS_KEY = '_ezpickle_args'
PKL_KWARGS_KEY = '_ezpickle_kwargs'
def __getstate__(self):
return {
PKL_ARGS_KEY: self._pkl_env_args,
PKL_KWARGS_KEY: self._pkl_env_kwargs,
}
cls.__getstate__ = __getstate__
def __setstate__(self, data):
saved_args = data[PKL_ARGS_KEY]
saved_kwargs = data[PKL_KWARGS_KEY]
# Override the saved state with the current config.
config = config_cache.get_config(config_id or type(self))
# Allow kwargs to override the config.
kwargs = {**saved_kwargs, **config}
inst = type(self)(*saved_args, **kwargs)
self.__dict__.update(inst.__dict__)
cls.__setstate__ = __setstate__
return cls
return cls_decorator
================================================
FILE: d4rl/d4rl/kitchen/adept_envs/utils/constants.py
================================================
#!/usr/bin/python
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
ENVS_ROOT_PATH = os.path.abspath(os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"../../"))
MODELS_PATH = os.path.abspath(os.path.join(ENVS_ROOT_PATH, "../adept_models/"))
================================================
FILE: d4rl/d4rl/kitchen/adept_envs/utils/parse_demos.py
================================================
#!/usr/bin/python
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import click
import glob
import pickle
import numpy as np
from parse_mjl import parse_mjl_logs, viz_parsed_mjl_logs
from mjrl.utils.gym_env import GymEnv
import adept_envs
import time as timer
import skvideo.io
import gym
# headless renderer
render_buffer = [] # rendering buffer
def viewer(env,
mode='initialize',
filename='video',
frame_size=(640, 480),
camera_id=0,
render=None):
if render == 'onscreen':
env.mj_render()
elif render == 'offscreen':
global render_buffer
if mode == 'initialize':
render_buffer = []
mode = 'render'
if mode == 'render':
curr_frame = env.render(mode='rgb_array')
render_buffer.append(curr_frame)
if mode == 'save':
skvideo.io.vwrite(filename, np.asarray(render_buffer))
print("\noffscreen buffer saved", filename)
elif render == 'None':
pass
else:
print("unknown render: ", render)
# view demos (physics ignored)
def render_demos(env, data, filename='demo_rendering.mp4', render=None):
FPS = 30
render_skip = max(1, round(1. / \
(FPS * env.sim.model.opt.timestep * env.frame_skip)))
t0 = timer.time()
viewer(env, mode='initialize', render=render)
for i_frame in range(data['ctrl'].shape[0]):
env.sim.data.qpos[:] = data['qpos'][i_frame].copy()
env.sim.data.qvel[:] = data['qvel'][i_frame].copy()
env.sim.forward()
if i_frame % render_skip == 0:
viewer(env, mode='render', render=render)
print(i_frame, end=', ', flush=True)
viewer(env, mode='save', filename=filename, render=render)
print("time taken = %f" % (timer.time() - t0))
# playback demos and get data(physics respected)
def gather_training_data(env, data, filename='demo_playback.mp4', render=None):
env = env.env
FPS = 30
render_skip = max(1, round(1. / \
(FPS * env.sim.model.opt.timestep * env.frame_skip)))
t0 = timer.time()
# initialize
env.reset()
init_qpos = data['qpos'][0].copy()
init_qvel = data['qvel'][0].copy()
act_mid = env.act_mid
act_rng = env.act_amp
# prepare env
env.sim.data.qpos[:] = init_qpos
env.sim.data.qvel[:] = init_qvel
env.sim.forward()
viewer(env, mode='initialize', render=render)
# step the env and gather data
path_obs = None
for i_frame in range(data['ctrl'].shape[0] - 1):
# Reset every time step
# if i_frame % 1 == 0:
# qp = data['qpos'][i_frame].copy()
# qv = data['qvel'][i_frame].copy()
# env.sim.data.qpos[:] = qp
# env.sim.data.qvel[:] = qv
# env.sim.forward()
obs = env._get_obs()
# Construct the action
# ctrl = (data['qpos'][i_frame + 1][:9] - obs[:9]) / (env.skip * env.model.opt.timestep)
ctrl = (data['ctrl'][i_frame] - obs[:9])/(env.skip*env.model.opt.timestep)
act = (ctrl - act_mid) / act_rng
act = np.clip(act, -0.999, 0.999)
next_obs, reward, done, env_info = env.step(act)
if path_obs is None:
path_obs = obs
path_act = act
else:
path_obs = np.vstack((path_obs, obs))
path_act = np.vstack((path_act, act))
# render when needed to maintain FPS
if i_frame % render_skip == 0:
viewer(env, mode='render', render=render)
print(i_frame, end=', ', flush=True)
# finalize
if render:
viewer(env, mode='save', filename=filename, render=render)
t1 = timer.time()
print("time taken = %f" % (t1 - t0))
# note that are one step away from
return path_obs, path_act, init_qpos, init_qvel
# MAIN =========================================================
@click.command(help="parse tele-op demos")
@click.option('--env', '-e', type=str, help='gym env name', required=True)
@click.option(
'--demo_dir',
'-d',
type=str,
help='directory with tele-op logs',
required=True)
@click.option(
'--skip',
'-s',
type=int,
help='number of frames to skip (1:no skip)',
default=1)
@click.option('--graph', '-g', type=bool, help='plot logs', default=False)
@click.option('--save_logs', '-l', type=bool, help='save logs', default=False)
@click.option(
'--view', '-v', type=str, help='render/playback', default='render')
@click.option(
'--render', '-r', type=str, help='onscreen/offscreen', default='onscreen')
def main(env, demo_dir, skip, graph, save_logs, view, render):
gym_env = gym.make(env)
paths = []
print("Scanning demo_dir: " + demo_dir + "=========")
for ind, file in enumerate(glob.glob(demo_dir + "*.mjl")):
# process logs
print("processing: " + file, end=': ')
data = parse_mjl_logs(file, skip)
print("log duration %0.2f" % (data['time'][-1] - data['time'][0]))
# plot logs
if (graph):
print("plotting: " + file)
viz_parsed_mjl_logs(data)
# save logs
if (save_logs):
pickle.dump(data, open(file[:-4] + ".pkl", 'wb'))
# render logs to video
if view == 'render':
render_demos(
gym_env,
data,
filename=data['logName'][:-4] + '_demo_render.mp4',
render=render)
# playback logs and gather data
elif view == 'playback':
try:
obs, act,init_qpos, init_qvel = gather_training_data(gym_env, data,\
filename=data['logName'][:-4]+'_playback.mp4', render=render)
except Exception as e:
print(e)
continue
path = {
'observations': obs,
'actions': act,
'goals': obs,
'init_qpos': init_qpos,
'init_qvel': init_qvel
}
paths.append(path)
# accept = input('accept demo?')
# if accept == 'n':
# continue
pickle.dump(path, open(demo_dir + env + str(ind) + "_path.pkl", 'wb'))
print(demo_dir + env + file + "_path.pkl")
if __name__ == '__main__':
main()
================================================
FILE: d4rl/d4rl/kitchen/adept_envs/utils/quatmath.py
================================================
#!/usr/bin/python
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
# For testing whether a number is close to zero
_FLOAT_EPS = np.finfo(np.float64).eps
_EPS4 = _FLOAT_EPS * 4.0
def mulQuat(qa, qb):
res = np.zeros(4)
res[0] = qa[0]*qb[0] - qa[1]*qb[1] - qa[2]*qb[2] - qa[3]*qb[3]
res[1] = qa[0]*qb[1] + qa[1]*qb[0] + qa[2]*qb[3] - qa[3]*qb[2]
res[2] = qa[0]*qb[2] - qa[1]*qb[3] + qa[2]*qb[0] + qa[3]*qb[1]
res[3] = qa[0]*qb[3] + qa[1]*qb[2] - qa[2]*qb[1] + qa[3]*qb[0]
return res
def negQuat(quat):
return np.array([quat[0], -quat[1], -quat[2], -quat[3]])
def quat2Vel(quat, dt=1):
axis = quat[1:].copy()
sin_a_2 = np.sqrt(np.sum(axis**2))
axis = axis/(sin_a_2+1e-8)
speed = 2*np.arctan2(sin_a_2, quat[0])/dt
return speed, axis
def quatDiff2Vel(quat1, quat2, dt):
neg = negQuat(quat1)
diff = mulQuat(quat2, neg)
return quat2Vel(diff, dt)
def axis_angle2quat(axis, angle):
c = np.cos(angle/2)
s = np.sin(angle/2)
return np.array([c, s*axis[0], s*axis[1], s*axis[2]])
def euler2mat(euler):
""" Convert Euler Angles to Rotation Matrix. See rotation.py for notes """
euler = np.asarray(euler, dtype=np.float64)
assert euler.shape[-1] == 3, "Invalid shaped euler {}".format(euler)
ai, aj, ak = -euler[..., 2], -euler[..., 1], -euler[..., 0]
si, sj, sk = np.sin(ai), np.sin(aj), np.sin(ak)
ci, cj, ck = np.cos(ai), np.cos(aj), np.cos(ak)
cc, cs = ci * ck, ci * sk
sc, ss = si * ck, si * sk
mat = np.empty(euler.shape[:-1] + (3, 3), dtype=np.float64)
mat[..., 2, 2] = cj * ck
mat[..., 2, 1] = sj * sc - cs
mat[..., 2, 0] = sj * cc + ss
mat[..., 1, 2] = cj * sk
mat[..., 1, 1] = sj * ss + cc
mat[..., 1, 0] = sj * cs - sc
mat[..., 0, 2] = -sj
mat[..., 0, 1] = cj * si
mat[..., 0, 0] = cj * ci
return mat
def euler2quat(euler):
""" Convert Euler Angles to Quaternions. See rotation.py for notes """
euler = np.asarray(euler, dtype=np.float64)
assert euler.shape[-1] == 3, "Invalid shape euler {}".format(euler)
ai, aj, ak = euler[..., 2] / 2, -euler[..., 1] / 2, euler[..., 0] / 2
si, sj, sk = np.sin(ai), np.sin(aj), np.sin(ak)
ci, cj, ck = np.cos(ai), np.cos(aj), np.cos(ak)
cc, cs = ci * ck, ci * sk
sc, ss = si * ck, si * sk
quat = np.empty(euler.shape[:-1] + (4,), dtype=np.float64)
quat[..., 0] = cj * cc + sj * ss
quat[..., 3] = cj * sc - sj * cs
quat[..., 2] = -(cj * ss + sj * cc)
quat[..., 1] = cj * cs - sj * sc
return quat
def mat2euler(mat):
""" Convert Rotation Matrix to Euler Angles. See rotation.py for notes """
mat = np.asarray(mat, dtype=np.float64)
assert mat.shape[-2:] == (3, 3), "Invalid shape matrix {}".format(mat)
cy = np.sqrt(mat[..., 2, 2] * mat[..., 2, 2] + mat[..., 1, 2] * mat[..., 1, 2])
condition = cy > _EPS4
euler = np.empty(mat.shape[:-1], dtype=np.float64)
euler[..., 2] = np.where(condition,
-np.arctan2(mat[..., 0, 1], mat[..., 0, 0]),
-np.arctan2(-mat[..., 1, 0], mat[..., 1, 1]))
euler[..., 1] = np.where(condition,
-np.arctan2(-mat[..., 0, 2], cy),
-np.arctan2(-mat[..., 0, 2], cy))
euler[..., 0] = np.where(condition,
-np.arctan2(mat[..., 1, 2], mat[..., 2, 2]),
0.0)
return euler
def mat2quat(mat):
""" Convert Rotation Matrix to Quaternion. See rotation.py for notes """
mat = np.asarray(mat, dtype=np.float64)
assert mat.shape[-2:] == (3, 3), "Invalid shape matrix {}".format(mat)
Qxx, Qyx, Qzx = mat[..., 0, 0], mat[..., 0, 1], mat[..., 0, 2]
Qxy, Qyy, Qzy = mat[..., 1, 0], mat[..., 1, 1], mat[..., 1, 2]
Qxz, Qyz, Qzz = mat[..., 2, 0], mat[..., 2, 1], mat[..., 2, 2]
# Fill only lower half of symmetric matrix
K = np.zeros(mat.shape[:-2] + (4, 4), dtype=np.float64)
K[..., 0, 0] = Qxx - Qyy - Qzz
K[..., 1, 0] = Qyx + Qxy
K[..., 1, 1] = Qyy - Qxx - Qzz
K[..., 2, 0] = Qzx + Qxz
K[..., 2, 1] = Qzy + Qyz
K[..., 2, 2] = Qzz - Qxx - Qyy
K[..., 3, 0] = Qyz - Qzy
K[..., 3, 1] = Qzx - Qxz
K[..., 3, 2] = Qxy - Qyx
K[..., 3, 3] = Qxx + Qyy + Qzz
K /= 3.0
# TODO: vectorize this -- probably could be made faster
q = np.empty(K.shape[:-2] + (4,))
it = np.nditer(q[..., 0], flags=['multi_index'])
while not it.finished:
# Use Hermitian eigenvectors, values for speed
vals, vecs = np.linalg.eigh(K[it.multi_index])
# Select largest eigenvector, reorder to w,x,y,z quaternion
q[it.multi_index] = vecs[[3, 0, 1, 2], np.argmax(vals)]
# Prefer quaternion with positive w
# (q * -1 corresponds to same rotation as q)
if q[it.multi_index][0] < 0:
q[it.multi_index] *= -1
it.iternext()
return q
def quat2euler(quat):
""" Convert Quaternion to Euler Angles. See rotation.py for notes """
return mat2euler(quat2mat(quat))
def quat2mat(quat):
""" Convert Quaternion to Euler Angles. See rotation.py for notes """
quat = np.asarray(quat, dtype=np.float64)
assert quat.shape[-1] == 4, "Invalid shape quat {}".format(quat)
w, x, y, z = quat[..., 0], quat[..., 1], quat[..., 2], quat[..., 3]
Nq = np.sum(quat * quat, axis=-1)
s = 2.0 / Nq
X, Y, Z = x * s, y * s, z * s
wX, wY, wZ = w * X, w * Y, w * Z
xX, xY, xZ = x * X, x * Y, x * Z
yY, yZ, zZ = y * Y, y * Z, z * Z
mat = np.empty(quat.shape[:-1] + (3, 3), dtype=np.float64)
mat[..., 0, 0] = 1.0 - (yY + zZ)
mat[..., 0, 1] = xY - wZ
mat[..., 0, 2] = xZ + wY
mat[..., 1, 0] = xY + wZ
mat[..., 1, 1] = 1.0 - (xX + zZ)
mat[..., 1, 2] = yZ - wX
mat[..., 2, 0] = xZ - wY
mat[..., 2, 1] = yZ + wX
mat[..., 2, 2] = 1.0 - (xX + yY)
return np.where((Nq > _FLOAT_EPS)[..., np.newaxis, np.newaxis], mat, np.eye(3))
================================================
FILE: d4rl/d4rl/kitchen/adept_models/.gitignore
================================================
# General
.DS_Store
*.swp
*.profraw
# Editors
.vscode
.idea
================================================
FILE: d4rl/d4rl/kitchen/adept_models/CONTRIBUTING.public.md
================================================
# How to Contribute
We'd love to accept your patches and contributions to this project. There are
just a few small guidelines you need to follow.
## Contributor License Agreement
Contributions to this project must be accompanied by a Contributor License
Agreement. You (or your employer) retain the copyright to your contribution;
this simply gives us permission to use and redistribute your contributions as
part of the project. Head over to to see
your current agreements on file or to sign a new one.
You generally only need to submit a CLA once, so if you've already submitted one
(even if it was for a different project), you probably don't need to do it
again.
## Code reviews
All submissions, including submissions by project members, require review. We
use GitHub pull requests for this purpose. Consult
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
information on using pull requests.
## Community Guidelines
This project follows
[Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
================================================
FILE: d4rl/d4rl/kitchen/adept_models/LICENSE
================================================
Copyright 2019 The DSuite Authors. All rights reserved.
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: d4rl/d4rl/kitchen/adept_models/README.public.md
================================================
# D'Suite Scenes
This repository is based on a collection of [MuJoCo](http://www.mujoco.org/) simulation
scenes and common assets for D'Suite environments. Based on code in the ROBEL suite
https://github.com/google-research/robel
## Disclaimer
This is not an official Google product.
================================================
FILE: d4rl/d4rl/kitchen/adept_models/__init__.py
================================================
================================================
FILE: d4rl/d4rl/kitchen/adept_models/kitchen/assets/backwall_asset.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/adept_models/kitchen/assets/backwall_chain.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/adept_models/kitchen/assets/counters_asset.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/adept_models/kitchen/assets/counters_chain.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/adept_models/kitchen/assets/hingecabinet_asset.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/adept_models/kitchen/assets/hingecabinet_chain.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/adept_models/kitchen/assets/kettle_asset.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/adept_models/kitchen/assets/kettle_chain.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/adept_models/kitchen/assets/microwave_asset.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/adept_models/kitchen/assets/microwave_chain.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/adept_models/kitchen/assets/oven_asset.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/adept_models/kitchen/assets/oven_chain.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/adept_models/kitchen/assets/slidecabinet_asset.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/adept_models/kitchen/assets/slidecabinet_chain.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/adept_models/kitchen/counters.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/adept_models/kitchen/hingecabinet.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/adept_models/kitchen/kettle.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/adept_models/kitchen/kitchen.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/adept_models/kitchen/microwave.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/adept_models/kitchen/oven.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/adept_models/kitchen/slidecabinet.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/adept_models/scenes/basic_scene.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/kitchen_envs.py
================================================
"""Environments using kitchen and Franka robot."""
import os
import numpy as np
from d4rl.kitchen.adept_envs.utils.configurable import configurable
from d4rl.kitchen.adept_envs.franka.kitchen_multitask_v0 import KitchenTaskRelaxV1
from d4rl.offline_env import OfflineEnv
OBS_ELEMENT_INDICES = {
'bottom burner': np.array([11, 12]),
'top burner': np.array([15, 16]),
'light switch': np.array([17, 18]),
'slide cabinet': np.array([19]),
'hinge cabinet': np.array([20, 21]),
'microwave': np.array([22]),
'kettle': np.array([23, 24, 25, 26, 27, 28, 29]),
}
OBS_ELEMENT_GOALS = {
'bottom burner': np.array([-0.88, -0.01]),
'top burner': np.array([-0.92, -0.01]),
'light switch': np.array([-0.69, -0.05]),
'slide cabinet': np.array([0.37]),
'hinge cabinet': np.array([0., 1.45]),
'microwave': np.array([-0.75]),
'kettle': np.array([-0.23, 0.75, 1.62, 0.99, 0., 0., -0.06]),
}
BONUS_THRESH = 0.3
@configurable(pickleable=True)
class KitchenBase(KitchenTaskRelaxV1, OfflineEnv):
# A string of element names. The robot's task is then to modify each of
# these elements appropriately.
TASK_ELEMENTS = []
REMOVE_TASKS_WHEN_COMPLETE = True
TERMINATE_ON_TASK_COMPLETE = True
def __init__(self, dataset_url=None, ref_max_score=None, ref_min_score=None, **kwargs):
self.tasks_to_complete = set(self.TASK_ELEMENTS)
super(KitchenBase, self).__init__(**kwargs)
OfflineEnv.__init__(
self,
dataset_url=dataset_url,
ref_max_score=ref_max_score,
ref_min_score=ref_min_score)
def _get_task_goal(self):
new_goal = np.zeros_like(self.goal)
for element in self.TASK_ELEMENTS:
element_idx = OBS_ELEMENT_INDICES[element]
element_goal = OBS_ELEMENT_GOALS[element]
new_goal[element_idx] = element_goal
return new_goal
def reset_model(self):
self.tasks_to_complete = set(self.TASK_ELEMENTS)
return super(KitchenBase, self).reset_model()
def _get_reward_n_score(self, obs_dict):
reward_dict, score = super(KitchenBase, self)._get_reward_n_score(obs_dict)
reward = 0.
next_q_obs = obs_dict['qp']
next_obj_obs = obs_dict['obj_qp']
next_goal = obs_dict['goal']
idx_offset = len(next_q_obs)
completions = []
for element in self.tasks_to_complete:
element_idx = OBS_ELEMENT_INDICES[element]
distance = np.linalg.norm(
next_obj_obs[..., element_idx - idx_offset] -
next_goal[element_idx])
complete = distance < BONUS_THRESH
if complete:
completions.append(element)
if self.REMOVE_TASKS_WHEN_COMPLETE:
[self.tasks_to_complete.remove(element) for element in completions]
bonus = float(len(completions))
reward_dict['bonus'] = bonus
reward_dict['r_total'] = bonus
score = bonus
return reward_dict, score
def step(self, a, b=None):
obs, reward, done, env_info = super(KitchenBase, self).step(a, b=b)
if self.TERMINATE_ON_TASK_COMPLETE:
done = not self.tasks_to_complete
return obs, reward, done, env_info
def render(self, mode='human'):
# Disable rendering to speed up environment evaluation.
return []
class KitchenMicrowaveKettleLightSliderV0(KitchenBase):
TASK_ELEMENTS = ['microwave', 'kettle', 'light switch', 'slide cabinet']
class KitchenMicrowaveKettleBottomBurnerLightV0(KitchenBase):
TASK_ELEMENTS = ['microwave', 'kettle', 'bottom burner', 'light switch']
================================================
FILE: d4rl/d4rl/kitchen/third_party/franka/LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: d4rl/d4rl/kitchen/third_party/franka/README.md
================================================
# franka
Franka panda mujoco models
# Environment
franka_panda.xml | comming soon
:-------------------------:|:-------------------------:
 | comming soon
================================================
FILE: d4rl/d4rl/kitchen/third_party/franka/assets/actuator0.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/third_party/franka/assets/actuator1.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/third_party/franka/assets/assets.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/third_party/franka/assets/basic_scene.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/third_party/franka/assets/chain0.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/third_party/franka/assets/chain0_overlay.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/third_party/franka/assets/chain1.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/third_party/franka/assets/teleop_actuator.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/third_party/franka/bi-franka_panda.xml
================================================
/
================================================
FILE: d4rl/d4rl/kitchen/third_party/franka/franka_panda.xml
================================================
================================================
FILE: d4rl/d4rl/kitchen/third_party/franka/franka_panda_teleop.xml
================================================
================================================
FILE: d4rl/d4rl/locomotion/__init__.py
================================================
from gym.envs.registration import register
from d4rl.locomotion import ant
from d4rl.locomotion import maze_env
"""
register(
id='antmaze-umaze-v0',
entry_point='d4rl.locomotion.ant:make_ant_maze_env',
max_episode_steps=700,
kwargs={
'maze_map': maze_env.U_MAZE_TEST,
'reward_type':'sparse',
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse.hdf5',
'non_zero_reset':False,
'eval':True,
'maze_size_scaling': 4.0,
'ref_min_score': 0.0,
'ref_max_score': 1.0,
}
)
"""
register(
id='antmaze-umaze-v0',
entry_point='d4rl.locomotion.ant:make_ant_maze_env',
max_episode_steps=700,
kwargs={
'deprecated': True,
'maze_map': maze_env.U_MAZE_TEST,
'reward_type':'sparse',
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse.hdf5',
'non_zero_reset':False,
'eval':True,
'maze_size_scaling': 4.0,
'ref_min_score': 0.0,
'ref_max_score': 1.0,
}
)
register(
id='antmaze-umaze-diverse-v0',
entry_point='d4rl.locomotion.ant:make_ant_maze_env',
max_episode_steps=700,
kwargs={
'deprecated': True,
'maze_map': maze_env.U_MAZE_TEST,
'reward_type':'sparse',
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_True_multigoal_True_sparse.hdf5',
'non_zero_reset':False,
'eval':True,
'maze_size_scaling': 4.0,
'ref_min_score': 0.0,
'ref_max_score': 1.0,
}
)
register(
id='antmaze-medium-play-v0',
entry_point='d4rl.locomotion.ant:make_ant_maze_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'maze_map': maze_env.BIG_MAZE_TEST,
'reward_type':'sparse',
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_False_sparse.hdf5',
'non_zero_reset':False,
'eval':True,
'maze_size_scaling': 4.0,
'ref_min_score': 0.0,
'ref_max_score': 1.0,
}
)
register(
id='antmaze-medium-diverse-v0',
entry_point='d4rl.locomotion.ant:make_ant_maze_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'maze_map': maze_env.BIG_MAZE_TEST,
'reward_type':'sparse',
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_True_sparse.hdf5',
'non_zero_reset':False,
'eval':True,
'maze_size_scaling': 4.0,
'ref_min_score': 0.0,
'ref_max_score': 1.0,
}
)
register(
id='antmaze-large-diverse-v0',
entry_point='d4rl.locomotion.ant:make_ant_maze_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'maze_map': maze_env.HARDEST_MAZE_TEST,
'reward_type':'sparse',
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_True_sparse.hdf5',
'non_zero_reset':False,
'eval':True,
'maze_size_scaling': 4.0,
'ref_min_score': 0.0,
'ref_max_score': 1.0,
}
)
register(
id='antmaze-large-play-v0',
entry_point='d4rl.locomotion.ant:make_ant_maze_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'maze_map': maze_env.HARDEST_MAZE_TEST,
'reward_type':'sparse',
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_sparse.hdf5',
'non_zero_reset':False,
'eval':True,
'maze_size_scaling': 4.0,
'ref_min_score': 0.0,
'ref_max_score': 1.0,
}
)
register(
id='antmaze-umaze-v1',
entry_point='d4rl.locomotion.ant:make_ant_maze_env',
max_episode_steps=700,
kwargs={
'deprecated': True,
'maze_map': maze_env.U_MAZE_TEST,
'reward_type':'sparse',
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_umaze_noisy_multistart_False_multigoal_False_sparse.hdf5',
'non_zero_reset':False,
'eval':True,
'maze_size_scaling': 4.0,
'ref_min_score': 0.0,
'ref_max_score': 1.0,
}
)
register(
id='antmaze-umaze-diverse-v1',
entry_point='d4rl.locomotion.ant:make_ant_maze_env',
max_episode_steps=700,
kwargs={
'deprecated': True,
'maze_map': maze_env.U_MAZE_TEST,
'reward_type':'sparse',
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_umaze_noisy_multistart_True_multigoal_True_sparse.hdf5',
'non_zero_reset':False,
'eval':True,
'maze_size_scaling': 4.0,
'ref_min_score': 0.0,
'ref_max_score': 1.0,
}
)
register(
id='antmaze-medium-play-v1',
entry_point='d4rl.locomotion.ant:make_ant_maze_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'maze_map': maze_env.BIG_MAZE_TEST,
'reward_type':'sparse',
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_medium_noisy_multistart_True_multigoal_False_sparse.hdf5',
'non_zero_reset':False,
'eval':True,
'maze_size_scaling': 4.0,
'ref_min_score': 0.0,
'ref_max_score': 1.0,
}
)
register(
id='antmaze-medium-diverse-v1',
entry_point='d4rl.locomotion.ant:make_ant_maze_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'maze_map': maze_env.BIG_MAZE_TEST,
'reward_type':'sparse',
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_medium_noisy_multistart_True_multigoal_True_sparse.hdf5',
'non_zero_reset':False,
'eval':True,
'maze_size_scaling': 4.0,
'ref_min_score': 0.0,
'ref_max_score': 1.0,
}
)
register(
id='antmaze-large-diverse-v1',
entry_point='d4rl.locomotion.ant:make_ant_maze_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'maze_map': maze_env.HARDEST_MAZE_TEST,
'reward_type':'sparse',
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_large_noisy_multistart_True_multigoal_True_sparse.hdf5',
'non_zero_reset':False,
'eval':True,
'maze_size_scaling': 4.0,
'ref_min_score': 0.0,
'ref_max_score': 1.0,
}
)
register(
id='antmaze-large-play-v1',
entry_point='d4rl.locomotion.ant:make_ant_maze_env',
max_episode_steps=1000,
kwargs={
'deprecated': True,
'maze_map': maze_env.HARDEST_MAZE_TEST,
'reward_type':'sparse',
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_large_noisy_multistart_True_multigoal_False_sparse.hdf5',
'non_zero_reset':False,
'eval':True,
'maze_size_scaling': 4.0,
'ref_min_score': 0.0,
'ref_max_score': 1.0,
}
)
register(
id='antmaze-eval-umaze-v0',
entry_point='d4rl.locomotion.ant:make_ant_maze_env',
max_episode_steps=700,
kwargs={
'maze_map': maze_env.U_MAZE_EVAL_TEST,
'reward_type':'sparse',
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_umaze_eval_noisy_multistart_True_multigoal_False_sparse.hdf5',
'non_zero_reset':False,
'eval':True,
'maze_size_scaling': 4.0,
'ref_min_score': 0.0,
'ref_max_score': 1.0,
}
)
register(
id='antmaze-eval-umaze-diverse-v0',
entry_point='d4rl.locomotion.ant:make_ant_maze_env',
max_episode_steps=700,
kwargs={
'maze_map': maze_env.U_MAZE_EVAL_TEST,
'reward_type':'sparse',
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_umaze_eval_noisy_multistart_True_multigoal_True_sparse.hdf5',
'non_zero_reset':False,
'eval':True,
'maze_size_scaling': 4.0,
'ref_min_score': 0.0,
'ref_max_score': 1.0,
}
)
register(
id='antmaze-eval-medium-play-v0',
entry_point='d4rl.locomotion.ant:make_ant_maze_env',
max_episode_steps=1000,
kwargs={
'maze_map': maze_env.BIG_MAZE_EVAL_TEST,
'reward_type':'sparse',
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_medium_eval_noisy_multistart_True_multigoal_True_sparse.hdf5',
'non_zero_reset':False,
'eval':True,
'maze_size_scaling': 4.0,
'ref_min_score': 0.0,
'ref_max_score': 1.0,
}
)
register(
id='antmaze-eval-medium-diverse-v0',
entry_point='d4rl.locomotion.ant:make_ant_maze_env',
max_episode_steps=1000,
kwargs={
'maze_map': maze_env.BIG_MAZE_EVAL_TEST,
'reward_type':'sparse',
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_medium_eval_noisy_multistart_True_multigoal_False_sparse.hdf5',
'non_zero_reset':False,
'eval':True,
'maze_size_scaling': 4.0,
'ref_min_score': 0.0,
'ref_max_score': 1.0,
}
)
register(
id='antmaze-eval-large-diverse-v0',
entry_point='d4rl.locomotion.ant:make_ant_maze_env',
max_episode_steps=1000,
kwargs={
'maze_map': maze_env.HARDEST_MAZE_EVAL_TEST,
'reward_type':'sparse',
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_large_eval_noisy_multistart_True_multigoal_False_sparse.hdf5',
'non_zero_reset':False,
'eval':True,
'maze_size_scaling': 4.0,
'ref_min_score': 0.0,
'ref_max_score': 1.0,
}
)
register(
id='antmaze-eval-large-play-v0',
entry_point='d4rl.locomotion.ant:make_ant_maze_env',
max_episode_steps=1000,
kwargs={
'maze_map': maze_env.HARDEST_MAZE_EVAL_TEST,
'reward_type':'sparse',
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_large_eval_noisy_multistart_True_multigoal_True_sparse.hdf5',
'non_zero_reset':False,
'eval':True,
'maze_size_scaling': 4.0,
'ref_min_score': 0.0,
'ref_max_score': 1.0,
}
)
register(
id='antmaze-umaze-v2',
entry_point='d4rl.locomotion.ant:make_ant_maze_env',
max_episode_steps=700,
kwargs={
'maze_map': maze_env.U_MAZE_TEST,
'reward_type':'sparse',
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse_fixed.hdf5',
'non_zero_reset':False,
'eval':True,
'maze_size_scaling': 4.0,
'ref_min_score': 0.0,
'ref_max_score': 1.0,
'v2_resets': True,
}
)
register(
id='antmaze-umaze-diverse-v2',
entry_point='d4rl.locomotion.ant:make_ant_maze_env',
max_episode_steps=700,
kwargs={
'maze_map': maze_env.U_MAZE_TEST,
'reward_type':'sparse',
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_u-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5',
'non_zero_reset':False,
'eval':True,
'maze_size_scaling': 4.0,
'ref_min_score': 0.0,
'ref_max_score': 1.0,
'v2_resets': True,
}
)
register(
id='antmaze-medium-play-v2',
entry_point='d4rl.locomotion.ant:make_ant_maze_env',
max_episode_steps=1000,
kwargs={
'maze_map': maze_env.BIG_MAZE_TEST,
'reward_type':'sparse',
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_big-maze_noisy_multistart_True_multigoal_False_sparse_fixed.hdf5',
'non_zero_reset':False,
'eval':True,
'maze_size_scaling': 4.0,
'ref_min_score': 0.0,
'ref_max_score': 1.0,
'v2_resets': True,
}
)
register(
id='antmaze-medium-diverse-v2',
entry_point='d4rl.locomotion.ant:make_ant_maze_env',
max_episode_steps=1000,
kwargs={
'maze_map': maze_env.BIG_MAZE_TEST,
'reward_type':'sparse',
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_big-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5',
'non_zero_reset':False,
'eval':True,
'maze_size_scaling': 4.0,
'ref_min_score': 0.0,
'ref_max_score': 1.0,
'v2_resets': True,
}
)
register(
id='antmaze-large-diverse-v2',
entry_point='d4rl.locomotion.ant:make_ant_maze_env',
max_episode_steps=1000,
kwargs={
'maze_map': maze_env.HARDEST_MAZE_TEST,
'reward_type':'sparse',
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5',
'non_zero_reset':False,
'eval':True,
'maze_size_scaling': 4.0,
'ref_min_score': 0.0,
'ref_max_score': 1.0,
'v2_resets': True,
}
)
register(
id='antmaze-large-play-v2',
entry_point='d4rl.locomotion.ant:make_ant_maze_env',
max_episode_steps=1000,
kwargs={
'maze_map': maze_env.HARDEST_MAZE_TEST,
'reward_type':'sparse',
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_sparse_fixed.hdf5',
'non_zero_reset':False,
'eval':True,
'maze_size_scaling': 4.0,
'ref_min_score': 0.0,
'ref_max_score': 1.0,
'v2_resets': True,
}
)
#######################################################
register(
id='antmaze-large-play-dense-v2',
entry_point='d4rl.locomotion.ant:make_ant_maze_env',
max_episode_steps=1000,
kwargs={
'maze_map': maze_env.HARDEST_MAZE_TEST,
'reward_type':'dense',
'dataset_url':'http://dummy_url/ant_maze_v2/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_dense_fixed.hdf5',
'non_zero_reset':False,
'eval':True,
'maze_size_scaling': 4.0,
'ref_min_score': 4.766126556281779e-13,
'ref_max_score': 458.9303516149521,
'v2_resets': True,
}
)
================================================
FILE: d4rl/d4rl/locomotion/ant.py
================================================
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wrapper for creating the ant environment."""
import math
import numpy as np
import mujoco_py
import os
from gym import utils
from gym.envs.mujoco import mujoco_env
from d4rl.locomotion import mujoco_goal_env
from d4rl.locomotion import goal_reaching_env
from d4rl.locomotion import maze_env
from d4rl import offline_env
from d4rl.locomotion import wrappers
GYM_ASSETS_DIR = os.path.join(
os.path.dirname(mujoco_goal_env.__file__),
'assets')
class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
"""Basic ant locomotion environment."""
FILE = os.path.join(GYM_ASSETS_DIR, 'ant.xml')
def __init__(self, file_path=None, expose_all_qpos=False,
expose_body_coms=None, expose_body_comvels=None, non_zero_reset=False):
if file_path is None:
file_path = self.FILE
self._expose_all_qpos = expose_all_qpos
self._expose_body_coms = expose_body_coms
self._expose_body_comvels = expose_body_comvels
self._body_com_indices = {}
self._body_comvel_indices = {}
self._non_zero_reset = non_zero_reset
mujoco_env.MujocoEnv.__init__(self, file_path, 5)
utils.EzPickle.__init__(self)
@property
def physics(self):
# Check mujoco version is greater than version 1.50 to call correct physics
# model containing PyMjData object for getting and setting position/velocity.
# Check https://github.com/openai/mujoco-py/issues/80 for updates to api.
if mujoco_py.get_version() >= '1.50':
return self.sim
else:
return self.model
def _step(self, a):
return self.step(a)
def step(self, a):
xposbefore = self.get_body_com("torso")[0]
self.do_simulation(a, self.frame_skip)
xposafter = self.get_body_com("torso")[0]
forward_reward = (xposafter - xposbefore) / self.dt
ctrl_cost = .5 * np.square(a).sum()
contact_cost = 0.5 * 1e-3 * np.sum(
np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
survive_reward = 1.0
reward = forward_reward - ctrl_cost - contact_cost + survive_reward
state = self.state_vector()
notdone = np.isfinite(state).all() \
and state[2] >= 0.2 and state[2] <= 1.0
done = not notdone
ob = self._get_obs()
return ob, reward, done, dict(
reward_forward=forward_reward,
reward_ctrl=-ctrl_cost,
reward_contact=-contact_cost,
reward_survive=survive_reward)
def _get_obs(self):
# No cfrc observation.
if self._expose_all_qpos:
obs = np.concatenate([
self.physics.data.qpos.flat[:15], # Ensures only ant obs.
self.physics.data.qvel.flat[:14],
])
else:
obs = np.concatenate([
self.physics.data.qpos.flat[2:15],
self.physics.data.qvel.flat[:14],
])
if self._expose_body_coms is not None:
for name in self._expose_body_coms:
com = self.get_body_com(name)
if name not in self._body_com_indices:
indices = range(len(obs), len(obs) + len(com))
self._body_com_indices[name] = indices
obs = np.concatenate([obs, com])
if self._expose_body_comvels is not None:
for name in self._expose_body_comvels:
comvel = self.get_body_comvel(name)
if name not in self._body_comvel_indices:
indices = range(len(obs), len(obs) + len(comvel))
self._body_comvel_indices[name] = indices
obs = np.concatenate([obs, comvel])
return obs
def reset_model(self):
qpos = self.init_qpos + self.np_random.uniform(
size=self.model.nq, low=-.1, high=.1)
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1
if self._non_zero_reset:
"""Now the reset is supposed to be to a non-zero location"""
reset_location = self._get_reset_location()
qpos[:2] = reset_location
# Set everything other than ant to original position and 0 velocity.
qpos[15:] = self.init_qpos[15:]
qvel[14:] = 0.
self.set_state(qpos, qvel)
return self._get_obs()
def viewer_setup(self):
self.viewer.cam.distance = self.model.stat.extent * 0.5
def get_xy(self):
return self.physics.data.qpos[:2]
def set_xy(self, xy):
qpos = np.copy(self.physics.data.qpos)
qpos[0] = xy[0]
qpos[1] = xy[1]
qvel = self.physics.data.qvel
self.set_state(qpos, qvel)
class GoalReachingAntEnv(goal_reaching_env.GoalReachingEnv, AntEnv):
"""Ant locomotion rewarded for goal-reaching."""
BASE_ENV = AntEnv
def __init__(self, goal_sampler=goal_reaching_env.disk_goal_sampler,
file_path=None,
expose_all_qpos=False, non_zero_reset=False, eval=False, reward_type='dense', **kwargs):
goal_reaching_env.GoalReachingEnv.__init__(self, goal_sampler, eval=eval, reward_type=reward_type)
AntEnv.__init__(self,
file_path=file_path,
expose_all_qpos=expose_all_qpos,
expose_body_coms=None,
expose_body_comvels=None,
non_zero_reset=non_zero_reset)
class AntMazeEnv(maze_env.MazeEnv, GoalReachingAntEnv, offline_env.OfflineEnv):
"""Ant navigating a maze."""
LOCOMOTION_ENV = GoalReachingAntEnv
def __init__(self, goal_sampler=None, expose_all_qpos=True,
reward_type='dense', v2_resets=False,
*args, **kwargs):
if goal_sampler is None:
goal_sampler = lambda np_rand: maze_env.MazeEnv.goal_sampler(self, np_rand)
maze_env.MazeEnv.__init__(
self, *args, manual_collision=False,
goal_sampler=goal_sampler,
expose_all_qpos=expose_all_qpos,
reward_type=reward_type,
**kwargs)
offline_env.OfflineEnv.__init__(self, **kwargs)
## We set the target foal here for evaluation
self.set_target()
self.v2_resets = v2_resets
def reset(self):
if self.v2_resets:
"""
The target goal for evaluation in antmazes is randomized.
antmazes-v0 and -v1 resulted in really high-variance evaluations
because the target goal was set once at the seed level. This led to
each run running evaluations with one particular goal. To accurately
cover each goal, this requires about 50-100 seeds, which might be
computationally infeasible. As an alternate fix, to reduce variance
in result reporting, we are creating the v2 environments
which use the same offline dataset as v0 environments, with the distinction
that the randomization of goals during evaluation is performed at the level of
each rollout. Thus running a few seeds, but performing the final evaluation
over 100-200 episodes will give a valid estimate of an algorithm's performance.
"""
self.set_target()
return super().reset()
def set_target(self, target_location=None):
return self.set_target_goal(target_location)
def seed(self, seed=0):
mujoco_env.MujocoEnv.seed(self, seed)
def make_ant_maze_env(**kwargs):
env = AntMazeEnv(**kwargs)
return wrappers.NormalizedBoxEnv(env)
================================================
FILE: d4rl/d4rl/locomotion/assets/ant.xml
================================================
================================================
FILE: d4rl/d4rl/locomotion/assets/point.xml
================================================
================================================
FILE: d4rl/d4rl/locomotion/common.py
================================================
def run_policy_on_env(policy_fn, env, truncate_episode_at=None,
first_obs=None):
if first_obs is None:
obs = env.reset()
else:
obs = first_obs
trajectory = []
step_num = 0
while True:
act = policy_fn(obs)
next_obs, rew, done, _ = env.step(act)
trajectory.append((obs, act, rew, done))
obs = next_obs
step_num += 1
if (done or
(truncate_episode_at is not None and step_num >= truncate_episode_at)):
break
return trajectory
================================================
FILE: d4rl/d4rl/locomotion/generate_dataset.py
================================================
import numpy as np
import pickle
import gzip
import h5py
import argparse
from d4rl.locomotion import maze_env, ant, swimmer
from d4rl.locomotion.wrappers import NormalizedBoxEnv
from rlkit.torch.pytorch_util import set_gpu_mode
import torch
import skvideo.io
from PIL import Image
import os
def reset_data():
return {'observations': [],
'actions': [],
'terminals': [],
'rewards': [],
'infos/goal': [],
'infos/qpos': [],
'infos/qvel': [],
}
def append_data(data, s, a, r, tgt, done, env_data):
data['observations'].append(s)
data['actions'].append(a)
data['rewards'].append(r)
data['terminals'].append(done)
data['infos/goal'].append(tgt)
data['infos/qpos'].append(env_data.qpos.ravel().copy())
data['infos/qvel'].append(env_data.qvel.ravel().copy())
def npify(data):
for k in data:
if k == 'terminals':
dtype = np.bool_
else:
dtype = np.float32
data[k] = np.array(data[k], dtype=dtype)
def load_policy(policy_file):
data = torch.load(policy_file)
policy = data['exploration/policy']
env = data['evaluation/env']
print("Policy loaded")
if True:
set_gpu_mode(True)
policy.cuda()
return policy, env
def save_video(save_dir, file_name, frames, episode_id=0):
filename = os.path.join(save_dir, file_name+ '_episode_{}'.format(episode_id))
if not os.path.exists(filename):
os.makedirs(filename)
num_frames = frames.shape[0]
for i in range(num_frames):
img = Image.fromarray(np.flipud(frames[i]), 'RGB')
img.save(os.path.join(filename, 'frame_{}.png'.format(i)))
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--noisy', action='store_true', help='Noisy actions')
parser.add_argument('--maze', type=str, default='u-maze', help='Maze type. small or default')
parser.add_argument('--num_samples', type=int, default=int(1e6), help='Num samples to collect')
parser.add_argument('--env', type=str, default='Ant', help='Environment type')
parser.add_argument('--policy_file', type=str, default='policy_file', help='file_name')
parser.add_argument('--max_episode_steps', default=1000, type=int)
parser.add_argument('--video', action='store_true')
parser.add_argument('--multi_start', action='store_true')
parser.add_argument('--multigoal', action='store_true')
args = parser.parse_args()
if args.maze == 'u-maze':
maze = maze_env.U_MAZE
elif args.maze == 'big-maze':
maze = maze_env.BIG_MAZE
elif args.maze == 'hardest-maze':
maze = maze_env.HARDEST_MAZE
else:
raise NotImplementedError
if args.env == 'Ant':
env = NormalizedBoxEnv(ant.AntMazeEnv(maze_map=maze, maze_size_scaling=4.0, non_zero_reset=args.multi_start))
elif args.env == 'Swimmer':
env = NormalizedBoxEnv(swimmer.SwimmerMazeEnv(mmaze_map=maze, maze_size_scaling=4.0, non_zero_reset=args.multi_start))
env.set_target_goal()
s = env.reset()
print (s.shape)
act = env.action_space.sample()
done = False
# Load the policy
policy, train_env = load_policy(args.policy_file)
# Define goal reaching policy fn
def _goal_reaching_policy_fn(obs, goal):
goal_x, goal_y = goal
obs_new = obs[2:-2]
goal_tuple = np.array([goal_x, goal_y])
# normalize the norm of the relative goals to in-distribution values
goal_tuple = goal_tuple / np.linalg.norm(goal_tuple) * 10.0
new_obs = np.concatenate([obs_new, goal_tuple], -1)
return policy.get_action(new_obs)[0], (goal_tuple[0] + obs[0], goal_tuple[1] + obs[1])
data = reset_data()
# create waypoint generating policy integrated with high level controller
data_collection_policy = env.create_navigation_policy(
_goal_reaching_policy_fn,
)
if args.video:
frames = []
ts = 0
num_episodes = 0
for _ in range(args.num_samples):
act, waypoint_goal = data_collection_policy(s)
if args.noisy:
act = act + np.random.randn(*act.shape)*0.2
act = np.clip(act, -1.0, 1.0)
ns, r, done, info = env.step(act)
if ts >= args.max_episode_steps:
done = True
append_data(data, s[:-2], act, r, env.target_goal, done, env.physics.data)
if len(data['observations']) % 10000 == 0:
print(len(data['observations']))
ts += 1
if done:
done = False
ts = 0
s = env.reset()
env.set_target_goal()
if args.video:
frames = np.array(frames)
save_video('./videos/', args.env + '_navigation', frames, num_episodes)
num_episodes += 1
frames = []
else:
s = ns
if args.video:
curr_frame = env.physics.render(width=500, height=500, depth=False)
frames.append(curr_frame)
if args.noisy:
fname = args.env + '_maze_%s_noisy_multistart_%s_multigoal_%s.hdf5' % (args.maze, str(args.multi_start), str(args.multigoal))
else:
fname = args.env + 'maze_%s_multistart_%s_multigoal_%s.hdf5' % (args.maze, str(args.multi_start), str(args.multigoal))
dataset = h5py.File(fname, 'w')
npify(data)
for k in data:
dataset.create_dataset(k, data=data[k], compression='gzip')
if __name__ == '__main__':
main()
================================================
FILE: d4rl/d4rl/locomotion/goal_reaching_env.py
================================================
import numpy as np
def disk_goal_sampler(np_random, goal_region_radius=10.):
th = 2 * np.pi * np_random.uniform()
radius = goal_region_radius * np_random.uniform()
return radius * np.array([np.cos(th), np.sin(th)])
def constant_goal_sampler(np_random, location=10.0 * np.ones([2])):
return location
class GoalReachingEnv(object):
"""General goal-reaching environment."""
BASE_ENV = None # Must be specified by child class.
def __init__(self, goal_sampler, eval=False, reward_type='dense'):
self._goal_sampler = goal_sampler
self._goal = np.ones([2])
self.target_goal = self._goal
# This flag is used to make sure that when using this environment
# for evaluation, that is no goals are appended to the state
self.eval = eval
# This is the reward type fed as input to the goal confitioned policy
self.reward_type = reward_type
def _get_obs(self):
base_obs = self.BASE_ENV._get_obs(self)
goal_direction = self._goal - self.get_xy()
if not self.eval:
obs = np.concatenate([base_obs, goal_direction])
return obs
else:
return base_obs
def step(self, a):
self.BASE_ENV.step(self, a)
if self.reward_type == 'dense':
reward = np.exp(-np.linalg.norm(self.target_goal - self.get_xy()))
elif self.reward_type == 'sparse':
reward = 1.0 if np.linalg.norm(self.get_xy() - self.target_goal) <= 0.5 else 0.0
done = False
# Terminate episode when we reach a goal
if self.eval and np.linalg.norm(self.get_xy() - self.target_goal) <= 0.5:
done = True
obs = self._get_obs()
return obs, reward, done, {}
def reset_model(self):
if self.target_goal is not None or self.eval:
self._goal = self.target_goal
else:
self._goal = self._goal_sampler(self.np_random)
return self.BASE_ENV.reset_model(self)
================================================
FILE: d4rl/d4rl/locomotion/maze_env.py
================================================
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Adapted from efficient-hrl maze_env.py."""
import os
import tempfile
import xml.etree.ElementTree as ET
import math
import numpy as np
import gym
from copy import deepcopy
RESET = R = 'r' # Reset position.
GOAL = G = 'g'
# Maze specifications for dataset generation
U_MAZE = [[1, 1, 1, 1, 1],
[1, R, 0, 0, 1],
[1, 1, 1, 0, 1],
[1, G, 0, 0, 1],
[1, 1, 1, 1, 1]]
BIG_MAZE = [[1, 1, 1, 1, 1, 1, 1, 1],
[1, R, 0, 1, 1, 0, 0, 1],
[1, 0, 0, 1, 0, 0, G, 1],
[1, 1, 0, 0, 0, 1, 1, 1],
[1, 0, 0, 1, 0, 0, 0, 1],
[1, G, 1, 0, 0, 1, 0, 1],
[1, 0, 0, 0, 1, G, 0, 1],
[1, 1, 1, 1, 1, 1, 1, 1]]
HARDEST_MAZE = [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, R, 0, 0, 0, 1, G, 0, 0, 0, 0, 1],
[1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1],
[1, 0, 0, 0, 0, G, 0, 1, 0, 0, G, 1],
[1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1],
[1, 0, G, 1, 0, 1, 0, 0, 0, 0, 0, 1],
[1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1],
[1, 0, 0, 1, G, 0, G, 1, 0, G, 0, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
# Maze specifications with a single target goal
U_MAZE_TEST = [[1, 1, 1, 1, 1],
[1, R, 0, 0, 1],
[1, 1, 1, 0, 1],
[1, G, 0, 0, 1],
[1, 1, 1, 1, 1]]
BIG_MAZE_TEST = [[1, 1, 1, 1, 1, 1, 1, 1],
[1, R, 0, 1, 1, 0, 0, 1],
[1, 0, 0, 1, 0, 0, 0, 1],
[1, 1, 0, 0, 0, 1, 1, 1],
[1, 0, 0, 1, 0, 0, 0, 1],
[1, 0, 1, 0, 0, 1, 0, 1],
[1, 0, 0, 0, 1, 0, G, 1],
[1, 1, 1, 1, 1, 1, 1, 1]]
HARDEST_MAZE_TEST = [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, R, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1],
[1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1],
[1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1],
[1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1],
[1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1],
[1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1],
[1, 0, 0, 1, 0, 0, 0, 1, 0, G, 0, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
# Maze specifications for evaluation
U_MAZE_EVAL = [[1, 1, 1, 1, 1],
[1, 0, 0, R, 1],
[1, 0, 1, 1, 1],
[1, 0, 0, G, 1],
[1, 1, 1, 1, 1]]
BIG_MAZE_EVAL = [[1, 1, 1, 1, 1, 1, 1, 1],
[1, R, 0, 0, 0, 0, G, 1],
[1, 0, 1, 0, 1, 1, 0, 1],
[1, 0, 0, 0, 0, 1, 0, 1],
[1, 1, 1, 0, 0, 1, 1, 1],
[1, G, 0, 0, 0, 0, 0, 1],
[1, 0, 0, 1, 1, G, 0, 1],
[1, 1, 1, 1, 1, 1, 1, 1]]
HARDEST_MAZE_EVAL = [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, R, 0, 1, G, 0, 0, 1, 0, G, 0, 1],
[1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1],
[1, 0, 0, 1, 0, 1, G, 0, 0, 0, 0, 1],
[1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1],
[1, G, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1],
[1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1],
[1, 0, 0, 0, G, 1, G, 0, 0, 0, G, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
U_MAZE_EVAL_TEST = [[1, 1, 1, 1, 1],
[1, 0, 0, R, 1],
[1, 0, 1, 1, 1],
[1, 0, 0, G, 1],
[1, 1, 1, 1, 1]]
BIG_MAZE_EVAL_TEST = [[1, 1, 1, 1, 1, 1, 1, 1],
[1, R, 0, 0, 0, 0, G, 1],
[1, 0, 1, 0, 1, 1, 0, 1],
[1, 0, 0, 0, 0, 1, 0, 1],
[1, 1, 1, 0, 0, 1, 1, 1],
[1, 0, 0, 0, 0, 0, 0, 1],
[1, 0, 0, 1, 1, 0, 0, 1],
[1, 1, 1, 1, 1, 1, 1, 1]]
HARDEST_MAZE_EVAL_TEST = [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, R, 0, 1, 0, 0, 0, 1, 0, G, 0, 1],
[1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1],
[1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1],
[1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1],
[1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1],
[1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1],
[1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
class MazeEnv(gym.Env):
LOCOMOTION_ENV = None # Must be specified by child class.
def __init__(
self,
maze_map,
maze_size_scaling,
maze_height=0.5,
manual_collision=False,
non_zero_reset=False,
reward_type='dense',
*args,
**kwargs):
if self.LOCOMOTION_ENV is None:
raise ValueError('LOCOMOTION_ENV is unspecified.')
xml_path = self.LOCOMOTION_ENV.FILE
tree = ET.parse(xml_path)
worldbody = tree.find(".//worldbody")
self._maze_map = maze_map
self._maze_height = maze_height
self._maze_size_scaling = maze_size_scaling
self._manual_collision = manual_collision
self._maze_map = maze_map
# Obtain a numpy array form for a maze map in case we want to reset
# to multiple starting states
temp_maze_map = deepcopy(self._maze_map)
for i in range(len(maze_map)):
for j in range(len(maze_map[0])):
if temp_maze_map[i][j] in [RESET,]:
temp_maze_map[i][j] = 0
elif temp_maze_map[i][j] in [GOAL,]:
temp_maze_map[i][j] = 1
self._np_maze_map = np.array(temp_maze_map)
torso_x, torso_y = self._find_robot()
self._init_torso_x = torso_x
self._init_torso_y = torso_y
for i in range(len(self._maze_map)):
for j in range(len(self._maze_map[0])):
struct = self._maze_map[i][j]
if struct == 1: # Unmovable block.
# Offset all coordinates so that robot starts at the origin.
ET.SubElement(
worldbody, "geom",
name="block_%d_%d" % (i, j),
pos="%f %f %f" % (j * self._maze_size_scaling - torso_x,
i * self._maze_size_scaling - torso_y,
self._maze_height / 2 * self._maze_size_scaling),
size="%f %f %f" % (0.5 * self._maze_size_scaling,
0.5 * self._maze_size_scaling,
self._maze_height / 2 * self._maze_size_scaling),
type="box",
material="",
contype="1",
conaffinity="1",
rgba="0.7 0.5 0.3 1.0",
)
# elif struct == 'g':
# # Offset all coordinates so that robot starts at the origin.
# ET.SubElement(
# worldbody, "geom",
# name="goal_%d_%d" % (i, j),
# pos="%f %f %f" % (j * self._maze_size_scaling - torso_x,
# i * self._maze_size_scaling - torso_y,
# self._maze_height / 2 * self._maze_size_scaling),
# size="%f %f %f" % (0.5 * self._maze_size_scaling,
# 0.5 * self._maze_size_scaling,
# self._maze_height / 2 * self._maze_size_scaling),
# type="plane",
# material="",
# contype="1",
# conaffinity="1",
# rgba="1.0 0.1 0.1 0.2",
# )
torso = tree.find(".//body[@name='torso']")
geoms = torso.findall(".//geom")
_, file_path = tempfile.mkstemp(text=True, suffix='.xml')
tree.write(file_path)
self.LOCOMOTION_ENV.__init__(self, *args, file_path=file_path, non_zero_reset=non_zero_reset, reward_type=reward_type, **kwargs)
self.target_goal = None
def _xy_to_rowcol(self, xy):
size_scaling = self._maze_size_scaling
xy = (max(xy[0], 1e-4), max(xy[1], 1e-4))
return (int(1 + (xy[1]) / size_scaling),
int(1 + (xy[0]) / size_scaling))
def _get_reset_location(self,):
prob = (1.0 - self._np_maze_map) / np.sum(1.0 - self._np_maze_map)
prob_row = np.sum(prob, 1)
row_sample = np.random.choice(np.arange(self._np_maze_map.shape[0]), p=prob_row)
col_sample = np.random.choice(np.arange(self._np_maze_map.shape[1]), p=prob[row_sample] * 1.0 / prob_row[row_sample])
reset_location = self._rowcol_to_xy((row_sample, col_sample))
# Add some random noise
random_x = np.random.uniform(low=0, high=0.5) * 0.5 * self._maze_size_scaling
random_y = np.random.uniform(low=0, high=0.5) * 0.5 * self._maze_size_scaling
return (max(reset_location[0] + random_x, 0), max(reset_location[1] + random_y, 0))
def _rowcol_to_xy(self, rowcol, add_random_noise=False):
row, col = rowcol
x = col * self._maze_size_scaling - self._init_torso_x
y = row * self._maze_size_scaling - self._init_torso_y
if add_random_noise:
x = x + np.random.uniform(low=0, high=self._maze_size_scaling * 0.25)
y = y + np.random.uniform(low=0, high=self._maze_size_scaling * 0.25)
return (x, y)
def goal_sampler(self, np_random, only_free_cells=True, interpolate=True):
valid_cells = []
goal_cells = []
for i in range(len(self._maze_map)):
for j in range(len(self._maze_map[0])):
if self._maze_map[i][j] in [0, RESET, GOAL] or not only_free_cells:
valid_cells.append((i, j))
if self._maze_map[i][j] == GOAL:
goal_cells.append((i, j))
# If there is a 'goal' designated, use that. Otherwise, any valid cell can
# be a goal.
sample_choices = goal_cells if goal_cells else valid_cells
cell = sample_choices[np_random.choice(len(sample_choices))]
xy = self._rowcol_to_xy(cell, add_random_noise=True)
random_x = np.random.uniform(low=0, high=0.5) * 0.25 * self._maze_size_scaling
random_y = np.random.uniform(low=0, high=0.5) * 0.25 * self._maze_size_scaling
xy = (max(xy[0] + random_x, 0), max(xy[1] + random_y, 0))
return xy
def set_target_goal(self, goal_input=None):
if goal_input is None:
self.target_goal = self.goal_sampler(np.random)
else:
self.target_goal = goal_input
# print ('Target Goal: ', self.target_goal)
## Make sure that the goal used in self._goal is also reset:
self._goal = self.target_goal
def _find_robot(self):
structure = self._maze_map
size_scaling = self._maze_size_scaling
for i in range(len(structure)):
for j in range(len(structure[0])):
if structure[i][j] == RESET:
return j * size_scaling, i * size_scaling
raise ValueError('No robot in maze specification.')
def _is_in_collision(self, pos):
x, y = pos
structure = self._maze_map
size_scaling = self._maze_size_scaling
for i in range(len(structure)):
for j in range(len(structure[0])):
if structure[i][j] == 1:
minx = j * size_scaling - size_scaling * 0.5 - self._init_torso_x
maxx = j * size_scaling + size_scaling * 0.5 - self._init_torso_x
miny = i * size_scaling - size_scaling * 0.5 - self._init_torso_y
maxy = i * size_scaling + size_scaling * 0.5 - self._init_torso_y
if minx <= x <= maxx and miny <= y <= maxy:
return True
return False
def step(self, action):
if self._manual_collision:
old_pos = self.get_xy()
inner_next_obs, inner_reward, done, info = self.LOCOMOTION_ENV.step(self, action)
new_pos = self.get_xy()
if self._is_in_collision(new_pos):
self.set_xy(old_pos)
else:
inner_next_obs, inner_reward, done, info = self.LOCOMOTION_ENV.step(self, action)
next_obs = self._get_obs()
return next_obs, inner_reward, done, info
def _get_best_next_rowcol(self, current_rowcol, target_rowcol):
"""Runs BFS to find shortest path to target and returns best next rowcol.
Add obstacle avoidance"""
current_rowcol = tuple(current_rowcol)
target_rowcol = tuple(target_rowcol)
if target_rowcol == current_rowcol:
return target_rowcol
visited = {}
to_visit = [target_rowcol]
while to_visit:
next_visit = []
for rowcol in to_visit:
visited[rowcol] = True
row, col = rowcol
left = (row, col - 1)
right = (row, col + 1)
down = (row + 1, col)
up = (row - 1, col)
for next_rowcol in [left, right, down, up]:
if next_rowcol == current_rowcol: # Found a shortest path.
return rowcol
next_row, next_col = next_rowcol
if next_row < 0 or next_row >= len(self._maze_map):
continue
if next_col < 0 or next_col >= len(self._maze_map[0]):
continue
if self._maze_map[next_row][next_col] not in [0, RESET, GOAL]:
continue
if next_rowcol in visited:
continue
next_visit.append(next_rowcol)
to_visit = next_visit
raise ValueError('No path found to target.')
def create_navigation_policy(self,
goal_reaching_policy_fn,
obs_to_robot=lambda obs: obs[:2],
obs_to_target=lambda obs: obs[-2:],
relative=False):
"""Creates a navigation policy by guiding a sub-policy to waypoints."""
def policy_fn(obs):
# import ipdb; ipdb.set_trace()
robot_x, robot_y = obs_to_robot(obs)
robot_row, robot_col = self._xy_to_rowcol([robot_x, robot_y])
target_x, target_y = self.target_goal
if relative:
target_x += robot_x # Target is given in relative coordinates.
target_y += robot_y
target_row, target_col = self._xy_to_rowcol([target_x, target_y])
print ('Target: ', target_row, target_col, target_x, target_y)
print ('Robot: ', robot_row, robot_col, robot_x, robot_y)
waypoint_row, waypoint_col = self._get_best_next_rowcol(
[robot_row, robot_col], [target_row, target_col])
if waypoint_row == target_row and waypoint_col == target_col:
waypoint_x = target_x
waypoint_y = target_y
else:
waypoint_x, waypoint_y = self._rowcol_to_xy([waypoint_row, waypoint_col], add_random_noise=True)
goal_x = waypoint_x - robot_x
goal_y = waypoint_y - robot_y
print ('Waypoint: ', waypoint_row, waypoint_col, waypoint_x, waypoint_y)
return goal_reaching_policy_fn(obs, (goal_x, goal_y))
return policy_fn
================================================
FILE: d4rl/d4rl/locomotion/mujoco_goal_env.py
================================================
from collections import OrderedDict
import os
from gym import error, spaces
from gym.utils import seeding
import numpy as np
from os import path
import gym
try:
import mujoco_py
except ImportError as e:
raise error.DependencyNotInstalled("{}. (HINT: you need to install mujoco_py, and also perform the setup instructions here: https://github.com/openai/mujoco-py/.)".format(e))
DEFAULT_SIZE = 500
def convert_observation_to_space(observation):
if isinstance(observation, dict):
space = spaces.Dict(OrderedDict([
(key, convert_observation_to_space(value))
for key, value in observation.items()
]))
elif isinstance(observation, np.ndarray):
low = np.full(observation.shape, -float('inf'), dtype=np.float32)
high = np.full(observation.shape, float('inf'), dtype=np.float32)
space = spaces.Box(low, high, dtype=observation.dtype)
else:
raise NotImplementedError(type(observation), observation)
return space
class MujocoGoalEnv(gym.Env):
"""SuperClass for all MuJoCo goal reaching environments"""
def __init__(self, model_path, frame_skip):
if model_path.startswith("/"):
fullpath = model_path
else:
fullpath = os.path.join(os.path.dirname(__file__), "assets", model_path)
if not path.exists(fullpath):
raise IOError("File %s does not exist" % fullpath)
self.frame_skip = frame_skip
self.model = mujoco_py.load_model_from_path(fullpath)
self.sim = mujoco_py.MjSim(self.model)
self.data = self.sim.data
self.viewer = None
self._viewers = {}
self.metadata = {
'render.modes': ['human', 'rgb_array', 'depth_array'],
'video.frames_per_second': int(np.round(1.0 / self.dt))
}
self.init_qpos = self.sim.data.qpos.ravel().copy()
self.init_qvel = self.sim.data.qvel.ravel().copy()
self._set_action_space()
action = self.action_space.sample()
# import ipdb; ipdb.set_trace()
observation, _reward, done, _info = self.step(action)
assert not done
self._set_observation_space(observation['observation'])
self.seed()
def _set_action_space(self):
bounds = self.model.actuator_ctrlrange.copy().astype(np.float32)
low, high = bounds.T
self.action_space = spaces.Box(low=low, high=high, dtype=np.float32)
return self.action_space
# def _set_observation_space(self, observation):
# self.observation_space = convert_observation_to_space(observation)
# return self.observation_space
def _set_observation_space(self, observation):
temp_observation_space = convert_observation_to_space(observation)
self.observation_space = spaces.Dict(dict(
observation=temp_observation_space,
desired_goal=spaces.Box(-np.inf, np.inf, shape=(2,), dtype=np.float32),
achieved_goal=spaces.Box(-np.inf, np.inf, shape=(2,), dtype=np.float32),
))
return self.observation_space
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
# methods to override:
# ----------------------------
def reset_model(self):
"""
Reset the robot degrees of freedom (qpos and qvel).
Implement this in each subclass.
"""
raise NotImplementedError
def viewer_setup(self):
"""
This method is called when the viewer is initialized.
Optionally implement this method, if you need to tinker with camera position
and so forth.
"""
pass
def reset(self):
self.sim.reset()
ob = self.reset_model()
return ob
def set_state(self, qpos, qvel):
assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)
old_state = self.sim.get_state()
new_state = mujoco_py.MjSimState(old_state.time, qpos, qvel,
old_state.act, old_state.udd_state)
self.sim.set_state(new_state)
self.sim.forward()
@property
def dt(self):
return self.model.opt.timestep * self.frame_skip
def do_simulation(self, ctrl, n_frames):
self.sim.data.ctrl[:] = ctrl
for _ in range(n_frames):
self.sim.step()
def render(self,
mode='human',
width=DEFAULT_SIZE,
height=DEFAULT_SIZE,
camera_id=None,
camera_name=None):
if mode == 'rgb_array':
if camera_id is not None and camera_name is not None:
raise ValueError("Both `camera_id` and `camera_name` cannot be"
" specified at the same time.")
no_camera_specified = camera_name is None and camera_id is None
if no_camera_specified:
camera_name = 'track'
if camera_id is None and camera_name in self.model._camera_name2id:
camera_id = self.model.camera_name2id(camera_name)
self._get_viewer(mode).render(width, height, camera_id=camera_id)
# window size used for old mujoco-py:
data = self._get_viewer(mode).read_pixels(width, height, depth=False)
# original image is upside-down, so flip it
return data[::-1, :, :]
elif mode == 'depth_array':
self._get_viewer(mode).render(width, height)
# window size used for old mujoco-py:
# Extract depth part of the read_pixels() tuple
data = self._get_viewer(mode).read_pixels(width, height, depth=True)[1]
# original image is upside-down, so flip it
return data[::-1, :]
elif mode == 'human':
self._get_viewer(mode).render()
def close(self):
if self.viewer is not None:
# self.viewer.finish()
self.viewer = None
self._viewers = {}
def _get_viewer(self, mode):
self.viewer = self._viewers.get(mode)
if self.viewer is None:
if mode == 'human':
self.viewer = mujoco_py.MjViewer(self.sim)
elif mode == 'rgb_array' or mode == 'depth_array':
self.viewer = mujoco_py.MjRenderContextOffscreen(self.sim, -1)
self.viewer_setup()
self._viewers[mode] = self.viewer
return self.viewer
def get_body_com(self, body_name):
return self.data.get_body_xpos(body_name)
def state_vector(self):
return np.concatenate([
self.sim.data.qpos.flat,
self.sim.data.qvel.flat
])
================================================
FILE: d4rl/d4rl/locomotion/point.py
================================================
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wrapper for creating the point environment."""
import math
import numpy as np
import mujoco_py
import os
from gym import utils
from gym.envs.mujoco import mujoco_env
from d4rl.locomotion import mujoco_goal_env
from d4rl.locomotion import goal_reaching_env
from d4rl.locomotion import maze_env
MY_ASSETS_DIR = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
'assets')
class PointEnv(mujoco_env.MujocoEnv, utils.EzPickle):
FILE = os.path.join(MY_ASSETS_DIR, 'point.xml')
def __init__(self, file_path=None, expose_all_qpos=False):
if file_path is None:
file_path = self.FILE
self._expose_all_qpos = expose_all_qpos
mujoco_env.MujocoEnv.__init__(self, file_path, 1)
# mujoco_goal_env.MujocoGoalEnv.__init__(self, file_path, 1)
utils.EzPickle.__init__(self)
@property
def physics(self):
# Check mujoco version is greater than version 1.50 to call correct physics
# model containing PyMjData object for getting and setting position/velocity.
# Check https://github.com/openai/mujoco-py/issues/80 for updates to api.
if mujoco_py.get_version() >= '1.50':
return self.sim
else:
return self.model
def _step(self, a):
return self.step(a)
def step(self, action):
action[0] = 0.2 * action[0]
qpos = np.copy(self.physics.data.qpos)
qpos[2] += action[1]
ori = qpos[2]
# Compute increment in each direction.
dx = math.cos(ori) * action[0]
dy = math.sin(ori) * action[0]
# Ensure that the robot is within reasonable range.
qpos[0] = np.clip(qpos[0] + dx, -100, 100)
qpos[1] = np.clip(qpos[1] + dy, -100, 100)
qvel = self.physics.data.qvel
self.set_state(qpos, qvel)
for _ in range(0, self.frame_skip):
self.physics.step()
next_obs = self._get_obs()
reward = 0
done = False
info = {}
return next_obs, reward, done, info
def _get_obs(self):
if self._expose_all_qpos:
return np.concatenate([
self.physics.data.qpos.flat[:3], # Only point-relevant coords.
self.physics.data.qvel.flat[:3]])
return np.concatenate([
self.physics.data.qpos.flat[2:3],
self.physics.data.qvel.flat[:3]])
def reset_model(self):
qpos = self.init_qpos + self.np_random.uniform(
size=self.physics.model.nq, low=-.1, high=.1)
qvel = self.init_qvel + self.np_random.randn(self.physics.model.nv) * .1
# Set everything other than point to original position and 0 velocity.
qpos[3:] = self.init_qpos[3:]
qvel[3:] = 0.
self.set_state(qpos, qvel)
return self._get_obs()
def get_xy(self):
return self.physics.data.qpos[:2]
def set_xy(self, xy):
qpos = np.copy(self.physics.data.qpos)
qpos[0] = xy[0]
qpos[1] = xy[1]
qvel = self.physics.data.qvel
self.set_state(qpos, qvel)
class GoalReachingPointEnv(goal_reaching_env.GoalReachingEnv, PointEnv):
"""Point locomotion rewarded for goal-reaching."""
BASE_ENV = PointEnv
def __init__(self, goal_sampler=goal_reaching_env.disk_goal_sampler,
file_path=None,
expose_all_qpos=False):
goal_reaching_env.GoalReachingEnv.__init__(self, goal_sampler)
PointEnv.__init__(self,
file_path=file_path,
expose_all_qpos=expose_all_qpos)
class GoalReachingPointDictEnv(goal_reaching_env.GoalReachingDictEnv, PointEnv):
"""Ant locomotion for goal reaching in a disctionary compatible format."""
BASE_ENV = PointEnv
def __init__(self, goal_sampler=goal_reaching_env.disk_goal_sampler,
file_path=None,
expose_all_qpos=False):
goal_reaching_env.GoalReachingDictEnv.__init__(self, goal_sampler)
PointEnv.__init__(self,
file_path=file_path,
expose_all_qpos=expose_all_qpos)
class PointMazeEnv(maze_env.MazeEnv, GoalReachingPointEnv):
"""Point navigating a maze."""
LOCOMOTION_ENV = GoalReachingPointEnv
def __init__(self, goal_sampler=None, expose_all_qpos=True,
*args, **kwargs):
if goal_sampler is None:
goal_sampler = lambda np_rand: maze_env.MazeEnv.goal_sampler(self, np_rand)
maze_env.MazeEnv.__init__(
self, *args, manual_collision=True,
goal_sampler=goal_sampler,
expose_all_qpos=expose_all_qpos,
**kwargs)
def create_goal_reaching_policy(obs_to_goal=lambda obs: obs[-2:],
obs_to_ori=lambda obs: obs[0]):
"""A hard-coded policy for reaching a goal position."""
def policy_fn(obs):
goal_x, goal_y = obs_to_goal(obs)
goal_dist = np.linalg.norm([goal_x, goal_y])
goal_ori = np.arctan2(goal_y, goal_x)
ori = obs_to_ori(obs)
ori_diff = (goal_ori - ori) % (2 * np.pi)
radius = goal_dist / 2. / max(0.1, np.abs(np.sin(ori_diff)))
rotation_left = (2 * ori_diff) % np.pi
circumference_left = max(goal_dist, radius * rotation_left)
speed = min(circumference_left * 5., 1.0)
velocity = speed
if ori_diff > np.pi / 2 and ori_diff < 3 * np.pi / 2:
velocity *= -1
time_left = min(circumference_left / (speed * 0.2), 10.)
signed_ori_diff = ori_diff
if signed_ori_diff >= 3 * np.pi / 2:
signed_ori_diff = 2 * np.pi - signed_ori_diff
elif signed_ori_diff > np.pi / 2 and signed_ori_diff < 3 * np.pi / 2:
signed_ori_diff = signed_ori_diff - np.pi
angular_velocity = signed_ori_diff / time_left
angular_velocity = np.clip(angular_velocity, -1., 1.)
return np.array([velocity, angular_velocity])
return policy_fn
def create_maze_navigation_policy(maze_env):
"""Creates a hard-coded policy to navigate a maze."""
ori_index = 2 if maze_env._expose_all_qpos else 0
obs_to_ori = lambda obs: obs[ori_index]
goal_reaching_policy = create_goal_reaching_policy(obs_to_ori=obs_to_ori)
goal_reaching_policy_fn = lambda obs, goal: goal_reaching_policy(
np.concatenate([obs, goal]))
return maze_env.create_navigation_policy(goal_reaching_policy_fn)
================================================
FILE: d4rl/d4rl/locomotion/swimmer.py
================================================
"""Wrapper for creating the swimmer environment."""
import math
import numpy as np
import mujoco_py
import os
from gym import utils
from gym.envs.mujoco import mujoco_env
from d4rl.locomotion import mujoco_goal_env
from d4rl.locomotion import goal_reaching_env
from d4rl.locomotion import maze_env
from d4rl import offline_env
GYM_ASSETS_DIR = os.path.join(
os.path.dirname(mujoco_env.__file__),
'assets')
class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
"""Basic swimmer locomotion environment."""
FILE = os.path.join(GYM_ASSETS_DIR, 'swimmer.xml')
def __init__(self, file_path=None, expose_all_qpos=False, non_zero_reset=False):
if file_path is None:
file_path = self.FILE
self._expose_all_qpos = expose_all_qpos
mujoco_env.MujocoEnv.__init__(self, file_path, 5)
utils.EzPickle.__init__(self)
@property
def physics(self):
# Check mujoco version is greater than version 1.50 to call correct physics
# model containing PyMjData object for getting and setting position/velocity.
# Check https://github.com/openai/mujoco-py/issues/80 for updates to api.
if mujoco_py.get_version() >= '1.50':
return self.sim
else:
return self.model
def _step(self, a):
return self.step(a)
def step(self, a):
ctrl_cost_coeff = 0.0001
xposbefore = self.sim.data.qpos[0]
self.do_simulation(a, self.frame_skip)
xposafter = self.sim.data.qpos[0]
reward_fwd = (xposafter - xposbefore) / self.dt
reward_ctrl = - ctrl_cost_coeff * np.square(a).sum()
reward = reward_fwd + reward_ctrl
ob = self._get_obs()
return ob, reward, False, dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl)
def _get_obs(self):
if self._expose_all_qpos:
obs = np.concatenate([
self.physics.data.qpos.flat[:5], # Ensures only swimmer obs.
self.physics.data.qvel.flat[:5],
])
else:
obs = np.concatenate([
self.physics.data.qpos.flat[2:5],
self.physics.data.qvel.flat[:5],
])
return obs
def reset_model(self):
qpos = self.init_qpos + self.np_random.uniform(
size=self.model.nq, low=-.1, high=.1)
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1
# Set everything other than swimmer to original position and 0 velocity.
qpos[5:] = self.init_qpos[5:]
qvel[5:] = 0.
self.set_state(qpos, qvel)
return self._get_obs()
def get_xy(self):
return self.physics.data.qpos[:2]
def set_xy(self, xy):
qpos = np.copy(self.physics.data.qpos)
qpos[0] = xy[0]
qpos[1] = xy[1]
qvel = self.physics.data.qvel
self.set_state(qpos, qvel)
class GoalReachingSwimmerEnv(goal_reaching_env.GoalReachingEnv, SwimmerEnv):
"""Swimmer locomotion rewarded for goal-reaching."""
BASE_ENV = SwimmerEnv
def __init__(self, goal_sampler=goal_reaching_env.disk_goal_sampler,
file_path=None,
expose_all_qpos=False, non_zero_reset=False, eval=False, reward_type="dense", **kwargs):
goal_reaching_env.GoalReachingEnv.__init__(self, goal_sampler, eval=eval, reward_type=reward_type)
SwimmerEnv.__init__(self,
file_path=file_path,
expose_all_qpos=expose_all_qpos,
non_zero_reset=non_zero_reset)
class SwimmerMazeEnv(maze_env.MazeEnv, GoalReachingSwimmerEnv, offline_env.OfflineEnv):
"""Swimmer navigating a maze."""
LOCOMOTION_ENV = GoalReachingSwimmerEnv
def __init__(self, goal_sampler=None, expose_all_qpos=True,
reward_type='dense',
*args, **kwargs):
if goal_sampler is None:
goal_sampler = lambda np_rand: maze_env.MazeEnv.goal_sampler(self, np_rand)
maze_env.MazeEnv.__init__(
self, *args, manual_collision=False,
goal_sampler=goal_sampler,
expose_all_qpos=expose_all_qpos,
reward_type=reward_type,
**kwargs)
offline_env.OfflineEnv.__init__(self, **kwargs)
def set_target(self, target_location=None):
return self.set_target_goal(target_location)
================================================
FILE: d4rl/d4rl/locomotion/wrappers.py
================================================
import numpy as np
import itertools
from gym import Env
from gym.spaces import Box
from gym.spaces import Discrete
from collections import deque
class ProxyEnv(Env):
def __init__(self, wrapped_env):
self._wrapped_env = wrapped_env
self.action_space = self._wrapped_env.action_space
self.observation_space = self._wrapped_env.observation_space
@property
def wrapped_env(self):
return self._wrapped_env
def reset(self, **kwargs):
return self._wrapped_env.reset(**kwargs)
def step(self, action):
return self._wrapped_env.step(action)
def render(self, *args, **kwargs):
return self._wrapped_env.render(*args, **kwargs)
@property
def horizon(self):
return self._wrapped_env.horizon
def terminate(self):
if hasattr(self.wrapped_env, "terminate"):
self.wrapped_env.terminate()
def __getattr__(self, attr):
if attr == '_wrapped_env':
raise AttributeError()
return getattr(self._wrapped_env, attr)
def __getstate__(self):
"""
This is useful to override in case the wrapped env has some funky
__getstate__ that doesn't play well with overriding __getattr__.
The main problematic case is/was gym's EzPickle serialization scheme.
:return:
"""
return self.__dict__
def __setstate__(self, state):
self.__dict__.update(state)
def __str__(self):
return '{}({})'.format(type(self).__name__, self.wrapped_env)
class HistoryEnv(ProxyEnv, Env):
def __init__(self, wrapped_env, history_len):
super().__init__(wrapped_env)
self.history_len = history_len
high = np.inf * np.ones(
self.history_len * self.observation_space.low.size)
low = -high
self.observation_space = Box(low=low,
high=high,
)
self.history = deque(maxlen=self.history_len)
def step(self, action):
state, reward, done, info = super().step(action)
self.history.append(state)
flattened_history = self._get_history().flatten()
return flattened_history, reward, done, info
def reset(self, **kwargs):
state = super().reset()
self.history = deque(maxlen=self.history_len)
self.history.append(state)
flattened_history = self._get_history().flatten()
return flattened_history
def _get_history(self):
observations = list(self.history)
obs_count = len(observations)
for _ in range(self.history_len - obs_count):
dummy = np.zeros(self._wrapped_env.observation_space.low.size)
observations.append(dummy)
return np.c_[observations]
class DiscretizeEnv(ProxyEnv, Env):
def __init__(self, wrapped_env, num_bins):
super().__init__(wrapped_env)
low = self.wrapped_env.action_space.low
high = self.wrapped_env.action_space.high
action_ranges = [
np.linspace(low[i], high[i], num_bins)
for i in range(len(low))
]
self.idx_to_continuous_action = [
np.array(x) for x in itertools.product(*action_ranges)
]
self.action_space = Discrete(len(self.idx_to_continuous_action))
def step(self, action):
continuous_action = self.idx_to_continuous_action[action]
return super().step(continuous_action)
class NormalizedBoxEnv(ProxyEnv):
"""
Normalize action to in [-1, 1].
Optionally normalize observations and scale reward.
"""
def __init__(
self,
env,
reward_scale=1.,
obs_mean=None,
obs_std=None,
):
ProxyEnv.__init__(self, env)
self._should_normalize = not (obs_mean is None and obs_std is None)
if self._should_normalize:
if obs_mean is None:
obs_mean = np.zeros_like(env.observation_space.low)
else:
obs_mean = np.array(obs_mean)
if obs_std is None:
obs_std = np.ones_like(env.observation_space.low)
else:
obs_std = np.array(obs_std)
self._reward_scale = reward_scale
self._obs_mean = obs_mean
self._obs_std = obs_std
ub = np.ones(self._wrapped_env.action_space.shape)
self.action_space = Box(-1 * ub, ub)
def estimate_obs_stats(self, obs_batch, override_values=False):
if self._obs_mean is not None and not override_values:
raise Exception("Observation mean and std already set. To "
"override, set override_values to True.")
self._obs_mean = np.mean(obs_batch, axis=0)
self._obs_std = np.std(obs_batch, axis=0)
def _apply_normalize_obs(self, obs):
return (obs - self._obs_mean) / (self._obs_std + 1e-8)
def step(self, action):
lb = self._wrapped_env.action_space.low
ub = self._wrapped_env.action_space.high
scaled_action = lb + (action + 1.) * 0.5 * (ub - lb)
scaled_action = np.clip(scaled_action, lb, ub)
wrapped_step = self._wrapped_env.step(scaled_action)
next_obs, reward, done, info = wrapped_step
if self._should_normalize:
next_obs = self._apply_normalize_obs(next_obs)
return next_obs, reward * self._reward_scale, done, info
def __str__(self):
return "Normalized: %s" % self._wrapped_env
================================================
FILE: d4rl/d4rl/offline_env.py
================================================
import os
import urllib.request
import warnings
import gym
from gym.utils import colorize
import h5py
from tqdm import tqdm
def set_dataset_path(path):
global DATASET_PATH
DATASET_PATH = path
os.makedirs(path, exist_ok=True)
set_dataset_path(os.environ.get('D4RL_DATASET_DIR', os.path.expanduser('~/.d4rl/datasets')))
def get_keys(h5file):
keys = []
def visitor(name, item):
if isinstance(item, h5py.Dataset):
keys.append(name)
h5file.visititems(visitor)
return keys
def filepath_from_url(dataset_url):
_, dataset_name = os.path.split(dataset_url)
dataset_filepath = os.path.join(DATASET_PATH, dataset_name)
return dataset_filepath
def download_dataset_from_url(dataset_url):
dataset_filepath = filepath_from_url(dataset_url)
if not os.path.exists(dataset_filepath):
print('Downloading dataset:', dataset_url, 'to', dataset_filepath)
urllib.request.urlretrieve(dataset_url, dataset_filepath)
if not os.path.exists(dataset_filepath):
raise IOError("Failed to download dataset from %s" % dataset_url)
return dataset_filepath
class OfflineEnv(gym.Env):
"""
Base class for offline RL envs.
Args:
dataset_url: URL pointing to the dataset.
ref_max_score: Maximum score (for score normalization)
ref_min_score: Minimum score (for score normalization)
deprecated: If True, will display a warning that the environment is deprecated.
"""
def __init__(self, dataset_url=None, ref_max_score=None, ref_min_score=None,
deprecated=False, deprecation_message=None, **kwargs):
super(OfflineEnv, self).__init__(**kwargs)
self.dataset_url = self._dataset_url = dataset_url
self.ref_max_score = ref_max_score
self.ref_min_score = ref_min_score
if deprecated:
if deprecation_message is None:
deprecation_message = "This environment is deprecated. Please use the most recent version of this environment."
# stacklevel=2 will bump the warning to the superclass.
warnings.warn(colorize(deprecation_message, 'yellow'), stacklevel=2)
def get_normalized_score(self, score):
if (self.ref_max_score is None) or (self.ref_min_score is None):
raise ValueError("Reference score not provided for env")
return (score - self.ref_min_score) / (self.ref_max_score - self.ref_min_score)
@property
def dataset_filepath(self):
return filepath_from_url(self.dataset_url)
def get_dataset(self, h5path=None):
if h5path is None:
if self._dataset_url is None:
raise ValueError("Offline env not configured with a dataset URL.")
h5path = download_dataset_from_url(self.dataset_url)
data_dict = {}
with h5py.File(h5path, 'r') as dataset_file:
for k in tqdm(get_keys(dataset_file), desc="load datafile"):
try: # first try loading as an array
data_dict[k] = dataset_file[k][:]
except ValueError as e: # try loading as a scalar
data_dict[k] = dataset_file[k][()]
# Run a few quick sanity checks
for key in ['observations', 'actions', 'rewards', 'terminals']:
assert key in data_dict, 'Dataset is missing key %s' % key
N_samples = data_dict['observations'].shape[0]
if self.observation_space.shape is not None:
assert data_dict['observations'].shape[1:] == self.observation_space.shape, \
'Observation shape does not match env: %s vs %s' % (
str(data_dict['observations'].shape[1:]), str(self.observation_space.shape))
assert data_dict['actions'].shape[1:] == self.action_space.shape, \
'Action shape does not match env: %s vs %s' % (
str(data_dict['actions'].shape[1:]), str(self.action_space.shape))
if data_dict['rewards'].shape == (N_samples, 1):
data_dict['rewards'] = data_dict['rewards'][:, 0]
assert data_dict['rewards'].shape == (N_samples,), 'Reward has wrong shape: %s' % (
str(data_dict['rewards'].shape))
if data_dict['terminals'].shape == (N_samples, 1):
data_dict['terminals'] = data_dict['terminals'][:, 0]
assert data_dict['terminals'].shape == (N_samples,), 'Terminals has wrong shape: %s' % (
str(data_dict['rewards'].shape))
return data_dict
def get_dataset_chunk(self, chunk_id, h5path=None):
"""
Returns a slice of the full dataset.
Args:
chunk_id (int): An integer representing which slice of the dataset to return.
Returns:
A dictionary containing observtions, actions, rewards, and terminals.
"""
if h5path is None:
if self._dataset_url is None:
raise ValueError("Offline env not configured with a dataset URL.")
h5path = download_dataset_from_url(self.dataset_url)
dataset_file = h5py.File(h5path, 'r')
if 'virtual' not in dataset_file.keys():
raise ValueError('Dataset is not a chunked dataset')
available_chunks = [int(_chunk) for _chunk in list(dataset_file['virtual'].keys())]
if chunk_id not in available_chunks:
raise ValueError('Chunk id not found: %d. Available chunks: %s' % (chunk_id, str(available_chunks)))
load_keys = ['observations', 'actions', 'rewards', 'terminals']
data_dict = {k: dataset_file['virtual/%d/%s' % (chunk_id, k)][:] for k in load_keys}
dataset_file.close()
return data_dict
class OfflineEnvWrapper(gym.Wrapper, OfflineEnv):
"""
Wrapper class for offline RL envs.
"""
def __init__(self, env, **kwargs):
gym.Wrapper.__init__(self, env)
OfflineEnv.__init__(self, **kwargs)
def reset(self):
return self.env.reset()
================================================
FILE: d4rl/d4rl/ope.py
================================================
"""
Metrics for off-policy evaluation.
"""
from d4rl import infos
import numpy as np
UNDISCOUNTED_POLICY_RETURNS = {
'halfcheetah-medium' : 3985.8150261686337,
'halfcheetah-random' : -199.26067391425954,
'halfcheetah-expert' : 12330.945945279545,
'hopper-medium' : 2260.1983114487352,
'hopper-random' : 1257.9757846810203,
'hopper-expert' : 3624.4696022560997,
'walker2d-medium' : 2760.3310101980005,
'walker2d-random' : 896.4751989935487,
'walker2d-expert' : 4005.89370727539,
}
DISCOUNTED_POLICY_RETURNS = {
'halfcheetah-medium' : 324.83583782709877,
'halfcheetah-random' : -16.836944753939207,
'halfcheetah-expert' : 827.7278887047698,
'hopper-medium' : 235.7441494727478,
'hopper-random' : 215.04955086664955,
'hopper-expert' : 271.6925087260701,
'walker2d-medium' : 202.23983424823822,
'walker2d-random' : 78.46052021427765,
'walker2d-expert' : 396.8752247768766
}
def get_returns(policy_id, discounted=False):
if discounted:
return DISCOUNTED_POLICY_RETURNS[policy_id]
return UNDISCOUNTED_POLICY_RETURNS[policy_id]
def normalize(policy_id, score):
key = policy_id + '-v0'
min_score = infos.REF_MIN_SCORE[key]
max_score = infos.REF_MAX_SCORE[key]
return (score - min_score) / (max_score - min_score)
def ranking_correlation_metric(policies, discounted=False):
"""
Computes Spearman's rank correlation coefficient.
A score of 1.0 means the policies are ranked correctly according to their values.
A score of -1.0 means the policies are ranked inversely.
Args:
policies: A list of policy string identifiers.
Valid identifiers must be contained in POLICY_RETURNS.
Returns:
A correlation value between [-1, 1]
"""
return_values = np.array([get_returns(policy_key, discounted=discounted) for policy_key in policies])
ranks = np.argsort(-return_values)
N = len(policies)
diff = ranks - np.arange(N)
return 1.0 - (6 * np.sum(diff ** 2)) / (N * (N**2 - 1))
def precision_at_k_metric(policies, k=1, n_rel=None, discounted=False):
"""
Computes precision@k.
Args:
policies: A list of policy string identifiers.
k (int): Number of top items.
n_rel (int): Number of relevant items. Default is k.
Returns:
Fraction of top k policies in the top n_rel of the true rankings.
"""
assert len(policies) >= k
if n_rel is None:
n_rel = k
top_k = sorted(policies, reverse=True, key=lambda x: get_returns(x, discounted=discounted))[:n_rel]
policy_k = policies[:k]
score = sum([policy in top_k for policy in policy_k])
return float(score) / k
def recall_at_k_metric(policies, k=1, n_rel=None, discounted=False):
"""
Computes recall@k.
Args:
policies: A list of policy string identifiers.
k (int): Number of top items.
n_rel (int): Number of relevant items. Default is k.
Returns:
Fraction of top n_rel true policy rankings in the top k of the given policies
"""
assert len(policies) >= k
if n_rel is None:
n_rel = k
top_k = sorted(policies, reverse=True, key=lambda x: get_returns(x, discounted=discounted))[:n_rel]
policy_k = policies[:k]
score = sum([policy in policy_k for policy in top_k])
return float(score) / k
def value_error_metric(policy, value, discounted=False):
"""
Returns the absolute error in estimated value.
Args:
policy (str): A policy string identifier.
value (float): Estimated value
"""
return abs(normalize(policy, value) - normalize(policy, get_returns(policy, discounted)))
def policy_regret_metric(policy, expert_policies, discounted=False):
"""
Returns the regret of the given policy against a set of expert policies.
Args:
policy (str): A policy string identifier.
expert_policies (list[str]): A list of expert policies
Returns:
The regret, which is value of the best expert minus the value of the policy.
"""
best_returns = max([get_returns(policy_key, discounted=discounted) for policy_key in expert_policies])
return normalize(policy, best_returns) - normalize(policy, get_returns(policy, discounted=discounted))
================================================
FILE: d4rl/d4rl/pointmaze/__init__.py
================================================
from .maze_model import MazeEnv, OPEN, U_MAZE, MEDIUM_MAZE, LARGE_MAZE, U_MAZE_EVAL, MEDIUM_MAZE_EVAL, LARGE_MAZE_EVAL
from gym.envs.registration import register
register(
id='maze2d-open-v0',
entry_point='d4rl.pointmaze:MazeEnv',
max_episode_steps=150,
kwargs={
'maze_spec':OPEN,
'reward_type':'sparse',
'reset_target': False,
'ref_min_score': 0.01,
'ref_max_score': 20.66,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-sparse.hdf5'
}
)
register(
id='maze2d-umaze-v0',
entry_point='d4rl.pointmaze:MazeEnv',
max_episode_steps=150,
kwargs={
'maze_spec':U_MAZE,
'reward_type':'sparse',
'reset_target': False,
'ref_min_score': 0.94,
'ref_max_score': 62.6,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-sparse.hdf5'
}
)
register(
id='maze2d-medium-v0',
entry_point='d4rl.pointmaze:MazeEnv',
max_episode_steps=250,
kwargs={
'maze_spec':MEDIUM_MAZE,
'reward_type':'sparse',
'reset_target': False,
'ref_min_score': 5.77,
'ref_max_score': 85.14,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-sparse.hdf5'
}
)
register(
id='maze2d-large-v0',
entry_point='d4rl.pointmaze:MazeEnv',
max_episode_steps=600,
kwargs={
'maze_spec':LARGE_MAZE,
'reward_type':'sparse',
'reset_target': False,
'ref_min_score': 4.83,
'ref_max_score': 191.99,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-sparse.hdf5'
}
)
register(
id='maze2d-umaze-v1',
entry_point='d4rl.pointmaze:MazeEnv',
max_episode_steps=300,
kwargs={
'maze_spec':U_MAZE,
'reward_type':'sparse',
'reset_target': False,
'ref_min_score': 23.85,
'ref_max_score': 161.86,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-sparse-v1.hdf5'
}
)
register(
id='maze2d-medium-v1',
entry_point='d4rl.pointmaze:MazeEnv',
max_episode_steps=600,
kwargs={
'maze_spec':MEDIUM_MAZE,
'reward_type':'sparse',
'reset_target': False,
'ref_min_score': 13.13,
'ref_max_score': 277.39,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-sparse-v1.hdf5'
}
)
register(
id='maze2d-large-v1',
entry_point='d4rl.pointmaze:MazeEnv',
max_episode_steps=800,
kwargs={
'maze_spec':LARGE_MAZE,
'reward_type':'sparse',
'reset_target': False,
'ref_min_score': 6.7,
'ref_max_score': 273.99,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-sparse-v1.hdf5'
}
)
register(
id='maze2d-eval-umaze-v1',
entry_point='d4rl.pointmaze:MazeEnv',
max_episode_steps=300,
kwargs={
'maze_spec':U_MAZE_EVAL,
'reward_type':'sparse',
'reset_target': False,
'ref_min_score': 36.63,
'ref_max_score': 141.4,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-umaze-sparse-v1.hdf5'
}
)
register(
id='maze2d-eval-medium-v1',
entry_point='d4rl.pointmaze:MazeEnv',
max_episode_steps=600,
kwargs={
'maze_spec':MEDIUM_MAZE_EVAL,
'reward_type':'sparse',
'reset_target': False,
'ref_min_score': 13.07,
'ref_max_score': 204.93,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-medium-sparse-v1.hdf5'
}
)
register(
id='maze2d-eval-large-v1',
entry_point='d4rl.pointmaze:MazeEnv',
max_episode_steps=800,
kwargs={
'maze_spec':LARGE_MAZE_EVAL,
'reward_type':'sparse',
'reset_target': False,
'ref_min_score': 16.4,
'ref_max_score': 302.22,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-large-sparse-v1.hdf5'
}
)
register(
id='maze2d-open-dense-v0',
entry_point='d4rl.pointmaze:MazeEnv',
max_episode_steps=150,
kwargs={
'maze_spec':OPEN,
'reward_type':'dense',
'reset_target': False,
'ref_min_score': 11.17817,
'ref_max_score': 27.166538620695782,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-dense.hdf5'
}
)
register(
id='maze2d-umaze-dense-v0',
entry_point='d4rl.pointmaze:MazeEnv',
max_episode_steps=150,
kwargs={
'maze_spec':U_MAZE,
'reward_type':'dense',
'reset_target': False,
'ref_min_score': 23.249793,
'ref_max_score': 81.78995240126592,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-dense.hdf5'
}
)
register(
id='maze2d-medium-dense-v0',
entry_point='d4rl.pointmaze:MazeEnv',
max_episode_steps=250,
kwargs={
'maze_spec':MEDIUM_MAZE,
'reward_type':'dense',
'reset_target': False,
'ref_min_score': 19.477620,
'ref_max_score': 96.03474232952358,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-dense.hdf5'
}
)
register(
id='maze2d-large-dense-v0',
entry_point='d4rl.pointmaze:MazeEnv',
max_episode_steps=600,
kwargs={
'maze_spec':LARGE_MAZE,
'reward_type':'dense',
'reset_target': False,
'ref_min_score': 27.388310,
'ref_max_score': 215.09965671563742,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-dense.hdf5'
}
)
register(
id='maze2d-umaze-dense-v1',
entry_point='d4rl.pointmaze:MazeEnv',
max_episode_steps=300,
kwargs={
'maze_spec':U_MAZE,
'reward_type':'dense',
'reset_target': False,
'ref_min_score': 68.537689,
'ref_max_score': 193.66285642381482,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-dense-v1.hdf5'
}
)
register(
id='maze2d-medium-dense-v1',
entry_point='d4rl.pointmaze:MazeEnv',
max_episode_steps=600,
kwargs={
'maze_spec':MEDIUM_MAZE,
'reward_type':'dense',
'reset_target': False,
'ref_min_score': 44.264742,
'ref_max_score': 297.4552547777125,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-dense-v1.hdf5'
}
)
register(
id='maze2d-large-dense-v1',
entry_point='d4rl.pointmaze:MazeEnv',
max_episode_steps=800,
kwargs={
'maze_spec':LARGE_MAZE,
'reward_type':'dense',
'reset_target': False,
'ref_min_score': 30.569041,
'ref_max_score': 303.4857382709002,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-dense-v1.hdf5'
}
)
register(
id='maze2d-eval-umaze-dense-v1',
entry_point='d4rl.pointmaze:MazeEnv',
max_episode_steps=300,
kwargs={
'maze_spec':U_MAZE_EVAL,
'reward_type':'dense',
'reset_target': False,
'ref_min_score': 56.95455,
'ref_max_score': 178.21373133248397,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-umaze-dense-v1.hdf5'
}
)
register(
id='maze2d-eval-medium-dense-v1',
entry_point='d4rl.pointmaze:MazeEnv',
max_episode_steps=600,
kwargs={
'maze_spec':MEDIUM_MAZE_EVAL,
'reward_type':'dense',
'reset_target': False,
'ref_min_score': 42.28578,
'ref_max_score': 235.5658957482388,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-medium-dense-v1.hdf5'
}
)
register(
id='maze2d-eval-large-dense-v1',
entry_point='d4rl.pointmaze:MazeEnv',
max_episode_steps=800,
kwargs={
'maze_spec':LARGE_MAZE_EVAL,
'reward_type':'dense',
'reset_target': False,
'ref_min_score': 56.95455,
'ref_max_score': 326.09647655082637,
'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-large-dense-v1.hdf5'
}
)
================================================
FILE: d4rl/d4rl/pointmaze/dynamic_mjc.py
================================================
"""
dynamic_mjc.py
A small library for programatically building MuJoCo XML files
"""
from contextlib import contextmanager
import tempfile
import numpy as np
def default_model(name):
"""
Get a model with basic settings such as gravity and RK4 integration enabled
"""
model = MJCModel(name)
root = model.root
# Setup
root.compiler(angle="radian", inertiafromgeom="true")
default = root.default()
default.joint(armature=1, damping=1, limited="true")
default.geom(contype=0, friction='1 0.1 0.1', rgba='0.7 0.7 0 1')
root.option(gravity="0 0 -9.81", integrator="RK4", timestep=0.01)
return model
def pointmass_model(name):
"""
Get a model with basic settings such as gravity and Euler integration enabled
"""
model = MJCModel(name)
root = model.root
# Setup
root.compiler(angle="radian", inertiafromgeom="true", coordinate="local")
default = root.default()
default.joint(limited="false", damping=1)
default.geom(contype=2, conaffinity="1", condim="1", friction=".5 .1 .1", density="1000", margin="0.002")
root.option(timestep=0.01, gravity="0 0 0", iterations="20", integrator="Euler")
return model
class MJCModel(object):
def __init__(self, name):
self.name = name
self.root = MJCTreeNode("mujoco").add_attr('model', name)
@contextmanager
def asfile(self):
"""
Usage:
model = MJCModel('reacher')
with model.asfile() as f:
print f.read() # prints a dump of the model
"""
with tempfile.NamedTemporaryFile(mode='w+', suffix='.xml', delete=True) as f:
self.root.write(f)
f.seek(0)
yield f
def open(self):
self.file = tempfile.NamedTemporaryFile(mode='w+', suffix='.xml', delete=True)
self.root.write(self.file)
self.file.seek(0)
return self.file
def close(self):
self.file.close()
def find_attr(self, attr, value):
return self.root.find_attr(attr, value)
def __getstate__(self):
return {}
def __setstate__(self, state):
pass
class MJCTreeNode(object):
def __init__(self, name):
self.name = name
self.attrs = {}
self.children = []
def add_attr(self, key, value):
if isinstance(value, str):
pass
elif isinstance(value, list) or isinstance(value, np.ndarray):
value = ' '.join([str(val).lower() for val in value])
else:
value = str(value).lower()
self.attrs[key] = value
return self
def __getattr__(self, name):
def wrapper(**kwargs):
newnode = MJCTreeNode(name)
for (k, v) in kwargs.items():
newnode.add_attr(k, v)
self.children.append(newnode)
return newnode
return wrapper
def dfs(self):
yield self
if self.children:
for child in self.children:
for node in child.dfs():
yield node
def find_attr(self, attr, value):
""" Run DFS to find a matching attr """
if attr in self.attrs and self.attrs[attr] == value:
return self
for child in self.children:
res = child.find_attr(attr, value)
if res is not None:
return res
return None
def write(self, ostream, tabs=0):
contents = ' '.join(['%s="%s"'%(k,v) for (k,v) in self.attrs.items()])
if self.children:
ostream.write('\t'*tabs)
ostream.write('<%s %s>\n' % (self.name, contents))
for child in self.children:
child.write(ostream, tabs=tabs+1)
ostream.write('\t'*tabs)
ostream.write('%s>\n' % self.name)
else:
ostream.write('\t'*tabs)
ostream.write('<%s %s/>\n' % (self.name, contents))
def __str__(self):
s = "<"+self.name
s += ' '.join(['%s="%s"'%(k,v) for (k,v) in self.attrs.items()])
return s+">"
================================================
FILE: d4rl/d4rl/pointmaze/gridcraft/__init__.py
================================================
================================================
FILE: d4rl/d4rl/pointmaze/gridcraft/grid_env.py
================================================
import sys
import numpy as np
import gym
import gym.spaces
from d4rl.pointmaze.gridcraft.grid_spec import REWARD, REWARD2, REWARD3, REWARD4, WALL, LAVA, TILES, START, RENDER_DICT
from d4rl.pointmaze.gridcraft.utils import one_hot_to_flat, flat_to_one_hot
ACT_NOOP = 0
ACT_UP = 1
ACT_DOWN = 2
ACT_LEFT = 3
ACT_RIGHT = 4
ACT_DICT = {
ACT_NOOP: [0,0],
ACT_UP: [0, -1],
ACT_LEFT: [-1, 0],
ACT_RIGHT: [+1, 0],
ACT_DOWN: [0, +1]
}
ACT_TO_STR = {
ACT_NOOP: 'NOOP',
ACT_UP: 'UP',
ACT_LEFT: 'LEFT',
ACT_RIGHT: 'RIGHT',
ACT_DOWN: 'DOWN'
}
class TransitionModel(object):
def __init__(self, gridspec, eps=0.2):
self.gs = gridspec
self.eps = eps
def get_aprobs(self, s, a):
# TODO: could probably output a matrix over all states...
legal_moves = self.__get_legal_moves(s)
p = np.zeros(len(ACT_DICT))
p[list(legal_moves)] = self.eps / (len(legal_moves))
if a in legal_moves:
p[a] += 1.0-self.eps
else:
#p = np.array([1.0,0,0,0,0]) # NOOP
p[ACT_NOOP] += 1.0-self.eps
return p
def __get_legal_moves(self, s):
xy = np.array(self.gs.idx_to_xy(s))
moves = {move for move in ACT_DICT if not self.gs.out_of_bounds(xy+ACT_DICT[move])
and self.gs[xy+ACT_DICT[move]] != WALL}
moves.add(ACT_NOOP)
return moves
class RewardFunction(object):
def __init__(self, rew_map=None, default=0):
if rew_map is None:
rew_map = {
REWARD: 1.0,
REWARD2: 2.0,
REWARD3: 4.0,
REWARD4: 8.0,
LAVA: -100.0,
}
self.default = default
self.rew_map = rew_map
def __call__(self, gridspec, s, a, ns):
val = gridspec[gridspec.idx_to_xy(s)]
if val in self.rew_map:
return self.rew_map[val]
return self.default
class GridEnv(gym.Env):
def __init__(self, gridspec,
tiles=TILES,
rew_fn=None,
teps=0.0,
max_timesteps=None,
rew_map=None,
terminal_states=None,
default_rew=0):
self.num_states = len(gridspec)
self.num_actions = 5
self._env_args = {'teps': teps, 'max_timesteps': max_timesteps}
self.gs = gridspec
self.model = TransitionModel(gridspec, eps=teps)
self.terminal_states = terminal_states
if rew_fn is None:
rew_fn = RewardFunction(rew_map=rew_map, default=default_rew)
self.rew_fn = rew_fn
self.possible_tiles = tiles
self.max_timesteps = max_timesteps
self._timestep = 0
self._true_q = None # q_vals for debugging
super(GridEnv, self).__init__()
def get_transitions(self, s, a):
tile_type = self.gs[self.gs.idx_to_xy(s)]
if tile_type == LAVA: # Lava gets you stuck
return {s: 1.0}
aprobs = self.model.get_aprobs(s, a)
t_dict = {}
for sa in range(5):
if aprobs[sa] > 0:
next_s = self.gs.idx_to_xy(s) + ACT_DICT[sa]
next_s_idx = self.gs.xy_to_idx(next_s)
t_dict[next_s_idx] = t_dict.get(next_s_idx, 0.0) + aprobs[sa]
return t_dict
def step_stateless(self, s, a, verbose=False):
aprobs = self.model.get_aprobs(s, a)
samp_a = np.random.choice(range(5), p=aprobs)
next_s = self.gs.idx_to_xy(s) + ACT_DICT[samp_a]
tile_type = self.gs[self.gs.idx_to_xy(s)]
if tile_type == LAVA: # Lava gets you stuck
next_s = self.gs.idx_to_xy(s)
next_s_idx = self.gs.xy_to_idx(next_s)
rew = self.rew_fn(self.gs, s, samp_a, next_s_idx)
if verbose:
print('Act: %s. Act Executed: %s' % (ACT_TO_STR[a], ACT_TO_STR[samp_a]))
return next_s_idx, rew
def step(self, a, verbose=False):
ns, r = self.step_stateless(self.__state, a, verbose=verbose)
traj_infos = {}
self.__state = ns
obs = ns #flat_to_one_hot(ns, len(self.gs))
done = False
self._timestep += 1
if self.max_timesteps is not None:
if self._timestep >= self.max_timesteps:
done = True
return obs, r, done, traj_infos
def reset(self):
start_idxs = np.array(np.where(self.gs.spec == START)).T
start_idx = start_idxs[np.random.randint(0, start_idxs.shape[0])]
start_idx = self.gs.xy_to_idx(start_idx)
self.__state =start_idx
self._timestep = 0
return start_idx #flat_to_one_hot(start_idx, len(self.gs))
def render(self, close=False, ostream=sys.stdout):
if close:
return
state = self.__state
ostream.write('-'*(self.gs.width+2)+'\n')
for h in range(self.gs.height):
ostream.write('|')
for w in range(self.gs.width):
if self.gs.xy_to_idx((w,h)) == state:
ostream.write('*')
else:
val = self.gs[w, h]
ostream.write(RENDER_DICT[val])
ostream.write('|\n')
ostream.write('-' * (self.gs.width + 2)+'\n')
@property
def action_space(self):
return gym.spaces.Discrete(5)
@property
def observation_space(self):
dO = len(self.gs)
#return gym.spaces.Box(0,1,shape=dO)
return gym.spaces.Discrete(dO)
def transition_matrix(self):
"""Constructs this environment's transition matrix.
Returns:
A dS x dA x dS array where the entry transition_matrix[s, a, ns]
corrsponds to the probability of transitioning into state ns after taking
action a from state s.
"""
ds = self.num_states
da = self.num_actions
transition_matrix = np.zeros((ds, da, ds))
for s in range(ds):
for a in range(da):
transitions = self.get_transitions(s,a)
for next_s in transitions:
transition_matrix[s, a, next_s] = transitions[next_s]
return transition_matrix
def reward_matrix(self):
"""Constructs this environment's reward matrix.
Returns:
A dS x dA x dS numpy array where the entry reward_matrix[s, a, ns]
reward given to an agent when transitioning into state ns after taking
action s from state s.
"""
ds = self.num_states
da = self.num_actions
rew_matrix = np.zeros((ds, da, ds))
for s in range(ds):
for a in range(da):
for ns in range(ds):
rew_matrix[s, a, ns] = self.rew_fn(self.gs, s, a, ns)
return rew_matrix
================================================
FILE: d4rl/d4rl/pointmaze/gridcraft/grid_spec.py
================================================
import numpy as np
EMPTY = 110
WALL = 111
START = 112
REWARD = 113
OUT_OF_BOUNDS = 114
REWARD2 = 115
REWARD3 = 116
REWARD4 = 117
LAVA = 118
GOAL = 119
TILES = {EMPTY, WALL, START, REWARD, REWARD2, REWARD3, REWARD4, LAVA, GOAL}
STR_MAP = {
'O': EMPTY,
'#': WALL,
'S': START,
'R': REWARD,
'2': REWARD2,
'3': REWARD3,
'4': REWARD4,
'G': GOAL,
'L': LAVA
}
RENDER_DICT = {v:k for k, v in STR_MAP.items()}
RENDER_DICT[EMPTY] = ' '
RENDER_DICT[START] = ' '
def spec_from_string(s, valmap=STR_MAP):
if s.endswith('\\'):
s = s[:-1]
rows = s.split('\\')
rowlens = np.array([len(row) for row in rows])
assert np.all(rowlens == rowlens[0])
w, h = len(rows), len(rows[0])#len(rows[0]), len(rows)
gs = GridSpec(w, h)
for i in range(w):
for j in range(h):
gs[i,j] = valmap[rows[i][j]]
return gs
def spec_from_sparse_locations(w, h, tile_to_locs):
"""
Example usage:
>> spec_from_sparse_locations(10, 10, {START: [(0,0)], REWARD: [(7,8), (8,8)]})
"""
gs = GridSpec(w, h)
for tile_type in tile_to_locs:
locs = np.array(tile_to_locs[tile_type])
for i in range(locs.shape[0]):
gs[tuple(locs[i])] = tile_type
return gs
def local_spec(map, xpnt):
"""
>>> local_spec("yOy\\\\Oxy", xpnt=(5,5))
array([[4, 4],
[6, 4],
[6, 5]])
"""
Y = 0; X=1; O=2
valmap={
'y': Y,
'x': X,
'O': O
}
gs = spec_from_string(map, valmap=valmap)
ys = gs.find(Y)
x = gs.find(X)
result = ys-x + np.array(xpnt)
return result
class GridSpec(object):
def __init__(self, w, h):
self.__data = np.zeros((w, h), dtype=np.int32)
self.__w = w
self.__h = h
def __setitem__(self, key, val):
self.__data[key] = val
def __getitem__(self, key):
if self.out_of_bounds(key):
raise NotImplementedError("Out of bounds:"+str(key))
return self.__data[tuple(key)]
def out_of_bounds(self, wh):
""" Return true if x, y is out of bounds """
w, h = wh
if w<0 or w>=self.__w:
return True
if h < 0 or h >= self.__h:
return True
return False
def get_neighbors(self, k, xy=False):
""" Return values of up, down, left, and right tiles """
if not xy:
k = self.idx_to_xy(k)
offsets = [np.array([0,-1]), np.array([0,1]),
np.array([-1,0]), np.array([1,0])]
neighbors = \
[self[k+offset] if (not self.out_of_bounds(k+offset)) else OUT_OF_BOUNDS for offset in offsets ]
return neighbors
def get_value(self, k, xy=False):
""" Return values of up, down, left, and right tiles """
if not xy:
k = self.idx_to_xy(k)
return self[k]
def find(self, value):
return np.array(np.where(self.spec == value)).T
@property
def spec(self):
return self.__data
@property
def width(self):
return self.__w
def __len__(self):
return self.__w*self.__h
@property
def height(self):
return self.__h
def idx_to_xy(self, idx):
if hasattr(idx, '__len__'): # array
x = idx % self.__w
y = np.floor(idx/self.__w).astype(np.int32)
xy = np.c_[x,y]
return xy
else:
return np.array([ idx % self.__w, int(np.floor(idx/self.__w))])
def xy_to_idx(self, key):
shape = np.array(key).shape
if len(shape) == 1:
return key[0] + key[1]*self.__w
elif len(shape) == 2:
return key[:,0] + key[:,1]*self.__w
else:
raise NotImplementedError()
def __hash__(self):
data = (self.__w, self.__h) + tuple(self.__data.reshape([-1]).tolist())
return hash(data)
================================================
FILE: d4rl/d4rl/pointmaze/gridcraft/utils.py
================================================
import numpy as np
def flat_to_one_hot(val, ndim):
"""
>>> flat_to_one_hot(2, ndim=4)
array([ 0., 0., 1., 0.])
>>> flat_to_one_hot(4, ndim=5)
array([ 0., 0., 0., 0., 1.])
>>> flat_to_one_hot(np.array([2, 4, 3]), ndim=5)
array([[ 0., 0., 1., 0., 0.],
[ 0., 0., 0., 0., 1.],
[ 0., 0., 0., 1., 0.]])
"""
shape =np.array(val).shape
v = np.zeros(shape + (ndim,))
if len(shape) == 1:
v[np.arange(shape[0]), val] = 1.0
else:
v[val] = 1.0
return v
def one_hot_to_flat(val):
"""
>>> one_hot_to_flat(np.array([0,0,0,0,1]))
4
>>> one_hot_to_flat(np.array([0,0,1,0]))
2
>>> one_hot_to_flat(np.array([[0,0,1,0], [1,0,0,0], [0,1,0,0]]))
array([2, 0, 1])
"""
idxs = np.array(np.where(val == 1.0))[-1]
if len(val.shape) == 1:
return int(idxs)
return idxs
================================================
FILE: d4rl/d4rl/pointmaze/gridcraft/wrappers.py
================================================
import numpy as np
from d4rl.pointmaze.gridcraft.grid_env import REWARD, GridEnv
from d4rl.pointmaze.gridcraft.wrappers import ObsWrapper
from gym.spaces import Box
class GridObsWrapper(ObsWrapper):
def __init__(self, env):
super(GridObsWrapper, self).__init__(env)
def render(self):
self.env.render()
class EyesWrapper(ObsWrapper):
def __init__(self, env, range=4, types=(REWARD,), angle_thresh=0.8):
super(EyesWrapper, self).__init__(env)
self.types = types
self.range = range
self.angle_thresh = angle_thresh
eyes_low = np.ones(5*len(types))
eyes_high = np.ones(5*len(types))
low = np.r_[env.observation_space.low, eyes_low]
high = np.r_[env.observation_space.high, eyes_high]
self.__observation_space = Box(low, high)
def wrap_obs(self, obs, info=None):
gs = self.env.gs # grid spec
xy = gs.idx_to_xy(self.env.obs_to_state(obs))
#xy = np.array([x, y])
extra_obs = []
for tile_type in self.types:
idxs = gs.find(tile_type).astype(np.float32) # N x 2
# gather all idxs that are close
diffs = idxs-np.expand_dims(xy, axis=0)
dists = np.linalg.norm(diffs, axis=1)
valid_idxs = np.where(dists <= self.range)[0]
if len(valid_idxs) == 0:
eye_data = np.array([0,0,0,0,0], dtype=np.float32)
else:
diffs = diffs[valid_idxs, :]
dists = dists[valid_idxs]+1e-6
cosines = diffs[:,0]/dists
cosines = np.r_[cosines, 0]
sines = diffs[:,1]/dists
sines = np.r_[sines, 0]
on_target = 0.0
if np.any(dists<=1.0):
on_target = 1.0
eye_data = np.abs(np.array([on_target, np.max(cosines), np.min(cosines), np.max(sines), np.min(sines)]))
eye_data[np.where(eye_data<=self.angle_thresh)] = 0
extra_obs.append(eye_data)
extra_obs = np.concatenate(extra_obs)
obs = np.r_[obs, extra_obs]
#if np.any(np.isnan(obs)):
# import pdb; pdb.set_trace()
return obs
def unwrap_obs(self, obs, info=None):
if len(obs.shape) == 1:
return obs[:-5*len(self.types)]
else:
return obs[:,:-5*len(self.types)]
@property
def observation_space(self):
return self.__observation_space
"""
class CoordinateWiseWrapper(GridObsWrapper):
def __init__(self, env):
assert isinstance(env, GridEnv)
super(CoordinateWiseWrapper, self).__init__(env)
self.gs = env.gs
self.dO = self.gs.width+self.gs.height
self.__observation_space = Box(0, 1, self.dO)
def wrap_obs(self, obs, info=None):
state = one_hot_to_flat(obs)
xy = self.gs.idx_to_xy(state)
x = flat_to_one_hot(xy[0], self.gs.width)
y = flat_to_one_hot(xy[1], self.gs.height)
obs = np.r_[x, y]
return obs
def unwrap_obs(self, obs, info=None):
if len(obs.shape) == 1:
x = obs[:self.gs.width]
y = obs[self.gs.width:]
x = one_hot_to_flat(x)
y = one_hot_to_flat(y)
state = self.gs.xy_to_idx(np.c_[x,y])
return flat_to_one_hot(state, self.dO)
else:
raise NotImplementedError()
"""
class RandomObsWrapper(GridObsWrapper):
def __init__(self, env, dO):
assert isinstance(env, GridEnv)
super(RandomObsWrapper, self).__init__(env)
self.gs = env.gs
self.dO = dO
self.obs_matrix = np.random.randn(self.dO, len(self.gs))
self.__observation_space = Box(np.min(self.obs_matrix), np.max(self.obs_matrix),
shape=(self.dO,), dtype=np.float32)
def wrap_obs(self, obs, info=None):
return np.inner(self.obs_matrix, obs)
def unwrap_obs(self, obs, info=None):
raise NotImplementedError()
================================================
FILE: d4rl/d4rl/pointmaze/maze_model.py
================================================
""" A pointmass maze env."""
from gym.envs.mujoco import mujoco_env
from gym import utils
from d4rl import offline_env
from d4rl.pointmaze.dynamic_mjc import MJCModel
import numpy as np
import random
WALL = 10
EMPTY = 11
GOAL = 12
def parse_maze(maze_str):
lines = maze_str.strip().split('\\')
width, height = len(lines), len(lines[0])
maze_arr = np.zeros((width, height), dtype=np.int32)
for w in range(width):
for h in range(height):
tile = lines[w][h]
if tile == '#':
maze_arr[w][h] = WALL
elif tile == 'G':
maze_arr[w][h] = GOAL
elif tile == ' ' or tile == 'O' or tile == '0':
maze_arr[w][h] = EMPTY
else:
raise ValueError('Unknown tile type: %s' % tile)
return maze_arr
def point_maze(maze_str):
maze_arr = parse_maze(maze_str)
mjcmodel = MJCModel('point_maze')
mjcmodel.root.compiler(inertiafromgeom="true", angle="radian", coordinate="local")
mjcmodel.root.option(timestep="0.01", gravity="0 0 0", iterations="20", integrator="Euler")
default = mjcmodel.root.default()
default.joint(damping=1, limited='false')
default.geom(friction=".5 .1 .1", density="1000", margin="0.002", condim="1", contype="2", conaffinity="1")
asset = mjcmodel.root.asset()
asset.texture(type="2d",name="groundplane",builtin="checker",rgb1="0.2 0.3 0.4",rgb2="0.1 0.2 0.3",width=100,height=100)
asset.texture(name="skybox",type="skybox",builtin="gradient",rgb1=".4 .6 .8",rgb2="0 0 0",
width="800",height="800",mark="random",markrgb="1 1 1")
asset.material(name="groundplane",texture="groundplane",texrepeat="20 20")
asset.material(name="wall",rgba=".7 .5 .3 1")
asset.material(name="target",rgba=".6 .3 .3 1")
visual = mjcmodel.root.visual()
visual.headlight(ambient=".4 .4 .4",diffuse=".8 .8 .8",specular="0.1 0.1 0.1")
visual.map(znear=.01)
visual.quality(shadowsize=2048)
worldbody = mjcmodel.root.worldbody()
worldbody.geom(name='ground',size="40 40 0.25",pos="0 0 -0.1",type="plane",contype=1,conaffinity=0,material="groundplane")
particle = worldbody.body(name='particle', pos=[1.2,1.2,0])
particle.geom(name='particle_geom', type='sphere', size=0.1, rgba='0.0 0.0 1.0 0.0', contype=1)
particle.site(name='particle_site', pos=[0.0,0.0,0], size=0.2, rgba='0.3 0.6 0.3 1')
particle.joint(name='ball_x', type='slide', pos=[0,0,0], axis=[1,0,0])
particle.joint(name='ball_y', type='slide', pos=[0,0,0], axis=[0,1,0])
worldbody.site(name='target_site', pos=[0.0,0.0,0], size=0.2, material='target')
width, height = maze_arr.shape
for w in range(width):
for h in range(height):
if maze_arr[w,h] == WALL:
worldbody.geom(conaffinity=1,
type='box',
name='wall_%d_%d'%(w,h),
material='wall',
pos=[w+1.0,h+1.0,0],
size=[0.5,0.5,0.2])
actuator = mjcmodel.root.actuator()
actuator.motor(joint="ball_x", ctrlrange=[-1.0, 1.0], ctrllimited=True, gear=100)
actuator.motor(joint="ball_y", ctrlrange=[-1.0, 1.0], ctrllimited=True, gear=100)
return mjcmodel
LARGE_MAZE = \
"############\\"+\
"#OOOO#OOOOO#\\"+\
"#O##O#O#O#O#\\"+\
"#OOOOOO#OOO#\\"+\
"#O####O###O#\\"+\
"#OO#O#OOOOO#\\"+\
"##O#O#O#O###\\"+\
"#OO#OOO#OGO#\\"+\
"############"
LARGE_MAZE_EVAL = \
"############\\"+\
"#OO#OOO#OGO#\\"+\
"##O###O#O#O#\\"+\
"#OO#O#OOOOO#\\"+\
"#O##O#OO##O#\\"+\
"#OOOOOO#OOO#\\"+\
"#O##O#O#O###\\"+\
"#OOOO#OOOOO#\\"+\
"############"
MEDIUM_MAZE = \
'########\\'+\
'#OO##OO#\\'+\
'#OO#OOO#\\'+\
'##OOO###\\'+\
'#OO#OOO#\\'+\
'#O#OO#O#\\'+\
'#OOO#OG#\\'+\
"########"
MEDIUM_MAZE_EVAL = \
'########\\'+\
'#OOOOOG#\\'+\
'#O#O##O#\\'+\
'#OOOO#O#\\'+\
'###OO###\\'+\
'#OOOOOO#\\'+\
'#OO##OO#\\'+\
"########"
SMALL_MAZE = \
"######\\"+\
"#OOOO#\\"+\
"#O##O#\\"+\
"#OOOO#\\"+\
"######"
U_MAZE = \
"#####\\"+\
"#GOO#\\"+\
"###O#\\"+\
"#OOO#\\"+\
"#####"
U_MAZE_EVAL = \
"#####\\"+\
"#OOG#\\"+\
"#O###\\"+\
"#OOO#\\"+\
"#####"
OPEN = \
"#######\\"+\
"#OOOOO#\\"+\
"#OOGOO#\\"+\
"#OOOOO#\\"+\
"#######"
class MazeEnv(mujoco_env.MujocoEnv, utils.EzPickle, offline_env.OfflineEnv):
def __init__(self,
maze_spec=U_MAZE,
reward_type='dense',
reset_target=False,
**kwargs):
offline_env.OfflineEnv.__init__(self, **kwargs)
self.reset_target = reset_target
self.str_maze_spec = maze_spec
self.maze_arr = parse_maze(maze_spec)
self.reward_type = reward_type
self.reset_locations = list(zip(*np.where(self.maze_arr == EMPTY)))
self.reset_locations.sort()
self._target = np.array([0.0,0.0])
model = point_maze(maze_spec)
with model.asfile() as f:
mujoco_env.MujocoEnv.__init__(self, model_path=f.name, frame_skip=1)
utils.EzPickle.__init__(self)
# Set the default goal (overriden by a call to set_target)
# Try to find a goal if it exists
self.goal_locations = list(zip(*np.where(self.maze_arr == GOAL)))
if len(self.goal_locations) == 1:
self.set_target(self.goal_locations[0])
elif len(self.goal_locations) > 1:
raise ValueError("More than 1 goal specified!")
else:
# If no goal, use the first empty tile
self.set_target(np.array(self.reset_locations[0]).astype(self.observation_space.dtype))
self.empty_and_goal_locations = self.reset_locations + self.goal_locations
def step(self, action):
action = np.clip(action, -1.0, 1.0)
self.clip_velocity()
self.do_simulation(action, self.frame_skip)
self.set_marker()
ob = self._get_obs()
if self.reward_type == 'sparse':
reward = 1.0 if np.linalg.norm(ob[0:2] - self._target) <= 0.5 else 0.0
elif self.reward_type == 'dense':
reward = np.exp(-np.linalg.norm(ob[0:2] - self._target))
else:
raise ValueError('Unknown reward type %s' % self.reward_type)
done = False
return ob, reward, done, {}
def _get_obs(self):
return np.concatenate([self.sim.data.qpos, self.sim.data.qvel]).ravel()
def get_target(self):
return self._target
def set_target(self, target_location=None):
if target_location is None:
idx = self.np_random.choice(len(self.empty_and_goal_locations))
reset_location = np.array(self.empty_and_goal_locations[idx]).astype(self.observation_space.dtype)
target_location = reset_location + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq)
self._target = target_location
def set_marker(self):
self.data.site_xpos[self.model.site_name2id('target_site')] = np.array([self._target[0]+1, self._target[1]+1, 0.0])
def clip_velocity(self):
qvel = np.clip(self.sim.data.qvel, -5.0, 5.0)
self.set_state(self.sim.data.qpos, qvel)
def reset_model(self):
idx = self.np_random.choice(len(self.empty_and_goal_locations))
reset_location = np.array(self.empty_and_goal_locations[idx]).astype(self.observation_space.dtype)
qpos = reset_location + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq)
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1
self.set_state(qpos, qvel)
if self.reset_target:
self.set_target()
return self._get_obs()
def reset_to_location(self, location):
self.sim.reset()
reset_location = np.array(location).astype(self.observation_space.dtype)
qpos = reset_location + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq)
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1
self.set_state(qpos, qvel)
return self._get_obs()
def viewer_setup(self):
pass
================================================
FILE: d4rl/d4rl/pointmaze/q_iteration.py
================================================
"""
Use q-iteration to solve for an optimal policy
Usage: q_iteration(env, gamma=discount factor, ent_wt= entropy bonus)
"""
import numpy as np
from scipy.special import logsumexp as sp_lse
def softmax(q, alpha=1.0):
q = (1.0/alpha)*q
q = q-np.max(q)
probs = np.exp(q)
probs = probs/np.sum(probs)
return probs
def logsumexp(q, alpha=1.0, axis=1):
if alpha == 0:
return np.max(q, axis=axis)
return alpha*sp_lse((1.0/alpha)*q, axis=axis)
def get_policy(q_fn, ent_wt=1.0):
v_rew = logsumexp(q_fn, alpha=ent_wt)
adv_rew = q_fn - np.expand_dims(v_rew, axis=1)
if ent_wt == 0:
pol_probs = adv_rew
pol_probs[pol_probs >= 0 ] = 1.0
pol_probs[pol_probs < 0 ] = 0.0
else:
pol_probs = np.exp((1.0/ent_wt)*adv_rew)
pol_probs /= np.sum(pol_probs, axis=1, keepdims=True)
assert np.all(np.isclose(np.sum(pol_probs, axis=1), 1.0)), str(pol_probs)
return pol_probs
def softq_iteration(env, transition_matrix=None, reward_matrix=None, num_itrs=50, discount=0.99, ent_wt=0.1, warmstart_q=None, policy=None):
"""
Perform tabular soft Q-iteration
"""
dim_obs = env.num_states
dim_act = env.num_actions
if reward_matrix is None:
reward_matrix = env.reward_matrix()
reward_matrix = reward_matrix[:,:,0]
if warmstart_q is None:
q_fn = np.zeros((dim_obs, dim_act))
else:
q_fn = warmstart_q
if transition_matrix is None:
t_matrix = env.transition_matrix()
else:
t_matrix = transition_matrix
for k in range(num_itrs):
if policy is None:
v_fn = logsumexp(q_fn, alpha=ent_wt)
else:
v_fn = np.sum((q_fn - ent_wt*np.log(policy))*policy, axis=1)
new_q = reward_matrix + discount*t_matrix.dot(v_fn)
q_fn = new_q
return q_fn
def q_iteration(env, **kwargs):
return softq_iteration(env, ent_wt=0.0, **kwargs)
def compute_visitation(env, q_fn, ent_wt=1.0, env_time_limit=50, discount=1.0):
pol_probs = get_policy(q_fn, ent_wt=ent_wt)
dim_obs = env.num_states
dim_act = env.num_actions
state_visitation = np.zeros((dim_obs, 1))
for (state, prob) in env.initial_state_distribution.items():
state_visitation[state] = prob
t_matrix = env.transition_matrix() # S x A x S
sa_visit_t = np.zeros((dim_obs, dim_act, env_time_limit))
for i in range(env_time_limit):
sa_visit = state_visitation * pol_probs
# sa_visit_t[:, :, i] = (discount ** i) * sa_visit
sa_visit_t[:, :, i] = sa_visit
# sum-out (SA)S
new_state_visitation = np.einsum('ij,ijk->k', sa_visit, t_matrix)
state_visitation = np.expand_dims(new_state_visitation, axis=1)
return np.sum(sa_visit_t, axis=2) / float(env_time_limit)
def compute_occupancy(env, q_fn, ent_wt=1.0, env_time_limit=50, discount=1.0):
pol_probs = get_policy(q_fn, ent_wt=ent_wt)
dim_obs = env.num_states
dim_act = env.num_actions
state_visitation = np.zeros((dim_obs, 1))
for (state, prob) in env.initial_state_distribution.items():
state_visitation[state] = prob
t_matrix = env.transition_matrix() # S x A x S
sa_visit_t = np.zeros((dim_obs, dim_act, env_time_limit))
for i in range(env_time_limit):
sa_visit = state_visitation * pol_probs
sa_visit_t[:, :, i] = (discount ** i) * sa_visit
# sa_visit_t[:, :, i] = sa_visit
# sum-out (SA)S
new_state_visitation = np.einsum('ij,ijk->k', sa_visit, t_matrix)
state_visitation = np.expand_dims(new_state_visitation, axis=1)
return np.sum(sa_visit_t, axis=2) #/ float(env_time_limit)
================================================
FILE: d4rl/d4rl/pointmaze/waypoint_controller.py
================================================
import numpy as np
from d4rl.pointmaze import q_iteration
from d4rl.pointmaze.gridcraft import grid_env
from d4rl.pointmaze.gridcraft import grid_spec
ZEROS = np.zeros((2,), dtype=np.float32)
ONES = np.zeros((2,), dtype=np.float32)
class WaypointController(object):
def __init__(self, maze_str, solve_thresh=0.1, p_gain=10.0, d_gain=-1.0):
self.maze_str = maze_str
self._target = -1000 * ONES
self.p_gain = p_gain
self.d_gain = d_gain
self.solve_thresh = solve_thresh
self.vel_thresh = 0.1
self._waypoint_idx = 0
self._waypoints = []
self._waypoint_prev_loc = ZEROS
self.env = grid_env.GridEnv(grid_spec.spec_from_string(maze_str))
def current_waypoint(self):
return self._waypoints[self._waypoint_idx]
def get_action(self, location, velocity, target):
if np.linalg.norm(self._target - np.array(self.gridify_state(target))) > 1e-3:
#print('New target!', target, 'old:', self._target)
self._new_target(location, target)
dist = np.linalg.norm(location - self._target)
vel = self._waypoint_prev_loc - location
vel_norm = np.linalg.norm(vel)
task_not_solved = (dist >= self.solve_thresh) or (vel_norm >= self.vel_thresh)
if task_not_solved:
next_wpnt = self._waypoints[self._waypoint_idx]
else:
next_wpnt = self._target
# Compute control
prop = next_wpnt - location
action = self.p_gain * prop + self.d_gain * velocity
dist_next_wpnt = np.linalg.norm(location - next_wpnt)
if task_not_solved and (dist_next_wpnt < self.solve_thresh) and (vel_norm 1:
raise ValueError("More than 1 goal specified!")
else:
# If no goal, use the first empty tile
self.set_target(np.array(self.reset_locations[0]).astype(self.observation_space.dtype))
self.empty_and_goal_locations = self.reset_locations + self.goal_locations
def create_single_player_scene(self, bullet_client):
return scene_abstract.SingleRobotEmptyScene(bullet_client, gravity=9.8, timestep=0.0165, frame_skip=1)
def reset(self):
if (self.stateId >= 0):
self._p.restoreState(self.stateId)
r = env_bases.MJCFBaseBulletEnv.reset(self)
if (self.stateId < 0):
self.stateId = self._p.saveState()
self.reset_model()
ob = self.robot.calc_state()
return ob
def step(self, action):
action = np.clip(action, -1.0, 1.0)
#self.clip_velocity()
self.robot.apply_action(action)
self.scene.global_step()
ob = self.robot.calc_state()
if self.reward_type == 'sparse':
reward = 1.0 if np.linalg.norm(ob[0:2] - self._target) <= 0.5 else 0.0
elif self.reward_type == 'dense':
reward = np.exp(-np.linalg.norm(ob[0:2] - self._target))
else:
raise ValueError('Unknown reward type %s' % self.reward_type)
done = False
self.HUD(ob, action, done)
return ob, reward, done, {}
def camera_adjust(self):
qpos = self.robot.qpos
x = qpos[0]
y = qpos[1]
self.camera.move_and_look_at(x, y, 1.4, x, y, 1.0)
def get_target(self):
return self._target
def set_target(self, target_location=None):
if target_location is None:
idx = self.np_random.choice(len(self.empty_and_goal_locations))
reset_location = np.array(self.empty_and_goal_locations[idx]).astype(self.observation_space.dtype)
target_location = reset_location + self.np_random.uniform(low=-.1, high=.1, size=2)
self._target = target_location
def clip_velocity(self):
qvel = np.clip(self.robot.qvel, -5.0, 5.0)
self.robot.set_state(self.robot.qpos, qvel)
def reset_model(self):
idx = self.np_random.choice(len(self.empty_and_goal_locations))
reset_location = np.array(self.empty_and_goal_locations[idx]).astype(self.observation_space.dtype)
qpos = reset_location + self.np_random.uniform(low=-.1, high=.1, size=2)
qvel = self.np_random.randn(2) * .1
self.robot.set_state(qpos, qvel)
if self.reset_target:
self.set_target()
return self.robot.get_obs()
def reset_to_location(self, location):
self.sim.reset()
reset_location = np.array(location).astype(self.observation_space.dtype)
qpos = reset_location + self.np_random.uniform(low=-.1, high=.1, size=2)
qvel = self.np_random.randn(2) * .1
self.robot.set_state(qpos, qvel)
return self.robot.get_obs()
================================================
FILE: d4rl/d4rl/pointmaze_bullet/bullet_robot.py
================================================
import os
import pybullet
from pybullet_envs import robot_bases
class MJCFBasedRobot(robot_bases.XmlBasedRobot):
"""
Base class for mujoco .xml based agents.
"""
def __init__(self, model_xml, robot_name, action_dim, obs_dim, self_collision=True):
robot_bases.XmlBasedRobot.__init__(self, robot_name, action_dim, obs_dim, self_collision)
self.model_xml = model_xml
self.doneLoading = 0
def reset(self, bullet_client):
self._p = bullet_client
#print("Created bullet_client with id=", self._p._client)
if (self.doneLoading == 0):
self.ordered_joints = []
self.doneLoading = 1
if self.self_collision:
self.objects = self._p.loadMJCF(self.model_xml,
flags=pybullet.URDF_USE_SELF_COLLISION |
pybullet.URDF_USE_SELF_COLLISION_EXCLUDE_ALL_PARENTS |
pybullet.URDF_GOOGLEY_UNDEFINED_COLORS )
self.parts, self.jdict, self.ordered_joints, self.robot_body = self.addToScene(
self._p, self.objects)
else:
self.objects = self._p.loadMJCF(self.model_xml, flags = pybullet.URDF_GOOGLEY_UNDEFINED_COLORS)
self.parts, self.jdict, self.ordered_joints, self.robot_body = self.addToScene(
self._p, self.objects)
self.robot_specific_reset(self._p)
s = self.calc_state(
) # optimization: calc_state() can calculate something in self.* for calc_potential() to use
return s
def calc_potential(self):
return 0
class WalkerBase(MJCFBasedRobot):
def __init__(self, fn, robot_name, action_dim, obs_dim, power):
MJCFBasedRobot.__init__(self, fn, robot_name, action_dim, obs_dim)
self.power = power
self.camera_x = 0
self.start_pos_x, self.start_pos_y, self.start_pos_z = 0, 0, 0
self.walk_target_x = 1e3 # kilometer away
self.walk_target_y = 0
self.body_xyz = [0, 0, 0]
def robot_specific_reset(self, bullet_client):
self._p = bullet_client
for j in self.ordered_joints:
j.reset_current_position(self.np_random.uniform(low=-0.1, high=0.1), 0)
self.feet = [self.parts[f] for f in self.foot_list]
self.feet_contact = np.array([0.0 for f in self.foot_list], dtype=np.float32)
self.scene.actor_introduce(self)
self.initial_z = None
def apply_action(self, a):
assert (np.isfinite(a).all())
for n, j in enumerate(self.ordered_joints):
j.set_motor_torque(self.power * j.power_coef * float(np.clip(a[n], -1, +1)))
def calc_state(self):
j = np.array([j.current_relative_position() for j in self.ordered_joints],
dtype=np.float32).flatten()
# even elements [0::2] position, scaled to -1..+1 between limits
# odd elements [1::2] angular speed, scaled to show -1..+1
self.joint_speeds = j[1::2]
self.joints_at_limit = np.count_nonzero(np.abs(j[0::2]) > 0.99)
body_pose = self.robot_body.pose()
parts_xyz = np.array([p.pose().xyz() for p in self.parts.values()]).flatten()
self.body_xyz = (parts_xyz[0::3].mean(), parts_xyz[1::3].mean(), body_pose.xyz()[2]
) # torso z is more informative than mean z
self.body_real_xyz = body_pose.xyz()
self.body_rpy = body_pose.rpy()
z = self.body_xyz[2]
if self.initial_z == None:
self.initial_z = z
r, p, yaw = self.body_rpy
self.walk_target_theta = np.arctan2(self.walk_target_y - self.body_xyz[1],
self.walk_target_x - self.body_xyz[0])
self.walk_target_dist = np.linalg.norm(
[self.walk_target_y - self.body_xyz[1], self.walk_target_x - self.body_xyz[0]])
angle_to_target = self.walk_target_theta - yaw
rot_speed = np.array([[np.cos(-yaw), -np.sin(-yaw), 0], [np.sin(-yaw),
np.cos(-yaw), 0], [0, 0, 1]])
vx, vy, vz = np.dot(rot_speed,
self.robot_body.speed()) # rotate speed back to body point of view
more = np.array(
[
z - self.initial_z,
np.sin(angle_to_target),
np.cos(angle_to_target),
0.3 * vx,
0.3 * vy,
0.3 * vz, # 0.3 is just scaling typical speed into -1..+1, no physical sense here
r,
p
],
dtype=np.float32)
return np.clip(np.concatenate([more] + [j] + [self.feet_contact]), -5, +5)
def calc_potential(self):
# progress in potential field is speed*dt, typical speed is about 2-3 meter per second, this potential will change 2-3 per frame (not per second),
# all rewards have rew/frame units and close to 1.0
debugmode = 0
if (debugmode):
print("calc_potential: self.walk_target_dist")
print(self.walk_target_dist)
print("self.scene.dt")
print(self.scene.dt)
print("self.scene.frame_skip")
print(self.scene.frame_skip)
print("self.scene.timestep")
print(self.scene.timestep)
return -self.walk_target_dist / self.scene.dt
================================================
FILE: d4rl/d4rl/utils/__init__.py
================================================
================================================
FILE: d4rl/d4rl/utils/dataset_utils.py
================================================
import h5py
import numpy as np
class DatasetWriter(object):
def __init__(self, mujoco=False, goal=False):
self.mujoco = mujoco
self.goal = goal
self.data = self._reset_data()
self._num_samples = 0
def _reset_data(self):
data = {'observations': [],
'actions': [],
'terminals': [],
'rewards': [],
}
if self.mujoco:
data['infos/qpos'] = []
data['infos/qvel'] = []
if self.goal:
data['infos/goal'] = []
return data
def __len__(self):
return self._num_samples
def append_data(self, s, a, r, done, goal=None, mujoco_env_data=None):
self._num_samples += 1
self.data['observations'].append(s)
self.data['actions'].append(a)
self.data['rewards'].append(r)
self.data['terminals'].append(done)
if self.goal:
self.data['infos/goal'].append(goal)
if self.mujoco:
self.data['infos/qpos'].append(mujoco_env_data.qpos.ravel().copy())
self.data['infos/qvel'].append(mujoco_env_data.qvel.ravel().copy())
def write_dataset(self, fname, max_size=None, compression='gzip'):
np_data = {}
for k in self.data:
if k == 'terminals':
dtype = np.bool_
else:
dtype = np.float32
data = np.array(self.data[k], dtype=dtype)
if max_size is not None:
data = data[:max_size]
np_data[k] = data
dataset = h5py.File(fname, 'w')
for k in np_data:
dataset.create_dataset(k, data=np_data[k], compression=compression)
dataset.close()
================================================
FILE: d4rl/d4rl/utils/quatmath.py
================================================
import numpy as np
# For testing whether a number is close to zero
_FLOAT_EPS = np.finfo(np.float64).eps
_EPS4 = _FLOAT_EPS * 4.0
def mulQuat(qa, qb):
res = np.zeros(4)
res[0] = qa[0]*qb[0] - qa[1]*qb[1] - qa[2]*qb[2] - qa[3]*qb[3]
res[1] = qa[0]*qb[1] + qa[1]*qb[0] + qa[2]*qb[3] - qa[3]*qb[2]
res[2] = qa[0]*qb[2] - qa[1]*qb[3] + qa[2]*qb[0] + qa[3]*qb[1]
res[3] = qa[0]*qb[3] + qa[1]*qb[2] - qa[2]*qb[1] + qa[3]*qb[0]
return res
def negQuat(quat):
return np.array([quat[0], -quat[1], -quat[2], -quat[3]])
def quat2Vel(quat, dt=1):
axis = quat[1:].copy()
sin_a_2 = np.sqrt(np.sum(axis**2))
axis = axis/(sin_a_2+1e-8)
speed = 2*np.arctan2(sin_a_2, quat[0])/dt
return speed, axis
def quatDiff2Vel(quat1, quat2, dt):
neg = negQuat(quat1)
diff = mulQuat(quat2, neg)
return quat2Vel(diff, dt)
def axis_angle2quat(axis, angle):
c = np.cos(angle/2)
s = np.sin(angle/2)
return np.array([c, s*axis[0], s*axis[1], s*axis[2]])
def euler2mat(euler):
""" Convert Euler Angles to Rotation Matrix. See rotation.py for notes """
euler = np.asarray(euler, dtype=np.float64)
assert euler.shape[-1] == 3, "Invalid shaped euler {}".format(euler)
ai, aj, ak = -euler[..., 2], -euler[..., 1], -euler[..., 0]
si, sj, sk = np.sin(ai), np.sin(aj), np.sin(ak)
ci, cj, ck = np.cos(ai), np.cos(aj), np.cos(ak)
cc, cs = ci * ck, ci * sk
sc, ss = si * ck, si * sk
mat = np.empty(euler.shape[:-1] + (3, 3), dtype=np.float64)
mat[..., 2, 2] = cj * ck
mat[..., 2, 1] = sj * sc - cs
mat[..., 2, 0] = sj * cc + ss
mat[..., 1, 2] = cj * sk
mat[..., 1, 1] = sj * ss + cc
mat[..., 1, 0] = sj * cs - sc
mat[..., 0, 2] = -sj
mat[..., 0, 1] = cj * si
mat[..., 0, 0] = cj * ci
return mat
def euler2quat(euler):
""" Convert Euler Angles to Quaternions. See rotation.py for notes """
euler = np.asarray(euler, dtype=np.float64)
assert euler.shape[-1] == 3, "Invalid shape euler {}".format(euler)
ai, aj, ak = euler[..., 2] / 2, -euler[..., 1] / 2, euler[..., 0] / 2
si, sj, sk = np.sin(ai), np.sin(aj), np.sin(ak)
ci, cj, ck = np.cos(ai), np.cos(aj), np.cos(ak)
cc, cs = ci * ck, ci * sk
sc, ss = si * ck, si * sk
quat = np.empty(euler.shape[:-1] + (4,), dtype=np.float64)
quat[..., 0] = cj * cc + sj * ss
quat[..., 3] = cj * sc - sj * cs
quat[..., 2] = -(cj * ss + sj * cc)
quat[..., 1] = cj * cs - sj * sc
return quat
def mat2euler(mat):
""" Convert Rotation Matrix to Euler Angles. See rotation.py for notes """
mat = np.asarray(mat, dtype=np.float64)
assert mat.shape[-2:] == (3, 3), "Invalid shape matrix {}".format(mat)
cy = np.sqrt(mat[..., 2, 2] * mat[..., 2, 2] + mat[..., 1, 2] * mat[..., 1, 2])
condition = cy > _EPS4
euler = np.empty(mat.shape[:-1], dtype=np.float64)
euler[..., 2] = np.where(condition,
-np.arctan2(mat[..., 0, 1], mat[..., 0, 0]),
-np.arctan2(-mat[..., 1, 0], mat[..., 1, 1]))
euler[..., 1] = np.where(condition,
-np.arctan2(-mat[..., 0, 2], cy),
-np.arctan2(-mat[..., 0, 2], cy))
euler[..., 0] = np.where(condition,
-np.arctan2(mat[..., 1, 2], mat[..., 2, 2]),
0.0)
return euler
def mat2quat(mat):
""" Convert Rotation Matrix to Quaternion. See rotation.py for notes """
mat = np.asarray(mat, dtype=np.float64)
assert mat.shape[-2:] == (3, 3), "Invalid shape matrix {}".format(mat)
Qxx, Qyx, Qzx = mat[..., 0, 0], mat[..., 0, 1], mat[..., 0, 2]
Qxy, Qyy, Qzy = mat[..., 1, 0], mat[..., 1, 1], mat[..., 1, 2]
Qxz, Qyz, Qzz = mat[..., 2, 0], mat[..., 2, 1], mat[..., 2, 2]
# Fill only lower half of symmetric matrix
K = np.zeros(mat.shape[:-2] + (4, 4), dtype=np.float64)
K[..., 0, 0] = Qxx - Qyy - Qzz
K[..., 1, 0] = Qyx + Qxy
K[..., 1, 1] = Qyy - Qxx - Qzz
K[..., 2, 0] = Qzx + Qxz
K[..., 2, 1] = Qzy + Qyz
K[..., 2, 2] = Qzz - Qxx - Qyy
K[..., 3, 0] = Qyz - Qzy
K[..., 3, 1] = Qzx - Qxz
K[..., 3, 2] = Qxy - Qyx
K[..., 3, 3] = Qxx + Qyy + Qzz
K /= 3.0
# TODO: vectorize this -- probably could be made faster
q = np.empty(K.shape[:-2] + (4,))
it = np.nditer(q[..., 0], flags=['multi_index'])
while not it.finished:
# Use Hermitian eigenvectors, values for speed
vals, vecs = np.linalg.eigh(K[it.multi_index])
# Select largest eigenvector, reorder to w,x,y,z quaternion
q[it.multi_index] = vecs[[3, 0, 1, 2], np.argmax(vals)]
# Prefer quaternion with positive w
# (q * -1 corresponds to same rotation as q)
if q[it.multi_index][0] < 0:
q[it.multi_index] *= -1
it.iternext()
return q
def quat2euler(quat):
""" Convert Quaternion to Euler Angles. See rotation.py for notes """
return mat2euler(quat2mat(quat))
def quat2mat(quat):
""" Convert Quaternion to Euler Angles. See rotation.py for notes """
quat = np.asarray(quat, dtype=np.float64)
assert quat.shape[-1] == 4, "Invalid shape quat {}".format(quat)
w, x, y, z = quat[..., 0], quat[..., 1], quat[..., 2], quat[..., 3]
Nq = np.sum(quat * quat, axis=-1)
s = 2.0 / Nq
X, Y, Z = x * s, y * s, z * s
wX, wY, wZ = w * X, w * Y, w * Z
xX, xY, xZ = x * X, x * Y, x * Z
yY, yZ, zZ = y * Y, y * Z, z * Z
mat = np.empty(quat.shape[:-1] + (3, 3), dtype=np.float64)
mat[..., 0, 0] = 1.0 - (yY + zZ)
mat[..., 0, 1] = xY - wZ
mat[..., 0, 2] = xZ + wY
mat[..., 1, 0] = xY + wZ
mat[..., 1, 1] = 1.0 - (xX + zZ)
mat[..., 1, 2] = yZ - wX
mat[..., 2, 0] = xZ - wY
mat[..., 2, 1] = yZ + wX
mat[..., 2, 2] = 1.0 - (xX + yY)
return np.where((Nq > _FLOAT_EPS)[..., np.newaxis, np.newaxis], mat, np.eye(3))
================================================
FILE: d4rl/d4rl/utils/visualize_env.py
================================================
import gym
import d4rl
import click
import os
import gym
import numpy as np
import pickle
from mjrl.utils.gym_env import GymEnv
#from mjrl.policies.gaussian_mlp import MLP
DESC = '''
Helper script to visualize policy (in mjrl format).\n
USAGE:\n
Visualizes policy on the env\n
$ python visualize_policy.py --env_name door-v0 \n
$ python visualize_policy.py --env_name door-v0 --policy my_policy.pickle --mode evaluation --episodes 10 \n
'''
class RandomPolicy(object):
def __init__(self, env):
self.env = env
def get_action(self, obs):
return [self.env.action_space.sample(),
{'evaluation': self.env.action_space.sample()}]
# MAIN =========================================================
@click.command(help=DESC)
@click.option('--env_name', type=str, help='environment to load', required= True)
@click.option('--policy', type=str, help='absolute path of the policy file', default=None)
@click.option('--mode', type=str, help='exploration or evaluation mode for policy', default='evaluation')
@click.option('--seed', type=int, help='seed for generating environment instances', default=123)
@click.option('--episodes', type=int, help='number of episodes to visualize', default=10)
def main(env_name, policy, mode, seed, episodes):
e = GymEnv(env_name)
e.set_seed(seed)
"""
if policy is not None:
pi = pickle.load(open(policy, 'rb'))
else:
pi = MLP(e.spec, hidden_sizes=(32,32), seed=seed, init_log_std=-1.0)
"""
pi = RandomPolicy(e)
# render policy
e.visualize_policy(pi, num_episodes=episodes, horizon=e.horizon, mode=mode)
if __name__ == '__main__':
main()
================================================
FILE: d4rl/d4rl/utils/wrappers.py
================================================
import numpy as np
import itertools
from gym import Env
from gym.spaces import Box
from gym.spaces import Discrete
from collections import deque
class ProxyEnv(Env):
def __init__(self, wrapped_env):
self._wrapped_env = wrapped_env
self.action_space = self._wrapped_env.action_space
self.observation_space = self._wrapped_env.observation_space
@property
def wrapped_env(self):
return self._wrapped_env
def reset(self, **kwargs):
return self._wrapped_env.reset(**kwargs)
def step(self, action):
return self._wrapped_env.step(action)
def render(self, *args, **kwargs):
return self._wrapped_env.render(*args, **kwargs)
def seed(self, seed=0):
return self._wrapped_env.seed(seed=seed)
@property
def horizon(self):
return self._wrapped_env.horizon
def terminate(self):
if hasattr(self.wrapped_env, "terminate"):
self.wrapped_env.terminate()
def __getattr__(self, attr):
if attr == '_wrapped_env':
raise AttributeError()
return getattr(self._wrapped_env, attr)
def __getstate__(self):
"""
This is useful to override in case the wrapped env has some funky
__getstate__ that doesn't play well with overriding __getattr__.
The main problematic case is/was gym's EzPickle serialization scheme.
:return:
"""
return self.__dict__
def __setstate__(self, state):
self.__dict__.update(state)
def __str__(self):
return '{}({})'.format(type(self).__name__, self.wrapped_env)
class HistoryEnv(ProxyEnv, Env):
def __init__(self, wrapped_env, history_len):
super().__init__(wrapped_env)
self.history_len = history_len
high = np.inf * np.ones(
self.history_len * self.observation_space.low.size)
low = -high
self.observation_space = Box(low=low,
high=high,
)
self.history = deque(maxlen=self.history_len)
def step(self, action):
state, reward, done, info = super().step(action)
self.history.append(state)
flattened_history = self._get_history().flatten()
return flattened_history, reward, done, info
def reset(self, **kwargs):
state = super().reset()
self.history = deque(maxlen=self.history_len)
self.history.append(state)
flattened_history = self._get_history().flatten()
return flattened_history
def _get_history(self):
observations = list(self.history)
obs_count = len(observations)
for _ in range(self.history_len - obs_count):
dummy = np.zeros(self._wrapped_env.observation_space.low.size)
observations.append(dummy)
return np.c_[observations]
class DiscretizeEnv(ProxyEnv, Env):
def __init__(self, wrapped_env, num_bins):
super().__init__(wrapped_env)
low = self.wrapped_env.action_space.low
high = self.wrapped_env.action_space.high
action_ranges = [
np.linspace(low[i], high[i], num_bins)
for i in range(len(low))
]
self.idx_to_continuous_action = [
np.array(x) for x in itertools.product(*action_ranges)
]
self.action_space = Discrete(len(self.idx_to_continuous_action))
def step(self, action):
continuous_action = self.idx_to_continuous_action[action]
return super().step(continuous_action)
class NormalizedBoxEnv(ProxyEnv):
"""
Normalize action to in [-1, 1].
Optionally normalize observations and scale reward.
"""
def __init__(
self,
env,
reward_scale=1.,
obs_mean=None,
obs_std=None,
):
ProxyEnv.__init__(self, env)
self._should_normalize = not (obs_mean is None and obs_std is None)
if self._should_normalize:
if obs_mean is None:
obs_mean = np.zeros_like(env.observation_space.low)
else:
obs_mean = np.array(obs_mean)
if obs_std is None:
obs_std = np.ones_like(env.observation_space.low)
else:
obs_std = np.array(obs_std)
self._reward_scale = reward_scale
self._obs_mean = obs_mean
self._obs_std = obs_std
ub = np.ones(self._wrapped_env.action_space.shape)
self.action_space = Box(-1 * ub, ub)
def estimate_obs_stats(self, obs_batch, override_values=False):
if self._obs_mean is not None and not override_values:
raise Exception("Observation mean and std already set. To "
"override, set override_values to True.")
self._obs_mean = np.mean(obs_batch, axis=0)
self._obs_std = np.std(obs_batch, axis=0)
def _apply_normalize_obs(self, obs):
return (obs - self._obs_mean) / (self._obs_std + 1e-8)
def step(self, action):
lb = self._wrapped_env.action_space.low
ub = self._wrapped_env.action_space.high
scaled_action = lb + (action + 1.) * 0.5 * (ub - lb)
scaled_action = np.clip(scaled_action, lb, ub)
wrapped_step = self._wrapped_env.step(scaled_action)
next_obs, reward, done, info = wrapped_step
if self._should_normalize:
next_obs = self._apply_normalize_obs(next_obs)
return next_obs, reward * self._reward_scale, done, info
def __str__(self):
return "Normalized: %s" % self._wrapped_env
================================================
FILE: d4rl/scripts/check_antmaze_datasets.py
================================================
"""
This script runs sanity checks all datasets in a directory.
Usage:
python check_antmaze_datasets.py
"""
import numpy as np
import scipy as sp
import scipy.spatial
import h5py
import os
import argparse
def check_identical_values(dset):
""" Check that values are not identical """
check_keys = ['actions', 'observations', 'infos/qpos', 'infos/qvel']
for k in check_keys:
values = dset[k][:]
values_0 = values[0]
values_mid = values[values.shape[0]//2]
values_last = values[-1]
values = np.c_[values_0, values_mid, values_last].T
dists = sp.spatial.distance.pdist(values)
not_same = dists > 0
assert np.all(not_same)
def check_num_samples(dset):
""" Check that all keys have the same # samples """
check_keys = ['actions', 'observations', 'rewards', 'timeouts', 'terminals', 'infos/qpos', 'infos/qvel']
N = None
for k in check_keys:
values = dset[k]
if N is None:
N = values.shape[0]
else:
assert values.shape[0] == N
def check_reset_nonterminal(dataset):
""" Check if a reset occured on a non-terminal state."""
positions = dataset['observations'][:-1,0:2]
next_positions = dataset['observations'][1:,0:2]
diffs = np.linalg.norm(positions-next_positions, axis=1)
terminal = ((dataset['terminals'][:] + dataset['timeouts'][:]) > 0)[:-1]
num_resets = np.sum(diffs > 5.0)
num_nonterminal_reset = np.sum( (diffs > 5.0) * (1-terminal))
print('num reset:', num_resets)
print('nonreset term:', num_nonterminal_reset)
assert num_nonterminal_reset == 0
def print_avg_returns(dset):
""" Print returns for manual sanity checking. """
rew = dset['rewards'][:]
terminals = dset['terminals'][:]
timeouts = dset['timeouts'][:]
end_episode = (timeouts + terminals) > 0
all_returns = []
returns = 0
for i in range(rew.shape[0]):
returns += float(rew[i])
if end_episode[i]:
all_returns.append(returns)
returns = 0
print('Avg returns:', np.mean(all_returns))
print('# timeout:', np.sum(timeouts))
print('# terminals:', np.sum(terminals))
CHECK_FNS = [print_avg_returns, check_reset_nonterminal, check_identical_values, check_num_samples]
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('dirname', type=str, help='Directory containing HDF5 datasets')
args = parser.parse_args()
dirname = args.dirname
for fname in os.listdir(dirname):
if fname.endswith('.hdf5'):
hfile = h5py.File(os.path.join(dirname, fname))
print('Checking:', fname)
for check_fn in CHECK_FNS:
try:
check_fn(hfile)
except AssertionError as e:
print('Failed test:', check_fn.__name__)
#raise e
================================================
FILE: d4rl/scripts/check_bullet.py
================================================
"""
A quick script to run a sanity check on all environments.
"""
import gym
import d4rl
import numpy as np
ENVS = [
'bullet-halfcheetah-random-v0',
'bullet-halfcheetah-medium-v0',
'bullet-halfcheetah-expert-v0',
'bullet-halfcheetah-medium-replay-v0',
'bullet-halfcheetah-medium-expert-v0',
'bullet-walker2d-random-v0',
'bullet-walker2d-medium-v0',
'bullet-walker2d-expert-v0',
'bullet-walker2d-medium-replay-v0',
'bullet-walker2d-medium-expert-v0',
'bullet-hopper-random-v0',
'bullet-hopper-medium-v0',
'bullet-hopper-expert-v0',
'bullet-hopper-medium-replay-v0',
'bullet-hopper-medium-expert-v0',
'bullet-ant-random-v0',
'bullet-ant-medium-v0',
'bullet-ant-expert-v0',
'bullet-ant-medium-replay-v0',
'bullet-ant-medium-expert-v0',
'bullet-maze2d-open-v0',
'bullet-maze2d-umaze-v0',
'bullet-maze2d-medium-v0',
'bullet-maze2d-large-v0',
]
if __name__ == '__main__':
for env_name in ENVS:
print('Checking', env_name)
try:
env = gym.make(env_name)
except Exception as e:
print(e)
continue
dset = env.get_dataset()
print('\t Max episode steps:', env._max_episode_steps)
print('\t',dset['observations'].shape, dset['actions'].shape)
assert 'observations' in dset, 'Observations not in dataset'
assert 'actions' in dset, 'Actions not in dataset'
assert 'rewards' in dset, 'Rewards not in dataset'
assert 'terminals' in dset, 'Terminals not in dataset'
N = dset['observations'].shape[0]
print('\t %d samples' % N)
assert dset['actions'].shape[0] == N, 'Action number does not match (%d vs %d)' % (dset['actions'].shape[0], N)
assert dset['rewards'].shape[0] == N, 'Reward number does not match (%d vs %d)' % (dset['rewards'].shape[0], N)
assert dset['terminals'].shape[0] == N, 'Terminals number does not match (%d vs %d)' % (dset['terminals'].shape[0], N)
print('\t num terminals: %d' % np.sum(dset['terminals']))
print('\t avg rew: %f' % np.mean(dset['rewards']))
env.reset()
env.step(env.action_space.sample())
score = env.get_normalized_score(0.0)
================================================
FILE: d4rl/scripts/check_envs.py
================================================
"""
A quick script to run a sanity check on all environments.
"""
import gym
import d4rl
import numpy as np
ENVS = []
for agent in ['halfcheetah', 'hopper', 'walker2d', 'ant']:
for dataset in ['random', 'medium', 'expert', 'medium-replay', 'full-replay', 'medium-expert']:
ENVS.append(agent+'-'+dataset+'-v1')
for agent in ['door', 'pen', 'relocate', 'hammer']:
for dataset in ['expert', 'cloned', 'human']:
ENVS.append(agent+'-'+dataset+'-v1')
ENVS.extend([
'maze2d-open-v0',
'maze2d-umaze-v1',
'maze2d-medium-v1',
'maze2d-large-v1',
'maze2d-open-dense-v0',
'maze2d-umaze-dense-v1',
'maze2d-medium-dense-v1',
'maze2d-large-dense-v1',
'minigrid-fourrooms-v0',
'minigrid-fourrooms-random-v0',
'pen-human-v0',
'pen-cloned-v0',
'pen-expert-v0',
'hammer-human-v0',
'hammer-cloned-v0',
'hammer-expert-v0',
'relocate-human-v0',
'relocate-cloned-v0',
'relocate-expert-v0',
'door-human-v0',
'door-cloned-v0',
'door-expert-v0',
'antmaze-umaze-v0',
'antmaze-umaze-diverse-v0',
'antmaze-medium-play-v0',
'antmaze-medium-diverse-v0',
'antmaze-large-play-v0',
'antmaze-large-diverse-v0',
'mini-kitchen-microwave-kettle-light-slider-v0',
'kitchen-microwave-kettle-light-slider-v0',
'kitchen-microwave-kettle-bottomburner-light-v0',
])
if __name__ == '__main__':
for env_name in ENVS:
print('Checking', env_name)
try:
env = gym.make(env_name)
except Exception as e:
print(e)
continue
dset = env.get_dataset()
print('\t Max episode steps:', env._max_episode_steps)
print('\t',dset['observations'].shape, dset['actions'].shape)
assert 'observations' in dset, 'Observations not in dataset'
assert 'actions' in dset, 'Actions not in dataset'
assert 'rewards' in dset, 'Rewards not in dataset'
assert 'terminals' in dset, 'Terminals not in dataset'
N = dset['observations'].shape[0]
print('\t %d samples' % N)
assert dset['actions'].shape[0] == N, 'Action number does not match (%d vs %d)' % (dset['actions'].shape[0], N)
assert dset['rewards'].shape[0] == N, 'Reward number does not match (%d vs %d)' % (dset['rewards'].shape[0], N)
assert dset['terminals'].shape[0] == N, 'Terminals number does not match (%d vs %d)' % (dset['terminals'].shape[0], N)
orig_terminals = np.sum(dset['terminals'])
print('\t num terminals: %d' % np.sum(dset['terminals']))
env.reset()
env.step(env.action_space.sample())
score = env.get_normalized_score(0.0)
dset = d4rl.qlearning_dataset(env, dataset=dset)
assert 'observations' in dset, 'Observations not in dataset'
assert 'next_observations' in dset, 'Observations not in dataset'
assert 'actions' in dset, 'Actions not in dataset'
assert 'rewards' in dset, 'Rewards not in dataset'
assert 'terminals' in dset, 'Terminals not in dataset'
N = dset['observations'].shape[0]
print('\t %d samples' % N)
assert dset['next_observations'].shape[0] == N, 'NextObs number does not match (%d vs %d)' % (dset['actions'].shape[0], N)
assert dset['actions'].shape[0] == N, 'Action number does not match (%d vs %d)' % (dset['actions'].shape[0], N)
assert dset['rewards'].shape[0] == N, 'Reward number does not match (%d vs %d)' % (dset['rewards'].shape[0], N)
assert dset['terminals'].shape[0] == N, 'Terminals number does not match (%d vs %d)' % (dset['terminals'].shape[0], N)
print('\t num terminals: %d' % np.sum(dset['terminals']))
assert orig_terminals == np.sum(dset['terminals']), 'Qlearining terminals doesnt match original terminals'
================================================
FILE: d4rl/scripts/check_mujoco_datasets.py
================================================
"""
This script runs sanity checks all datasets in a directory.
Assumes all datasets in the directory are generated via mujoco and contain
the qpos/qvel keys.
Usage:
python check_mujoco_datasets.py
"""
import numpy as np
import scipy as sp
import scipy.spatial
import h5py
import os
import argparse
import tqdm
def check_identical_values(dset):
""" Check that values are not identical """
check_keys = ['actions', 'observations', 'infos/qpos', 'infos/qvel']
for k in check_keys:
values = dset[k][:]
values_0 = values[0]
values_mid = values[values.shape[0]//2]
values_last = values[-1]
values = np.c_[values_0, values_mid, values_last].T
dists = sp.spatial.distance.pdist(values)
not_same = dists > 0
assert np.all(not_same)
def check_qpos_qvel(dset):
""" Check that qpos/qvel produces correct state"""
import gym
import d4rl
N = dset['rewards'].shape[0]
qpos = dset['infos/qpos']
qvel = dset['infos/qvel']
obs = dset['observations']
reverse_env_map = {v.split('/')[-1]: k for (k, v) in d4rl.infos.DATASET_URLS.items()}
env_name = reverse_env_map[dset.filename.split('/')[-1]]
env = gym.make(env_name)
env.reset()
print('checking qpos/qvel')
for t in tqdm.tqdm(range(N)):
env.set_state(qpos[t], qvel[t])
env_obs = env.env.wrapped_env._get_obs()
error = ((obs[t] - env_obs)**2).sum()
assert error < 1e-8
def check_num_samples(dset):
""" Check that all keys have the same # samples """
check_keys = ['actions', 'observations', 'rewards', 'timeouts', 'terminals', 'infos/qpos', 'infos/qvel']
N = None
for k in check_keys:
values = dset[k]
if N is None:
N = values.shape[0]
else:
assert values.shape[0] == N
def check_reset_state(dset):
""" Check that resets correspond approximately to the initial state """
obs = dset['observations'][:]
N = obs.shape[0]
terminals = dset['terminals'][:]
timeouts = dset['timeouts'][:]
end_episode = (timeouts + terminals) > 0
# Use the first observation as a reference initial state
reset_state = obs[0]
# Make sure all reset observations are close to the reference initial state
# Take up to [:-1] in case last entry in dataset is terminal
end_idxs = np.where(end_episode)[0][:-1]
diffs = obs[1:] - reset_state
dists = np.linalg.norm(diffs, axis=1)
min_dist = np.min(dists)
reset_dists = dists[end_idxs] #don't add idx +1 because we took the obs[:1] slice
print('max reset:', np.max(reset_dists))
print('min reset:', np.min(reset_dists))
assert np.all(reset_dists < (min_dist + 1e-2) * 5)
def print_avg_returns(dset):
""" Print returns for manual sanity checking. """
rew = dset['rewards'][:]
terminals = dset['terminals'][:]
timeouts = dset['timeouts'][:]
end_episode = (timeouts + terminals) > 0
all_returns = []
returns = 0
for i in range(rew.shape[0]):
returns += float(rew[i])
if end_episode[i]:
all_returns.append(returns)
returns = 0
print('Avg returns:', np.mean(all_returns))
print('# timeout:', np.sum(timeouts))
print('# terminals:', np.sum(terminals))
CHECK_FNS = [print_avg_returns, check_qpos_qvel, check_reset_state, check_identical_values, check_num_samples]
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('dirname', type=str, help='Directory containing HDF5 datasets')
args = parser.parse_args()
dirname = args.dirname
for fname in os.listdir(dirname):
if fname.endswith('.hdf5'):
hfile = h5py.File(os.path.join(dirname, fname))
print('Checking:', fname)
for check_fn in CHECK_FNS:
try:
check_fn(hfile)
except AssertionError as e:
print('Failed test:', check_fn.__name__)
raise e
================================================
FILE: d4rl/scripts/generation/flow_idm.py
================================================
import numpy as np
import argparse
import gym
import d4rl.flow
from d4rl.utils import dataset_utils
from flow.controllers import car_following_models
def main():
parser = argparse.ArgumentParser()
#parser.add_argument('--render', action='store_true', help='Render trajectories')
#parser.add_argument('--type', action='store_true', help='Noisy actions')
parser.add_argument('--controller', type=str, default='idm', help='random, idm')
parser.add_argument('--env_name', type=str, default='flow-ring-v0', help='Maze type. small or default')
parser.add_argument('--num_samples', type=int, default=int(1e6), help='Num samples to collect')
args = parser.parse_args()
env = gym.make(args.env_name)
env.reset()
print(env.action_space)
if args.controller == 'idm':
uenv = env.unwrapped
veh_ids = uenv.k.vehicle.get_rl_ids()
if hasattr(uenv, 'num_rl'):
num_rl = uenv.num_rl
else:
num_rl = len(veh_ids)
if num_rl == 0:
raise ValueError("No RL vehicles")
controllers = []
acc_controller = uenv.k.vehicle.get_acc_controller(uenv.k.vehicle.get_ids()[0])
car_following_params = acc_controller.car_following_params
#for veh_id in veh_ids:
# controllers.append(car_following_models.IDMController(veh_id, car_following_params=car_following_params))
def get_action(s):
actions = np.zeros_like(env.action_space.sample())
for i, veh_id in enumerate(uenv.k.vehicle.get_rl_ids()):
if i >= actions.shape[0]:
break
actions[i] = car_following_models.IDMController(veh_id, car_following_params=car_following_params).get_accel(env)
return actions
elif args.controller == 'random':
def get_action(s):
return env.action_space.sample()
else:
raise ValueError("Unknown controller type: %s" % str(args.controller))
writer = dataset_utils.DatasetWriter()
while len(writer) < args.num_samples:
s = env.reset()
ret = 0
for _ in range(env._max_episode_steps):
action = get_action(s)
ns , r, done, infos = env.step(action)
ret += r
writer.append_data(s, action, r, done)
s = ns
print(ret)
#env.render()
fname = '%s-%s.hdf5' % (args.env_name, args.controller)
writer.write_dataset(fname, max_size=args.num_samples)
if __name__ == "__main__":
main()
================================================
FILE: d4rl/scripts/generation/generate_ant_maze_datasets.py
================================================
import numpy as np
import pickle
import gzip
import h5py
import argparse
from d4rl.locomotion import maze_env, ant, swimmer
from d4rl.locomotion.wrappers import NormalizedBoxEnv
import torch
from PIL import Image
import os
def reset_data():
return {'observations': [],
'actions': [],
'terminals': [],
'timeouts': [],
'rewards': [],
'infos/goal': [],
'infos/qpos': [],
'infos/qvel': [],
}
def append_data(data, s, a, r, tgt, done, timeout, env_data):
data['observations'].append(s)
data['actions'].append(a)
data['rewards'].append(r)
data['terminals'].append(done)
data['timeouts'].append(timeout)
data['infos/goal'].append(tgt)
data['infos/qpos'].append(env_data.qpos.ravel().copy())
data['infos/qvel'].append(env_data.qvel.ravel().copy())
def npify(data):
for k in data:
if k in ['terminals', 'timeouts']:
dtype = np.bool_
else:
dtype = np.float32
data[k] = np.array(data[k], dtype=dtype)
def load_policy(policy_file):
data = torch.load(policy_file)
policy = data['exploration/policy'].to('cpu')
env = data['evaluation/env']
print("Policy loaded")
return policy, env
def save_video(save_dir, file_name, frames, episode_id=0):
filename = os.path.join(save_dir, file_name+ '_episode_{}'.format(episode_id))
if not os.path.exists(filename):
os.makedirs(filename)
num_frames = frames.shape[0]
for i in range(num_frames):
img = Image.fromarray(np.flipud(frames[i]), 'RGB')
img.save(os.path.join(filename, 'frame_{}.png'.format(i)))
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--noisy', action='store_true', help='Noisy actions')
parser.add_argument('--maze', type=str, default='umaze', help='Maze type. umaze, medium, or large')
parser.add_argument('--num_samples', type=int, default=int(1e6), help='Num samples to collect')
parser.add_argument('--env', type=str, default='Ant', help='Environment type')
parser.add_argument('--policy_file', type=str, default='policy_file', help='file_name')
parser.add_argument('--max_episode_steps', default=1000, type=int)
parser.add_argument('--video', action='store_true')
parser.add_argument('--multi_start', action='store_true')
parser.add_argument('--multigoal', action='store_true')
args = parser.parse_args()
if args.maze == 'umaze':
maze = maze_env.U_MAZE
elif args.maze == 'medium':
maze = maze_env.BIG_MAZE
elif args.maze == 'large':
maze = maze_env.HARDEST_MAZE
elif args.maze == 'umaze_eval':
maze = maze_env.U_MAZE_EVAL
elif args.maze == 'medium_eval':
maze = maze_env.BIG_MAZE_EVAL
elif args.maze == 'large_eval':
maze = maze_env.HARDEST_MAZE_EVAL
else:
raise NotImplementedError
if args.env == 'Ant':
env = NormalizedBoxEnv(ant.AntMazeEnv(maze_map=maze, maze_size_scaling=4.0, non_zero_reset=args.multi_start))
elif args.env == 'Swimmer':
env = NormalizedBoxEnv(swimmer.SwimmerMazeEnv(mmaze_map=maze, maze_size_scaling=4.0, non_zero_reset=args.multi_start))
else:
raise NotImplementedError
env.set_target()
s = env.reset()
act = env.action_space.sample()
done = False
# Load the policy
policy, train_env = load_policy(args.policy_file)
# Define goal reaching policy fn
def _goal_reaching_policy_fn(obs, goal):
goal_x, goal_y = goal
obs_new = obs[2:-2]
goal_tuple = np.array([goal_x, goal_y])
# normalize the norm of the relative goals to in-distribution values
goal_tuple = goal_tuple / np.linalg.norm(goal_tuple) * 10.0
new_obs = np.concatenate([obs_new, goal_tuple], -1)
return policy.get_action(new_obs)[0], (goal_tuple[0] + obs[0], goal_tuple[1] + obs[1])
data = reset_data()
# create waypoint generating policy integrated with high level controller
data_collection_policy = env.create_navigation_policy(
_goal_reaching_policy_fn,
)
if args.video:
frames = []
ts = 0
num_episodes = 0
for _ in range(args.num_samples):
act, waypoint_goal = data_collection_policy(s)
if args.noisy:
act = act + np.random.randn(*act.shape)*0.2
act = np.clip(act, -1.0, 1.0)
ns, r, done, info = env.step(act)
timeout = False
if ts >= args.max_episode_steps:
timeout = True
#done = True
append_data(data, s[:-2], act, r, env.target_goal, done, timeout, env.physics.data)
if len(data['observations']) % 10000 == 0:
print(len(data['observations']))
ts += 1
if done or timeout:
done = False
ts = 0
s = env.reset()
env.set_target_goal()
if args.video:
frames = np.array(frames)
save_video('./videos/', args.env + '_navigation', frames, num_episodes)
num_episodes += 1
frames = []
else:
s = ns
if args.video:
curr_frame = env.physics.render(width=500, height=500, depth=False)
frames.append(curr_frame)
if args.noisy:
fname = args.env + '_maze_%s_noisy_multistart_%s_multigoal_%s.hdf5' % (args.maze, str(args.multi_start), str(args.multigoal))
else:
fname = args.env + 'maze_%s_multistart_%s_multigoal_%s.hdf5' % (args.maze, str(args.multi_start), str(args.multigoal))
dataset = h5py.File(fname, 'w')
npify(data)
for k in data:
dataset.create_dataset(k, data=data[k], compression='gzip')
if __name__ == '__main__':
main()
================================================
FILE: d4rl/scripts/generation/generate_kitchen_datasets.py
================================================
"""Script for generating the datasets for kitchen environments."""
import d4rl.kitchen
import glob
import gym
import h5py
import numpy as np
import os
import pickle
np.set_printoptions(precision=2, suppress=True)
SAVE_DIRECTORY = '~/.offline_rl/datasets'
DEMOS_DIRECTORY = '~/relay-policy-learning/kitchen_demos_multitask'
DEMOS_SUBDIR_PATTERN = '*'
ENVIRONMENTS = ['kitchen_microwave_kettle_light_slider-v0',
'kitchen_microwave_kettle_bottomburner_light-v0']
# Uncomment lines below for "mini_kitchen_microwave_kettle_light_slider-v0'".
DEMOS_SUBDIR_PATTERN = '*microwave_kettle_switch_slide'
ENVIRONMENTS = ['mini_kitchen_microwave_kettle_light_slider-v0']
OBS_ELEMENT_INDICES = [
[11, 12], # Bottom burners.
[15, 16], # Top burners.
[17, 18], # Light switch.
[19], # Slide.
[20, 21], # Hinge.
[22], # Microwave.
[23, 24, 25, 26, 27, 28, 29], # Kettle.
]
FLAT_OBS_ELEMENT_INDICES = sum(OBS_ELEMENT_INDICES, [])
def _relabel_obs_with_goal(obs_array, goal):
obs_array[..., 30:] = goal
return obs_array
def _obs_array_to_obs_dict(obs_array, goal=None):
obs_dict = {
'qp': obs_array[:9],
'obj_qp': obs_array[9:30],
'goal': goal,
}
if obs_dict['goal'] is None:
obs_dict['goal'] = obs_array[30:]
return obs_dict
def main():
pattern = os.path.join(DEMOS_DIRECTORY, DEMOS_SUBDIR_PATTERN)
demo_subdirs = sorted(glob.glob(pattern))
print('Found %d demo subdirs.' % len(demo_subdirs))
all_demos = {}
for demo_subdir in demo_subdirs:
demo_files = glob.glob(os.path.join(demo_subdir, '*.pkl'))
print('Found %d demos in %s.' % (len(demo_files), demo_subdir))
demos = []
for demo_file in demo_files:
with open(demo_file, 'rb') as f:
demo = pickle.load(f)
demos.append(demo)
all_demos[demo_subdir] = demos
# For debugging...
all_observations = [demo['observations'] for demo in demos]
first_elements = [obs[0, FLAT_OBS_ELEMENT_INDICES]
for obs in all_observations]
last_elements = [obs[-1, FLAT_OBS_ELEMENT_INDICES]
for obs in all_observations]
# End for debugging.
for env_name in ENVIRONMENTS:
env = gym.make(env_name).unwrapped
env.REMOVE_TASKS_WHEN_COMPLETE = False # This enables a Markovian reward.
all_obs = []
all_actions = []
all_rewards = []
all_terminals = []
all_infos = []
print('Relabelling data for %s.' % env_name)
for demo_subdir, demos in all_demos.items():
print('On demo from %s.' % demo_subdir)
demos_obs = []
demos_actions = []
demos_rewards = []
demos_terminals = []
demos_infos = []
for idx, demo in enumerate(demos):
env_goal = env._get_task_goal()
rewards = []
relabelled_obs = _relabel_obs_with_goal(demo['observations'], env_goal)
for obs in relabelled_obs:
reward_dict, score = env._get_reward_n_score(
_obs_array_to_obs_dict(obs))
rewards.append(reward_dict['r_total'])
terminate_at = len(rewards)
rewards = rewards[:terminate_at]
demos_obs.append(relabelled_obs[:terminate_at])
demos_actions.append(demo['actions'][:terminate_at])
demos_rewards.append(np.array(rewards))
demos_terminals.append(np.arange(len(rewards)) >= len(rewards) - 1)
demos_infos.append([idx] * len(rewards))
all_obs.append(np.concatenate(demos_obs))
all_actions.append(np.concatenate(demos_actions))
all_rewards.append(np.concatenate(demos_rewards))
all_terminals.append(np.concatenate(demos_terminals))
all_infos.append(np.concatenate(demos_infos))
episode_rewards = [np.sum(rewards) for rewards in demos_rewards]
last_rewards = [rewards[-1] for rewards in demos_rewards]
print('Avg episode rewards %f.' % np.mean(episode_rewards))
print('Avg last step rewards %f.' % np.mean(last_rewards))
dataset_obs = np.concatenate(all_obs).astype('float32')
dataset_actions = np.concatenate(all_actions).astype('float32')
dataset_rewards = np.concatenate(all_rewards).astype('float32')
dataset_terminals = np.concatenate(all_terminals).astype('float32')
dataset_infos = np.concatenate(all_infos)
dataset_size = len(dataset_obs)
assert dataset_size == len(dataset_actions)
assert dataset_size == len(dataset_rewards)
assert dataset_size == len(dataset_terminals)
assert dataset_size == len(dataset_infos)
dataset = {
'observations': dataset_obs,
'actions': dataset_actions,
'rewards': dataset_rewards,
'terminals': dataset_terminals,
'infos': dataset_infos,
}
print('Generated dataset with %d total steps.' % dataset_size)
save_filename = os.path.join(SAVE_DIRECTORY, '%s.hdf5' % env_name)
print('Saving dataset to %s.' % save_filename)
h5_dataset = h5py.File(save_filename, 'w')
for key in dataset:
h5_dataset.create_dataset(key, data=dataset[key], compression='gzip')
print('Done.')
if __name__ == '__main__':
main()
================================================
FILE: d4rl/scripts/generation/generate_maze2d_bullet_datasets.py
================================================
import gym
import logging
from d4rl.pointmaze import waypoint_controller
from d4rl.pointmaze_bullet import bullet_maze
from d4rl.pointmaze import maze_model
import numpy as np
import pickle
import gzip
import h5py
import argparse
import time
def reset_data():
return {'observations': [],
'actions': [],
'terminals': [],
'timeouts': [],
'rewards': [],
'infos/goal': [],
'infos/qpos': [],
'infos/qvel': [],
}
def append_data(data, s, a, tgt, done, timeout, robot):
data['observations'].append(s)
data['actions'].append(a)
data['rewards'].append(0.0)
data['terminals'].append(False)
data['timeouts'].append(False)
data['infos/goal'].append(tgt)
data['infos/goal_reached'].append(done)
data['infos/goal_timeout'].append(timeout)
data['infos/qpos'].append(robot.qpos.copy())
data['infos/qvel'].append(robot.qvel.copy())
def npify(data):
for k in data:
if k == 'terminals' or k == 'timeouts':
dtype = np.bool_
else:
dtype = np.float32
data[k] = np.array(data[k], dtype=dtype)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--render', action='store_true', help='Render trajectories')
parser.add_argument('--noisy', action='store_true', help='Noisy actions')
parser.add_argument('--env_name', type=str, default='maze2d-umaze-v1', help='Maze type')
parser.add_argument('--num_samples', type=int, default=int(1e6), help='Num samples to collect')
args = parser.parse_args()
env = gym.make(args.env_name)
maze = env.str_maze_spec
max_episode_steps = env._max_episode_steps
# default: p=10, d=-1
controller = waypoint_controller.WaypointController(maze, p_gain=10.0, d_gain=-2.0)
env = bullet_maze.Maze2DBulletEnv(maze)
if args.render:
env.render('human')
env.set_target()
s = env.reset()
act = env.action_space.sample()
timeout = False
data = reset_data()
last_position = s[0:2]
ts = 0
for _ in range(args.num_samples):
position = s[0:2]
velocity = s[2:4]
# subtract 1.0 due to offset between tabular maze representation and bullet state
act, done = controller.get_action(position , velocity, env._target)
if args.noisy:
act = act + np.random.randn(*act.shape)*0.5
act = np.clip(act, -1.0, 1.0)
if ts >= max_episode_steps:
timeout = True
append_data(data, s, act, env._target, done, timeout, env.robot)
ns, _, _, _ = env.step(act)
if len(data['observations']) % 10000 == 0:
print(len(data['observations']))
ts += 1
if done:
env.set_target()
done = False
ts = 0
else:
last_position = s[0:2]
s = ns
if args.render:
env.render('human')
if args.noisy:
fname = '%s-noisy-bullet.hdf5' % args.env_name
else:
fname = '%s-bullet.hdf5' % args.env_name
dataset = h5py.File(fname, 'w')
npify(data)
for k in data:
dataset.create_dataset(k, data=data[k], compression='gzip')
if __name__ == "__main__":
main()
================================================
FILE: d4rl/scripts/generation/generate_maze2d_datasets.py
================================================
import gym
import logging
from d4rl.pointmaze import waypoint_controller
from d4rl.pointmaze import maze_model
import numpy as np
import pickle
import gzip
import h5py
import argparse
def reset_data():
return {'observations': [],
'actions': [],
'terminals': [],
'rewards': [],
'infos/goal': [],
'infos/qpos': [],
'infos/qvel': [],
}
def append_data(data, s, a, tgt, done, env_data):
data['observations'].append(s)
data['actions'].append(a)
data['rewards'].append(0.0)
data['terminals'].append(done)
data['infos/goal'].append(tgt)
data['infos/qpos'].append(env_data.qpos.ravel().copy())
data['infos/qvel'].append(env_data.qvel.ravel().copy())
def npify(data):
for k in data:
if k == 'terminals':
dtype = np.bool_
else:
dtype = np.float32
data[k] = np.array(data[k], dtype=dtype)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--render', action='store_true', help='Render trajectories')
parser.add_argument('--noisy', action='store_true', help='Noisy actions')
parser.add_argument('--env_name', type=str, default='maze2d-umaze-v1', help='Maze type')
parser.add_argument('--num_samples', type=int, default=int(1e6), help='Num samples to collect')
args = parser.parse_args()
env = gym.make(args.env_name)
maze = env.str_maze_spec
max_episode_steps = env._max_episode_steps
controller = waypoint_controller.WaypointController(maze)
env = maze_model.MazeEnv(maze)
env.set_target()
s = env.reset()
act = env.action_space.sample()
done = False
data = reset_data()
ts = 0
for _ in range(args.num_samples):
position = s[0:2]
velocity = s[2:4]
act, done = controller.get_action(position, velocity, env._target)
if args.noisy:
act = act + np.random.randn(*act.shape)*0.5
act = np.clip(act, -1.0, 1.0)
if ts >= max_episode_steps:
done = True
append_data(data, s, act, env._target, done, env.sim.data)
ns, _, _, _ = env.step(act)
if len(data['observations']) % 10000 == 0:
print(len(data['observations']))
ts += 1
if done:
env.set_target()
done = False
ts = 0
else:
s = ns
if args.render:
env.render()
if args.noisy:
fname = '%s-noisy.hdf5' % args.env_name
else:
fname = '%s.hdf5' % args.env_name
dataset = h5py.File(fname, 'w')
npify(data)
for k in data:
dataset.create_dataset(k, data=data[k], compression='gzip')
if __name__ == "__main__":
main()
================================================
FILE: d4rl/scripts/generation/generate_minigrid_fourroom_data.py
================================================
import logging
from offline_rl.gym_minigrid import fourroom_controller
from offline_rl.gym_minigrid.envs import fourrooms
import numpy as np
import pickle
import gzip
import h5py
import argparse
def reset_data():
return {'observations': [],
'actions': [],
'terminals': [],
'rewards': [],
'infos/goal': [],
'infos/pos': [],
'infos/orientation': [],
}
def append_data(data, s, a, tgt, done, pos, ori):
data['observations'].append(s)
data['actions'].append(a)
data['rewards'].append(0.0)
data['terminals'].append(done)
data['infos/goal'].append(tgt)
data['infos/pos'].append(pos)
data['infos/orientation'].append(ori)
def npify(data):
for k in data:
if k == 'terminals':
dtype = np.bool_
else:
dtype = np.float32
data[k] = np.array(data[k], dtype=dtype)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--render', action='store_true', help='Render trajectories')
parser.add_argument('--random', action='store_true', help='Noisy actions')
parser.add_argument('--num_samples', type=int, default=int(1e5), help='Num samples to collect')
args = parser.parse_args()
controller = fourroom_controller.FourRoomController()
env = fourrooms.FourRoomsEnv()
controller.set_target(controller.sample_target())
s = env.reset()
act = env.action_space.sample()
done = False
data = reset_data()
ts = 0
for _ in range(args.num_samples):
if args.render:
env.render()
if args.random:
act = env.action_space.sample()
else:
act, done = controller.get_action(env.agent_pos, env.agent_dir)
if ts >= 50:
done = True
append_data(data, s['image'], act, controller.target, done, env.agent_pos, env.agent_dir)
ns, _, _, _ = env.step(act)
if len(data['observations']) % 10000 == 0:
print(len(data['observations']))
ts += 1
if done:
controller.set_target(controller.sample_target())
done = False
ts = 0
else:
s = ns
if args.random:
fname = 'minigrid4rooms_random.hdf5'
else:
fname = 'minigrid4rooms.hdf5'
dataset = h5py.File(fname, 'w')
npify(data)
for k in data:
dataset.create_dataset(k, data=data[k], compression='gzip')
if __name__ == "__main__":
main()
================================================
FILE: d4rl/scripts/generation/hand_dapg_combined.py
================================================
import gym
import d4rl
import argparse
import os
import numpy as np
import h5py
def get_keys(h5file):
keys = []
def visitor(name, item):
if isinstance(item, h5py.Dataset):
keys.append(name)
h5file.visititems(visitor)
return keys
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='')
parser.add_argument('--env_name', type=str, default='pen', help='Env name')
parser.add_argument('--bc', type=str, help='BC hdf5 dataset')
parser.add_argument('--human', type=str, help='Human demos hdf5 dataset')
args = parser.parse_args()
env = gym.make('%s-v0' % args.env_name)
human_dataset = h5py.File(args.human, 'r')
bc_dataset = h5py.File(args.bc, 'r')
N = env._max_episode_steps * 5000
# search for nearest terminal after the halfway mark
halfN = N // 2
terms = bc_dataset['terminals'][:]
tos = bc_dataset['timeouts'][:]
last_term = 0
for i in range(halfN, N):
if terms[i] or tos[i]:
last_term = i
break
halfN = last_term + 1
remaining_N = N - halfN
aug_dataset = h5py.File('%s-cloned-v1.hdf5' % args.env_name, 'w')
for k in get_keys(bc_dataset):
if 'metadata' not in k:
human_data = human_dataset[k][:]
bc_data = bc_dataset[k][:halfN]
print(k, human_data.shape, bc_data.shape)
N_tile = int(halfN / human_data.shape[0]) + 1
if len(human_data.shape) == 1:
human_data = np.tile(human_data, [N_tile])[:remaining_N]
elif len(human_data.shape) == 2:
human_data = np.tile(human_data, [N_tile, 1])[:remaining_N]
else:
raise NotImplementedError()
# clone demo_data
aug_data = np.concatenate([bc_data, human_data], axis=0)
assert aug_data.shape[1:] == bc_data.shape[1:]
assert aug_data.shape[1:] == human_data.shape[1:]
print('\t',human_data.shape, bc_data.shape, '->',aug_data.shape)
aug_dataset.create_dataset(k, data=aug_data, compression='gzip')
else:
shape = bc_dataset[k].shape
print('metadata:', k, shape)
if len(shape) == 0:
aug_dataset[k] = bc_dataset[k][()]
else:
aug_dataset[k] = bc_dataset[k][:]
================================================
FILE: d4rl/scripts/generation/hand_dapg_demos.py
================================================
import d4rl
import click
import os
import gym
import numpy as np
import pickle
import h5py
import collections
from mjrl.utils.gym_env import GymEnv
DESC = '''
Helper script to visualize demonstrations.\n
USAGE:\n
Visualizes demonstrations on the env\n
$ python utils/visualize_demos --env_name relocate-v0\n
'''
# MAIN =========================================================
@click.command(help=DESC)
@click.option('--env_name', type=str, help='environment to load', default='door-v0')
def main(env_name):
if env_name is "":
print("Unknown env.")
return
demos = pickle.load(open('./demonstrations/'+env_name+'_demos.pickle', 'rb'))
# render demonstrations
demo_playback(env_name, demos, clip=True)
def demo_playback(env_name, demo_paths, clip=False):
e = gym.make(env_name)
e.reset()
obs_ = []
act_ = []
rew_ = []
term_ = []
timeout_ = []
info_qpos_ = []
info_qvel_ = []
info_env_state_ = collections.defaultdict(list)
for i, path in enumerate(demo_paths):
e.set_env_state(path['init_state_dict'])
actions = path['actions']
returns = 0
for t in range(actions.shape[0]):
obs_.append(e.get_obs())
info_qpos_.append(e.env.data.qpos.ravel().copy())
info_qvel_.append(e.env.data.qvel.ravel().copy())
[info_env_state_[k].append(v) for k,v in e.get_env_state().items()]
commanded_action = actions[t]
if clip:
commanded_action = np.clip(commanded_action, -1.0, 1.0)
act_.append(commanded_action)
_, rew, _, info = e.step(commanded_action)
returns += rew
rew_.append(rew)
done = False
timeout = False
if t == (actions.shape[0]-1):
timeout = True
#if t == (e._max_episode_steps-1):
# timeout = True
# done = False
term_.append(done)
timeout_.append(timeout)
#e.env.mj_render() # this is much faster
#e.render()
print(i, returns, returns/float(actions.shape[0]))
# write out hdf5 file
obs_ = np.array(obs_).astype(np.float32)
act_ = np.array(act_).astype(np.float32)
rew_ = np.array(rew_).astype(np.float32)
term_ = np.array(term_).astype(np.bool_)
timeout_ = np.array(timeout_).astype(np.bool_)
info_qpos_ = np.array(info_qpos_).astype(np.float32)
info_qvel_ = np.array(info_qvel_).astype(np.float32)
if clip:
dataset = h5py.File('%s_demos_clipped.hdf5' % env_name, 'w')
else:
dataset = h5py.File('%s_demos.hdf5' % env_name, 'w')
#dataset.create_dataset('observations', obs_.shape, dtype='f4')
dataset.create_dataset('observations', data=obs_, compression='gzip')
dataset.create_dataset('actions', data=act_, compression='gzip')
dataset.create_dataset('rewards', data=rew_, compression='gzip')
dataset.create_dataset('terminals', data=term_, compression='gzip')
dataset.create_dataset('timeouts', data=timeout_, compression='gzip')
#dataset['infos/qpos'] = info_qpos_
#dataset['infos/qvel'] = info_qvel_
for k in info_env_state_:
dataset.create_dataset('infos/%s' % k, data=np.array(info_env_state_[k], dtype=np.float32), compression='gzip')
if __name__ == '__main__':
main()
================================================
FILE: d4rl/scripts/generation/hand_dapg_jax.py
================================================
import d4rl
import click
import h5py
import os
import gym
import numpy as np
import pickle
import gzip
import collections
from mjrl.utils.gym_env import GymEnv
DESC = '''
Helper script to visualize policy (in mjrl format).\n
USAGE:\n
Visualizes policy on the env\n
$ python utils/visualize_policy --env_name relocate-v0 --policy policies/relocate-v0.pickle --mode evaluation\n
'''
# MAIN =========================================================
@click.command(help=DESC)
@click.option('--env_name', type=str, help='environment to load', required= True)
@click.option('--snapshot_file', type=str, help='absolute path of the policy file', required=True)
@click.option('--num_trajs', type=int, help='Num trajectories', default=5000)
@click.option('--mode', type=str, help='exploration or evaluation mode for policy', default='evaluation')
def main(env_name, snapshot_file, mode, num_trajs, clip=True):
e = GymEnv(env_name)
pi = pickle.load(gzip.open(snapshot_file, 'rb'))
import pdb; pdb.set_trace()
pass
# render policy
#pol_playback(env_name, pi, num_trajs, clip=clip)
def extract_params(policy):
out_dict = {
'fc0/weight': _fc0w,
'fc0/bias': _fc0b,
'fc1/weight': params[2].data.numpy(),
'fc1/bias': params[3].data.numpy(),
'last_fc/weight': _fclw,
'last_fc/bias': _fclb,
'last_fc_log_std/weight': _fclw,
'last_fc_log_std/bias': _fclb,
}
return out_dict
def pol_playback(env_name, pi, num_trajs=100, clip=True):
e = gym.make(env_name)
e.reset()
obs_ = []
act_ = []
rew_ = []
term_ = []
timeout_ = []
info_qpos_ = []
info_qvel_ = []
info_mean_ = []
info_logstd_ = []
info_env_state_ = collections.defaultdict(list)
ravg = []
for n in range(num_trajs):
e.reset()
returns = 0
for t in range(e._max_episode_steps):
obs = e.get_obs()
obs_.append(obs)
info_qpos_.append(e.env.data.qpos.ravel().copy())
info_qvel_.append(e.env.data.qvel.ravel().copy())
[info_env_state_[k].append(v) for k,v in e.get_env_state().items()]
action, infos = pi.get_action(obs)
action = pi.get_action(obs)[0] # eval
if clip:
action = np.clip(action, -1, 1)
act_.append(action)
info_mean_.append(infos['mean'])
info_logstd_.append(infos['log_std'])
_, rew, done, info = e.step(action)
returns += rew
rew_.append(rew)
if t == (e._max_episode_steps-1):
timeout = True
done = False
else:
timeout = False
term_.append(done)
timeout_.append(timeout)
if done or timeout:
e.reset()
break
#e.env.mj_render() # this is much faster
# e.render()
ravg.append(returns)
print(n, returns, t)
# write out hdf5 file
obs_ = np.array(obs_).astype(np.float32)
act_ = np.array(act_).astype(np.float32)
rew_ = np.array(rew_).astype(np.float32)
term_ = np.array(term_).astype(np.bool_)
timeout_ = np.array(timeout_).astype(np.bool_)
info_qpos_ = np.array(info_qpos_).astype(np.float32)
info_qvel_ = np.array(info_qvel_).astype(np.float32)
info_mean_ = np.array(info_mean_).astype(np.float32)
info_logstd_ = np.array(info_logstd_).astype(np.float32)
if clip:
dataset = h5py.File('%s_expert_clip.hdf5' % env_name, 'w')
else:
dataset = h5py.File('%s_expert.hdf5' % env_name, 'w')
#dataset.create_dataset('observations', obs_.shape, dtype='f4')
dataset.create_dataset('observations', data=obs_, compression='gzip')
dataset.create_dataset('actions', data=act_, compression='gzip')
dataset.create_dataset('rewards', data=rew_, compression='gzip')
dataset.create_dataset('terminals', data=term_, compression='gzip')
dataset.create_dataset('timeouts', data=timeout_, compression='gzip')
#dataset.create_dataset('infos/qpos', data=info_qpos_, compression='gzip')
#dataset.create_dataset('infos/qvel', data=info_qvel_, compression='gzip')
dataset.create_dataset('infos/action_mean', data=info_mean_, compression='gzip')
dataset.create_dataset('infos/action_log_std', data=info_logstd_, compression='gzip')
for k in info_env_state_:
dataset.create_dataset('infos/%s' % k, data=np.array(info_env_state_[k], dtype=np.float32), compression='gzip')
# write metadata
policy_params = extract_params(pi)
dataset['metadata/algorithm'] = np.string_('DAPG')
dataset['metadata/policy/nonlinearity'] = np.string_('tanh')
dataset['metadata/policy/output_distribution'] = np.string_('gaussian')
for k, v in policy_params.items():
dataset['metadata/policy/'+k] = v
if __name__ == '__main__':
main()
================================================
FILE: d4rl/scripts/generation/hand_dapg_policies.py
================================================
import d4rl
import click
import h5py
import os
import gym
import numpy as np
import pickle
import collections
from mjrl.utils.gym_env import GymEnv
DESC = '''
Helper script to visualize policy (in mjrl format).\n
USAGE:\n
Visualizes policy on the env\n
$ python utils/visualize_policy --env_name relocate-v0 --policy policies/relocate-v0.pickle --mode evaluation\n
'''
# MAIN =========================================================
@click.command(help=DESC)
@click.option('--env_name', type=str, help='environment to load', required= True)
#@click.option('--policy', type=str, help='absolute path of the policy file', required=True)
@click.option('--num_trajs', type=int, help='Num trajectories', default=5000)
@click.option('--mode', type=str, help='exploration or evaluation mode for policy', default='evaluation')
def main(env_name, mode, num_trajs, clip=True):
e = GymEnv(env_name)
policy = './policies/'+env_name+'.pickle'
pi = pickle.load(open(policy, 'rb'))
# render policy
pol_playback(env_name, pi, num_trajs, clip=clip)
def extract_params(policy):
params = policy.trainable_params
in_shift = policy.model.in_shift.data.numpy()
in_scale = policy.model.in_scale.data.numpy()
out_shift = policy.model.out_shift.data.numpy()
out_scale = policy.model.out_scale.data.numpy()
fc0w = params[0].data.numpy()
fc0b = params[1].data.numpy()
_fc0w = np.dot(fc0w, np.diag(1.0 / in_scale))
_fc0b = fc0b - np.dot(_fc0w, in_shift)
assert _fc0w.shape == fc0w.shape
assert _fc0b.shape == fc0b.shape
fclw = params[4].data.numpy()
fclb = params[5].data.numpy()
_fclw = np.dot(np.diag(out_scale), fclw)
_fclb = fclb * out_scale + out_shift
assert _fclw.shape == fclw.shape
assert _fclb.shape == fclb.shape
out_dict = {
'fc0/weight': _fc0w,
'fc0/bias': _fc0b,
'fc1/weight': params[2].data.numpy(),
'fc1/bias': params[3].data.numpy(),
'last_fc/weight': _fclw,
'last_fc/bias': _fclb,
'last_fc_log_std/weight': _fclw,
'last_fc_log_std/bias': _fclb,
}
return out_dict
def pol_playback(env_name, pi, num_trajs=100, clip=True):
e = gym.make(env_name)
e.reset()
obs_ = []
act_ = []
rew_ = []
term_ = []
timeout_ = []
info_qpos_ = []
info_qvel_ = []
info_mean_ = []
info_logstd_ = []
info_env_state_ = collections.defaultdict(list)
ravg = []
for n in range(num_trajs):
e.reset()
returns = 0
for t in range(e._max_episode_steps):
obs = e.get_obs()
obs_.append(obs)
info_qpos_.append(e.env.data.qpos.ravel().copy())
info_qvel_.append(e.env.data.qvel.ravel().copy())
[info_env_state_[k].append(v) for k,v in e.get_env_state().items()]
action, infos = pi.get_action(obs)
action = pi.get_action(obs)[0] # eval
if clip:
action = np.clip(action, -1, 1)
act_.append(action)
info_mean_.append(infos['mean'])
info_logstd_.append(infos['log_std'])
_, rew, done, info = e.step(action)
returns += rew
rew_.append(rew)
if t == (e._max_episode_steps-1):
timeout = True
done = False
else:
timeout = False
term_.append(done)
timeout_.append(timeout)
if done or timeout:
e.reset()
break
#e.env.mj_render() # this is much faster
# e.render()
ravg.append(returns)
print(n, returns, t)
# write out hdf5 file
obs_ = np.array(obs_).astype(np.float32)
act_ = np.array(act_).astype(np.float32)
rew_ = np.array(rew_).astype(np.float32)
term_ = np.array(term_).astype(np.bool_)
timeout_ = np.array(timeout_).astype(np.bool_)
info_qpos_ = np.array(info_qpos_).astype(np.float32)
info_qvel_ = np.array(info_qvel_).astype(np.float32)
info_mean_ = np.array(info_mean_).astype(np.float32)
info_logstd_ = np.array(info_logstd_).astype(np.float32)
if clip:
dataset = h5py.File('%s_expert_clip.hdf5' % env_name, 'w')
else:
dataset = h5py.File('%s_expert.hdf5' % env_name, 'w')
#dataset.create_dataset('observations', obs_.shape, dtype='f4')
dataset.create_dataset('observations', data=obs_, compression='gzip')
dataset.create_dataset('actions', data=act_, compression='gzip')
dataset.create_dataset('rewards', data=rew_, compression='gzip')
dataset.create_dataset('terminals', data=term_, compression='gzip')
dataset.create_dataset('timeouts', data=timeout_, compression='gzip')
#dataset.create_dataset('infos/qpos', data=info_qpos_, compression='gzip')
#dataset.create_dataset('infos/qvel', data=info_qvel_, compression='gzip')
dataset.create_dataset('infos/action_mean', data=info_mean_, compression='gzip')
dataset.create_dataset('infos/action_log_std', data=info_logstd_, compression='gzip')
for k in info_env_state_:
dataset.create_dataset('infos/%s' % k, data=np.array(info_env_state_[k], dtype=np.float32), compression='gzip')
# write metadata
policy_params = extract_params(pi)
dataset['metadata/algorithm'] = np.string_('DAPG')
dataset['metadata/policy/nonlinearity'] = np.string_('tanh')
dataset['metadata/policy/output_distribution'] = np.string_('gaussian')
for k, v in policy_params.items():
dataset['metadata/policy/'+k] = v
if __name__ == '__main__':
main()
================================================
FILE: d4rl/scripts/generation/hand_dapg_random.py
================================================
import brenvs
import click
import h5py
import os
import gym
import numpy as np
import pickle
from mjrl.utils.gym_env import GymEnv
DESC = '''
Helper script to visualize policy (in mjrl format).\n
USAGE:\n
Visualizes policy on the env\n
$ python utils/visualize_policy --env_name relocate-v0 --policy policies/relocate-v0.pickle --mode evaluation\n
'''
# MAIN =========================================================
@click.command(help=DESC)
@click.option('--env_name', type=str, help='environment to load', required= True)
@click.option('--num_trajs', type=int, help='Num trajectories', default=5000)
def main(env_name, num_trajs):
e = GymEnv(env_name)
# render policy
pol_playback(env_name, num_trajs)
def pol_playback(env_name, num_trajs=100):
e = GymEnv(env_name)
e.reset()
obs_ = []
act_ = []
rew_ = []
term_ = []
timeout_ = []
info_qpos_ = []
info_qvel_ = []
info_env_state_ = []
ravg = []
for n in range(num_trajs):
e.reset()
returns = 0
for t in range(e._horizon):
obs = e.get_obs()
obs_.append(obs)
info_qpos_.append(e.env.data.qpos.ravel().copy())
info_qvel_.append(e.env.data.qvel.ravel().copy())
info_env_state_.append(e.get_env_state())
action = e.action_space.sample()
act_.append(action)
_, rew, done, info = e.step(action)
returns += rew
rew_.append(rew)
if t == (e._horizon-1):
timeout = True
done = False
else:
timeout = False
term_.append(done)
timeout_.append(timeout)
if done or timeout:
e.reset()
#e.env.mj_render() # this is much faster
# e.render()
ravg.append(returns)
# write out hdf5 file
obs_ = np.array(obs_).astype(np.float32)
act_ = np.array(act_).astype(np.float32)
rew_ = np.array(rew_).astype(np.float32)
term_ = np.array(term_).astype(np.bool_)
timeout_ = np.array(timeout_).astype(np.bool_)
info_qpos_ = np.array(info_qpos_).astype(np.float32)
info_qvel_ = np.array(info_qvel_).astype(np.float32)
dataset = h5py.File('%s_random.hdf5' % env_name, 'w')
#dataset.create_dataset('observations', obs_.shape, dtype='f4')
dataset.create_dataset('observations', data=obs_, compression='gzip')
dataset.create_dataset('actions', data=act_, compression='gzip')
dataset.create_dataset('rewards', data=rew_, compression='gzip')
dataset.create_dataset('terminals', data=term_, compression='gzip')
dataset.create_dataset('timeouts', data=timeout_, compression='gzip')
dataset.create_dataset('infos/qpos', data=info_qpos_, compression='gzip')
dataset.create_dataset('infos/qvel', data=info_qvel_, compression='gzip')
dataset.create_dataset('infos/env_state', data=np.array(info_env_state_, dtype=np.float32), compression='gzip')
if __name__ == '__main__':
main()
================================================
FILE: d4rl/scripts/generation/mujoco/collect_data.py
================================================
import argparse
import re
import h5py
import torch
import gym
import d4rl
import numpy as np
from rlkit.torch import pytorch_util as ptu
itr_re = re.compile(r'itr_(?P[0-9]+).pkl')
def load(pklfile):
params = torch.load(pklfile)
return params['trainer/policy']
def get_pkl_itr(pklfile):
match = itr_re.search(pklfile)
if match:
return match.group('itr')
raise ValueError(pklfile+" has no iteration number.")
def get_policy_wts(params):
out_dict = {
'fc0/weight': params.fcs[0].weight.data.numpy(),
'fc0/bias': params.fcs[0].bias.data.numpy(),
'fc1/weight': params.fcs[1].weight.data.numpy(),
'fc1/bias': params.fcs[1].bias.data.numpy(),
'last_fc/weight': params.last_fc.weight.data.numpy(),
'last_fc/bias': params.last_fc.bias.data.numpy(),
'last_fc_log_std/weight': params.last_fc_log_std.weight.data.numpy(),
'last_fc_log_std/bias': params.last_fc_log_std.bias.data.numpy(),
}
return out_dict
def get_reset_data():
data = dict(
observations = [],
next_observations = [],
actions = [],
rewards = [],
terminals = [],
timeouts = [],
logprobs = [],
qpos = [],
qvel = []
)
return data
def rollout(policy, env_name, max_path, num_data, random=False):
env = gym.make(env_name)
data = get_reset_data()
traj_data = get_reset_data()
_returns = 0
t = 0
done = False
s = env.reset()
while len(data['rewards']) < num_data:
if random:
a = env.action_space.sample()
logprob = np.log(1.0 / np.prod(env.action_space.high - env.action_space.low))
else:
torch_s = ptu.from_numpy(np.expand_dims(s, axis=0))
distr = policy.forward(torch_s)
a = distr.sample()
logprob = distr.log_prob(a)
a = ptu.get_numpy(a).squeeze()
#mujoco only
qpos, qvel = env.sim.data.qpos.ravel().copy(), env.sim.data.qvel.ravel().copy()
try:
ns, rew, done, infos = env.step(a)
except:
print('lost connection')
env.close()
env = gym.make(env_name)
s = env.reset()
traj_data = get_reset_data()
t = 0
_returns = 0
continue
_returns += rew
t += 1
timeout = False
terminal = False
if t == max_path:
timeout = True
elif done:
terminal = True
traj_data['observations'].append(s)
traj_data['actions'].append(a)
traj_data['next_observations'].append(ns)
traj_data['rewards'].append(rew)
traj_data['terminals'].append(terminal)
traj_data['timeouts'].append(timeout)
traj_data['logprobs'].append(logprob)
traj_data['qpos'].append(qpos)
traj_data['qvel'].append(qvel)
s = ns
if terminal or timeout:
print('Finished trajectory. Len=%d, Returns=%f. Progress:%d/%d' % (t, _returns, len(data['rewards']), num_data))
s = env.reset()
t = 0
_returns = 0
for k in data:
data[k].extend(traj_data[k])
traj_data = get_reset_data()
new_data = dict(
observations=np.array(data['observations']).astype(np.float32),
actions=np.array(data['actions']).astype(np.float32),
next_observations=np.array(data['next_observations']).astype(np.float32),
rewards=np.array(data['rewards']).astype(np.float32),
terminals=np.array(data['terminals']).astype(np.bool),
timeouts=np.array(data['timeouts']).astype(np.bool)
)
new_data['infos/action_log_probs'] = np.array(data['logprobs']).astype(np.float32)
new_data['infos/qpos'] = np.array(data['qpos']).astype(np.float32)
new_data['infos/qvel'] = np.array(data['qvel']).astype(np.float32)
for k in new_data:
new_data[k] = new_data[k][:num_data]
return new_data
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('env', type=str)
parser.add_argument('--pklfile', type=str, default=None)
parser.add_argument('--output_file', type=str, default='output.hdf5')
parser.add_argument('--max_path', type=int, default=1000)
parser.add_argument('--num_data', type=int, default=10000)
parser.add_argument('--random', action='store_true')
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()
np.random.seed(args.seed)
torch.manual_seed(args.seed)
policy = None
if not args.random:
policy = load(args.pklfile)
data = rollout(policy, args.env, max_path=args.max_path, num_data=args.num_data, random=args.random)
hfile = h5py.File(args.output_file, 'w')
for k in data:
hfile.create_dataset(k, data=data[k], compression='gzip')
if args.random:
pass
else:
hfile['metadata/algorithm'] = np.string_('SAC')
hfile['metadata/iteration'] = np.array([get_pkl_itr(args.pklfile)], dtype=np.int32)[0]
hfile['metadata/policy/nonlinearity'] = np.string_('relu')
hfile['metadata/policy/output_distribution'] = np.string_('tanh_gaussian')
for k, v in get_policy_wts(policy).items():
hfile['metadata/policy/'+k] = v
hfile.close()
================================================
FILE: d4rl/scripts/generation/mujoco/convert_buffer.py
================================================
import argparse
import re
import h5py
import torch
import numpy as np
itr_re = re.compile(r'itr_(?P[0-9]+).pkl')
def load(pklfile):
params = torch.load(pklfile)
env_infos = params['replay_buffer/env_infos']
results = {
'observations': params['replay_buffer/observations'],
'next_observations': params['replay_buffer/next_observations'],
'actions': params['replay_buffer/actions'],
'rewards': params['replay_buffer/rewards'],
'terminals': env_infos['terminal'].squeeze(),
'timeouts': env_infos['timeout'].squeeze(),
'infos/action_log_probs': env_infos['action_log_prob'].squeeze(),
}
if 'qpos' in env_infos:
results['infos/qpos'] = env_infos['qpos']
results['infos/qvel'] = env_infos['qvel']
return results
def get_pkl_itr(pklfile):
match = itr_re.search(pklfile)
if match:
return match.group('itr')
raise ValueError(pklfile+" has no iteration number.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('pklfile', type=str)
parser.add_argument('--output_file', type=str, default='output.hdf5')
args = parser.parse_args()
data = load(args.pklfile)
hfile = h5py.File(args.output_file, 'w')
for k in data:
hfile.create_dataset(k, data=data[k], compression='gzip')
hfile['metadata/algorithm'] = np.string_('SAC')
hfile['metadata/iteration'] = np.array([get_pkl_itr(args.pklfile)], dtype=np.int32)[0]
hfile.close()
================================================
FILE: d4rl/scripts/generation/mujoco/fix_qpos_qvel.py
================================================
import numpy as np
import argparse
import d4rl
import d4rl.offline_env
import gym
import h5py
import os
def unwrap_env(env):
return env.env.wrapped_env
def set_state_qpos(env, qpos, qvel):
env.set_state(qpos, qvel)
def pad_obs(env, obs, twod=False, scale=0.1):
#TODO: sample val
if twod:
val = env.init_qpos[0:2] + np.random.uniform(size=2, low=-.1, high=.1)
state = np.concatenate([np.ones(2)*val, obs])
else:
val = env.init_qpos[0:1] + np.random.uniform(size=1, low=-scale, high=scale)
state = np.concatenate([np.ones(1)*val, obs])
return state
def set_state_obs(env, obs):
env_name = (str(unwrap_env(env).__class__))
ant_env = 'Ant' in env_name
hopper_walker_env = 'Hopper' in env_name or 'Walker' in env_name
state = pad_obs(env, obs, twod=ant_env, scale=0.005 if hopper_walker_env else 0.1)
qpos_dim = env.sim.data.qpos.size
if ant_env:
env.set_state(state[:15], state[15:29])
else:
env.set_state(state[:qpos_dim], state[qpos_dim:])
def resync_state_obs(env, obs):
# Prevents drifting of the obs over time
ant_env = 'Ant' in (str(unwrap_env(env).__class__))
cur_qpos, cur_qvel = env.sim.data.qpos.ravel().copy(), env.sim.data.qvel.ravel().copy()
if ant_env:
cur_qpos[2:15] = obs[0:13]
cur_qvel = obs[13:27]
env.set_state(cur_qpos, cur_qvel)
else:
qpos_dim = env.sim.data.qpos.size
cur_qpos[1:] = obs[0:qpos_dim-1]
cur_qvel = obs[qpos_dim-1:]
env.set_state(cur_qpos, cur_qvel)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('env', type=str)
args = parser.parse_args()
env = gym.make(args.env)
env.reset()
fname = unwrap_env(env).dataset_url.split('/')[-1]
prefix, ext = os.path.splitext(fname)
#out_fname = prefix+'_qfix'+ext
out_fname = prefix+ext
dset = env.get_dataset()
all_qpos = dset['infos/qpos']
all_qvel = dset['infos/qvel']
observations = dset['observations']
actions = dset['actions']
dones = dset['terminals']
timeouts = dset['timeouts']
terminals = dones + timeouts
start_obs = observations[0]
set_state_obs(env, start_obs)
#set_state_qpos(env, all_qpos[0], all_qvel[0])
new_qpos = []
new_qvel = []
for t in range(actions.shape[0]):
cur_qpos, cur_qvel = env.sim.data.qpos.ravel().copy(), env.sim.data.qvel.ravel().copy()
new_qpos.append(cur_qpos)
new_qvel.append(cur_qvel)
next_obs, reward, done, infos = env.step(actions[t])
if t == actions.shape[0]-1:
break
if terminals[t]:
set_state_obs(env, observations[t+1])
#print(t, 'done')
else:
true_next_obs = observations[t+1]
error = ((true_next_obs - next_obs)**2).sum()
if t % 1000 == 0:
print(t, error)
# prevent drifting over time
resync_state_obs(env, observations[t+1])
dset_filepath = d4rl.offline_env.download_dataset_from_url(unwrap_env(env).dataset_url)
inf = h5py.File(dset_filepath, 'r')
outf = h5py.File(out_fname, 'w')
for k in d4rl.offline_env.get_keys(inf):
print('writing', k)
if 'qpos' in k:
outf.create_dataset(k, data=np.array(new_qpos), compression='gzip')
elif 'qvel' in k:
outf.create_dataset(k, data=np.array(new_qvel), compression='gzip')
else:
try:
if 'reward' in k:
outf.create_dataset(k, data=inf[k][:].squeeze().astype(np.float32), compression='gzip')
else:
if 'terminals' in k or 'timeouts' in k:
outf.create_dataset(k, data=inf[k][:].astype(np.bool), compression='gzip')
else:
outf.create_dataset(k, data=inf[k][:].astype(np.float32), compression='gzip')
except Exception as e:
print(e)
outf.create_dataset(k, data=inf[k])
outf.close()
================================================
FILE: d4rl/scripts/generation/mujoco/stitch_dataset.py
================================================
import argparse
import h5py
import numpy as np
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('file1', type=str, default=None)
parser.add_argument('file2', type=str, default=None)
parser.add_argument('--output_file', type=str, default='output.hdf5')
parser.add_argument('--maxlen', type=int, default=2000000)
args = parser.parse_args()
hfile1 = h5py.File(args.file1, 'r')
hfile2 = h5py.File(args.file2, 'r')
outf = h5py.File(args.output_file, 'w')
keys = ['observations', 'next_observations', 'actions', 'rewards', 'terminals', 'timeouts', 'infos/action_log_probs', 'infos/qpos', 'infos/qvel']
# be careful with trajectories not ending at the end of a file!
# find end of last traj
terms = hfile1['terminals'][:]
tos = hfile1['timeouts'][:]
last_term = 0
for i in range(terms.shape[0]-1, -1, -1):
if terms[i] or tos[i]:
last_term = i
break
N = last_term + 1
for k in keys:
d1 = hfile1[k][:N]
d2 = hfile2[k][:]
combined = np.concatenate([d1,d2],axis=0)[:args.maxlen]
print(k, combined.shape)
outf.create_dataset(k, data=combined, compression='gzip')
outf.close()
================================================
FILE: d4rl/scripts/generation/relabel_antmaze_rewards.py
================================================
import d4rl.locomotion
from d4rl.offline_env import get_keys
import os
import argparse
import numpy as np
import gym
import h5py
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--env_name', default='antmaze-umaze-v0', help='')
parser.add_argument('--relabel_type', default='sparse', help='')
parser.add_argument('--filename', type=str)
args = parser.parse_args()
env = gym.make(args.env_name)
target_goal = env.target_goal
# print ('Target Goal: ', target_goal)
rdataset = h5py.File(args.filename, 'r')
fpath, ext = os.path.splitext(args.filename)
wdataset = h5py.File(fpath + '_' + args.relabel_type + ext, 'w')
all_obs = rdataset['observations'][:]
if args.relabel_type == 'dense':
"""reward at the next state = dist(s', g)"""
_rew = np.exp(-np.linalg.norm(all_obs[1:,:2] - target_goal, axis=1))
elif args.relabel_type == 'sparse':
_rew = (np.linalg.norm(all_obs[1:,:2] - target_goal, axis=1) <= 0.5).astype(np.float32)
else:
_rew = rdataset['rewards'][:]
# Also add terminals here
_terminals = (np.linalg.norm(all_obs[1:,:2] - target_goal, axis=1) <= 0.5).astype(np.float32)
_terminals = np.concatenate([_terminals, np.array([0])], 0)
_rew = np.concatenate([_rew, np.array([0])], 0)
print ('Sum of rewards: ', _rew.sum())
for k in get_keys(rdataset):
print(k)
if k == 'rewards':
wdataset.create_dataset(k, data=_rew, compression='gzip')
elif k == 'terminals':
wdataset.create_dataset(k, data=_terminals, compression='gzip')
else:
wdataset.create_dataset(k, data=rdataset[k], compression='gzip')
================================================
FILE: d4rl/scripts/generation/relabel_maze2d_rewards.py
================================================
from d4rl.pointmaze import MazeEnv, maze_model
from d4rl.offline_env import get_keys
import os
import argparse
import numpy as np
import h5py
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='SAC-BEAR')
parser.add_argument('--maze', default='umaze', help='')
parser.add_argument('--relabel_type', default='dense', help='')
parser.add_argument('--filename', type=str)
args = parser.parse_args()
if args.maze == 'umaze':
maze = maze_model.U_MAZE
elif args.maze == 'open':
maze = maze_model.OPEN
elif args.maze == 'medium':
maze = maze_model.MEDIUM_MAZE
else:
maze = maze_model.LARGE_MAZE
env = MazeEnv(maze, reset_target=False, reward_type='sparse')
target_goal = env._target
rdataset = h5py.File(args.filename, 'r')
fpath, ext = os.path.splitext(args.filename)
wdataset = h5py.File(fpath+'-'+args.relabel_type+ext, 'w')
all_obs = rdataset['observations']
if args.relabel_type == 'dense':
_rew = np.exp(-np.linalg.norm(all_obs[:,:2] - target_goal, axis=1))
elif args.relabel_type == 'sparse':
_rew = (np.linalg.norm(all_obs[:,:2] - target_goal, axis=1) <= 0.5).astype(np.float32)
else:
_rew = rdataset['rewards'].value
for k in get_keys(rdataset):
print(k)
if k == 'rewards':
wdataset.create_dataset(k, data=_rew, compression='gzip')
else:
if k.startswith('metadata'):
wdataset[k] = rdataset[k][()]
else:
wdataset.create_dataset(k, data=rdataset[k], compression='gzip')
================================================
FILE: d4rl/scripts/ope_rollout.py
================================================
"""
This script runs rollouts on the OPE policies
using the ONNX runtime and averages the returns.
"""
import d4rl
import gym
import sys
import onnx
import onnxruntime as ort
import numpy as np
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('policy', type=str, help='ONNX policy file. i.e. cheetah.sampler.onnx')
parser.add_argument('env_name', type=str, help='Env name')
parser.add_argument('--num_rollouts', type=int, default=10, help='Number of rollouts to run.')
args = parser.parse_args()
env = gym.make(args.env_name)
policy = ort.InferenceSession(args.policy)
all_returns = []
for _ in range(args.num_rollouts):
s = env.reset()
returns = 0
for t in range(env._max_episode_steps):
obs_input = np.expand_dims(s, axis=0).astype(np.float32)
noise_input = np.random.randn(1, env.action_space.shape[0]).astype(np.float32)
action, _, _ = policy.run(None, {'observations': obs_input, 'noise': noise_input})
s, r, d, _ = env.step(action)
returns += r
print(returns, end='\r')
all_returns.append(returns)
print(args.env_name, ':', np.mean(returns))
================================================
FILE: d4rl/scripts/reference_scores/adroit_expert.py
================================================
"""
Instructions:
1) Download the expert policies from https://github.com/aravindr93/hand_dapg
2) Place the policies from dapg_policies in the current directory
3) Run this script passing in the appropriate env_name
"""
import d4rl
import argparse
import os
import gym
import numpy as np
import pickle
from mjrl.utils.gym_env import GymEnv
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--env_name', default='', help='Environment Name')
parser.add_argument('--num_episodes', type=int, default=100)
args = parser.parse_args()
policy = './policies/'+args.env_name+'.pickle'
pi = pickle.load(open(policy, 'rb'))
e = gym.make(args.env_name)
e.seed(0)
e.reset()
ravg = []
for n in range(args.num_episodes):
e.reset()
returns = 0
for t in range(e._max_episode_steps):
obs = e.get_obs()
action, infos = pi.get_action(obs)
action = pi.get_action(obs)[0] # eval
_, rew, done, info = e.step(action)
returns += rew
if done:
break
# e.env.mj_render() # this is much faster
# e.render()
ravg.append(returns)
print(args.env_name, 'returns', np.mean(ravg))
if __name__ == '__main__':
main()
================================================
FILE: d4rl/scripts/reference_scores/carla_lane_controller.py
================================================
import d4rl
import gym
from d4rl.carla import data_collection_agent_lane
import numpy as np
import argparse
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--env_name', type=str, default='carla-lane-v0', help='Maze type. small or default')
parser.add_argument('--num_episodes', type=int, default=100, help='Num samples to collect')
args = parser.parse_args()
env = gym.make(args.env_name)
env.seed(0)
np.random.seed(0)
ravg = []
for i in range(args.num_episodes):
s = env.reset()
controller = data_collection_agent_lane.RoamingAgent(env)
returns = 0
for t in range(env._max_episode_steps):
act = controller.compute_action()
s, rew, done, _ = env.step(act)
returns += rew
if done:
break
ravg.append(returns)
print(i, returns, ' mean:', np.mean(ravg))
print(args.env_name, 'returns', np.mean(ravg))
if __name__ == "__main__":
main()
================================================
FILE: d4rl/scripts/reference_scores/generate_ref_min_score.py
================================================
"""
Generate "minimum" reference scores by averaging the score for a random
policy over 100 episodes.
"""
import d4rl
import argparse
import gym
import numpy as np
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--env_name', default='', help='Environment Name')
parser.add_argument('--num_episodes', type=int, default=100)
args = parser.parse_args()
env = gym.make(args.env_name)
env.seed(0)
try:
env.action_space.seed(0)
except:
pass
ravg = []
for n in range(args.num_episodes):
env.reset()
returns = 0
for t in range(env._max_episode_steps):
action = env.action_space.sample()
_, rew, done, info = env.step(action)
returns += rew
if done:
break
ravg.append(returns)
print('%s Average returns (%d ep): %f' % (args.env_name, args.num_episodes, np.mean(ravg)))
if __name__ == "__main__":
main()
================================================
FILE: d4rl/scripts/reference_scores/generate_ref_min_score.sh
================================================
for e in $(cat scripts/reference_scores/envs.txt)
do
python scripts/reference_scores/generate_ref_min_score.py --env_name=$e
done
================================================
FILE: d4rl/scripts/reference_scores/maze2d_bullet_controller.py
================================================
import d4rl
import gym
from d4rl.pointmaze import waypoint_controller
from d4rl.pointmaze import maze_model
import numpy as np
import argparse
import time
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--env_name', type=str, default='maze2d-umaze-v0', help='Maze type. small or default')
parser.add_argument('--num_episodes', type=int, default=100, help='Num samples to collect')
parser.add_argument('--render', action='store_true')
args = parser.parse_args()
env = gym.make(args.env_name)
if args.render:
env.render('human')
env.seed(0)
np.random.seed(0)
d_gain = -2.0
p_gain = 10.0
controller = waypoint_controller.WaypointController(env.env.str_maze_spec, p_gain=p_gain, d_gain=d_gain)
print('max steps:', env._max_episode_steps)
ravg = []
for _ in range(args.num_episodes):
controller = waypoint_controller.WaypointController(env.env.str_maze_spec, p_gain=p_gain, d_gain=d_gain)
s = env.reset()
returns = 0
for t in range(env._max_episode_steps):
position = s[0:2]
velocity = s[2:4]
act, done = controller.get_action(position, velocity, np.array(env.env.get_target()))
#print(position-1, controller.current_waypoint(), np.array(env.env.get_target()) - 1)
#print('\t', act)
s, rew, _, _ = env.step(act)
if args.render:
time.sleep(0.01)
env.render('human')
returns += rew
print(returns)
ravg.append(returns)
print(args.env_name, 'returns', np.mean(ravg))
if __name__ == "__main__":
main()
================================================
FILE: d4rl/scripts/reference_scores/maze2d_controller.py
================================================
import d4rl
import gym
from d4rl.pointmaze import waypoint_controller
from d4rl.pointmaze import maze_model
import numpy as np
import argparse
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--env_name', type=str, default='maze2d-umaze-v0', help='Maze type. small or default')
parser.add_argument('--num_episodes', type=int, default=100, help='Num samples to collect')
args = parser.parse_args()
env = gym.make(args.env_name)
env.seed(0)
np.random.seed(0)
controller = waypoint_controller.WaypointController(env.str_maze_spec)
ravg = []
for _ in range(args.num_episodes):
s = env.reset()
returns = 0
for t in range(env._max_episode_steps):
position = s[0:2]
velocity = s[2:4]
act, done = controller.get_action(position, velocity, env.get_target())
s, rew, _, _ = env.step(act)
returns += rew
ravg.append(returns)
print(args.env_name, 'returns', np.mean(ravg))
if __name__ == "__main__":
main()
================================================
FILE: d4rl/scripts/reference_scores/minigrid_controller.py
================================================
import logging
from offline_rl.gym_minigrid import fourroom_controller
from offline_rl.gym_minigrid.envs import fourrooms
import numpy as np
import pickle
import gzip
import h5py
import argparse
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--num_episodes', type=int, default=100, help='Num trajs to collect')
args = parser.parse_args()
np.random.seed(0)
env = fourrooms.FourRoomsEnv()
env.seed(0)
controller = fourroom_controller.FourRoomController()
controller.set_target(env.get_target())
ravg = []
for _ in range(args.num_episodes):
s = env.reset()
returns = 0
for t in range(50):
act, done = controller.get_action(env.agent_pos, env.agent_dir)
ns, rew, _, _ = env.step(act)
returns += rew
ravg.append(returns)
print('returns', np.mean(ravg))
if __name__ == "__main__":
main()
================================================
FILE: d4rl/scripts/visualize_dataset.py
================================================
import argparse
import d4rl
import gym
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--env_name', type=str, default='maze2d-umaze-v0')
args = parser.parse_args()
env = gym.make(args.env_name)
dataset = env.get_dataset()
if 'infos/qpos' not in dataset:
raise ValueError('Only MuJoCo-based environments can be visualized')
qpos = dataset['infos/qpos']
qvel = dataset['infos/qvel']
rewards = dataset['rewards']
actions = dataset['actions']
env.reset()
env.set_state(qpos[0], qvel[0])
for t in range(qpos.shape[0]):
env.set_state(qpos[t], qvel[t])
env.render()
================================================
FILE: d4rl/setup.py
================================================
from distutils.core import setup
from platform import platform
from setuptools import find_packages
setup(
name='d4rl',
version='1.1',
install_requires=['gym',
'numpy',
'mujoco_py',
'pybullet',
'h5py',
'termcolor', # adept_envs dependency
'click', # adept_envs dependency
'dm_control' if 'macOS' in platform() else
'dm_control @ git+https://github.com/deepmind/dm_control@main#egg=dm_control',
'mjrl @ git+https://github.com/aravindr93/mjrl@master#egg=mjrl'],
packages=find_packages(),
package_data={'d4rl': ['locomotion/assets/*',
'hand_manipulation_suite/assets/*',
'hand_manipulation_suite/Adroit/*',
'hand_manipulation_suite/Adroit/gallery/*',
'hand_manipulation_suite/Adroit/resources/*',
'hand_manipulation_suite/Adroit/resources/meshes/*',
'hand_manipulation_suite/Adroit/resources/textures/*',
]},
include_package_data=True,
)
================================================
FILE: dataset_utils.py
================================================
import collections
from typing import Optional
import jax
import d4rl
import gym
import numpy as np
import jax.numpy as jnp
from tqdm import tqdm, trange
Batch = collections.namedtuple(
'Batch',
['observations', 'actions', 'rewards', 'masks', 'next_observations'])
def split_into_trajectories(observations, actions, rewards, masks, dones_float,
next_observations):
trajs = [[]]
for i in tqdm(range(len(observations)), desc="split"):
trajs[-1].append((observations[i], actions[i], rewards[i], masks[i],
dones_float[i], next_observations[i]))
if dones_float[i] == 1.0 and i + 1 < len(observations):
trajs.append([])
return trajs
def merge_trajectories(trajs):
observations = []
actions = []
rewards = []
masks = []
dones_float = []
next_observations = []
for traj in trajs:
for (obs, act, rew, mask, done, next_obs) in traj:
observations.append(obs)
actions.append(act)
rewards.append(rew)
masks.append(mask)
dones_float.append(done)
next_observations.append(next_obs)
return np.stack(observations), np.stack(actions), np.stack(
rewards), np.stack(masks), np.stack(dones_float), np.stack(
next_observations)
class Dataset(object):
def __init__(self, observations: np.ndarray, actions: np.ndarray,
rewards: np.ndarray, masks: np.ndarray,
dones_float: np.ndarray, next_observations: np.ndarray,
size: int):
self.observations = observations
self.actions = actions
self.rewards = rewards
self.masks = masks
self.dones_float = dones_float
self.next_observations = next_observations
self.size = size
def sample(self, batch_size: int) -> Batch:
indx = np.random.randint(self.size, size=batch_size)
return Batch(observations=self.observations[indx],
actions=self.actions[indx],
rewards=self.rewards[indx],
masks=self.masks[indx],
next_observations=self.next_observations[indx])
class D4RLDataset(Dataset):
def __init__(self,
env: gym.Env,
clip_to_eps: bool = True,
eps: float = 1e-5):
dataset = d4rl.qlearning_dataset(env)
if clip_to_eps:
lim = 1 - eps
dataset['actions'] = np.clip(dataset['actions'], -lim, lim)
dones_float = np.zeros_like(dataset['rewards'])
for i in range(len(dones_float) - 1):
if np.linalg.norm(dataset['observations'][i + 1] -
dataset['next_observations'][i]
) > 1e-5 or dataset['terminals'][i] == 1.0:
dones_float[i] = 1
else:
dones_float[i] = 0
dones_float[-1] = 1
super().__init__(dataset['observations'].astype(np.float32),
actions=dataset['actions'].astype(np.float32),
rewards=dataset['rewards'].astype(np.float32),
masks=1.0 - dataset['terminals'].astype(np.float32),
dones_float=dones_float.astype(np.float32),
next_observations=dataset['next_observations'].astype(
np.float32),
size=len(dataset['observations']))
class RelabeledDataset(Dataset):
def __init__(self, observations, actions, rewards, terminals, next_observations, clip_to_eps: bool = True, eps: float = 1e-5):
if clip_to_eps:
lim = 1 - eps
actions = np.clip(actions, -lim, lim)
dones_float = np.zeros_like(rewards)
for i in range(len(dones_float) - 1):
if np.linalg.norm(observations[i + 1] -
next_observations[i]
) > 1e-6 or terminals[i] == 1.0:
dones_float[i] = 1
else:
dones_float[i] = 0
dones_float[-1] = 1
super().__init__(
observations=observations,
actions=actions,
rewards=rewards,
masks=1.0 - terminals,
dones_float=dones_float.astype(np.float32),
next_observations=next_observations,
size=len(observations)
)
class ReplayBuffer(Dataset):
def __init__(self, observation_space: gym.spaces.Box, action_dim: int,
capacity: int):
observations = np.empty((capacity, *observation_space.shape),
dtype=observation_space.dtype)
actions = np.empty((capacity, action_dim), dtype=np.float32)
rewards = np.empty((capacity, ), dtype=np.float32)
masks = np.empty((capacity, ), dtype=np.float32)
dones_float = np.empty((capacity, ), dtype=np.float32)
next_observations = np.empty((capacity, *observation_space.shape),
dtype=observation_space.dtype)
super().__init__(observations=observations,
actions=actions,
rewards=rewards,
masks=masks,
dones_float=dones_float,
next_observations=next_observations,
size=0)
self.size = 0
self.insert_index = 0
self.capacity = capacity
def initialize_with_dataset(self, dataset: Dataset,
num_samples: Optional[int]):
assert self.insert_index == 0, 'Can insert a batch online in an empty replay buffer.'
dataset_size = len(dataset.observations)
if num_samples is None:
num_samples = dataset_size
else:
num_samples = min(dataset_size, num_samples)
assert self.capacity >= num_samples, 'Dataset cannot be larger than the replay buffer capacity.'
if num_samples < dataset_size:
perm = np.random.permutation(dataset_size)
indices = perm[:num_samples]
else:
indices = np.arange(num_samples)
self.observations[:num_samples] = dataset.observations[indices]
self.actions[:num_samples] = dataset.actions[indices]
self.rewards[:num_samples] = dataset.rewards[indices]
self.masks[:num_samples] = dataset.masks[indices]
self.dones_float[:num_samples] = dataset.dones_float[indices]
self.next_observations[:num_samples] = dataset.next_observations[
indices]
self.insert_index = num_samples
self.size = num_samples
def insert(self, observation: np.ndarray, action: np.ndarray,
reward: float, mask: float, done_float: float,
next_observation: np.ndarray):
self.observations[self.insert_index] = observation
self.actions[self.insert_index] = action
self.rewards[self.insert_index] = reward
self.masks[self.insert_index] = mask
self.dones_float[self.insert_index] = done_float
self.next_observations[self.insert_index] = next_observation
self.insert_index = (self.insert_index + 1) % self.capacity
self.size = min(self.size + 1, self.capacity)
@jax.jit
def batch_to_jax(batch):
return jax.tree_util.tree_map(jax.device_put, batch)
def reward_from_preference(
env_name: str,
dataset: D4RLDataset,
reward_model,
batch_size: int = 256,
):
data_size = dataset.rewards.shape[0]
interval = int(data_size / batch_size) + 1
new_r = np.zeros_like(dataset.rewards)
for i in trange(interval):
start_pt = i * batch_size
end_pt = (i + 1) * batch_size
input = dict(
observations=dataset.observations[start_pt:end_pt],
actions=dataset.actions[start_pt:end_pt],
next_observations=dataset.next_observations[start_pt:end_pt]
)
jax_input = batch_to_jax(input)
new_reward = reward_model.get_reward(jax_input)
new_reward = np.asarray(list(new_reward))
new_r[start_pt:end_pt] = new_reward
dataset.rewards = new_r.copy()
return dataset
def reward_from_preference_transformer(
env_name: str,
dataset: D4RLDataset,
reward_model,
seq_len: int,
batch_size : int = 256,
use_diff: bool = False,
label_mode: str = 'last',
with_attn_weights: bool = False # Option for attention analysis.
):
trajs = split_into_trajectories(
dataset.observations,
dataset.actions,
dataset.rewards,
dataset.masks,
dataset.dones_float,
dataset.next_observations
)
trajectories = []
trj_mapper = []
observation_dim = dataset.observations.shape[-1]
action_dim = dataset.actions.shape[-1]
for trj_idx, traj in tqdm(enumerate(trajs), total=len(trajs), desc="chunk trajectories"):
_obs, _act, _reward, _mask, _done, _next_obs = [], [], [], [], [], []
for _o, _a, _r, _m, _d, _no in traj:
_obs.append(_o)
_act.append(_a)
_reward.append(_r)
_mask.append(_m)
_done.append(_d)
_next_obs.append(_no)
traj_len = len(traj)
_obs, _act = np.asarray(_obs), np.asarray(_act)
trajectories.append((_obs, _act))
for seg_idx in range(traj_len):
trj_mapper.append((trj_idx, seg_idx))
data_size = dataset.rewards.shape[0]
interval = int(data_size / batch_size) + 1
new_r = np.zeros_like(dataset.rewards)
pts = []
attn_weights = []
for i in trange(interval, desc="relabel reward"):
start_pt = i * batch_size
end_pt = min((i + 1) * batch_size, data_size)
_input_obs, _input_act, _input_timestep, _input_attn_mask, _input_pt = [], [], [], [], []
for pt in range(start_pt, end_pt):
_trj_idx, _seg_idx = trj_mapper[pt]
if _seg_idx < seq_len - 1:
__input_obs = np.concatenate([np.zeros((seq_len - 1 - _seg_idx, observation_dim)), trajectories[_trj_idx][0][:_seg_idx + 1, :]], axis=0)
__input_act = np.concatenate([np.zeros((seq_len - 1 - _seg_idx, action_dim)), trajectories[_trj_idx][1][:_seg_idx + 1, :]], axis=0)
__input_timestep = np.concatenate([np.zeros(seq_len - 1 - _seg_idx, dtype=np.int32), np.arange(1, _seg_idx + 2, dtype=np.int32)], axis=0)
__input_attn_mask = np.concatenate([np.zeros(seq_len - 1 - _seg_idx, dtype=np.int32), np.ones(_seg_idx + 1, dtype=np.float32)], axis=0)
__input_pt = np.concatenate([np.zeros(seq_len - 1 - _seg_idx), np.arange(pt - _seg_idx , pt + 1)], axis=0)
else:
__input_obs = trajectories[_trj_idx][0][_seg_idx - seq_len + 1:_seg_idx + 1, :]
__input_act = trajectories[_trj_idx][1][_seg_idx - seq_len + 1:_seg_idx + 1, :]
__input_timestep = np.arange(1, seq_len + 1, dtype=np.int32)
__input_attn_mask = np.ones((seq_len), dtype=np.float32)
__input_pt = np.arange(pt - seq_len + 1, pt + 1)
_input_obs.append(__input_obs)
_input_act.append(__input_act)
_input_timestep.append(__input_timestep)
_input_attn_mask.append(__input_attn_mask)
_input_pt.append(__input_pt)
_input_obs = np.asarray(_input_obs)
_input_act = np.asarray(_input_act)
_input_timestep = np.asarray(_input_timestep)
_input_attn_mask = np.asarray(_input_attn_mask)
_input_pt = np.asarray(_input_pt)
input = dict(
observations=_input_obs,
actions=_input_act,
timestep=_input_timestep,
attn_mask=_input_attn_mask,
)
jax_input = batch_to_jax(input)
if with_attn_weights:
new_reward, attn_weight = reward_model.get_reward(jax_input)
attn_weights.append(np.array(attn_weight))
pts.append(_input_pt)
else:
new_reward, _ = reward_model.get_reward(jax_input)
new_reward = new_reward.reshape(end_pt - start_pt, seq_len) * _input_attn_mask
if use_diff:
prev_input = dict(
observations=_input_obs[:, :seq_len - 1, :],
actions=_input_act[:, :seq_len - 1, :],
timestep=_input_timestep[:, :seq_len - 1],
attn_mask=_input_attn_mask[:, :seq_len - 1],
)
jax_prev_input = batch_to_jax(prev_input)
prev_reward, _ = reward_model.get_reward(jax_prev_input)
prev_reward = prev_reward.reshape(end_pt - start_pt, seq_len - 1) * prev_input["attn_mask"]
if label_mode == "mean":
new_reward = jnp.sum(new_reward, axis=1).reshape(-1, 1)
prev_reward = jnp.sum(prev_reward, axis=1).reshape(-1, 1)
elif label_mode == "last":
new_reward = new_reward[:, -1].reshape(-1, 1)
prev_reward = prev_reward[:, -1].reshape(-1, 1)
new_reward -= prev_reward
else:
if label_mode == "mean":
new_reward = jnp.sum(new_reward, axis=1) / jnp.sum(_input_attn_mask, axis=1)
new_reward = new_reward.reshape(-1, 1)
elif label_mode == "last":
new_reward = new_reward[:, -1].reshape(-1, 1)
new_reward = np.asarray(list(new_reward))
new_r[start_pt:end_pt, ...] = new_reward.squeeze(-1)
dataset.rewards = new_r.copy()
if with_attn_weights:
return dataset, (attn_weights, pts)
return dataset
================================================
FILE: evaluation.py
================================================
from typing import Dict
import flax.linen as nn
import gym
import numpy as np
from tqdm import trange
def evaluate(agent: nn.Module, env: gym.Env,
num_episodes: int) -> Dict[str, float]:
stats = {'return': [], 'length': [], 'success': []}
for _ in trange(num_episodes, desc='evaluation', leave=False):
observation, done = env.reset(), False
while not done:
action = agent.sample_actions(observation, temperature=0.0)
observation, _, done, info = env.step(action)
for k in stats.keys():
stats[k].append(info['episode'][k])
for k, v in stats.items():
stats[k] = np.mean(v)
return stats
================================================
FILE: flaxmodels/README.md
================================================
### About
The goal of this project is to make current deep learning models more easily available for the awesome Jax/Flax ecosystem.
### Models
* GPT2 [[model](flaxmodels/gpt2)]
* StyleGAN2 [[model](flaxmodels/stylegan2)] [[training](training/stylegan2)]
* ResNet{18, 34, 50, 101, 152} [[model](flaxmodels/resnet)] [[training](training/resnet)]
* VGG{16, 19} [[model](flaxmodels/vgg)] [[training](training/vgg)]
* FewShotGanAdaption [[model](flaxmodels/few_shot_gan_adaption)] [[training](training/few_shot_gan_adaption)]
### Installation
You will need Python 3.7 or later.
1. For GPU usage, follow the Jax installation with CUDA.
2. Then install:
```sh
> pip install --upgrade git+https://github.com/matthias-wright/flaxmodels.git
```
For CPU-only you can skip step 1.
### Documentation
The documentation for the models can be found [here](docs/Documentation.md#models).
### Checkpoints
The checkpoints are taken from the repositories that are referenced on the model pages. The processing steps and the format of the checkpoints are documented [here](docs/Documentation.md#1-checkpoints).
### Testing
To run the tests, pytest needs to be installed.
```sh
> git clone https://github.com/matthias-wright/flaxmodels.git
> cd flaxmodels
> python -m pytest tests/
```
See [here](docs/Documentation.md#2-testing) for an explanation of the testing strategy.
### Acknowledgments
Thank you to the developers of Jax and Flax. The title image is a photograph of a flax flower, kindly made available by Marta Matyszczyk.
### License
Each model has an individual license.
================================================
FILE: flaxmodels/flaxmodels/__init__.py
================================================
from . import gpt2, lstm
__version__ = '0.1.2'
================================================
FILE: flaxmodels/flaxmodels/gpt2/README.md
================================================
# Better Language Models and Their Implications (GPT2)
Paper:https://openai.com/blog/better-language-models/Repository:https://github.com/huggingface/transformers/tree/master/src/transformers/models/gpt2
##### Table of Contents
* [1. Models](#models)
* [2. Basic Usage](#usage)
* [3. Documentation](#documentation)
* [4. Acknowledgments](#ack)
* [5. License](#license)
## 1. Models
| Model | Parameters | Size | URL |
| ------------- | ------------- | ------------- | ------------- |
| gpt2 | ~ 120 Million | ~ 500 MB | https://huggingface.co/gpt2 |
| gpt2-medium | ~ 350 Million | ~ 1.5 GB | https://huggingface.co/gpt2-medium |
| gpt2-large | ~ 800 Million | ~ 3 GB | https://huggingface.co/gpt2-large |
| gpt2-xl | ~ 1.5 Billion | ~ 6 GB | https://huggingface.co/gpt2-xl |
## 2. Basic Usage
For more usage examples check out this [Colab](gpt2_demo.ipynb).
This is very simple greedy text generation. There are more sophisticated methods out there.
```python
import jax
import jax.numpy as jnp
import flaxmodels as fm
key = jax.random.PRNGKey(0)
# Initialize tokenizer
tokenizer = fm.gpt2.get_tokenizer()
# Encode start sequence
generated = tokenizer.encode('The Manhattan bridge')
context = jnp.array([generated])
past = None
# Initialize model
# Models to choose from ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']
model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2')
params = model.init(key, input_ids=context, past_key_values=past)
for i in range(20):
# Predict next token in sequence
output = model.apply(params, input_ids=context, past_key_values=past, use_cache=True)
token = jnp.argmax(output['logits'][..., -1, :])
context = jnp.expand_dims(token, axis=0)
# Add token to sequence
generated += [token]
# Update past keys and values
past = output['past_key_values']
# Decode sequence of tokens
sequence = tokenizer.decode(generated)
print(sequence)
```
## 3. Documentation
The documentation can be found [here](../../docs/Documentation.md#gpt2).
## 4. Acknowledgments
The tokenizer is taken from Huggingface.
## 5. License
Apache-2.0 License
================================================
FILE: flaxmodels/flaxmodels/gpt2/__init__.py
================================================
from .gpt2 import GPT2Model
from .gpt2 import GPT2LMHeadModel
from .trajectory_gpt2 import GPT2Model as TrajectoryGPT2Model
from .trajectory_gpt2 import TransRewardModel
from .tokenizer import *
================================================
FILE: flaxmodels/flaxmodels/gpt2/gpt2.py
================================================
import jax
import jax.numpy as jnp
import flax.linen as nn
from typing import Any
import h5py
from .. import utils
from . import ops
URLS = {'gpt2': 'https://www.dropbox.com/s/0wdgj0gazwt9nm7/gpt2.h5?dl=1',
'gpt2-medium': 'https://www.dropbox.com/s/nam11kbd83wsm7d/gpt2-medium.h5?dl=1',
'gpt2-large': 'https://www.dropbox.com/s/oy8623qwkkjm8gt/gpt2-large.h5?dl=1',
'gpt2-xl': 'https://www.dropbox.com/s/6c6qt0bzz4v2afx/gpt2-xl.h5?dl=1'}
CONFIGS = {'gpt2': 'https://www.dropbox.com/s/s5xl32dgwc8322p/gpt2.json?dl=1',
'gpt2-medium': 'https://www.dropbox.com/s/7mwkijxoh1earm5/gpt2-medium.json?dl=1',
'gpt2-large': 'https://www.dropbox.com/s/nhslkxwxtpn7auz/gpt2-large.json?dl=1',
'gpt2-xl': 'https://www.dropbox.com/s/1iv0nq1xigsfdvb/gpt2-xl.json?dl=1'}
class GPT2SelfAttention(nn.Module):
"""
GPT2 Self Attention.
Attributes:
config (Any): Configuration object. If 'pretrained' is not None, this parameter will be ignored.
param_dict (dict): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.
"""
config: dict=None
param_dict: dict=None
def setup(self):
self.max_pos = self.config.n_positions
self.embd_dim = self.config.n_embd
self.num_heads = self.config.n_head
self.head_dim = self.embd_dim // self.num_heads
self.attn_dropout = self.config.attn_pdrop
self.resid_dropout = self.config.resid_pdrop
self.scale_attn_weights = self.config.scale_attn_weights
@nn.compact
def __call__(self, x, layer_past=None, attn_mask=None, head_mask=None, use_cache=False, training=False):
"""
Run attention.
Args:
x (tensor): Input tensor.
layer_past (Tuple): Tuple of past keys and values.
attn_mask (tensor): Mask to avoid performing attention on padding token indices.
head_mask (tensor): Mask to nullify selected heads of the self-attention modules.
use_cache (bool): If True, keys and values are returned (past_key_values).
training (bool): Training mode.
Returns:
(tensor, Tuple): Output tensor, tuple of keys and values.
"""
x = ops.linear(3 * self.embd_dim, ops.get(self.param_dict, 'c_proj'))(x)
query, key, value = jnp.split(x, 3, axis=2)
query = ops.split_heads(query, self.num_heads, self.head_dim)
value = ops.split_heads(value, self.num_heads, self.head_dim)
key = ops.split_heads(key, self.num_heads, self.head_dim)
if layer_past is not None:
past_key, past_value = layer_past
key = jnp.concatenate((past_key, key), axis=-2)
value = jnp.concatenate((past_value, value), axis=-2)
present = (key, value) if use_cache else None
query_len, key_len = query.shape[-2], key.shape[-2]
casual_mask = jnp.tril(jnp.ones((1, 1, self.max_pos, self.max_pos)))[:, :, key_len - query_len :key_len, :key_len]
casual_mask = casual_mask.astype(bool)
attn_dropout = nn.Dropout(rate=self.attn_dropout)
out, _ = ops.attention(query, key, value, casual_mask, -1e4, attn_dropout, self.scale_attn_weights, training, attn_mask, head_mask)
out = ops.merge_heads(out, self.num_heads, self.head_dim)
out = ops.linear(self.embd_dim, ops.get(self.param_dict, 'out_proj'))(out)
out = nn.Dropout(rate=self.resid_dropout)(out, deterministic=not training)
return out, present
class GPT2MLP(nn.Module):
"""
GPT2 MLP.
Attributes:
intermediate_dim (int): Dimension of the intermediate layer.
config (Any): Configuration object. If 'pretrained' is not None, this parameter will be ignored.
param_dict (dict): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.
"""
intermediate_dim: int
config: dict=None
param_dict: dict=None
def setup(self):
self.embd_dim = self.config.n_embd
self.resid_dropout = self.config.resid_pdrop
self.activation = self.config.activation_function
@nn.compact
def __call__(self, x, training=False):
"""
Run the MLP.
Args:
x (tensor): Input tensor.
training (bool): Training mode.
"""
x = ops.linear(self.intermediate_dim, ops.get(self.param_dict, 'c_fc'))(x)
x = ops.apply_activation(x, activation=self.activation)
x = ops.linear(self.embd_dim, ops.get(self.param_dict, 'c_proj'))(x)
x = nn.Dropout(rate=self.resid_dropout)(x, deterministic=not training)
return x
class GPT2Block(nn.Module):
"""
GPT2 Block.
Attributes:
config (Any): Configuration object. If 'pretrained' is not None, this parameter will be ignored.
param_dict (dict): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.
"""
config: dict=None
param_dict: dict=None
def setup(self):
self.embd_dim = self.config.n_embd
self.eps = self.config.layer_norm_epsilon
self.inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * self.embd_dim
@nn.compact
def __call__(self, x, layer_past=None, attn_mask=None, head_mask=None, use_cache=False, training=False):
"""
Run the block.
Args:
x (tensor): Input tensor.
layer_past (Tuple): Tuple of past keys and values.
attn_mask (tensor): Mask to avoid performing attention on padding token indices.
head_mask (tensor): Mask to nullify selected heads of the self-attention modules.
use_cache (bool): If True, keys and values are returned (past_key_values).
training (bool): Training mode.
Returns:
(tensor, Tuple): Output tensor, tuple of keys and values.
"""
residual = x
x = ops.layer_norm(ops.get(self.param_dict, 'ln_1'), eps=self.eps)(x)
kwargs = {'layer_past': layer_past, 'attn_mask': attn_mask, 'head_mask': head_mask,
'use_cache': use_cache, 'training': training}
x, present = GPT2SelfAttention(self.config, ops.get(self.param_dict, 'attn'))(x, **kwargs)
x += residual
residual = x
x = ops.layer_norm(ops.get(self.param_dict, 'ln_2'), eps=self.eps)(x)
x = GPT2MLP(self.inner_dim, self.config, ops.get(self.param_dict, 'mlp'))(x, training)
x += residual
return x, present
class GPT2Model(nn.Module):
"""
The GPT2 Model.
Attributes:
config (Any): Configuration object. If 'pretrained' is not None, this parameter will be ignored.
pretrained (str): Which pretrained model to use, None for random initialization.
ckpt_dir (str): Directory to which the pretrained weights are downloaded. If None, a temp directory will be used.
param_dict (dict): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.
"""
config: dict=None
pretrained: str=None
ckpt_dir: str=None
param_dict: dict=None
def setup(self):
if self.pretrained is not None:
assert self.pretrained in URLS.keys(), f'Pretrained model not available {self.pretrained}.'
ckpt_file = utils.download(self.ckpt_dir, URLS[self.pretrained])
self.param_dict_ = h5py.File(ckpt_file, 'r')['transformer']
config_file = utils.download(self.ckpt_dir, CONFIGS[self.pretrained])
self.config_ = ops.load_config(config_file)
else:
self.config_ = self.config
self.param_dict_ = self.param_dict
self.vocab_size = self.config_.vocab_size
self.max_pos = self.config_.n_positions
self.embd_dim = self.config_.n_embd
self.embd_dropout = self.config_.embd_pdrop
self.num_layers = self.config_.n_layer
self.eps = self.config_.layer_norm_epsilon
@nn.compact
def __call__(self,
input_ids=None,
past_key_values=None,
input_embds=None,
position_ids=None,
attn_mask=None,
head_mask=None,
use_cache=False,
training=False):
"""
Run the model.
Args:
input_ids (tensor): Input token ids, shape [B, seq_len].
past_key_values (Tuple): Precomputed hidden keys and values, tuple of tuples.
If past_key_values is used, only input_ids that do not have their
past calculated should be passed as input_ids.
input_embds (tensor): Input embeddings, shape [B, seq_len, embd_dim].
labels (tensor): Labels for language modeling, shape [B, seq_len]. Will be shifted inside the model. Ignore label = -100.
position_ids (tensor): Indices of positions of each input sequence tokens in the position embeddings, shape [B, seq_len].
attn_mask (tensor): Mask to avoid performing attention on padding token indices, shape [B, seq_len].
head_mask (tensor): Mask to nullify selected heads of the self-attention modules, shape [num_heads] or [num_layers, num_heads].
use_cache (bool): If True, keys and values are returned (past_key_values).
training (bool): Training mode.
Returns:
(dict): Dictionary containing 'last_hidden_state', 'past_key_values'.
"""
if input_ids is not None and input_embds is not None:
raise ValueError('You cannot specify both input_ids and input_embd at the same time.')
elif input_ids is not None:
input_shape = input_ids.shape
input_ids = jnp.reshape(input_ids, newshape=(-1, input_shape[-1]))
batch_size = input_ids.shape[0]
elif input_embds is not None:
input_shape = input_embds.shape[:-1]
batch_size = input_embds.shape[0]
else:
raise ValueError('You have to specify either input_ids or input_embd.')
if position_ids is not None:
position_ids = jnp.reshape(position_ids, newshape=(-1, input_shape[-1]))
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * self.num_layers)
else:
past_length = past_key_values[0][0].shape[-2]
if position_ids is None:
position_ids = jnp.arange(start=past_length, stop=input_shape[-1] + past_length)
position_ids = jnp.reshape(jnp.expand_dims(position_ids, axis=0), newshape=(-1, input_shape[-1]))
if input_embds is None:
input_embds = ops.embedding(self.vocab_size, self.embd_dim, ops.get(self.param_dict_, 'token_embd'))(input_ids)
if attn_mask is not None:
attn_mask = ops.get_attention_mask(attn_mask, batch_size)
if head_mask is not None:
head_mask = ops.get_head_mask(head_mask, self.num_layers)
else:
head_mask = [None] * self.num_layers
position_embds = ops.embedding(self.max_pos, self.embd_dim, ops.get(self.param_dict_, 'pos_embd'))(position_ids)
x = input_embds + position_embds
x = nn.Dropout(rate=self.embd_dropout)(x, deterministic=not training)
output_shape = input_shape + (x.shape[-1],)
presents = () if use_cache else None
for i in range(self.num_layers):
kwargs = {'layer_past': past_key_values[i], 'attn_mask': attn_mask, 'head_mask': head_mask[i],
'use_cache': use_cache, 'training': training}
x, present = GPT2Block(self.config_, ops.get(self.param_dict_, f'block{i}'))(x, **kwargs)
if use_cache:
presents = presents + (present,)
x = ops.layer_norm(ops.get(self.param_dict_, 'ln_final'), eps=self.eps)(x)
return {'last_hidden_state': x, 'past_key_values': presents}
class GPT2LMHeadModel(nn.Module):
"""
The GPT2 Model transformer with a language model head on top.
Attributes:
config (Any): Configuration object. If 'pretrained' is not None, this parameter will be ignored.
pretrained (str): Which pretrained model to use, None for random initialization.
ckpt_dir (str): Directory to which the pretrained weights are downloaded. If None, a temp directory will be used.
"""
config: Any=None
pretrained: str=None
ckpt_dir: str=None
def setup(self):
if self.pretrained is not None:
assert self.pretrained in URLS.keys(), f'Pretrained model not available {self.pretrained}.'
ckpt_file = utils.download(self.ckpt_dir, URLS[self.pretrained])
self.param_dict = h5py.File(ckpt_file, 'r')
config_file = utils.download(self.ckpt_dir, CONFIGS[self.pretrained])
self.config_ = ops.load_config(config_file)
else:
self.config_ = self.config
self.vocab_size = self.config_.vocab_size
self.max_pos = self.config_.n_positions
self.embd_dim = self.config_.n_embd
self.embd_dropout = self.config_.embd_pdrop
self.num_layers = self.config_.n_layer
self.eps = self.config_.layer_norm_epsilon
@nn.compact
def __call__(self,
input_ids=None,
past_key_values=None,
input_embds=None,
labels=None,
position_ids=None,
attn_mask=None,
head_mask=None,
use_cache=False,
training=False):
"""
Run the model.
Args:
input_ids (tensor): Input token ids, shape [B, seq_len].
past_key_values (Tuple): Precomputed hidden keys and values, tuple of tuples.
If past_key_values is used, only input_ids that do not have their
past calculated should be passed as input_ids.
input_embds (tensor): Input embeddings, shape [B, seq_len, embd_dim].
labels (tensor): Labels for language modeling, shape [B, seq_len]. Will be shifted inside the model. Ignore label = -100.
position_ids (tensor): Indices of positions of each input sequence tokens in the position embeddings, shape [B, seq_len].
attn_mask (tensor): Mask to avoid performing attention on padding token indices, shape [B, seq_len].
head_mask (tensor): Mask to nullify selected heads of the self-attention modules, shape [num_heads] or [num_layers, num_heads].
use_cache (bool): If True, keys and values are returned (past_key_values).
training (bool): Training mode.
Returns:
(dict): Dictionary containing 'last_hidden_state', 'past_key_values', 'loss', and 'logits'.
"""
kwargs = {'input_ids': input_ids,
'past_key_values': past_key_values,
'input_embds': input_embds,
'position_ids': position_ids,
'attn_mask': attn_mask,
'head_mask': head_mask,
'use_cache': use_cache,
'training': training}
output = GPT2Model(self.config_, param_dict=ops.get(self.param_dict, 'transformer'))(**kwargs)
lm_logits = ops.linear(self.vocab_size, ops.get(self.param_dict, 'lm_head'), bias=False)(output['last_hidden_state'])
loss = None
if labels is not None:
shift_logits = lm_logits[..., :-1, :]
shift_labels = labels[..., 1:]
# flatten the tokens
loss = ops.cross_entropy(jnp.reshape(shift_logits, (-1, shift_logits.shape[-1])), jnp.reshape(shift_labels, (-1)))
output['loss'] = loss
output['logits'] = lm_logits
return output
================================================
FILE: flaxmodels/flaxmodels/gpt2/gpt2_demo.ipynb
================================================
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "gpt2_demo.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6_i3EQa2yOzA",
"outputId": "07cea0ca-55a5-4545-fd64-064d0652690f"
},
"source": [
"!pip install --upgrade pip\n",
"!pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html\n",
"!pip install --upgrade git+https://github.com/matthias-wright/flaxmodels.git"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: pip in /usr/local/lib/python3.7/dist-packages (21.2.4)\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n",
"Looking in links: https://storage.googleapis.com/jax-releases/jax_releases.html\n",
"Requirement already satisfied: jax[cuda111] in /usr/local/lib/python3.7/dist-packages (0.2.19)\n",
"Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax[cuda111]) (3.3.0)\n",
"Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax[cuda111]) (0.12.0)\n",
"Requirement already satisfied: numpy>=1.18 in /usr/local/lib/python3.7/dist-packages (from jax[cuda111]) (1.19.5)\n",
"Collecting jaxlib==0.1.70+cuda111\n",
" Downloading https://storage.googleapis.com/jax-releases/cuda111/jaxlib-0.1.70%2Bcuda111-cp37-none-manylinux2010_x86_64.whl (197.0 MB)\n",
"\u001b[K |████████████████████████████████| 197.0 MB 19 kB/s \n",
"\u001b[?25hRequirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib==0.1.70+cuda111->jax[cuda111]) (1.12)\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from jaxlib==0.1.70+cuda111->jax[cuda111]) (1.4.1)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py->jax[cuda111]) (1.15.0)\n",
"Installing collected packages: jaxlib\n",
" Attempting uninstall: jaxlib\n",
" Found existing installation: jaxlib 0.1.66+cuda111\n",
" Uninstalling jaxlib-0.1.66+cuda111:\n",
" Successfully uninstalled jaxlib-0.1.66+cuda111\n",
"Successfully installed jaxlib-0.1.70+cuda111\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n",
"Collecting git+https://github.com/matthias-wright/flaxmodels.git\n",
" Cloning https://github.com/matthias-wright/flaxmodels.git to /tmp/pip-req-build-cg84k2dn\n",
" Running command git clone -q https://github.com/matthias-wright/flaxmodels.git /tmp/pip-req-build-cg84k2dn\n",
" Resolved https://github.com/matthias-wright/flaxmodels.git to commit 242ced2a4a12ace8adc32a705b08064ffeeb31ac\n",
"Requirement already satisfied: h5py==2.10.0 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (2.10.0)\n",
"Requirement already satisfied: numpy==1.19.5 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (1.19.5)\n",
"Requirement already satisfied: requests==2.23.0 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (2.23.0)\n",
"Requirement already satisfied: packaging==20.9 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (20.9)\n",
"Requirement already satisfied: dataclasses==0.6 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (0.6)\n",
"Requirement already satisfied: filelock==3.0.12 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (3.0.12)\n",
"Requirement already satisfied: jax in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (0.2.19)\n",
"Requirement already satisfied: jaxlib in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (0.1.70+cuda111)\n",
"Requirement already satisfied: flax in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (0.3.4)\n",
"Requirement already satisfied: Pillow==7.1.2 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (7.1.2)\n",
"Requirement already satisfied: regex==2021.4.4 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (2021.4.4)\n",
"Requirement already satisfied: tqdm==4.60.0 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (4.60.0)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from h5py==2.10.0->flaxmodels==0.1.0) (1.15.0)\n",
"Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging==20.9->flaxmodels==0.1.0) (2.4.7)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests==2.23.0->flaxmodels==0.1.0) (3.0.4)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests==2.23.0->flaxmodels==0.1.0) (2.10)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests==2.23.0->flaxmodels==0.1.0) (2021.5.30)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests==2.23.0->flaxmodels==0.1.0) (1.24.3)\n",
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from flax->flaxmodels==0.1.0) (3.2.2)\n",
"Requirement already satisfied: msgpack in /usr/local/lib/python3.7/dist-packages (from flax->flaxmodels==0.1.0) (1.0.2)\n",
"Requirement already satisfied: optax in /usr/local/lib/python3.7/dist-packages (from flax->flaxmodels==0.1.0) (0.0.9)\n",
"Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax->flaxmodels==0.1.0) (3.3.0)\n",
"Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax->flaxmodels==0.1.0) (0.12.0)\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from jaxlib->flaxmodels==0.1.0) (1.4.1)\n",
"Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib->flaxmodels==0.1.0) (1.12)\n",
"Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax->flaxmodels==0.1.0) (2.8.2)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax->flaxmodels==0.1.0) (1.3.1)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax->flaxmodels==0.1.0) (0.10.0)\n",
"Requirement already satisfied: chex>=0.0.4 in /usr/local/lib/python3.7/dist-packages (from optax->flax->flaxmodels==0.1.0) (0.0.8)\n",
"Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax->flax->flaxmodels==0.1.0) (0.11.1)\n",
"Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax->flax->flaxmodels==0.1.0) (0.1.6)\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qr2BfYc9YVHx"
},
"source": [
"# Generate text"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RHa6ySp-ywef"
},
"source": [
"This is very simple greedy text generation. There are more sophisticated [methods](https://huggingface.co/blog/how-to-generate) out there."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Y-nDnbE-yvWY",
"outputId": "3a8d9c4a-6349-4967-aacc-be9b8335f3c0"
},
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import flaxmodels as fm\n",
"\n",
"key = jax.random.PRNGKey(0)\n",
"\n",
"# Initialize tokenizer\n",
"tokenizer = fm.gpt2.get_tokenizer()\n",
"\n",
"# Encode start sequence\n",
"generated = tokenizer.encode('The Manhattan bridge')\n",
"\n",
"context = jnp.array([generated])\n",
"past = None\n",
"\n",
"# Initialize model\n",
"# Models to choose from ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']\n",
"model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2')\n",
"params = model.init(key, input_ids=context, past_key_values=past)\n",
"\n",
"for i in range(20):\n",
" # Predict next token in sequence\n",
" output = model.apply(params, input_ids=context, past_key_values=past, use_cache=True)\n",
" token = jnp.argmax(output['logits'][..., -1, :])\n",
" #context = jnp.expand_dims(token, axis=(0, 1))\n",
" context = jnp.expand_dims(token, axis=0)\n",
" # Add token to sequence\n",
" generated += [token]\n",
" # Update past keys and values\n",
" past = output['past_key_values']\n",
"\n",
"# Decode sequence of tokens\n",
"sequence = tokenizer.decode(generated)\n",
"\n",
"print()\n",
"print(sequence)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Downloading: \"https://www.dropbox.com/s/7f5n1gf348sy1mt/merges.txt\" to /tmp/flaxmodels/merges.txt\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"100%|██████████| 456k/456k [00:00<00:00, 12.1MiB/s]\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Downloading: \"https://www.dropbox.com/s/s93xkhgcac5nbmn/vocab.json\" to /tmp/flaxmodels/vocab.json\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"100%|██████████| 1.04M/1.04M [00:00<00:00, 23.1MiB/s]\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Downloading: \"https://www.dropbox.com/s/0wdgj0gazwt9nm7/gpt2.h5\" to /tmp/flaxmodels/gpt2.h5\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"100%|██████████| 703M/703M [00:14<00:00, 48.1MiB/s]\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Downloading: \"https://www.dropbox.com/s/s5xl32dgwc8322p/gpt2.json\" to /tmp/flaxmodels/gpt2.json\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"100%|██████████| 715/715 [00:00<00:00, 159kiB/s]\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"\n",
"The Manhattan bridge is a major artery for the city's subway system, and the bridge is one of the busiest in\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kKnwDOU2YhSN"
},
"source": [
"# Get language model head output from text input"
]
},
{
"cell_type": "code",
"metadata": {
"id": "zW-IBk_FYm9a"
},
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import flaxmodels as fm\n",
"\n",
"key = jax.random.PRNGKey(0)\n",
"\n",
"# Initialize tokenizer\n",
"tokenizer = fm.gpt2.get_tokenizer()\n",
"\n",
"# Encode start sequence\n",
"input_ids = tokenizer.encode('The Manhattan bridge')\n",
"input_ids = jnp.array([input_ids])\n",
"\n",
"# Initialize model\n",
"model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2')\n",
"params = model.init(key, input_ids=input_ids)\n",
"\n",
"# Compute output\n",
"output = model.apply(params, input_ids=input_ids, use_cache=True)\n",
"# output: {'last_hidden_state': ..., 'past_key_values': ..., 'loss': ..., 'logits': ...}"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ui2DneCuYrOA"
},
"source": [
"# Get language model head output from embeddings\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "W8PrhOpZYuRZ"
},
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import flaxmodels as fm\n",
" \n",
"key = jax.random.PRNGKey(0)\n",
"\n",
"# Dummy input \n",
"input_embds = jax.random.normal(key, shape=(2, 10, 768))\n",
"\n",
"# Initialize model\n",
"model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2')\n",
"params = model.init(key, input_embds=input_embds)\n",
"# Compute output\n",
"output = model.apply(params, input_embds=input_embds, use_cache=True)\n",
"# output: {'last_hidden_state': ..., 'past_key_values': ..., 'loss': ..., 'logits': ...}"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "j0IUgj4yYwET"
},
"source": [
"# Get model output from text input"
]
},
{
"cell_type": "code",
"metadata": {
"id": "jSuAZ1YjYxmo"
},
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import flaxmodels as fm\n",
"\n",
"key = jax.random.PRNGKey(0)\n",
"\n",
"# Initialize tokenizer\n",
"tokenizer = fm.gpt2.get_tokenizer()\n",
"\n",
"# Encode start sequence\n",
"input_ids = tokenizer.encode('The Manhattan bridge')\n",
"input_ids = jnp.array([input_ids])\n",
"\n",
"# Initialize model\n",
"model = fm.gpt2.GPT2Model(pretrained='gpt2')\n",
"params = model.init(key, input_ids=input_ids)\n",
"\n",
"# Compute output\n",
"output = model.apply(params, input_ids=input_ids, use_cache=True)\n",
"# output: {'last_hidden_state': ..., 'past_key_values': ...}"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "-jR2kX9GYzIn"
},
"source": [
"# Get model output from embeddings"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Z1taV3BGY06n"
},
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import flaxmodels as fm\n",
" \n",
"key = jax.random.PRNGKey(0)\n",
"\n",
"# Dummy input\n",
"input_embds = jax.random.normal(key, shape=(2, 10, 768))\n",
" \n",
"# Initialize model\n",
"model = fm.gpt2.GPT2Model(pretrained='gpt2')\n",
"params = model.init(key, input_embds=input_embds)\n",
"\n",
"# Compute output\n",
"output = model.apply(params, input_embds=input_embds, use_cache=True)\n",
"# output: {'last_hidden_state': ..., 'past_key_values': ...}"
],
"execution_count": null,
"outputs": []
}
]
}
================================================
FILE: flaxmodels/flaxmodels/gpt2/ops.py
================================================
import jax
import jax.numpy as jnp
import flax.linen as nn
import math
import json
from types import SimpleNamespace
#----------------------------------------------------------
# Linear
#----------------------------------------------------------
def linear(features, param_dict, bias=True):
if param_dict is None:
return nn.Dense(features=features, use_bias=bias)
else:
if bias:
assert 'bias' in param_dict
assert 'weight' in param_dict
return nn.Dense(features=features,
kernel_init=lambda *_ : jnp.array(param_dict['weight']),
bias_init=lambda *_ : jnp.array(param_dict['bias']))
else:
assert 'weight' in param_dict
return nn.Dense(features=features,
kernel_init=lambda *_ : jnp.array(param_dict['weight']))
def embedding(num_embeddings, features, param_dict, dtype='float32'):
if param_dict is None:
return nn.Embed(num_embeddings=num_embeddings, features=features, dtype=dtype)
else:
assert 'weight' in param_dict
embedding_init = lambda *_ : jnp.array(param_dict['weight'])
return nn.Embed(num_embeddings=num_embeddings, features=features, embedding_init=embedding_init, dtype=dtype)
#----------------------------------------------------------
# Activation
#----------------------------------------------------------
def apply_activation(x, activation='linear'):
if activation == 'linear':
return x
elif activation == 'gelu_new':
return 0.5 * x * (1.0 + nn.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * jnp.power(x, 3.0))))
elif activation == 'gelu_fast':
return 0.5 * x * (1.0 + nn.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
elif activation == 'gelu':
return jax.nn.gelu(x)
elif activation == 'relu':
return jax.nn.relu(x)
elif activation == 'leaky_relu':
return jax.nn.leaky_relu(x)
elif activation == 'sigmoid':
return jax.nn.sigmoid(x)
elif activation == 'tanh':
return nn.tanh(x)
else:
raise ValueError(f'Unknown activation function: {activation}.')
#----------------------------------------------------------
# Normalization
#----------------------------------------------------------
def layer_norm(param_dict, use_bias=True, use_scale=True, eps=1e-06, dtype='float32'):
if param_dict is None:
return nn.LayerNorm(use_bias=use_bias, use_scale=use_scale, epsilon=eps, dtype=dtype)
else:
kwargs = {'use_bias': use_bias, 'use_scale': use_scale, 'epsilon': eps, 'dtype': dtype}
if use_bias:
assert 'bias' in param_dict, 'use_bias is set True but bias parameter does not exist in param_dict.'
kwargs['bias_init'] = lambda *_ : jnp.array(param_dict['bias'])
if use_scale:
assert 'scale' in param_dict, 'use_scale is set True but scale parameter does not exist in param_dict.'
kwargs['scale_init'] = lambda *_ : jnp.array(param_dict['scale'])
return nn.LayerNorm(**kwargs)
#----------------------------------------------------------
# Attention
#----------------------------------------------------------
def split_heads(x, num_heads, head_dim):
"""
Splits embeddings for different heads.
Args:
x (tensor): Input tensor, shape [B, seq_len, embd_dim] or [B, blocks, block_len, embd_dim].
num_heads (int): Number of heads.
head_dim (int): Dimension of embedding for each head.
Returns:
(tensor): Output tensor, shape [B, num_head, seq_len, head_dim] or [B, blocks, num_head, block_len, head_dim].
"""
newshape = x.shape[:-1] + (num_heads, head_dim)
x = jnp.reshape(x, newshape)
if x.ndim == 5:
# [batch, blocks, head, block_len, head_dim]
return jnp.transpose(x, axes=(0, 1, 3, 2, 4))
elif x.ndim == 4:
# [batch, head, seq_len, head_dim]
return jnp.transpose(x, axes=(0, 2, 1, 3))
else:
raise ValueError(f'Input tensor should have rank 4 or 5, but has rank {x.ndim}.')
def merge_heads(x, num_heads, head_dim):
"""
Merge embeddings for different heads.
Args:
x (tensor): Input tensor, shape [B, num_head, seq_len, head_dim] or [B, blocks, num_head, block_len, head_dim].
num_heads (int): Number of heads.
head_dim (int): Dimension of embedding for each head.
Returns:
(tensor): Output tensor, shape [B, seq_len, embd_dim] or [B, blocks, block_len, embd_dim].
"""
if x.ndim == 5:
x = jnp.transpose(x, axes=(0, 1, 3, 2, 4))
elif x.ndim == 4:
x = jnp.transpose(x, axes=(0, 2, 1, 3))
else:
raise ValueError(f'Input tensor should have rank 4 or 5, but has rank {x.ndim}.')
newshape = x.shape[:-2] + (num_heads * head_dim,)
x = jnp.reshape(x, newshape)
return x
def attention(query, key, value, casual_mask, masked_bias, dropout, scale_attn_weights, training, attn_mask=None, head_mask=None, feedback=None):
"""
Computes Dot-Product Attention for the given query, key and value.
Args:
query (tensor): Query, shape [B, num_heads, seq_len, embd_dim].
key (tensor): Key, shape [B, num_heads, seq_len, embd_dim].
value (tensor): Value, shape [B, num_heads, seq_len, embd_dim].
casual_mask (tensor): Mask to ensure that attention is only applied to the left of the input sequence,
shape [1, 1, key_len - query_len :key_len, :key_len].
masked_bias (float): Value to insert for masked part of the sequence.
dropout (nn.Dropout): Dropout module that is applied to the attention output.
scale_attn_weights (bool): If True, scale the attention weights.
training (bool): Training mode.
attn_mask (tensor): Mask to avoid performing attention on padded tokens indices, shape [B, seq_len].
head_mask (tensor): Mask to nullify selected heads of the self-attention modules, shape [num_heads,] or [num_layers, num_heads].
feedback (tensor): external feedback with marked points.
Returns:
(tensor): Attention output, shape [B, num_heads, seq_len, embd_dim].
(tensor): Attention weights, shape [B, num_heads, seq_len, seq_len].
(tensor): KLD loss with external feedback, float.
"""
query = query.astype(jnp.float32)
key = key.astype(jnp.float32)
attn_weights = jnp.matmul(query, jnp.swapaxes(key, -1, -2))
if scale_attn_weights:
attn_weights = attn_weights / (float(value.shape[-1]) ** 0.5)
attn_weights = jnp.where(casual_mask, attn_weights, masked_bias)
if attn_mask is not None:
attn_weights = attn_weights + attn_mask
_attn_weights = nn.softmax(attn_weights, axis=-1)
attn_weights = _attn_weights.astype(value.dtype)
attn_weights = dropout(attn_weights, deterministic=not training)
if head_mask is not None:
attn_weights = attn_weights * head_mask
out = jnp.matmul(attn_weights, value)
return out, _attn_weights
#----------------------------------------------------------
# Losses
#----------------------------------------------------------
def cross_entropy(logits, labels, ignore_index=-100):
"""
Computes the cross entroy loss (on logits).
Args:
logits (tensor): Logits, shape [B, num_classes].
labels (tensor): Labels, shape [B,].
ignore_index (int): Value of label to ignore for loss computation.
Returns:
(tensor): Cross entroy loss.
"""
batch_size, num_classes = logits.shape
logits = nn.log_softmax(logits)
# Get indices where label is equal to ignore_index
idx = jnp.nonzero(labels == ignore_index)[0]
one_hot_labels = jax.nn.one_hot(labels, num_classes=num_classes)
mult = one_hot_labels * logits
# Insert zeros, where the labels are equal to ignore_index
mult = mult.at[idx].set(jnp.zeros((idx.shape[0], num_classes)))
return -jnp.sum(jnp.sum(mult, axis=-1)) / (batch_size - idx.shape[0])
def kld_loss(p, q):
return jnp.sum(jnp.where(p != 0, p * (jnp.log(p) - jnp.log(q)), 0))
#----------------------------------------------------------
# Misc
#----------------------------------------------------------
def get(dictionary, key):
if dictionary is None or key not in dictionary:
return None
return dictionary[key]
def get_attention_mask(attn_mask, batch_size):
assert batch_size > 0, 'batch_size should be > 0.'
attn_mask = jnp.reshape(attn_mask, newshape=(batch_size, -1))
attn_mask = jnp.expand_dims(attn_mask, axis=(1, 2))
attn_mask = (1.0 - attn_mask) * -10000.0
return attn_mask
def get_head_mask(head_mask, num_layers):
if head_mask.ndim == 1:
head_mask = jnp.expand_dims(head_mask, newshape=(0, 1, -2, -1))
head_mask = jnp.repeat(head_mask, repeats=num_layers, axis=0)
elif head_mask.ndim == 2:
head_mask = jnp.expand_dims(head_mask, newshape=(1, -2, -1))
else:
raise ValueError(f'head_mask must have rank 5, but has rank {head_mask.ndim}.')
return head_mask
def load_config(path):
return json.loads(open(path, 'r', encoding='utf-8').read(), object_hook=lambda d : SimpleNamespace(**d))
def custom_softmax(array, axis=-1, temperature=1.0):
array = array / temperature
return jax.nn.softmax(array, axis=axis)
def mse_loss(val, target):
return jnp.mean(jnp.square(val - target))
================================================
FILE: flaxmodels/flaxmodels/gpt2/third_party/__init__.py
================================================
================================================
FILE: flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/__init__.py
================================================
================================================
FILE: flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/configuration_gpt2.py
================================================
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
import json
import os
from functools import lru_cache
from typing import TYPE_CHECKING, List, Optional, Tuple
import regex as re
from .utils.tokenization_utils import AddedToken, PreTrainedTokenizer
from .utils import logging
if TYPE_CHECKING:
from transformers.pipelines.conversational import Conversation
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json",
"merges_file": "merges.txt",
}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"gpt2": "https://huggingface.co/gpt2/resolve/main/vocab.json",
"gpt2-medium": "https://huggingface.co/gpt2-medium/resolve/main/vocab.json",
"gpt2-large": "https://huggingface.co/gpt2-large/resolve/main/vocab.json",
"gpt2-xl": "https://huggingface.co/gpt2-xl/resolve/main/vocab.json",
"distilgpt2": "https://huggingface.co/distilgpt2/resolve/main/vocab.json",
},
"merges_file": {
"gpt2": "https://huggingface.co/gpt2/resolve/main/merges.txt",
"gpt2-medium": "https://huggingface.co/gpt2-medium/resolve/main/merges.txt",
"gpt2-large": "https://huggingface.co/gpt2-large/resolve/main/merges.txt",
"gpt2-xl": "https://huggingface.co/gpt2-xl/resolve/main/merges.txt",
"distilgpt2": "https://huggingface.co/distilgpt2/resolve/main/merges.txt",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"gpt2": 1024,
"gpt2-medium": 1024,
"gpt2-large": 1024,
"gpt2-xl": 1024,
"distilgpt2": 1024,
}
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
characters the bpe code barfs on.
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
decent coverage. This is a signficant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
tables between utf-8 bytes and unicode strings.
"""
bs = (
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2 ** 8):
if b not in bs:
bs.append(b)
cs.append(2 ** 8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""
Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class GPT2Tokenizer(PreTrainedTokenizer):
"""
Construct a GPT-2 tokenizer. Based on byte-level Byte-Pair-Encoding.
This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
be encoded differently whether it is at the beginning of the sentence (without space) or not:
::
>>> from transformers import GPT2Tokenizer
>>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
>>> tokenizer("Hello world")['input_ids']
[15496, 995]
>>> tokenizer(" Hello world")['input_ids']
[18435, 995]
You can get around that behavior by passing ``add_prefix_space=True`` when instantiating this tokenizer or when you
call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
.. note::
When used with ``is_split_into_words=True``, this tokenizer will add a space before each word (even the first
one).
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
Users should refer to this superclass for more information regarding those methods.
Args:
vocab_file (:obj:`str`):
Path to the vocabulary file.
merges_file (:obj:`str`):
Path to the merges file.
errors (:obj:`str`, `optional`, defaults to :obj:`"replace"`):
Paradigm to follow when decoding bytes to UTF-8. See `bytes.decode
`__ for more information.
unk_token (:obj:`str`, `optional`, defaults to :obj:`<|endoftext|>`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
bos_token (:obj:`str`, `optional`, defaults to :obj:`<|endoftext|>`):
The beginning of sequence token.
eos_token (:obj:`str`, `optional`, defaults to :obj:`<|endoftext|>`):
The end of sequence token.
add_prefix_space (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
other word. (GPT2 tokenizer detect beginning of words by the preceding space).
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
merges_file,
errors="replace",
unk_token="<|endoftext|>",
bos_token="<|endoftext|>",
eos_token="<|endoftext|>",
add_prefix_space=False,
**kwargs
):
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
super().__init__(
errors=errors,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
add_prefix_space=add_prefix_space,
**kwargs,
)
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
with open(merges_file, encoding="utf-8") as merges_handle:
bpe_merges = merges_handle.read().split("\n")[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
self.add_prefix_space = add_prefix_space
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
@property
def vocab_size(self):
return len(self.encoder)
def get_vocab(self):
return dict(self.encoder, **self.added_tokens_encoder)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
except ValueError:
new_word.extend(word[i:])
break
else:
new_word.extend(word[i:j])
i = j
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
def _tokenize(self, text):
""" Tokenize a string. """
bpe_tokens = []
for token in re.findall(self.pat, text):
token = "".join(
self.byte_encoder[b] for b in token.encode("utf-8")
) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
return bpe_tokens
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.decoder.get(index)
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
text = "".join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
return text
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
merge_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write("#version: 0.2\n")
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning(
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
)
index = token_index
writer.write(" ".join(bpe_tokens) + "\n")
index += 1
return vocab_file, merge_file
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
if is_split_into_words or add_prefix_space:
text = " " + text
return (text, kwargs)
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
input_ids = []
for is_user, text in conversation.iter_texts():
input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
if len(input_ids) > self.model_max_length:
input_ids = input_ids[-self.model_max_length :]
return input_ids
================================================
FILE: flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/utils/__init__.py
================================================
================================================
FILE: flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/utils/file_utils.py
================================================
# Copyright 2020 The HuggingFace Team, the AllenNLP library authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities for working with the local dataset cache. Parts of this file is adapted from the AllenNLP library at
https://github.com/allenai/allennlp.
"""
import copy
import fnmatch
import importlib.util
import io
import json
import os
import re
import shutil
import sys
import tarfile
import tempfile
from collections import OrderedDict, UserDict
from contextlib import contextmanager
from dataclasses import fields
from enum import Enum
from functools import partial, wraps
from hashlib import sha256
from pathlib import Path
from types import ModuleType
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse
from uuid import uuid4
from zipfile import ZipFile, is_zipfile
import numpy as np
from packaging import version
from tqdm.auto import tqdm
import requests
from filelock import FileLock
from .versions import importlib_metadata
#from . import __version__
from .hf_api import HfFolder
from . import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
_torch_available = importlib.util.find_spec("torch") is not None
if _torch_available:
try:
_torch_version = importlib_metadata.version("torch")
logger.info(f"PyTorch version {_torch_version} available.")
except importlib_metadata.PackageNotFoundError:
_torch_available = False
else:
logger.info("Disabling PyTorch because USE_TF is set")
_torch_available = False
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
_tf_available = importlib.util.find_spec("tensorflow") is not None
if _tf_available:
candidates = (
"tensorflow",
"tensorflow-cpu",
"tensorflow-gpu",
"tf-nightly",
"tf-nightly-cpu",
"tf-nightly-gpu",
"intel-tensorflow",
)
_tf_version = None
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
for pkg in candidates:
try:
_tf_version = importlib_metadata.version(pkg)
break
except importlib_metadata.PackageNotFoundError:
pass
_tf_available = _tf_version is not None
if _tf_available:
if version.parse(_tf_version) < version.parse("2"):
logger.info(f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum.")
_tf_available = False
else:
logger.info(f"TensorFlow version {_tf_version} available.")
else:
logger.info("Disabling Tensorflow because USE_TORCH is set")
_tf_available = False
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
if _flax_available:
try:
_jax_version = importlib_metadata.version("jax")
_flax_version = importlib_metadata.version("flax")
logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
except importlib_metadata.PackageNotFoundError:
_flax_available = False
else:
_flax_available = False
_datasets_available = importlib.util.find_spec("datasets") is not None
try:
# Check we're not importing a "datasets" directory somewhere but the actual library by trying to grab the version
# AND checking it has an author field in the metadata that is HuggingFace.
_ = importlib_metadata.version("datasets")
_datasets_metadata = importlib_metadata.metadata("datasets")
if _datasets_metadata.get("author", "") != "HuggingFace Inc.":
_datasets_available = False
except importlib_metadata.PackageNotFoundError:
_datasets_available = False
_faiss_available = importlib.util.find_spec("faiss") is not None
try:
_faiss_version = importlib_metadata.version("faiss")
logger.debug(f"Successfully imported faiss version {_faiss_version}")
except importlib_metadata.PackageNotFoundError:
try:
_faiss_version = importlib_metadata.version("faiss-cpu")
logger.debug(f"Successfully imported faiss version {_faiss_version}")
except importlib_metadata.PackageNotFoundError:
_faiss_available = False
_onnx_available = (
importlib.util.find_spec("keras2onnx") is not None and importlib.util.find_spec("onnxruntime") is not None
)
try:
_onxx_version = importlib_metadata.version("onnx")
logger.debug(f"Successfully imported onnx version {_onxx_version}")
except importlib_metadata.PackageNotFoundError:
_onnx_available = False
_scatter_available = importlib.util.find_spec("torch_scatter") is not None
try:
_scatter_version = importlib_metadata.version("torch_scatter")
logger.debug(f"Successfully imported torch-scatter version {_scatter_version}")
except importlib_metadata.PackageNotFoundError:
_scatter_available = False
_soundfile_available = importlib.util.find_spec("soundfile") is not None
try:
_soundfile_version = importlib_metadata.version("soundfile")
logger.debug(f"Successfully imported soundfile version {_soundfile_version}")
except importlib_metadata.PackageNotFoundError:
_soundfile_available = False
_torchaudio_available = importlib.util.find_spec("torchaudio") is not None
try:
_torchaudio_version = importlib_metadata.version("torchaudio")
logger.debug(f"Successfully imported torchaudio version {_torchaudio_version}")
except importlib_metadata.PackageNotFoundError:
_torchaudio_available = False
torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
old_default_cache_path = os.path.join(torch_cache_home, "transformers")
# New default cache, shared with the Datasets library
hf_cache_home = os.path.expanduser(
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
)
default_cache_path = os.path.join(hf_cache_home, "transformers")
# Onetime move from the old location to the new one if no ENV variable has been set.
if (
os.path.isdir(old_default_cache_path)
and not os.path.isdir(default_cache_path)
and "PYTORCH_PRETRAINED_BERT_CACHE" not in os.environ
and "PYTORCH_TRANSFORMERS_CACHE" not in os.environ
and "TRANSFORMERS_CACHE" not in os.environ
):
logger.warning(
"In Transformers v4.0.0, the default path to cache downloaded models changed from "
"'~/.cache/torch/transformers' to '~/.cache/huggingface/transformers'. Since you don't seem to have overridden "
"and '~/.cache/torch/transformers' is a directory that exists, we're moving it to "
"'~/.cache/huggingface/transformers' to avoid redownloading models you have already in the cache. You should "
"only see this message once."
)
shutil.move(old_default_cache_path, default_cache_path)
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
SESSION_ID = uuid4().hex
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False) in ENV_VARS_TRUE_VALUES
WEIGHTS_NAME = "pytorch_model.bin"
TF2_WEIGHTS_NAME = "tf_model.h5"
TF_WEIGHTS_NAME = "model.ckpt"
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
CONFIG_NAME = "config.json"
FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
MODEL_CARD_NAME = "modelcard.json"
SENTENCEPIECE_UNDERLINE = "▁"
SPIECE_UNDERLINE = SENTENCEPIECE_UNDERLINE # Kept for backward compatibility
MULTIPLE_CHOICE_DUMMY_INPUTS = [
[[0, 1, 0, 1], [1, 0, 0, 1]]
] * 2 # Needs to have 0s and 1s only since XLM uses it for langs too.
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
HUGGINGFACE_CO_PREFIX = "https://huggingface.co/{model_id}/resolve/{revision}/{filename}"
PRESET_MIRROR_DICT = {
"tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models",
"bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
}
_is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False
def is_offline_mode():
return _is_offline_mode
def is_torch_available():
return _torch_available
def is_torch_cuda_available():
if is_torch_available():
import torch
return torch.cuda.is_available()
else:
return False
def is_tf_available():
return _tf_available
def is_onnx_available():
return _onnx_available
def is_flax_available():
return _flax_available
def is_torch_tpu_available():
if not _torch_available:
return False
# This test is probably enough, but just in case, we unpack a bit.
if importlib.util.find_spec("torch_xla") is None:
return False
if importlib.util.find_spec("torch_xla.core") is None:
return False
return importlib.util.find_spec("torch_xla.core.xla_model") is not None
def is_datasets_available():
return _datasets_available
def is_psutil_available():
return importlib.util.find_spec("psutil") is not None
def is_py3nvml_available():
return importlib.util.find_spec("py3nvml") is not None
def is_apex_available():
return importlib.util.find_spec("apex") is not None
def is_faiss_available():
return _faiss_available
def is_sklearn_available():
if importlib.util.find_spec("sklearn") is None:
return False
if importlib.util.find_spec("scipy") is None:
return False
return importlib.util.find_spec("sklearn.metrics") and importlib.util.find_spec("scipy.stats")
def is_sentencepiece_available():
return importlib.util.find_spec("sentencepiece") is not None
def is_protobuf_available():
if importlib.util.find_spec("google") is None:
return False
return importlib.util.find_spec("google.protobuf") is not None
def is_tokenizers_available():
return importlib.util.find_spec("tokenizers") is not None
def is_vision_available():
return importlib.util.find_spec("PIL") is not None
def is_in_notebook():
try:
# Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
get_ipython = sys.modules["IPython"].get_ipython
if "IPKernelApp" not in get_ipython().config:
raise ImportError("console")
if "VSCODE_PID" in os.environ:
raise ImportError("vscode")
return importlib.util.find_spec("IPython") is not None
except (AttributeError, ImportError, KeyError):
return False
def is_scatter_available():
return _scatter_available
def is_pandas_available():
return importlib.util.find_spec("pandas") is not None
def is_sagemaker_dp_enabled():
# Get the sagemaker specific env variable.
sagemaker_params = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
try:
# Parse it and check the field "sagemaker_distributed_dataparallel_enabled".
sagemaker_params = json.loads(sagemaker_params)
if not sagemaker_params.get("sagemaker_distributed_dataparallel_enabled", False):
return False
except json.JSONDecodeError:
return False
# Lastly, check if the `smdistributed` module is present.
return importlib.util.find_spec("smdistributed") is not None
def is_sagemaker_mp_enabled():
# Get the sagemaker specific mp parameters from smp_options variable.
smp_options = os.getenv("SM_HP_MP_PARAMETERS", "{}")
try:
# Parse it and check the field "partitions" is included, it is required for model parallel.
smp_options = json.loads(smp_options)
if "partitions" not in smp_options:
return False
except json.JSONDecodeError:
return False
# Get the sagemaker specific framework parameters from mpi_options variable.
mpi_options = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
try:
# Parse it and check the field "sagemaker_distributed_dataparallel_enabled".
mpi_options = json.loads(mpi_options)
if not mpi_options.get("sagemaker_mpi_enabled", False):
return False
except json.JSONDecodeError:
return False
# Lastly, check if the `smdistributed` module is present.
return importlib.util.find_spec("smdistributed") is not None
def is_training_run_on_sagemaker():
return "SAGEMAKER_JOB_NAME" in os.environ
def is_soundfile_availble():
return _soundfile_available
def is_torchaudio_available():
return _torchaudio_available
def is_speech_available():
# For now this depends on torchaudio but the exact dependency might evolve in the future.
return _torchaudio_available
def torch_only_method(fn):
def wrapper(*args, **kwargs):
if not _torch_available:
raise ImportError(
"You need to install pytorch to use this method or class, "
"or activate it with environment variables USE_TORCH=1 and USE_TF=0."
)
else:
return fn(*args, **kwargs)
return wrapper
# docstyle-ignore
DATASETS_IMPORT_ERROR = """
{0} requires the 🤗 Datasets library but it was not found in your environment. You can install it with:
```
pip install datasets
```
In a notebook or a colab, you can install it by executing a cell with
```
!pip install datasets
```
then restarting your kernel.
Note that if you have a local folder named `datasets` or a local python file named `datasets.py` in your current
working directory, python may try to import this instead of the 🤗 Datasets library. You should rename this folder or
that python file if that's the case.
"""
# docstyle-ignore
TOKENIZERS_IMPORT_ERROR = """
{0} requires the 🤗 Tokenizers library but it was not found in your environment. You can install it with:
```
pip install tokenizers
```
In a notebook or a colab, you can install it by executing a cell with
```
!pip install tokenizers
```
"""
# docstyle-ignore
SENTENCEPIECE_IMPORT_ERROR = """
{0} requires the SentencePiece library but it was not found in your environment. Checkout the instructions on the
installation page of its repo: https://github.com/google/sentencepiece#installation and follow the ones
that match your environment.
"""
# docstyle-ignore
PROTOBUF_IMPORT_ERROR = """
{0} requires the protobuf library but it was not found in your environment. Checkout the instructions on the
installation page of its repo: https://github.com/protocolbuffers/protobuf/tree/master/python#installation and follow the ones
that match your environment.
"""
# docstyle-ignore
FAISS_IMPORT_ERROR = """
{0} requires the faiss library but it was not found in your environment. Checkout the instructions on the
installation page of its repo: https://github.com/facebookresearch/faiss/blob/master/INSTALL.md and follow the ones
that match your environment.
"""
# docstyle-ignore
PYTORCH_IMPORT_ERROR = """
{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
"""
# docstyle-ignore
SKLEARN_IMPORT_ERROR = """
{0} requires the scikit-learn library but it was not found in your environment. You can install it with:
```
pip install -U scikit-learn
```
In a notebook or a colab, you can install it by executing a cell with
```
!pip install -U scikit-learn
```
"""
# docstyle-ignore
TENSORFLOW_IMPORT_ERROR = """
{0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the
installation page: https://www.tensorflow.org/install and follow the ones that match your environment.
"""
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
installation page: https://github.com/google/flax and follow the ones that match your environment.
"""
# docstyle-ignore
SCATTER_IMPORT_ERROR = """
{0} requires the torch-scatter library but it was not found in your environment. You can install it with pip as
explained here: https://github.com/rusty1s/pytorch_scatter.
"""
# docstyle-ignore
PANDAS_IMPORT_ERROR = """
{0} requires the pandas library but it was not found in your environment. You can install it with pip as
explained here: https://pandas.pydata.org/pandas-docs/stable/getting_started/install.html.
"""
# docstyle-ignore
SPEECH_IMPORT_ERROR = """
{0} requires the torchaudio library but it was not found in your environment. You can install it with pip:
`pip install torchaudio`
"""
# docstyle-ignore
VISION_IMPORT_ERROR = """
{0} requires the PIL library but it was not found in your environment. You can install it with pip:
`pip install pillow`
"""
BACKENDS_MAPPING = OrderedDict(
[
("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)),
("faiss", (is_faiss_available, FAISS_IMPORT_ERROR)),
("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
("pandas", (is_pandas_available, PANDAS_IMPORT_ERROR)),
("protobuf", (is_protobuf_available, PROTOBUF_IMPORT_ERROR)),
("scatter", (is_scatter_available, SCATTER_IMPORT_ERROR)),
("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
("sklearn", (is_sklearn_available, SKLEARN_IMPORT_ERROR)),
("speech", (is_speech_available, SPEECH_IMPORT_ERROR)),
("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
("tokenziers", (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)),
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
("vision", (is_vision_available, VISION_IMPORT_ERROR)),
]
)
def requires_backends(obj, backends):
if not isinstance(backends, (list, tuple)):
backends = [backends]
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
if not all(BACKENDS_MAPPING[backend][0]() for backend in backends):
raise ImportError("".join([BACKENDS_MAPPING[backend][1].format(name) for backend in backends]))
def add_start_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
return fn
return docstring_decorator
def add_start_docstrings_to_model_forward(*docstr):
def docstring_decorator(fn):
class_name = f":class:`~transformers.{fn.__qualname__.split('.')[0]}`"
intro = f" The {class_name} forward method, overrides the :func:`__call__` special method."
note = r"""
.. note::
Although the recipe for forward pass needs to be defined within this function, one should call the
:class:`Module` instance afterwards instead of this since the former takes care of running the pre and post
processing steps while the latter silently ignores them.
"""
fn.__doc__ = intro + note + "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
return fn
return docstring_decorator
def add_end_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = fn.__doc__ + "".join(docstr)
return fn
return docstring_decorator
PT_RETURN_INTRODUCTION = r"""
Returns:
:class:`~{full_output_type}` or :obj:`tuple(torch.FloatTensor)`: A :class:`~{full_output_type}` (if
``return_dict=True`` is passed or when ``config.return_dict=True``) or a tuple of :obj:`torch.FloatTensor`
comprising various elements depending on the configuration (:class:`~transformers.{config_class}`) and inputs.
"""
TF_RETURN_INTRODUCTION = r"""
Returns:
:class:`~{full_output_type}` or :obj:`tuple(tf.Tensor)`: A :class:`~{full_output_type}` (if
``return_dict=True`` is passed or when ``config.return_dict=True``) or a tuple of :obj:`tf.Tensor` comprising
various elements depending on the configuration (:class:`~transformers.{config_class}`) and inputs.
"""
def _get_indent(t):
"""Returns the indentation in the first line of t"""
search = re.search(r"^(\s*)\S", t)
return "" if search is None else search.groups()[0]
def _convert_output_args_doc(output_args_doc):
"""Convert output_args_doc to display properly."""
# Split output_arg_doc in blocks argument/description
indent = _get_indent(output_args_doc)
blocks = []
current_block = ""
for line in output_args_doc.split("\n"):
# If the indent is the same as the beginning, the line is the name of new arg.
if _get_indent(line) == indent:
if len(current_block) > 0:
blocks.append(current_block[:-1])
current_block = f"{line}\n"
else:
# Otherwise it's part of the description of the current arg.
# We need to remove 2 spaces to the indentation.
current_block += f"{line[2:]}\n"
blocks.append(current_block[:-1])
# Format each block for proper rendering
for i in range(len(blocks)):
blocks[i] = re.sub(r"^(\s+)(\S+)(\s+)", r"\1- **\2**\3", blocks[i])
blocks[i] = re.sub(r":\s*\n\s*(\S)", r" -- \1", blocks[i])
return "\n".join(blocks)
def _prepare_output_docstrings(output_type, config_class):
"""
Prepares the return part of the docstring using `output_type`.
"""
docstrings = output_type.__doc__
# Remove the head of the docstring to keep the list of args only
lines = docstrings.split("\n")
i = 0
while i < len(lines) and re.search(r"^\s*(Args|Parameters):\s*$", lines[i]) is None:
i += 1
if i < len(lines):
docstrings = "\n".join(lines[(i + 1) :])
docstrings = _convert_output_args_doc(docstrings)
# Add the return introduction
full_output_type = f"{output_type.__module__}.{output_type.__name__}"
intro = TF_RETURN_INTRODUCTION if output_type.__name__.startswith("TF") else PT_RETURN_INTRODUCTION
intro = intro.format(full_output_type=full_output_type, config_class=config_class)
return intro + docstrings
PT_TOKEN_CLASSIFICATION_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> import torch
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> labels = torch.tensor([1] * inputs["input_ids"].size(1)).unsqueeze(0) # Batch size 1
>>> outputs = model(**inputs, labels=labels)
>>> loss = outputs.loss
>>> logits = outputs.logits
"""
PT_QUESTION_ANSWERING_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> import torch
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
>>> inputs = tokenizer(question, text, return_tensors='pt')
>>> start_positions = torch.tensor([1])
>>> end_positions = torch.tensor([3])
>>> outputs = model(**inputs, start_positions=start_positions, end_positions=end_positions)
>>> loss = outputs.loss
>>> start_scores = outputs.start_logits
>>> end_scores = outputs.end_logits
"""
PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> import torch
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
>>> outputs = model(**inputs, labels=labels)
>>> loss = outputs.loss
>>> logits = outputs.logits
"""
PT_MASKED_LM_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> import torch
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="pt")
>>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
>>> outputs = model(**inputs, labels=labels)
>>> loss = outputs.loss
>>> logits = outputs.logits
"""
PT_BASE_MODEL_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> import torch
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
"""
PT_MULTIPLE_CHOICE_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> import torch
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
>>> choice0 = "It is eaten with a fork and a knife."
>>> choice1 = "It is eaten while held in the hand."
>>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1
>>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='pt', padding=True)
>>> outputs = model(**{{k: v.unsqueeze(0) for k,v in encoding.items()}}, labels=labels) # batch size is 1
>>> # the linear classifier still needs to be trained
>>> loss = outputs.loss
>>> logits = outputs.logits
"""
PT_CAUSAL_LM_SAMPLE = r"""
Example::
>>> import torch
>>> from transformers import {tokenizer_class}, {model_class}
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs, labels=inputs["input_ids"])
>>> loss = outputs.loss
>>> logits = outputs.logits
"""
TF_TOKEN_CLASSIFICATION_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> import tensorflow as tf
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
>>> input_ids = inputs["input_ids"]
>>> inputs["labels"] = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1
>>> outputs = model(inputs)
>>> loss = outputs.loss
>>> logits = outputs.logits
"""
TF_QUESTION_ANSWERING_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> import tensorflow as tf
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
>>> input_dict = tokenizer(question, text, return_tensors='tf')
>>> outputs = model(input_dict)
>>> start_logits = outputs.start_logits
>>> end_logits = outputs.end_logits
>>> all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
>>> answer = ' '.join(all_tokens[tf.math.argmax(start_logits, 1)[0] : tf.math.argmax(end_logits, 1)[0]+1])
"""
TF_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> import tensorflow as tf
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
>>> inputs["labels"] = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1
>>> outputs = model(inputs)
>>> loss = outputs.loss
>>> logits = outputs.logits
"""
TF_MASKED_LM_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> import tensorflow as tf
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="tf")
>>> inputs["labels"] = tokenizer("The capital of France is Paris.", return_tensors="tf")["input_ids"]
>>> outputs = model(inputs)
>>> loss = outputs.loss
>>> logits = outputs.logits
"""
TF_BASE_MODEL_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> import tensorflow as tf
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
>>> outputs = model(inputs)
>>> last_hidden_states = outputs.last_hidden_state
"""
TF_MULTIPLE_CHOICE_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> import tensorflow as tf
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
>>> choice0 = "It is eaten with a fork and a knife."
>>> choice1 = "It is eaten while held in the hand."
>>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='tf', padding=True)
>>> inputs = {{k: tf.expand_dims(v, 0) for k, v in encoding.items()}}
>>> outputs = model(inputs) # batch size is 1
>>> # the linear classifier still needs to be trained
>>> logits = outputs.logits
"""
TF_CAUSAL_LM_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> import tensorflow as tf
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
>>> outputs = model(inputs)
>>> logits = outputs.logits
"""
def add_code_sample_docstrings(
*docstr, tokenizer_class=None, checkpoint=None, output_type=None, config_class=None, mask=None
):
def docstring_decorator(fn):
model_class = fn.__qualname__.split(".")[0]
is_tf_class = model_class[:2] == "TF"
doc_kwargs = dict(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint)
if "SequenceClassification" in model_class:
code_sample = TF_SEQUENCE_CLASSIFICATION_SAMPLE if is_tf_class else PT_SEQUENCE_CLASSIFICATION_SAMPLE
elif "QuestionAnswering" in model_class:
code_sample = TF_QUESTION_ANSWERING_SAMPLE if is_tf_class else PT_QUESTION_ANSWERING_SAMPLE
elif "TokenClassification" in model_class:
code_sample = TF_TOKEN_CLASSIFICATION_SAMPLE if is_tf_class else PT_TOKEN_CLASSIFICATION_SAMPLE
elif "MultipleChoice" in model_class:
code_sample = TF_MULTIPLE_CHOICE_SAMPLE if is_tf_class else PT_MULTIPLE_CHOICE_SAMPLE
elif "MaskedLM" in model_class or model_class in ["FlaubertWithLMHeadModel", "XLMWithLMHeadModel"]:
doc_kwargs["mask"] = "[MASK]" if mask is None else mask
code_sample = TF_MASKED_LM_SAMPLE if is_tf_class else PT_MASKED_LM_SAMPLE
elif "LMHead" in model_class or "CausalLM" in model_class:
code_sample = TF_CAUSAL_LM_SAMPLE if is_tf_class else PT_CAUSAL_LM_SAMPLE
elif "Model" in model_class or "Encoder" in model_class:
code_sample = TF_BASE_MODEL_SAMPLE if is_tf_class else PT_BASE_MODEL_SAMPLE
else:
raise ValueError(f"Docstring can't be built for model {model_class}")
output_doc = _prepare_output_docstrings(output_type, config_class) if output_type is not None else ""
built_doc = code_sample.format(**doc_kwargs)
fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + output_doc + built_doc
return fn
return docstring_decorator
def replace_return_docstrings(output_type=None, config_class=None):
def docstring_decorator(fn):
docstrings = fn.__doc__
lines = docstrings.split("\n")
i = 0
while i < len(lines) and re.search(r"^\s*Returns?:\s*$", lines[i]) is None:
i += 1
if i < len(lines):
lines[i] = _prepare_output_docstrings(output_type, config_class)
docstrings = "\n".join(lines)
else:
raise ValueError(
f"The function {fn} should have an empty 'Return:' or 'Returns:' in its docstring as placeholder, current docstring is:\n{docstrings}"
)
fn.__doc__ = docstrings
return fn
return docstring_decorator
def is_remote_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https")
def hf_bucket_url(
model_id: str, filename: str, subfolder: Optional[str] = None, revision: Optional[str] = None, mirror=None
) -> str:
"""
Resolve a model identifier, a file name, and an optional revision id, to a huggingface.co-hosted url, redirecting
to Cloudfront (a Content Delivery Network, or CDN) for large files.
Cloudfront is replicated over the globe so downloads are way faster for the end user (and it also lowers our
bandwidth costs).
Cloudfront aggressively caches files by default (default TTL is 24 hours), however this is not an issue here
because we migrated to a git-based versioning system on huggingface.co, so we now store the files on S3/Cloudfront
in a content-addressable way (i.e., the file name is its hash). Using content-addressable filenames means cache
can't ever be stale.
In terms of client-side caching from this library, we base our caching on the objects' ETag. An object' ETag is:
its sha1 if stored in git, or its sha256 if stored in git-lfs. Files cached locally from transformers before v3.5.0
are not shared with those new files, because the cached file's name contains a hash of the url (which changed).
"""
if subfolder is not None:
filename = f"{subfolder}/{filename}"
if mirror:
endpoint = PRESET_MIRROR_DICT.get(mirror, mirror)
legacy_format = "/" not in model_id
if legacy_format:
return f"{endpoint}/{model_id}-{filename}"
else:
return f"{endpoint}/{model_id}/{filename}"
if revision is None:
revision = "main"
return HUGGINGFACE_CO_PREFIX.format(model_id=model_id, revision=revision, filename=filename)
def url_to_filename(url: str, etag: Optional[str] = None) -> str:
"""
Convert `url` into a hashed filename in a repeatable way. If `etag` is specified, append its hash to the url's,
delimited by a period. If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name so that TF 2.0 can
identify it as a HDF5 file (see
https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
"""
url_bytes = url.encode("utf-8")
filename = sha256(url_bytes).hexdigest()
if etag:
etag_bytes = etag.encode("utf-8")
filename += "." + sha256(etag_bytes).hexdigest()
if url.endswith(".h5"):
filename += ".h5"
return filename
def filename_to_url(filename, cache_dir=None):
"""
Return the url and etag (which may be ``None``) stored for `filename`. Raise ``EnvironmentError`` if `filename` or
its stored metadata do not exist.
"""
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
raise EnvironmentError(f"file {cache_path} not found")
meta_path = cache_path + ".json"
if not os.path.exists(meta_path):
raise EnvironmentError(f"file {meta_path} not found")
with open(meta_path, encoding="utf-8") as meta_file:
metadata = json.load(meta_file)
url = metadata["url"]
etag = metadata["etag"]
return url, etag
def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
"""
Returns a list of tuples representing model binaries that are cached locally. Each tuple has shape
:obj:`(model_url, etag, size_MB)`. Filenames in :obj:`cache_dir` are use to get the metadata for each model, only
urls ending with `.bin` are added.
Args:
cache_dir (:obj:`Union[str, Path]`, `optional`):
The cache directory to search for models within. Will default to the transformers cache if unset.
Returns:
List[Tuple]: List of tuples each with shape :obj:`(model_url, etag, size_MB)`
"""
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
elif isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
cached_models = []
for file in os.listdir(cache_dir):
if file.endswith(".json"):
meta_path = os.path.join(cache_dir, file)
with open(meta_path, encoding="utf-8") as meta_file:
metadata = json.load(meta_file)
url = metadata["url"]
etag = metadata["etag"]
if url.endswith(".bin"):
size_MB = os.path.getsize(meta_path.strip(".json")) / 1e6
cached_models.append((url, etag, size_MB))
return cached_models
def cached_path(
url_or_filename,
cache_dir=None,
force_download=False,
proxies=None,
resume_download=False,
user_agent: Union[Dict, str, None] = None,
extract_compressed_file=False,
force_extract=False,
use_auth_token: Union[bool, str, None] = None,
local_files_only=False,
) -> Optional[str]:
"""
Given something that might be a URL (or might be a local path), determine which. If it's a URL, download the file
and cache it, and return the path to the cached file. If it's already a local path, make sure the file exists and
then return the path
Args:
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
force_download: if True, re-download the file even if it's already cached in the cache dir.
resume_download: if True, resume the download if incompletely received file is found.
user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
use_auth_token: Optional string or boolean to use as Bearer token for remote files. If True,
will get token from ~/.huggingface.
extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
file in a folder along the archive.
force_extract: if True when extract_compressed_file is True and the archive was already extracted,
re-extract the archive and override the folder where it was extracted.
Return:
Local path (string) of file or if networking is off, last version of file cached on disk.
Raises:
In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
"""
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
if isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename)
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
if is_remote_url(url_or_filename):
# URL, so get it from the cache (downloading if necessary)
output_path = get_from_cache(
url_or_filename,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
user_agent=user_agent,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
)
elif os.path.exists(url_or_filename):
# File, and it exists.
output_path = url_or_filename
elif urlparse(url_or_filename).scheme == "":
# File, but it doesn't exist.
raise EnvironmentError(f"file {url_or_filename} not found")
else:
# Something unknown
raise ValueError(f"unable to parse {url_or_filename} as a URL or as a local path")
if extract_compressed_file:
if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
return output_path
# Path where we extract compressed archives
# We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
output_dir, output_file = os.path.split(output_path)
output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
return output_path_extracted
# Prevent parallel extractions
lock_path = output_path + ".lock"
with FileLock(lock_path):
shutil.rmtree(output_path_extracted, ignore_errors=True)
os.makedirs(output_path_extracted)
if is_zipfile(output_path):
with ZipFile(output_path, "r") as zip_file:
zip_file.extractall(output_path_extracted)
zip_file.close()
elif tarfile.is_tarfile(output_path):
tar_file = tarfile.open(output_path)
tar_file.extractall(output_path_extracted)
tar_file.close()
else:
raise EnvironmentError(f"Archive format of {output_path} could not be identified")
return output_path_extracted
return output_path
def define_sagemaker_information():
try:
instance_data = requests.get(os.environ["ECS_CONTAINER_METADATA_URI"]).json()
dlc_container_used = instance_data["Image"]
dlc_tag = instance_data["Image"].split(":")[1]
except Exception:
dlc_container_used = None
dlc_tag = None
sagemaker_params = json.loads(os.getenv("SM_FRAMEWORK_PARAMS", "{}"))
runs_distributed_training = True if "sagemaker_distributed_dataparallel_enabled" in sagemaker_params else False
account_id = os.getenv("TRAINING_JOB_ARN").split(":")[4] if "TRAINING_JOB_ARN" in os.environ else None
sagemaker_object = {
"sm_framework": os.getenv("SM_FRAMEWORK_MODULE", None),
"sm_region": os.getenv("AWS_REGION", None),
"sm_number_gpu": os.getenv("SM_NUM_GPUS", 0),
"sm_number_cpu": os.getenv("SM_NUM_CPUS", 0),
"sm_distributed_training": runs_distributed_training,
"sm_deep_learning_container": dlc_container_used,
"sm_deep_learning_container_tag": dlc_tag,
"sm_account_id": account_id,
}
return sagemaker_object
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
"""
Formats a user-agent string with basic info about a request.
"""
#ua = f"transformers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
if is_torch_available():
ua += f"; torch/{_torch_version}"
if is_tf_available():
ua += f"; tensorflow/{_tf_version}"
if DISABLE_TELEMETRY:
return ua + "; telemetry/off"
if is_training_run_on_sagemaker():
ua += "; " + "; ".join(f"{k}/{v}" for k, v in define_sagemaker_information().items())
# CI will set this value to True
if os.environ.get("TRANSFORMERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
ua += "; is_ci/true"
if isinstance(user_agent, dict):
ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
elif isinstance(user_agent, str):
ua += "; " + user_agent
return ua
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):
"""
Download remote file. Do not gobble up errors.
"""
headers = copy.deepcopy(headers)
if resume_size > 0:
headers["Range"] = f"bytes={resume_size}-"
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
r.raise_for_status()
content_length = r.headers.get("Content-Length")
total = resume_size + int(content_length) if content_length is not None else None
progress = tqdm(
unit="B",
unit_scale=True,
total=total,
initial=resume_size,
desc="Downloading",
disable=bool(logging.get_verbosity() == logging.NOTSET),
)
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
def get_from_cache(
url: str,
cache_dir=None,
force_download=False,
proxies=None,
etag_timeout=10,
resume_download=False,
user_agent: Union[Dict, str, None] = None,
use_auth_token: Union[bool, str, None] = None,
local_files_only=False,
) -> Optional[str]:
"""
Given a URL, look for the corresponding file in the local cache. If it's not there, download it. Then return the
path to the cached file.
Return:
Local path (string) of file or if networking is off, last version of file cached on disk.
Raises:
In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
"""
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
os.makedirs(cache_dir, exist_ok=True)
headers = {"user-agent": http_user_agent(user_agent)}
if isinstance(use_auth_token, str):
headers["authorization"] = f"Bearer {use_auth_token}"
elif use_auth_token:
token = HfFolder.get_token()
if token is None:
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
headers["authorization"] = f"Bearer {token}"
url_to_download = url
etag = None
if not local_files_only:
try:
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
r.raise_for_status()
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
# We favor a custom header indicating the etag of the linked resource, and
# we fallback to the regular etag header.
# If we don't have any of those, raise an error.
if etag is None:
raise OSError(
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
)
# In case of a redirect,
# save an extra redirect on the request.get call,
# and ensure we download the exact atomic version even if it changed
# between the HEAD and the GET (unlikely, but hey).
if 300 <= r.status_code <= 399:
url_to_download = r.headers["Location"]
except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
# Actually raise for those subclasses of ConnectionError
raise
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
# Otherwise, our Internet connection is down.
# etag is None
pass
filename = url_to_filename(url, etag)
# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)
# etag is None == we don't have a connection or we passed local_files_only.
# try to get the last downloaded one
if etag is None:
if os.path.exists(cache_path):
return cache_path
else:
matching_files = [
file
for file in fnmatch.filter(os.listdir(cache_dir), filename.split(".")[0] + ".*")
if not file.endswith(".json") and not file.endswith(".lock")
]
if len(matching_files) > 0:
return os.path.join(cache_dir, matching_files[-1])
else:
# If files cannot be found and local_files_only=True,
# the models might've been found if local_files_only=False
# Notify the user about that
if local_files_only:
raise FileNotFoundError(
"Cannot find the requested files in the cached path and outgoing traffic has been"
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
" to False."
)
else:
raise ValueError(
"Connection error, and we cannot find the requested files in the cached path."
" Please try again or make sure your Internet connection is on."
)
# From now on, etag is not None.
if os.path.exists(cache_path) and not force_download:
return cache_path
# Prevent parallel downloads of the same file with a lock.
lock_path = cache_path + ".lock"
with FileLock(lock_path):
# If the download just completed while the lock was activated.
if os.path.exists(cache_path) and not force_download:
# Even if returning early like here, the lock will be released.
return cache_path
if resume_download:
incomplete_path = cache_path + ".incomplete"
@contextmanager
def _resumable_file_manager() -> "io.BufferedWriter":
with open(incomplete_path, "ab") as f:
yield f
temp_file_manager = _resumable_file_manager
if os.path.exists(incomplete_path):
resume_size = os.stat(incomplete_path).st_size
else:
resume_size = 0
else:
temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False)
resume_size = 0
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with temp_file_manager() as temp_file:
logger.info(f"{url} not found in cache or force_download set to True, downloading to {temp_file.name}")
http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, headers=headers)
logger.info(f"storing {url} in cache at {cache_path}")
os.replace(temp_file.name, cache_path)
logger.info(f"creating metadata file for {cache_path}")
meta = {"url": url, "etag": etag}
meta_path = cache_path + ".json"
with open(meta_path, "w") as meta_file:
json.dump(meta, meta_file)
return cache_path
class cached_property(property):
"""
Descriptor that mimics @property but caches output in member variable.
From tensorflow_datasets
Built-in in functools from Python 3.8.
"""
def __get__(self, obj, objtype=None):
# See docs.python.org/3/howto/descriptor.html#properties
if obj is None:
return self
if self.fget is None:
raise AttributeError("unreadable attribute")
attr = "__cached_" + self.fget.__name__
cached = getattr(obj, attr, None)
if cached is None:
cached = self.fget(obj)
setattr(obj, attr, cached)
return cached
def torch_required(func):
# Chose a different decorator name than in tests so it's clear they are not the same.
@wraps(func)
def wrapper(*args, **kwargs):
if is_torch_available():
return func(*args, **kwargs)
else:
raise ImportError(f"Method `{func.__name__}` requires PyTorch.")
return wrapper
def tf_required(func):
# Chose a different decorator name than in tests so it's clear they are not the same.
@wraps(func)
def wrapper(*args, **kwargs):
if is_tf_available():
return func(*args, **kwargs)
else:
raise ImportError(f"Method `{func.__name__}` requires TF.")
return wrapper
def is_tensor(x):
""" Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor` or :obj:`np.ndarray`. """
if is_torch_available():
import torch
if isinstance(x, torch.Tensor):
return True
if is_tf_available():
import tensorflow as tf
if isinstance(x, tf.Tensor):
return True
return isinstance(x, np.ndarray)
def _is_numpy(x):
return isinstance(x, np.ndarray)
def _is_torch(x):
import torch
return isinstance(x, torch.Tensor)
def _is_torch_device(x):
import torch
return isinstance(x, torch.device)
def _is_tensorflow(x):
import tensorflow as tf
return isinstance(x, tf.Tensor)
def _is_jax(x):
import jax.numpy as jnp # noqa: F811
return isinstance(x, jnp.ndarray)
def to_py_obj(obj):
"""
Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list.
"""
if isinstance(obj, (dict, UserDict)):
return {k: to_py_obj(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return [to_py_obj(o) for o in obj]
elif is_tf_available() and _is_tensorflow(obj):
return obj.numpy().tolist()
elif is_torch_available() and _is_torch(obj):
return obj.detach().cpu().tolist()
elif isinstance(obj, np.ndarray):
return obj.tolist()
else:
return obj
class ModelOutput(OrderedDict):
"""
Base class for all model outputs as dataclass. Has a ``__getitem__`` that allows indexing by integer or slice (like
a tuple) or strings (like a dictionary) that will ignore the ``None`` attributes. Otherwise behaves like a regular
python dictionary.
.. warning::
You can't unpack a :obj:`ModelOutput` directly. Use the :meth:`~transformers.file_utils.ModelOutput.to_tuple`
method to convert it to a tuple before.
"""
def __post_init__(self):
class_fields = fields(self)
# Safety and consistency checks
assert len(class_fields), f"{self.__class__.__name__} has no fields."
assert all(
field.default is None for field in class_fields[1:]
), f"{self.__class__.__name__} should not have more than one required field."
first_field = getattr(self, class_fields[0].name)
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
if other_fields_are_none and not is_tensor(first_field):
try:
iterator = iter(first_field)
first_field_iterator = True
except TypeError:
first_field_iterator = False
# if we provided an iterator as first field and the iterator is a (key, value) iterator
# set the associated fields
if first_field_iterator:
for element in iterator:
if (
not isinstance(element, (list, tuple))
or not len(element) == 2
or not isinstance(element[0], str)
):
break
setattr(self, element[0], element[1])
if element[1] is not None:
self[element[0]] = element[1]
elif first_field is not None:
self[class_fields[0].name] = first_field
else:
for field in class_fields:
v = getattr(self, field.name)
if v is not None:
self[field.name] = v
def __delitem__(self, *args, **kwargs):
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
def setdefault(self, *args, **kwargs):
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
def pop(self, *args, **kwargs):
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
def update(self, *args, **kwargs):
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
def __getitem__(self, k):
if isinstance(k, str):
inner_dict = {k: v for (k, v) in self.items()}
return inner_dict[k]
else:
return self.to_tuple()[k]
def __setattr__(self, name, value):
if name in self.keys() and value is not None:
# Don't call self.__setitem__ to avoid recursion errors
super().__setitem__(name, value)
super().__setattr__(name, value)
def __setitem__(self, key, value):
# Will raise a KeyException if needed
super().__setitem__(key, value)
# Don't call self.__setattr__ to avoid recursion errors
super().__setattr__(key, value)
def to_tuple(self) -> Tuple[Any]:
"""
Convert self to a tuple containing all the attributes/keys that are not ``None``.
"""
return tuple(self[k] for k in self.keys())
class ExplicitEnum(Enum):
"""
Enum with more explicit error message for missing values.
"""
@classmethod
def _missing_(cls, value):
raise ValueError(
f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
)
class PaddingStrategy(ExplicitEnum):
"""
Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
in an IDE.
"""
LONGEST = "longest"
MAX_LENGTH = "max_length"
DO_NOT_PAD = "do_not_pad"
class TensorType(ExplicitEnum):
"""
Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
tab-completion in an IDE.
"""
PYTORCH = "pt"
TENSORFLOW = "tf"
NUMPY = "np"
JAX = "jax"
class _BaseLazyModule(ModuleType):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
# Very heavily inspired by optuna.integration._IntegrationModule
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
def __init__(self, name, import_structure):
super().__init__(name)
self._modules = set(import_structure.keys())
self._class_to_module = {}
for key, values in import_structure.items():
for value in values:
self._class_to_module[value] = key
# Needed for autocompletion in an IDE
self.__all__ = list(import_structure.keys()) + sum(import_structure.values(), [])
# Needed for autocompletion in an IDE
def __dir__(self):
return super().__dir__() + self.__all__
def __getattr__(self, name: str) -> Any:
if name in self._modules:
value = self._get_module(name)
elif name in self._class_to_module.keys():
module = self._get_module(self._class_to_module[name])
value = getattr(module, name)
else:
raise AttributeError(f"module {self.__name__} has no attribute {name}")
setattr(self, name, value)
return value
def _get_module(self, module_name: str) -> ModuleType:
raise NotImplementedError
================================================
FILE: flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/utils/hf_api.py
================================================
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
from os.path import expanduser
from typing import Dict, List, Optional, Tuple
from tqdm import tqdm
import requests
ENDPOINT = "https://huggingface.co"
class RepoObj:
"""
HuggingFace git-based system, data structure that represents a file belonging to the current user.
"""
def __init__(self, filename: str, lastModified: str, commit: str, size: int, **kwargs):
self.filename = filename
self.lastModified = lastModified
self.commit = commit
self.size = size
class ModelSibling:
"""
Data structure that represents a public file inside a model, accessible from huggingface.co
"""
def __init__(self, rfilename: str, **kwargs):
self.rfilename = rfilename # filename relative to the model root
for k, v in kwargs.items():
setattr(self, k, v)
class ModelInfo:
"""
Info about a public model accessible from huggingface.co
"""
def __init__(
self,
modelId: Optional[str] = None, # id of model
tags: List[str] = [],
pipeline_tag: Optional[str] = None,
siblings: Optional[List[Dict]] = None, # list of files that constitute the model
**kwargs
):
self.modelId = modelId
self.tags = tags
self.pipeline_tag = pipeline_tag
self.siblings = [ModelSibling(**x) for x in siblings] if siblings is not None else None
for k, v in kwargs.items():
setattr(self, k, v)
class HfApi:
def __init__(self, endpoint=None):
self.endpoint = endpoint if endpoint is not None else ENDPOINT
def login(self, username: str, password: str) -> str:
"""
Call HF API to sign in a user and get a token if credentials are valid.
Outputs: token if credentials are valid
Throws: requests.exceptions.HTTPError if credentials are invalid
"""
path = f"{self.endpoint}/api/login"
r = requests.post(path, json={"username": username, "password": password})
r.raise_for_status()
d = r.json()
return d["token"]
def whoami(self, token: str) -> Tuple[str, List[str]]:
"""
Call HF API to know "whoami"
"""
path = f"{self.endpoint}/api/whoami"
r = requests.get(path, headers={"authorization": f"Bearer {token}"})
r.raise_for_status()
d = r.json()
return d["user"], d["orgs"]
def logout(self, token: str) -> None:
"""
Call HF API to log out.
"""
path = f"{self.endpoint}/api/logout"
r = requests.post(path, headers={"authorization": f"Bearer {token}"})
r.raise_for_status()
def model_list(self) -> List[ModelInfo]:
"""
Get the public list of all the models on huggingface.co
"""
path = f"{self.endpoint}/api/models"
r = requests.get(path)
r.raise_for_status()
d = r.json()
return [ModelInfo(**x) for x in d]
def list_repos_objs(self, token: str, organization: Optional[str] = None) -> List[RepoObj]:
"""
HuggingFace git-based system, used for models.
Call HF API to list all stored files for user (or one of their organizations).
"""
path = f"{self.endpoint}/api/repos/ls"
params = {"organization": organization} if organization is not None else None
r = requests.get(path, params=params, headers={"authorization": f"Bearer {token}"})
r.raise_for_status()
d = r.json()
return [RepoObj(**x) for x in d]
def create_repo(
self,
token: str,
name: str,
organization: Optional[str] = None,
private: Optional[bool] = None,
exist_ok=False,
lfsmultipartthresh: Optional[int] = None,
) -> str:
"""
HuggingFace git-based system, used for models.
Call HF API to create a whole repo.
Params:
private: Whether the model repo should be private (requires a paid huggingface.co account)
exist_ok: Do not raise an error if repo already exists
lfsmultipartthresh: Optional: internal param for testing purposes.
"""
path = f"{self.endpoint}/api/repos/create"
json = {"name": name, "organization": organization, "private": private}
if lfsmultipartthresh is not None:
json["lfsmultipartthresh"] = lfsmultipartthresh
r = requests.post(
path,
headers={"authorization": f"Bearer {token}"},
json=json,
)
if exist_ok and r.status_code == 409:
return ""
r.raise_for_status()
d = r.json()
return d["url"]
def delete_repo(self, token: str, name: str, organization: Optional[str] = None):
"""
HuggingFace git-based system, used for models.
Call HF API to delete a whole repo.
CAUTION(this is irreversible).
"""
path = f"{self.endpoint}/api/repos/delete"
r = requests.delete(
path,
headers={"authorization": f"Bearer {token}"},
json={"name": name, "organization": organization},
)
r.raise_for_status()
class TqdmProgressFileReader:
"""
Wrap an io.BufferedReader `f` (such as the output of `open(…, "rb")`) and override `f.read()` so as to display a
tqdm progress bar.
see github.com/huggingface/transformers/pull/2078#discussion_r354739608 for implementation details.
"""
def __init__(self, f: io.BufferedReader):
self.f = f
self.total_size = os.fstat(f.fileno()).st_size
self.pbar = tqdm(total=self.total_size, leave=False)
self.read = f.read
f.read = self._read
def _read(self, n=-1):
self.pbar.update(n)
return self.read(n)
def close(self):
self.pbar.close()
class HfFolder:
path_token = expanduser("~/.huggingface/token")
@classmethod
def save_token(cls, token):
"""
Save token, creating folder as needed.
"""
os.makedirs(os.path.dirname(cls.path_token), exist_ok=True)
with open(cls.path_token, "w+") as f:
f.write(token)
@classmethod
def get_token(cls):
"""
Get token or None if not existent.
"""
try:
with open(cls.path_token, "r") as f:
return f.read()
except FileNotFoundError:
pass
@classmethod
def delete_token(cls):
"""
Delete token. Do not fail if token does not exist.
"""
try:
os.remove(cls.path_token)
except FileNotFoundError:
pass
================================================
FILE: flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/utils/logging.py
================================================
# coding=utf-8
# Copyright 2020 Optuna, Hugging Face
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Logging utilities. """
import logging
import os
import sys
import threading
from logging import CRITICAL # NOQA
from logging import DEBUG # NOQA
from logging import ERROR # NOQA
from logging import FATAL # NOQA
from logging import INFO # NOQA
from logging import NOTSET # NOQA
from logging import WARN # NOQA
from logging import WARNING # NOQA
from typing import Optional
_lock = threading.Lock()
_default_handler: Optional[logging.Handler] = None
log_levels = {
"debug": logging.DEBUG,
"info": logging.INFO,
"warning": logging.WARNING,
"error": logging.ERROR,
"critical": logging.CRITICAL,
}
_default_log_level = logging.WARNING
def _get_default_logging_level():
"""
If TRANSFORMERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
not - fall back to ``_default_log_level``
"""
env_level_str = os.getenv("TRANSFORMERS_VERBOSITY", None)
if env_level_str:
if env_level_str in log_levels:
return log_levels[env_level_str]
else:
logging.getLogger().warning(
f"Unknown option TRANSFORMERS_VERBOSITY={env_level_str}, "
f"has to be one of: { ', '.join(log_levels.keys()) }"
)
return _default_log_level
def _get_library_name() -> str:
return __name__.split(".")[0]
def _get_library_root_logger() -> logging.Logger:
return logging.getLogger(_get_library_name())
def _configure_library_root_logger() -> None:
global _default_handler
with _lock:
if _default_handler:
# This library has already configured the library root logger.
return
_default_handler = logging.StreamHandler() # Set sys.stderr as stream.
_default_handler.flush = sys.stderr.flush
# Apply our default configuration to the library root logger.
library_root_logger = _get_library_root_logger()
library_root_logger.addHandler(_default_handler)
library_root_logger.setLevel(_get_default_logging_level())
library_root_logger.propagate = False
def _reset_library_root_logger() -> None:
global _default_handler
with _lock:
if not _default_handler:
return
library_root_logger = _get_library_root_logger()
library_root_logger.removeHandler(_default_handler)
library_root_logger.setLevel(logging.NOTSET)
_default_handler = None
def get_logger(name: Optional[str] = None) -> logging.Logger:
"""
Return a logger with the specified name.
This function is not supposed to be directly accessed unless you are writing a custom transformers module.
"""
if name is None:
name = _get_library_name()
_configure_library_root_logger()
return logging.getLogger(name)
def get_verbosity() -> int:
"""
Return the current level for the 🤗 Transformers's root logger as an int.
Returns:
:obj:`int`: The logging level.
.. note::
🤗 Transformers has following logging levels:
- 50: ``transformers.logging.CRITICAL`` or ``transformers.logging.FATAL``
- 40: ``transformers.logging.ERROR``
- 30: ``transformers.logging.WARNING`` or ``transformers.logging.WARN``
- 20: ``transformers.logging.INFO``
- 10: ``transformers.logging.DEBUG``
"""
_configure_library_root_logger()
return _get_library_root_logger().getEffectiveLevel()
def set_verbosity(verbosity: int) -> None:
"""
Set the vebosity level for the 🤗 Transformers's root logger.
Args:
verbosity (:obj:`int`):
Logging level, e.g., one of:
- ``transformers.logging.CRITICAL`` or ``transformers.logging.FATAL``
- ``transformers.logging.ERROR``
- ``transformers.logging.WARNING`` or ``transformers.logging.WARN``
- ``transformers.logging.INFO``
- ``transformers.logging.DEBUG``
"""
_configure_library_root_logger()
_get_library_root_logger().setLevel(verbosity)
def set_verbosity_info():
"""Set the verbosity to the :obj:`INFO` level."""
return set_verbosity(INFO)
def set_verbosity_warning():
"""Set the verbosity to the :obj:`WARNING` level."""
return set_verbosity(WARNING)
def set_verbosity_debug():
"""Set the verbosity to the :obj:`DEBUG` level."""
return set_verbosity(DEBUG)
def set_verbosity_error():
"""Set the verbosity to the :obj:`ERROR` level."""
return set_verbosity(ERROR)
def disable_default_handler() -> None:
"""Disable the default handler of the HuggingFace Transformers's root logger."""
_configure_library_root_logger()
assert _default_handler is not None
_get_library_root_logger().removeHandler(_default_handler)
def enable_default_handler() -> None:
"""Enable the default handler of the HuggingFace Transformers's root logger."""
_configure_library_root_logger()
assert _default_handler is not None
_get_library_root_logger().addHandler(_default_handler)
def add_handler(handler: logging.Handler) -> None:
"""adds a handler to the HuggingFace Transformers's root logger."""
_configure_library_root_logger()
assert handler is not None
_get_library_root_logger().addHandler(handler)
def remove_handler(handler: logging.Handler) -> None:
"""removes given handler from the HuggingFace Transformers's root logger."""
_configure_library_root_logger()
assert handler is not None and handler not in _get_library_root_logger().handlers
_get_library_root_logger().removeHandler(handler)
def disable_propagation() -> None:
"""
Disable propagation of the library log outputs. Note that log propagation is disabled by default.
"""
_configure_library_root_logger()
_get_library_root_logger().propagate = False
def enable_propagation() -> None:
"""
Enable propagation of the library log outputs. Please disable the HuggingFace Transformers's default handler to
prevent double logging if the root logger has been configured.
"""
_configure_library_root_logger()
_get_library_root_logger().propagate = True
def enable_explicit_format() -> None:
"""
Enable explicit formatting for every HuggingFace Transformers's logger. The explicit formatter is as follows:
::
[LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
All handlers currently bound to the root logger are affected by this method.
"""
handlers = _get_library_root_logger().handlers
for handler in handlers:
formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
handler.setFormatter(formatter)
def reset_format() -> None:
"""
Resets the formatting for HuggingFace Transformers's loggers.
All handlers currently bound to the root logger are affected by this method.
"""
handlers = _get_library_root_logger().handlers
for handler in handlers:
handler.setFormatter(None)
================================================
FILE: flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/utils/tokenization_utils.py
================================================
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tokenization classes for python tokenizers. For fast tokenizers (provided by HuggingFace's tokenizers library) see
tokenization_utils_fast.py
"""
import bisect
import itertools
import re
import unicodedata
from typing import Any, Dict, List, Optional, Tuple, Union, overload
from .file_utils import PaddingStrategy, TensorType, add_end_docstrings
from .tokenization_utils_base import (
ENCODE_KWARGS_DOCSTRING,
ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING,
INIT_TOKENIZER_DOCSTRING,
AddedToken,
BatchEncoding,
EncodedInput,
EncodedInputPair,
PreTokenizedInput,
PreTokenizedInputPair,
PreTrainedTokenizerBase,
TextInput,
TextInputPair,
TruncationStrategy,
)
from . import logging
logger = logging.get_logger(__name__)
# Slow tokenizers are saved in a vocabulary plus three separated files
SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
ADDED_TOKENS_FILE = "added_tokens.json"
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
def _is_whitespace(char):
"""Checks whether `char` is a whitespace character."""
# \t, \n, and \r are technically control characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `char` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False
def _is_punctuation(char):
"""Checks whether `char` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
def _is_end_of_word(text):
"""Checks whether the last character in text is one of a punctuation, control or whitespace character."""
last_char = text[-1]
return bool(_is_control(last_char) | _is_punctuation(last_char) | _is_whitespace(last_char))
def _is_start_of_word(text):
"""Checks whether the first character in text is one of a punctuation, control or whitespace character."""
first_char = text[0]
return bool(_is_control(first_char) | _is_punctuation(first_char) | _is_whitespace(first_char))
def _insert_one_token_to_ordered_list(token_list: List[str], new_token: str):
"""
Inserts one token to an ordered list if it does not already exist. Note: token_list must be sorted.
"""
insertion_idx = bisect.bisect_left(token_list, new_token)
# Checks if new_token is already in the ordered token_list
if insertion_idx < len(token_list) and token_list[insertion_idx] == new_token:
# new_token is in token_list, don't add
return
else:
token_list.insert(insertion_idx, new_token)
@add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
class PreTrainedTokenizer(PreTrainedTokenizerBase):
"""
Base class for all slow tokenizers.
Inherits from :class:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase`.
Handle all the shared methods for tokenization and special tokens as well as methods downloading/caching/loading
pretrained tokenizers as well as adding tokens to the vocabulary.
This class also contain the added tokens in a unified way on top of all tokenizers so we don't have to handle the
specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...).
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Added tokens - We store this for both slow and fast tokenizers
# until the serialization of Fast tokenizers is updated
self.added_tokens_encoder: Dict[str, int] = {}
self.added_tokens_decoder: Dict[int, str] = {}
self.unique_no_split_tokens: List[str] = []
self._decode_use_source_tokenizer = False
@property
def is_fast(self) -> bool:
return False
@property
def vocab_size(self) -> int:
"""
:obj:`int`: Size of the base vocabulary (without the added tokens).
"""
raise NotImplementedError
def get_added_vocab(self) -> Dict[str, int]:
"""
Returns the added tokens in the vocabulary as a dictionary of token to index.
Returns:
:obj:`Dict[str, int]`: The added tokens.
"""
return self.added_tokens_encoder
def __len__(self):
"""
Size of the full vocabulary with the added tokens.
"""
return self.vocab_size + len(self.added_tokens_encoder)
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
"""
Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to
it with indices starting from length of the current vocabulary.
Args:
new_tokens (:obj:`List[str]`or :obj:`List[tokenizers.AddedToken]`):
Token(s) to add in vocabulary. A token is only added if it's not already in the vocabulary (tested by
checking if the tokenizer assign the index of the ``unk_token`` to them).
special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the tokens should be added as special tokens.
Returns:
:obj:`int`: The number of tokens actually added to the vocabulary.
Examples::
# Let's see how to increase the vocabulary of Bert model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
print('We have added', num_added_toks, 'tokens')
# Note: resize_token_embeddings expects to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
model.resize_token_embeddings(len(tokenizer))
"""
new_tokens = [str(tok) for tok in new_tokens]
tokens_to_add = []
for token in new_tokens:
assert isinstance(token, str)
if not special_tokens and hasattr(self, "do_lower_case") and self.do_lower_case:
token = token.lower()
if (
token != self.unk_token
and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token)
and token not in tokens_to_add
):
tokens_to_add.append(token)
if self.verbose:
logger.info(f"Adding {token} to the vocabulary")
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add))
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
self.added_tokens_encoder.update(added_tok_encoder)
self.added_tokens_decoder.update(added_tok_decoder)
# Make sure we don't split on any special tokens (even they were already in the vocab before e.g. for Albert)
if special_tokens:
if len(new_tokens) == 1:
_insert_one_token_to_ordered_list(self.unique_no_split_tokens, new_tokens[0])
else:
self.unique_no_split_tokens = sorted(set(self.unique_no_split_tokens).union(set(new_tokens)))
else:
# Or on the newly added tokens
if len(tokens_to_add) == 1:
_insert_one_token_to_ordered_list(self.unique_no_split_tokens, tokens_to_add[0])
else:
self.unique_no_split_tokens = sorted(set(self.unique_no_split_tokens).union(set(tokens_to_add)))
return len(tokens_to_add)
def num_special_tokens_to_add(self, pair: bool = False) -> int:
"""
Returns the number of added tokens when encoding a sequence with special tokens.
.. note::
This encodes a dummy input and checks the number of added tokens, and is therefore not efficient. Do not
put this inside your training loop.
Args:
pair (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether the number of added tokens should be computed in the case of a sequence pair or a single
sequence.
Returns:
:obj:`int`: Number of special tokens added to sequences.
"""
token_ids_0 = []
token_ids_1 = []
return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None))
def tokenize(self, text: TextInput, **kwargs) -> List[str]:
"""
Converts a string in a sequence of tokens, using the tokenizer.
Split in words for word-based vocabulary or sub-words for sub-word-based vocabularies
(BPE/SentencePieces/WordPieces). Takes care of added tokens.
Args:
text (:obj:`str`):
The sequence to be encoded.
**kwargs (additional keyword arguments):
Passed along to the model-specific ``prepare_for_tokenization`` preprocessing method.
Returns:
:obj:`List[str]`: The list of tokens.
"""
# Simple mapping string => AddedToken for special tokens with specific tokenization behaviors
all_special_tokens_extended = dict(
(str(t), t) for t in self.all_special_tokens_extended if isinstance(t, AddedToken)
)
text, kwargs = self.prepare_for_tokenization(text, **kwargs)
if kwargs:
logger.warning(f"Keyword arguments {kwargs} not recognized.")
# TODO: should this be in the base class?
if hasattr(self, "do_lower_case") and self.do_lower_case:
# convert non-special tokens to lowercase
escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens]
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
def split_on_token(tok, text):
result = []
tok_extended = all_special_tokens_extended.get(tok, None)
split_text = text.split(tok)
full_word = ""
for i, sub_text in enumerate(split_text):
# AddedToken can control whitespace stripping around them.
# We use them for GPT2 and Roberta to have different behavior depending on the special token
# Cf. https://github.com/huggingface/transformers/pull/2778
# and https://github.com/huggingface/transformers/issues/3788
if isinstance(tok_extended, AddedToken):
if tok_extended.single_word:
# Try to avoid splitting on token
if (
i < len(split_text) - 1
and not _is_end_of_word(sub_text)
and not _is_start_of_word(split_text[i + 1])
):
# Don't extract the special token
full_word += sub_text + tok
elif full_word:
full_word += sub_text
result.append(full_word)
full_word = ""
continue
# Strip white spaces on the right
if tok_extended.rstrip and i > 0:
# A bit counter-intuitive but we strip the left of the string
# since tok_extended.rstrip means the special token is eating all white spaces on its right
sub_text = sub_text.lstrip()
# Strip white spaces on the left
if tok_extended.lstrip and i < len(split_text) - 1:
sub_text = sub_text.rstrip() # Opposite here
else:
# We strip left and right by default
if i < len(split_text) - 1:
sub_text = sub_text.rstrip()
if i > 0:
sub_text = sub_text.lstrip()
if i == 0 and not sub_text:
result.append(tok)
elif i == len(split_text) - 1:
if sub_text:
result.append(sub_text)
else:
pass
else:
if sub_text:
result.append(sub_text)
result.append(tok)
return result
def split_on_tokens(tok_list, text):
if not text.strip():
return []
if not tok_list:
return self._tokenize(text)
tokenized_text = []
text_list = [text]
for tok in tok_list:
tokenized_text = []
for sub_text in text_list:
if sub_text not in self.unique_no_split_tokens:
tokenized_text.extend(split_on_token(tok, sub_text))
else:
tokenized_text.append(sub_text)
text_list = tokenized_text
return list(
itertools.chain.from_iterable(
(
self._tokenize(token) if token not in self.unique_no_split_tokens else [token]
for token in tokenized_text
)
)
)
no_split_token = self.unique_no_split_tokens
tokenized_text = split_on_tokens(no_split_token, text)
return tokenized_text
def _tokenize(self, text, **kwargs):
"""
Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
Do NOT take care of added tokens.
"""
raise NotImplementedError
def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
"""
Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
vocabulary.
Args:
tokens (:obj:`str` or :obj:`List[str]`): One or several token(s) to convert to token id(s).
Returns:
:obj:`int` or :obj:`List[int]`: The token id or list of token ids.
"""
if tokens is None:
return None
if isinstance(tokens, str):
return self._convert_token_to_id_with_added_voc(tokens)
ids = []
for token in tokens:
ids.append(self._convert_token_to_id_with_added_voc(token))
return ids
def _convert_token_to_id_with_added_voc(self, token):
if token is None:
return None
if token in self.added_tokens_encoder:
return self.added_tokens_encoder[token]
return self._convert_token_to_id(token)
def _convert_token_to_id(self, token):
raise NotImplementedError
def _encode_plus(
self,
text: Union[TextInput, PreTokenizedInput, EncodedInput],
text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
add_special_tokens: bool = True,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
max_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs
) -> BatchEncoding:
def get_input_ids(text):
if isinstance(text, str):
tokens = self.tokenize(text, **kwargs)
return self.convert_tokens_to_ids(tokens)
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
if is_split_into_words:
tokens = list(
itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text))
)
return self.convert_tokens_to_ids(tokens)
else:
return self.convert_tokens_to_ids(text)
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
return text
else:
if is_split_into_words:
raise ValueError(
f"Input {text} is not valid. Should be a string or a list/tuple of strings when `is_split_into_words=True`."
)
else:
raise ValueError(
f"Input {text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
)
if return_offsets_mapping:
raise NotImplementedError(
"return_offset_mapping is not available when using Python tokenizers."
"To use this feature, change your tokenizer to one deriving from "
"transformers.PreTrainedTokenizerFast."
"More information on available tokenizers at "
"https://github.com/huggingface/transformers/pull/2674"
)
first_ids = get_input_ids(text)
second_ids = get_input_ids(text_pair) if text_pair is not None else None
return self.prepare_for_model(
first_ids,
pair_ids=second_ids,
add_special_tokens=add_special_tokens,
padding=padding_strategy.value,
truncation=truncation_strategy.value,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors,
prepend_batch_axis=True,
return_attention_mask=return_attention_mask,
return_token_type_ids=return_token_type_ids,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_length=return_length,
verbose=verbose,
)
def _batch_encode_plus(
self,
batch_text_or_text_pairs: Union[
List[TextInput],
List[TextInputPair],
List[PreTokenizedInput],
List[PreTokenizedInputPair],
List[EncodedInput],
List[EncodedInputPair],
],
add_special_tokens: bool = True,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
max_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs
) -> BatchEncoding:
def get_input_ids(text):
if isinstance(text, str):
tokens = self.tokenize(text, **kwargs)
return self.convert_tokens_to_ids(tokens)
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
if is_split_into_words:
tokens = list(
itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text))
)
return self.convert_tokens_to_ids(tokens)
else:
return self.convert_tokens_to_ids(text)
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
return text
else:
raise ValueError(
"Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
)
if return_offsets_mapping:
raise NotImplementedError(
"return_offset_mapping is not available when using Python tokenizers."
"To use this feature, change your tokenizer to one deriving from "
"transformers.PreTrainedTokenizerFast."
)
input_ids = []
for ids_or_pair_ids in batch_text_or_text_pairs:
if not isinstance(ids_or_pair_ids, (list, tuple)):
ids, pair_ids = ids_or_pair_ids, None
elif is_split_into_words and not isinstance(ids_or_pair_ids[0], (list, tuple)):
ids, pair_ids = ids_or_pair_ids, None
else:
ids, pair_ids = ids_or_pair_ids
first_ids = get_input_ids(ids)
second_ids = get_input_ids(pair_ids) if pair_ids is not None else None
input_ids.append((first_ids, second_ids))
batch_outputs = self._batch_prepare_for_model(
input_ids,
add_special_tokens=add_special_tokens,
padding_strategy=padding_strategy,
truncation_strategy=truncation_strategy,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_token_type_ids=return_token_type_ids,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_length=return_length,
return_tensors=return_tensors,
verbose=verbose,
)
return BatchEncoding(batch_outputs)
@add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
def _batch_prepare_for_model(
self,
batch_ids_pairs: List[Union[PreTokenizedInputPair, Tuple[List[int], None]]],
add_special_tokens: bool = True,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[str] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_length: bool = False,
verbose: bool = True,
) -> BatchEncoding:
"""
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
manages a moving window (with user defined stride) for overflowing tokens
Args:
batch_ids_pairs: list of tokenized input ids or input ids pairs
"""
batch_outputs = {}
for first_ids, second_ids in batch_ids_pairs:
outputs = self.prepare_for_model(
first_ids,
second_ids,
add_special_tokens=add_special_tokens,
padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward
truncation=truncation_strategy.value,
max_length=max_length,
stride=stride,
pad_to_multiple_of=None, # we pad in batch afterward
return_attention_mask=False, # we pad in batch afterward
return_token_type_ids=return_token_type_ids,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_length=return_length,
return_tensors=None, # We convert the whole batch to tensors at the end
prepend_batch_axis=False,
verbose=verbose,
)
for key, value in outputs.items():
if key not in batch_outputs:
batch_outputs[key] = []
batch_outputs[key].append(value)
batch_outputs = self.pad(
batch_outputs,
padding=padding_strategy.value,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
return batch_outputs
def prepare_for_tokenization(
self, text: str, is_split_into_words: bool = False, **kwargs
) -> Tuple[str, Dict[str, Any]]:
"""
Performs any necessary transformations before tokenization.
This method should pop the arguments from kwargs and return the remaining :obj:`kwargs` as well. We test the
:obj:`kwargs` at the end of the encoding process to be sure all the arguments have been used.
Args:
text (:obj:`str`):
The text to prepare.
is_split_into_words (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the text has been pretokenized.
kwargs:
Keyword arguments to use for the tokenization.
Returns:
:obj:`Tuple[str, Dict[str, Any]]`: The prepared text and the unused kwargs.
"""
return (text, kwargs)
def get_special_tokens_mask(
self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
Args:
token_ids_0 (:obj:`List[int]`):
List of ids of the first sequence.
token_ids_1 (:obj:`List[int]`, `optional`):
List of ids of the second sequence.
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError(
"You should not supply a second sequence if the provided sequence of "
"ids is already formatted with special tokens for the model."
)
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
@overload
def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = False) -> str:
...
@overload
def convert_ids_to_tokens(self, ids: List[int], skip_special_tokens: bool = False) -> List[str]:
...
def convert_ids_to_tokens(
self, ids: Union[int, List[int]], skip_special_tokens: bool = False
) -> Union[str, List[str]]:
"""
Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
added tokens.
Args:
ids (:obj:`int` or :obj:`List[int]`):
The token id (or token ids) to convert to tokens.
skip_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to remove special tokens in the decoding.
Returns:
:obj:`str` or :obj:`List[str]`: The decoded token(s).
"""
if isinstance(ids, int):
if ids in self.added_tokens_decoder:
return self.added_tokens_decoder[ids]
else:
return self._convert_id_to_token(ids)
tokens = []
for index in ids:
index = int(index)
if skip_special_tokens and index in self.all_special_ids:
continue
if index in self.added_tokens_decoder:
tokens.append(self.added_tokens_decoder[index])
else:
tokens.append(self._convert_id_to_token(index))
return tokens
def _convert_id_to_token(self, index: int) -> str:
raise NotImplementedError
def convert_tokens_to_string(self, tokens: List[str]) -> str:
return " ".join(tokens)
def _decode(
self,
token_ids: List[int],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
spaces_between_special_tokens: bool = True,
**kwargs
) -> str:
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
# To avoid mixing byte-level and unicode for byte-level BPT
# we need to build string separately for added tokens and byte-level tokens
# cf. https://github.com/huggingface/transformers/issues/1133
sub_texts = []
current_sub_text = []
for token in filtered_tokens:
if skip_special_tokens and token in self.all_special_ids:
continue
if token in self.added_tokens_encoder:
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
if spaces_between_special_tokens:
text = " ".join(sub_texts)
else:
text = "".join(sub_texts)
if clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text)
return clean_text
else:
return text
================================================
FILE: flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/utils/tokenization_utils_base.py
================================================
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Base classes common to both the slow and the fast tokenization classes: PreTrainedTokenizerBase (host all the user
fronting encoding methods) Special token mixing (host the special tokens logic) and BatchEncoding (wrap the dictionary
of output with special method for the Fast tokenizers)
"""
import copy
import json
import os
import warnings
from collections import OrderedDict, UserDict
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
import numpy as np
import requests
from .file_utils import (
ExplicitEnum,
PaddingStrategy,
TensorType,
_is_jax,
_is_numpy,
_is_tensorflow,
_is_torch,
_is_torch_device,
add_end_docstrings,
cached_path,
hf_bucket_url,
is_flax_available,
is_offline_mode,
is_remote_url,
is_tf_available,
is_tokenizers_available,
is_torch_available,
to_py_obj,
torch_required,
)
from . import logging
if TYPE_CHECKING:
if is_torch_available():
import torch
if is_tf_available():
import tensorflow as tf
if is_flax_available():
import jax.numpy as jnp # noqa: F401
if is_tokenizers_available():
from tokenizers import AddedToken
from tokenizers import Encoding as EncodingFast
else:
@dataclass(frozen=True, eq=True)
class AddedToken:
"""
AddedToken represents a token to be added to a Tokenizer An AddedToken can have special options defining the
way it should behave.
"""
content: str = field(default_factory=str)
single_word: bool = False
lstrip: bool = False
rstrip: bool = False
normalized: bool = True
def __getstate__(self):
return self.__dict__
@dataclass
class EncodingFast:
""" This is dummy class because without the `tokenizers` library we don't have these objects anyway """
pass
logger = logging.get_logger(__name__)
VERY_LARGE_INTEGER = int(1e30) # This is used to set the max input length for a model with infinite size input
LARGE_INTEGER = int(1e20) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
# Define type aliases and NamedTuples
TextInput = str
PreTokenizedInput = List[str]
EncodedInput = List[int]
TextInputPair = Tuple[str, str]
PreTokenizedInputPair = Tuple[List[str], List[str]]
EncodedInputPair = Tuple[List[int], List[int]]
# Slow tokenizers used to be saved in three separated files
SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
ADDED_TOKENS_FILE = "added_tokens.json"
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file
FULL_TOKENIZER_FILE = "tokenizer.json"
class TruncationStrategy(ExplicitEnum):
"""
Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
tab-completion in an IDE.
"""
ONLY_FIRST = "only_first"
ONLY_SECOND = "only_second"
LONGEST_FIRST = "longest_first"
DO_NOT_TRUNCATE = "do_not_truncate"
class CharSpan(NamedTuple):
"""
Character span in the original string.
Args:
start (:obj:`int`): Index of the first character in the original string.
end (:obj:`int`): Index of the character following the last character in the original string.
"""
start: int
end: int
class TokenSpan(NamedTuple):
"""
Token span in an encoded string (list of tokens).
Args:
start (:obj:`int`): Index of the first token in the span.
end (:obj:`int`): Index of the token following the last token in the span.
"""
start: int
end: int
class BatchEncoding(UserDict):
"""
Holds the output of the :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase.encode_plus` and
:meth:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase.batch_encode` methods (tokens,
attention_masks, etc).
This class is derived from a python dictionary and can be used as a dictionary. In addition, this class exposes
utility methods to map from word/character space to token space.
Args:
data (:obj:`dict`):
Dictionary of lists/arrays/tensors returned by the encode/batch_encode methods ('input_ids',
'attention_mask', etc.).
encoding (:obj:`tokenizers.Encoding` or :obj:`Sequence[tokenizers.Encoding]`, `optional`):
If the tokenizer is a fast tokenizer which outputs additional information like mapping from word/character
space to token space the :obj:`tokenizers.Encoding` instance or list of instance (for batches) hold this
information.
tensor_type (:obj:`Union[None, str, TensorType]`, `optional`):
You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at
initialization.
prepend_batch_axis (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to add a batch axis when converting to tensors (see :obj:`tensor_type` above).
n_sequences (:obj:`Optional[int]`, `optional`):
You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at
initialization.
"""
def __init__(
self,
data: Optional[Dict[str, Any]] = None,
encoding: Optional[Union[EncodingFast, Sequence[EncodingFast]]] = None,
tensor_type: Union[None, str, TensorType] = None,
prepend_batch_axis: bool = False,
n_sequences: Optional[int] = None,
):
super().__init__(data)
if isinstance(encoding, EncodingFast):
encoding = [encoding]
self._encodings = encoding
if n_sequences is None and encoding is not None and len(encoding):
n_sequences = encoding[0].n_sequences
self._n_sequences = n_sequences
self.convert_to_tensors(tensor_type=tensor_type, prepend_batch_axis=prepend_batch_axis)
@property
def n_sequences(self) -> Optional[int]:
"""
:obj:`Optional[int]`: The number of sequences used to generate each sample from the batch encoded in this
:class:`~transformers.BatchEncoding`. Currently can be one of :obj:`None` (unknown), :obj:`1` (a single
sentence) or :obj:`2` (a pair of sentences)
"""
return self._n_sequences
@property
def is_fast(self) -> bool:
"""
:obj:`bool`: Indicate whether this :class:`~transformers.BatchEncoding` was generated from the result of a
:class:`~transformers.PreTrainedTokenizerFast` or not.
"""
return self._encodings is not None
def __getitem__(self, item: Union[int, str]) -> Union[Any, EncodingFast]:
"""
If the key is a string, returns the value of the dict associated to :obj:`key` ('input_ids', 'attention_mask',
etc.).
If the key is an integer, get the :obj:`tokenizers.Encoding` for batch item with index :obj:`key`.
"""
if isinstance(item, str):
return self.data[item]
elif self._encodings is not None:
return self._encodings[item]
else:
raise KeyError(
"Indexing with integers (to access backend Encoding for a given batch index) "
"is not available when using Python based tokenizers"
)
def __getattr__(self, item: str):
try:
return self.data[item]
except KeyError:
raise AttributeError
def __getstate__(self):
return {"data": self.data, "encodings": self._encodings}
def __setstate__(self, state):
if "data" in state:
self.data = state["data"]
if "encodings" in state:
self._encodings = state["encodings"]
def keys(self):
return self.data.keys()
def values(self):
return self.data.values()
def items(self):
return self.data.items()
# After this point:
# Extended properties and methods only available for fast (Rust-based) tokenizers
# provided by HuggingFace tokenizers library.
@property
def encodings(self) -> Optional[List[EncodingFast]]:
"""
:obj:`Optional[List[tokenizers.Encoding]]`: The list all encodings from the tokenization process. Returns
:obj:`None` if the input was tokenized through Python (i.e., not a fast) tokenizer.
"""
return self._encodings
def tokens(self, batch_index: int = 0) -> List[str]:
"""
Return the list of tokens (sub-parts of the input strings after word/subword splitting and before conversion to
integer indices) at a given batch index (only works for the output of a fast tokenizer).
Args:
batch_index (:obj:`int`, `optional`, defaults to 0): The index to access in the batch.
Returns:
:obj:`List[str]`: The list of tokens at that index.
"""
if not self._encodings:
raise ValueError("tokens() is not available when using Python-based tokenizers")
return self._encodings[batch_index].tokens
def sequence_ids(self, batch_index: int = 0) -> List[Optional[int]]:
"""
Return a list mapping the tokens to the id of their original sentences:
- :obj:`None` for special tokens added around or between sequences,
- :obj:`0` for tokens corresponding to words in the first sequence,
- :obj:`1` for tokens corresponding to words in the second sequence when a pair of sequences was jointly
encoded.
Args:
batch_index (:obj:`int`, `optional`, defaults to 0): The index to access in the batch.
Returns:
:obj:`List[Optional[int]]`: A list indicating the sequence id corresponding to each token. Special tokens
added by the tokenizer are mapped to :obj:`None` and other tokens are mapped to the index of their
corresponding sequence.
"""
if not self._encodings:
raise ValueError("sequence_ids() is not available when using Python-based tokenizers")
return self._encodings[batch_index].sequence_ids
def words(self, batch_index: int = 0) -> List[Optional[int]]:
"""
Return a list mapping the tokens to their actual word in the initial sentence for a fast tokenizer.
Args:
batch_index (:obj:`int`, `optional`, defaults to 0): The index to access in the batch.
Returns:
:obj:`List[Optional[int]]`: A list indicating the word corresponding to each token. Special tokens added by
the tokenizer are mapped to :obj:`None` and other tokens are mapped to the index of their corresponding
word (several tokens will be mapped to the same word index if they are parts of that word).
"""
if not self._encodings:
raise ValueError("words() is not available when using Python-based tokenizers")
warnings.warn(
"`BatchEncoding.words()` property is deprecated and should be replaced with the identical, "
"but more self-explanatory `BatchEncoding.word_ids()` property.",
FutureWarning,
)
return self.word_ids(batch_index)
def word_ids(self, batch_index: int = 0) -> List[Optional[int]]:
"""
Return a list mapping the tokens to their actual word in the initial sentence for a fast tokenizer.
Args:
batch_index (:obj:`int`, `optional`, defaults to 0): The index to access in the batch.
Returns:
:obj:`List[Optional[int]]`: A list indicating the word corresponding to each token. Special tokens added by
the tokenizer are mapped to :obj:`None` and other tokens are mapped to the index of their corresponding
word (several tokens will be mapped to the same word index if they are parts of that word).
"""
if not self._encodings:
raise ValueError("word_ids() is not available when using Python-based tokenizers")
return self._encodings[batch_index].word_ids
def token_to_sequence(self, batch_or_token_index: int, token_index: Optional[int] = None) -> int:
"""
Get the index of the sequence represented by the given token. In the general use case, this method returns
:obj:`0` for a single sequence or the first sequence of a pair, and :obj:`1` for the second sequence of a pair
Can be called as:
- ``self.token_to_sequence(token_index)`` if batch size is 1
- ``self.token_to_sequence(batch_index, token_index)`` if batch size is greater than 1
This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e.,
words are defined by the user). In this case it allows to easily associate encoded tokens with provided
tokenized words.
Args:
batch_or_token_index (:obj:`int`):
Index of the sequence in the batch. If the batch only comprises one sequence, this can be the index of
the token in the sequence.
token_index (:obj:`int`, `optional`):
If a batch index is provided in `batch_or_token_index`, this can be the index of the token in the
sequence.
Returns:
:obj:`int`: Index of the word in the input sequence.
"""
if not self._encodings:
raise ValueError("token_to_sequence() is not available when using Python based tokenizers")
if token_index is not None:
batch_index = batch_or_token_index
else:
batch_index = 0
token_index = batch_or_token_index
if batch_index < 0:
batch_index = self._batch_size + batch_index
if token_index < 0:
token_index = self._seq_len + token_index
return self._encodings[batch_index].token_to_sequence(token_index)
def token_to_word(self, batch_or_token_index: int, token_index: Optional[int] = None) -> int:
"""
Get the index of the word corresponding (i.e. comprising) to an encoded token in a sequence of the batch.
Can be called as:
- ``self.token_to_word(token_index)`` if batch size is 1
- ``self.token_to_word(batch_index, token_index)`` if batch size is greater than 1
This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e.,
words are defined by the user). In this case it allows to easily associate encoded tokens with provided
tokenized words.
Args:
batch_or_token_index (:obj:`int`):
Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of
the token in the sequence.
token_index (:obj:`int`, `optional`):
If a batch index is provided in `batch_or_token_index`, this can be the index of the token in the
sequence.
Returns:
:obj:`int`: Index of the word in the input sequence.
"""
if not self._encodings:
raise ValueError("token_to_word() is not available when using Python based tokenizers")
if token_index is not None:
batch_index = batch_or_token_index
else:
batch_index = 0
token_index = batch_or_token_index
if batch_index < 0:
batch_index = self._batch_size + batch_index
if token_index < 0:
token_index = self._seq_len + token_index
return self._encodings[batch_index].token_to_word(token_index)
def word_to_tokens(
self, batch_or_word_index: int, word_index: Optional[int] = None, sequence_index: int = 0
) -> Optional[TokenSpan]:
"""
Get the encoded token span corresponding to a word in a sequence of the batch.
Token spans are returned as a :class:`~transformers.tokenization_utils_base.TokenSpan` with:
- **start** -- Index of the first token.
- **end** -- Index of the token following the last token.
Can be called as:
- ``self.word_to_tokens(word_index, sequence_index: int = 0)`` if batch size is 1
- ``self.word_to_tokens(batch_index, word_index, sequence_index: int = 0)`` if batch size is greater or equal
to 1
This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words
are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized
words.
Args:
batch_or_word_index (:obj:`int`):
Index of the sequence in the batch. If the batch only comprises one sequence, this can be the index of
the word in the sequence.
word_index (:obj:`int`, `optional`):
If a batch index is provided in `batch_or_token_index`, this can be the index of the word in the
sequence.
sequence_index (:obj:`int`, `optional`, defaults to 0):
If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0
or 1) the provided word index belongs to.
Returns:
Optional :class:`~transformers.tokenization_utils_base.TokenSpan` Span of tokens in the encoded sequence.
Returns :obj:`None` if no tokens correspond to the word.
"""
if not self._encodings:
raise ValueError("word_to_tokens() is not available when using Python based tokenizers")
if word_index is not None:
batch_index = batch_or_word_index
else:
batch_index = 0
word_index = batch_or_word_index
if batch_index < 0:
batch_index = self._batch_size + batch_index
if word_index < 0:
word_index = self._seq_len + word_index
span = self._encodings[batch_index].word_to_tokens(word_index, sequence_index)
return TokenSpan(*span) if span is not None else None
def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = None) -> CharSpan:
"""
Get the character span corresponding to an encoded token in a sequence of the batch.
Character spans are returned as a :class:`~transformers.tokenization_utils_base.CharSpan` with:
- **start** -- Index of the first character in the original string associated to the token.
- **end** -- Index of the character following the last character in the original string associated to the
token.
Can be called as:
- ``self.token_to_chars(token_index)`` if batch size is 1
- ``self.token_to_chars(batch_index, token_index)`` if batch size is greater or equal to 1
Args:
batch_or_token_index (:obj:`int`):
Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of
the token in the sequence.
token_index (:obj:`int`, `optional`):
If a batch index is provided in `batch_or_token_index`, this can be the index of the token or tokens in
the sequence.
Returns:
:class:`~transformers.tokenization_utils_base.CharSpan`: Span of characters in the original string.
"""
if not self._encodings:
raise ValueError("token_to_chars() is not available when using Python based tokenizers")
if token_index is not None:
batch_index = batch_or_token_index
else:
batch_index = 0
token_index = batch_or_token_index
return CharSpan(*(self._encodings[batch_index].token_to_chars(token_index)))
def char_to_token(
self, batch_or_char_index: int, char_index: Optional[int] = None, sequence_index: int = 0
) -> int:
"""
Get the index of the token in the encoded output comprising a character in the original string for a sequence
of the batch.
Can be called as:
- ``self.char_to_token(char_index)`` if batch size is 1
- ``self.char_to_token(batch_index, char_index)`` if batch size is greater or equal to 1
This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words
are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized
words.
Args:
batch_or_char_index (:obj:`int`):
Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of
the word in the sequence
char_index (:obj:`int`, `optional`):
If a batch index is provided in `batch_or_token_index`, this can be the index of the word in the
sequence.
sequence_index (:obj:`int`, `optional`, defaults to 0):
If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0
or 1) the provided character index belongs to.
Returns:
:obj:`int`: Index of the token.
"""
if not self._encodings:
raise ValueError("char_to_token() is not available when using Python based tokenizers")
if char_index is not None:
batch_index = batch_or_char_index
else:
batch_index = 0
char_index = batch_or_char_index
return self._encodings[batch_index].char_to_token(char_index, sequence_index)
def word_to_chars(
self, batch_or_word_index: int, word_index: Optional[int] = None, sequence_index: int = 0
) -> CharSpan:
"""
Get the character span in the original string corresponding to given word in a sequence of the batch.
Character spans are returned as a CharSpan NamedTuple with:
- start: index of the first character in the original string
- end: index of the character following the last character in the original string
Can be called as:
- ``self.word_to_chars(word_index)`` if batch size is 1
- ``self.word_to_chars(batch_index, word_index)`` if batch size is greater or equal to 1
Args:
batch_or_word_index (:obj:`int`):
Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of
the word in the sequence
word_index (:obj:`int`, `optional`):
If a batch index is provided in `batch_or_token_index`, this can be the index of the word in the
sequence.
sequence_index (:obj:`int`, `optional`, defaults to 0):
If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0
or 1) the provided word index belongs to.
Returns:
:obj:`CharSpan` or :obj:`List[CharSpan]`: Span(s) of the associated character or characters in the string.
CharSpan are NamedTuple with:
- start: index of the first character associated to the token in the original string
- end: index of the character following the last character associated to the token in the original
string
"""
if not self._encodings:
raise ValueError("word_to_chars() is not available when using Python based tokenizers")
if word_index is not None:
batch_index = batch_or_word_index
else:
batch_index = 0
word_index = batch_or_word_index
return CharSpan(*(self._encodings[batch_index].word_to_chars(word_index, sequence_index)))
def char_to_word(self, batch_or_char_index: int, char_index: Optional[int] = None, sequence_index: int = 0) -> int:
"""
Get the word in the original string corresponding to a character in the original string of a sequence of the
batch.
Can be called as:
- ``self.char_to_word(char_index)`` if batch size is 1
- ``self.char_to_word(batch_index, char_index)`` if batch size is greater than 1
This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words
are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized
words.
Args:
batch_or_char_index (:obj:`int`):
Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of
the character in the original string.
char_index (:obj:`int`, `optional`):
If a batch index is provided in `batch_or_token_index`, this can be the index of the character in the
original string.
sequence_index (:obj:`int`, `optional`, defaults to 0):
If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0
or 1) the provided character index belongs to.
Returns:
:obj:`int` or :obj:`List[int]`: Index or indices of the associated encoded token(s).
"""
if not self._encodings:
raise ValueError("char_to_word() is not available when using Python based tokenizers")
if char_index is not None:
batch_index = batch_or_char_index
else:
batch_index = 0
char_index = batch_or_char_index
return self._encodings[batch_index].char_to_word(char_index, sequence_index)
def convert_to_tensors(
self, tensor_type: Optional[Union[str, TensorType]] = None, prepend_batch_axis: bool = False
):
"""
Convert the inner content to tensors.
Args:
tensor_type (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`):
The type of tensors to use. If :obj:`str`, should be one of the values of the enum
:class:`~transformers.file_utils.TensorType`. If :obj:`None`, no modification is done.
prepend_batch_axis (:obj:`int`, `optional`, defaults to :obj:`False`):
Whether or not to add the batch dimension during the conversion.
"""
if tensor_type is None:
return self
# Convert to TensorType
if not isinstance(tensor_type, TensorType):
tensor_type = TensorType(tensor_type)
# Get a function reference for the correct framework
if tensor_type == TensorType.TENSORFLOW:
if not is_tf_available():
raise ImportError(
"Unable to convert output to TensorFlow tensors format, TensorFlow is not installed."
)
import tensorflow as tf
as_tensor = tf.constant
is_tensor = tf.is_tensor
elif tensor_type == TensorType.PYTORCH:
if not is_torch_available():
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
import torch
as_tensor = torch.tensor
is_tensor = torch.is_tensor
elif tensor_type == TensorType.JAX:
if not is_flax_available():
raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
import jax.numpy as jnp # noqa: F811
as_tensor = jnp.array
is_tensor = _is_jax
else:
as_tensor = np.asarray
is_tensor = _is_numpy
# (mfuntowicz: This code is unreachable)
# else:
# raise ImportError(
# f"Unable to convert output to tensors format {tensor_type}"
# )
# Do the tensor conversion in batch
for key, value in self.items():
try:
if prepend_batch_axis:
value = [value]
if not is_tensor(value):
tensor = as_tensor(value)
# Removing this for now in favor of controlling the shape with `prepend_batch_axis`
# # at-least2d
# if tensor.ndim > 2:
# tensor = tensor.squeeze(0)
# elif tensor.ndim < 2:
# tensor = tensor[None, :]
self[key] = tensor
except: # noqa E722
if key == "overflowing_tokens":
raise ValueError(
"Unable to create tensor returning overflowing tokens of different lengths. "
"Please see if a fast version of this tokenizer is available to have this feature available."
)
raise ValueError(
"Unable to create tensor, you should probably activate truncation and/or padding "
"with 'padding=True' 'truncation=True' to have batched tensors with the same length."
)
return self
@torch_required
def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
"""
Send all values to device by calling :obj:`v.to(device)` (PyTorch only).
Args:
device (:obj:`str` or :obj:`torch.device`): The device to put the tensors on.
Returns:
:class:`~transformers.BatchEncoding`: The same instance after modification.
"""
# This check catches things like APEX blindly calling "to" on all inputs to a module
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
# into a HalfTensor
if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int):
self.data = {k: v.to(device=device) for k, v in self.data.items()}
else:
logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
return self
class SpecialTokensMixin:
"""
A mixin derived by :class:`~transformers.PreTrainedTokenizer` and :class:`~transformers.PreTrainedTokenizerFast` to
handle specific behaviors related to special tokens. In particular, this class hold the attributes which can be
used to directly access these special tokens in a model-independent manner and allow to set and update the special
tokens.
Args:
bos_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):
A special token representing the beginning of a sentence.
eos_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):
A special token representing the end of a sentence.
unk_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):
A special token representing an out-of-vocabulary token.
sep_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):
A special token separating two different sentences in the same input (used by BERT for instance).
pad_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):
A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
attention mechanisms or loss computation.
cls_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):
A special token representing the class of the input (used by BERT for instance).
mask_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):
A special token representing a masked token (used by masked-language modeling pretraining objectives, like
BERT).
additional_special_tokens (tuple or list of :obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):
A tuple or a list of additional special tokens.
"""
SPECIAL_TOKENS_ATTRIBUTES = [
"bos_token",
"eos_token",
"unk_token",
"sep_token",
"pad_token",
"cls_token",
"mask_token",
"additional_special_tokens",
]
def __init__(self, verbose=True, **kwargs):
self._bos_token = None
self._eos_token = None
self._unk_token = None
self._sep_token = None
self._pad_token = None
self._cls_token = None
self._mask_token = None
self._pad_token_type_id = 0
self._additional_special_tokens = []
self.verbose = verbose
# We directly set the hidden value to allow initialization with special tokens
# which are not yet in the vocabulary. Necessary for serialization/de-serialization
# TODO clean this up at some point (probably by switching to fast tokenizers)
for key, value in kwargs.items():
if value is None:
continue
if key in self.SPECIAL_TOKENS_ATTRIBUTES:
if key == "additional_special_tokens":
assert isinstance(value, (list, tuple)), f"Value {value} is not a list or tuple"
assert all(isinstance(t, str) for t in value), "One of the tokens is not a string"
setattr(self, key, value)
elif isinstance(value, (str, AddedToken)):
setattr(self, key, value)
else:
raise TypeError(f"special token {key} has to be either str or AddedToken but got: {type(value)}")
def sanitize_special_tokens(self) -> int:
"""
Make sure that all the special tokens attributes of the tokenizer (:obj:`tokenizer.mask_token`,
:obj:`tokenizer.cls_token`, etc.) are in the vocabulary.
Add the missing ones to the vocabulary if needed.
Return:
:obj:`int`: The number of tokens added in the vocabulary during the operation.
"""
return self.add_tokens(self.all_special_tokens_extended, special_tokens=True)
def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, AddedToken]]) -> int:
"""
Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder and link them to class attributes. If
special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last index of the
current vocabulary).
.. Note::
When adding new tokens to the vocabulary, you should make sure to also resize the token embedding matrix of
the model so that its embedding matrix matches the tokenizer.
In order to do that, please use the :meth:`~transformers.PreTrainedModel.resize_token_embeddings` method.
Using :obj:`add_special_tokens` will ensure your special tokens can be used in several ways:
- Special tokens are carefully handled by the tokenizer (they are never split).
- You can easily refer to special tokens using tokenizer class attributes like :obj:`tokenizer.cls_token`. This
makes it easy to develop model-agnostic training and fine-tuning scripts.
When possible, special tokens are already registered for provided pretrained models (for instance
:class:`~transformers.BertTokenizer` :obj:`cls_token` is already registered to be :obj`'[CLS]'` and XLM's one
is also registered to be :obj:`''`).
Args:
special_tokens_dict (dictionary `str` to `str` or :obj:`tokenizers.AddedToken`):
Keys should be in the list of predefined special attributes: [``bos_token``, ``eos_token``,
``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``,
``additional_special_tokens``].
Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer
assign the index of the ``unk_token`` to them).
Returns:
:obj:`int`: Number of tokens added to the vocabulary.
Examples::
# Let's see how to add a new classification token to GPT-2
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
special_tokens_dict = {'cls_token': ''}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
print('We have added', num_added_toks, 'tokens')
# Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer.
model.resize_token_embeddings(len(tokenizer))
assert tokenizer.cls_token == ''
"""
if not special_tokens_dict:
return 0
added_tokens = 0
for key, value in special_tokens_dict.items():
assert key in self.SPECIAL_TOKENS_ATTRIBUTES, f"Key {key} is not a special token"
if self.verbose:
logger.info(f"Assigning {value} to the {key} key of the tokenizer")
setattr(self, key, value)
if key == "additional_special_tokens":
assert isinstance(value, (list, tuple)) and all(
isinstance(t, (str, AddedToken)) for t in value
), f"Tokens {value} for key {key} should all be str or AddedToken instances"
added_tokens += self.add_tokens(value, special_tokens=True)
else:
assert isinstance(
value, (str, AddedToken)
), f"Token {value} for key {key} should be a str or an AddedToken instance"
added_tokens += self.add_tokens([value], special_tokens=True)
return added_tokens
def add_tokens(
self, new_tokens: Union[str, AddedToken, List[Union[str, AddedToken]]], special_tokens: bool = False
) -> int:
"""
Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to
it with indices starting from length of the current vocabulary.
.. Note::
When adding new tokens to the vocabulary, you should make sure to also resize the token embedding matrix of
the model so that its embedding matrix matches the tokenizer.
In order to do that, please use the :meth:`~transformers.PreTrainedModel.resize_token_embeddings` method.
Args:
new_tokens (:obj:`str`, :obj:`tokenizers.AddedToken` or a list of `str` or :obj:`tokenizers.AddedToken`):
Tokens are only added if they are not already in the vocabulary. :obj:`tokenizers.AddedToken` wraps a
string token to let you personalize its behavior: whether this token should only match against a single
word, whether this token should strip all potential whitespaces on the left side, whether this token
should strip all potential whitespaces on the right side, etc.
special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Can be used to specify if the token is a special token. This mostly change the normalization behavior
(special tokens like CLS or [MASK] are usually not lower-cased for instance).
See details for :obj:`tokenizers.AddedToken` in HuggingFace tokenizers library.
Returns:
:obj:`int`: Number of tokens added to the vocabulary.
Examples::
# Let's see how to increase the vocabulary of Bert model and tokenizer
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
print('We have added', num_added_toks, 'tokens')
# Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer.
model.resize_token_embeddings(len(tokenizer))
"""
if not new_tokens:
return 0
if not isinstance(new_tokens, (list, tuple)):
new_tokens = [new_tokens]
return self._add_tokens(new_tokens, special_tokens=special_tokens)
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
raise NotImplementedError
@property
def bos_token(self) -> str:
"""
:obj:`str`: Beginning of sentence token. Log an error if used while not having been set.
"""
if self._bos_token is None and self.verbose:
logger.error("Using bos_token, but it is not set yet.")
return None
return str(self._bos_token)
@property
def eos_token(self) -> str:
"""
:obj:`str`: End of sentence token. Log an error if used while not having been set.
"""
if self._eos_token is None and self.verbose:
logger.error("Using eos_token, but it is not set yet.")
return None
return str(self._eos_token)
@property
def unk_token(self) -> str:
"""
:obj:`str`: Unknown token. Log an error if used while not having been set.
"""
if self._unk_token is None and self.verbose:
logger.error("Using unk_token, but it is not set yet.")
return None
return str(self._unk_token)
@property
def sep_token(self) -> str:
"""
:obj:`str`: Separation token, to separate context and query in an input sequence. Log an error if used while
not having been set.
"""
if self._sep_token is None and self.verbose:
logger.error("Using sep_token, but it is not set yet.")
return None
return str(self._sep_token)
@property
def pad_token(self) -> str:
"""
:obj:`str`: Padding token. Log an error if used while not having been set.
"""
if self._pad_token is None and self.verbose:
logger.error("Using pad_token, but it is not set yet.")
return None
return str(self._pad_token)
@property
def cls_token(self) -> str:
"""
:obj:`str`: Classification token, to extract a summary of an input sequence leveraging self-attention along the
full depth of the model. Log an error if used while not having been set.
"""
if self._cls_token is None and self.verbose:
logger.error("Using cls_token, but it is not set yet.")
return None
return str(self._cls_token)
@property
def mask_token(self) -> str:
"""
:obj:`str`: Mask token, to use when training a model with masked-language modeling. Log an error if used while
not having been set.
"""
if self._mask_token is None and self.verbose:
logger.error("Using mask_token, but it is not set yet.")
return None
return str(self._mask_token)
@property
def additional_special_tokens(self) -> List[str]:
"""
:obj:`List[str]`: All the additional special tokens you may want to use. Log an error if used while not having
been set.
"""
if self._additional_special_tokens is None and self.verbose:
logger.error("Using additional_special_tokens, but it is not set yet.")
return None
return [str(tok) for tok in self._additional_special_tokens]
@bos_token.setter
def bos_token(self, value):
self._bos_token = value
@eos_token.setter
def eos_token(self, value):
self._eos_token = value
@unk_token.setter
def unk_token(self, value):
self._unk_token = value
@sep_token.setter
def sep_token(self, value):
self._sep_token = value
@pad_token.setter
def pad_token(self, value):
self._pad_token = value
@cls_token.setter
def cls_token(self, value):
self._cls_token = value
@mask_token.setter
def mask_token(self, value):
self._mask_token = value
@additional_special_tokens.setter
def additional_special_tokens(self, value):
self._additional_special_tokens = value
@property
def bos_token_id(self) -> Optional[int]:
"""
:obj:`Optional[int]`: Id of the beginning of sentence token in the vocabulary. Returns :obj:`None` if the token
has not been set.
"""
if self._bos_token is None:
return None
return self.convert_tokens_to_ids(self.bos_token)
@property
def eos_token_id(self) -> Optional[int]:
"""
:obj:`Optional[int]`: Id of the end of sentence token in the vocabulary. Returns :obj:`None` if the token has
not been set.
"""
if self._eos_token is None:
return None
return self.convert_tokens_to_ids(self.eos_token)
@property
def unk_token_id(self) -> Optional[int]:
"""
:obj:`Optional[int]`: Id of the unknown token in the vocabulary. Returns :obj:`None` if the token has not been
set.
"""
if self._unk_token is None:
return None
return self.convert_tokens_to_ids(self.unk_token)
@property
def sep_token_id(self) -> Optional[int]:
"""
:obj:`Optional[int]`: Id of the separation token in the vocabulary, to separate context and query in an input
sequence. Returns :obj:`None` if the token has not been set.
"""
if self._sep_token is None:
return None
return self.convert_tokens_to_ids(self.sep_token)
@property
def pad_token_id(self) -> Optional[int]:
"""
:obj:`Optional[int]`: Id of the padding token in the vocabulary. Returns :obj:`None` if the token has not been
set.
"""
if self._pad_token is None:
return None
return self.convert_tokens_to_ids(self.pad_token)
@property
def pad_token_type_id(self) -> int:
"""
:obj:`int`: Id of the padding token type in the vocabulary.
"""
return self._pad_token_type_id
@property
def cls_token_id(self) -> Optional[int]:
"""
:obj:`Optional[int]`: Id of the classification token in the vocabulary, to extract a summary of an input
sequence leveraging self-attention along the full depth of the model.
Returns :obj:`None` if the token has not been set.
"""
if self._cls_token is None:
return None
return self.convert_tokens_to_ids(self.cls_token)
@property
def mask_token_id(self) -> Optional[int]:
"""
:obj:`Optional[int]`: Id of the mask token in the vocabulary, used when training a model with masked-language
modeling. Returns :obj:`None` if the token has not been set.
"""
if self._mask_token is None:
return None
return self.convert_tokens_to_ids(self.mask_token)
@property
def additional_special_tokens_ids(self) -> List[int]:
"""
:obj:`List[int]`: Ids of all the additional special tokens in the vocabulary. Log an error if used while not
having been set.
"""
return self.convert_tokens_to_ids(self.additional_special_tokens)
@bos_token_id.setter
def bos_token_id(self, value):
self._bos_token = self.convert_tokens_to_ids(value)
@eos_token_id.setter
def eos_token_id(self, value):
self._eos_token = self.convert_tokens_to_ids(value)
@unk_token_id.setter
def unk_token_id(self, value):
self._unk_token = self.convert_tokens_to_ids(value)
@sep_token_id.setter
def sep_token_id(self, value):
self._sep_token = self.convert_tokens_to_ids(value)
@pad_token_id.setter
def pad_token_id(self, value):
self._pad_token = self.convert_tokens_to_ids(value)
@cls_token_id.setter
def cls_token_id(self, value):
self._cls_token = self.convert_tokens_to_ids(value)
@mask_token_id.setter
def mask_token_id(self, value):
self._mask_token = self.convert_tokens_to_ids(value)
@additional_special_tokens_ids.setter
def additional_special_tokens_ids(self, values):
self._additional_special_tokens = [self.convert_tokens_to_ids(value) for value in values]
@property
def special_tokens_map(self) -> Dict[str, Union[str, List[str]]]:
"""
:obj:`Dict[str, Union[str, List[str]]]`: A dictionary mapping special token class attributes (:obj:`cls_token`,
:obj:`unk_token`, etc.) to their values (:obj:`''`, :obj:`''`, etc.).
Convert potential tokens of :obj:`tokenizers.AddedToken` type to string.
"""
set_attr = {}
for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
attr_value = getattr(self, "_" + attr)
if attr_value:
set_attr[attr] = str(attr_value)
return set_attr
@property
def special_tokens_map_extended(self) -> Dict[str, Union[str, AddedToken, List[Union[str, AddedToken]]]]:
"""
:obj:`Dict[str, Union[str, tokenizers.AddedToken, List[Union[str, tokenizers.AddedToken]]]]`: A dictionary
mapping special token class attributes (:obj:`cls_token`, :obj:`unk_token`, etc.) to their values
(:obj:`''`, :obj:`''`, etc.).
Don't convert tokens of :obj:`tokenizers.AddedToken` type to string so they can be used to control more finely
how special tokens are tokenized.
"""
set_attr = {}
for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
attr_value = getattr(self, "_" + attr)
if attr_value:
set_attr[attr] = attr_value
return set_attr
@property
def all_special_tokens(self) -> List[str]:
"""
:obj:`List[str]`: All the special tokens (:obj:`''`, :obj:`''`, etc.) mapped to class attributes.
Convert tokens of :obj:`tokenizers.AddedToken` type to string.
"""
all_toks = [str(s) for s in self.all_special_tokens_extended]
return all_toks
@property
def all_special_tokens_extended(self) -> List[Union[str, AddedToken]]:
"""
:obj:`List[Union[str, tokenizers.AddedToken]]`: All the special tokens (:obj:`''`, :obj:`''`, etc.)
mapped to class attributes.
Don't convert tokens of :obj:`tokenizers.AddedToken` type to string so they can be used to control more finely
how special tokens are tokenized.
"""
all_toks = []
set_attr = self.special_tokens_map_extended
for attr_value in set_attr.values():
all_toks = all_toks + (list(attr_value) if isinstance(attr_value, (list, tuple)) else [attr_value])
all_toks = list(OrderedDict.fromkeys(all_toks))
return all_toks
@property
def all_special_ids(self) -> List[int]:
"""
:obj:`List[int]`: List the ids of the special tokens(:obj:`''`, :obj:`''`, etc.) mapped to class
attributes.
"""
all_toks = self.all_special_tokens
all_ids = self.convert_tokens_to_ids(all_toks)
return all_ids
ENCODE_KWARGS_DOCSTRING = r"""
add_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to encode the sequences with the special tokens relative to their model.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Activates and controls padding. Accepts the following values:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`False`):
Activates and controls truncation. Accepts the following values:
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
provided. This will truncate token by token, removing a token from the longest sequence in the pair
if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
to the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
sequence lengths greater than the model maximum admissible input size).
max_length (:obj:`int`, `optional`):
Controls the maximum length to use by one of the truncation/padding parameters.
If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum
length is required by one of the truncation/padding parameters. If the model has no specific maximum
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
stride (:obj:`int`, `optional`, defaults to 0):
If set to a number along with :obj:`max_length`, the overflowing tokens returned when
:obj:`return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence
returned to provide some overlap between truncated and overflowing sequences. The value of this
argument defines the number of overlapping tokens.
is_split_into_words (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the input is already pre-tokenized (e.g., split into words), in which case the tokenizer
will skip the pre-tokenization step. This is useful for NER or token classification.
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
"""
ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r"""
return_token_type_ids (:obj:`bool`, `optional`):
Whether to return token type IDs. If left to the default, will return the token type IDs according to
the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
`What are token type IDs? <../glossary.html#token-type-ids>`__
return_attention_mask (:obj:`bool`, `optional`):
Whether to return the attention mask. If left to the default, will return the attention mask according
to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
`What are attention masks? <../glossary.html#attention-mask>`__
return_overflowing_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return overflowing token sequences.
return_special_tokens_mask (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return special tokens mask information.
return_offsets_mapping (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return :obj:`(char_start, char_end)` for each token.
This is only available on fast tokenizers inheriting from
:class:`~transformers.PreTrainedTokenizerFast`, if using Python's tokenizer, this method will raise
:obj:`NotImplementedError`.
return_length (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return the lengths of the encoded inputs.
verbose (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to print more information and warnings.
**kwargs: passed to the :obj:`self.tokenize()` method
Return:
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
- **input_ids** -- List of token ids to be fed to a model.
`What are input IDs? <../glossary.html#input-ids>`__
- **token_type_ids** -- List of token type ids to be fed to a model (when :obj:`return_token_type_ids=True`
or if `"token_type_ids"` is in :obj:`self.model_input_names`).
`What are token type IDs? <../glossary.html#token-type-ids>`__
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
:obj:`return_attention_mask=True` or if `"attention_mask"` is in :obj:`self.model_input_names`).
`What are attention masks? <../glossary.html#attention-mask>`__
- **overflowing_tokens** -- List of overflowing tokens sequences (when a :obj:`max_length` is specified and
:obj:`return_overflowing_tokens=True`).
- **num_truncated_tokens** -- Number of tokens truncated (when a :obj:`max_length` is specified and
:obj:`return_overflowing_tokens=True`).
- **special_tokens_mask** -- List of 0s and 1s, with 1 specifying added special tokens and 0 specifying
regular sequence tokens (when :obj:`add_special_tokens=True` and :obj:`return_special_tokens_mask=True`).
- **length** -- The length of the inputs (when :obj:`return_length=True`)
"""
INIT_TOKENIZER_DOCSTRING = r"""
Class attributes (overridden by derived classes)
- **vocab_files_names** (:obj:`Dict[str, str]`) -- A dictionary with, as keys, the ``__init__`` keyword name of
each vocabulary file required by the model, and as associated values, the filename for saving the associated
file (string).
- **pretrained_vocab_files_map** (:obj:`Dict[str, Dict[str, str]]`) -- A dictionary of dictionaries, with the
high-level keys being the ``__init__`` keyword name of each vocabulary file required by the model, the
low-level being the :obj:`short-cut-names` of the pretrained models with, as associated values, the
:obj:`url` to the associated pretrained vocabulary file.
- **max_model_input_sizes** (:obj:`Dict[str, Optinal[int]]`) -- A dictionary with, as keys, the
:obj:`short-cut-names` of the pretrained models, and as associated values, the maximum length of the sequence
inputs of this model, or :obj:`None` if the model has no maximum input size.
- **pretrained_init_configuration** (:obj:`Dict[str, Dict[str, Any]]`) -- A dictionary with, as keys, the
:obj:`short-cut-names` of the pretrained models, and as associated values, a dictionary of specific arguments
to pass to the ``__init__`` method of the tokenizer class for this pretrained model when loading the
tokenizer with the :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase.from_pretrained`
method.
- **model_input_names** (:obj:`List[str]`) -- A list of inputs expected in the forward pass of the model.
- **padding_side** (:obj:`str`) -- The default value for the side on which the model should have padding
applied. Should be :obj:`'right'` or :obj:`'left'`.
Args:
model_max_length (:obj:`int`, `optional`):
The maximum length (in number of tokens) for the inputs to the transformer model. When the tokenizer is
loaded with :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase.from_pretrained`, this
will be set to the value stored for the associated model in ``max_model_input_sizes`` (see above). If no
value is provided, will default to VERY_LARGE_INTEGER (:obj:`int(1e30)`).
padding_side: (:obj:`str`, `optional`):
The side on which the model should have padding applied. Should be selected between ['right', 'left'].
Default value is picked from the class attribute of the same name.
model_input_names (:obj:`List[string]`, `optional`):
The list of inputs accepted by the forward pass of the model (like :obj:`"token_type_ids"` or
:obj:`"attention_mask"`). Default value is picked from the class attribute of the same name.
bos_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):
A special token representing the beginning of a sentence. Will be associated to ``self.bos_token`` and
``self.bos_token_id``.
eos_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):
A special token representing the end of a sentence. Will be associated to ``self.eos_token`` and
``self.eos_token_id``.
unk_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):
A special token representing an out-of-vocabulary token. Will be associated to ``self.unk_token`` and
``self.unk_token_id``.
sep_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):
A special token separating two different sentences in the same input (used by BERT for instance). Will be
associated to ``self.sep_token`` and ``self.sep_token_id``.
pad_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):
A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
attention mechanisms or loss computation. Will be associated to ``self.pad_token`` and
``self.pad_token_id``.
cls_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):
A special token representing the class of the input (used by BERT for instance). Will be associated to
``self.cls_token`` and ``self.cls_token_id``.
mask_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):
A special token representing a masked token (used by masked-language modeling pretraining objectives, like
BERT). Will be associated to ``self.mask_token`` and ``self.mask_token_id``.
additional_special_tokens (tuple or list of :obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):
A tuple or a list of additional special tokens. Add them here to ensure they won't be split by the
tokenization process. Will be associated to ``self.additional_special_tokens`` and
``self.additional_special_tokens_ids``.
"""
@add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
class PreTrainedTokenizerBase(SpecialTokensMixin):
"""
Base class for :class:`~transformers.PreTrainedTokenizer` and :class:`~transformers.PreTrainedTokenizerFast`.
Handles shared (mostly boiler plate) methods for those two classes.
"""
vocab_files_names: Dict[str, str] = {}
pretrained_vocab_files_map: Dict[str, Dict[str, str]] = {}
pretrained_init_configuration: Dict[str, Dict[str, Any]] = {}
max_model_input_sizes: Dict[str, Optional[int]] = {}
# first name has to correspond to main model input name
# to make sure `tokenizer.pad(...)` works correctly
model_input_names: List[str] = ["input_ids", "token_type_ids", "attention_mask"]
padding_side: str = "right"
slow_tokenizer_class = None
def __init__(self, **kwargs):
# inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
self.init_inputs = ()
self.init_kwargs = copy.deepcopy(kwargs)
self.name_or_path = kwargs.pop("name_or_path", "")
# For backward compatibility we fallback to set model_max_length from max_len if provided
model_max_length = kwargs.pop("model_max_length", kwargs.pop("max_len", None))
self.model_max_length = model_max_length if model_max_length is not None else VERY_LARGE_INTEGER
# Padding side is right by default and overridden in subclasses. If specified in the kwargs, it is changed.
self.padding_side = kwargs.pop("padding_side", self.padding_side)
assert self.padding_side in [
"right",
"left",
], f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}"
self.model_input_names = kwargs.pop("model_input_names", self.model_input_names)
self.deprecation_warnings = (
{}
) # Use to store when we have already noticed a deprecation warning (avoid overlogging).
super().__init__(**kwargs)
@property
def max_len_single_sentence(self) -> int:
"""
:obj:`int`: The maximum length of a sentence that can be fed to the model.
"""
return self.model_max_length - self.num_special_tokens_to_add(pair=False)
@property
def max_len_sentences_pair(self) -> int:
"""
:obj:`int`: The maximum combined length of a pair of sentences that can be fed to the model.
"""
return self.model_max_length - self.num_special_tokens_to_add(pair=True)
@max_len_single_sentence.setter
def max_len_single_sentence(self, value) -> int:
# For backward compatibility, allow to try to setup 'max_len_single_sentence'.
if value == self.model_max_length - self.num_special_tokens_to_add(pair=False) and self.verbose:
if not self.deprecation_warnings.get("max_len_single_sentence", False):
logger.warning(
"Setting 'max_len_single_sentence' is now deprecated. " "This value is automatically set up."
)
self.deprecation_warnings["max_len_single_sentence"] = True
else:
raise ValueError(
"Setting 'max_len_single_sentence' is now deprecated. " "This value is automatically set up."
)
@max_len_sentences_pair.setter
def max_len_sentences_pair(self, value) -> int:
# For backward compatibility, allow to try to setup 'max_len_sentences_pair'.
if value == self.model_max_length - self.num_special_tokens_to_add(pair=True) and self.verbose:
if not self.deprecation_warnings.get("max_len_sentences_pair", False):
logger.warning(
"Setting 'max_len_sentences_pair' is now deprecated. " "This value is automatically set up."
)
self.deprecation_warnings["max_len_sentences_pair"] = True
else:
raise ValueError(
"Setting 'max_len_sentences_pair' is now deprecated. " "This value is automatically set up."
)
def __repr__(self) -> str:
return (
f"{'PreTrainedTokenizerFast' if self.is_fast else 'PreTrainedTokenizer'}(name_or_path='{self.name_or_path}', "
f"vocab_size={self.vocab_size}, model_max_len={self.model_max_length}, is_fast={self.is_fast}, "
f"padding_side='{self.padding_side}', special_tokens={self.special_tokens_map_extended})"
)
def get_vocab(self) -> Dict[str, int]:
"""
Returns the vocabulary as a dictionary of token to index.
:obj:`tokenizer.get_vocab()[token]` is equivalent to :obj:`tokenizer.convert_tokens_to_ids(token)` when
:obj:`token` is in the vocab.
Returns:
:obj:`Dict[str, int]`: The vocabulary.
"""
raise NotImplementedError()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *init_inputs, **kwargs):
r"""
Instantiate a :class:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase` (or a derived class) from
a predefined tokenizer.
Args:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
Can be either:
- A string, the `model id` of a predefined tokenizer hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under a
user or organization name, like ``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing vocabulary files required by the tokenizer, for instance saved
using the :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase.save_pretrained`
method, e.g., ``./my_model_directory/``.
- (**Deprecated**, not applicable to all derived classes) A path or url to a single saved vocabulary
file (if and only if the tokenizer only requires a single vocabulary file like Bert or XLNet), e.g.,
``./my_model_directory/vocab.txt``.
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the
standard cache should not be used.
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force the (re-)download the vocabulary files and override the cached versions if they
exist.
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to delete incompletely received files. Attempt to resume the download if such a file
exists.
proxies (:obj:`Dict[str, str], `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
use_auth_token (:obj:`str` or `bool`, `optional`):
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
subfolder (:obj:`str`, `optional`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co (e.g. for
facebook/rag-token-base), specify it here.
inputs (additional positional arguments, `optional`):
Will be passed along to the Tokenizer ``__init__`` method.
kwargs (additional keyword arguments, `optional`):
Will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like
``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``,
``mask_token``, ``additional_special_tokens``. See parameters in the ``__init__`` for more details.
.. note::
Passing :obj:`use_auth_token=True` is required when you want to use a private model.
Examples::
# We can't instantiate directly the base class `PreTrainedTokenizerBase` so let's show our examples on a derived class: BertTokenizer
# Download vocabulary from huggingface.co and cache.
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Download vocabulary from huggingface.co (user-uploaded) and cache.
tokenizer = BertTokenizer.from_pretrained('dbmdz/bert-base-german-cased')
# If vocabulary files are in a directory (e.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`)
tokenizer = BertTokenizer.from_pretrained('./test/saved_model/')
# If the tokenizer uses a single vocabulary file, you can point directly to this file
tokenizer = BertTokenizer.from_pretrained('./test/saved_model/my_vocab.txt')
# You can link tokens to special vocabulary when instantiating
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', unk_token='')
# You should be sure '' is in the vocabulary when doing that.
# Otherwise use tokenizer.add_special_tokens({'unk_token': ''}) instead)
assert tokenizer.unk_token == ''
"""
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
user_agent = {"file_type": "tokenizer", "from_auto_class": from_auto_class, "is_fast": "Fast" in cls.__name__}
if from_pipeline is not None:
user_agent["using_pipeline"] = from_pipeline
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
vocab_files = {}
init_configuration = {}
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
if len(cls.vocab_files_names) > 1:
raise ValueError(
f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not "
"supported for this tokenizer. Use a model identifier or the path to a directory instead."
)
warnings.warn(
f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is deprecated and "
"won't be possible anymore in v5. Use a model identifier or the path to a directory instead.",
FutureWarning,
)
file_id = list(cls.vocab_files_names.keys())[0]
vocab_files[file_id] = pretrained_model_name_or_path
else:
# At this point pretrained_model_name_or_path is either a directory or a model identifier name
additional_files_names = {
"added_tokens_file": ADDED_TOKENS_FILE,
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
"tokenizer_file": FULL_TOKENIZER_FILE,
}
# Look for the tokenizer files
for file_id, file_name in {**cls.vocab_files_names, **additional_files_names}.items():
if os.path.isdir(pretrained_model_name_or_path):
if subfolder is not None:
full_file_name = os.path.join(pretrained_model_name_or_path, subfolder, file_name)
else:
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
if not os.path.exists(full_file_name):
logger.info(f"Didn't find file {full_file_name}. We won't load it.")
full_file_name = None
else:
full_file_name = hf_bucket_url(
pretrained_model_name_or_path,
filename=file_name,
subfolder=subfolder,
revision=revision,
mirror=None,
)
vocab_files[file_id] = full_file_name
# Get files from url, cache, or disk depending on the case
resolved_vocab_files = {}
unresolved_files = []
for file_id, file_path in vocab_files.items():
if file_path is None:
resolved_vocab_files[file_id] = None
else:
try:
resolved_vocab_files[file_id] = cached_path(
file_path,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)
except FileNotFoundError as error:
if local_files_only:
unresolved_files.append(file_id)
else:
raise error
except requests.exceptions.HTTPError as err:
if "404 Client Error" in str(err):
logger.debug(err)
resolved_vocab_files[file_id] = None
else:
raise err
if len(unresolved_files) > 0:
logger.info(
f"Can't load following files from cache: {unresolved_files} and cannot check if these "
"files are necessary for the tokenizer to operate."
)
if all(full_file_name is None for full_file_name in resolved_vocab_files.values()):
msg = (
f"Can't load tokenizer for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing relevant tokenizer files\n\n"
)
raise EnvironmentError(msg)
for file_id, file_path in vocab_files.items():
if file_id not in resolved_vocab_files:
continue
if file_path == resolved_vocab_files[file_id]:
logger.info(f"loading file {file_path}")
else:
logger.info(f"loading file {file_path} from cache at {resolved_vocab_files[file_id]}")
return cls._from_pretrained(
resolved_vocab_files, pretrained_model_name_or_path, init_configuration, *init_inputs, **kwargs
)
@classmethod
def _from_pretrained(
cls, resolved_vocab_files, pretrained_model_name_or_path, init_configuration, *init_inputs, **kwargs
):
# We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json
# file or if `from_slow` is set to True.
from_slow = kwargs.get("from_slow", False)
has_tokenizer_file = resolved_vocab_files.get("tokenizer_file", None) is not None
if (from_slow or not has_tokenizer_file) and cls.slow_tokenizer_class is not None:
slow_tokenizer = (cls.slow_tokenizer_class)._from_pretrained(
copy.deepcopy(resolved_vocab_files),
pretrained_model_name_or_path,
copy.deepcopy(init_configuration),
*init_inputs,
**(copy.deepcopy(kwargs)),
)
else:
slow_tokenizer = None
# Prepare tokenizer initialization kwargs
# Did we saved some inputs and kwargs to reload ?
tokenizer_config_file = resolved_vocab_files.pop("tokenizer_config_file", None)
if tokenizer_config_file is not None:
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
init_kwargs = json.load(tokenizer_config_handle)
saved_init_inputs = init_kwargs.pop("init_inputs", ())
if not init_inputs:
init_inputs = saved_init_inputs
else:
init_kwargs = init_configuration
# Update with newly provided kwargs
init_kwargs.update(kwargs)
# Convert AddedTokens serialized as dict to class instances
def convert_added_tokens(obj: Union[AddedToken, Any]):
if isinstance(obj, dict) and "__type" in obj and obj["__type"] == "AddedToken":
obj.pop("__type")
return AddedToken(**obj)
elif isinstance(obj, (list, tuple)):
return list(convert_added_tokens(o) for o in obj)
elif isinstance(obj, dict):
return {k: convert_added_tokens(v) for k, v in obj.items()}
return obj
init_kwargs = convert_added_tokens(init_kwargs)
# Set max length if needed
if pretrained_model_name_or_path in cls.max_model_input_sizes:
# if we're using a pretrained model, ensure the tokenizer
# wont index sequences longer than the number of positional embeddings
model_max_length = cls.max_model_input_sizes[pretrained_model_name_or_path]
if model_max_length is not None and isinstance(model_max_length, (int, float)):
init_kwargs["model_max_length"] = min(init_kwargs.get("model_max_length", int(1e30)), model_max_length)
# Merge resolved_vocab_files arguments in init_kwargs.
added_tokens_file = resolved_vocab_files.pop("added_tokens_file", None)
for args_name, file_path in resolved_vocab_files.items():
if args_name not in init_kwargs:
init_kwargs[args_name] = file_path
if slow_tokenizer is not None:
init_kwargs["__slow_tokenizer"] = slow_tokenizer
init_kwargs["name_or_path"] = pretrained_model_name_or_path
# Instantiate tokenizer.
try:
tokenizer = cls(*init_inputs, **init_kwargs)
except OSError:
raise OSError(
"Unable to load vocabulary from file. "
"Please check that the provided vocabulary is accessible and not corrupted."
)
# Save inputs and kwargs for saving and re-loading with ``save_pretrained``
# Removed: Now done at the base class level
# tokenizer.init_inputs = init_inputs
# tokenizer.init_kwargs = init_kwargs
# If there is a complementary special token map, load it
special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None)
if special_tokens_map_file is not None:
with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle:
special_tokens_map = json.load(special_tokens_map_handle)
for key, value in special_tokens_map.items():
if isinstance(value, dict):
value = AddedToken(**value)
elif isinstance(value, list):
value = [AddedToken(**token) if isinstance(token, dict) else token for token in value]
setattr(tokenizer, key, value)
# Add supplementary tokens.
special_tokens = tokenizer.all_special_tokens
if added_tokens_file is not None:
with open(added_tokens_file, encoding="utf-8") as added_tokens_handle:
added_tok_encoder = json.load(added_tokens_handle)
# Sort added tokens by index
added_tok_encoder_sorted = list(sorted(added_tok_encoder.items(), key=lambda x: x[1]))
for token, index in added_tok_encoder_sorted:
assert index == len(tokenizer), (
f"Non-consecutive added token '{token}' found. "
f"Should have index {len(tokenizer)} but has index {index} in saved vocabulary."
)
tokenizer.add_tokens(token, special_tokens=bool(token in special_tokens))
# Check all our special tokens are registered as "no split" token (we don't cut them) and are in the vocab
added_tokens = tokenizer.sanitize_special_tokens()
if added_tokens:
logger.warning(
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained."
)
return tokenizer
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
legacy_format: bool = True,
filename_prefix: Optional[str] = None,
) -> Tuple[str]:
"""
Save the full tokenizer state.
This method make sure the full tokenizer can then be re-loaded using the
:meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.from_pretrained` class method.
.. Note::
A "fast" tokenizer (instance of :class:`transformers.PreTrainedTokenizerFast`) saved with this method will
not be possible to load back in a "slow" tokenizer, i.e. in a :class:`transformers.PreTrainedTokenizer`
instance. It can only be loaded in a "fast" tokenizer, i.e. in a
:class:`transformers.PreTrainedTokenizerFast` instance.
.. Warning::
This won't save modifications you may have applied to the tokenizer after the instantiation (for instance,
modifying :obj:`tokenizer.do_lower_case` after creation).
Args:
save_directory (:obj:`str` or :obj:`os.PathLike`): The path to a directory where the tokenizer will be saved.
legacy_format (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to save the tokenizer in legacy format (default), i.e. with tokenizer specific vocabulary and a
separate added_tokens files or in the unified JSON file format for the `tokenizers` library. It's only
possible to save a Fast tokenizer in the unified JSON format and this format is incompatible with
"slow" tokenizers (not powered by the `tokenizers` library).
filename_prefix: (:obj:`str`, `optional`):
A prefix to add to the names of the files saved by the tokenizer.
Returns:
A tuple of :obj:`str`: The files saved.
"""
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
os.makedirs(save_directory, exist_ok=True)
special_tokens_map_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + SPECIAL_TOKENS_MAP_FILE
)
tokenizer_config_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_CONFIG_FILE
)
tokenizer_config = copy.deepcopy(self.init_kwargs)
if len(self.init_inputs) > 0:
tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs)
for file_id in self.vocab_files_names.keys():
tokenizer_config.pop(file_id, None)
# Sanitize AddedTokens
def convert_added_tokens(obj: Union[AddedToken, Any], add_type_field=True):
if isinstance(obj, AddedToken):
out = obj.__getstate__()
if add_type_field:
out["__type"] = "AddedToken"
return out
elif isinstance(obj, (list, tuple)):
return list(convert_added_tokens(o, add_type_field=add_type_field) for o in obj)
elif isinstance(obj, dict):
return {k: convert_added_tokens(v, add_type_field=add_type_field) for k, v in obj.items()}
return obj
# add_type_field=True to allow dicts in the kwargs / differentiate from AddedToken serialization
tokenizer_config = convert_added_tokens(tokenizer_config, add_type_field=True)
with open(tokenizer_config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(tokenizer_config, ensure_ascii=False))
logger.info(f"tokenizer config file saved in {tokenizer_config_file}")
# Sanitize AddedTokens in special_tokens_map
write_dict = convert_added_tokens(self.special_tokens_map_extended, add_type_field=False)
with open(special_tokens_map_file, "w", encoding="utf-8") as f:
f.write(json.dumps(write_dict, ensure_ascii=False))
logger.info(f"Special tokens file saved in {special_tokens_map_file}")
file_names = (tokenizer_config_file, special_tokens_map_file)
return self._save_pretrained(
save_directory=save_directory,
file_names=file_names,
legacy_format=legacy_format,
filename_prefix=filename_prefix,
)
def _save_pretrained(
self,
save_directory: Union[str, os.PathLike],
file_names: Tuple[str],
legacy_format: bool = True,
filename_prefix: Optional[str] = None,
) -> Tuple[str]:
"""
Save a tokenizer using the slow-tokenizer/legacy format: vocabulary + added tokens.
Fast tokenizers can also be saved in a unique JSON file containing {config + vocab + added-tokens} using the
specific :meth:`~transformers.tokenization_utils_fast.PreTrainedTokenizerFast._save_pretrained`
"""
if not legacy_format:
raise ValueError(
"Only fast tokenizers (instances of PreTrainedTokenizerFast) can be saved in non legacy format."
)
save_directory = str(save_directory)
added_tokens_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + ADDED_TOKENS_FILE
)
added_vocab = self.get_added_vocab()
if added_vocab:
with open(added_tokens_file, "w", encoding="utf-8") as f:
out_str = json.dumps(added_vocab, ensure_ascii=False)
f.write(out_str)
logger.info(f"added tokens file saved in {added_tokens_file}")
vocab_files = self.save_vocabulary(save_directory, filename_prefix=filename_prefix)
return file_names + vocab_files + (added_tokens_file,)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
"""
Save only the vocabulary of the tokenizer (vocabulary + added tokens).
This method won't save the configuration and special token mappings of the tokenizer. Use
:meth:`~transformers.PreTrainedTokenizerFast._save_pretrained` to save the whole state of the tokenizer.
Args:
save_directory (:obj:`str`):
The directory in which to save the vocabulary.
filename_prefix (:obj:`str`, `optional`):
An optional prefix to add to the named of the saved files.
Returns:
:obj:`Tuple(str)`: Paths to the files saved.
"""
raise NotImplementedError
def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> List[str]:
"""
Converts a string in a sequence of tokens, replacing unknown tokens with the :obj:`unk_token`.
Args:
text (:obj:`str`):
The sequence to be encoded.
pair (:obj:`str`, `optional`):
A second sequence to be encoded with the first.
add_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to add the special tokens associated with the corresponding model.
kwargs (additional keyword arguments, `optional`):
Will be passed to the underlying model specific encode method. See details in
:meth:`~transformers.PreTrainedTokenizerBase.__call__`
Returns:
:obj:`List[str]`: The list of tokens.
"""
raise NotImplementedError
@add_end_docstrings(
ENCODE_KWARGS_DOCSTRING,
"""
**kwargs: Passed along to the `.tokenize()` method.
""",
"""
Returns:
:obj:`List[int]`, :obj:`torch.Tensor`, :obj:`tf.Tensor` or :obj:`np.ndarray`: The tokenized ids of the
text.
""",
)
def encode(
self,
text: Union[TextInput, PreTokenizedInput, EncodedInput],
text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
stride: int = 0,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs
) -> List[int]:
"""
Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary.
Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
Args:
text (:obj:`str`, :obj:`List[str]` or :obj:`List[int]`):
The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the
``tokenize`` method) or a list of integers (tokenized string ids using the ``convert_tokens_to_ids``
method).
text_pair (:obj:`str`, :obj:`List[str]` or :obj:`List[int]`, `optional`):
Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using
the ``tokenize`` method) or a list of integers (tokenized string ids using the
``convert_tokens_to_ids`` method).
"""
encoded_inputs = self.encode_plus(
text,
text_pair=text_pair,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
return_tensors=return_tensors,
**kwargs,
)
return encoded_inputs["input_ids"]
def num_special_tokens_to_add(self, pair: bool = False) -> int:
raise NotImplementedError
def _get_padding_truncation_strategies(
self, padding=False, truncation=False, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
):
"""
Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
and pad_to_max_length) and behaviors.
"""
old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate")
old_pad_to_max_length = kwargs.pop("pad_to_max_length", False)
# Backward compatibility for previous behavior, maybe we should deprecate it:
# If you only set max_length, it activates truncation for max_length
if max_length is not None and padding is False and truncation is False:
if verbose:
if not self.deprecation_warnings.get("Truncation-not-explicitly-activated", False):
logger.warning(
"Truncation was not explicitly activated but `max_length` is provided a specific value, "
"please use `truncation=True` to explicitly truncate examples to max length. "
"Defaulting to 'longest_first' truncation strategy. "
"If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy "
"more precisely by providing a specific strategy to `truncation`."
)
self.deprecation_warnings["Truncation-not-explicitly-activated"] = True
truncation = "longest_first"
# Get padding strategy
if padding is False and old_pad_to_max_length:
if verbose:
warnings.warn(
"The `pad_to_max_length` argument is deprecated and will be removed in a future version, "
"use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or "
"use `padding='max_length'` to pad to a max length. In this case, you can give a specific "
"length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the "
"maximal input size of the model (e.g. 512 for Bert).",
FutureWarning,
)
if max_length is None:
padding_strategy = PaddingStrategy.LONGEST
else:
padding_strategy = PaddingStrategy.MAX_LENGTH
elif padding is not False:
if padding is True:
padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
elif not isinstance(padding, PaddingStrategy):
padding_strategy = PaddingStrategy(padding)
elif isinstance(padding, PaddingStrategy):
padding_strategy = padding
else:
padding_strategy = PaddingStrategy.DO_NOT_PAD
# Get truncation strategy
if truncation is False and old_truncation_strategy != "do_not_truncate":
if verbose:
warnings.warn(
"The `truncation_strategy` argument is deprecated and will be removed in a future version, "
"use `truncation=True` to truncate examples to a max length. You can give a specific "
"length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the "
"maximal input size of the model (e.g. 512 for Bert). "
" If you have pairs of inputs, you can give a specific truncation strategy selected among "
"`truncation='only_first'` (will only truncate the first sentence in the pairs) "
"`truncation='only_second'` (will only truncate the second sentence in the pairs) "
"or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence in the pairs).",
FutureWarning,
)
truncation_strategy = TruncationStrategy(old_truncation_strategy)
elif truncation is not False:
if truncation is True:
truncation_strategy = (
TruncationStrategy.LONGEST_FIRST
) # Default to truncate the longest sequences in pairs of inputs
elif not isinstance(truncation, TruncationStrategy):
truncation_strategy = TruncationStrategy(truncation)
elif isinstance(truncation, TruncationStrategy):
truncation_strategy = truncation
else:
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
# Set max length if needed
if max_length is None:
if padding_strategy == PaddingStrategy.MAX_LENGTH:
if self.model_max_length > LARGE_INTEGER:
if verbose:
if not self.deprecation_warnings.get("Asking-to-pad-to-max_length", False):
logger.warning(
"Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
"Default to no padding."
)
self.deprecation_warnings["Asking-to-pad-to-max_length"] = True
padding_strategy = PaddingStrategy.DO_NOT_PAD
else:
max_length = self.model_max_length
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
if self.model_max_length > LARGE_INTEGER:
if verbose:
if not self.deprecation_warnings.get("Asking-to-truncate-to-max_length", False):
logger.warning(
"Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
"Default to no truncation."
)
self.deprecation_warnings["Asking-to-truncate-to-max_length"] = True
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
else:
max_length = self.model_max_length
# Test if we have a padding token
if padding_strategy != PaddingStrategy.DO_NOT_PAD and (not self.pad_token or self.pad_token_id < 0):
raise ValueError(
"Asking to pad but the tokenizer does not have a padding token. "
"Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
"or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`."
)
# Check that we will truncate to a multiple of pad_to_multiple_of if both are provided
if (
truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE
and padding_strategy != PaddingStrategy.DO_NOT_PAD
and pad_to_multiple_of is not None
and max_length is not None
and (max_length % pad_to_multiple_of != 0)
):
raise ValueError(
f"Truncation and padding are both activated but "
f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})."
)
return padding_strategy, truncation_strategy, max_length, kwargs
@add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs
) -> BatchEncoding:
"""
Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of
sequences.
Args:
text (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
:obj:`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
text_pair (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
:obj:`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
"""
# Input type checking for clearer error
assert isinstance(text, str) or (
isinstance(text, (list, tuple))
and (
len(text) == 0
or (
isinstance(text[0], str)
or (isinstance(text[0], (list, tuple)) and (len(text[0]) == 0 or isinstance(text[0][0], str)))
)
)
), (
"text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) "
"or `List[List[str]]` (batch of pretokenized examples)."
)
assert (
text_pair is None
or isinstance(text_pair, str)
or (
isinstance(text_pair, (list, tuple))
and (
len(text_pair) == 0
or (
isinstance(text_pair[0], str)
or (
isinstance(text_pair[0], (list, tuple))
and (len(text_pair[0]) == 0 or isinstance(text_pair[0][0], str))
)
)
)
)
), (
"text_pair input must of type `str` (single example), `List[str]` (batch or single pretokenized example) "
"or `List[List[str]]` (batch of pretokenized examples)."
)
is_batched = bool(
(not is_split_into_words and isinstance(text, (list, tuple)))
or (
is_split_into_words and isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
)
)
if is_batched:
batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
return self.batch_encode_plus(
batch_text_or_text_pairs=batch_text_or_text_pairs,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
is_split_into_words=is_split_into_words,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
**kwargs,
)
else:
return self.encode_plus(
text=text,
text_pair=text_pair,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
is_split_into_words=is_split_into_words,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
**kwargs,
)
@add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
def encode_plus(
self,
text: Union[TextInput, PreTokenizedInput, EncodedInput],
text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs
) -> BatchEncoding:
"""
Tokenize and prepare for the model a sequence or a pair of sequences.
.. warning::
This method is deprecated, ``__call__`` should be used instead.
Args:
text (:obj:`str`, :obj:`List[str]` or :obj:`List[int]` (the latter only for not-fast tokenizers)):
The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the
``tokenize`` method) or a list of integers (tokenized string ids using the ``convert_tokens_to_ids``
method).
text_pair (:obj:`str`, :obj:`List[str]` or :obj:`List[int]`, `optional`):
Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using
the ``tokenize`` method) or a list of integers (tokenized string ids using the
``convert_tokens_to_ids`` method).
"""
# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
padding=padding,
truncation=truncation,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
verbose=verbose,
**kwargs,
)
return self._encode_plus(
text=text,
text_pair=text_pair,
add_special_tokens=add_special_tokens,
padding_strategy=padding_strategy,
truncation_strategy=truncation_strategy,
max_length=max_length,
stride=stride,
is_split_into_words=is_split_into_words,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
**kwargs,
)
def _encode_plus(
self,
text: Union[TextInput, PreTokenizedInput, EncodedInput],
text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
add_special_tokens: bool = True,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
max_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs
) -> BatchEncoding:
raise NotImplementedError
@add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
def batch_encode_plus(
self,
batch_text_or_text_pairs: Union[
List[TextInput],
List[TextInputPair],
List[PreTokenizedInput],
List[PreTokenizedInputPair],
List[EncodedInput],
List[EncodedInputPair],
],
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs
) -> BatchEncoding:
"""
Tokenize and prepare for the model a list of sequences or a list of pairs of sequences.
.. warning::
This method is deprecated, ``__call__`` should be used instead.
Args:
batch_text_or_text_pairs (:obj:`List[str]`, :obj:`List[Tuple[str, str]]`, :obj:`List[List[str]]`, :obj:`List[Tuple[List[str], List[str]]]`, and for not-fast tokenizers, also :obj:`List[List[int]]`, :obj:`List[Tuple[List[int], List[int]]]`):
Batch of sequences or pair of sequences to be encoded. This can be a list of
string/string-sequences/int-sequences or a list of pair of string/string-sequences/int-sequence (see
details in ``encode_plus``).
"""
# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
padding=padding,
truncation=truncation,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
verbose=verbose,
**kwargs,
)
return self._batch_encode_plus(
batch_text_or_text_pairs=batch_text_or_text_pairs,
add_special_tokens=add_special_tokens,
padding_strategy=padding_strategy,
truncation_strategy=truncation_strategy,
max_length=max_length,
stride=stride,
is_split_into_words=is_split_into_words,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
**kwargs,
)
def _batch_encode_plus(
self,
batch_text_or_text_pairs: Union[
List[TextInput],
List[TextInputPair],
List[PreTokenizedInput],
List[PreTokenizedInputPair],
List[EncodedInput],
List[EncodedInputPair],
],
add_special_tokens: bool = True,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
max_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs
) -> BatchEncoding:
raise NotImplementedError
def pad(
self,
encoded_inputs: Union[
BatchEncoding,
List[BatchEncoding],
Dict[str, EncodedInput],
Dict[str, List[EncodedInput]],
List[Dict[str, EncodedInput]],
],
padding: Union[bool, str, PaddingStrategy] = True,
max_length: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
verbose: bool = True,
) -> BatchEncoding:
"""
Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
in the batch.
Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``,
``self.pad_token_id`` and ``self.pad_token_type_id``)
.. note::
If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
result will use the same type unless you provide a different tensor type with ``return_tensors``. In the
case of PyTorch tensors, you will lose the specific device of your tensors however.
Args:
encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`):
Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str,
List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str,
List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as
well as in a PyTorch Dataloader collate function.
Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
see the note above for the return type.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
>= 7.5 (Volta).
return_attention_mask (:obj:`bool`, `optional`):
Whether to return the attention mask. If left to the default, will return the attention mask according
to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
`What are attention masks? <../glossary.html#attention-mask>`__
return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
verbose (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to print more information and warnings.
"""
# If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
# The model's main input name, usually `input_ids`, has be passed for padding
if self.model_input_names[0] not in encoded_inputs:
raise ValueError(
"You should supply an encoding or a list of encodings to this method"
f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
)
required_input = encoded_inputs[self.model_input_names[0]]
if not required_input:
if return_attention_mask:
encoded_inputs["attention_mask"] = []
return encoded_inputs
# If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
# and rebuild them afterwards if no return_tensors is specified
# Note that we lose the specific device the tensor may be on for PyTorch
first_element = required_input[0]
if isinstance(first_element, (list, tuple)):
# first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
index = 0
while len(required_input[index]) == 0:
index += 1
if index < len(required_input):
first_element = required_input[index][0]
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
if not isinstance(first_element, (int, list, tuple)):
if is_tf_available() and _is_tensorflow(first_element):
return_tensors = "tf" if return_tensors is None else return_tensors
elif is_torch_available() and _is_torch(first_element):
return_tensors = "pt" if return_tensors is None else return_tensors
elif isinstance(first_element, np.ndarray):
return_tensors = "np" if return_tensors is None else return_tensors
else:
raise ValueError(
f"type of {first_element} unknown: {type(first_element)}. "
f"Should be one of a python, numpy, pytorch or tensorflow object."
)
for key, value in encoded_inputs.items():
encoded_inputs[key] = to_py_obj(value)
# Convert padding_strategy in PaddingStrategy
padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
padding=padding, max_length=max_length, verbose=verbose
)
required_input = encoded_inputs[self.model_input_names[0]]
if required_input and not isinstance(required_input[0], (list, tuple)):
encoded_inputs = self._pad(
encoded_inputs,
max_length=max_length,
padding_strategy=padding_strategy,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
batch_size = len(required_input)
assert all(
len(v) == batch_size for v in encoded_inputs.values()
), "Some items in the output dictionary have a different batch size than others."
if padding_strategy == PaddingStrategy.LONGEST:
max_length = max(len(inputs) for inputs in required_input)
padding_strategy = PaddingStrategy.MAX_LENGTH
batch_outputs = {}
for i in range(batch_size):
inputs = dict((k, v[i]) for k, v in encoded_inputs.items())
outputs = self._pad(
inputs,
max_length=max_length,
padding_strategy=padding_strategy,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
for key, value in outputs.items():
if key not in batch_outputs:
batch_outputs[key] = []
batch_outputs[key].append(value)
return BatchEncoding(batch_outputs, tensor_type=return_tensors)
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create the token type IDs corresponding to the sequences passed. `What are token type IDs?
<../glossary.html#token-type-ids>`__
Should be overridden in a subclass if the model has a special way of building those.
Args:
token_ids_0 (:obj:`List[int]`): The first tokenized sequence.
token_ids_1 (:obj:`List[int]`, `optional`): The second tokenized sequence.
Returns:
:obj:`List[int]`: The token type ids.
"""
if token_ids_1 is None:
return len(token_ids_0) * [0]
return [0] * len(token_ids_0) + [1] * len(token_ids_1)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens.
This implementation does not add special tokens and this method should be overridden in a subclass.
Args:
token_ids_0 (:obj:`List[int]`): The first tokenized sequence.
token_ids_1 (:obj:`List[int]`, `optional`): The second tokenized sequence.
Returns:
:obj:`List[int]`: The model input with special tokens.
"""
if token_ids_1 is None:
return token_ids_0
return token_ids_0 + token_ids_1
@add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
def prepare_for_model(
self,
ids: List[int],
pair_ids: Optional[List[int]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
prepend_batch_axis: bool = False,
**kwargs
) -> BatchEncoding:
"""
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
manages a moving window (with user defined stride) for overflowing tokens
Args:
ids (:obj:`List[int]`):
Tokenized input ids of the first sequence. Can be obtained from a string by chaining the ``tokenize``
and ``convert_tokens_to_ids`` methods.
pair_ids (:obj:`List[int]`, `optional`):
Tokenized input ids of the second sequence. Can be obtained from a string by chaining the ``tokenize``
and ``convert_tokens_to_ids`` methods.
"""
# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
padding=padding,
truncation=truncation,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
verbose=verbose,
**kwargs,
)
pair = bool(pair_ids is not None)
len_ids = len(ids)
len_pair_ids = len(pair_ids) if pair else 0
if return_token_type_ids and not add_special_tokens:
raise ValueError(
"Asking to return token_type_ids while setting add_special_tokens to False "
"results in an undefined behavior. Please set add_special_tokens to True or "
"set return_token_type_ids to None."
)
# Load from model defaults
if return_token_type_ids is None:
return_token_type_ids = "token_type_ids" in self.model_input_names
if return_attention_mask is None:
return_attention_mask = "attention_mask" in self.model_input_names
encoded_inputs = {}
# Compute the total size of the returned encodings
total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
# Truncation: Handle max sequence length
overflowing_tokens = []
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
ids, pair_ids, overflowing_tokens = self.truncate_sequences(
ids,
pair_ids=pair_ids,
num_tokens_to_remove=total_len - max_length,
truncation_strategy=truncation_strategy,
stride=stride,
)
if return_overflowing_tokens:
encoded_inputs["overflowing_tokens"] = overflowing_tokens
encoded_inputs["num_truncated_tokens"] = total_len - max_length
# Add special tokens
if add_special_tokens:
sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
else:
sequence = ids + pair_ids if pair else ids
token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else [])
# Build output dictionary
encoded_inputs["input_ids"] = sequence
if return_token_type_ids:
encoded_inputs["token_type_ids"] = token_type_ids
if return_special_tokens_mask:
if add_special_tokens:
encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
else:
encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
# Check lengths
self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose)
# Padding
if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
encoded_inputs = self.pad(
encoded_inputs,
max_length=max_length,
padding=padding_strategy.value,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
if return_length:
encoded_inputs["length"] = len(encoded_inputs["input_ids"])
batch_outputs = BatchEncoding(
encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis
)
return batch_outputs
def truncate_sequences(
self,
ids: List[int],
pair_ids: Optional[List[int]] = None,
num_tokens_to_remove: int = 0,
truncation_strategy: Union[str, TruncationStrategy] = "longest_first",
stride: int = 0,
) -> Tuple[List[int], List[int], List[int]]:
"""
Truncates a sequence pair in-place following the strategy.
Args:
ids (:obj:`List[int]`):
Tokenized input ids of the first sequence. Can be obtained from a string by chaining the ``tokenize``
and ``convert_tokens_to_ids`` methods.
pair_ids (:obj:`List[int]`, `optional`):
Tokenized input ids of the second sequence. Can be obtained from a string by chaining the ``tokenize``
and ``convert_tokens_to_ids`` methods.
num_tokens_to_remove (:obj:`int`, `optional`, defaults to 0):
Number of tokens to remove using the truncation strategy.
truncation_strategy (:obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`False`):
The strategy to follow for truncation. Can be:
* :obj:`'longest_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
to the maximum acceptable input length for the model if that argument is not provided. This will
truncate token by token, removing a token from the longest sequence in the pair if a pair of
sequences (or a batch of pairs) is provided.
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
to the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
greater than the model maximum admissible input size).
stride (:obj:`int`, `optional`, defaults to 0):
If set to a positive number, the overflowing tokens returned will contain some tokens from the main
sequence returned. The value of this argument defines the number of additional tokens.
Returns:
:obj:`Tuple[List[int], List[int], List[int]]`: The truncated ``ids``, the truncated ``pair_ids`` and the
list of overflowing tokens.
"""
if num_tokens_to_remove <= 0:
return ids, pair_ids, []
if not isinstance(truncation_strategy, TruncationStrategy):
truncation_strategy = TruncationStrategy(truncation_strategy)
overflowing_tokens = []
if truncation_strategy == TruncationStrategy.LONGEST_FIRST:
for _ in range(num_tokens_to_remove):
if pair_ids is None or len(ids) > len(pair_ids):
if not overflowing_tokens:
window_len = min(len(ids), stride + 1)
else:
window_len = 1
overflowing_tokens.extend(ids[-window_len:])
ids = ids[:-1]
else:
if not overflowing_tokens:
window_len = min(len(pair_ids), stride + 1)
else:
window_len = 1
overflowing_tokens.extend(pair_ids[-window_len:])
pair_ids = pair_ids[:-1]
elif truncation_strategy == TruncationStrategy.ONLY_FIRST:
if len(ids) > num_tokens_to_remove:
window_len = min(len(ids), stride + num_tokens_to_remove)
overflowing_tokens = ids[-window_len:]
ids = ids[:-num_tokens_to_remove]
else:
logger.error(
f"We need to remove {num_tokens_to_remove} to truncate the input"
f"but the first sequence has a length {len(ids)}. "
f"Please select another truncation strategy than {truncation_strategy}, "
f"for instance 'longest_first' or 'only_second'."
)
elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:
if len(pair_ids) > num_tokens_to_remove:
window_len = min(len(pair_ids), stride + num_tokens_to_remove)
overflowing_tokens = pair_ids[-window_len:]
pair_ids = pair_ids[:-num_tokens_to_remove]
else:
logger.error(
f"We need to remove {num_tokens_to_remove} to truncate the input"
f"but the second sequence has a length {len(pair_ids)}. "
f"Please select another truncation strategy than {truncation_strategy}, "
f"for instance 'longest_first' or 'only_first'."
)
return (ids, pair_ids, overflowing_tokens)
def _pad(
self,
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
"""
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
Args:
encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
max_length: maximum length of the returned list and optionally padding length (see below).
Will truncate by taking into account the special tokens.
padding_strategy: PaddingStrategy to use for padding.
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
- PaddingStrategy.DO_NOT_PAD: Do not pad
The tokenizer padding sides are defined in self.padding_side:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
>= 7.5 (Volta).
return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
"""
# Load from model defaults
if return_attention_mask is None:
return_attention_mask = "attention_mask" in self.model_input_names
required_input = encoded_inputs[self.model_input_names[0]]
if padding_strategy == PaddingStrategy.LONGEST:
max_length = len(required_input)
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
if needs_to_be_padded:
difference = max_length - len(required_input)
if self.padding_side == "right":
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(required_input) + [0] * difference
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = (
encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
)
if "special_tokens_mask" in encoded_inputs:
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
elif self.padding_side == "left":
if return_attention_mask:
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(required_input)
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
"token_type_ids"
]
if "special_tokens_mask" in encoded_inputs:
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
else:
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
elif return_attention_mask and "attention_mask" not in encoded_inputs:
encoded_inputs["attention_mask"] = [1] * len(required_input)
return encoded_inputs
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""
Converts a sequence of tokens in a single string. The most simple way to do it is ``" ".join(tokens)`` but we
often want to remove sub-word tokenization artifacts at the same time.
Args:
tokens (:obj:`List[str]`): The token to join in a string.
Returns:
:obj:`str`: The joined tokens.
"""
raise NotImplementedError
def batch_decode(
self,
sequences: Union[List[int], List[List[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
**kwargs
) -> List[str]:
"""
Convert a list of lists of token ids into a list of strings by calling decode.
Args:
sequences (:obj:`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`):
List of tokenized input ids. Can be obtained using the ``__call__`` method.
skip_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to remove special tokens in the decoding.
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to clean up the tokenization spaces.
kwargs (additional keyword arguments, `optional`):
Will be passed to the underlying model specific decode method.
Returns:
:obj:`List[str]`: The list of decoded sentences.
"""
return [
self.decode(
seq,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
for seq in sequences
]
def decode(
self,
token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
**kwargs
) -> str:
"""
Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
tokens and clean up tokenization spaces.
Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
Args:
token_ids (:obj:`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
List of tokenized input ids. Can be obtained using the ``__call__`` method.
skip_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to remove special tokens in the decoding.
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to clean up the tokenization spaces.
kwargs (additional keyword arguments, `optional`):
Will be passed to the underlying model specific decode method.
Returns:
:obj:`str`: The decoded sentence.
"""
# Convert inputs to python lists
token_ids = to_py_obj(token_ids)
return self._decode(
token_ids=token_ids,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
def _decode(
self,
token_ids: Union[int, List[int]],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
**kwargs
) -> str:
raise NotImplementedError
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
Args:
token_ids_0 (:obj:`List[int]`):
List of ids of the first sequence.
token_ids_1 (:obj:`List[int]`, `optional`):
List of ids of the second sequence.
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
assert already_has_special_tokens and token_ids_1 is None, (
"You cannot use ``already_has_special_tokens=False`` with this tokenizer. "
"Please use a slow (full python) tokenizer to activate this argument."
"Or set `return_special_tokens_mask=True` when calling the encoding method "
"to get the special tokens mask in any tokenizer. "
)
all_special_ids = self.all_special_ids # cache the property
special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0]
return special_tokens_mask
@staticmethod
def clean_up_tokenization(out_string: str) -> str:
"""
Clean up a list of simple English tokenization artifacts like spaces before punctuations and abbreviated forms.
Args:
out_string (:obj:`str`): The text to clean up.
Returns:
:obj:`str`: The cleaned-up string.
"""
out_string = (
out_string.replace(" .", ".")
.replace(" ?", "?")
.replace(" !", "!")
.replace(" ,", ",")
.replace(" ' ", "'")
.replace(" n't", "n't")
.replace(" 'm", "'m")
.replace(" 's", "'s")
.replace(" 've", "'ve")
.replace(" 're", "'re")
)
return out_string
def _eventual_warn_about_too_long_sequence(self, ids: List[int], max_length: Optional[int], verbose: bool):
"""
Depending on the input and internal state we might trigger a warning about a sequence that is too long for it's
corresponding model
Args:
ids (:obj:`List[str]`): The ids produced by the tokenization
max_length (:obj:`int`, `optional`): The max_length desired (does not trigger a warning if it is set)
verbose (:obj:`bool`): Whether or not to print more information and warnings.
"""
if max_length is None and len(ids) > self.model_max_length and verbose:
if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False):
logger.warning(
"Token indices sequence length is longer than the specified maximum sequence length "
f"for this model ({len(ids)} > {self.model_max_length}). Running this sequence through the model "
"will result in indexing errors"
)
self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True
@contextmanager
def as_target_tokenizer(self):
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
sequence-to-sequence models that need a slightly different processing for the labels.
"""
yield
def prepare_seq2seq_batch(
self,
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
padding: str = "longest",
return_tensors: str = None,
truncation: bool = True,
**kwargs,
) -> BatchEncoding:
"""
Prepare model inputs for translation. For best performance, translate one sentence at a time.
Arguments:
src_texts (:obj:`List[str]`):
List of documents to summarize or source language texts.
tgt_texts (:obj:`list`, `optional`):
List of summaries or target language texts.
max_length (:obj:`int`, `optional`):
Controls the maximum length for encoder inputs (documents to summarize or source language texts) If
left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum length
is required by one of the truncation/padding parameters. If the model has no specific maximum input
length (like XLNet) truncation/padding to a maximum length will be deactivated.
max_target_length (:obj:`int`, `optional`):
Controls the maximum length of decoder inputs (target language texts or summaries) If left unset or set
to :obj:`None`, this will use the max_length value.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Activates and controls padding. Accepts the following values:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
Activates and controls truncation. Accepts the following values:
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
provided. This will truncate token by token, removing a token from the longest sequence in the pair
if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
to the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
sequence lengths greater than the model maximum admissible input size).
**kwargs:
Additional keyword arguments passed along to :obj:`self.__call__`.
Return:
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
- **input_ids** -- List of token ids to be fed to the encoder.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
- **labels** -- List of token ids for tgt_texts.
The full set of keys ``[input_ids, attention_mask, labels]``, will only be returned if tgt_texts is passed.
Otherwise, input_ids, attention_mask will be the only keys.
"""
warnings.warn(
"`prepare_seq2seq_batch` is deprecated and will be removed in version 5 of 🤗 Transformers. Use the "
"regular `__call__` method to prepare your inputs and the tokenizer under the `with_target_tokenizer` "
"context manager to prepare your targets. See the documentation of your specific tokenizer for more "
"details",
FutureWarning,
)
# mBART-specific kwargs that should be ignored by other models.
kwargs.pop("src_lang", None)
kwargs.pop("tgt_lang", None)
if max_length is None:
max_length = self.model_max_length
model_inputs = self(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
padding=padding,
truncation=truncation,
**kwargs,
)
if tgt_texts is None:
return model_inputs
# Process tgt_texts
if max_target_length is None:
max_target_length = max_length
with self.as_target_tokenizer():
labels = self(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=truncation,
**kwargs,
)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
================================================
FILE: flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/utils/versions.py
================================================
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities for working with package versions
"""
import operator
import re
import sys
from typing import Optional
from packaging import version
# The package importlib_metadata is in a different place, depending on the python version.
if sys.version_info < (3, 8):
import importlib_metadata
else:
import importlib.metadata as importlib_metadata
ops = {
"<": operator.lt,
"<=": operator.le,
"==": operator.eq,
"!=": operator.ne,
">=": operator.ge,
">": operator.gt,
}
def _compare_versions(op, got_ver, want_ver, requirement, pkg, hint):
if got_ver is None:
raise ValueError("got_ver is None")
if want_ver is None:
raise ValueError("want_ver is None")
if not ops[op](version.parse(got_ver), version.parse(want_ver)):
raise ImportError(
f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}.{hint}"
)
def require_version(requirement: str, hint: Optional[str] = None) -> None:
"""
Perform a runtime check of the dependency versions, using the exact same syntax used by pip.
The installed module version comes from the `site-packages` dir via `importlib_metadata`.
Args:
requirement (:obj:`str`): pip style definition, e.g., "tokenizers==0.9.4", "tqdm>=4.27", "numpy"
hint (:obj:`str`, `optional`): what suggestion to print in case of requirements not being met
Example::
require_version("pandas>1.1.2")
require_version("numpy>1.18.5", "this is important to have for whatever reason")
"""
hint = f"\n{hint}" if hint is not None else ""
# non-versioned check
if re.match(r"^[\w_\-\d]+$", requirement):
pkg, op, want_ver = requirement, None, None
else:
match = re.findall(r"^([^!=<>\s]+)([\s!=<>]{1,2}.+)", requirement)
if not match:
raise ValueError(
f"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but got {requirement}"
)
pkg, want_full = match[0]
want_range = want_full.split(",") # there could be multiple requirements
wanted = {}
for w in want_range:
match = re.findall(r"^([\s!=<>]{1,2})(.+)", w)
if not match:
raise ValueError(
f"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but got {requirement}"
)
op, want_ver = match[0]
wanted[op] = want_ver
if op not in ops:
raise ValueError(f"{requirement}: need one of {list(ops.keys())}, but got {op}")
# special case
if pkg == "python":
got_ver = ".".join([str(x) for x in sys.version_info[:3]])
for op, want_ver in wanted.items():
_compare_versions(op, got_ver, want_ver, requirement, pkg, hint)
return
# check if any version is installed
try:
got_ver = importlib_metadata.version(pkg)
except importlib_metadata.PackageNotFoundError:
raise importlib_metadata.PackageNotFoundError(
f"The '{requirement}' distribution was not found and is required by this application. {hint}"
)
# check that the right version is installed if version number or a range was provided
if want_ver is not None:
for op, want_ver in wanted.items():
_compare_versions(op, got_ver, want_ver, requirement, pkg, hint)
def require_version_core(requirement):
""" require_version wrapper which emits a core-specific hint on failure """
hint = "Try: pip install transformers -U or pip install -e '.[dev]' if you're working with git master"
return require_version(requirement, hint)
def require_version_examples(requirement):
""" require_version wrapper which emits examples-specific hint on failure """
hint = "Try: pip install -r examples/requirements.txt"
return require_version(requirement, hint)
================================================
FILE: flaxmodels/flaxmodels/gpt2/tokenizer.py
================================================
from .third_party.huggingface_transformers.configuration_gpt2 import GPT2Tokenizer
from .. import utils
def get_tokenizer(errors='replace',
unk_token='<|endoftext|>',
bos_token='<|endoftext|>',
eos_token='<|endoftext|>',
add_prefix_space=False,
ckpt_dir=None):
"""
Returns the GPT2Tokenizer from Huggingface with loaded merges and vocab files.
See: https://huggingface.co/transformers/model_doc/gpt2.html#gpt2tokenizer
Args:
errors (str): Paradigm to follow when decoding bytes to UTF-8.
unk_token (str): The unknown token. A token that is not in the
vocabulary cannot be converted to an ID and is set to be this token instead.
bos_token (str): The beginning of sequence token.
eos_token (str): The end of sequence token.
add_prefix_space (bool): Whether or not to add an initial space to the input.
This allows to treat the leading word just as any other word.
ckpt_dir (str): Path to directory, where merges and vocab files are downloaded to.
If None, the files will be downloaded to a temp directory.
Returns:
(GPT2Tokenizer): GPT2 Tokenizer.
"""
merges_file = utils.download(ckpt_dir, 'https://www.dropbox.com/s/7f5n1gf348sy1mt/merges.txt?dl=1')
vocab_file = utils.download(ckpt_dir, 'https://www.dropbox.com/s/s93xkhgcac5nbmn/vocab.json?dl=1')
return GPT2Tokenizer(vocab_file=vocab_file,
merges_file=merges_file,
errors=errors,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
add_prefix_space=add_prefix_space)
================================================
FILE: flaxmodels/flaxmodels/gpt2/trajectory_gpt2.py
================================================
import jax.numpy as jnp
import flax.linen as nn
from typing import Any
import h5py
from .. import utils
from . import ops
URLS = {'gpt2': 'https://www.dropbox.com/s/0wdgj0gazwt9nm7/gpt2.h5?dl=1',
'gpt2-medium': 'https://www.dropbox.com/s/nam11kbd83wsm7d/gpt2-medium.h5?dl=1',
'gpt2-large': 'https://www.dropbox.com/s/oy8623qwkkjm8gt/gpt2-large.h5?dl=1',
'gpt2-xl': 'https://www.dropbox.com/s/6c6qt0bzz4v2afx/gpt2-xl.h5?dl=1'}
CONFIGS = {'gpt2': 'https://www.dropbox.com/s/s5xl32dgwc8322p/gpt2.json?dl=1',
'gpt2-medium': 'https://www.dropbox.com/s/7mwkijxoh1earm5/gpt2-medium.json?dl=1',
'gpt2-large': 'https://www.dropbox.com/s/nhslkxwxtpn7auz/gpt2-large.json?dl=1',
'gpt2-xl': 'https://www.dropbox.com/s/1iv0nq1xigsfdvb/gpt2-xl.json?dl=1'}
class GPT2SelfAttention(nn.Module):
"""
GPT2 Self Attention.
Attributes:
config (Any): Configuration object. If 'pretrained' is not None, this parameter will be ignored.
param_dict (dict): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.
"""
config: dict = None
def setup(self):
self.max_pos = self.config.n_positions
self.embd_dim = self.config.n_embd
self.num_heads = self.config.n_head
self.head_dim = self.embd_dim // self.num_heads
self.attn_dropout = self.config.attn_pdrop
self.resid_dropout = self.config.resid_pdrop
self.scale_attn_weights = True
@nn.compact
def __call__(self, x, layer_past=None, attn_mask=None, head_mask=None, use_cache=False, training=False):
"""
Run attention.
Args:
x (tensor): Input tensor.
layer_past (Tuple): Tuple of past keys and values.
attn_mask (tensor): Mask to avoid performing attention on padding token indices.
head_mask (tensor): Mask to nullify selected heads of the self-attention modules.
use_cache (bool): If True, keys and values are returned (past_key_values).
training (bool): Training mode.
Returns:
(tensor, Tuple): Output tensor, tuple of keys and values.
"""
x = nn.Dense(features=3*self.embd_dim)(x)
query, key, value = jnp.split(x, 3, axis=2)
query = ops.split_heads(query, self.num_heads, self.head_dim)
value = ops.split_heads(value, self.num_heads, self.head_dim)
key = ops.split_heads(key, self.num_heads, self.head_dim)
if layer_past is not None:
past_key, past_value = layer_past
key = jnp.concatenate((past_key, key), axis=-2)
value = jnp.concatenate((past_value, value), axis=-2)
present = (key, value) if use_cache else None
query_len, key_len = query.shape[-2], key.shape[-2]
casual_mask = jnp.tril(jnp.ones((1, 1, self.max_pos, self.max_pos)))[:, :, key_len - query_len :key_len, :key_len]
# casual_mask = jnp.ones((1, 1, self.max_pos, self.max_pos))[:, :, key_len - query_len :key_len, :key_len]
casual_mask = casual_mask.astype(bool)
attn_dropout = nn.Dropout(rate=self.attn_dropout)
out, _attn_weights = ops.attention(query, key, value, casual_mask, -1e4, attn_dropout, self.scale_attn_weights, training, attn_mask, head_mask)
out = ops.merge_heads(out, self.num_heads, self.head_dim)
out = nn.Dense(features=self.embd_dim)(out)
out = nn.Dropout(rate=self.resid_dropout)(out, deterministic=not training)
return out, present, _attn_weights
class GPT2MLP(nn.Module):
"""
GPT2 MLP.
Attributes:
intermediate_dim (int): Dimension of the intermediate layer.
config (Any): Configuration object. If 'pretrained' is not None, this parameter will be ignored.
param_dict (dict): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.
"""
intermediate_dim: int
config: dict = None
def setup(self):
self.embd_dim = self.config.n_embd
self.resid_dropout = self.config.resid_pdrop
self.activation = self.config.activation_function
@nn.compact
def __call__(self, x, training=False):
"""
Run the MLP.
Args:
x (tensor): Input tensor.
training (bool): Training mode.
"""
x = nn.Dense(features=self.intermediate_dim)(x)
x = ops.apply_activation(x, activation=self.activation)
x = nn.Dense(features=self.embd_dim)(x)
x = nn.Dropout(rate=self.resid_dropout)(x, deterministic=not training)
return x
class GPT2Block(nn.Module):
"""
GPT2 Block.
Attributes:
config (Any): Configuration object. If 'pretrained' is not None, this parameter will be ignored.
param_dict (dict): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.
"""
config: dict = None
def setup(self):
self.embd_dim = self.config.n_embd
self.eps = self.config.layer_norm_epsilon
self.inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * self.embd_dim
@nn.compact
def __call__(self, x, layer_past=None, attn_mask=None, head_mask=None, use_cache=False, training=False):
"""
Run the block.
Args:
x (tensor): Input tensor.
layer_past (Tuple): Tuple of past keys and values.
attn_mask (tensor): Mask to avoid performing attention on padding token indices.
head_mask (tensor): Mask to nullify selected heads of the self-attention modules.
use_cache (bool): If True, keys and values are returned (past_key_values).
training (bool): Training mode.
Returns:
(tensor, Tuple): Output tensor, tuple of keys and values.
"""
residual = x
x = nn.LayerNorm(epsilon=self.eps)(x)
kwargs = {'layer_past': layer_past, 'attn_mask': attn_mask, 'head_mask': head_mask,
'use_cache': use_cache, 'training': training}
x, present, _attn_weights = GPT2SelfAttention(config=self.config)(x, **kwargs)
x += residual
residual = x
x = nn.LayerNorm(epsilon=self.eps)(x)
x = GPT2MLP(intermediate_dim=self.inner_dim, config=self.config)(x, training)
x += residual
return x, present, _attn_weights
class GPT2Model(nn.Module):
"""
The GPT2 Model.
Attributes:
config (Any): Configuration object. If 'pretrained' is not None, this parameter will be ignored.
pretrained (str): Which pretrained model to use, None for random initialization.
ckpt_dir (str): Directory to which the pretrained weights are downloaded. If None, a temp directory will be used.
param_dict (dict): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.
"""
config: dict = None
pretrained: str = None
ckpt_dir: str = None
def setup(self):
assert self.pretrained is None, "pretrain must be None for training."
if self.pretrained is not None:
assert self.pretrained in URLS.keys(), f'Pretrained model not available {self.pretrained}.'
ckpt_file = utils.download(self.ckpt_dir, URLS[self.pretrained])
self.param_dict_ = h5py.File(ckpt_file, 'r')['transformer']
config_file = utils.download(self.ckpt_dir, CONFIGS[self.pretrained])
self.config_ = ops.load_config(config_file)
else:
self.config_ = self.config
self.vocab_size = self.config_.vocab_size
self.max_pos = self.config_.n_positions
self.embd_dim = self.config_.n_embd
self.embd_dropout = self.config_.embd_pdrop
self.num_layers = self.config_.n_layer
self.eps = self.config_.layer_norm_epsilon
@nn.compact
def __call__(self,
input_ids=None,
past_key_values=None,
input_embds=None,
position_ids=None,
attn_mask=None,
head_mask=None,
use_cache=False,
training=False
):
"""
Run the model.
Args:
input_ids (tensor): Input token ids, shape [B, seq_len].
past_key_values (Tuple): Precomputed hidden keys and values, tuple of tuples.
If past_key_values is used, only input_ids that do not have their
past calculated should be passed as input_ids.
input_embds (tensor): Input embeddings, shape [B, seq_len, embd_dim].
labels (tensor): Labels for language modeling, shape [B, seq_len]. Will be shifted inside the model. Ignore label = -100.
position_ids (tensor): Indices of positions of each input sequence tokens in the position embeddings, shape [B, seq_len].
attn_mask (tensor): Mask to avoid performing attention on padding token indices, shape [B, seq_len].
head_mask (tensor): Mask to nullify selected heads of the self-attention modules, shape [num_heads] or [num_layers, num_heads].
use_cache (bool): If True, keys and values are returned (past_key_values).
training (bool): Training mode.
Returns:
(dict): Dictionary containing 'last_hidden_state', 'past_key_values'.
"""
if input_ids is not None and input_embds is not None:
raise ValueError('You cannot specify both input_ids and input_embd at the same time.')
elif input_ids is not None:
input_shape = input_ids.shape
input_ids = jnp.reshape(input_ids, newshape=(-1, input_shape[-1]))
batch_size = input_ids.shape[0]
elif input_embds is not None:
input_shape = input_embds.shape[:-1]
batch_size = input_embds.shape[0]
else:
raise ValueError('You have to specify either input_ids or input_embd.')
if position_ids is not None:
position_ids = jnp.reshape(position_ids, newshape=(-1, input_shape[-1]))
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * self.num_layers)
else:
past_length = past_key_values[0][0].shape[-2]
if position_ids is None:
position_ids = jnp.arange(start=past_length, stop=input_shape[-1] + past_length)
position_ids = jnp.reshape(jnp.expand_dims(position_ids, axis=0), newshape=(-1, input_shape[-1]))
if input_embds is None:
input_embds = nn.Embed(num_embeddings=self.vocab_size, features=self.embd_dim)(input_ids)
if attn_mask is not None:
attn_mask = ops.get_attention_mask(attn_mask, batch_size)
if head_mask is not None:
head_mask = ops.get_head_mask(head_mask, self.num_layers)
else:
head_mask = [None] * self.num_layers
# position_embds = nn.Embed(num_embeddings=self.max_pos, features=self.embd_dim)(position_ids)
# x = input_embds + position_embds
x = input_embds
x = nn.Dropout(rate=self.embd_dropout)(x, deterministic=not training)
output_shape = input_shape + (x.shape[-1],)
presents = () if use_cache else None
attn_weights_list = []
for i in range(self.num_layers):
kwargs = {'layer_past': past_key_values[i], 'attn_mask': attn_mask, 'head_mask': head_mask[i],
'use_cache': use_cache, 'training': training}
x, present, attn_weights = GPT2Block(config=self.config_)(x, **kwargs)
if use_cache:
presents = presents + (present,)
attn_weights_list.append(attn_weights)
x = nn.LayerNorm(epsilon=self.eps)(x)
return {'last_hidden_state': x, 'past_key_values': presents, 'attn_weights_list': attn_weights_list}
class TransRewardModel(nn.Module):
config: Any = None
pretrained: str = None
ckpt_dir: str = None
observation_dim: int = 29
action_dim: int = 8
activation: str = None
activation_final: str = None
max_episode_steps: int = 1000
def setup(self):
self.config_ = self.config
self.config_.activation_function = self.activation
self.config_.activation_final = self.activation_final
self.vocab_size = self.config_.vocab_size
self.max_pos = self.config_.n_positions
self.embd_dim = self.config_.n_embd
self.pref_attn_embd_dim = self.config_.pref_attn_embd_dim
self.embd_dropout = self.config_.embd_pdrop
self.attn_dropout = self.config_.attn_pdrop
self.resid_dropout = self.config_.resid_pdrop
self.num_layers = self.config_.n_layer
self.inner_dim = self.config_.n_embd // 2
self.eps = self.config_.layer_norm_epsilon
@nn.compact
def __call__(
self,
states,
actions,
timesteps,
attn_mask=None,
training=False,
reverse=False,
target_idx=1,
):
batch_size, seq_length = states.shape[0], states.shape[1]
if attn_mask is None:
attn_mask = jnp.ones((batch_size, seq_length), dtype=jnp.float32)
embd_state = nn.Dense(features=self.embd_dim)(states)
embd_action = nn.Dense(features=self.embd_dim)(actions)
embd_timestep = nn.Embed(num_embeddings=self.max_episode_steps + 1, features=self.embd_dim)(timesteps)
embd_state = embd_state + embd_timestep
embd_action = embd_action + embd_timestep
if reverse:
stacked_inputs = jnp.stack(
[embd_state, embd_action],
axis=1
).transpose(0, 2, 1, 3).reshape(batch_size, 2 * seq_length, self.embd_dim)
else:
stacked_inputs = jnp.stack(
[embd_action, embd_state],
axis=1
).transpose(0, 2, 1, 3).reshape(batch_size, 2 * seq_length, self.embd_dim)
stacked_inputs = nn.LayerNorm(epsilon=self.eps)(stacked_inputs)
stacked_attn_mask = jnp.stack(
[attn_mask, attn_mask],
axis=1
).transpose(0, 2, 1).reshape(batch_size, 2 * seq_length)
transformer_outputs = GPT2Model(
config=self.config
)(
input_embds=stacked_inputs,
attn_mask=stacked_attn_mask,
training=training,
)
x = transformer_outputs["last_hidden_state"]
attn_weights_list = transformer_outputs["attn_weights_list"]
x = x.reshape(batch_size, seq_length, 2, self.embd_dim).transpose(0, 2, 1, 3)
hidden_output = x[:, target_idx]
if self.config_.use_weighted_sum:
'''
add additional Attention Layer for Weighted Sum.
x (= output, tensor): Predicted Reward, shape [B, seq_len, embd_dim]
'''
x = nn.Dense(features=2 * self.pref_attn_embd_dim + 1)(hidden_output)
# only one head, because value has 1 dim for predicting rewards directly.
num_heads = 1
# query: [B, seq_len, embd_dim]
# key: [B, seq_len, embd_dim]
# value: [B, seq_len, 1]
query, key, value = jnp.split(x, [self.pref_attn_embd_dim, self.pref_attn_embd_dim * 2], axis=2)
query = ops.split_heads(query, num_heads, self.pref_attn_embd_dim)
key = ops.split_heads(key, num_heads, self.pref_attn_embd_dim)
value = ops.split_heads(value, num_heads, 1)
# query: [B, 1, seq_len, embd_dim]
# key: [B, 1, seq_len, embd_dim]
# value: [B, 1, seq_len, 1]
query_len, key_len = query.shape[-2], key.shape[-2]
# casual_mask = jnp.tril(jnp.ones((1, 1, self.config_.n_positions, self.config_.n_positions)))[:, :, key_len - query_len :key_len, :key_len]
# casual_mask = casual_mask.astype(bool)
casual_mask = jnp.ones((1, 1, seq_length, seq_length))[:, :, key_len - query_len :key_len, :key_len]
casual_mask = casual_mask.astype(bool)
# attn_dropout = nn.Dropout(rate=self.attn_dropout) # split dropout rate
attn_dropout = nn.Dropout(rate=0.0) # boilerplate code.
new_attn_mask = ops.get_attention_mask(attn_mask, batch_size)
out, last_attn_weights = ops.attention(query, key, value, casual_mask, -1e-4, attn_dropout, scale_attn_weights=True, training=training, attn_mask=new_attn_mask, head_mask=None)
attn_weights_list.append(last_attn_weights)
# out: [B, 1, seq_len, 1]
output = ops.merge_heads(out, num_heads, 1)
# output: [B, seq_len, 1]
# output = nn.Dropout(rate=self.resid_dropout)(out, deterministic=not training)
return {"weighted_sum": output, "value": value}, attn_weights_list
else:
x = nn.Dense(features=self.inner_dim)(hidden_output)
x = ops.apply_activation(x, activation=self.activation)
output = nn.Dense(features=1)(x)
if self.activation_final != 'none':
output = ops.apply_activation(output, activation=self.activation_final)
return {"value": output}, attn_weights_list
================================================
FILE: flaxmodels/flaxmodels/lstm/lstm.py
================================================
import functools
import jax
import jax.numpy as jnp
import flax.linen as nn
from typing import Any
import h5py
from .. import utils
from . import ops
class SimpleLSTM(nn.Module):
"""A simple unidirectional LSTM."""
@functools.partial(
nn.transforms.scan,
variable_broadcast='params',
in_axes=1, out_axes=1,
split_rngs={'params': False})
@nn.compact
def __call__(self, carry, x):
return nn.OptimizedLSTMCell()(carry, x)
@staticmethod
def initialize_carry(batch_dims, hidden_size):
# Use fixed random key since default state init fn is just zeros.
return nn.OptimizedLSTMCell.initialize_carry(
jax.random.PRNGKey(0), batch_dims, hidden_size)
class LSTMRewardModel(nn.Module):
config: Any=None
pretrained: str=None
ckpt_dir: str=None
observation_dim: int=29
action_dim: int=8
activation: str=None
activation_final: str=None
max_episode_steps: int=1000
def setup(self):
self.config_ = self.config
self.config_.activation_function = self.activation
self.config_.activation_final = self.activation_final
self.vocab_size = self.config_.vocab_size
self.max_pos = self.config_.n_positions
self.embd_dim = self.config_.n_embd
self.embd_dropout = self.config_.embd_pdrop
self.num_layers = self.config_.n_layer
self.inner_dim = self.config_.n_inner
self.eps = self.config_.layer_norm_epsilon
@nn.compact
def __call__(
self,
states,
actions,
timesteps,
attn_mask=None,
training=False,
reverse=False,
target_idx=1
):
batch_size = states.shape[0]
x = jnp.concatenate([states, actions], axis=-1)
for hd in [self.embd_dim, self.embd_dim // 2, self.embd_dim // 2]:
x = nn.Dense(features=hd)(x)
x = ops.apply_activation(x, activation=self.activation)
x = nn.Dropout(rate=self.embd_dropout)(x, deterministic=not training)
lstm = SimpleLSTM()
initial_state = lstm.initialize_carry((batch_size, ), self.embd_dim // 2)
_, lstm_outputs = lstm(initial_state, x)
x = jnp.concatenate([x, lstm_outputs], axis=-1)
for hd in [self.embd_dim // 2, self.embd_dim // 4, self.embd_dim // 4]:
x = nn.Dense(features=hd)(x)
x = ops.apply_activation(x, activation=self.activation)
x = nn.Dropout(rate=self.embd_dropout)(x, deterministic=not training)
output = nn.Dense(features=1)(x)
return output, lstm_outputs
================================================
FILE: flaxmodels/flaxmodels/lstm/ops.py
================================================
import jax
import jax.numpy as jnp
import flax.linen as nn
import math
import json
from types import SimpleNamespace
#----------------------------------------------------------
# Linear
#----------------------------------------------------------
def linear(features, param_dict, bias=True):
if param_dict is None:
return nn.Dense(features=features, use_bias=bias)
else:
if bias:
assert 'bias' in param_dict
assert 'weight' in param_dict
return nn.Dense(features=features,
kernel_init=lambda *_ : jnp.array(param_dict['weight']),
bias_init=lambda *_ : jnp.array(param_dict['bias']))
else:
assert 'weight' in param_dict
return nn.Dense(features=features,
kernel_init=lambda *_ : jnp.array(param_dict['weight']))
def embedding(num_embeddings, features, param_dict, dtype='float32'):
if param_dict is None:
return nn.Embed(num_embeddings=num_embeddings, features=features, dtype=dtype)
else:
assert 'weight' in param_dict
embedding_init = lambda *_ : jnp.array(param_dict['weight'])
return nn.Embed(num_embeddings=num_embeddings, features=features, embedding_init=embedding_init, dtype=dtype)
#----------------------------------------------------------
# Activation
#----------------------------------------------------------
def apply_activation(x, activation='linear'):
if activation == 'linear':
return x
elif activation == 'gelu_new':
return 0.5 * x * (1.0 + nn.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * jnp.power(x, 3.0))))
elif activation == 'gelu_fast':
return 0.5 * x * (1.0 + nn.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
elif activation == 'gelu':
return jax.nn.gelu(x)
elif activation == 'relu':
return jax.nn.relu(x)
elif activation == 'leaky_relu':
return jax.nn.leaky_relu(x)
elif activation == 'sigmoid':
return jax.nn.sigmoid(x)
elif activation == 'tanh':
return nn.tanh(x)
else:
raise ValueError(f'Unknown activation function: {activation}.')
#----------------------------------------------------------
# Normalization
#----------------------------------------------------------
def layer_norm(param_dict, use_bias=True, use_scale=True, eps=1e-06, dtype='float32'):
if param_dict is None:
return nn.LayerNorm(use_bias=use_bias, use_scale=use_scale, epsilon=eps, dtype=dtype)
else:
kwargs = {'use_bias': use_bias, 'use_scale': use_scale, 'epsilon': eps, 'dtype': dtype}
if use_bias:
assert 'bias' in param_dict, 'use_bias is set True but bias parameter does not exist in param_dict.'
kwargs['bias_init'] = lambda *_ : jnp.array(param_dict['bias'])
if use_scale:
assert 'scale' in param_dict, 'use_scale is set True but scale parameter does not exist in param_dict.'
kwargs['scale_init'] = lambda *_ : jnp.array(param_dict['scale'])
return nn.LayerNorm(**kwargs)
#----------------------------------------------------------
# Attention
#----------------------------------------------------------
def split_heads(x, num_heads, head_dim):
"""
Splits embeddings for different heads.
Args:
x (tensor): Input tensor, shape [B, seq_len, embd_dim] or [B, blocks, block_len, embd_dim].
num_heads (int): Number of heads.
head_dim (int): Dimension of embedding for each head.
Returns:
(tensor): Output tensor, shape [B, num_head, seq_len, head_dim] or [B, blocks, num_head, block_len, head_dim].
"""
newshape = x.shape[:-1] + (num_heads, head_dim)
x = jnp.reshape(x, newshape)
if x.ndim == 5:
# [batch, blocks, head, block_len, head_dim]
return jnp.transpose(x, axes=(0, 1, 3, 2, 4))
elif x.ndim == 4:
# [batch, head, seq_len, head_dim]
return jnp.transpose(x, axes=(0, 2, 1, 3))
else:
raise ValueError(f'Input tensor should have rank 4 or 5, but has rank {x.ndim}.')
def merge_heads(x, num_heads, head_dim):
"""
Merge embeddings for different heads.
Args:
x (tensor): Input tensor, shape [B, num_head, seq_len, head_dim] or [B, blocks, num_head, block_len, head_dim].
num_heads (int): Number of heads.
head_dim (int): Dimension of embedding for each head.
Returns:
(tensor): Output tensor, shape [B, seq_len, embd_dim] or [B, blocks, block_len, embd_dim].
"""
if x.ndim == 5:
x = jnp.transpose(x, axes=(0, 1, 3, 2, 4))
elif x.ndim == 4:
x = jnp.transpose(x, axes=(0, 2, 1, 3))
else:
raise ValueError(f'Input tensor should have rank 4 or 5, but has rank {x.ndim}.')
newshape = x.shape[:-2] + (num_heads * head_dim,)
x = jnp.reshape(x, newshape)
return x
def attention(query, key, value, casual_mask, masked_bias, dropout, scale_attn_weights, training, attn_mask=None, head_mask=None, explicit_sparse=False, k=5):
"""
Computes Dot-Product Attention for the given query, key and value.
Args:
query (tensor): Query, shape [B, num_heads, seq_len, embd_dim].
key (tensor): Key, shape [B, num_heads, seq_len, embd_dim].
value (tensor): Value, shape [B, num_heads, seq_len, embd_dim].
casual_mask (tensor): Mask to ensure that attention is only applied to the left of the input sequence,
shape [1, 1, key_len - query_len :key_len, :key_len].
masked_bias (float): Value to insert for masked part of the sequence.
dropout (nn.Dropout): Dropout module that is applied to the attention output.
scale_attn_weights (bool): If True, scale the attention weights.
training (bool): Training mode.
attn_mask (tensor): Mask to avoid performing attention on padded tokens indices, shape [B, seq_len].
head_mask (tensor): Mask to nullify selected heads of the self-attention modules, shape [num_heads,] or [num_layers, num_heads].
Returns:
(tensor): Attention output, shape [B, num_heads, seq_len, embd_dim].
(tensor): Attention weights, shape [B, num_heads, seq_len, seq_len].
"""
query = query.astype(jnp.float32)
key = key.astype(jnp.float32)
attn_weights = jnp.matmul(query, jnp.swapaxes(key, -1, -2))
if scale_attn_weights:
attn_weights = attn_weights / (float(value.shape[-1]) ** 0.5)
attn_weights = jnp.where(casual_mask, attn_weights, masked_bias)
if attn_mask is not None:
attn_weights = attn_weights + attn_mask
if explicit_sparse:
v, _ = jax.lax.top_k(attn_weights, k=k)
vk = jnp.expand_dims(v[..., -1], axis=-1)
vk = jnp.tile(vk, [1, 1, 1, attn_weights.shape[-1]])
mask_k = jnp.less(attn_weights, vk)
attn_weights = jnp.where(mask_k, attn_weights, -1e18)
attn_weights = nn.softmax(attn_weights, axis=-1)
attn_weights = attn_weights.astype(value.dtype)
attn_weights = dropout(attn_weights, deterministic=not training)
if head_mask is not None:
attn_weights = attn_weights * head_mask
out = jnp.matmul(attn_weights, value)
return out, attn_weights
#----------------------------------------------------------
# Losses
#----------------------------------------------------------
def cross_entropy(logits, labels, ignore_index=-100):
"""
Computes the cross entroy loss (on logits).
Args:
logits (tensor): Logits, shape [B, num_classes].
labels (tensor): Labels, shape [B,].
ignore_index (int): Value of label to ignore for loss computation.
Returns:
(tensor): Cross entroy loss.
"""
batch_size, num_classes = logits.shape
logits = nn.log_softmax(logits)
# Get indices where label is equal to ignore_index
idx = jnp.nonzero(labels == ignore_index)[0]
one_hot_labels = jax.nn.one_hot(labels, num_classes=num_classes)
mult = one_hot_labels * logits
# Insert zeros, where the labels are equal to ignore_index
mult = mult.at[idx].set(jnp.zeros((idx.shape[0], num_classes)))
return -jnp.sum(jnp.sum(mult, axis=-1)) / (batch_size - idx.shape[0])
#----------------------------------------------------------
# Misc
#----------------------------------------------------------
def get(dictionary, key):
if dictionary is None or key not in dictionary:
return None
return dictionary[key]
def get_attention_mask(attn_mask, batch_size):
assert batch_size > 0, 'batch_size should be > 0.'
attn_mask = jnp.reshape(attn_mask, newshape=(batch_size, -1))
attn_mask = jnp.expand_dims(attn_mask, axis=(1, 2))
attn_mask = (1.0 - attn_mask) * -10000.0
return attn_mask
def get_head_mask(head_mask, num_layers):
if head_mask.ndim == 1:
head_mask = jnp.expand_dims(head_mask, newshape=(0, 1, -2, -1))
head_mask = jnp.repeat(head_mask, repeats=num_layers, axis=0)
elif head_mask.ndim == 2:
head_mask = jnp.expand_dims(head_mask, newshape=(1, -2, -1))
else:
raise ValueError(f'head_mask must have rank 5, but has rank {head_mask.ndim}.')
return head_mask
def load_config(path):
return json.loads(open(path, 'r', encoding='utf-8').read(), object_hook=lambda d : SimpleNamespace(**d))
================================================
FILE: flaxmodels/flaxmodels/utils.py
================================================
from tqdm import tqdm
import requests
import os
import tempfile
def download(ckpt_dir, url):
name = url[url.rfind('/') + 1 : url.rfind('?')]
if ckpt_dir is None:
ckpt_dir = tempfile.gettempdir()
ckpt_dir = os.path.join(ckpt_dir, 'flaxmodels')
ckpt_file = os.path.join(ckpt_dir, name)
if not os.path.exists(ckpt_file):
print(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}')
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
# first create temp file, in case the download fails
ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp')
with open(ckpt_file_temp, 'wb') as file:
for data in response.iter_content(chunk_size=1024):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
print('An error occured while downloading, please try again.')
if os.path.exists(ckpt_file_temp):
os.remove(ckpt_file_temp)
else:
# if download was successful, rename the temp file
os.rename(ckpt_file_temp, ckpt_file)
return ckpt_file
================================================
FILE: flaxmodels/setup.py
================================================
from setuptools import setup, find_packages
import os
directory = os.path.abspath(os.path.dirname(__file__))
with open(os.path.join(directory, 'README.md'), encoding='utf-8') as f:
long_description = f.read()
setup(name='flaxmodels',
version='0.1.2',
url='https://github.com/matthias-wright/flaxmodels',
author='Matthias Wright',
packages=find_packages(),
install_requires=['h5py>=2.10.0',
'numpy>=1.19.5',
'requests>=2.23.0',
'packaging>=20.9',
'dataclasses>=0.6',
'filelock>=3.0.12',
'jax>=0.3',
'jaxlib',
'flax>=0.4.0',
'Pillow>=7.1.2',
'regex>=2021.4.4',
'tqdm>=4.60.0'],
extras_require={
'testing': ['pytest'],
},
python_requires='>=3.6',
license='Each model has an individual license.',
description='A collection of pretrained models in Flax.',
long_description=long_description,
long_description_content_type='text/markdown')
================================================
FILE: human_label/README.md
================================================
# Generating your own human preferences
Based on the collected indices for queries in this folder, you could also generate your own real human preferences.
## Generating Videos
First, you have to generate videos for queries by running codes below.
```python
python -m JaxPref.human_label_preprocess_antmaze --env_name {AntMaze env name} --query_path ./human_label --save_dir {video folder to save} --num_query {number of query} --query_len {query length}
python -m JaxPref.human_label_preprocess_mujoco --env_name {Mujoco env name} --query_path ./human_label --save_dir {video folder to save} --num_query {number of query} --query_len {query length}
python -m JaxPref.human_label_preprocess_robosuite --dataset /mnt/changyeon/ICLR2023_rebuttal/robosuite --dataset_type ph --env {Lift/Can/Square} --use-obs --video_path {video folder to save} --render_image_names agentview_image --indices_path ./human_label/ --query_len {query length} --num_query {number of query}
```
## Labeling Human Preferences
After generating videos, You could use `label_program.ipynb` for collecting human preferences.
================================================
FILE: human_label/label_program.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import os\n",
"import numpy as np\n",
"\n",
"from IPython.display import Video\n",
"\n",
"\n",
"def get_label(ans):\n",
" try:\n",
" ans = int(ans)\n",
" except:\n",
" print(\"Wrong Input\")\n",
" return False\n",
" if ans not in [1,2,3]:\n",
" print(\"Invalid option.\")\n",
" return False\n",
" if ans == 1:\n",
" return [1, 0]\n",
" elif ans == 2:\n",
" return [0, 1]\n",
" else:\n",
" return [0.5, 0.5]\n",
"\n",
"\n",
"def create_human_label(save_dir, env_name, num_query=1000, start_idx=None, width=1000, height=500):\n",
" video_path = os.path.join(save_dir, env_name)\n",
" os.makedirs(os.path.join(video_path, \"label\"), exist_ok=True)\n",
" print(\"START!\")\n",
" if start_idx:\n",
" assert start_idx > 0, \"you must input with video number (1, 2, 3, ...)\"\n",
" interval = range(start_idx - 1, num_query)\n",
" else:\n",
" interval = range(num_query)\n",
" \n",
" for i in interval:\n",
" label = False\n",
" while not label:\n",
" print(f\"\\nVideo {i + 1}\")\n",
" video_file = os.path.join(video_path, f\"idx{i}.mp4\")\n",
" display(Video(video_file, width=width, height=height, html_attributes=\"loop autoplay\"))\n",
" reward = input(f\"[{i + 1}/{num_query}] Put Preference (1 (left), 2 (right), 3 (equal)): \").strip()\n",
" label = get_label(reward)\n",
" if label:\n",
" with open(os.path.join(video_path, \"label\", f\"label_{i}.txt\"), \"w\") as f:\n",
" f.write(reward)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"### create human label in save_dir, you could fix the start point.\n",
"create_human_label(save_dir=\"../video\", env_name=\"antmaze-large-diverse-v2\", start_idx=956, num_query=1000)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import glob\n",
"import pickle\n",
"import numpy as np\n",
"from tqdm import trange\n",
"\n",
"# make final pickle file from separated label files.\n",
"def merge_labels(save_dir, env_name=\"antmaze-medium-play-v2\", num_query=1000, query_len=100, seed=3407):\n",
" label_dir = os.path.join(save_dir, env_name, \"label\")\n",
" # label_files = sorted(glob.glob(os.path.join(label_dir, \"*.txt\")), key=lambda x: int(x.split(\".\")[0].split(\"_\")[-1]))\n",
" labels = []\n",
" for idx in trange(num_query):\n",
" assert os.path.exists(os.path.join(label_dir, f\"label_{idx}.txt\")), f\"labeling is not finished. {idx + 1} / {num_query}\"\n",
" with open(os.path.join(label_dir, f\"label_{idx}.txt\")) as f:\n",
" choice = int(f.read().strip())\n",
" if choice == 1:\n",
" _label = 0\n",
" elif choice == 2:\n",
" _label = 1\n",
" elif choice == 3:\n",
" _label = -1\n",
" labels.append(_label)\n",
" \n",
" # labels = np.array(labels)\n",
" \n",
" with open(os.path.join(save_dir, env_name, f\"human_labels_numq{num_query}_len{query_len}_s{seed}.pkl\"), \"wb\") as f:\n",
" pickle.dump(labels, f)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 22433.63it/s]\n"
]
}
],
"source": [
"merge_labels(save_dir=\"../video\", env_name=\"antmaze-medium-play-v2\")"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 19003.26it/s]\n"
]
}
],
"source": [
"merge_labels(save_dir=\"../video\", env_name=\"antmaze-medium-diverse-v2\")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 20405.77it/s]\n"
]
}
],
"source": [
"merge_labels(save_dir=\"../video\", env_name=\"antmaze-large-diverse-v2\")"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"with open(\"../video/antmaze-medium-play-v2/human_labels_numq1000_len100_s3407.pkl\", \"rb\") as f:\n",
" labels = pickle.load(f)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.6.8 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
},
"vscode": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}
================================================
FILE: learner.py
================================================
"""Implementations of algorithms for continuous control."""
from typing import Optional, Sequence, Tuple
import jax
import jax.numpy as jnp
import numpy as np
import optax
import policy
import value_net
from actor import update as awr_update_actor
from common import Batch, InfoDict, Model, PRNGKey
from critic import update_q, update_v
def target_update(critic: Model, target_critic: Model, tau: float) -> Model:
new_target_params = jax.tree_util.tree_map(
lambda p, tp: p * tau + tp * (1 - tau), critic.params,
target_critic.params)
return target_critic.replace(params=new_target_params)
@jax.jit
def _update_jit(
rng: PRNGKey, actor: Model, critic: Model, value: Model,
target_critic: Model, batch: Batch, discount: float, tau: float,
expectile: float, temperature: float
) -> Tuple[PRNGKey, Model, Model, Model, Model, Model, InfoDict]:
new_value, value_info = update_v(target_critic, value, batch, expectile)
key, rng = jax.random.split(rng)
new_actor, actor_info = awr_update_actor(key, actor, target_critic,
new_value, batch, temperature)
new_critic, critic_info = update_q(critic, new_value, batch, discount)
new_target_critic = target_update(new_critic, target_critic, tau)
return rng, new_actor, new_critic, new_value, new_target_critic, {
**critic_info,
**value_info,
**actor_info
}
class Learner(object):
def __init__(self,
seed: int,
observations: jnp.ndarray,
actions: jnp.ndarray,
actor_lr: float = 3e-4,
value_lr: float = 3e-4,
critic_lr: float = 3e-4,
hidden_dims: Sequence[int] = (256, 256),
discount: float = 0.99,
tau: float = 0.005,
expectile: float = 0.8,
temperature: float = 0.1,
dropout_rate: Optional[float] = None,
max_steps: Optional[int] = None,
opt_decay_schedule: str = "cosine"):
"""
An implementation of the version of Soft-Actor-Critic described in https://arxiv.org/abs/1801.01290
"""
self.expectile = expectile
self.tau = tau
self.discount = discount
self.temperature = temperature
rng = jax.random.PRNGKey(seed)
rng, actor_key, critic_key, value_key = jax.random.split(rng, 4)
action_dim = actions.shape[-1]
actor_def = policy.NormalTanhPolicy(hidden_dims,
action_dim,
log_std_scale=1e-3,
log_std_min=-5.0,
dropout_rate=dropout_rate,
state_dependent_std=False,
tanh_squash_distribution=False)
if opt_decay_schedule == "cosine":
schedule_fn = optax.cosine_decay_schedule(-actor_lr, max_steps)
optimiser = optax.chain(optax.scale_by_adam(),
optax.scale_by_schedule(schedule_fn))
else:
optimiser = optax.adam(learning_rate=actor_lr)
actor = Model.create(actor_def,
inputs=[actor_key, observations],
tx=optimiser)
critic_def = value_net.DoubleCritic(hidden_dims)
critic = Model.create(critic_def,
inputs=[critic_key, observations, actions],
tx=optax.adam(learning_rate=critic_lr))
value_def = value_net.ValueCritic(hidden_dims)
value = Model.create(value_def,
inputs=[value_key, observations],
tx=optax.adam(learning_rate=value_lr))
target_critic = Model.create(
critic_def, inputs=[critic_key, observations, actions])
self.actor = actor
self.critic = critic
self.value = value
self.target_critic = target_critic
self.rng = rng
def sample_actions(self,
observations: np.ndarray,
temperature: float = 1.0) -> jnp.ndarray:
rng, actions = policy.sample_actions(self.rng, self.actor.apply_fn,
self.actor.params, observations,
temperature)
self.rng = rng
actions = np.asarray(actions)
return np.clip(actions, -1, 1)
def update(self, batch: Batch) -> InfoDict:
new_rng, new_actor, new_critic, new_value, new_target_critic, info = _update_jit(
self.rng, self.actor, self.critic, self.value, self.target_critic,
batch, self.discount, self.tau, self.expectile, self.temperature)
self.rng = new_rng
self.actor = new_actor
self.critic = new_critic
self.value = new_value
self.target_critic = new_target_critic
return info
================================================
FILE: policy.py
================================================
import functools
from typing import Optional, Sequence, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
from common import MLP, Params, PRNGKey, default_init
LOG_STD_MIN = -10.0
LOG_STD_MAX = 2.0
class NormalTanhPolicy(nn.Module):
hidden_dims: Sequence[int]
action_dim: int
state_dependent_std: bool = True
dropout_rate: Optional[float] = None
log_std_scale: float = 1.0
log_std_min: Optional[float] = None
log_std_max: Optional[float] = None
tanh_squash_distribution: bool = True
@nn.compact
def __call__(self,
observations: jnp.ndarray,
temperature: float = 1.0,
training: bool = False) -> tfd.Distribution:
outputs = MLP(self.hidden_dims,
activate_final=True,
dropout_rate=self.dropout_rate)(observations,
training=training)
means = nn.Dense(self.action_dim, kernel_init=default_init())(outputs)
if self.state_dependent_std:
log_stds = nn.Dense(self.action_dim,
kernel_init=default_init(
self.log_std_scale))(outputs)
else:
log_stds = self.param('log_stds', nn.initializers.zeros,
(self.action_dim, ))
log_std_min = self.log_std_min or LOG_STD_MIN
log_std_max = self.log_std_max or LOG_STD_MAX
log_stds = jnp.clip(log_stds, log_std_min, log_std_max)
if not self.tanh_squash_distribution:
means = nn.tanh(means)
base_dist = tfd.MultivariateNormalDiag(loc=means,
scale_diag=jnp.exp(log_stds) *
temperature)
if self.tanh_squash_distribution:
return tfd.TransformedDistribution(distribution=base_dist,
bijector=tfb.Tanh())
else:
return base_dist
@functools.partial(jax.jit, static_argnames=('actor_def', 'distribution'))
def _sample_actions(rng: PRNGKey,
actor_def: nn.Module,
actor_params: Params,
observations: np.ndarray,
temperature: float = 1.0) -> Tuple[PRNGKey, jnp.ndarray]:
dist = actor_def.apply({'params': actor_params}, observations, temperature)
rng, key = jax.random.split(rng)
return rng, dist.sample(seed=key)
def sample_actions(rng: PRNGKey,
actor_def: nn.Module,
actor_params: Params,
observations: np.ndarray,
temperature: float = 1.0) -> Tuple[PRNGKey, jnp.ndarray]:
return _sample_actions(rng, actor_def, actor_params, observations,
temperature)
================================================
FILE: requirements.txt
================================================
numpy >= 1.20.2
scipy >= 1.6.0
absl-py >= 0.12.0
gym[mujoco] >= 0.18.0
gdown >= 3.12.2
tqdm >= 4.60.0
flax >= 0.3.5
jax >= 0.2.27
ml_collections >= 0.1.0
optax >= 0.0.6
tensorboardX == 2.1
tensorflow-probability >= 0.14.1
imageio >= 2.9.0
imageio-ffmpeg >= 0.4.3
pandas
git+https://github.com/ARISE-Initiative/robosuite.git@v1.3
git+https://github.com/ARISE-Initiative/robomimic.git
================================================
FILE: robosuite_train_offline.py
================================================
import datetime
import os
import pickle
from typing import Tuple
import gym
import numpy as np
from tqdm import tqdm
from absl import app, flags
from flax.training import checkpoints
from ml_collections import config_flags
from tensorboardX import SummaryWriter
import robosuite as suite
from robosuite.wrappers import GymWrapper
import robomimic.utils.env_utils as EnvUtils
import wrappers
from JaxPref.reward_transform import qlearning_robosuite_dataset
from dataset_utils import D4RLDataset, RelabeledDataset, reward_from_preference, reward_from_preference_transformer, split_into_trajectories
from evaluation import evaluate
from learner import Learner
# os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.40'
FLAGS = flags.FLAGS
flags.DEFINE_string('env_name', 'halfcheetah-expert-v2', 'Environment name.')
flags.DEFINE_string('save_dir', './logs/', 'Tensorboard logging dir.')
flags.DEFINE_integer('seed', 42, 'Random seed.')
flags.DEFINE_integer('eval_episodes', 10,
'Number of episodes used for evaluation.')
flags.DEFINE_integer('log_interval', 1000, 'Logging interval.')
flags.DEFINE_integer('eval_interval', 5000, 'Eval interval.')
flags.DEFINE_integer('batch_size', 256, 'Mini batch size.')
flags.DEFINE_integer('max_steps', int(1e6), 'Number of training steps.')
flags.DEFINE_boolean('tqdm', True, 'Use tqdm progress bar.')
flags.DEFINE_boolean('use_reward_model', False, 'Use reward model for relabeling reward.')
flags.DEFINE_string('model_type', 'MLP', 'type of reward model.')
flags.DEFINE_string('ckpt_dir',
'./logs/pref_reward',
'ckpt path for reward model.')
flags.DEFINE_string('comment',
'base',
'comment for distinguishing experiments.')
flags.DEFINE_integer('seq_len', 25, 'sequence length for relabeling reward in Transformer.')
flags.DEFINE_bool('use_diff', False, 'boolean whether use difference in sequence for reward relabeling.')
flags.DEFINE_string('label_mode', 'last', 'mode for relabeling reward with tranformer.')
flags.DEFINE_string('pref_attn_type', 'max', 'mode for preference attention with tranformer.')
flags.DEFINE_integer('max_episode_steps', 500, 'max_episode_steps for rollout.')
flags.DEFINE_string('robosuite_dataset_path', './data', 'hdf5 dataset path for demonstrations')
flags.DEFINE_string('robosuite_dataset_type', 'ph', 'dataset type for robosuite')
# flags.DEFINE_list(
# 'obs_keys',
# ["robot0_joint_pos_cos", "robot0_joint_pos_sin", "robot0_joint_vel", "robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos", "robot0_gripper_qvel", "object"],
# 'obs keys for using in making observations.'
# )
config_flags.DEFINE_config_file(
'config',
'default.py',
'File path to the training hyperparameter configuration.',
lock_config=False)
def normalize(dataset, env_name, max_episode_steps=1000):
trajs = split_into_trajectories(dataset.observations, dataset.actions,
dataset.rewards, dataset.masks,
dataset.dones_float,
dataset.next_observations)
trj_mapper = []
for trj_idx, traj in tqdm(enumerate(trajs), total=len(trajs), desc="chunk trajectories"):
traj_len = len(traj)
for _ in range(traj_len):
trj_mapper.append((trj_idx, traj_len))
def compute_returns(traj):
episode_return = 0
for _, _, rew, _, _, _ in traj:
episode_return += rew
return episode_return
sorted_trajs = sorted(trajs, key=compute_returns)
min_return, max_return = compute_returns(sorted_trajs[0]), compute_returns(sorted_trajs[-1])
normalized_rewards = []
for i in range(dataset.size):
_reward = dataset.rewards[i]
if 'antmaze' in env_name:
_, len_trj = trj_mapper[i]
_reward -= min_return / len_trj
_reward /= max_return - min_return
# if ('halfcheetah' in env_name or 'walker2d' in env_name or 'hopper' in env_name):
_reward *= max_episode_steps
normalized_rewards.append(_reward)
dataset.rewards = np.array(normalized_rewards)
def make_env_and_dataset(env_name: str,
seed: int,
dataset_path: str,
max_episode_steps: int = 500) -> Tuple[gym.Env, D4RLDataset]:
ds = qlearning_robosuite_dataset(dataset_path)
dataset = RelabeledDataset(ds['observations'], ds['actions'], ds['rewards'], ds['terminals'], ds['next_observations'])
ds['env_meta']['env_kwargs']['horizon'] = max_episode_steps
env = EnvUtils.create_env_from_metadata(
env_meta=ds['env_meta'],
render=False, # no on-screen rendering
render_offscreen=False, # off-screen rendering to support rendering video frames
).env
env.ignore_done = False
env._max_episode_steps = env.horizon
env = GymWrapper(env)
env = wrappers.RobosuiteWrapper(env)
env = wrappers.EpisodeMonitor(env)
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
if FLAGS.use_reward_model:
reward_model = initialize_model()
if FLAGS.model_type == "MR":
dataset = reward_from_preference(FLAGS.env_name, dataset, reward_model, batch_size=FLAGS.batch_size)
else:
dataset = reward_from_preference_transformer(
FLAGS.env_name,
dataset,
reward_model,
batch_size=FLAGS.batch_size,
seq_len=FLAGS.seq_len,
use_diff=FLAGS.use_diff,
label_mode=FLAGS.label_mode
)
del reward_model
if FLAGS.use_reward_model:
normalize(dataset, FLAGS.env_name, max_episode_steps=env.env.env._max_episode_steps)
# if 'antmaze' in FLAGS.env_name:
# dataset.rewards -= 1.0
if ('halfcheetah' in FLAGS.env_name or 'walker2d' in FLAGS.env_name or 'hopper' in FLAGS.env_name):
dataset.rewards += 0.5
else:
if 'antmaze' in FLAGS.env_name:
dataset.rewards -= 1.0
# See https://github.com/aviralkumar2907/CQL/blob/master/d4rl/examples/cql_antmaze_new.py#L22
# but I found no difference between (x - 0.5) * 4 and x - 1.0
elif ('halfcheetah' in FLAGS.env_name or 'walker2d' in FLAGS.env_name or 'hopper' in FLAGS.env_name):
normalize(dataset, FLAGS.env_name, max_episode_steps=env.env.env._max_episode_steps)
if 'pen' in FLAGS.env_name or 'hammer' in FLAGS.env_name:
trajs = split_into_trajectories(dataset.observations, dataset.actions,
dataset.rewards, dataset.masks,
dataset.dones_float,
dataset.next_observations)
trj_cumsum = np.cumsum([len(traj) for traj in trajs])
split_point = trj_cumsum[int(len(trajs) // 2)]
dataset.observations = dataset.observations[:split_point]
dataset.actions = dataset.actions[:split_point]
dataset.rewards = dataset.rewards[:split_point]
dataset.masks = dataset.masks[:split_point]
dataset.dones_float = dataset.dones_float[:split_point]
dataset.next_observations = dataset.next_observations[:split_point]
dataset.size = len(dataset.observations)
return env, dataset
def initialize_model():
if os.path.exists(os.path.join(FLAGS.ckpt_dir, "best_model.pkl")):
model_path = os.path.join(FLAGS.ckpt_dir, "best_model.pkl")
else:
model_path = os.path.join(FLAGS.ckpt_dir, "model.pkl")
with open(model_path, "rb") as f:
ckpt = pickle.load(f)
reward_model = ckpt['reward_model']
if FLAGS.model_type == "PrefTransformer":
reward_model.trans.config.pref_attn_type = FLAGS.pref_attn_type
return reward_model
def main(_):
save_dir = os.path.join(FLAGS.save_dir, 'tb',
FLAGS.env_name,
f"reward_{FLAGS.use_reward_model}_{FLAGS.model_type}" if FLAGS.use_reward_model else "original",
f"{FLAGS.comment}",
str(FLAGS.seed),
f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
summary_writer = SummaryWriter(save_dir,
write_to_disk=True)
os.makedirs(FLAGS.save_dir, exist_ok=True)
dataset_path = os.path.join(FLAGS.robosuite_dataset_path, FLAGS.env_name.lower(), FLAGS.robosuite_dataset_type, "low_dim.hdf5")
env, dataset = make_env_and_dataset(FLAGS.env_name, FLAGS.seed, dataset_path, max_episode_steps=FLAGS.max_episode_steps)
kwargs = dict(FLAGS.config)
agent = Learner(FLAGS.seed,
env.observation_space.sample()[np.newaxis],
env.action_space.sample()[np.newaxis],
max_steps=FLAGS.max_steps,
**kwargs)
eval_returns = []
for i in tqdm(range(1, FLAGS.max_steps + 1), smoothing=0.1, disable=not FLAGS.tqdm):
batch = dataset.sample(FLAGS.batch_size)
update_info = agent.update(batch)
if i % FLAGS.log_interval == 0:
for k, v in update_info.items():
if v.ndim == 0:
summary_writer.add_scalar(f'training/{k}', v, i)
else:
summary_writer.add_histogram(f'training/{k}', v, i)
summary_writer.flush()
if i % FLAGS.eval_interval == 0:
eval_stats = evaluate(agent, env, FLAGS.eval_episodes)
for k, v in eval_stats.items():
summary_writer.add_scalar(f'evaluation/average_{k}s', v, i)
summary_writer.flush()
eval_returns.append((i, eval_stats['return']))
np.savetxt(os.path.join(save_dir, 'progress.txt'),
eval_returns,
fmt=['%d', '%.1f'])
# save IQL agent for last timestep.
checkpoints.save_checkpoint(os.path.join(save_dir, "actor"), target=agent.actor, step=FLAGS.max_steps)
checkpoints.save_checkpoint(os.path.join(save_dir, "critic"), target=agent.critic, step=FLAGS.max_steps)
checkpoints.save_checkpoint(os.path.join(save_dir, "value"), target=agent.value, step=FLAGS.max_steps)
checkpoints.save_checkpoint(os.path.join(save_dir, "target_critic"), target=agent.actor, step=FLAGS.max_steps)
if __name__ == '__main__':
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
app.run(main)
================================================
FILE: train_finetune.py
================================================
import os
from typing import Tuple
import gym
import numpy as np
import tqdm
from absl import app, flags
from ml_collections import config_flags
from tensorboardX import SummaryWriter
import wrappers
from dataset_utils import (Batch, D4RLDataset, ReplayBuffer,
split_into_trajectories)
from evaluation import evaluate
from learner import Learner
FLAGS = flags.FLAGS
flags.DEFINE_string('env_name', 'halfcheetah-expert-v2', 'Environment name.')
flags.DEFINE_string('save_dir', './tmp/', 'Tensorboard logging dir.')
flags.DEFINE_integer('seed', 42, 'Random seed.')
flags.DEFINE_integer('eval_episodes', 100,
'Number of episodes used for evaluation.')
flags.DEFINE_integer('log_interval', 1000, 'Logging interval.')
flags.DEFINE_integer('eval_interval', 100000, 'Eval interval.')
flags.DEFINE_integer('batch_size', 256, 'Mini batch size.')
flags.DEFINE_integer('max_steps', int(1e6), 'Number of training steps.')
flags.DEFINE_integer('num_pretraining_steps', int(1e6),
'Number of pretraining steps.')
flags.DEFINE_integer('replay_buffer_size', 2000000,
'Replay buffer size (=max_steps if unspecified).')
flags.DEFINE_integer('init_dataset_size', None,
'Offline data size (uses all data if unspecified).')
flags.DEFINE_boolean('tqdm', True, 'Use tqdm progress bar.')
config_flags.DEFINE_config_file(
'config',
'configs/antmaze_finetune_config.py',
'File path to the training hyperparameter configuration.',
lock_config=False)
def normalize(dataset):
trajs = split_into_trajectories(dataset.observations, dataset.actions,
dataset.rewards, dataset.masks,
dataset.dones_float,
dataset.next_observations)
def compute_returns(traj):
episode_return = 0
for _, _, rew, _, _, _ in traj:
episode_return += rew
return episode_return
trajs.sort(key=compute_returns)
dataset.rewards /= compute_returns(trajs[-1]) - compute_returns(trajs[0])
dataset.rewards *= 1000.0
def make_env_and_dataset(env_name: str,
seed: int) -> Tuple[gym.Env, D4RLDataset]:
env = gym.make(env_name)
env = wrappers.EpisodeMonitor(env)
env = wrappers.SinglePrecision(env)
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
dataset = D4RLDataset(env)
if 'antmaze' in FLAGS.env_name:
# dataset.rewards -= 1.0
pass # normalized in the batch instead
# See https://github.com/aviralkumar2907/CQL/blob/master/d4rl/examples/cql_antmaze_new.py#L22
# but I found no difference between (x - 0.5) * 4 and x - 1.0
elif ('halfcheetah' in FLAGS.env_name or 'walker2d' in FLAGS.env_name
or 'hopper' in FLAGS.env_name):
normalize(dataset)
return env, dataset
def main(_):
summary_writer = SummaryWriter(os.path.join(FLAGS.save_dir, 'tb',
str(FLAGS.seed)),
write_to_disk=True)
os.makedirs(FLAGS.save_dir, exist_ok=True)
env, dataset = make_env_and_dataset(FLAGS.env_name, FLAGS.seed)
action_dim = env.action_space.shape[0]
replay_buffer = ReplayBuffer(env.observation_space, action_dim,
FLAGS.replay_buffer_size or FLAGS.max_steps)
replay_buffer.initialize_with_dataset(dataset, FLAGS.init_dataset_size)
kwargs = dict(FLAGS.config)
agent = Learner(FLAGS.seed,
env.observation_space.sample()[np.newaxis],
env.action_space.sample()[np.newaxis], **kwargs)
eval_returns = []
observation, done = env.reset(), False
# Use negative indices for pretraining steps.
for i in tqdm.tqdm(range(1 - FLAGS.num_pretraining_steps,
FLAGS.max_steps + 1),
smoothing=0.1,
disable=not FLAGS.tqdm):
if i >= 1:
action = agent.sample_actions(observation, )
action = np.clip(action, -1, 1)
next_observation, reward, done, info = env.step(action)
if not done or 'TimeLimit.truncated' in info:
mask = 1.0
else:
mask = 0.0
replay_buffer.insert(observation, action, reward, mask,
float(done), next_observation)
observation = next_observation
if done:
observation, done = env.reset(), False
for k, v in info['episode'].items():
summary_writer.add_scalar(f'training/{k}', v,
info['total']['timesteps'])
else:
info = {}
info['total'] = {'timesteps': i}
batch = replay_buffer.sample(FLAGS.batch_size)
if 'antmaze' in FLAGS.env_name:
batch = Batch(observations=batch.observations,
actions=batch.actions,
rewards=batch.rewards - 1,
masks=batch.masks,
next_observations=batch.next_observations)
update_info = agent.update(batch)
if i % FLAGS.log_interval == 0:
for k, v in update_info.items():
if v.ndim == 0:
summary_writer.add_scalar(f'training/{k}', v, i)
else:
summary_writer.add_histogram(f'training/{k}', v, i)
summary_writer.flush()
if i % FLAGS.eval_interval == 0:
eval_stats = evaluate(agent, env, FLAGS.eval_episodes)
for k, v in eval_stats.items():
summary_writer.add_scalar(f'evaluation/average_{k}s', v, i)
summary_writer.flush()
eval_returns.append((i, eval_stats['return']))
np.savetxt(os.path.join(FLAGS.save_dir, f'{FLAGS.seed}.txt'),
eval_returns,
fmt=['%d', '%.1f'])
if __name__ == '__main__':
app.run(main)
================================================
FILE: train_offline.py
================================================
import datetime
import os
import pickle
from typing import Tuple
import gym
import numpy as np
from tqdm import tqdm
from absl import app, flags
from ml_collections import config_flags
from tensorboardX import SummaryWriter
import wrappers
from dataset_utils import D4RLDataset, reward_from_preference, reward_from_preference_transformer, split_into_trajectories
from evaluation import evaluate
from learner import Learner
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.40'
FLAGS = flags.FLAGS
flags.DEFINE_string('env_name', 'halfcheetah-expert-v2', 'Environment name.')
flags.DEFINE_string('save_dir', './logs/', 'Tensorboard logging dir.')
flags.DEFINE_integer('seed', 42, 'Random seed.')
flags.DEFINE_integer('eval_episodes', 10,
'Number of episodes used for evaluation.')
flags.DEFINE_integer('log_interval', 1000, 'Logging interval.')
flags.DEFINE_integer('eval_interval', 5000, 'Eval interval.')
flags.DEFINE_integer('batch_size', 256, 'Mini batch size.')
flags.DEFINE_integer('max_steps', int(1e6), 'Number of training steps.')
flags.DEFINE_boolean('tqdm', True, 'Use tqdm progress bar.')
flags.DEFINE_boolean('use_reward_model', False, 'Use reward model for relabeling reward.')
flags.DEFINE_string('model_type', 'MLP', 'type of reward model.')
flags.DEFINE_string('ckpt_dir',
'./logs/pref_reward',
'ckpt path for reward model.')
flags.DEFINE_string('comment',
'base',
'comment for distinguishing experiments.')
flags.DEFINE_integer('seq_len', 25, 'sequence length for relabeling reward in Transformer.')
flags.DEFINE_bool('use_diff', False, 'boolean whether use difference in sequence for reward relabeling.')
flags.DEFINE_string('label_mode', 'last', 'mode for relabeling reward with tranformer.')
config_flags.DEFINE_config_file(
'config',
'default.py',
'File path to the training hyperparameter configuration.',
lock_config=False)
def normalize(dataset, env_name, max_episode_steps=1000):
trajs = split_into_trajectories(dataset.observations, dataset.actions,
dataset.rewards, dataset.masks,
dataset.dones_float,
dataset.next_observations)
trj_mapper = []
for trj_idx, traj in tqdm(enumerate(trajs), total=len(trajs), desc="chunk trajectories"):
traj_len = len(traj)
for _ in range(traj_len):
trj_mapper.append((trj_idx, traj_len))
def compute_returns(traj):
episode_return = 0
for _, _, rew, _, _, _ in traj:
episode_return += rew
return episode_return
sorted_trajs = sorted(trajs, key=compute_returns)
min_return, max_return = compute_returns(sorted_trajs[0]), compute_returns(sorted_trajs[-1])
normalized_rewards = []
for i in range(dataset.size):
_reward = dataset.rewards[i]
if 'antmaze' in env_name:
_, len_trj = trj_mapper[i]
_reward -= min_return / len_trj
_reward /= max_return - min_return
# if ('halfcheetah' in env_name or 'walker2d' in env_name or 'hopper' in env_name):
_reward *= max_episode_steps
normalized_rewards.append(_reward)
dataset.rewards = np.array(normalized_rewards)
def make_env_and_dataset(env_name: str,
seed: int) -> Tuple[gym.Env, D4RLDataset]:
env = gym.make(env_name)
env = wrappers.EpisodeMonitor(env)
env = wrappers.SinglePrecision(env)
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
dataset = D4RLDataset(env)
if FLAGS.use_reward_model:
reward_model = initialize_model()
if FLAGS.model_type == "MR":
dataset = reward_from_preference(FLAGS.env_name, dataset, reward_model, batch_size=FLAGS.batch_size)
else:
dataset = reward_from_preference_transformer(
FLAGS.env_name,
dataset,
reward_model,
batch_size=FLAGS.batch_size,
seq_len=FLAGS.seq_len,
use_diff=FLAGS.use_diff,
label_mode=FLAGS.label_mode
)
del reward_model
if FLAGS.use_reward_model:
normalize(dataset, FLAGS.env_name, max_episode_steps=env.env.env._max_episode_steps)
if 'antmaze' in FLAGS.env_name:
dataset.rewards -= 1.0
if ('halfcheetah' in FLAGS.env_name or 'walker2d' in FLAGS.env_name or 'hopper' in FLAGS.env_name):
dataset.rewards += 0.5
else:
if 'antmaze' in FLAGS.env_name:
dataset.rewards -= 1.0
# See https://github.com/aviralkumar2907/CQL/blob/master/d4rl/examples/cql_antmaze_new.py#L22
# but I found no difference between (x - 0.5) * 4 and x - 1.0
elif ('halfcheetah' in FLAGS.env_name or 'walker2d' in FLAGS.env_name or 'hopper' in FLAGS.env_name):
normalize(dataset, FLAGS.env_name, max_episode_steps=env.env.env._max_episode_steps)
return env, dataset
def initialize_model():
if os.path.exists(os.path.join(FLAGS.ckpt_dir, "best_model.pkl")):
model_path = os.path.join(FLAGS.ckpt_dir, "best_model.pkl")
else:
model_path = os.path.join(FLAGS.ckpt_dir, "model.pkl")
with open(model_path, "rb") as f:
ckpt = pickle.load(f)
reward_model = ckpt['reward_model']
return reward_model
def main(_):
save_dir = os.path.join(FLAGS.save_dir, 'tb',
FLAGS.env_name,
f"reward_{FLAGS.use_reward_model}_{FLAGS.model_type}" if FLAGS.use_reward_model else "original",
f"{FLAGS.comment}",
str(FLAGS.seed),
f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
summary_writer = SummaryWriter(save_dir,
write_to_disk=True)
os.makedirs(FLAGS.save_dir, exist_ok=True)
env, dataset = make_env_and_dataset(FLAGS.env_name, FLAGS.seed)
kwargs = dict(FLAGS.config)
agent = Learner(FLAGS.seed,
env.observation_space.sample()[np.newaxis],
env.action_space.sample()[np.newaxis],
max_steps=FLAGS.max_steps,
**kwargs)
eval_returns = []
for i in tqdm(range(1, FLAGS.max_steps + 1), smoothing=0.1, disable=not FLAGS.tqdm):
batch = dataset.sample(FLAGS.batch_size)
update_info = agent.update(batch)
if i % FLAGS.log_interval == 0:
for k, v in update_info.items():
if v.ndim == 0:
summary_writer.add_scalar(f'training/{k}', v, i)
else:
summary_writer.add_histogram(f'training/{k}', v, i)
summary_writer.flush()
if i % FLAGS.eval_interval == 0:
eval_stats = evaluate(agent, env, FLAGS.eval_episodes)
for k, v in eval_stats.items():
summary_writer.add_scalar(f'evaluation/average_{k}s', v, i)
summary_writer.flush()
eval_returns.append((i, eval_stats['return']))
np.savetxt(os.path.join(save_dir, 'progress.txt'),
eval_returns,
fmt=['%d', '%.1f'])
if __name__ == '__main__':
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
app.run(main)
================================================
FILE: value_net.py
================================================
from typing import Callable, Sequence, Tuple
import jax.numpy as jnp
from flax import linen as nn
from common import MLP
class ValueCritic(nn.Module):
hidden_dims: Sequence[int]
@nn.compact
def __call__(self, observations: jnp.ndarray) -> jnp.ndarray:
critic = MLP((*self.hidden_dims, 1))(observations)
return jnp.squeeze(critic, -1)
class Critic(nn.Module):
hidden_dims: Sequence[int]
activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
@nn.compact
def __call__(self, observations: jnp.ndarray,
actions: jnp.ndarray) -> jnp.ndarray:
inputs = jnp.concatenate([observations, actions], -1)
critic = MLP((*self.hidden_dims, 1),
activations=self.activations)(inputs)
return jnp.squeeze(critic, -1)
class DoubleCritic(nn.Module):
hidden_dims: Sequence[int]
activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
@nn.compact
def __call__(self, observations: jnp.ndarray,
actions: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
critic1 = Critic(self.hidden_dims,
activations=self.activations)(observations, actions)
critic2 = Critic(self.hidden_dims,
activations=self.activations)(observations, actions)
return critic1, critic2
================================================
FILE: viskit/__init__.py
================================================
__author__ = 'dementrock'
================================================
FILE: viskit/core.py
================================================
import csv
import math
import os
import numpy as np
import json
import itertools
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def unique(l):
return list(set(l))
def flatten(l):
return [item for sublist in l for item in sublist]
def load_progress(progress_csv_path):
print("Reading %s" % progress_csv_path)
entries = dict()
if progress_csv_path.split('.')[-1] == "csv":
delimiter = ','
else:
delimiter = '\t'
with open(progress_csv_path, 'r') as csvfile:
reader = csv.DictReader(csvfile, delimiter=delimiter)
for row in reader:
for k, v in row.items():
if k not in entries:
entries[k] = []
try:
entries[k].append(float(v))
except:
entries[k].append(0.)
entries = dict([(k, np.array(v)) for k, v in entries.items()])
return entries
def to_json(stub_object):
from rllab.misc.instrument import StubObject
from rllab.misc.instrument import StubAttr
if isinstance(stub_object, StubObject):
assert len(stub_object.args) == 0
data = dict()
for k, v in stub_object.kwargs.items():
data[k] = to_json(v)
data["_name"] = stub_object.proxy_class.__module__ + \
"." + stub_object.proxy_class.__name__
return data
elif isinstance(stub_object, StubAttr):
return dict(
obj=to_json(stub_object.obj),
attr=to_json(stub_object.attr_name)
)
return stub_object
def flatten_dict(d):
flat_params = dict()
for k, v in d.items():
if isinstance(v, dict):
v = flatten_dict(v)
for subk, subv in flatten_dict(v).items():
flat_params[k + "." + subk] = subv
else:
flat_params[k] = v
return flat_params
def load_params(params_json_path):
with open(params_json_path, 'r') as f:
data = json.loads(f.read())
if "args_data" in data:
del data["args_data"]
if "exp_name" not in data:
data["exp_name"] = params_json_path.split("/")[-2]
return data
def lookup(d, keys):
if not isinstance(keys, list):
keys = keys.split(".")
for k in keys:
if hasattr(d, "__getitem__"):
if k in d:
d = d[k]
else:
return None
else:
return None
return d
def load_exps_data(
exp_folder_paths,
data_filename='progress.csv',
params_filename='params.json',
disable_variant=False,
):
exps = []
for exp_folder_path in exp_folder_paths:
exps += [x[0] for x in os.walk(exp_folder_path)]
exps_data = []
for exp in exps:
try:
exp_path = exp
params_json_path = os.path.join(exp_path, params_filename)
variant_json_path = os.path.join(exp_path, "variant.json")
progress_csv_path = os.path.join(exp_path, data_filename)
if os.stat(progress_csv_path).st_size == 0:
progress_csv_path = os.path.join(exp_path, "log.txt")
progress = load_progress(progress_csv_path)
if disable_variant:
params = load_params(params_json_path)
else:
try:
params = load_params(variant_json_path)
except IOError:
params = load_params(params_json_path)
exps_data.append(AttrDict(
progress=progress,
params=params,
flat_params=flatten_dict(params)))
except IOError as e:
print(e)
return exps_data
def smart_repr(x):
if isinstance(x, tuple):
if len(x) == 0:
return "tuple()"
elif len(x) == 1:
return "(%s,)" % smart_repr(x[0])
else:
return "(" + ",".join(map(smart_repr, x)) + ")"
elif isinstance(x, list):
if len(x) == 0:
return "[]"
elif len(x) == 1:
return "[%s,]" % smart_repr(x[0])
else:
return "[" + ",".join(map(smart_repr, x)) + "]"
else:
if hasattr(x, "__call__"):
return "__import__('pydoc').locate('%s')" % (x.__module__ + "." + x.__name__)
elif isinstance(x, float) and math.isnan(x):
return 'float("nan")'
else:
return repr(x)
def smart_eval(string):
string = string.replace(',inf)', ',"inf")')
return eval(string)
def extract_distinct_params(exps_data, excluded_params=('seed', 'log_dir'), l=1):
# all_pairs = unique(flatten([d.flat_params.items() for d in exps_data]))
# if logger:
# logger("(Excluding {excluded})".format(excluded=', '.join(excluded_params)))
# def cmp(x,y):
# if x < y:
# return -1
# elif x > y:
# return 1
# else:
# return 0
try:
params_as_evalable_strings = [
list(
map(
smart_repr,
list(d.flat_params.items())
)
)
for d in exps_data
]
unique_params = unique(
flatten(
params_as_evalable_strings
)
)
stringified_pairs = sorted(
map(
smart_eval,
unique_params
),
key=lambda x: (
tuple(smart_repr(i) for i in x)
# tuple(0. if it is None else it for it in x),
)
)
except Exception as e:
print(e)
import ipdb; ipdb.set_trace()
proposals = [(k, [x[1] for x in v])
for k, v in itertools.groupby(stringified_pairs, lambda x: x[0])]
filtered = [
(k, v) for (k, v) in proposals
if k == 'version' or (
len(v) > l and all(
[k.find(excluded_param) != 0
for excluded_param in excluded_params]
)
)
]
return filtered
def exp_has_key_value(exp, k, v):
return (
str(exp.flat_params.get(k, None)) == str(v)
# TODO: include this?
or (k not in exp.flat_params)
)
class Selector(object):
def __init__(self, exps_data, filters=None, custom_filters=None):
self._exps_data = exps_data
if filters is None:
self._filters = tuple()
else:
self._filters = tuple(filters)
if custom_filters is None:
self._custom_filters = []
else:
self._custom_filters = custom_filters
def where(self, k, v):
return Selector(
self._exps_data,
self._filters + ((k, v),),
self._custom_filters,
)
def where_not(self, k, v):
return Selector(
self._exps_data,
self._filters,
self._custom_filters + [
lambda exp: not exp_has_key_value(exp, k, v)
],
)
def custom_filter(self, filter):
return Selector(self._exps_data, self._filters, self._custom_filters + [filter])
def _check_exp(self, exp):
# or exp.flat_params.get(k, None) is None
return all(
(
exp_has_key_value(exp, k, v)
for k, v in self._filters
)
) and all(custom_filter(exp) for custom_filter in self._custom_filters)
def extract(self):
return list(filter(self._check_exp, self._exps_data))
def iextract(self):
return filter(self._check_exp, self._exps_data)
# Taken from plot.ly
color_defaults = [
'#1f77b4', # muted blue
'#ff7f0e', # safety orange
'#2ca02c', # cooked asparagus green
'#d62728', # brick red
'#9467bd', # muted purple
'#8c564b', # chestnut brown
'#e377c2', # raspberry yogurt pink
'#7f7f7f', # middle gray
'#bcbd22', # curry yellow-green
'#17becf' # blue-teal
]
def hex_to_rgb(hex, opacity=1.0):
if hex[0] == '#':
hex = hex[1:]
assert (len(hex) == 6)
return "rgba({0},{1},{2},{3})".format(int(hex[:2], 16), int(hex[2:4], 16), int(hex[4:6], 16), opacity)
================================================
FILE: viskit/frontend.py
================================================
import sys
from viskit.core import AttrDict
sys.path.append('.')
import matplotlib
import os
matplotlib.use('Agg')
import flask # import Flask, render_template, send_from_directory
from viskit import core
import sys
import argparse
import json
import numpy as np
from plotly import tools
import plotly.offline as po
import plotly.graph_objs as go
named_colors = [
'dodgerblue',
'darkorange',
'green',
'cyan',
'magenta',
'orange',
'yellow',
'black',
'blue',
'brown',
'lime',
'pink',
'purple',
]
def flatten(xs):
return [x for y in xs for x in y]
def sliding_mean(data_array, window=5):
data_array = np.array(data_array)
new_list = []
for i in range(len(data_array)):
indices = list(range(max(i - window + 1, 0),
min(i + window + 1, len(data_array))))
avg = 0
for j in indices:
avg += data_array[j]
avg /= float(len(indices))
new_list.append(avg)
return np.array(new_list)
import itertools
app = flask.Flask(__name__, static_url_path='/static')
exps_data = None
plottable_keys = None
distinct_params = None
@app.route('/js/')
def send_js(path):
return flask.send_from_directory('js', path)
@app.route('/css/')
def send_css(path):
return flask.send_from_directory('css', path)
def create_bar_chart(
plot_lists,
use_median=False,
plot_width=None,
plot_height=None,
title=None,
value_i=-1,
):
"""
plot_lists is a list of lists.
Each outer list represents different y-axis attributes.
Each inner list represents different experiments to run, within that y-axis
attribute.
Each plot is an AttrDict which should have the elements used below.
"""
x_axis = [(subplot['plot_key'], subplot['means']) for plot_list in plot_lists for subplot in plot_list if subplot['x_key']]
plot_lists = [[subplot for subplot in plot_list] for plot_list in plot_lists if not plot_list[0]['x_key']]
xlabel = x_axis[0][0] if len(x_axis) else 'iteration'
p25, p50, p75 = [], [], []
num_y_axes = len(plot_lists)
fig = tools.make_subplots(
rows=num_y_axes,
cols=1,
print_grid=False,
shared_xaxes=True,
)
fig.layout.update(
width=plot_width,
height=plot_height,
title=title,
barmode='group',
)
all_plot_keys = []
for plot_list in plot_lists:
all_plot_keys.append(plot_list[0].plot_key)
traces = []
num_exps = len(plot_lists[0])
for y_idx, plot_list in enumerate(plot_lists):
traces = []
y_idx_plotly = y_idx + 1
for plt_idx, plt in enumerate(plot_list):
if use_median:
value = plt.percentile50[value_i]
error = plt.percentile75[value_i] - value
error_minus = value - plt.percentile25[value_i]
else:
value = np.mean(plt.means)
error = plt.stds[value_i]
error_minus = plt.stds[value_i]
# convert numpy scalar to number
# value = value.item()
# error = error.item()
# error_minus = error_minus.item()
trace = go.Bar(
x=[plt.legend],
y=[value],
# TODO: implement this correctly. I should give the option of
# choosing another field as the error bar for this field.
# Currently, this uses the own field to compute std. This might
# be correct, but often will be misleading (e.g. "std of mean"
# vs "mean of std" if each trial measures its own mean/std).
# error_y=dict(
# type='data',
# symmetric=False,
# array=[error],
# arrayminus=[error_minus],
# visible=True,
# ),
name=plt.legend,
showlegend=y_idx==0,
legendgroup=plt.legend,
marker=dict(
color=named_colors[plt_idx % len(named_colors)],
),
)
fig.append_trace(trace, y_idx_plotly, 1)
fig['layout']['yaxis{}'.format(y_idx_plotly)].update(
title=plt.plot_key,
)
fig_div = po.plot(
fig,
output_type='div',
include_plotlyjs=False,
)
if "footnote" in plot_list[0]:
footnote = " ".join([
r"%s: %s" % (
plt.legend, plt.footnote)
for plt in plot_list
])
return r"%s
%s
" % (fig_div, footnote)
else:
return fig_div
def make_plot(
plot_lists,
use_median=False,
plot_width=None,
plot_height=None,
title=None,
):
"""
plot_lists is a list of lists.
Each outer list represents different y-axis attributes.
Each inner list represents different experiments to run, within that y-axis
attribute.
Each plot is an AttrDict which should have the elements used below.
"""
x_axis = [(subplot['plot_key'], subplot['means']) for plot_list in plot_lists for subplot in plot_list if subplot['x_key']]
plot_lists = [[subplot for subplot in plot_list] for plot_list in plot_lists if not plot_list[0]['x_key']]
xlabel = x_axis[0][0] if len(x_axis) else 'iteration'
p25, p50, p75 = [], [], []
num_y_axes = len(plot_lists)
fig = tools.make_subplots(rows=num_y_axes, cols=1, print_grid=False)
fig['layout'].update(
width=plot_width,
height=plot_height,
title=title,
)
for y_idx, plot_list in enumerate(plot_lists):
for idx, plt in enumerate(plot_list):
color = core.color_defaults[idx % len(core.color_defaults)]
if use_median:
p25.append(np.mean(plt.percentile25))
p50.append(np.mean(plt.percentile50))
p75.append(np.mean(plt.percentile75))
if x_axis:
x = list(x_axis[idx][1])
else:
x = list(range(len(plt.percentile50)))
y = list(plt.percentile50)
y_upper = list(plt.percentile75)
y_lower = list(plt.percentile25)
else:
if x_axis:
x = list(x_axis[idx][1])
else:
x = list(range(len(plt.means)))
y = list(plt.means)
y_upper = list(plt.means + plt.stds)
y_lower = list(plt.means - plt.stds)
errors = go.Scatter(
x=x + x[::-1],
y=y_upper + y_lower[::-1],
fill='tozerox',
fillcolor=core.hex_to_rgb(color, 0.2),
line=go.scatter.Line(color=core.hex_to_rgb(color, 0)),
showlegend=False,
legendgroup=plt.legend,
hoverinfo='none',
)
values = go.Scatter(
x=x,
y=y,
name=plt.legend,
legendgroup=plt.legend,
line=dict(color=core.hex_to_rgb(color)),
hoverlabel=dict(namelength=-1),
hoverinfo='all',
)
# plotly is 1-indexed like matplotlib for subplots
y_idx_plotly = y_idx + 1
fig.append_trace(values, y_idx_plotly, 1)
fig.append_trace(errors, y_idx_plotly, 1)
title = plt.plot_key
if len(title) > 30:
title_parts = title.split('/')
title = " /".join(
title_parts[:-1]
+ [r"{}".format(t) for t in title_parts[-1:]]
)
fig['layout']['yaxis{}'.format(y_idx_plotly)].update(
title=title,
)
fig['layout']['xaxis{}'.format(y_idx_plotly)].update(
title=xlabel,
)
fig_div = po.plot(fig, output_type='div', include_plotlyjs=False)
if "footnote" in plot_list[0]:
footnote = " ".join([
r"%s: %s" % (
plt.legend, plt.footnote)
for plt in plot_list
])
return r"%s
%s
" % (fig_div, footnote)
else:
return fig_div
def make_plot_eps(plot_list, use_median=False, counter=0):
import matplotlib.pyplot as _plt
f, ax = _plt.subplots(figsize=(8, 5))
for idx, plt in enumerate(plot_list):
color = core.color_defaults[idx % len(core.color_defaults)]
if use_median:
x = list(range(len(plt.percentile50)))
y = list(plt.percentile50)
y_upper = list(plt.percentile75)
y_lower = list(plt.percentile25)
else:
x = list(range(len(plt.means)))
y = list(plt.means)
y_upper = list(plt.means + plt.stds)
y_lower = list(plt.means - plt.stds)
plt.legend = plt.legend.replace('rllab.algos.trpo.TRPO', 'TRPO')
plt.legend = plt.legend.replace('rllab.algos.vpg.VPG', 'REINFORCE')
plt.legend = plt.legend.replace('rllab.algos.erwr.ERWR', 'ERWR')
plt.legend = plt.legend.replace('sandbox.rein.algos.trpo_vime.TRPO',
'TRPO+VIME')
plt.legend = plt.legend.replace('sandbox.rein.algos.vpg_vime.VPG',
'REINFORCE+VIME')
plt.legend = plt.legend.replace('sandbox.rein.algos.erwr_vime.ERWR',
'ERWR+VIME')
plt.legend = plt.legend.replace('0.0001', '1e-4')
# plt.legend = plt.legend.replace('0.001', 'TRPO+VIME')
# plt.legend = plt.legend.replace('0', 'TRPO')
# plt.legend = plt.legend.replace('0.005', 'TRPO+L2')
if idx == 0:
plt.legend = 'TRPO (0.0)'
if idx == 1:
plt.legend = 'TRPO+VIME (103.7)'
if idx == 2:
plt.legend = 'TRPO+L2 (0.0)'
ax.fill_between(
x, y_lower, y_upper, interpolate=True, facecolor=color,
linewidth=0.0, alpha=0.3)
if idx == 2:
ax.plot(x, y, color=color, label=plt.legend, linewidth=2.0,
linestyle="--")
else:
ax.plot(x, y, color=color, label=plt.legend, linewidth=2.0)
ax.grid(True)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
if counter == 1:
# ax.set_xlim([0, 120])
ax.set_ylim([-3, 60])
# ax.set_xlim([0, 80])
loc = 'upper left'
elif counter == 2:
ax.set_ylim([-0.04, 0.4])
# ax.set_ylim([-0.1, 0.4])
ax.set_xlim([0, 2000])
loc = 'upper left'
elif counter == 3:
# ax.set_xlim([0, 1000])
loc = 'lower right'
elif counter == 4:
# ax.set_xlim([0, 800])
# ax.set_ylim([0, 2])
loc = 'lower right'
leg = ax.legend(loc=loc, prop={'size': 12}, ncol=1)
for legobj in leg.legendHandles:
legobj.set_linewidth(5.0)
def y_fmt(x, y):
return str(int(np.round(x / 1000.0))) + 'K'
import matplotlib.ticker as tick
# ax.xaxis.set_major_formatter(tick.FuncFormatter(y_fmt))
_plt.savefig('tmp' + str(counter) + '.pdf', bbox_inches='tight')
def summary_name(exp, selector=None):
# if selector is not None:
# exclude_params = set([x[0] for x in selector._filters])
# else:
# exclude_params = set()
# rest_params = set([x[0] for x in distinct_params]).difference(exclude_params)
# if len(rest_params) > 0:
# name = ""
# for k in rest_params:
# name += "%s=%s;" % (k.split(".")[-1], str(exp.flat_params.get(k, "")).split(".")[-1])
# return name
return exp.params["exp_name"]
def check_nan(exp):
return all(
not np.any(np.isnan(vals)) for vals in list(exp.progress.values()))
def get_plot_instruction(
plot_keys,
x_keys=None,
split_keys=None,
group_keys=None,
best_filter_key=None,
filters=None,
exclusions=None,
use_median=False,
only_show_best=False,
best_based_on_final=False,
gen_eps=False,
only_show_best_sofar=False,
best_is_lowest=False,
clip_plot_value=None,
plot_width=None,
plot_height=None,
filter_nan=False,
smooth_curve=False,
custom_filter=None,
legend_post_processor=None,
normalize_error=False,
make_bar_chart=False,
value_i=-1, # TODO: add option to set value_i
custom_series_splitter=None,
):
if x_keys is None:
x_keys = []
if x_keys:
assert len(x_keys) == 1
if x_keys[0] is None:
x_keys = []
plot_keys = x_keys + plot_keys
"""
A custom filter might look like
"lambda exp: exp.flat_params['algo_params_base_kwargs.batch_size'] == 64"
"""
if filter_nan:
nonnan_exps_data = list(filter(check_nan, exps_data))
selector = core.Selector(nonnan_exps_data)
else:
selector = core.Selector(exps_data)
if legend_post_processor is None:
legend_post_processor = lambda x: x
if filters is None:
filters = dict()
if exclusions is None:
exclusions = []
if split_keys is None:
split_keys = []
if group_keys is None:
group_keys = []
if plot_height is None:
plot_height = 300 * len(plot_keys)
for k, v in filters.items():
selector = selector.where(k, str(v))
for k, v in exclusions:
selector = selector.where_not(k, str(v))
if custom_filter is not None:
selector = selector.custom_filter(custom_filter)
if len(split_keys) > 0:
split_selectors, split_titles = split_by_keys(
selector, split_keys, distinct_params
)
else:
split_selectors = [selector]
split_titles = ["Plot"]
plots = []
counter = 1
print("Plot_keys:", plot_keys)
print("X keys:", x_keys)
print("split_keys:", split_keys)
print("group_keys:", group_keys)
print("filters:", filters)
print("exclusions:", exclusions)
for split_selector, split_title in zip(split_selectors, split_titles):
if custom_series_splitter is not None:
exps = split_selector.extract()
splitted_dict = dict()
for exp in exps:
key = custom_series_splitter(exp)
if key not in splitted_dict:
splitted_dict[key] = list()
splitted_dict[key].append(exp)
splitted = list(splitted_dict.items())
group_selectors = [core.Selector(list(x[1])) for x in splitted]
group_legends = [x[0] for x in splitted]
else:
if len(group_keys) > 0:
group_selectors, group_legends = split_by_keys(
split_selector, group_keys, distinct_params
)
else:
group_selectors = [split_selector]
group_legends = [split_title]
list_of_list_of_plot_dicts = []
for plot_ind, plot_key in enumerate(plot_keys):
to_plot = []
for group_selector, group_legend in zip(group_selectors, group_legends):
filtered_data = group_selector.extract()
if len(filtered_data) == 0:
continue
if (best_filter_key
and best_filter_key not in group_keys
and best_filter_key not in split_keys):
selectors = split_by_key(
group_selector, best_filter_key, distinct_params
)
scores = [
get_selector_score(plot_key, selector, use_median, best_based_on_final)
for selector in selectors
]
if np.isfinite(scores).any():
if best_is_lowest:
best_idx = np.nanargmin(scores)
else:
best_idx = np.nanargmax(scores)
best_selector = selectors[best_idx]
filtered_data = best_selector.extract()
print("For split '{0}', group '{1}':".format(
split_title,
group_legend,
))
print(" best '{0}': {1}".format(
best_filter_key,
dict(best_selector._filters)[best_filter_key]
))
if only_show_best or only_show_best_sofar:
# Group by seed and sort.
# -----------------------
filtered_params = core.extract_distinct_params(
filtered_data, l=0)
filtered_params2 = [p[1] for p in filtered_params]
filtered_params_k = [p[0] for p in filtered_params]
product_space = list(itertools.product(
*filtered_params2
))
data_best_regret = None
best_regret = np.inf if best_is_lowest else -np.inf
kv_string_best_regret = None
for idx, params in enumerate(product_space):
selector = core.Selector(exps_data)
for k, v in zip(filtered_params_k, params):
selector = selector.where(k, str(v))
data = selector.extract()
if len(data) > 0:
progresses = [
exp.progress.get(plot_key, np.array([np.nan]))
for exp in data
]
sizes = list(map(len, progresses))
max_size = max(sizes)
progresses = [
np.concatenate(
[ps, np.ones(max_size - len(ps)) * np.nan])
for ps in progresses]
if best_based_on_final:
progresses = np.asarray(progresses)[:, -1]
if only_show_best_sofar:
if best_is_lowest:
progresses = np.min(np.asarray(progresses),
axis=1)
else:
progresses = np.max(np.asarray(progresses),
axis=1)
if use_median:
medians = np.nanmedian(progresses, axis=0)
regret = np.mean(medians)
else:
means = np.nanmean(progresses, axis=0)
regret = np.mean(means)
distinct_params_k = [p[0] for p in distinct_params]
distinct_params_v = [
v for k, v in zip(filtered_params_k, params) if
k in distinct_params_k]
distinct_params_kv = [
(k, v) for k, v in
zip(distinct_params_k, distinct_params_v)]
distinct_params_kv_string = str(
distinct_params_kv).replace('), ', ')\t')
print(
'{}\t{}\t{}'.format(regret, len(progresses),
distinct_params_kv_string))
if best_is_lowest:
change_regret = regret < best_regret
else:
change_regret = regret > best_regret
if change_regret:
best_regret = regret
best_progress = progresses
data_best_regret = data
kv_string_best_regret = distinct_params_kv_string
print(group_selector._filters)
print('best regret: {}'.format(best_regret))
# -----------------------
if np.isfinite(best_regret):
progresses = [
exp.progress.get(plot_key, np.array([np.nan])) for
exp in data_best_regret]
# progresses = [progress[:500] for progress in progresses ]
sizes = list(map(len, progresses))
# more intelligent:
max_size = max(sizes)
progresses = [
np.concatenate(
[ps, np.ones(max_size - len(ps)) * np.nan]) for
ps in progresses]
legend = '{} (mu: {:.3f}, std: {:.5f})'.format(
group_legend, best_regret, np.std(best_progress))
window_size = np.maximum(
int(np.round(max_size / float(1000))), 1)
statistics = get_statistics(
progresses, use_median, normalize_error,
)
statistics = process_statistics(
statistics,
smooth_curve,
clip_plot_value,
window_size,
)
to_plot.append(
AttrDict(
legend=legend_post_processor(legend),
plot_key=plot_key,
**statistics
)
)
if len(to_plot) > 0 and len(data) > 0:
to_plot[-1]["footnote"] = "%s; e.g. %s" % (
kv_string_best_regret,
data[0].params.get("exp_name", "NA"))
else:
to_plot[-1]["footnote"] = ""
else:
progresses = [
exp.progress.get(plot_key, np.array([np.nan])) for exp
in filtered_data
]
sizes = list(map(len, progresses))
# more intelligent:
max_size = max(sizes)
progresses = [
np.concatenate(
[ps, np.ones(max_size - len(ps)) * np.nan]) for ps
in progresses]
window_size = np.maximum(
int(np.round(max_size / float(100))),
1,
)
statistics = get_statistics(
progresses, use_median, normalize_error,
)
statistics = process_statistics(
statistics,
smooth_curve,
clip_plot_value,
window_size,
)
to_plot.append(
AttrDict(
legend=legend_post_processor(group_legend),
plot_key=plot_key,
x_key=plot_key in x_keys and plot_ind == 0,
**statistics
)
)
if len(to_plot) > 0:
list_of_list_of_plot_dicts.append(to_plot)
if len(list_of_list_of_plot_dicts) > 0 and not gen_eps:
fig_title = split_title
if make_bar_chart:
plots.append(create_bar_chart(
list_of_list_of_plot_dicts,
use_median=use_median, title=fig_title,
plot_width=plot_width, plot_height=plot_height,
value_i=value_i,
))
else:
plots.append(make_plot(
list_of_list_of_plot_dicts,
use_median=use_median, title=fig_title,
plot_width=plot_width, plot_height=plot_height
))
if gen_eps:
make_plot_eps(to_plot, use_median=use_median, counter=counter)
counter += 1
return "\n".join(plots)
def shorten_key(key):
"""
Convert a dot-map string like "foo.bar.baz" into "f.b.baz"
"""
*heads, tail = key.split(".")
new_key_builder = []
for subkey in heads:
if len(subkey) > 0:
new_key_builder.append(subkey[0])
new_key_builder.append(tail)
return ".".join(new_key_builder)
def get_selector_score(key, selector, use_median, best_based_on_final):
"""
:param key: Thing to measure (e.g. Average Returns, Loss, etc.)
:param selector: Selector instance
:param use_median: Use the median? Else use the mean
:param best_based_on_final: Only look at the final value? Else use all
values.
:return: A single number that gives the score of `key` inside `selector`
"""
data = selector.extract()
if best_based_on_final:
values = [
exp.progress.get(key, np.array([np.nan]))[-1]
for exp in data
]
else:
values = np.concatenate([
exp.progress.get(key, np.array([np.nan]))
for exp in data
] or [[np.nan]])
if len(values) == 0 or not np.isfinite(values).all():
return np.nan
if use_median:
return np.nanpercentile(values, q=50, axis=0)
else:
return np.nanmean(values)
def get_statistics(progresses, use_median, normalize_errors):
"""
Get some dictionary of statistics (e.g. the median, mean).
:param progresses:
:param use_median:
:param normalize_errors:
:return:
"""
if use_median:
return dict(
percentile25=np.nanpercentile(progresses, q=25, axis=0),
percentile50=np.nanpercentile(progresses, q=50, axis=0),
percentile75=np.nanpercentile(progresses, q=75, axis=0),
)
else:
stds = np.nanstd(progresses, axis=0)
if normalize_errors:
stds /= np.sqrt(np.sum((1. - np.isnan(progresses)), axis=0))
return dict(
means=np.nanmean(progresses, axis=0),
stds=stds,
)
def process_statistics(
statistics,
smooth_curve,
clip_plot_value,
window_size
):
"""
Smoothen and clip time-series data.
"""
clean_statistics = {}
for k, v in statistics.items():
clean_statistics[k] = v
if smooth_curve:
clean_statistics[k] = sliding_mean(v, window=window_size)
if clip_plot_value is not None:
clean_statistics[k] = np.clip(
clean_statistics[k],
-clip_plot_value,
clip_plot_value,
)
return clean_statistics
def get_possible_values(distinct_params, key):
return [vs for k, vs in distinct_params if k == key][0]
def split_by_key(selector, key, distinct_params):
"""
Return a list of selectors based on this selector.
Each selector represents one distinct value of `key`.
"""
values = get_possible_values(distinct_params, key)
return [selector.where(key, v) for v in values]
def split_by_keys(base_selector, keys, distinct_params):
"""
Return a list of selectors based on the base_selector.
Each selector represents one distinct set of values for each key in `keys`.
:param base_selector:
:param keys:
:param distinct_params:
:return:
"""
list_of_key_and_unique_value = [
[
(key, v)
for v in get_possible_values(distinct_params, key)
]
for key in keys
]
"""
elements of list_of_key_and_unique_value should look like:
- [(color, red), (color, blue), (color, green), ...]
- [(season, spring), (season, summer), (season, fall), ...]
We now take the cartesian product so that we get all the
combinations, like:
- [(color, red), (season, spring)]
- [(color, blue), (season, spring)]
- ...
"""
selectors = []
descriptions = []
for key_and_value_list in itertools.product(
*list_of_key_and_unique_value
):
selector = None
keys = []
for key, value in key_and_value_list:
keys.append(key)
if selector is None:
selector = base_selector.where(key, value)
else:
selector = selector.where(key, value)
selectors.append(selector)
descriptions.append(", ".join([
"{0}={1}".format(
shorten_key(key),
value,
)
for key, value in key_and_value_list
]))
return selectors, descriptions
def parse_float_arg(args, key):
x = args.get(key, "")
try:
return float(x)
except Exception:
return None
@app.route("/plot_div")
def plot_div():
args = flask.request.args
plot_keys_json = args.get("plot_keys")
plot_keys = json.loads(plot_keys_json)
x_keys_json = args.get("x_keys")
x_keys = json.loads(x_keys_json)
split_keys_json = args.get("split_keys", "[]")
split_keys = json.loads(split_keys_json)
group_keys_json = args.get("group_keys", "[]")
group_keys = json.loads(group_keys_json)
best_filter_key = args.get("best_filter_key", "")
filters_json = args.get("filters", "{}")
filters = json.loads(filters_json)
exclusions_json = args.get("exclusions", "{}")
exclusions = json.loads(exclusions_json)
if len(best_filter_key) == 0:
best_filter_key = None
use_median = args.get("use_median", "") == 'True'
gen_eps = args.get("eps", "") == 'True'
only_show_best = args.get("only_show_best", "") == 'True'
best_based_on_final = args.get("best_based_on_final", "") == 'True'
only_show_best_sofar = args.get("only_show_best_sofar", "") == 'True'
best_is_lowest = args.get("best_is_lowest", "") == 'True'
normalize_error = args.get("normalize_error", "") == 'True'
make_bar_chart = args.get("make_bar_chart", "") == 'True'
filter_nan = args.get("filter_nan", "") == 'True'
smooth_curve = args.get("smooth_curve", "") == 'True'
clip_plot_value = parse_float_arg(args, "clip_plot_value")
plot_width = parse_float_arg(args, "plot_width")
plot_height = parse_float_arg(args, "plot_height")
custom_filter = args.get("custom_filter", None)
custom_series_splitter = args.get("custom_series_splitter", None)
if custom_filter is not None and len(custom_filter.strip()) > 0:
custom_filter = safer_eval(custom_filter)
else:
custom_filter = None
legend_post_processor = args.get("legend_post_processor", None)
if legend_post_processor is not None and len(
legend_post_processor.strip()) > 0:
legend_post_processor = safer_eval(legend_post_processor)
else:
legend_post_processor = None
if custom_series_splitter is not None and len(
custom_series_splitter.strip()) > 0:
custom_series_splitter = safer_eval(custom_series_splitter)
else:
custom_series_splitter = None
plot_div = get_plot_instruction(
plot_keys=plot_keys,
x_keys=x_keys,
split_keys=split_keys,
filter_nan=filter_nan,
group_keys=group_keys,
best_filter_key=best_filter_key,
filters=filters,
exclusions=exclusions,
use_median=use_median,
gen_eps=gen_eps,
only_show_best=only_show_best,
best_based_on_final=best_based_on_final,
only_show_best_sofar=only_show_best_sofar,
best_is_lowest=best_is_lowest,
clip_plot_value=clip_plot_value,
plot_width=plot_width,
plot_height=plot_height,
smooth_curve=smooth_curve,
custom_filter=custom_filter,
legend_post_processor=legend_post_processor,
normalize_error=normalize_error,
make_bar_chart=make_bar_chart,
custom_series_splitter=custom_series_splitter,
)
return plot_div
def safer_eval(some_string):
"""
Not full-proof, but taking advice from:
https://nedbatchelder.com/blog/201206/eval_really_is_dangerous.html
"""
if "__" in some_string or "import" in some_string:
raise Exception("string to eval looks suspicious")
return eval(some_string, {'__builtins__': {}})
@app.route("/")
def index():
if "AverageReturn" in plottable_keys:
plot_keys = ["AverageReturn"]
elif 'training/return-average' in plottable_keys:
plot_keys = ['training/return-average']
elif len(plottable_keys) > 0:
plot_keys = plottable_keys[0:1]
else:
plot_keys = None
plot_div = get_plot_instruction(plot_keys=plot_keys)
return flask.render_template(
"main.html",
plot_div=plot_div,
plot_keys=plot_keys,
group_keys=[],
plottable_keys=plottable_keys,
distinct_param_keys=[str(k) for k, v in distinct_params],
distinct_params=dict([(str(k), list(map(str, v)))
for k, v in distinct_params]),
)
@app.route("/reload-data", methods=['POST'])
def reload():
reload_data()
return 'Reloaded'
def reload_data():
global exps_data
global plottable_keys
global distinct_params
exps_data = core.load_exps_data(
args.data_paths,
args.data_filename,
args.params_filename,
args.disable_variant,
)
plottable_keys = list(
set(flatten(list(exp.progress.keys()) for exp in exps_data)))
plottable_keys = sorted([k for k in plottable_keys if k is not None])
distinct_params = sorted(core.extract_distinct_params(exps_data))
def main():
global args
parser = argparse.ArgumentParser()
parser.add_argument("data_paths", type=str, nargs='*')
parser.add_argument("--prefix", type=str, nargs='?', default="???")
parser.add_argument("--debug", action="store_true", default=False)
parser.add_argument("--port", type=int, default=5000)
parser.add_argument("--disable-variant", default=False, action='store_true')
parser.add_argument("--data-filename",
default='progress.csv',
help='name of data file.')
parser.add_argument("--params-filename",
default='params.json',
help='name of params file.')
args = parser.parse_args(sys.argv[1:])
# load all folders following a prefix
if args.prefix != "???":
args.data_paths = []
dirname = os.path.dirname(args.prefix)
subdirprefix = os.path.basename(args.prefix)
for subdirname in os.listdir(dirname):
path = os.path.join(dirname, subdirname)
if os.path.isdir(path) and (subdirprefix in subdirname):
args.data_paths.append(path)
print("Importing data from {path}...".format(path=args.data_paths))
reload_data()
port = args.port
try:
print("View http://localhost:%d in your browser" % port)
app.run(host='0.0.0.0', port=port, debug=args.debug)
except OSError as e:
if e.strerror == 'Address already in use':
print("Port {} is busy. Try specifying a different port with ("
"e.g.) --port=5001".format(port))
if __name__ == "__main__":
main()
================================================
FILE: viskit/logging.py
================================================
"""
File taken from RLKit (https://github.com/vitchyr/rlkit).
Based on rllab's logger.
https://github.com/rll/rllab
"""
from enum import Enum
from contextlib import contextmanager
import numpy as np
import os
import os.path as osp
import sys
import datetime
import dateutil.tz
import csv
import json
import pickle
import errno
import time
import tempfile
from viskit.tabulate import tabulate
class TerminalTablePrinter(object):
def __init__(self):
self.headers = None
self.tabulars = []
def print_tabular(self, new_tabular):
if self.headers is None:
self.headers = [x[0] for x in new_tabular]
else:
assert len(self.headers) == len(new_tabular)
self.tabulars.append([x[1] for x in new_tabular])
self.refresh()
def refresh(self):
import os
rows, columns = os.popen('stty size', 'r').read().split()
tabulars = self.tabulars[-(int(rows) - 3):]
sys.stdout.write("\x1b[2J\x1b[H")
sys.stdout.write(tabulate(tabulars, self.headers))
sys.stdout.write("\n")
class MyEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, type):
return {'$class': o.__module__ + "." + o.__name__}
elif isinstance(o, Enum):
return {
'$enum': o.__module__ + "." + o.__class__.__name__ + '.' + o.name
}
elif callable(o):
return {
'$function': o.__module__ + "." + o.__name__
}
return json.JSONEncoder.default(self, o)
def mkdir_p(path):
try:
os.makedirs(path)
except OSError as exc: # Python >2.5
if exc.errno == errno.EEXIST and os.path.isdir(path):
pass
else:
raise
class Logger(object):
def __init__(self):
self._prefixes = []
self._prefix_str = ''
self._tabular_prefixes = []
self._tabular_prefix_str = ''
self._tabular = []
self._text_outputs = []
self._tabular_outputs = []
self._text_fds = {}
self._tabular_fds = {}
self._tabular_header_written = set()
self._snapshot_dir = None
self._snapshot_mode = 'all'
self._snapshot_gap = 1
self._log_tabular_only = False
self._header_printed = False
self.table_printer = TerminalTablePrinter()
def reset(self):
self.__init__()
def _add_output(self, file_name, arr, fds, mode='a'):
if file_name not in arr:
mkdir_p(os.path.dirname(file_name))
arr.append(file_name)
fds[file_name] = open(file_name, mode)
def _remove_output(self, file_name, arr, fds):
if file_name in arr:
fds[file_name].close()
del fds[file_name]
arr.remove(file_name)
def push_prefix(self, prefix):
self._prefixes.append(prefix)
self._prefix_str = ''.join(self._prefixes)
def add_text_output(self, file_name):
self._add_output(file_name, self._text_outputs, self._text_fds,
mode='a')
def remove_text_output(self, file_name):
self._remove_output(file_name, self._text_outputs, self._text_fds)
def add_tabular_output(self, file_name, relative_to_snapshot_dir=False):
if relative_to_snapshot_dir:
file_name = osp.join(self._snapshot_dir, file_name)
self._add_output(file_name, self._tabular_outputs, self._tabular_fds,
mode='w')
def remove_tabular_output(self, file_name, relative_to_snapshot_dir=False):
if relative_to_snapshot_dir:
file_name = osp.join(self._snapshot_dir, file_name)
if self._tabular_fds[file_name] in self._tabular_header_written:
self._tabular_header_written.remove(self._tabular_fds[file_name])
self._remove_output(file_name, self._tabular_outputs, self._tabular_fds)
def set_snapshot_dir(self, dir_name):
self._snapshot_dir = dir_name
def get_snapshot_dir(self, ):
return self._snapshot_dir
def get_snapshot_mode(self, ):
return self._snapshot_mode
def set_snapshot_mode(self, mode):
self._snapshot_mode = mode
def get_snapshot_gap(self, ):
return self._snapshot_gap
def set_snapshot_gap(self, gap):
self._snapshot_gap = gap
def set_log_tabular_only(self, log_tabular_only):
self._log_tabular_only = log_tabular_only
def get_log_tabular_only(self, ):
return self._log_tabular_only
def log(self, s, with_prefix=True, with_timestamp=True):
out = s
if with_prefix:
out = self._prefix_str + out
if with_timestamp:
now = datetime.datetime.now(dateutil.tz.tzlocal())
timestamp = now.strftime('%Y-%m-%d %H:%M:%S.%f %Z')
out = "%s | %s" % (timestamp, out)
if not self._log_tabular_only:
# Also log to stdout
print(out)
for fd in list(self._text_fds.values()):
fd.write(out + '\n')
fd.flush()
sys.stdout.flush()
def record_tabular(self, key, val):
self._tabular.append((self._tabular_prefix_str + str(key), str(val)))
def record_dict(self, d, prefix=None):
if prefix is not None:
self.push_tabular_prefix(prefix)
for k, v in d.items():
self.record_tabular(k, v)
if prefix is not None:
self.pop_tabular_prefix()
def push_tabular_prefix(self, key):
self._tabular_prefixes.append(key)
self._tabular_prefix_str = ''.join(self._tabular_prefixes)
def pop_tabular_prefix(self, ):
del self._tabular_prefixes[-1]
self._tabular_prefix_str = ''.join(self._tabular_prefixes)
def save_extra_data(self, data, file_name='extra_data.pkl', mode='joblib'):
"""
Data saved here will always override the last entry
:param data: Something pickle'able.
"""
file_name = osp.join(self._snapshot_dir, file_name)
if mode == 'joblib':
import joblib
joblib.dump(data, file_name, compress=3)
elif mode == 'pickle':
pickle.dump(data, open(file_name, "wb"))
else:
raise ValueError("Invalid mode: {}".format(mode))
return file_name
def get_table_dict(self, ):
return dict(self._tabular)
def get_table_key_set(self, ):
return set(key for key, value in self._tabular)
@contextmanager
def prefix(self, key):
self.push_prefix(key)
try:
yield
finally:
self.pop_prefix()
@contextmanager
def tabular_prefix(self, key):
self.push_tabular_prefix(key)
yield
self.pop_tabular_prefix()
def log_variant(self, log_file, variant_data):
mkdir_p(os.path.dirname(log_file))
with open(log_file, "w") as f:
json.dump(variant_data, f, indent=2, sort_keys=True, cls=MyEncoder)
def record_tabular_misc_stat(self, key, values, placement='back'):
if placement == 'front':
prefix = ""
suffix = key
else:
prefix = key
suffix = ""
if len(values) > 0:
self.record_tabular(prefix + "Average" + suffix, np.average(values))
self.record_tabular(prefix + "Std" + suffix, np.std(values))
self.record_tabular(prefix + "Median" + suffix, np.median(values))
self.record_tabular(prefix + "Min" + suffix, np.min(values))
self.record_tabular(prefix + "Max" + suffix, np.max(values))
else:
self.record_tabular(prefix + "Average" + suffix, np.nan)
self.record_tabular(prefix + "Std" + suffix, np.nan)
self.record_tabular(prefix + "Median" + suffix, np.nan)
self.record_tabular(prefix + "Min" + suffix, np.nan)
self.record_tabular(prefix + "Max" + suffix, np.nan)
def dump_tabular(self, *args, **kwargs):
wh = kwargs.pop("write_header", None)
if len(self._tabular) > 0:
if self._log_tabular_only:
self.table_printer.print_tabular(self._tabular)
else:
for line in tabulate(self._tabular).split('\n'):
self.log(line, *args, **kwargs)
tabular_dict = dict(self._tabular)
# Also write to the csv files
# This assumes that the keys in each iteration won't change!
for tabular_fd in list(self._tabular_fds.values()):
writer = csv.DictWriter(tabular_fd,
fieldnames=list(tabular_dict.keys()))
if wh or (
wh is None and tabular_fd not in self._tabular_header_written):
writer.writeheader()
self._tabular_header_written.add(tabular_fd)
writer.writerow(tabular_dict)
tabular_fd.flush()
del self._tabular[:]
def pop_prefix(self, ):
del self._prefixes[-1]
self._prefix_str = ''.join(self._prefixes)
def safe_json(data):
if data is None:
return True
elif isinstance(data, (bool, int, float)):
return True
elif isinstance(data, (tuple, list)):
return all(safe_json(x) for x in data)
elif isinstance(data, dict):
return all(isinstance(k, str) and safe_json(v) for k, v in data.items())
return False
def dict_to_safe_json(d):
"""
Convert each value in the dictionary into a JSON'able primitive.
:param d:
:return:
"""
new_d = {}
for key, item in d.items():
if safe_json(item):
new_d[key] = item
else:
if isinstance(item, dict):
new_d[key] = dict_to_safe_json(item)
else:
new_d[key] = str(item)
return new_d
def create_exp_name(exp_prefix, exp_id=0, seed=0):
"""
Create a semi-unique experiment name that has a timestamp
:param exp_prefix:
:param exp_id:
:return:
"""
now = datetime.datetime.now(dateutil.tz.tzlocal())
timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
return "%s_%s-s-%d--%s" % (exp_prefix, timestamp, seed, str(exp_id))
def create_log_dir(
exp_prefix,
exp_id=0,
seed=0,
base_log_dir=None,
include_exp_prefix_sub_dir=True,
):
"""
Creates and returns a unique log directory.
:param exp_prefix: All experiments with this prefix will have log
directories be under this directory.
:param exp_id: The number of the specific experiment run within this
experiment.
:param base_log_dir: The directory where all log should be saved.
:return:
"""
exp_name = create_exp_name(exp_prefix, exp_id=exp_id,
seed=seed)
if base_log_dir is None:
base_log_dir = conf.LOCAL_LOG_DIR
# if include_exp_prefix_sub_dir:
# log_dir = osp.join(base_log_dir, exp_prefix.replace("_", "-"), exp_name)
# else:
# log_dir = osp.join(base_log_dir, exp_name)
log_dir = base_log_dir
if osp.exists(log_dir):
print("WARNING: Log directory already exists {}".format(log_dir))
os.makedirs(log_dir, exist_ok=True)
return log_dir
def setup_logger(
exp_prefix="default",
variant=None,
text_log_file="debug.log",
variant_log_file="variant.json",
tabular_log_file="progress.csv",
snapshot_mode="last",
snapshot_gap=1,
log_tabular_only=False,
base_log_dir=None,
**create_log_dir_kwargs
):
"""
Set up logger to have some reasonable default settings.
Will save log output to
based_log_dir/exp_prefix/exp_name.
exp_name will be auto-generated to be unique.
If log_dir is specified, then that directory is used as the output dir.
:param exp_prefix: The sub-directory for this specific experiment.
:param variant:
:param text_log_file:
:param variant_log_file:
:param tabular_log_file:
:param snapshot_mode:
:param log_tabular_only:
:param snapshot_gap:
:param log_dir:
:return:
"""
log_dir = create_log_dir(
exp_prefix, base_log_dir=base_log_dir, **create_log_dir_kwargs
)
if variant is not None:
logger.log("Variant:")
logger.log(json.dumps(dict_to_safe_json(variant), indent=2))
variant_log_path = osp.join(log_dir, variant_log_file)
logger.log_variant(variant_log_path, variant)
tabular_log_path = osp.join(log_dir, tabular_log_file)
text_log_path = osp.join(log_dir, text_log_file)
logger.add_text_output(text_log_path)
logger.add_tabular_output(tabular_log_path)
logger.set_snapshot_dir(log_dir)
logger.set_snapshot_mode(snapshot_mode)
logger.set_snapshot_gap(snapshot_gap)
logger.set_log_tabular_only(log_tabular_only)
exp_name = log_dir.split("/")[-1]
logger.push_prefix("[%s] " % exp_name)
return log_dir
logger = Logger()
================================================
FILE: viskit/static/css/dropdowns-enhancement.css
================================================
.dropdown-menu > li > label {
display: block;
padding: 3px 20px;
clear: both;
font-weight: normal;
line-height: 1.42857143;
color: #333333;
white-space: nowrap;
}
.dropdown-menu > li > label:hover,
.dropdown-menu > li > label:focus {
text-decoration: none;
color: #262626;
background-color: #f5f5f5;
}
.dropdown-menu > li > input:checked ~ label,
.dropdown-menu > li > input:checked ~ label:hover,
.dropdown-menu > li > input:checked ~ label:focus,
.dropdown-menu > .active > label,
.dropdown-menu > .active > label:hover,
.dropdown-menu > .active > label:focus {
color: #ffffff;
text-decoration: none;
outline: 0;
background-color: #428bca;
}
.dropdown-menu > li > input[disabled] ~ label,
.dropdown-menu > li > input[disabled] ~ label:hover,
.dropdown-menu > li > input[disabled] ~ label:focus,
.dropdown-menu > .disabled > label,
.dropdown-menu > .disabled > label:hover,
.dropdown-menu > .disabled > label:focus {
color: #999999;
}
.dropdown-menu > li > input[disabled] ~ label:hover,
.dropdown-menu > li > input[disabled] ~ label:focus,
.dropdown-menu > .disabled > label:hover,
.dropdown-menu > .disabled > label:focus {
text-decoration: none;
background-color: transparent;
background-image: none;
filter: progid:DXImageTransform.Microsoft.gradient(enabled = false);
cursor: not-allowed;
}
.dropdown-menu > li > label {
margin-bottom: 0;
cursor: pointer;
}
.dropdown-menu > li > input[type="radio"],
.dropdown-menu > li > input[type="checkbox"] {
display: none;
position: absolute;
top: -9999em;
left: -9999em;
}
.dropdown-menu > li > label:focus,
.dropdown-menu > li > input:focus ~ label {
outline: thin dotted;
outline: 5px auto -webkit-focus-ring-color;
outline-offset: -2px;
}
.dropdown-menu.pull-right {
right: 0;
left: auto;
}
.dropdown-menu.pull-top {
bottom: 100%;
top: auto;
margin: 0 0 2px;
-webkit-box-shadow: 0 -6px 12px rgba(0, 0, 0, 0.175);
box-shadow: 0 -6px 12px rgba(0, 0, 0, 0.175);
}
.dropdown-menu.pull-center {
right: 50%;
left: auto;
}
.dropdown-menu.pull-middle {
right: 100%;
margin: 0 2px 0 0;
box-shadow: -5px 0 10px rgba(0, 0, 0, 0.2);
left: auto;
}
.dropdown-menu.pull-middle.pull-right {
right: auto;
left: 100%;
margin: 0 0 0 2px;
box-shadow: 5px 0 10px rgba(0, 0, 0, 0.2);
}
.dropdown-menu.pull-middle.pull-center {
right: 50%;
margin: 0;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.2);
}
.dropdown-menu.bullet {
margin-top: 8px;
}
.dropdown-menu.bullet:before {
width: 0;
height: 0;
content: '';
display: inline-block;
position: absolute;
border-color: transparent;
border-style: solid;
-webkit-transform: rotate(360deg);
border-width: 0 7px 7px;
border-bottom-color: #cccccc;
border-bottom-color: rgba(0, 0, 0, 0.15);
top: -7px;
left: 9px;
}
.dropdown-menu.bullet:after {
width: 0;
height: 0;
content: '';
display: inline-block;
position: absolute;
border-color: transparent;
border-style: solid;
-webkit-transform: rotate(360deg);
border-width: 0 6px 6px;
border-bottom-color: #ffffff;
top: -6px;
left: 10px;
}
.dropdown-menu.bullet.pull-right:before {
left: auto;
right: 9px;
}
.dropdown-menu.bullet.pull-right:after {
left: auto;
right: 10px;
}
.dropdown-menu.bullet.pull-top {
margin-top: 0;
margin-bottom: 8px;
}
.dropdown-menu.bullet.pull-top:before {
top: auto;
bottom: -7px;
border-bottom-width: 0;
border-top-width: 7px;
border-top-color: #cccccc;
border-top-color: rgba(0, 0, 0, 0.15);
}
.dropdown-menu.bullet.pull-top:after {
top: auto;
bottom: -6px;
border-bottom: none;
border-top-width: 6px;
border-top-color: #ffffff;
}
.dropdown-menu.bullet.pull-center:before {
left: auto;
right: 50%;
margin-right: -7px;
}
.dropdown-menu.bullet.pull-center:after {
left: auto;
right: 50%;
margin-right: -6px;
}
.dropdown-menu.bullet.pull-middle {
margin-right: 8px;
}
.dropdown-menu.bullet.pull-middle:before {
top: 50%;
left: 100%;
right: auto;
margin-top: -7px;
border-right-width: 0;
border-bottom-color: transparent;
border-top-width: 7px;
border-left-color: #cccccc;
border-left-color: rgba(0, 0, 0, 0.15);
}
.dropdown-menu.bullet.pull-middle:after {
top: 50%;
left: 100%;
right: auto;
margin-top: -6px;
border-right-width: 0;
border-bottom-color: transparent;
border-top-width: 6px;
border-left-color: #ffffff;
}
.dropdown-menu.bullet.pull-middle.pull-right {
margin-right: 0;
margin-left: 8px;
}
.dropdown-menu.bullet.pull-middle.pull-right:before {
left: -7px;
border-left-width: 0;
border-right-width: 7px;
border-right-color: #cccccc;
border-right-color: rgba(0, 0, 0, 0.15);
}
.dropdown-menu.bullet.pull-middle.pull-right:after {
left: -6px;
border-left-width: 0;
border-right-width: 6px;
border-right-color: #ffffff;
}
.dropdown-menu.bullet.pull-middle.pull-center {
margin-left: 0;
margin-right: 0;
}
.dropdown-menu.bullet.pull-middle.pull-center:before {
border: none;
display: none;
}
.dropdown-menu.bullet.pull-middle.pull-center:after {
border: none;
display: none;
}
.dropdown-submenu {
position: relative;
}
.dropdown-submenu > .dropdown-menu {
top: 0;
left: 100%;
margin-top: -6px;
margin-left: -1px;
border-top-left-radius: 0;
}
.dropdown-submenu > a:before {
display: block;
float: right;
width: 0;
height: 0;
content: "";
margin-top: 6px;
margin-right: -8px;
border-width: 4px 0 4px 4px;
border-style: solid;
border-left-style: dashed;
border-top-color: transparent;
border-bottom-color: transparent;
}
@media (max-width: 767px) {
.navbar-nav .dropdown-submenu > a:before {
margin-top: 8px;
border-color: inherit;
border-style: solid;
border-width: 4px 4px 0;
border-left-color: transparent;
border-right-color: transparent;
}
.navbar-nav .dropdown-submenu > a {
padding-left: 40px;
}
.navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > a,
.navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > label {
padding-left: 35px;
}
.navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > a,
.navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > label {
padding-left: 45px;
}
.navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > a,
.navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > label {
padding-left: 55px;
}
.navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > a,
.navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > label {
padding-left: 65px;
}
.navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > a,
.navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > label {
padding-left: 75px;
}
}
.navbar-default .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a,
.navbar-default .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a:hover,
.navbar-default .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a:focus {
background-color: #e7e7e7;
color: #555555;
}
@media (max-width: 767px) {
.navbar-default .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a:before {
border-top-color: #555555;
}
}
.navbar-inverse .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a,
.navbar-inverse .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a:hover,
.navbar-inverse .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a:focus {
background-color: #080808;
color: #ffffff;
}
@media (max-width: 767px) {
.navbar-inverse .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a:before {
border-top-color: #ffffff;
}
}
================================================
FILE: viskit/static/js/dropdowns-enhancement.js
================================================
/* ========================================================================
* Bootstrap Dropdowns Enhancement: dropdowns-enhancement.js v3.1.1 (Beta 1)
* http://behigh.github.io/bootstrap_dropdowns_enhancement/
* ========================================================================
* Licensed under MIT (https://github.com/twbs/bootstrap/blob/master/LICENSE)
* ======================================================================== */
(function($) {
"use strict";
var toggle = '[data-toggle="dropdown"]',
disabled = '.disabled, :disabled',
backdrop = '.dropdown-backdrop',
menuClass = 'dropdown-menu',
subMenuClass = 'dropdown-submenu',
namespace = '.bs.dropdown.data-api',
eventNamespace = '.bs.dropdown',
openClass = 'open',
touchSupport = 'ontouchstart' in document.documentElement,
opened;
function Dropdown(element) {
$(element).on('click' + eventNamespace, this.toggle)
}
var proto = Dropdown.prototype;
proto.toggle = function(event) {
var $element = $(this);
if ($element.is(disabled)) return;
var $parent = getParent($element);
var isActive = $parent.hasClass(openClass);
var isSubMenu = $parent.hasClass(subMenuClass);
var menuTree = isSubMenu ? getSubMenuParents($parent) : null;
closeOpened(event, menuTree);
if (!isActive) {
if (!menuTree)
menuTree = [$parent];
if (touchSupport && !$parent.closest('.navbar-nav').length && !menuTree[0].find(backdrop).length) {
// if mobile we use a backdrop because click events don't delegate
$('').appendTo(menuTree[0]).on('click', closeOpened)
}
for (var i = 0, s = menuTree.length; i < s; i++) {
if (!menuTree[i].hasClass(openClass)) {
menuTree[i].addClass(openClass);
positioning(menuTree[i].children('.' + menuClass), menuTree[i]);
}
}
opened = menuTree[0];
}
return false;
};
proto.keydown = function (e) {
if (!/(38|40|27)/.test(e.keyCode)) return;
var $this = $(this);
e.preventDefault();
e.stopPropagation();
if ($this.is('.disabled, :disabled')) return;
var $parent = getParent($this);
var isActive = $parent.hasClass('open');
if (!isActive || (isActive && e.keyCode == 27)) {
if (e.which == 27) $parent.find(toggle).trigger('focus');
return $this.trigger('click')
}
var desc = ' li:not(.divider):visible a';
var desc1 = 'li:not(.divider):visible > input:not(disabled) ~ label';
var $items = $parent.find(desc1 + ', ' + '[role="menu"]' + desc + ', [role="listbox"]' + desc);
if (!$items.length) return;
var index = $items.index($items.filter(':focus'));
if (e.keyCode == 38 && index > 0) index--; // up
if (e.keyCode == 40 && index < $items.length - 1) index++; // down
if (!~index) index = 0;
$items.eq(index).trigger('focus')
};
proto.change = function (e) {
var
$parent,
$menu,
$toggle,
selector,
text = '',
$items;
$menu = $(this).closest('.' + menuClass);
$toggle = $menu.parent().find('[data-label-placement]');
if (!$toggle || !$toggle.length) {
$toggle = $menu.parent().find(toggle);
}
if (!$toggle || !$toggle.length || $toggle.data('placeholder') === false)
return; // do nothing, no control
($toggle.data('placeholder') == undefined && $toggle.data('placeholder', $.trim($toggle.text())));
text = $.data($toggle[0], 'placeholder');
$items = $menu.find('li > input:checked');
if ($items.length) {
text = [];
$items.each(function () {
var str = $(this).parent().find('label').eq(0),
label = str.find('.data-label');
if (label.length) {
var p = $('');
p.append(label.clone());
str = p.html();
}
else {
str = str.html();
}
str && text.push($.trim(str));
});
text = text.length < 4 ? text.join(', ') : text.length + ' selected';
}
var caret = $toggle.find('.caret');
$toggle.html(text || ' ');
if (caret.length)
$toggle.append(' ') && caret.appendTo($toggle);
};
function positioning($menu, $control) {
if ($menu.hasClass('pull-center')) {
$menu.css('margin-right', $menu.outerWidth() / -2);
}
if ($menu.hasClass('pull-middle')) {
$menu.css('margin-top', ($menu.outerHeight() / -2) - ($control.outerHeight() / 2));
}
}
function closeOpened(event, menuTree) {
if (opened) {
if (!menuTree) {
menuTree = [opened];
}
var parent;
if (opened[0] !== menuTree[0][0]) {
parent = opened;
} else {
parent = menuTree[menuTree.length - 1];
if (parent.parent().hasClass(menuClass)) {
parent = parent.parent();
}
}
parent.find('.' + openClass).removeClass(openClass);
if (parent.hasClass(openClass))
parent.removeClass(openClass);
if (parent === opened) {
opened = null;
$(backdrop).remove();
}
}
}
function getSubMenuParents($submenu) {
var result = [$submenu];
var $parent;
while (!$parent || $parent.hasClass(subMenuClass)) {
$parent = ($parent || $submenu).parent();
if ($parent.hasClass(menuClass)) {
$parent = $parent.parent();
}
if ($parent.children(toggle)) {
result.unshift($parent);
}
}
return result;
}
function getParent($this) {
var selector = $this.attr('data-target');
if (!selector) {
selector = $this.attr('href');
selector = selector && /#[A-Za-z]/.test(selector) && selector.replace(/.*(?=#[^\s]*$)/, ''); //strip for ie7
}
var $parent = selector && $(selector);
return $parent && $parent.length ? $parent : $this.parent()
}
// DROPDOWN PLUGIN DEFINITION
// ==========================
var old = $.fn.dropdown;
$.fn.dropdown = function (option) {
return this.each(function () {
var $this = $(this);
var data = $this.data('bs.dropdown');
if (!data) $this.data('bs.dropdown', (data = new Dropdown(this)));
if (typeof option == 'string') data[option].call($this);
})
};
$.fn.dropdown.Constructor = Dropdown;
$.fn.dropdown.clearMenus = function(e) {
$(backdrop).remove();
$('.' + openClass + ' ' + toggle).each(function () {
var $parent = getParent($(this));
var relatedTarget = { relatedTarget: this };
if (!$parent.hasClass('open')) return;
$parent.trigger(e = $.Event('hide' + eventNamespace, relatedTarget));
if (e.isDefaultPrevented()) return;
$parent.removeClass('open').trigger('hidden' + eventNamespace, relatedTarget);
});
return this;
};
// DROPDOWN NO CONFLICT
// ====================
$.fn.dropdown.noConflict = function () {
$.fn.dropdown = old;
return this
};
$(document).off(namespace)
.on('click' + namespace, closeOpened)
.on('click' + namespace, toggle, proto.toggle)
.on('click' + namespace, '.dropdown-menu > li > input[type="checkbox"] ~ label, .dropdown-menu > li > input[type="checkbox"], .dropdown-menu.noclose > li', function (e) {
e.stopPropagation()
})
.on('change' + namespace, '.dropdown-menu > li > input[type="checkbox"], .dropdown-menu > li > input[type="radio"]', proto.change)
.on('keydown' + namespace, toggle + ', [role="menu"], [role="listbox"]', proto.keydown)
}(jQuery));
================================================
FILE: viskit/static/js/jquery.loadTemplate-1.5.6.js
================================================
(function ($) {
"use strict";
var templates = {},
queue = {},
formatters = {},
isArray;
function loadTemplate(template, data, options) {
var $that = this,
$template,
isFile,
settings;
data = data || {};
settings = $.extend(true, {
// These are the defaults.
async: true,
overwriteCache: false,
complete: null,
success: null,
error: function () {
$(this).each(function () {
$(this).html(settings.errorMessage);
});
},
errorMessage: "There was an error loading the template.",
paged: false,
pageNo: 1,
elemPerPage: 10,
append: false,
prepend: false,
beforeInsert: null,
afterInsert: null,
bindingOptions: {
ignoreUndefined: false,
ignoreNull: false,
ignoreEmptyString: false
}
}, options);
if ($.type(data) === "array") {
isArray = true;
return processArray.call(this, template, data, settings);
}
if (!containsSlashes(template)) {
$template = $(template);
if (typeof template === 'string' && template.indexOf('#') === 0) {
settings.isFile = false;
}
}
isFile = settings.isFile || (typeof settings.isFile === "undefined" && (typeof $template === "undefined" || $template.length === 0));
if (isFile && !settings.overwriteCache && templates[template]) {
prepareTemplateFromCache(template, $that, data, settings);
} else if (isFile && !settings.overwriteCache && templates.hasOwnProperty(template)) {
addToQueue(template, $that, data, settings);
} else if (isFile) {
loadAndPrepareTemplate(template, $that, data, settings);
} else {
loadTemplateFromDocument($template, $that, data, settings);
}
return this;
}
function addTemplateFormatter(key, formatter) {
if (formatter) {
formatters[key] = formatter;
} else {
formatters = $.extend(formatters, key);
}
}
function containsSlashes(str) {
return typeof str === "string" && str.indexOf("/") > -1;
}
function processArray(template, data, settings) {
settings = settings || {};
var $that = this,
todo = data.length,
doPrepend = settings.prepend && !settings.append,
done = 0,
success = 0,
errored = false,
errorObjects = [],
newOptions;
if (settings.paged) {
var startNo = (settings.pageNo - 1) * settings.elemPerPage;
data = data.slice(startNo, startNo + settings.elemPerPage);
todo = data.length;
}
newOptions = $.extend(
{},
settings,
{
async: false,
complete: function (data) {
if (this.html) {
var insertedElement;
if (doPrepend) {
insertedElement = $(this.html()).prependTo($that);
} else {
insertedElement = $(this.html()).appendTo($that);
}
if (settings.afterInsert && data) {
settings.afterInsert(insertedElement, data);
}
}
done++;
if (done === todo || errored) {
if (errored && settings && typeof settings.error === "function") {
settings.error.call($that, errorObjects);
}
if (settings && typeof settings.complete === "function") {
settings.complete();
}
}
},
success: function () {
success++;
if (success === todo) {
if (settings && typeof settings.success === "function") {
settings.success();
}
}
},
error: function (e) {
errored = true;
errorObjects.push(e);
}
}
);
if (!settings.append && !settings.prepend) {
$that.html("");
}
if (doPrepend) data.reverse();
$(data).each(function () {
var $div = $("");
loadTemplate.call($div, template, this, newOptions);
if (errored) {
return false;
}
});
return this;
}
function addToQueue(template, selection, data, settings) {
if (queue[template]) {
queue[template].push({ data: data, selection: selection, settings: settings });
} else {
queue[template] = [{ data: data, selection: selection, settings: settings}];
}
}
function prepareTemplateFromCache(template, selection, data, settings) {
var $templateContainer = templates[template].clone();
prepareTemplate.call(selection, $templateContainer, data, settings);
if (typeof settings.success === "function") {
settings.success();
}
}
function uniqueId() {
return new Date().getTime();
}
function urlAvoidCache(url) {
if (url.indexOf('?') !== -1) {
return url + "&_=" + uniqueId();
}
else {
return url + "?_=" + uniqueId();
}
}
function loadAndPrepareTemplate(template, selection, data, settings) {
var $templateContainer = $("");
templates[template] = null;
var templateUrl = template;
if (settings.overwriteCache) {
templateUrl = urlAvoidCache(templateUrl);
}
$.ajax({
url: templateUrl,
async: settings.async,
success: function (templateContent) {
$templateContainer.html(templateContent);
handleTemplateLoadingSuccess($templateContainer, template, selection, data, settings);
},
error: function (e) {
handleTemplateLoadingError(template, selection, data, settings, e);
}
});
}
function loadTemplateFromDocument($template, selection, data, settings) {
var $templateContainer = $("");
if ($template.is("script") || $template.is("template")) {
$template = $.parseHTML($.trim($template.html()));
}
$templateContainer.html($template);
prepareTemplate.call(selection, $templateContainer, data, settings);
if (typeof settings.success === "function") {
settings.success();
}
}
function prepareTemplate(template, data, settings) {
bindData(template, data, settings);
$(this).each(function () {
var $templateHtml = $(template.html());
if (settings.beforeInsert) {
settings.beforeInsert($templateHtml, data);
}
if (settings.append) {
$(this).append($templateHtml);
} else if (settings.prepend) {
$(this).prepend($templateHtml);
} else {
$(this).html($templateHtml);
}
if (settings.afterInsert && !isArray) {
settings.afterInsert($templateHtml, data);
}
});
if (typeof settings.complete === "function") {
settings.complete.call($(this), data);
}
}
function handleTemplateLoadingError(template, selection, data, settings, error) {
var value;
if (typeof settings.error === "function") {
settings.error.call(selection, error);
}
$(queue[template]).each(function (key, value) {
if (typeof value.settings.error === "function") {
value.settings.error.call(value.selection, error);
}
});
if (typeof settings.complete === "function") {
settings.complete.call(selection);
}
while (queue[template] && (value = queue[template].shift())) {
if (typeof value.settings.complete === "function") {
value.settings.complete.call(value.selection);
}
}
if (typeof queue[template] !== 'undefined' && queue[template].length > 0) {
queue[template] = [];
}
}
function handleTemplateLoadingSuccess($templateContainer, template, selection, data, settings) {
var value;
templates[template] = $templateContainer.clone();
prepareTemplate.call(selection, $templateContainer, data, settings);
if (typeof settings.success === "function") {
settings.success.call(selection);
}
while (queue[template] && (value = queue[template].shift())) {
prepareTemplate.call(value.selection, templates[template].clone(), value.data, value.settings);
if (typeof value.settings.success === "function") {
value.settings.success.call(value.selection);
}
}
}
function bindData(template, data, settings) {
data = data || {};
processElements("data-content", template, data, settings, function ($elem, value) {
$elem.html(applyFormatters($elem, value, "content", settings));
});
processElements("data-content-append", template, data, settings, function ($elem, value) {
$elem.append(applyFormatters($elem, value, "content", settings));
});
processElements("data-content-prepend", template, data, settings, function ($elem, value) {
$elem.prepend(applyFormatters($elem, value, "content", settings));
});
processElements("data-content-text", template, data, settings, function ($elem, value) {
$elem.text(applyFormatters($elem, value, "content", settings));
});
processElements("data-innerHTML", template, data, settings, function ($elem, value) {
$elem.html(applyFormatters($elem, value, "content", settings));
});
processElements("data-src", template, data, settings, function ($elem, value) {
$elem.attr("src", applyFormatters($elem, value, "src", settings));
}, function ($elem) {
$elem.remove();
});
processElements("data-href", template, data, settings, function ($elem, value) {
$elem.attr("href", applyFormatters($elem, value, "href", settings));
}, function ($elem) {
$elem.remove();
});
processElements("data-alt", template, data, settings, function ($elem, value) {
$elem.attr("alt", applyFormatters($elem, value, "alt", settings));
});
processElements("data-id", template, data, settings, function ($elem, value) {
$elem.attr("id", applyFormatters($elem, value, "id", settings));
});
processElements("data-value", template, data, settings, function ($elem, value) {
$elem.attr("value", applyFormatters($elem, value, "value", settings));
});
processElements("data-class", template, data, settings, function ($elem, value) {
$elem.addClass(applyFormatters($elem, value, "class", settings));
});
processElements("data-link", template, data, settings, function ($elem, value) {
var $linkElem = $("");
$linkElem.attr("href", applyFormatters($elem, value, "link", settings));
$linkElem.html($elem.html());
$elem.html($linkElem);
});
processElements("data-link-wrap", template, data, settings, function ($elem, value) {
var $linkElem = $("");
$linkElem.attr("href", applyFormatters($elem, value, "link-wrap", settings));
$elem.wrap($linkElem);
});
processElements("data-options", template, data, settings, function ($elem, value) {
$(value).each(function () {
var $option = $("");
$option.attr('value', this).text(this).appendTo($elem);
});
});
processAllElements(template, data, settings);
}
function processElements(attribute, template, data, settings, dataBindFunction, noDataFunction) {
$("[" + attribute + "]", template).each(function () {
var $this = $(this),
param = $this.attr(attribute),
value = getValue(data, param);
if (!valueIsAllowedByBindingOptions($this, value, settings)) {
$this.remove();
return;
}
$this.removeAttr(attribute);
if (typeof value !== 'undefined' && dataBindFunction) {
dataBindFunction($this, value);
} else if (noDataFunction) {
noDataFunction($this);
}
});
return;
}
function valueIsAllowedByBindingOptions(bindingOptionsContainer, value, settings) {
var bindingOptions = getBindingOptions(bindingOptionsContainer, settings);
if (bindingOptions.ignoreUndefined && typeof value === "undefined") {
return false;
} else if (bindingOptions.ignoreNull && value === null) {
return false;
} else if (bindingOptions.ignoreEmptyString && value === "") {
return false;
} else {
return true;
}
}
function getBindingOptions(bindingOptionsContainer, settings) {
var bindingOptions = {};
// binding options passed as template attribute, i.e. 'data-binding-options'
if (bindingOptionsContainer instanceof jQuery && bindingOptionsContainer.attr("data-binding-options")) {
bindingOptions = $.parseJSON(bindingOptionsContainer.attr("data-binding-options"));
bindingOptionsContainer.removeAttr("data-binding-options");
// binding options defined in a "data-template-bind" attribute
} else if (typeof bindingOptionsContainer === "object" && bindingOptionsContainer.hasOwnProperty('bindingOptions')) {
bindingOptions = bindingOptionsContainer.bindingOptions;
}
// extend general bindingOptions with specific settings
return $.extend({}, settings.bindingOptions, bindingOptions);
}
function processAllElements(template, data, settings) {
$("[data-template-bind]", template).each(function () {
var $this = $(this),
param = $.parseJSON($this.attr("data-template-bind"));
$this.removeAttr("data-template-bind");
$(param).each(function () {
var value;
if (typeof (this.value) === 'object') {
value = getValue(data, this.value.data);
} else {
value = getValue(data, this.value);
}
if (this.attribute) {
if (!valueIsAllowedByBindingOptions(this, value, settings)) {
$this.remove();
return;
}
switch (this.attribute) {
case "content":
case "innerHTML":
$this.html(applyDataBindFormatters($this, value, this));
break;
case "contentAppend":
$this.append(applyDataBindFormatters($this, value, this));
break;
case "contentPrepend":
$this.prepend(applyDataBindFormatters($this, value, this));
break;
case "contentText":
$this.text(applyDataBindFormatters($this, value, this));
break;
case "options":
var optionsData = this;
$(value).each(function () {
var $option = $("");
$option
.attr('value', this[optionsData.value.value])
.text(applyDataBindFormatters($this, this[optionsData.value.content], optionsData))
.attr('selected', typeof this[optionsData.value.selected] == undefined ? false : this[optionsData.value.selected])
.appendTo($this);
});
break;
default:
$this.attr(this.attribute, applyDataBindFormatters($this, value, this));
}
}
});
});
}
function applyDataBindFormatters($elem, value, data, settings) {
if (data.formatter && formatters[data.formatter]) {
return (function (formatterSettings) {
return formatters[data.formatter].call($elem, value, data.formatOptions, formatterSettings);
})(settings);
}
return value;
}
function getValue(data, param) {
if (param === "this") {
return data;
}
var paramParts = param.split('.'),
part,
value = data;
while ((part = paramParts.shift()) && typeof value !== "undefined" && value != null) {
value = value[part];
}
return value;
}
function applyFormatters($elem, value, attr, settings) {
var formatterTarget = $elem.attr("data-format-target"),
formatter;
if (formatterTarget === attr || (!formatterTarget && attr === "content")) {
formatter = $elem.attr("data-format");
if (formatter && typeof formatters[formatter] === "function") {
var formatOptions = $elem.attr("data-format-options");
return (function (formatterSettings) {
return formatters[formatter].call($elem[0], value, formatOptions, $.extend({}, formatterSettings));
})(settings);
}
}
return value;
}
addTemplateFormatter("nestedTemplateFormatter", function (value, options, internalSettings) {
if (!options) {
return;
}
if (typeof options === "string" && options[0] === "{") {
options = $.parseJSON(options);
}
var parentElement = options.parentElement || "div";
var template = options.template || options;
//If a parent is specified, return it; otherwise only return the generated children.
if (options.parentElement)
return $("<" + parentElement + "/>").loadTemplate(template, value, internalSettings);
else
return $("<" + parentElement + "/>").loadTemplate(template, value, internalSettings).children();
});
$.fn.loadTemplate = loadTemplate;
$.addTemplateFormatter = addTemplateFormatter;
})(jQuery);
================================================
FILE: viskit/tabulate.py
================================================
"""File taken from RLKit (https://github.com/vitchyr/rlkit)."""
# -*- coding: utf-8 -*-
# Taken from John's code
"""Pretty-print tabular data."""
from collections import namedtuple
from platform import python_version_tuple
import re
if python_version_tuple()[0] < "3":
from itertools import izip_longest
from functools import partial
_none_type = type(None)
_int_type = int
_float_type = float
_text_type = str
_binary_type = str
else:
from itertools import zip_longest as izip_longest
from functools import reduce, partial
_none_type = type(None)
_int_type = int
_float_type = float
_text_type = str
_binary_type = bytes
__all__ = ["tabulate", "tabulate_formats", "simple_separated_format"]
__version__ = "0.7.2"
Line = namedtuple("Line", ["begin", "hline", "sep", "end"])
DataRow = namedtuple("DataRow", ["begin", "sep", "end"])
# A table structure is suppposed to be:
#
# --- lineabove ---------
# headerrow
# --- linebelowheader ---
# datarow
# --- linebewteenrows ---
# ... (more datarows) ...
# --- linebewteenrows ---
# last datarow
# --- linebelow ---------
#
# TableFormat's line* elements can be
#
# - either None, if the element is not used,
# - or a Line tuple,
# - or a function: [col_widths], [col_alignments] -> string.
#
# TableFormat's *row elements can be
#
# - either None, if the element is not used,
# - or a DataRow tuple,
# - or a function: [cell_values], [col_widths], [col_alignments] -> string.
#
# padding (an integer) is the amount of white space around data values.
#
# with_header_hide:
#
# - either None, to display all table elements unconditionally,
# - or a list of elements not to be displayed if the table has column headers.
#
TableFormat = namedtuple("TableFormat", ["lineabove", "linebelowheader",
"linebetweenrows", "linebelow",
"headerrow", "datarow",
"padding", "with_header_hide"])
def _pipe_segment_with_colons(align, colwidth):
"""Return a segment of a horizontal line with optional colons which
indicate column's alignment (as in `pipe` output format)."""
w = colwidth
if align in ["right", "decimal"]:
return ('-' * (w - 1)) + ":"
elif align == "center":
return ":" + ('-' * (w - 2)) + ":"
elif align == "left":
return ":" + ('-' * (w - 1))
else:
return '-' * w
def _pipe_line_with_colons(colwidths, colaligns):
"""Return a horizontal line with optional colons to indicate column's
alignment (as in `pipe` output format)."""
segments = [_pipe_segment_with_colons(a, w) for a, w in zip(colaligns, colwidths)]
return "|" + "|".join(segments) + "|"
def _mediawiki_row_with_attrs(separator, cell_values, colwidths, colaligns):
alignment = { "left": '',
"right": 'align="right"| ',
"center": 'align="center"| ',
"decimal": 'align="right"| ' }
# hard-coded padding _around_ align attribute and value together
# rather than padding parameter which affects only the value
values_with_attrs = [' ' + alignment.get(a, '') + c + ' '
for c, a in zip(cell_values, colaligns)]
colsep = separator*2
return (separator + colsep.join(values_with_attrs)).rstrip()
def _latex_line_begin_tabular(colwidths, colaligns):
alignment = { "left": "l", "right": "r", "center": "c", "decimal": "r" }
tabular_columns_fmt = "".join([alignment.get(a, "l") for a in colaligns])
return "\\begin{tabular}{" + tabular_columns_fmt + "}\n\hline"
_table_formats = {"simple":
TableFormat(lineabove=Line("", "-", " ", ""),
linebelowheader=Line("", "-", " ", ""),
linebetweenrows=None,
linebelow=Line("", "-", " ", ""),
headerrow=DataRow("", " ", ""),
datarow=DataRow("", " ", ""),
padding=0,
with_header_hide=["lineabove", "linebelow"]),
"plain":
TableFormat(lineabove=None, linebelowheader=None,
linebetweenrows=None, linebelow=None,
headerrow=DataRow("", " ", ""),
datarow=DataRow("", " ", ""),
padding=0, with_header_hide=None),
"grid":
TableFormat(lineabove=Line("+", "-", "+", "+"),
linebelowheader=Line("+", "=", "+", "+"),
linebetweenrows=Line("+", "-", "+", "+"),
linebelow=Line("+", "-", "+", "+"),
headerrow=DataRow("|", "|", "|"),
datarow=DataRow("|", "|", "|"),
padding=1, with_header_hide=None),
"pipe":
TableFormat(lineabove=_pipe_line_with_colons,
linebelowheader=_pipe_line_with_colons,
linebetweenrows=None,
linebelow=None,
headerrow=DataRow("|", "|", "|"),
datarow=DataRow("|", "|", "|"),
padding=1,
with_header_hide=["lineabove"]),
"orgtbl":
TableFormat(lineabove=None,
linebelowheader=Line("|", "-", "+", "|"),
linebetweenrows=None,
linebelow=None,
headerrow=DataRow("|", "|", "|"),
datarow=DataRow("|", "|", "|"),
padding=1, with_header_hide=None),
"rst":
TableFormat(lineabove=Line("", "=", " ", ""),
linebelowheader=Line("", "=", " ", ""),
linebetweenrows=None,
linebelow=Line("", "=", " ", ""),
headerrow=DataRow("", " ", ""),
datarow=DataRow("", " ", ""),
padding=0, with_header_hide=None),
"mediawiki":
TableFormat(lineabove=Line("{| class=\"wikitable\" style=\"text-align: left;\"",
"", "", "\n|+ \n|-"),
linebelowheader=Line("|-", "", "", ""),
linebetweenrows=Line("|-", "", "", ""),
linebelow=Line("|}", "", "", ""),
headerrow=partial(_mediawiki_row_with_attrs, "!"),
datarow=partial(_mediawiki_row_with_attrs, "|"),
padding=0, with_header_hide=None),
"latex":
TableFormat(lineabove=_latex_line_begin_tabular,
linebelowheader=Line("\\hline", "", "", ""),
linebetweenrows=None,
linebelow=Line("\\hline\n\\end{tabular}", "", "", ""),
headerrow=DataRow("", "&", "\\\\"),
datarow=DataRow("", "&", "\\\\"),
padding=1, with_header_hide=None),
"tsv":
TableFormat(lineabove=None, linebelowheader=None,
linebetweenrows=None, linebelow=None,
headerrow=DataRow("", "\t", ""),
datarow=DataRow("", "\t", ""),
padding=0, with_header_hide=None)}
tabulate_formats = list(sorted(_table_formats.keys()))
_invisible_codes = re.compile("\x1b\[\d*m") # ANSI color codes
_invisible_codes_bytes = re.compile(b"\x1b\[\d*m") # ANSI color codes
def simple_separated_format(separator):
"""Construct a simple TableFormat with columns separated by a separator.
>>> tsv = simple_separated_format("\\t") ; \
tabulate([["foo", 1], ["spam", 23]], tablefmt=tsv) == 'foo \\t 1\\nspam\\t23'
True
"""
return TableFormat(None, None, None, None,
headerrow=DataRow('', separator, ''),
datarow=DataRow('', separator, ''),
padding=0, with_header_hide=None)
def _isconvertible(conv, string):
try:
n = conv(string)
return True
except ValueError:
return False
def _isnumber(string):
"""
>>> _isnumber("123.45")
True
>>> _isnumber("123")
True
>>> _isnumber("spam")
False
"""
return _isconvertible(float, string)
def _isint(string):
"""
>>> _isint("123")
True
>>> _isint("123.45")
False
"""
return type(string) is int or \
(isinstance(string, _binary_type) or isinstance(string, _text_type)) and \
_isconvertible(int, string)
def _type(string, has_invisible=True):
"""The least generic type (type(None), int, float, str, unicode).
>>> _type(None) is type(None)
True
>>> _type("foo") is type("")
True
>>> _type("1") is type(1)
True
>>> _type('\x1b[31m42\x1b[0m') is type(42)
True
>>> _type('\x1b[31m42\x1b[0m') is type(42)
True
"""
if has_invisible and \
(isinstance(string, _text_type) or isinstance(string, _binary_type)):
string = _strip_invisible(string)
if string is None:
return _none_type
elif hasattr(string, "isoformat"): # datetime.datetime, date, and time
return _text_type
elif _isint(string):
return int
elif _isnumber(string):
return float
elif isinstance(string, _binary_type):
return _binary_type
else:
return _text_type
def _afterpoint(string):
"""Symbols after a decimal point, -1 if the string lacks the decimal point.
>>> _afterpoint("123.45")
2
>>> _afterpoint("1001")
-1
>>> _afterpoint("eggs")
-1
>>> _afterpoint("123e45")
2
"""
if _isnumber(string):
if _isint(string):
return -1
else:
pos = string.rfind(".")
pos = string.lower().rfind("e") if pos < 0 else pos
if pos >= 0:
return len(string) - pos - 1
else:
return -1 # no point
else:
return -1 # not a number
def _padleft(width, s, has_invisible=True):
"""Flush right.
>>> _padleft(6, '\u044f\u0439\u0446\u0430') == ' \u044f\u0439\u0446\u0430'
True
"""
iwidth = width + len(s) - len(_strip_invisible(s)) if has_invisible else width
fmt = "{0:>%ds}" % iwidth
return fmt.format(s)
def _padright(width, s, has_invisible=True):
"""Flush left.
>>> _padright(6, '\u044f\u0439\u0446\u0430') == '\u044f\u0439\u0446\u0430 '
True
"""
iwidth = width + len(s) - len(_strip_invisible(s)) if has_invisible else width
fmt = "{0:<%ds}" % iwidth
return fmt.format(s)
def _padboth(width, s, has_invisible=True):
"""Center string.
>>> _padboth(6, '\u044f\u0439\u0446\u0430') == ' \u044f\u0439\u0446\u0430 '
True
"""
iwidth = width + len(s) - len(_strip_invisible(s)) if has_invisible else width
fmt = "{0:^%ds}" % iwidth
return fmt.format(s)
def _strip_invisible(s):
"Remove invisible ANSI color codes."
if isinstance(s, _text_type):
return re.sub(_invisible_codes, "", s)
else: # a bytestring
return re.sub(_invisible_codes_bytes, "", s)
def _visible_width(s):
"""Visible width of a printed string. ANSI color codes are removed.
>>> _visible_width('\x1b[31mhello\x1b[0m'), _visible_width("world")
(5, 5)
"""
if isinstance(s, _text_type) or isinstance(s, _binary_type):
return len(_strip_invisible(s))
else:
return len(_text_type(s))
def _align_column(strings, alignment, minwidth=0, has_invisible=True):
"""[string] -> [padded_string]
>>> list(map(str,_align_column(["12.345", "-1234.5", "1.23", "1234.5", "1e+234", "1.0e234"], "decimal")))
[' 12.345 ', '-1234.5 ', ' 1.23 ', ' 1234.5 ', ' 1e+234 ', ' 1.0e234']
>>> list(map(str,_align_column(['123.4', '56.7890'], None)))
['123.4', '56.7890']
"""
if alignment == "right":
strings = [s.strip() for s in strings]
padfn = _padleft
elif alignment == "center":
strings = [s.strip() for s in strings]
padfn = _padboth
elif alignment == "decimal":
decimals = [_afterpoint(s) for s in strings]
maxdecimals = max(decimals)
strings = [s + (maxdecimals - decs) * " "
for s, decs in zip(strings, decimals)]
padfn = _padleft
elif not alignment:
return strings
else:
strings = [s.strip() for s in strings]
padfn = _padright
if has_invisible:
width_fn = _visible_width
else:
width_fn = len
maxwidth = max(max(list(map(width_fn, strings))), minwidth)
padded_strings = [padfn(maxwidth, s, has_invisible) for s in strings]
return padded_strings
def _more_generic(type1, type2):
types = { _none_type: 0, int: 1, float: 2, _binary_type: 3, _text_type: 4 }
invtypes = { 4: _text_type, 3: _binary_type, 2: float, 1: int, 0: _none_type }
moregeneric = max(types.get(type1, 4), types.get(type2, 4))
return invtypes[moregeneric]
def _column_type(strings, has_invisible=True):
"""The least generic type all column values are convertible to.
>>> _column_type(["1", "2"]) is _int_type
True
>>> _column_type(["1", "2.3"]) is _float_type
True
>>> _column_type(["1", "2.3", "four"]) is _text_type
True
>>> _column_type(["four", '\u043f\u044f\u0442\u044c']) is _text_type
True
>>> _column_type([None, "brux"]) is _text_type
True
>>> _column_type([1, 2, None]) is _int_type
True
>>> import datetime as dt
>>> _column_type([dt.datetime(1991,2,19), dt.time(17,35)]) is _text_type
True
"""
types = [_type(s, has_invisible) for s in strings ]
return reduce(_more_generic, types, int)
def _format(val, valtype, floatfmt, missingval=""):
"""Format a value accoding to its type.
Unicode is supported:
>>> hrow = ['\u0431\u0443\u043a\u0432\u0430', '\u0446\u0438\u0444\u0440\u0430'] ; \
tbl = [['\u0430\u0437', 2], ['\u0431\u0443\u043a\u0438', 4]] ; \
good_result = '\\u0431\\u0443\\u043a\\u0432\\u0430 \\u0446\\u0438\\u0444\\u0440\\u0430\\n------- -------\\n\\u0430\\u0437 2\\n\\u0431\\u0443\\u043a\\u0438 4' ; \
tabulate(tbl, headers=hrow) == good_result
True
"""
if val is None:
return missingval
if valtype in [int, _text_type]:
return "{0}".format(val)
elif valtype is _binary_type:
return _text_type(val, "ascii")
elif valtype is float:
return format(float(val), floatfmt)
else:
return "{0}".format(val)
def _align_header(header, alignment, width):
if alignment == "left":
return _padright(width, header)
elif alignment == "center":
return _padboth(width, header)
elif not alignment:
return "{0}".format(header)
else:
return _padleft(width, header)
def _normalize_tabular_data(tabular_data, headers):
"""Transform a supported data type to a list of lists, and a list of headers.
Supported tabular data types:
* list-of-lists or another iterable of iterables
* list of named tuples (usually used with headers="keys")
* 2D NumPy arrays
* NumPy record arrays (usually used with headers="keys")
* dict of iterables (usually used with headers="keys")
* pandas.DataFrame (usually used with headers="keys")
The first row can be used as headers if headers="firstrow",
column indices can be used as headers if headers="keys".
"""
if hasattr(tabular_data, "keys") and hasattr(tabular_data, "values"):
# dict-like and pandas.DataFrame?
if hasattr(tabular_data.values, "__call__"):
# likely a conventional dict
keys = list(tabular_data.keys())
rows = list(zip_longest(*list(tabular_data.values()))) # columns have to be transposed
elif hasattr(tabular_data, "index"):
# values is a property, has .index => it's likely a pandas.DataFrame (pandas 0.11.0)
keys = list(tabular_data.keys())
vals = tabular_data.values # values matrix doesn't need to be transposed
names = tabular_data.index
rows = [[v]+list(row) for v,row in zip(names, vals)]
else:
raise ValueError("tabular data doesn't appear to be a dict or a DataFrame")
if headers == "keys":
headers = list(map(_text_type,keys)) # headers should be strings
else: # it's a usual an iterable of iterables, or a NumPy array
rows = list(tabular_data)
if (headers == "keys" and
hasattr(tabular_data, "dtype") and
getattr(tabular_data.dtype, "names")):
# numpy record array
headers = tabular_data.dtype.names
elif (headers == "keys"
and len(rows) > 0
and isinstance(rows[0], tuple)
and hasattr(rows[0], "_fields")): # namedtuple
headers = list(map(_text_type, rows[0]._fields))
elif headers == "keys" and len(rows) > 0: # keys are column indices
headers = list(map(_text_type, list(range(len(rows[0])))))
# take headers from the first row if necessary
if headers == "firstrow" and len(rows) > 0:
headers = list(map(_text_type, rows[0])) # headers should be strings
rows = rows[1:]
headers = list(headers)
rows = list(map(list,rows))
# pad with empty headers for initial columns if necessary
if headers and len(rows) > 0:
nhs = len(headers)
ncols = len(rows[0])
if nhs < ncols:
headers = [""]*(ncols - nhs) + headers
return rows, headers
def tabulate(tabular_data, headers=[], tablefmt="simple",
floatfmt="g", numalign="decimal", stralign="left",
missingval=""):
"""Format a fixed width table for pretty printing.
>>> print(tabulate([[1, 2.34], [-56, "8.999"], ["2", "10001"]]))
--- ---------
1 2.34
-56 8.999
2 10001
--- ---------
The first required argument (`tabular_data`) can be a
list-of-lists (or another iterable of iterables), a list of named
tuples, a dictionary of iterables, a two-dimensional NumPy array,
NumPy record array, or a Pandas' dataframe.
Table headers
-------------
To print nice column headers, supply the second argument (`headers`):
- `headers` can be an explicit list of column headers
- if `headers="firstrow"`, then the first row of data is used
- if `headers="keys"`, then dictionary keys or column indices are used
Otherwise a headerless table is produced.
If the number of headers is less than the number of columns, they
are supposed to be names of the last columns. This is consistent
with the plain-text format of R and Pandas' dataframes.
>>> print(tabulate([["sex","age"],["Alice","F",24],["Bob","M",19]],
... headers="firstrow"))
sex age
----- ----- -----
Alice F 24
Bob M 19
Column alignment
----------------
`tabulate` tries to detect column types automatically, and aligns
the values properly. By default it aligns decimal points of the
numbers (or flushes integer numbers to the right), and flushes
everything else to the left. Possible column alignments
(`numalign`, `stralign`) are: "right", "center", "left", "decimal"
(only for `numalign`), and None (to disable alignment).
Table formats
-------------
`floatfmt` is a format specification used for columns which
contain numeric data with a decimal point.
`None` values are replaced with a `missingval` string:
>>> print(tabulate([["spam", 1, None],
... ["eggs", 42, 3.14],
... ["other", None, 2.7]], missingval="?"))
----- -- ----
spam 1 ?
eggs 42 3.14
other ? 2.7
----- -- ----
Various plain-text table formats (`tablefmt`) are supported:
'plain', 'simple', 'grid', 'pipe', 'orgtbl', 'rst', 'mediawiki',
and 'latex'. Variable `tabulate_formats` contains the list of
currently supported formats.
"plain" format doesn't use any pseudographics to draw tables,
it separates columns with a double space:
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]],
... ["strings", "numbers"], "plain"))
strings numbers
spam 41.9999
eggs 451
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="plain"))
spam 41.9999
eggs 451
"simple" format is like Pandoc simple_tables:
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]],
... ["strings", "numbers"], "simple"))
strings numbers
--------- ---------
spam 41.9999
eggs 451
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="simple"))
---- --------
spam 41.9999
eggs 451
---- --------
"grid" is similar to tables produced by Emacs table.el package or
Pandoc grid_tables:
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]],
... ["strings", "numbers"], "grid"))
+-----------+-----------+
| strings | numbers |
+===========+===========+
| spam | 41.9999 |
+-----------+-----------+
| eggs | 451 |
+-----------+-----------+
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="grid"))
+------+----------+
| spam | 41.9999 |
+------+----------+
| eggs | 451 |
+------+----------+
"pipe" is like tables in PHP Markdown Extra extension or Pandoc
pipe_tables:
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]],
... ["strings", "numbers"], "pipe"))
| strings | numbers |
|:----------|----------:|
| spam | 41.9999 |
| eggs | 451 |
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="pipe"))
|:-----|---------:|
| spam | 41.9999 |
| eggs | 451 |
"orgtbl" is like tables in Emacs org-mode and orgtbl-mode. They
are slightly different from "pipe" format by not using colons to
define column alignment, and using a "+" sign to indicate line
intersections:
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]],
... ["strings", "numbers"], "orgtbl"))
| strings | numbers |
|-----------+-----------|
| spam | 41.9999 |
| eggs | 451 |
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="orgtbl"))
| spam | 41.9999 |
| eggs | 451 |
"rst" is like a simple table format from reStructuredText; please
note that reStructuredText accepts also "grid" tables:
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]],
... ["strings", "numbers"], "rst"))
========= =========
strings numbers
========= =========
spam 41.9999
eggs 451
========= =========
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="rst"))
==== ========
spam 41.9999
eggs 451
==== ========
"mediawiki" produces a table markup used in Wikipedia and on other
MediaWiki-based sites:
>>> print(tabulate([["strings", "numbers"], ["spam", 41.9999], ["eggs", "451.0"]],
... headers="firstrow", tablefmt="mediawiki"))
{| class="wikitable" style="text-align: left;"
|+
|-
! strings !! align="right"| numbers
|-
| spam || align="right"| 41.9999
|-
| eggs || align="right"| 451
|}
"latex" produces a tabular environment of LaTeX document markup:
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="latex"))
\\begin{tabular}{lr}
\\hline
spam & 41.9999 \\\\
eggs & 451 \\\\
\\hline
\\end{tabular}
"""
list_of_lists, headers = _normalize_tabular_data(tabular_data, headers)
# optimization: look for ANSI control codes once,
# enable smart width functions only if a control code is found
plain_text = '\n'.join(['\t'.join(map(_text_type, headers))] + \
['\t'.join(map(_text_type, row)) for row in list_of_lists])
has_invisible = re.search(_invisible_codes, plain_text)
if has_invisible:
width_fn = _visible_width
else:
width_fn = len
# format rows and columns, convert numeric values to strings
cols = list(zip(*list_of_lists))
coltypes = list(map(_column_type, cols))
cols = [[_format(v, ct, floatfmt, missingval) for v in c]
for c,ct in zip(cols, coltypes)]
# align columns
aligns = [numalign if ct in [int,float] else stralign for ct in coltypes]
minwidths = [width_fn(h)+2 for h in headers] if headers else [0]*len(cols)
cols = [_align_column(c, a, minw, has_invisible)
for c, a, minw in zip(cols, aligns, minwidths)]
if headers:
# align headers and add headers
minwidths = [max(minw, width_fn(c[0])) for minw, c in zip(minwidths, cols)]
headers = [_align_header(h, a, minw)
for h, a, minw in zip(headers, aligns, minwidths)]
rows = list(zip(*cols))
else:
minwidths = [width_fn(c[0]) for c in cols]
rows = list(zip(*cols))
if not isinstance(tablefmt, TableFormat):
tablefmt = _table_formats.get(tablefmt, _table_formats["simple"])
return _format_table(tablefmt, headers, rows, minwidths, aligns)
def _build_simple_row(padded_cells, rowfmt):
"Format row according to DataRow format without padding."
begin, sep, end = rowfmt
return (begin + sep.join(padded_cells) + end).rstrip()
def _build_row(padded_cells, colwidths, colaligns, rowfmt):
"Return a string which represents a row of data cells."
if not rowfmt:
return None
if hasattr(rowfmt, "__call__"):
return rowfmt(padded_cells, colwidths, colaligns)
else:
return _build_simple_row(padded_cells, rowfmt)
def _build_line(colwidths, colaligns, linefmt):
"Return a string which represents a horizontal line."
if not linefmt:
return None
if hasattr(linefmt, "__call__"):
return linefmt(colwidths, colaligns)
else:
begin, fill, sep, end = linefmt
cells = [fill*w for w in colwidths]
return _build_simple_row(cells, (begin, sep, end))
def _pad_row(cells, padding):
if cells:
pad = " "*padding
padded_cells = [pad + cell + pad for cell in cells]
return padded_cells
else:
return cells
def _format_table(fmt, headers, rows, colwidths, colaligns):
"""Produce a plain-text representation of the table."""
lines = []
hidden = fmt.with_header_hide if (headers and fmt.with_header_hide) else []
pad = fmt.padding
headerrow = fmt.headerrow
padded_widths = [(w + 2*pad) for w in colwidths]
padded_headers = _pad_row(headers, pad)
padded_rows = [_pad_row(row, pad) for row in rows]
if fmt.lineabove and "lineabove" not in hidden:
lines.append(_build_line(padded_widths, colaligns, fmt.lineabove))
if padded_headers:
lines.append(_build_row(padded_headers, padded_widths, colaligns, headerrow))
if fmt.linebelowheader and "linebelowheader" not in hidden:
lines.append(_build_line(padded_widths, colaligns, fmt.linebelowheader))
if padded_rows and fmt.linebetweenrows and "linebetweenrows" not in hidden:
# initial rows with a line below
for row in padded_rows[:-1]:
lines.append(_build_row(row, padded_widths, colaligns, fmt.datarow))
lines.append(_build_line(padded_widths, colaligns, fmt.linebetweenrows))
# the last row without a line below
lines.append(_build_row(padded_rows[-1], padded_widths, colaligns, fmt.datarow))
else:
for row in padded_rows:
lines.append(_build_row(row, padded_widths, colaligns, fmt.datarow))
if fmt.linebelow and "linebelow" not in hidden:
lines.append(_build_line(padded_widths, colaligns, fmt.linebelow))
return "\n".join(lines)
================================================
FILE: viskit/templates/main.html
================================================